Source code for mpi_array.distribution_test

"""
=============================================
The :mod:`mpi_array.distribution_test` Module
=============================================

Module defining :mod:`mpi_array.distribution` unit-tests.
Execute as::

   python -m mpi_array.distribution_test


Classes
=======

.. autosummary::
   :toctree: generated/
   :template: autosummary/inherits_TestCase_class.rst

   CartLocaleExtentTest - Tests for :obj:`mpi_array.distribution.CartLocaleExtent`.
   BlockPartitionTest - Tests for :obj:`mpi_array.distribution.BlockPartition`.


"""
from __future__ import absolute_import

import mpi4py.MPI as _mpi
import numpy as _np  # noqa: E402,F401
import array_split as _array_split

from .license import license as _license, copyright as _copyright, version as _version
from . import unittest as _unittest
from . import logging as _logging  # noqa: E402,F401

from .indexing import IndexingExtent
from .distribution import BlockPartition, Distribution
from .distribution import CartLocaleExtent, GlobaleExtent, LocaleExtent


__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _version()


class LocaleExtentTest(_unittest.TestCase):

    """
    :obj:`unittest.TestCase` for :obj:`mpi_array.distribution.LocaleExtent`.
    """

    def do_test_construct_empty_with_axis(self, halo=0):
        """
        Tests :obj:`mpi_array.distribution.LocaleExtent` with empty axis.
        """
        le = \
            LocaleExtent(
                start=(0, 0, 1),
                stop=(10, 20, 1),
                peer_rank=1,
                inter_locale_rank=1,
                globale_extent=GlobaleExtent(start=(0, 0, 0), stop=(10, 20, 0)),
                halo=halo
            )
        self.assertSequenceEqual(
            le.halo.tolist(),
            [[0, 0], [0, 0], [0, 0]]
        )
        self.assertSequenceEqual(
            tuple(le.start_n),
            (0, 0, 1)
        )
        self.assertSequenceEqual(
            tuple(le.start_h),
            (0, 0, 1)
        )
        self.assertSequenceEqual(
            tuple(le.stop_n),
            (10, 20, 1)
        )
        self.assertSequenceEqual(
            tuple(le.stop_n),
            (10, 20, 1)
        )

    def test_construct_empty_with_axis_no_halo(self):
        """
        Tests :obj:`mpi_array.distribution.LocaleExtent` with empty axis :samp:`halo=0`.
        """
        self.do_test_construct_empty_with_axis(halo=0)

    def test_construct_empty_with_axis_halo(self):
        """
        Tests :obj:`mpi_array.distribution.LocaleExtent` with empty axis with non-zero
        halo for all axes.
        """
        self.do_test_construct_empty_with_axis(halo=((1, 2), (3, 4), (3, 2)))

    def test_repr(self):
        """
        Tests :meth:`mpi_array.distribution.LocaleExtent.__repr__`.
        """
        le = \
            LocaleExtent(
                start=(25, 25),
                stop=(50, 50),
                peer_rank=0,
                inter_locale_rank=0,
                globale_extent=GlobaleExtent(start=(0, 0), stop=(50, 50)),
                halo=(2, 2)
            )
        le_repr = repr(le)
        le_eval = eval(le_repr)

        self.assertEqual(le, le_eval)

        le_str = str(le)
        self.assertEqual(le_repr, le_str)


