"""
==================================
The :mod:`mpi_array.update` Module
==================================
Helper classes for calculating sub-extent intersections in order
to perform remote array element copying/updates.
Classes and Functions
=====================
.. autosummary::
:toctree: generated/
ExtentAndRegion - container for locale extent and update region (sub-extent).
MpiExtentAndRegion - Provides MPI datatype creation.
ExtentUpdate - Base class for describing a sub-extent update.
PairExtentUpdate - Describes sub-extent source and sub-extent destination.
MpiPairExtentUpdate - Extends :obj:`PairExtentUpdate` with MPI data type factory.
MpiPairExtentUpdateDifferentDtypes - Over-rides :meth:`MpiPairExtentUpdate.do_get`.
HaloSingleExtentUpdate - Describes sub-extent for halo region update.
MpiHaloSingleExtentUpdate - Extends :obj:`HaloSingleExtentUpdate` with MPI data type factory.
UpdatesForRedistribute - Calculate sequence of overlapping extents between two distributions.
RmaUpdateExecutor - Execute updates using one-sided RMA fetch.
"""
from __future__ import absolute_import
import mpi4py.MPI as _mpi
import collections as _collections
import copy as _copy
import numpy as _np
from .license import license as _license, copyright as _copyright, version as _version
from .indexing import HaloIndexingExtent
from .indexing import calc_intersection_split as _calc_intersection_split
from . import types as _types
__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _version()
[docs]class ExtentAndRegion:
"""
Container for :obj:`mpi_array.distribution.LocaleExtent` and
an update region (:obj:`mpi_array.indexing.IndexingExtent`).
"""
[docs] def __init__(self, locale_extent, region_extent=None):
self._locale_extent = locale_extent
self._region_extent = region_extent
@property
def locale_extent(self):
return self._locale_extent
@property
def region_extent(self):
return self._region_extent
@region_extent.setter
def region_extent(self, region):
self._region_extent = region
[docs]class ExtentUpdate(object):
"""
Source and destination indexing info for updating a sub-extent region.
"""
[docs] def __init__(self, dst_extent_info, src_extent_info):
"""
Initialise.
:type dst_extent_info: :obj:`ExtentAndRegion`
:param dst_extent_info: Info containing locale extent which is to receive region update.
:type dst_extent_info: :obj:`ExtentAndRegion`
:param dst_extent_info: Info containing locale extent from which the region update is read.
"""
object.__init__(self)
self._dst = dst_extent_info
self._src = src_extent_info
@property
def dst_extent(self):
"""
The locale :obj:`LocaleExtent` which is to receive sub-array update.
"""
return self._dst.locale_extent
@property
def src_extent(self):
"""
The locale :obj:`CartLocaleExtent` from which the sub-array update is read.
"""
return self._src.locale_extent
def pair_extent_update_copyto(peu, dst_array, src_array, casting):
"""
Copies the :samp:`{peu}.src_update_extent` region from :samp:`{src_array}`
to the :samp:`{peu}.dst_update_extent` region of :samp:`{dst_array}`
:type peu: :obj:`PairExtentUpdate`
:param peu: Object describing extent of :samp:`dst_array` and :samp:`src_array`
and the source and destination regions.
:type dst_array: :obj:`numpy.ndarray`
:param dst_array: Destination for copy.
:type src_array: :obj:`numpy.ndarray`
:param src_array: Source for copy.
:type casting: :obj:`str`
:param casting: Indicates casting regime, see :func:`numpy.casting`.
"""
src_slice = peu.src_extent.globale_to_locale_extent_h(peu.src_update_extent).to_slice()
dst_slice = peu.dst_extent.globale_to_locale_extent_h(peu.dst_update_extent).to_slice()
_np.copyto(dst_array[dst_slice], src_array[src_slice], casting=casting)
[docs]class PairExtentUpdate(ExtentUpdate):
"""
Source and destination indexing info for updating a sub-extent region.
"""
[docs] def __init__(self, dst_extent, src_extent, dst_update_extent, src_update_extent):
ExtentUpdate.__init__(
self,
ExtentAndRegion(dst_extent, dst_update_extent),
ExtentAndRegion(src_extent, src_update_extent)
)
[docs] def copyto(self, dst_array, src_array, casting):
"""
Copies the :attr:`src_update_extent` region from :samp:`{src_array}`
to the :attr:`dst_update_extent` region of :samp:`{dst_array}`
"""
pair_extent_update_copyto(self, dst_array, src_array, casting)
@property
def dst_update_extent(self):
"""
The locale sub-extent (:obj:`IndexingExtent`) to be updated.
"""
return self._dst.region_extent
@property
def src_update_extent(self):
"""
The locale sub-extent (:obj:`IndexingExtent`) from which the update is read.
"""
return self._src.region_extent
[docs]class MpiExtentAndRegion(ExtentAndRegion):
[docs] def __init__(
self,
locale_extent,
region_extent,
dtype=None,
order=None,
mpi_data_type=None,
mpi_order=None
):
ExtentAndRegion.__init__(self, locale_extent, region_extent)
self._dtype = dtype
self._parent_mpi_data_type = None
self._order = order
self._mpi_data_type = mpi_data_type
self._mpi_order = mpi_order
[docs] def create_data_type(self, dtype, order="C"):
mpi_order = _mpi.ORDER_C
if order == "F":
mpi_order = _mpi.ORDER_FORTRAN
parent_mpi_data_type = _types.to_datatype(dtype)
mpi_data_type = \
parent_mpi_data_type.Create_subarray(
self.locale_extent.shape_h,
self.region_extent.shape,
self.locale_extent.globale_to_locale_h(self.region_extent.start),
order=mpi_order
)
mpi_data_type.Commit()
return mpi_data_type, mpi_order, parent_mpi_data_type
[docs] def initialise_mpi_data_type(self, dtype, order):
dtype = _np.dtype(dtype)
order = order.lower()
if (
(self._dtype is None)
or
(self._dtype != dtype)
or
(self._order != order)
):
self._mpi_data_type, self._mpi_order, self._parent_mpi_data_type = \
self.create_data_type(dtype, order)
self._dtype = dtype
self._order = order
@property
def mpi_data_type(self):
return self._mpi_data_type
[docs]class MpiPairExtentUpdate(ExtentUpdate):
"""
Source and destination indexing info for updating the whole of a halo portion.
Extends :obj:`ExtentUpdate` with API to create :obj:`mpi4py.MPI.Datatype`
instances (using :meth:`mpi4py.MPI.Datatype.Create_subarray`) for convenient
transfer of sub-array data.
"""
[docs] def __init__(self, dst_extent, src_extent, dst_update_extent, src_update_extent):
self._casting = "same_kind"
ExtentUpdate.__init__(
self,
MpiExtentAndRegion(dst_extent, dst_update_extent),
MpiExtentAndRegion(src_extent, src_update_extent)
)
self._str_format = \
(
"%8s, %20s, %20s, %20s, %20s, %20s, %20s, %16s, "
+
"%8s, %20s, %20s, %20s, %20s, %20s, %20s, %16s"
)
self._header_str = \
(
self._str_format
%
(
"dst rank",
"dst ext glb start",
"dst ext glb stop ",
"dst updt loc start",
"dst updt loc stop ",
"dst updt glb start",
"dst updt glb stop ",
"dst MPI datatype",
"src rank",
"src ext glb start",
"src ext glb stop ",
"src updt loc start",
"src updt loc stop ",
"src updt glb start",
"src updt glb stop ",
"src MPI datatype",
)
)
[docs] def initialise_data_types(self, dst_dtype, src_dtype, dst_order, src_order):
"""
Assigns new instances of `mpi4py.MPI.Datatype` for
the :attr:`dst_data_type` and :attr:`src_data_type`
attributes. Only creates new instances when
the :samp:`{dst_dtype}`, :samp:`{src_dtype}` or :samp:`{order}`
do not match existing instances.
:type dst_dtype: :obj:`numpy.dtype`
:param dst_dtype: The array element type of the array which is to receive data.
:type src_dtype: :obj:`numpy.dtype`
:param src_dtype: The array element type of the array from which data is copied.
:type order: :obj:`str`
:param order: Array memory layout, :samp:`"C"` for C array,
or :samp:`"F"` for fortran array.
"""
self._dst.initialise_mpi_data_type(dst_dtype, dst_order)
self._src.initialise_mpi_data_type(src_dtype, src_order)
[docs] def copyto(self, dst_array, src_array, casting):
"""
Copies the :attr:`src_update_extent` region from :samp:`{src_array}`
to the :attr:`dst_update_extent` region of :samp:`{dst_array}`
"""
pair_extent_update_copyto(self, dst_array, src_array, casting)
@property
def dst_update_extent(self):
"""
The locale sub-extent (:obj:`IndexingExtent`) to be updated.
"""
return self._dst.region_extent
@property
def src_update_extent(self):
"""
The locale sub-extent (:obj:`IndexingExtent`) from which the update is read.
"""
return self._src.region_extent
@property
def dst_dtype(self):
"""
A :obj:`numpy.dtype` object indicating the element type
of the destination array.
"""
return self._dst._dtype
@property
def src_dtype(self):
"""
A :obj:`numpy.dtype` object indicating the element type
of the source array.
"""
return self._src._dtype
@property
def dst_data_type(self):
"""
A :obj:`mpi4py.MPI.Datatype` object created
using :meth:`mpi4py.MPI.Datatype.Create_subarray` which
defines the sub-array of halo elements which are to
receive update values.
"""
return self._dst.mpi_data_type
@property
def src_data_type(self):
"""
A :obj:`mpi4py.MPI.Datatype` object created
using :meth:`mpi4py.MPI.Datatype.Create_subarray` which
defines the sub-array of halo elements from which
receive update values.
"""
return self._src.mpi_data_type
@property
def casting(self):
"""
A :obj:`str` indicating the casting allowed between different :obj:`numpy.dtype` elements.
See the :samp:`casting` parameter for the :func:`numpy.copyto` function.
"""
return self._casting
@casting.setter
def casting(self, casting):
self._casting = casting
[docs] def do_get(self, mpi_win, target_src_rank, origin_dst_buffer):
"""
Performs calls :meth:`mpi4py.MPI.Win.Get` method of :samp:`mpi_win`
to perform the RMA data-transfer.
:type mpi_win: :obj:`mpi4py.MPI.Win`
:param mpi_win: Window used to retrieve update region for array.
:type target_src_rank: :obj:`int`
:param target_src_rank: The rank of the target process in :samp:`mpi_win.group.rank`.
:type origin_dst_buffer: :obj:`memoryview`
:param origin_dst_buffer: The destination memory for the update, size of buffer
should correspond to the size of the :attr:`dst_extent`.
"""
mpi_win.Get(
[origin_dst_buffer, 1, self.dst_data_type],
target_src_rank,
[0, 1, self.src_data_type]
)
[docs] def do_rget(self, mpi_win, target_src_rank, origin_dst_buffer):
"""
Performs calls :meth:`mpi4py.MPI.Win.Rget` method of :samp:`mpi_win`
to perform the RMA data-transfer.
:type mpi_win: :obj:`mpi4py.MPI.Win`
:param mpi_win: Window used to retrieve update region for array.
:type target_src_rank: :obj:`int`
:param target_src_rank: The rank of the target process in :samp:`mpi_win.group.rank`.
:type origin_dst_buffer: :obj:`memoryview`
:param origin_dst_buffer: The destination memory for the update, size of buffer
should correspond to the size of the :attr:`dst_extent`.
"""
req = \
mpi_win.Rget(
[origin_dst_buffer, 1, self.dst_data_type],
target_src_rank,
[0, 1, self.src_data_type]
)
return req
[docs] def conclude(self):
"""
"""
pass
def __str__(self):
"""
Stringify.
"""
dst_mpi_dtype = None
if self._dst._dtype is not None:
dst_mpi_dtype = self._dst._parent_mpi_data_type.Get_name()
src_mpi_dtype = None
if self._src._dtype is not None:
src_mpi_dtype = self._src._parent_mpi_data_type.Get_name()
return \
(
self._str_format
%
(
self.dst_extent.inter_locale_rank,
self.dst_extent.start_h,
self.dst_extent.stop_h,
self.dst_extent.globale_to_locale_h(self.dst_update_extent.start),
self.dst_extent.globale_to_locale_h(self.dst_update_extent.stop),
self.dst_update_extent.start,
self.dst_update_extent.stop,
dst_mpi_dtype,
self.src_extent.inter_locale_rank,
self.src_extent.start_h,
self.src_extent.stop_h,
self.src_extent.globale_to_locale_h(self.src_update_extent.start),
self.src_extent.globale_to_locale_h(self.src_update_extent.stop),
self.src_update_extent.start,
self.src_update_extent.stop,
src_mpi_dtype
)
)
[docs]class MpiPairExtentUpdateDifferentDtypes(MpiPairExtentUpdate):
"""
Over-rides :meth:`MpiPairExtentUpdate.do_get` to buffer-copy and
subsequent casting when source and destination arrays have different :obj:`numpy.dtype`.
"""
[docs] def __init__(self, dst_extent, src_extent, dst_update_extent, src_update_extent):
"""
"""
MpiPairExtentUpdate.__init__(
self,
dst_extent,
src_extent,
dst_update_extent,
src_update_extent
)
self._buffer = None
self._dst_buffer = None
[docs] def do_get(self, mpi_win, target_src_rank, origin_dst_buffer):
"""
Performs calls :meth:`mpi4py.MPI.Win.Get` method of :samp:`mpi_win`
to perform the RMA data-transfer. Uses a locally allocated buffer
to receive the data and then uses :func:`numpy.copyto` to convert
the :attr:`src_dtype` to the :attr:`dst_dtype`.
:type mpi_win: :obj:`mpi4py.MPI.Win.Get`
:param mpi_win: Window used to retrieve update region for array.
:type target_src_rank: :obj:`int`
:param target_src_rank: The rank of the target process in :samp:`mpi_win.group.rank`.
:type origin_dst_buffer: :obj:`memoryview`
:param origin_dst_buffer: The destination memory for the update, size of buffer
should correspond to the size of the :attr:`dst_extent`.
"""
self._buffer = _np.empty(shape=self._src.region_extent.shape, dtype=self._src._dtype)
self._dst_buffer = origin_dst_buffer
mpi_win.Get(
[self._buffer, _np.product(self._buffer.shape),
self._src._parent_mpi_data_type],
target_src_rank,
[0, 1, self.src_data_type]
)
[docs] def do_rget(self, mpi_win, target_src_rank, origin_dst_buffer):
"""
"""
self._buffer = _np.empty(shape=self._src.region_extent.shape, dtype=self._src._dtype)
self._dst_buffer = origin_dst_buffer
r = mpi_win.Rget(
[self._buffer, _np.product(self._buffer.shape),
self._src._parent_mpi_data_type],
target_src_rank,
[0, 1, self.src_data_type]
)
return r
[docs] def conclude(self):
"""
"""
origin_dst_buffer_slice = \
self._dst.locale_extent.globale_to_locale_extent_h(self._dst.region_extent).to_slice()
_np.copyto(
self._dst_buffer[origin_dst_buffer_slice],
self._buffer,
casting=self.casting
)
self._buffer = None
self._dst_buffer = None
[docs]class HaloSingleExtentUpdate(ExtentUpdate):
"""
Source and destination indexing info for updating a halo portion.
"""
[docs] def __init__(self, dst_extent, src_extent, update_extent):
ExtentUpdate.__init__(
self,
ExtentAndRegion(dst_extent, update_extent),
ExtentAndRegion(src_extent, update_extent)
)
@property
def update_extent(self):
"""
The :obj:`IndexingExtent` indicating the halo sub-array which is to be updated.
"""
return self._src.region_extent
[docs]class MpiHaloSingleExtentUpdate(ExtentUpdate):
"""
Source and destination indexing info for updating the whole of a halo portion.
Extends :obj:`ExtentUpdate` with API to create :obj:`mpi4py.MPI.Datatype`
instances (using :meth:`mpi4py.MPI.Datatype.Create_subarray`) for convenient
transfer of sub-array data.
"""
[docs] def __init__(self, dst_extent, src_extent, update_extent):
ExtentUpdate.__init__(
self,
MpiExtentAndRegion(dst_extent, update_extent),
MpiExtentAndRegion(src_extent, update_extent)
)
self._str_format = \
"%8s, %20s, %20s, %20s, %20s, %8s, %20s, %20s, %20s, %20s, %20s, %20s, %16s"
self._header_str = \
(
self._str_format
%
(
"dst rank",
"dst ext glb start",
"dst ext glb stop ",
"dst halo loc start",
"dst halo loc stop ",
"src rank",
"src ext glb start",
"src ext glb stop ",
"src halo loc start",
"src halo loc stop ",
" halo glb start",
" halo glb stop ",
"MPI datatype",
)
)
[docs] def initialise_data_types(self, dtype, order):
"""
Assigns new instances of `mpi4py.MPI.Datatype` for
the :attr:`dst_data_type` and :attr:`src_data_type`
attributes. Only creates new instances when
the :samp:`{dtype}` and :samp:`{order}` do not match
existing instances.
:type dtype: :obj:`numpy.dtype`
:param dtype: The array element type.
:type order: :obj:`str`
:param order: Array memory layout, :samp:`"C"` for C array,
or :samp:`"F"` for fortran array.
"""
self._dst.initialise_mpi_data_type(dtype=dtype, order=order)
self._src.initialise_mpi_data_type(dtype=dtype, order=order)
@property
def dst_data_type(self):
"""
A :obj:`mpi4py.MPI.Datatype` object created
using :meth:`mpi4py.MPI.Datatype.Create_subarray` which
defines the sub-array of halo elements which are to
receive update values.
"""
return self._dst.mpi_data_type
@property
def src_data_type(self):
"""
A :obj:`mpi4py.MPI.Datatype` object created
using :meth:`mpi4py.MPI.Datatype.Create_subarray` which
defines the sub-array of halo elements from which
receive update values.
"""
return self._src.mpi_data_type
@property
def update_extent(self):
"""
The :obj:`IndexingExtent` indicating the halo sub-array which is to be updated.
"""
return self._src.region_extent
def __str__(self):
"""
Stringify.
"""
mpi_dtype = None
if self._dst._dtype is not None:
mpi_dtype = _types.to_datatype(self._dst._dtype).Get_name()
return \
(
self._str_format
%
(
self.dst_extent.cart_rank,
self.dst_extent.start_h,
self.dst_extent.stop_h,
self.dst_extent.globale_to_locale_h(self.update_extent.start),
self.dst_extent.globale_to_locale_h(self.update_extent.stop),
self.src_extent.cart_rank,
self.src_extent.start_h,
self.src_extent.stop_h,
self.src_extent.globale_to_locale_h(self.update_extent.start),
self.src_extent.globale_to_locale_h(self.update_extent.stop),
self.update_extent.start,
self.update_extent.stop,
mpi_dtype
)
)
class HalosUpdate(object):
"""
Indexing info for updating the halo regions of a single locale
on MPI rank :samp:`self.dst_rank`.
"""
#: The "low index" indices.
LO = HaloIndexingExtent.LO
#: The "high index" indices.
HI = HaloIndexingExtent.HI
def __init__(self, dst_rank, rank_to_extents_map):
"""
Construct.
:type dst_rank: :obj:`int`
:param dst_rank: The MPI rank (:samp:`cart_comm`) of the MPI
process which is to receive the halo updates.
:type rank_to_extents_map: :obj:`dict`
:param rank_to_extents_map: Dictionary of :samp:`(r, extent)`
pairs for all ranks :samp:`r` (of :samp:`cart_comm`), where :samp:`extent`
is a :obj:`CartLocaleExtent` object indicating the indexing extent
(tile) on MPI rank :samp:`r.`
"""
self.initialise(dst_rank, rank_to_extents_map)
def create_single_extent_update(self, dst_extent, src_extent, halo_extent):
"""
Factory method for creating instances of type :obj:`HaloSingleExtentUpdate`.
:type dst_extent: :obj:`IndexingExtent`
:param dst_extent: The destination locale extent for halo element update.
:type src_extent: :obj:`IndexingExtent`
:param src_extent: The source locale extent for obtaining halo element update.
:type halo_extent: :obj:`IndexingExtent`
:param halo_extent: The extent indicating the sub-array of halo elements.
:rtype: :obj:`HaloSingleExtentUpdate`
:return: Returns new instance of :obj:`HaloSingleExtentUpdate`.
"""
return HaloSingleExtentUpdate(dst_extent, src_extent, halo_extent)
def calc_halo_intersection(self, dst_extent, src_extent, axis, dir):
"""
Calculates the intersection of :samp:`{dst_extent}` halo slab with
the update region of :samp:`{src_extent}`.
:type dst_extent: :obj:`CartLocaleExtent`
:param dst_extent: Halo slab indicated by :samp:`{axis}` and :samp:`{dir}`
taken from this extent.
:type src_extent: :obj:`CartLocaleExtent`
:param src_extent: This extent, minus the halo in the :samp:`{axis}` dimension,
is intersected with the halo slab.
:type axis: :obj:`int`
:param axis: Axis dimension indicating slab.
:type dir: :attr:`LO` or :attr:`HI`
:param dir: :attr:`LO` for low-index slab or :attr:`HI` for high-index slab.
:rtype: :obj:`IndexingExtent`
:return: Overlap extent of :samp:{dst_extent} halo-slab and
the :samp:`{src_extent}` update region.
"""
return \
dst_extent.halo_slab_extent(axis, dir).calc_intersection(
src_extent.no_halo_extent(axis)
)
def split_extent_for_max_elements(self, extent, max_elements=None):
"""
Partitions the specified extent into smaller extents with number
of elements no more than :samp:`{max_elements}`.
:type extent: :obj:`CartLocaleExtent`
:param extent: The extent to be split.
:type max_elements: :obj:`int`
:param max_elements: Each partition of the returned split has no more
than this many elements.
:rtype: :obj:`list` of :obj:`CartLocaleExtent`
:return: List of extents forming a partition of :samp:`{extent}`
with each extent having no more than :samp:`{max_element}` elements.
"""
return [extent, ]
def initialise(self, dst_rank, rank_to_extents_map):
"""
Calculates the ranks and regions required to update the
halo regions of the :samp:`dst_rank` MPI rank.
:type dst_rank: :obj:`int`
:param dst_rank: The MPI rank (:samp:`cart_comm`) of the MPI
process which is to receive the halo updates.
:type rank_to_extents_map: :obj:`dict`
:param rank_to_extents_map: Dictionary of :samp:`(r, extent)`
pairs for all ranks :samp:`r` (of :samp:`cart_comm`), where :samp:`extent`
is a :obj:`CartLocaleExtent` object indicating the indexing extent
(tile) on MPI rank :samp:`r.`
"""
self._dst_rank = dst_rank
self._dst_extent = rank_to_extents_map[dst_rank]
self._updates = [[[], []]] * self._dst_extent.ndim
if hasattr(rank_to_extents_map, "keys"):
ranks = rank_to_extents_map.keys()
else:
ranks = range(0, len(rank_to_extents_map))
cart_coord_to_extents_dict = \
{
tuple(rank_to_extents_map[r].cart_coord): rank_to_extents_map[r]
for r in ranks
}
for dir in [self.LO, self.HI]:
for a in range(self._dst_extent.ndim):
if dir == self.LO:
i_range = range(-1, -self._dst_extent.cart_coord[a] - 1, -1)
else:
i_range = \
range(1, self._dst_extent.cart_shape[a] - self._dst_extent.cart_coord[a], 1)
for i in i_range:
src_cart_coord = _np.array(self._dst_extent.cart_coord, copy=True)
src_cart_coord[a] += i
src_extent = cart_coord_to_extents_dict[tuple(src_cart_coord)]
halo_extent = self.calc_halo_intersection(self._dst_extent, src_extent, a, dir)
if halo_extent is not None:
self._updates[a][dir] += \
self.split_extent_for_max_elements(
self.create_single_extent_update(
self._dst_extent,
src_extent,
halo_extent
)
)
else:
break
@property
def updates_per_axis(self):
"""
A :attr:`ndim` length list of pair elements, each element of the pair
is a list of :obj:`HaloSingleExtentUpdate` objects.
"""
return self._updates
class MpiHalosUpdate(HalosUpdate):
"""
Indexing info for updating the halo regions of a single tile
on MPI rank :samp:`self.dst_rank`.
Over-rides the :meth:`create_single_extent_update` to
return :obj:`MpiHaloSingleExtentUpdate` instances.
"""
def create_single_extent_update(self, dst_extent, src_extent, halo_extent):
"""
Factory method for creating instances of type :obj:`MpiHaloSingleExtentUpdate`.
:type dst_extent: :obj:`IndexingExtent`
:param dst_extent: The destination locale extent for halo element update.
:type src_extent: :obj:`IndexingExtent`
:param src_extent: The source locale extent for obtaining halo element update.
:type halo_extent: :obj:`IndexingExtent`
:param halo_extent: The extent indicating the sub-array of halo elements.
:rtype: :obj:`MpiHaloSingleExtentUpdate`
:return: Returns new instance of :obj:`MpiHaloSingleExtentUpdate`.
"""
return MpiHaloSingleExtentUpdate(dst_extent, src_extent, halo_extent)
[docs]class UpdatesForRedistribute(object):
"""
Collection of update extents for re-distribution of array
elements from one distribution to another.
"""
[docs] def __init__(
self,
dst_distrib,
src_distrib,
peer_rank_translator=None
):
"""
"""
object.__init__(self)
self._dst_distrib = dst_distrib
self._src_distrib = src_distrib
self._dst_extent_queue = None
self._dst_cpy2_updates = None
self._dst_rget_updates = None
self.update_dst_halo = False
self._dst_translated_peer_ranks = None
self._dst_peer_ranks = None
self._src_translated_peer_ranks = None
self._src_peer_ranks = None
if peer_rank_translator is not None:
if dst_distrib.peer_ranks_per_locale.ndim == 2:
self._dst_peer_ranks = _np.sort(dst_distrib.peer_ranks_per_locale)
self._dst_translated_peer_ranks = \
_np.sort(peer_rank_translator.dst_to_src(self._dst_peer_ranks))
else:
self._dst_peer_ranks = _copy.deepcopy(dst_distrib.peer_ranks_per_locale)
self._dst_translated_peer_ranks = _copy.deepcopy(dst_distrib.peer_ranks_per_locale)
for r in range(len(self._dst_peer_ranks)):
self._dst_peer_ranks[r] = _np.sort(self._dst_peer_ranks[r])
self._dst_translated_peer_ranks[r] = \
_np.sort(peer_rank_translator.dst_to_src(self._dst_peer_ranks[r]))
if src_distrib.peer_ranks_per_locale.ndim == 2:
self._src_peer_ranks = _np.sort(src_distrib.peer_ranks_per_locale)
self._src_translated_peer_ranks = \
_np.sort(peer_rank_translator.src_to_dst(self._src_peer_ranks))
else:
self._src_peer_ranks = _copy.deepcopy(src_distrib.peer_ranks_per_locale)
self._src_translated_peer_ranks = _copy.deepcopy(src_distrib.peer_ranks_per_locale)
for r in range(len(self._src_peer_ranks)):
self._src_peer_ranks[r] = _np.sort(self._src_peer_ranks[r])
self._src_translated_peer_ranks[r] = \
_np.sort(peer_rank_translator.src_to_dst(self._src_peer_ranks[r]))
self.initialise()
[docs] def create_pair_extent_update(
self,
dst_extent,
src_extent,
intersection_extent
):
"""
Factory method for creating :obj:`PairExtentUpdate` objects.
:type dst_extent: :obj:`mpi_array.distribution.LocaleExtent`
:param dst_extent: Destination extent.
:type src_extent: :obj:`mpi_array.distribution.LocaleExtent`
:param src_extent: Source extent.
:type intersection_extent: :obj:`mpi_array.indexing.IndexingExtent`
:param src_extent: The intersection of :samp:`{src_extent}`
and :samp:`{dst_extent}` which defines the region of array elements which
are to be transferred from source to destination.
:rtype: :obj:`PairExtentUpdate`
:return: Object Defining the source sub-array and destination sub-array.
"""
peu = \
PairExtentUpdate(
self._dst_distrib.locale_extents[dst_extent.inter_locale_rank],
self._src_distrib.locale_extents[src_extent.inter_locale_rank],
intersection_extent,
intersection_extent
)
return [peu, ]
[docs] def calc_intersection_split(self, dst_extent, src_extent):
"""
Calculates intersection between :samp:`{dst_extent}` and `{src_extent}`.
Any regions of :samp:`{dst_extent}` which **do not** intersect with :samp:`{src_extent}`
are returned as a :obj:`list` of *left-over* :samp:`type({dst_extent})` elements.
The regions of :samp:`{dst_extent}` which **do** intersect with :samp:`{src_extent}`
are returned as a :obj:`list` of *update* :obj:`PairExtentUpdate` elements.
Returns :obj:`tuple` pair :samp:`(leftovers, updates)`
:type dst_extent: :obj:`HaloIndexingExtent`
:param dst_extent: Extent which is to receive update from intersection
with :samp:`{src_extent}`.
:type src_extent: :obj:`HaloIndexingExtent`
:param src_extent: Extent which is to provide update for the intersecting
region of :samp:`{dst_extent}`.
:rtype: :obj:`tuple`
:return: Returns :obj:`tuple` pair of :samp:`(leftovers, updates)`.
"""
return \
_calc_intersection_split(
dst_extent,
src_extent,
self.create_pair_extent_update,
self.update_dst_halo
)
[docs] def get_cpy2_src_extents(self, dst_inter_locale_rank):
"""
"""
dst_translated_peer_ranks = self._dst_translated_peer_ranks[dst_inter_locale_rank]
src_locale_extents = self._src_distrib.locale_extents
src_extent_indices = \
tuple(
src_inter_locale_rank
for src_inter_locale_rank in range(0, len(src_locale_extents))
if _np.intersect1d(
dst_translated_peer_ranks,
self._src_peer_ranks[src_inter_locale_rank]
).size > 0
)
src_extents = tuple(src_locale_extents[e] for e in src_extent_indices)
return src_extents
[docs] def initialise_cpy2_updates(self):
"""
"""
if (self._dst_translated_peer_ranks is not None) and (self._src_peer_ranks is not None):
all_dst_leftovers = []
for dst_extent_idx in range(len(self._dst_extent_queue)):
dst_extent = self._dst_extent_queue.pop()
dst_inter_locale_rank = dst_extent.inter_locale_rank
src_extents = self.get_cpy2_src_extents(dst_inter_locale_rank)
dst_extent_leftovers = [dst_extent, ]
if (src_extents is not None) and (len(src_extents) > 0):
for src_extent in src_extents:
new_dst_extent_leftovers = []
for dst_extent in dst_extent_leftovers:
dst_leftovers, dst_updates = \
self.calc_intersection_split(dst_extent, src_extent)
self._dst_cpy2_updates[dst_inter_locale_rank] += dst_updates
new_dst_extent_leftovers += dst_leftovers
dst_extent_leftovers = new_dst_extent_leftovers
all_dst_leftovers += dst_extent_leftovers
self._dst_extent_queue.extend(all_dst_leftovers)
[docs] def initialise_rget_updates(self):
"""
"""
for src_rank in range(len(self._src_distrib.locale_extents)):
src_extent = self._src_distrib.locale_extents[src_rank]
all_dst_leftovers = []
while len(self._dst_extent_queue) > 0:
dst_extent = self._dst_extent_queue.pop()
dst_rank = dst_extent.inter_locale_rank
dst_leftovers, dst_updates = \
self.calc_intersection_split(dst_extent, src_extent)
self._dst_rget_updates[dst_rank] += dst_updates
all_dst_leftovers += dst_leftovers
self._dst_extent_queue.extend(all_dst_leftovers)
if len(self._dst_extent_queue) <= 0:
break
if len(self._dst_extent_queue) > 0:
self._dst_cad.rank_logger.warning(
"Non-empty leftover queue=%s",
self._dst_extent_queue
)
[docs] def check_updates(self):
"""
Runs consistency checks on the calculated updates, assumes that
the :attr:`dst_distrib` and :attr:`src_distrib` distributed
as a partitioning (no locale extent overlaps except for halo).
:raises RuntimeError: If update inconsistency discovered.
"""
import itertools
msg = ""
all_updates = \
tuple(self._dst_cpy2_updates.values()) + tuple(self._dst_rget_updates.values())
all_updates = tuple(i for i in itertools.chain(*all_updates))
total_dst_update_elems = 0
total_src_update_elems = 0
for i in range(len(all_updates)):
u0 = all_updates[i]
total_dst_update_elems += _np.product(u0.dst_update_extent.shape)
total_src_update_elems += _np.product(u0.src_update_extent.shape)
for j in range(0, i):
u1 = all_updates[j]
isect = u0.dst_update_extent.calc_intersection(u1.dst_update_extent)
if isect is not None:
msg += \
(
"Got intersecting updates, intersection=%s, updates:\n%s\n%s\n\n"
(isect, u0, u1)
)
globale_intersect = \
self._dst_distrib.globale_extent.calc_intersection(self._src_distrib.globale_extent)
total_intersect_elems = _np.product(globale_intersect.shape)
if total_intersect_elems != total_dst_update_elems:
msg += \
(
"total_intersect_elems=%s != total_dst_update_elems=%s\n"
%
(total_intersect_elems, total_dst_update_elems)
)
if total_intersect_elems != total_src_update_elems:
msg += \
(
"total_intersect_elems=%s != total_src_update_elems=%s\n"
%
(total_intersect_elems, total_src_update_elems)
)
if (len(msg) > 0):
raise \
RuntimeError(
"%s.check_updates failed checks:\n%s"
%
(self.__class__.__name__, msg)
)
[docs] def initialise_updates(self):
"""
"""
self.initialise_cpy2_updates()
self.initialise_rget_updates()
[docs] def initialise(self):
"""
"""
self._dst_extent_queue = _collections.deque()
self._dst_extent_queue.extend(self._dst_distrib.locale_extents)
self._dst_cpy2_updates = _collections.defaultdict(list)
self._dst_rget_updates = _collections.defaultdict(list)
self.initialise_updates()
class UpdatesForGet(object):
"""
Collection of update extents for fetching an arbitrary sub-extent
from the globale array.
"""
def __init__(
self,
dst_extent,
src_distrib,
update_dst_halo=False
):
"""
"""
object.__init__(self)
self._dst_extent = dst_extent
self._src_distrib = src_distrib
self._dst_extent_queue = None
self._dst_cpy2_updates = None
self._dst_rget_updates = None
self._update_dst_halo = update_dst_halo
self.initialise()
def create_pair_extent_update(
self,
dst_extent,
src_extent,
intersection_extent
):
"""
Factory method for creating :obj:`PairExtentUpdate` objects.
:type dst_extent: :obj:`mpi_array.distribution.LocaleExtent`
:param dst_extent: Destination extent.
:type src_extent: :obj:`mpi_array.distribution.LocaleExtent`
:param src_extent: Source extent.
:type intersection_extent: :obj:`mpi_array.indexing.IndexingExtent`
:param src_extent: The intersection of :samp:`{src_extent}`
and :samp:`{dst_extent}` which defines the region of array elements which
are to be transferred from source to destination.
:rtype: :obj:`PairExtentUpdate`
:return: Object Defining the source sub-array and destination sub-array.
"""
peu = \
PairExtentUpdate(
self._dst_extent,
self._src_distrib.locale_extents[src_extent.inter_locale_rank],
intersection_extent,
intersection_extent
)
return [peu, ]
def calc_intersection_split(self, dst_extent, src_extent):
"""
Calculates intersection between :samp:`{dst_extent}` and `{src_extent}`.
Any regions of :samp:`{dst_extent}` which **do not** intersect with :samp:`{src_extent}`
are returned as a :obj:`list` of *left-over* :samp:`type({dst_extent})` elements.
The regions of :samp:`{dst_extent}` which **do** intersect with :samp:`{src_extent}`
are returned as a :obj:`list` of *update* :obj:`PairExtentUpdate` elements.
Returns :obj:`tuple` pair :samp:`(leftovers, updates)`
:type dst_extent: :obj:`HaloIndexingExtent`
:param dst_extent: Extent which is to receive update from intersection
with :samp:`{src_extent}`.
:type src_extent: :obj:`HaloIndexingExtent`
:param src_extent: Extent which is to provide update for the intersecting
region of :samp:`{dst_extent}`.
:rtype: :obj:`tuple`
:return: Returns :obj:`tuple` pair of :samp:`(leftovers, updates)`.
"""
return \
_calc_intersection_split(
dst_extent,
src_extent,
self.create_pair_extent_update,
self._update_dst_halo
)
def get_cpy2_src_extents(self, dst_inter_locale_rank):
"""
"""
src_extents = (self._src_distrib.locale_extents[dst_inter_locale_rank],)
return src_extents
def initialise_cpy2_updates(self):
"""
"""
all_dst_leftovers = []
for dst_extent_idx in range(len(self._dst_extent_queue)):
dst_extent = self._dst_extent_queue.pop()
dst_inter_locale_rank = dst_extent.inter_locale_rank
src_extents = self.get_cpy2_src_extents(dst_inter_locale_rank)
dst_extent_leftovers = [dst_extent, ]
if (src_extents is not None) and (len(src_extents) > 0):
for src_extent in src_extents:
new_dst_extent_leftovers = []
for dst_extent in dst_extent_leftovers:
dst_leftovers, dst_updates = \
self.calc_intersection_split(dst_extent, src_extent)
self._dst_cpy2_updates[dst_inter_locale_rank] += dst_updates
new_dst_extent_leftovers += dst_leftovers
dst_extent_leftovers = new_dst_extent_leftovers
all_dst_leftovers += dst_extent_leftovers
self._dst_extent_queue.extend(all_dst_leftovers)
def initialise_rget_updates(self):
"""
"""
for src_rank in range(len(self._src_distrib.locale_extents)):
src_extent = self._src_distrib.locale_extents[src_rank]
all_dst_leftovers = []
while len(self._dst_extent_queue) > 0:
dst_extent = self._dst_extent_queue.pop()
dst_rank = dst_extent.inter_locale_rank
dst_leftovers, dst_updates = \
self.calc_intersection_split(dst_extent, src_extent)
self._dst_rget_updates[dst_rank] += dst_updates
all_dst_leftovers += dst_leftovers
self._dst_extent_queue.extend(all_dst_leftovers)
if len(self._dst_extent_queue) <= 0:
break
if len(self._dst_extent_queue) > 0:
self._dst_cad.rank_logger.warning(
"Non-empty leftover queue=%s",
self._dst_extent_queue
)
def initialise_updates(self):
"""
"""
self.initialise_cpy2_updates()
self.initialise_rget_updates()
def initialise(self):
self._dst_extent_queue = _collections.deque()
self._dst_extent_queue.extend((self._dst_extent,))
self._dst_cpy2_updates = _collections.defaultdict(list)
self._dst_rget_updates = _collections.defaultdict(list)
self.initialise_updates()
class MpiUpdatesForGet(UpdatesForGet):
"""
Extends :obj:`UpdatesForGet` by over-riding :meth:`create_pair_extent_update`
to generate :obj:`MpiPairExtentUpdate` objects.
"""
def __init__(
self,
dst_extent,
src_distrib,
dtype,
order,
update_dst_halo=False,
):
"""
"""
self.dtype = _np.dtype(dtype)
self.order = order
UpdatesForGet.__init__(
self,
dst_extent=dst_extent,
src_distrib=src_distrib,
update_dst_halo=update_dst_halo
)
def create_pair_extent_update(
self,
dst_extent,
src_extent,
intersection_extent
):
"""
Factory method for creating :obj:`MpiPairExtentUpdate` objects.
:type dst_extent: :obj:`mpi_array.distribution.LocaleExtent`
:param dst_extent: Destination extent.
:type src_extent: :obj:`mpi_array.distribution.LocaleExtent`
:param src_extent: Source extent.
:type intersection_extent: :obj:`mpi_array.indexing.IndexingExtent`
:param src_extent: The intersection of :samp:`{src_extent}`
and :samp:`{dst_extent}` which defines the region of array elements which
are to be transferred from source to destination.
:rtype: :obj:`MpiPairExtentUpdate`
:return: Object Defining the source sub-array and destination sub-array.
"""
peu = \
MpiPairExtentUpdate(
self._dst_extent,
self._src_distrib.locale_extents[src_extent.inter_locale_rank],
intersection_extent,
intersection_extent
)
peu_list = [peu, ]
for peu in peu_list:
peu.initialise_data_types(
dst_dtype=self.dtype,
src_dtype=self.dtype,
dst_order=self.order,
src_order=self.order
)
return peu_list
[docs]class RmaUpdateExecutor(object):
"""
Performs one-sided fetch of data from remote (source) locale arrays to
update destination locale array.
"""
[docs] def __init__(
self,
inter_win,
dst_lndarray,
src_inter_win_rank_attr,
rank_logger=None,
casting="same_kind"
):
"""
"""
object.__init__(self)
self._dst_lndarray = dst_lndarray
self._casting = casting
self._inter_win = inter_win
self._num_requests_per_group = 32
self._random_state = None
self._rank_logger = rank_logger
self._src_inter_win_rank_attr = src_inter_win_rank_attr
[docs] def get_src_win_rank(self, src_extent):
"""
Returns target rank integer (:attr:`inter_win`) for specified :samp:`{src_extent}` extent.
:type src_extent: :obj:`mpi_array.distribution.LocaleExtent`
:param src_extent: Return target rank for this extent.
:rtype: :obj:`int`
:return: Target rank for window :attr:`inter_win`.
"""
return getattr(src_extent, self._src_inter_win_rank_attr)
@property
def inter_win(self):
"""
The :obj:`mpi4py.MPI.Win` instance used for remote fetch of data.
"""
return self._inter_win
@property
def dst_lndarray(self):
"""
The destination :obj:`numpy.ndarray` for remote fetches.
"""
return self._dst_lndarray
@property
def rank_logger(self):
"""
The :obj:`logging.Logger` used for log messages.
"""
return self._rank_logger
@rank_logger.setter
def rank_logger(self, logger):
"""
"""
self._rank_logger = logger
@property
def random_state(self):
"""
A :obj:`numpy.random.RandomState` instance, used to permute target
rank ordering to alleviate swamping single rank with *get* requests
from multiple source ranks.
"""
if (
(self._random_state is None)
and
(self._inter_win is not None)
and
(self._inter_win != _mpi.WIN_NULL)
):
seed_str = str(2 ** 31)[1:]
rank_str = str(self._inter_win.group.rank + 1)
seed_str = rank_str + seed_str[len(rank_str):]
seed_str = seed_str[0:-len(rank_str)] + rank_str[::-1]
self._random_state = _np.random.RandomState(seed=int(seed_str))
return self._random_state
[docs] def do_direct_cpy2_update(self, updates, src_lndarray):
"""
Does direct copy update to :attr:`dst_lndarray` from the
specified :samp:`{src_lndarray}` array.
:type updates: sequence of :obj:`PairExtentUpdate`
:param updates: Sequence of destination and source extents.
:type src_lndarray: :obj:`numpy.ndarray`
:param src_lndarray: Elements copied from this array.
"""
for single_update in updates:
single_update.copyto(self._dst_lndarray, src_lndarray, casting=self._casting)
[docs] def do_locale_rma_update(self, updates):
"""
Performs RMA to get elements from remote (source) locales to
update the (destination) locale extent array.
:type updates: sequence of :obj:`PairExtentUpdate`
:param updates: Sequence of destination and source extents.
"""
if (
(self._inter_win is not None)
and
(self._inter_win != _mpi.WIN_NULL)
):
update_dict = _collections.defaultdict(list)
for single_update in updates:
update_dict[self.get_src_win_rank(single_update.src_extent)].append(
single_update
)
src_win_ranks = self.random_state.permutation(tuple(update_dict.keys()))
src_win_ranks_per_group = self._num_requests_per_group
if len(src_win_ranks) > src_win_ranks_per_group:
src_win_rank_sub_groups = \
_np.array_split(
src_win_ranks,
(len(src_win_ranks) - 1) // src_win_ranks_per_group + 1
)
else:
src_win_rank_sub_groups = (src_win_ranks,)
group_idx = 1
for src_win_ranks in src_win_rank_sub_groups:
self.rank_logger.debug(
"BEG: Getting updates from src_win_ranks group %4d of %4d ranks: %s",
group_idx,
len(src_win_rank_sub_groups),
src_win_ranks
)
for src_win_rank in src_win_ranks:
self._inter_win.Lock(src_win_rank, _mpi.LOCK_SHARED)
for single_update in update_dict[src_win_rank]:
self.rank_logger.debug(
"Getting update:\n%s\n%s",
single_update._header_str,
single_update
)
single_update.do_get(
self._inter_win,
src_win_rank,
self._dst_lndarray
)
self._inter_win.Unlock(src_win_rank)
for single_update in update_dict[src_win_rank]:
single_update.conclude()
self.rank_logger.debug(
"END: Getting updates from src_win_ranks group %4d of %4d ranks: %s",
group_idx,
len(src_win_rank_sub_groups),
src_win_ranks
)
group_idx += 1
__all__ = [s for s in dir() if not s.startswith('_')]