Source code for sporco.mpiutil

# -*- coding: utf-8 -*-
# Copyright (C) 2017-2021 by Cristina Garcia-Cardona <cgarciac@lanl.gov>
#                            Brendt Wohlberg <brendt@ieee.org>
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""Utility functions that make use of MPI for parallel computing."""

from __future__ import absolute_import, division, print_function
from builtins import range

from mpi4py import MPI
import itertools
import numpy as np


__author__ = """\n""".join(['Cristina Garcia-Cardona <cgarciac@lanl.gov>',
                            'Brendt Wohlberg <brendt@ieee.org>'])


__all__ = ['grid_search']


def _get_rank_limits(comm, arrlen):
    """Determine the chunk of the grid that has to be computed per
    process. The grid has been 'flattened' and has arrlen length. The
    chunk assigned to each process depends on its rank in the MPI
    communicator.

    Parameters
    ----------
    comm : MPI communicator object
      Describes topology of network: number of processes, rank
    arrlen : int
      Number of points in grid search.

    Returns
    -------
    begin : int
      Index, with respect to 'flattened' grid, where the chunk
      for this process starts.
    end : int
      Index, with respect to 'flattened' grid, where the chunk
      for this process ends.
    """

    rank = comm.Get_rank()  # Id of this process
    size = comm.Get_size()  # Total number of processes in communicator
    end = 0
    # The scan should be done with ints, not floats
    ranklen = int(arrlen / size)
    if rank < arrlen % size:
        ranklen += 1
    # Compute upper limit based on the sizes covered by the processes
    # with less rank
    end = comm.scan(sendobj=ranklen, op=MPI.SUM)
    begin = end - ranklen

    return (begin, end)