"""
=======================================
The :mod:`mpi_array.update_test` Module
=======================================
Module defining :mod:`mpi_array.update` unit-tests.
Execute as::
python -m mpi_array.update_test
Classes
=======
.. autosummary::
:toctree: generated/
:template: autosummary/inherits_TestCase_class.rst
MpiPairExtentUpdateTest - Tests for :obj:`mpi_array.update.MpiPairExtentUpdate`.
MpiHaloSingleExtentUpdateTest - Tests :obj:`mpi_array.update.MpiHaloSingleExtentUpdate`.
HalosUpdateTest - Test mpi_array.update.HalosUpdate`.
UpdatesForRedistributeTest -Tests :obj:`mpi_array.update.UpdatesForRedistribute`.
"""
from __future__ import absolute_import
import mpi4py.MPI as _mpi
import numpy as _np # noqa: E402,F401
from array_split import shape_split as _shape_split
from .license import license as _license, copyright as _copyright, version as _version
from . import unittest as _unittest
from . import logging as _logging # noqa: E402,F401
from .indexing import IndexingExtent
from .distribution import CartLocaleExtent, GlobaleExtent, BlockPartition
from .update import MpiHaloSingleExtentUpdate, HalosUpdate
from .update import MpiPairExtentUpdate, UpdatesForRedistribute
__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _version()
[docs]class MpiPairExtentUpdateTest(_unittest.TestCase):
"""
Tests for :obj:`mpi_array.distribution.MpiPairExtentUpdate`.
"""
[docs] def setUp(self):
self.se = \
CartLocaleExtent(
peer_rank=0,
inter_locale_rank=0,
cart_coord=(0,),
cart_shape=(2,),
globale_extent=GlobaleExtent(stop=(100,)),
slice=(slice(0, 100),),
halo=((10, 10),)
)
self.de = \
CartLocaleExtent(
peer_rank=1,
inter_locale_rank=1,
cart_coord=(1,),
cart_shape=(2,),
globale_extent=GlobaleExtent(stop=(100,)),
slice=(slice(100, 200),),
halo=((10, 10),)
)
self.due = IndexingExtent(start=(90,), stop=(100,))
self.sue = IndexingExtent(start=(90,), stop=(100,))
[docs] def test_construct(self):
"""
Tests for :meth:`mpi_array.distribution.MpiPairExtentUpdate.__init__`.
"""
se = self.se
de = self.de
due = self.due
sue = self.sue
u = MpiPairExtentUpdate(de, se, due, sue)
self.assertTrue(u.dst_extent is de)
self.assertTrue(u.src_extent is se)
self.assertTrue(u.dst_update_extent is due)
self.assertTrue(u.src_update_extent is sue)
[docs] def test_str(self):
"""
Tests for :meth:`mpi_array.distribution.MpiPairExtentUpdate.__str__`.
"""
se = self.se
de = self.de
due = self.due
sue = self.sue
u = MpiPairExtentUpdate(de, se, due, sue)
self.assertTrue(len(str(u)) > 0)
u.initialise_data_types(dst_dtype="int32", src_dtype="int32", dst_order="C", src_order="C")
self.assertTrue(len(str(u)) > 0)
[docs] def test_data_type(self):
"""
Tests for :meth:`mpi_array.distribution.MpiHaloSingleExtentUpdate.__str__`.
"""
se = self.se
de = self.de
due = self.due
sue = self.sue
u = MpiPairExtentUpdate(de, se, due, sue)
u.initialise_data_types(dst_dtype="int32", src_dtype="int32", dst_order="C", src_order="C")
self.assertTrue(u.dst_data_type is not None)
self.assertTrue(isinstance(u.dst_data_type, _mpi.Datatype))
self.assertTrue(u.src_data_type is not None)
self.assertTrue(isinstance(u.src_data_type, _mpi.Datatype))
ddt = u.dst_data_type
sdt = u.src_data_type
u.initialise_data_types(dst_dtype="int32", src_dtype="int32", dst_order="C", src_order="C")
self.assertTrue(u.dst_data_type is ddt)
self.assertTrue(u.src_data_type is sdt)
ddt = u.dst_data_type
sdt = u.src_data_type
u.initialise_data_types(dst_dtype="int32", src_dtype="int32", dst_order="F", src_order="F")
self.assertTrue(u.dst_data_type is not ddt)
self.assertTrue(u.src_data_type is not sdt)
[docs]class MpiHaloSingleExtentUpdateTest(_unittest.TestCase):
"""
Tests for :obj:`mpi_array.distribution.MpiHaloSingleExtentUpdate`.
"""
[docs] def setUp(self):
self.se = \
CartLocaleExtent(
peer_rank=0,
inter_locale_rank=0,
cart_coord=(0,),
cart_shape=(2,),
globale_extent=GlobaleExtent(stop=(100,)),
slice=(slice(0, 100),),
halo=((10, 10),)
)
self.de = \
CartLocaleExtent(
peer_rank=1,
inter_locale_rank=1,
cart_coord=(1,),
cart_shape=(2,),
globale_extent=GlobaleExtent(stop=(100,)),
slice=(slice(100, 200),),
halo=((10, 10),)
)
self.ue = IndexingExtent(start=(90,), stop=(100,))
[docs] def test_construct(self):
"""
Tests for :meth:`mpi_array.distribution.MpiHaloSingleExtentUpdate.__init__`.
"""
se = self.se
de = self.de
ue = self.ue
u = MpiHaloSingleExtentUpdate(de, se, ue)
self.assertTrue(u.dst_extent is de)
self.assertTrue(u.src_extent is se)
self.assertTrue(u.update_extent is ue)
[docs] def test_str(self):
"""
Tests for :meth:`mpi_array.distribution.MpiHaloSingleExtentUpdate.__str__`.
"""
se = self.se
de = self.de
ue = self.ue
u = MpiHaloSingleExtentUpdate(de, se, ue)
self.assertTrue(len(str(u)) > 0)
u.initialise_data_types(dtype="int32", order="C")
self.assertTrue(len(str(u)) > 0)
[docs] def test_data_type(self):
"""
Tests for :meth:`mpi_array.distribution.MpiHaloSingleExtentUpdate.__str__`.
"""
se = self.se
de = self.de
ue = self.ue
u = MpiHaloSingleExtentUpdate(de, se, ue)
u.initialise_data_types(dtype="int32", order="C")
self.assertTrue(u.dst_data_type is not None)
self.assertTrue(isinstance(u.dst_data_type, _mpi.Datatype))
self.assertTrue(u.src_data_type is not None)
self.assertTrue(isinstance(u.src_data_type, _mpi.Datatype))
ddt = u.dst_data_type
sdt = u.src_data_type
u.initialise_data_types(dtype="int32", order="C")
self.assertTrue(u.dst_data_type is ddt)
self.assertTrue(u.src_data_type is sdt)
ddt = u.dst_data_type
sdt = u.src_data_type
u.initialise_data_types(dtype="int32", order="F")
self.assertTrue(u.dst_data_type is not ddt)
self.assertTrue(u.src_data_type is not sdt)
[docs]class HalosUpdateTest(_unittest.TestCase):
"""
Tests for :obj:`mpi_array.distribution.HalosUpdate`.
"""
[docs] def setUp(self):
self.se = \
CartLocaleExtent(
peer_rank=0,
inter_locale_rank=0,
cart_coord=(0,),
cart_shape=(2,),
globale_extent=GlobaleExtent(stop=(100,)),
slice=(slice(0, 100),),
halo=((10, 10),)
)
self.de = \
CartLocaleExtent(
peer_rank=1,
inter_locale_rank=1,
cart_coord=(1,),
cart_shape=(2,),
globale_extent=GlobaleExtent(stop=(100,)),
slice=(slice(100, 200),),
halo=((10, 10),)
)
self.ue = IndexingExtent(start=(90,), stop=(100,))
[docs] def test_construct(self):
"""
Tests for :meth:`mpi_array.distribution.HalosUpdate.__init__`.
"""
rank_to_extent_dict = \
{
self.se.cart_rank: self.se,
self.de.cart_rank: self.de
}
hu = HalosUpdate(self.de.cart_rank, rank_to_extent_dict)
self.assertEqual(1, len(hu.updates_per_axis))
self.assertEqual(2, len(hu.updates_per_axis[0]))
self.assertEqual(0, len(hu.updates_per_axis[0][1]))
self.assertEqual(1, len(hu.updates_per_axis[0][0]))
self.assertTrue(self.de is hu.updates_per_axis[0][0][0].dst_extent)
self.assertTrue(self.se is hu.updates_per_axis[0][0][0].src_extent)
self.assertEqual(self.ue, hu.updates_per_axis[0][0][0].update_extent)
[docs]class UpdatesForRedistributeTest(_unittest.TestCase):
"""
Tests for :obj:`mpi_array.update.UpdatesForRedistribute`.
"""
[docs] def setUp(self):
"""
Sets :samp:`self.rank_logger` attribute.
"""
self.rank_logger = _logging.get_rank_logger(self.id())
[docs] def test_slab_to_block(self):
"""
Tests :obj:`mpi_array.update.UpdatesForRedistribute` by calling
the :obj:`mpi_array.update.UpdatesForRedistribute.check_updates` method.
"""
num_peer_ranks = 64
num_peer_ranks_per_node = 4
num_nodes = num_peer_ranks // num_peer_ranks_per_node
gshape = (num_peer_ranks * 16, num_peer_ranks * 16)
d_proc_blck_dims = _shape_split(gshape, num_peer_ranks).shape
d_proc_blck_cart_ranks = _np.array(range(0, num_peer_ranks))
coords = _np.array(_np.unravel_index(d_proc_blck_cart_ranks, tuple(d_proc_blck_dims))).T
coords = [tuple(c) for c in coords]
d_proc_blck_cc2cr = {coords[cart_rank]: cart_rank for cart_rank in d_proc_blck_cart_ranks}
d_proc_blck_ilr2pr = _np.array(range(0, num_peer_ranks))
d_proc_blck = \
BlockPartition(
gshape,
d_proc_blck_dims,
d_proc_blck_cc2cr,
inter_locale_rank_to_peer_rank=d_proc_blck_ilr2pr
)
d_proc_blck_prpl = \
_np.arange(0, d_proc_blck.num_locales).reshape((d_proc_blck.num_locales, 1))
d_proc_blck.peer_ranks_per_locale = d_proc_blck_prpl
d_node_slab_dims = _shape_split(gshape, num_nodes, axis=0).shape
d_node_slab_cart_ranks = _np.array(range(0, num_nodes))
coords = _np.array(_np.unravel_index(d_node_slab_cart_ranks, tuple(d_node_slab_dims))).T
coords = [tuple(c) for c in coords]
d_node_slab_cc2cr = {coords[cart_rank]: cart_rank for cart_rank in d_node_slab_cart_ranks}
d_node_slab_ilr2pr = _np.array(range(0, num_nodes))
d_node_slab_ilr2pr *= num_peer_ranks_per_node
d_node_slab = \
BlockPartition(
gshape,
d_node_slab_dims,
d_node_slab_cc2cr,
inter_locale_rank_to_peer_rank=d_node_slab_ilr2pr
)
d_node_slab_prpl = \
_np.arange(0, num_peer_ranks).reshape(
(d_node_slab.num_locales, num_peer_ranks_per_node)
)
d_node_slab.peer_ranks_per_locale = d_node_slab_prpl
class RankTranslator(object):
def dst_to_src(self, ranks):
return _np.array(ranks, copy=True)
def src_to_dst(self, ranks):
return _np.array(ranks, copy=True)
self.rank_logger.info("BEG: u4r = UpdatesForRedistribute(d_proc_blck, d_node_slab)...")
u4r = UpdatesForRedistribute(d_proc_blck, d_node_slab)
self.rank_logger.info("END: UpdatesForRedistribute(d_proc_blck, d_node_slab).")
self.rank_logger.info("BEG: u4r.check_updates()...")
u4r.check_updates()
self.rank_logger.info("END: u4r.check_updates()...")
self.rank_logger.info(
"BEG: u4r = UpdatesForRedistribute(d_proc_blck, d_node_slab, RankTranslator())..."
)
u4r = UpdatesForRedistribute(d_proc_blck, d_node_slab, RankTranslator())
self.rank_logger.info(
"END: u4r = UpdatesForRedistribute(d_proc_blck, d_node_slab, RankTranslator())."
)
self.rank_logger.info("BEG: u4r.check_updates()...")
u4r.check_updates()
self.rank_logger.info("END: u4r.check_updates()...")
self.rank_logger.info("BEG: u4r = UpdatesForRedistribute(d_node_slab, d_proc_blck)...")
u4r = UpdatesForRedistribute(d_node_slab, d_proc_blck)
self.rank_logger.info("END: u4r = UpdatesForRedistribute(d_node_slab, d_proc_blck).")
self.rank_logger.info("BEG: u4r.check_updates()...")
u4r.check_updates()
self.rank_logger.info("END: u4r.check_updates()...")
self.rank_logger.info(
"BEG: u4r = UpdatesForRedistribute(d_node_slab, d_proc_blck, RankTranslator())..."
)
u4r = UpdatesForRedistribute(d_node_slab, d_proc_blck, RankTranslator())
self.rank_logger.info(
"END: u4r = UpdatesForRedistribute(d_node_slab, d_proc_blck, RankTranslator())..."
)
self.rank_logger.info("BEG: u4r.check_updates()...")
u4r.check_updates()
self.rank_logger.info("END: u4r.check_updates()...")
_unittest.main(__name__)
__all__ = [s for s in dir() if not s.startswith('_')]