[docs]class CartLocaleExtentTest(_unittest.TestCase): """ :obj:`unittest.TestCase` for :obj:`mpi_array.distribution.CartLocaleExtent`. """
[docs] def test_construct_attribs(self): """ Assertions for properties. """ de = \ CartLocaleExtent( peer_rank=0, inter_locale_rank=0, cart_coord=(0,), cart_shape=(1,), globale_extent=GlobaleExtent(stop=(100,)), slice=(slice(0, 100),), halo=((10, 10),) ) self.assertEqual(0, de.peer_rank) self.assertEqual(0, de.cart_rank) self.assertTrue(_np.all(de.cart_coord == (0,))) self.assertTrue(_np.all(de.cart_shape == (1,))) self.assertTrue(_np.all(de.halo == 0)) de = \ CartLocaleExtent( peer_rank=56, inter_locale_rank=7, cart_coord=(7,), cart_shape=(8,), globale_extent=GlobaleExtent(stop=(640,)), slice=(slice(560, 640),), halo=((10, 10),) ) self.assertEqual(56, de.peer_rank) self.assertEqual(7, de.cart_rank)
[docs] def test_repr(self): """ Tests :meth:`mpi_array.distribution.CartLocaleExtent.__repr__`. """ cle = \ CartLocaleExtent( start=(25, 25), stop=(50, 50), cart_coord=(3, 3), cart_shape=(4, 4), peer_rank=0, inter_locale_rank=0, globale_extent=GlobaleExtent(start=(0, 0), stop=(50, 50)), halo=(2, 2) ) cle_repr = repr(cle) cle_eval = eval(cle_repr) self.assertEqual(cle, cle_eval) cle_str = str(cle) self.assertEqual(cle_repr, cle_str)
[docs] def test_extent_calcs_1d_thick_tiles(self): """ Tests :meth:`mpi_array.distribution.CartLocaleExtent.halo_slab_extent` and :meth:`mpi_array.distribution.CartLocaleExtent.no_halo_extent` methods when halo size is smaller than the tile size. """ halo = ((10, 10),) splt = _array_split.shape_split((300,), axis=(3,), halo=0) de = \ [ CartLocaleExtent( peer_rank=r, inter_locale_rank=r, cart_coord=(r,), cart_shape=(splt.shape[0],), globale_extent=GlobaleExtent(stop=(300,)), slice=splt[r], halo=halo ) for r in range(0, splt.shape[0]) ] self.assertEqual(0, de[0].cart_rank) self.assertTrue(_np.all(de[0].cart_coord == (0,))) self.assertTrue(_np.all(de[0].cart_shape == (3,))) self.assertSequenceEqual(de[0].halo.tolist(), [[0, 10], ]) self.assertEqual( IndexingExtent(start=(0,), stop=(0,)), de[0].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(100,), stop=(110,)), de[0].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(0,), stop=(100,)), de[0].no_halo_extent(0) ) self.assertEqual(1, de[1].cart_rank) self.assertTrue(_np.all(de[1].cart_coord == (1,))) self.assertTrue(_np.all(de[1].cart_shape == (3,))) self.assertTrue(_np.all(de[1].halo == ((10, 10),))) self.assertEqual( IndexingExtent(start=(90,), stop=(100,)), de[1].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(200,), stop=(210,)), de[1].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(100,), stop=(200,)), de[1].no_halo_extent(0) ) self.assertEqual(2, de[2].cart_rank) self.assertTrue(_np.all(de[2].cart_coord == (2,))) self.assertTrue(_np.all(de[2].cart_shape == (3,))) self.assertTrue(_np.all(de[2].halo == ((10, 0),))) self.assertEqual( IndexingExtent(start=(190,), stop=(200,)), de[2].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(300,), stop=(300,)), de[2].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(200,), stop=(300,)), de[2].no_halo_extent(0) )
[docs] def test_extent_calcs_1d_thin_tiles(self): """ Tests :meth:`mpi_array.distribution.CartLocaleExtent.halo_slab_extent` and :meth:`mpi_array.distribution.CartLocaleExtent.no_halo_extent` methods when halo size is larger than the tile size, 1D fixture. """ halo = ((5, 5),) splt = _array_split.shape_split((15,), axis=(5,), halo=0) de = \ [ CartLocaleExtent( peer_rank=r, inter_locale_rank=r, cart_coord=(r,), cart_shape=(splt.shape[0],), globale_extent=GlobaleExtent(stop=(15,)), slice=splt[r], halo=halo ) for r in range(0, splt.shape[0]) ] self.assertEqual(0, de[0].cart_rank) self.assertTrue(_np.all(de[0].cart_coord == (0,))) self.assertTrue(_np.all(de[0].cart_shape == (5,))) self.assertTrue(_np.all(de[0].halo == ((0, 5),))) self.assertEqual( IndexingExtent(start=(0,), stop=(0,)), de[0].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(3,), stop=(8,)), de[0].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(0,), stop=(3,)), de[0].no_halo_extent(0) ) self.assertEqual(1, de[1].cart_rank) self.assertTrue(_np.all(de[1].cart_coord == (1,))) self.assertTrue(_np.all(de[1].cart_shape == (5,))) self.assertTrue(_np.all(de[1].halo == ((3, 5),))) self.assertEqual( IndexingExtent(start=(0,), stop=(3,)), de[1].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(6,), stop=(11,)), de[1].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(3,), stop=(6,)), de[1].no_halo_extent(0) ) self.assertEqual(2, de[2].cart_rank) self.assertTrue(_np.all(de[2].cart_coord == (2,))) self.assertTrue(_np.all(de[2].cart_shape == (5,))) self.assertTrue(_np.all(de[2].halo == ((5, 5),))) self.assertEqual( IndexingExtent(start=(1,), stop=(6,)), de[2].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(9,), stop=(14,)), de[2].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(6,), stop=(9,)), de[2].no_halo_extent(0) ) self.assertEqual(3, de[3].cart_rank) self.assertTrue(_np.all(de[3].cart_coord == (3,))) self.assertTrue(_np.all(de[3].cart_shape == (5,))) self.assertTrue(_np.all(de[3].halo == ((5, 3),))) self.assertEqual( IndexingExtent(start=(4,), stop=(9,)), de[3].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(12,), stop=(15,)), de[3].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(9,), stop=(12,)), de[3].no_halo_extent(0) ) self.assertEqual(4, de[4].cart_rank) self.assertTrue(_np.all(de[4].cart_coord == (4,))) self.assertTrue(_np.all(de[4].cart_shape == (5,))) self.assertTrue(_np.all(de[4].halo == ((5, 0),))) self.assertEqual( IndexingExtent(start=(7,), stop=(12,)), de[4].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(15,), stop=(15,)), de[4].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(12,), stop=(15,)), de[4].no_halo_extent(0) )
[docs] def test_extent_calcs_2d_thick_tiles(self): """ Tests :meth:`mpi_array.distribution.CartLocaleExtent.halo_slab_extent` and :meth:`mpi_array.distribution.CartLocaleExtent.no_halo_extent` methods when halo size is smaller than the tile size, 2D fixture. """ halo = ((10, 10), (5, 5)) splt = _array_split.shape_split((300, 600), axis=(3, 3), halo=0) de = \ [ CartLocaleExtent( peer_rank=r, inter_locale_rank=r, cart_coord=_np.unravel_index(r, splt.shape), cart_shape=splt.shape, globale_extent=GlobaleExtent(stop=(300, 600)), slice=splt[tuple(_np.unravel_index(r, splt.shape))], halo=halo ) for r in range(0, _np.product(splt.shape)) ] self.assertEqual(0, de[0].cart_rank) self.assertTrue(_np.all(de[0].cart_coord == (0, 0))) self.assertTrue(_np.all(de[0].cart_shape == (3, 3))) self.assertSequenceEqual([[0, 10], [0, 5]], de[0].halo.tolist()) self.assertEqual( IndexingExtent(start=(0, 0), stop=(0, 205)), de[0].halo_slab_extent(0, de[0].LO) ) self.assertEqual( IndexingExtent(start=(100, 0), stop=(110, 205)), de[0].halo_slab_extent(0, de[0].HI) ) self.assertEqual( IndexingExtent(start=(0, 0), stop=(110, 0)), de[0].halo_slab_extent(1, de[0].LO) ) self.assertEqual( IndexingExtent(start=(0, 200), stop=(110, 205)), de[0].halo_slab_extent(1, de[0].HI) ) self.assertEqual( IndexingExtent(start=(0, 0), stop=(100, 205)), de[0].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(0, 0), stop=(110, 200)), de[0].no_halo_extent(1) ) self.assertEqual(1, de[1].cart_rank) self.assertTrue(_np.all(de[1].cart_coord == (0, 1))) self.assertTrue(_np.all(de[1].cart_shape == (3, 3))) self.assertSequenceEqual([[0, 10], [5, 5]], de[1].halo.tolist()) self.assertEqual( IndexingExtent(start=(0, 195), stop=(0, 405)), de[1].halo_slab_extent(0, de[1].LO) ) self.assertEqual( IndexingExtent(start=(100, 195), stop=(110, 405)), de[1].halo_slab_extent(0, de[1].HI) ) self.assertEqual( IndexingExtent(start=(0, 195), stop=(110, 200)), de[1].halo_slab_extent(1, de[0].LO) ) self.assertEqual( IndexingExtent(start=(0, 400), stop=(110, 405)), de[1].halo_slab_extent(1, de[0].HI) ) self.assertEqual( IndexingExtent(start=(0, 195), stop=(100, 405)), de[1].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(0, 200), stop=(110, 400)), de[1].no_halo_extent(1) ) self.assertEqual(2, de[2].cart_rank) self.assertTrue(_np.all(de[2].cart_coord == (0, 2))) self.assertTrue(_np.all(de[2].cart_shape == (3, 3))) self.assertSequenceEqual([[0, 10], [5, 0]], de[2].halo.tolist()) self.assertEqual( IndexingExtent(start=(0, 395), stop=(0, 600)), de[2].halo_slab_extent(0, de[2].LO) ) self.assertEqual( IndexingExtent(start=(100, 395), stop=(110, 600)), de[2].halo_slab_extent(0, de[2].HI) ) self.assertEqual( IndexingExtent(start=(0, 395), stop=(110, 400)), de[2].halo_slab_extent(1, de[0].LO) ) self.assertEqual( IndexingExtent(start=(0, 600), stop=(110, 600)), de[2].halo_slab_extent(1, de[0].HI) ) self.assertEqual( IndexingExtent(start=(0, 395), stop=(100, 600)), de[2].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(0, 400), stop=(110, 600)), de[2].no_halo_extent(1) ) self.assertEqual(3, de[3].cart_rank) self.assertTrue(_np.all(de[3].cart_coord == (1, 0))) self.assertTrue(_np.all(de[3].cart_shape == (3, 3))) self.assertSequenceEqual([[10, 10], [0, 5]], de[3].halo.tolist()) self.assertEqual( IndexingExtent(start=(90, 0), stop=(100, 205)), de[3].halo_slab_extent(0, de[3].LO) ) self.assertEqual( IndexingExtent(start=(200, 0), stop=(210, 205)), de[3].halo_slab_extent(0, de[3].HI) ) self.assertEqual( IndexingExtent(start=(90, 0), stop=(210, 0)), de[3].halo_slab_extent(1, de[3].LO) ) self.assertEqual( IndexingExtent(start=(90, 200), stop=(210, 205)), de[3].halo_slab_extent(1, de[3].HI) ) self.assertEqual( IndexingExtent(start=(100, 0), stop=(200, 205)), de[3].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(90, 0), stop=(210, 200)), de[3].no_halo_extent(1) ) self.assertEqual(4, de[4].cart_rank) self.assertTrue(_np.all(de[4].cart_coord == (1, 1))) self.assertTrue(_np.all(de[4].cart_shape == (3, 3))) self.assertSequenceEqual([[10, 10], [5, 5]], de[4].halo.tolist()) self.assertEqual( IndexingExtent(start=(90, 195), stop=(100, 405)), de[4].halo_slab_extent(0, de[4].LO) ) self.assertEqual( IndexingExtent(start=(200, 195), stop=(210, 405)), de[4].halo_slab_extent(0, de[4].HI) ) self.assertEqual( IndexingExtent(start=(90, 195), stop=(210, 200)), de[4].halo_slab_extent(1, de[0].LO) ) self.assertEqual( IndexingExtent(start=(90, 400), stop=(210, 405)), de[4].halo_slab_extent(1, de[0].HI) ) self.assertEqual( IndexingExtent(start=(100, 195), stop=(200, 405)), de[4].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(90, 200), stop=(210, 400)), de[4].no_halo_extent(1) ) self.assertEqual(5, de[5].cart_rank) self.assertTrue(_np.all(de[5].cart_coord == (1, 2))) self.assertTrue(_np.all(de[5].cart_shape == (3, 3))) self.assertSequenceEqual([[10, 10], [5, 0]], de[5].halo.tolist()) self.assertEqual( IndexingExtent(start=(90, 395), stop=(100, 600)), de[5].halo_slab_extent(0, de[5].LO) ) self.assertEqual( IndexingExtent(start=(200, 395), stop=(210, 600)), de[5].halo_slab_extent(0, de[5].HI) ) self.assertEqual( IndexingExtent(start=(90, 395), stop=(210, 400)), de[5].halo_slab_extent(1, de[0].LO) ) self.assertEqual( IndexingExtent(start=(90, 600), stop=(210, 600)), de[5].halo_slab_extent(1, de[0].HI) ) self.assertEqual( IndexingExtent(start=(100, 395), stop=(200, 600)), de[5].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(90, 400), stop=(210, 600)), de[5].no_halo_extent(1) ) self.assertEqual(6, de[6].cart_rank) self.assertTrue(_np.all(de[6].cart_coord == (2, 0))) self.assertTrue(_np.all(de[6].cart_shape == (3, 3))) self.assertSequenceEqual([[10, 0], [0, 5]], de[6].halo.tolist()) self.assertEqual( IndexingExtent(start=(190, 0), stop=(200, 205)), de[6].halo_slab_extent(0, de[6].LO) ) self.assertEqual( IndexingExtent(start=(300, 0), stop=(300, 205)), de[6].halo_slab_extent(0, de[6].HI) ) self.assertEqual( IndexingExtent(start=(190, 0), stop=(300, 0)), de[6].halo_slab_extent(1, de[6].LO) ) self.assertEqual( IndexingExtent(start=(190, 200), stop=(300, 205)), de[6].halo_slab_extent(1, de[6].HI) ) self.assertEqual( IndexingExtent(start=(200, 0), stop=(300, 205)), de[6].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(190, 0), stop=(300, 200)), de[6].no_halo_extent(1) ) self.assertEqual(7, de[7].cart_rank) self.assertTrue(_np.all(de[7].cart_coord == (2, 1))) self.assertTrue(_np.all(de[7].cart_shape == (3, 3))) self.assertSequenceEqual([[10, 0], [5, 5]], de[7].halo.tolist()) self.assertEqual( IndexingExtent(start=(190, 195), stop=(200, 405)), de[7].halo_slab_extent(0, de[7].LO) ) self.assertEqual( IndexingExtent(start=(300, 195), stop=(300, 405)), de[7].halo_slab_extent(0, de[7].HI) ) self.assertEqual( IndexingExtent(start=(190, 195), stop=(300, 200)), de[7].halo_slab_extent(1, de[0].LO) ) self.assertEqual( IndexingExtent(start=(190, 400), stop=(300, 405)), de[7].halo_slab_extent(1, de[0].HI) ) self.assertEqual( IndexingExtent(start=(200, 195), stop=(300, 405)), de[7].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(190, 200), stop=(300, 400)), de[7].no_halo_extent(1) ) self.assertEqual(8, de[8].cart_rank) self.assertTrue(_np.all(de[8].cart_coord == (2, 2))) self.assertTrue(_np.all(de[8].cart_shape == (3, 3))) self.assertSequenceEqual([[10, 0], [5, 0]], de[8].halo.tolist()) self.assertEqual( IndexingExtent(start=(190, 395), stop=(200, 600)), de[8].halo_slab_extent(0, de[8].LO) ) self.assertEqual( IndexingExtent(start=(300, 395), stop=(300, 600)), de[8].halo_slab_extent(0, de[8].HI) ) self.assertEqual( IndexingExtent(start=(190, 395), stop=(300, 400)), de[8].halo_slab_extent(1, de[0].LO) ) self.assertEqual( IndexingExtent(start=(190, 600), stop=(300, 600)), de[8].halo_slab_extent(1, de[0].HI) ) self.assertEqual( IndexingExtent(start=(200, 395), stop=(300, 600)), de[8].no_halo_extent(0) ) self.assertEqual( IndexingExtent(start=(190, 400), stop=(300, 600)), de[8].no_halo_extent(1) )
class DistributionTest(_unittest.TestCase): """ :obj:`unittest.TestCase` for :obj:`mpi_array.distribution.Distribution`. """ def test_invalid_args(self): """ Tests for :meth:`mpi_array.distribution.Distribution.__init__` """ globale_extent = IndexingExtent(start=(0, 0, 0), stop=(100, 200, 300)) locale_extent = IndexingExtent(start=(0, 0, 0), stop=(100, 200, 300)) self.assertRaises( ValueError, Distribution, globale_extent=1, locale_extents=[locale_extent, ] ) self.assertRaises( ValueError, Distribution, globale_extent=globale_extent, locale_extents=[1, ] ) def test_construct(self): """ Tests for :meth:`mpi_array.distribution.Distribution.__init__` """ globale_extent = IndexingExtent(start=(0, 0, 0), stop=(100, 200, 300)) locale_extent = IndexingExtent(start=(0, 0, 0), stop=(100, 200, 300)) d = \ Distribution( globale_extent=globale_extent, locale_extents=[locale_extent, ] ) self.assertEqual( GlobaleExtent(start=globale_extent.start, stop=globale_extent.stop), d.globale_extent ) self.assertEqual(_mpi.UNDEFINED, d.get_peer_rank(0)) self.assertEqual( LocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, start=locale_extent.start, stop=locale_extent.stop, globale_extent=None ), d.locale_extents[0] ) d = \ Distribution( globale_extent=globale_extent.to_slice(), locale_extents=[locale_extent.to_slice(), ] ) self.assertEqual( GlobaleExtent(start=globale_extent.start, stop=globale_extent.stop), d.globale_extent ) self.assertEqual(_mpi.UNDEFINED, d.get_peer_rank(0)) self.assertEqual( LocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, start=locale_extent.start, stop=locale_extent.stop, globale_extent=None ), d.locale_extents[0] ) self.assertEqual( LocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, start=locale_extent.start, stop=locale_extent.stop, globale_extent=None ), d.get_extent_for_rank(0) ) le = \ LocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, start=locale_extent.start, stop=locale_extent.stop, globale_extent=None ) d = \ Distribution( globale_extent=globale_extent.to_slice(), locale_extents=[le, ], inter_locale_rank_to_peer_rank=[5] ) self.assertEqual( GlobaleExtent(start=globale_extent.start, stop=globale_extent.stop), d.globale_extent ) self.assertEqual(5, d.get_peer_rank(0)) self.assertEqual( LocaleExtent( peer_rank=5, inter_locale_rank=0, start=le.start, stop=le.stop, globale_extent=None ), d.get_extent_for_rank(0) )
[docs]class BlockPartitionTest(_unittest.TestCase): """ :obj:`unittest.TestCase` for :obj:`mpi_array.distribution.BlockPartition`. """
[docs] def setUp(self): """ Initialise self.root_logger. """ self.root_logger = _logging.get_root_logger(__name__ + "." + self.id()) self.rank_logger = _logging.get_rank_logger(__name__ + "." + self.id())
[docs] def test_construct_single_locale_1d(self): """ Test :obj:`mpi_array.distribution.BlockPartition` construction. """ distrib = \ BlockPartition( globale_extent=(8,), dims=[1, ], cart_coord_to_cart_rank={(0,): 0} ) self.assertEqual(1, len(distrib.locale_extents)) self.assertEqual(GlobaleExtent(stop=(8,)), distrib.globale_extent) self.assertEqual(1, distrib.num_locales) self.root_logger.info("START " + self.id()) self.root_logger.info(str(distrib)) self.root_logger.info("END " + self.id()) self.root_logger.info("distrib.locale_extents[0]=\n%s" % (distrib.locale_extents[0],)) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, globale_extent=GlobaleExtent(stop=(8,)), cart_coord=(0,), cart_shape=(1,), start=(0,), stop=(8,) ), distrib.locale_extents[0] ) distrib = \ BlockPartition( globale_extent=distrib.globale_extent, dims=[4, ], cart_coord_to_cart_rank={(i,): i for i in range(0, 4)} ) self.assertEqual(4, len(distrib.locale_extents)) self.assertEqual(GlobaleExtent(stop=(8,)), distrib.globale_extent) self.assertEqual(4, distrib.num_locales) self.root_logger.info("START " + self.id()) self.root_logger.info(str(distrib)) self.root_logger.info("END " + self.id())
[docs] def do_test_construct_1d_with_halo(self, halo=0): """ Test :obj:`mpi_array.distribution.BlockPartition` construction. """ distrib = \ BlockPartition( (32,), dims=[4, ], halo=halo, cart_coord_to_cart_rank={(i,): i for i in range(0, 4)} ) valid_extent = \ CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, globale_extent=GlobaleExtent(stop=(32,)), cart_coord=(0,), cart_shape=(4,), start=(0,), stop=(8,), halo=halo ) self.rank_logger.debug("valid_extent=\n%s" % (valid_extent,)) self.rank_logger.debug("distrib.locale_extents[0]=\n%s" % (distrib.locale_extents[0],)) self.assertEqual( valid_extent, distrib.locale_extents[0] ) valid_extent = \ CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=1, globale_extent=GlobaleExtent(stop=(32,)), cart_coord=(1,), cart_shape=(4,), start=(8,), stop=(16,), halo=halo ) self.rank_logger.debug("valid_extent=\n%s" % (valid_extent,)) self.rank_logger.debug("distrib.locale_extents[1]=\n%s" % (distrib.locale_extents[1],)) self.assertEqual( valid_extent, distrib.locale_extents[1] ) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=2, globale_extent=GlobaleExtent(stop=(32,)), cart_coord=(2,), cart_shape=(4,), start=(16,), stop=(24,), halo=halo ), distrib.locale_extents[2] ) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=3, globale_extent=GlobaleExtent(stop=(32,)), cart_coord=(3,), cart_shape=(4,), start=(24,), stop=(32,), halo=halo ), distrib.locale_extents[3] ) self.root_logger.info("START " + self.id()) self.root_logger.info(str(distrib)) self.root_logger.info("END " + self.id())
[docs] def test_construct_1d_no_halo(self): """ Test :obj:`mpi_array.distribution.BlockPartition` construction. """ self.do_test_construct_1d_with_halo(halo=0)
[docs] def test_construct_1d_with_halo(self): """ Test :obj:`mpi_array.distribution.BlockPartition` construction. """ self.do_test_construct_1d_with_halo(halo=[[2, 4], ])
[docs] def test_construct_1d_empty_tiles(self): """ Test :obj:`mpi_array.distribution.BlockPartition` construction when the partition leads to empty extents. """ halo = 0 distrib = \ BlockPartition( globale_extent=(slice(0, 2),), dims=(4,), halo=halo, cart_coord_to_cart_rank={(i,): i for i in range(0, 4)} ) self.root_logger.info("START " + self.id()) self.root_logger.info(str(distrib)) self.root_logger.info("END " + self.id()) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, globale_extent=GlobaleExtent(stop=(2,)), cart_coord=(0,), cart_shape=(4,), start=(0,), stop=(1,), halo=halo ), distrib.locale_extents[0] ) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=1, globale_extent=GlobaleExtent(stop=(2,)), cart_coord=(1,), cart_shape=(4,), start=(1,), stop=(2,), halo=halo ), distrib.locale_extents[1] ) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=2, globale_extent=GlobaleExtent(stop=(2,)), cart_coord=(2,), cart_shape=(4,), start=(2,), stop=(2,), halo=halo ), distrib.locale_extents[2] ) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=3, globale_extent=GlobaleExtent(stop=(2,)), cart_coord=(3,), cart_shape=(4,), start=(2,), stop=(2,), halo=halo ), distrib.locale_extents[3] )
[docs] def do_test_construct_2d_with_halo(self, halo=0): """ Test :obj:`mpi_array.distribution.BlockPartition` construction. """ distrib = \ BlockPartition( globale_extent=(16, 32), dims=(2, 4), halo=halo, cart_coord_to_cart_rank={ tuple(_np.unravel_index(i, (2, 4))): i for i in range(0, 8) } ) self.root_logger.info("START " + self.id()) self.root_logger.info(str(distrib)) self.root_logger.info("END " + self.id()) self.assertEqual(8, distrib.num_locales) self.assertEqual(8, len(distrib.locale_extents)) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=0, globale_extent=GlobaleExtent(stop=(16, 32,)), cart_coord=(0, 0), cart_shape=(2, 4), start=(0, 0), stop=(8, 8), halo=halo ), distrib.locale_extents[0] ) self.assertEqual( CartLocaleExtent( peer_rank=_mpi.UNDEFINED, inter_locale_rank=7, globale_extent=GlobaleExtent(stop=(16, 32,)), cart_coord=(1, 3), cart_shape=(2, 4), start=(8, 24), stop=(16, 32), halo=halo ), distrib.locale_extents[7] )
[docs] def test_construct_2d_no_halo(self): """ Test :obj:`mpi_array.distribution.BlockPartition` construction. """ self.do_test_construct_2d_with_halo(halo=0)
[docs] def test_construct_2d_with_halo(self): """ Test :obj:`mpi_array.distribution.BlockPartition` construction. """ self.do_test_construct_2d_with_halo(halo=[[1, 2], [3, 4]])
_unittest.main(__name__) __all__ = [s for s in dir() if not s.startswith('_')]