Source code for mpi_array.indexing_test

"""
=========================================
The :mod:`mpi_array.indexing_test` Module
=========================================

Module defining :mod:`mpi_array.indexing` unit-tests.
Execute as::

   python -m mpi_array.indexing_test


Classes
=======

.. autosummary::
   :toctree: generated/
   :template: autosummary/inherits_TestCase_class.rst

   IndexingExtentTest - Tests for :obj:`mpi_array.indexing.IndexingExtent`.
   HaloIndexingExtentTest - Tests for :obj:`mpi_array.indexing.IndexingExtent`.


"""
from __future__ import absolute_import

import numpy as _np  # noqa: E402,F401

from .license import license as _license, copyright as _copyright, version as _version
from . import unittest as _unittest
from . import logging as _logging  # noqa: E402,F401
from .indexing import IndexingExtent, HaloIndexingExtent, calc_intersection_split


__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _version


[docs]class IndexingExtentTest(_unittest.TestCase): """ :obj:`unittest.TestCase` for :obj:`mpi_array.indexing.IndexingExtentTest`. """
[docs] def test_repr(self): """ Test for :samp:`repr(IndexingExtent(start=(1,2,3), stop=(8,9,10)))`. """ ie = IndexingExtent(start=(10,), stop=(32,)) self.assertNotEqual(None, str(ie)) self.assertNotEqual("", str(ie)) self.assertEqual(ie, eval(repr(ie)))
[docs] def test_to_tuple(self): """ Test for :meth:`IndexingExtent.to_tuple`. """ ie = IndexingExtent(start=(10, 15), stop=(32, 66)) self.assertEqual(ie, IndexingExtent(*(ie.to_tuple())))
[docs] def test_assign_different_dimension_index(self): """ Test for :meth:`IndexingExtent.start = ...`. """ ie = IndexingExtent(start=(10, 15), stop=(32, 66)) def assign_start(): ie.start = (1,) def assign_stop(): ie.stop = (1,) self.assertRaises(ValueError, assign_start) self.assertRaises(ValueError, assign_stop)
[docs] def test_attributes(self): """ Tests :attr:`mpi_array.indexing.IndexingExtent.start` and :attr:`mpi_array.indexing.IndexingExtent.stop` and :attr:`mpi_array.indexing.IndexingExtent.shape` attributes. """ ie = IndexingExtent(start=(10,), stop=(32,)) self.assertTrue(_np.all(ie.shape == (22,))) self.assertTrue(_np.all(ie.start == (10,))) self.assertTrue(_np.all(ie.stop == (32,))) ie = IndexingExtent((slice(10, 32),)) self.assertTrue(_np.all(ie.shape == (22,))) self.assertTrue(_np.all(ie.start == (10,))) self.assertTrue(_np.all(ie.stop == (32,))) ie = IndexingExtent(start=(10, 25), stop=(32, 55)) self.assertTrue(_np.all(ie.shape == (22, 30))) self.assertTrue(_np.all(ie.start == (10, 25))) self.assertTrue(_np.all(ie.stop == (32, 55))) ie = IndexingExtent((slice(10, 32), slice(25, 55))) self.assertTrue(_np.all(ie.shape == (22, 30))) self.assertTrue(_np.all(ie.start == (10, 25))) self.assertTrue(_np.all(ie.stop == (32, 55))) ie = IndexingExtent((slice(10, 32), slice(25, 55))) ie.start = (5, 6) self.assertSequenceEqual([5, 6], ie.start.tolist()) ie.stop = (11, 12) self.assertSequenceEqual([11, 12], ie.stop.tolist())
[docs] def test_intersection_1d(self): """ Tests :meth:`mpi_array.indexing.IndexingExtent.calc_intersection` method, 1D indexing. """ ie0 = IndexingExtent(start=(10,), stop=(32,)) iei = ie0.calc_intersection(ie0) self.assertTrue(_np.all(iei.shape == (22,))) self.assertTrue(_np.all(iei.start == (10,))) self.assertTrue(_np.all(iei.stop == (32,))) ie1 = IndexingExtent(start=(5,), stop=(32,)) iei = ie0.calc_intersection(ie1) self.assertTrue(_np.all(iei.shape == (22,))) self.assertTrue(_np.all(iei.start == (10,))) self.assertTrue(_np.all(iei.stop == (32,))) ie1 = IndexingExtent(start=(10,), stop=(39,)) iei = ie0.calc_intersection(ie1) self.assertTrue(_np.all(iei.shape == (22,))) self.assertTrue(_np.all(iei.start == (10,))) self.assertTrue(_np.all(iei.stop == (32,))) ie1 = IndexingExtent(start=(-5,), stop=(39,)) iei = ie0.calc_intersection(ie1) self.assertTrue(_np.all(iei.shape == (22,))) self.assertTrue(_np.all(iei.start == (10,))) self.assertTrue(_np.all(iei.stop == (32,))) ie1 = IndexingExtent(start=(11,), stop=(31,)) iei = ie0.calc_intersection(ie1) self.assertTrue(_np.all(iei.shape == (20,))) self.assertTrue(_np.all(iei.start == (11,))) self.assertTrue(_np.all(iei.stop == (31,))) ie1 = IndexingExtent(start=(5,), stop=(10,)) iei = ie0.calc_intersection(ie1) self.assertEqual(None, iei) ie1 = IndexingExtent(start=(32,), stop=(55,)) iei = ie0.calc_intersection(ie1) self.assertEqual(None, iei)
[docs] def test_intersection_2d(self): """ Tests :meth:`mpi_array.indexing.IndexingExtent.calc_intersection` method, 2D indexing. """ ie0 = IndexingExtent(start=(10, 20), stop=(32, 64)) iei = ie0.calc_intersection(ie0) self.assertSequenceEqual(ie0.shape.tolist(), iei.shape.tolist()) self.assertSequenceEqual(ie0.start.tolist(), iei.start.tolist()) self.assertSequenceEqual(ie0.stop.tolist(), iei.stop.tolist()) ie1 = IndexingExtent(start=(0, 20), stop=(10, 64)) iei = ie0.calc_intersection(ie1) self.assertEqual(None, iei) ie1 = IndexingExtent(start=(10, 0), stop=(32, 20)) iei = ie0.calc_intersection(ie1) self.assertEqual(None, iei) ie1 = IndexingExtent(start=(0, 0), stop=(10, 20)) iei = ie0.calc_intersection(ie1) self.assertEqual(None, iei) ie1 = IndexingExtent(start=(32, 64), stop=(110, 120)) iei = ie0.calc_intersection(ie1) self.assertEqual(None, iei) ie1 = IndexingExtent(start=(20, 10), stop=(30, 40)) iei = ie0.calc_intersection(ie1) self.assertSequenceEqual([10, 20], iei.shape.tolist()) self.assertSequenceEqual([20, 20], iei.start.tolist()) self.assertSequenceEqual([30, 40], iei.stop.tolist()) ie1 = IndexingExtent(start=(22, 54), stop=(80, 90)) iei = ie0.calc_intersection(ie1) self.assertSequenceEqual([10, 10], iei.shape.tolist()) self.assertSequenceEqual([22, 54], iei.start.tolist()) self.assertSequenceEqual([32, 64], iei.stop.tolist())
[docs] def test_split(self): """ Test for :meth:`mpi_array.indexing.IndexingExtent.split`. """ ie = IndexingExtent(start=(10, 3), stop=(32, 20)) lo, hi = ie.split(0, 10) self.assertTrue(lo is None) self.assertTrue(hi is ie) lo, hi = ie.split(1, 3) self.assertTrue(lo is None) self.assertTrue(hi is ie) lo, hi = ie.split(0, 32) self.assertTrue(lo is ie) self.assertTrue(hi is None) lo, hi = ie.split(1, 20) self.assertTrue(lo is ie) self.assertTrue(hi is None) lo, hi = ie.split(0, 11) self.assertEqual(IndexingExtent(start=(10, 3), stop=(11, 20)), lo) self.assertEqual(IndexingExtent(start=(11, 3), stop=(32, 20)), hi) lo, hi = ie.split(1, 4) self.assertEqual(IndexingExtent(start=(10, 3), stop=(32, 4)), lo) self.assertEqual(IndexingExtent(start=(10, 4), stop=(32, 20)), hi) lo, hi = ie.split(0, 31) self.assertEqual(IndexingExtent(start=(10, 3), stop=(31, 20)), lo) self.assertEqual(IndexingExtent(start=(31, 3), stop=(32, 20)), hi) lo, hi = ie.split(1, 19) self.assertEqual(IndexingExtent(start=(10, 3), stop=(32, 19)), lo) self.assertEqual(IndexingExtent(start=(10, 19), stop=(32, 20)), hi)
[docs] def test_calc_intersection_split(self): """ Test for :meth:`mpi_array.indexing.IndexingExtent.calc_intersection_split`. """ ie = IndexingExtent(start=(0, 50), stop=(50, 100)) other = IndexingExtent(start=(0, 50), stop=(50, 100)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(0, len(leftovers)) self.assertEqual(intersection, ie) self.assertEqual(intersection, other) other = IndexingExtent(start=(25, 50), stop=(50, 100)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(intersection, other) self.assertEqual(1, len(leftovers)) self.assertEqual(IndexingExtent(start=(0, 50), stop=(25, 100)), leftovers[0]) other = IndexingExtent(start=(0, 50), stop=(25, 100)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(intersection, other) self.assertEqual(1, len(leftovers)) self.assertEqual(IndexingExtent(start=(25, 50), stop=(50, 100)), leftovers[0]) other = IndexingExtent(start=(0, 50), stop=(50, 75)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(intersection, other) self.assertEqual(1, len(leftovers)) self.assertEqual(IndexingExtent(start=(0, 75), stop=(50, 100)), leftovers[0]) other = IndexingExtent(start=(0, 75), stop=(50, 100)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(intersection, other) self.assertEqual(1, len(leftovers)) self.assertEqual(IndexingExtent(start=(0, 50), stop=(50, 75)), leftovers[0]) other = IndexingExtent(start=(0, 50), stop=(25, 75)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(intersection, other) self.assertEqual(2, len(leftovers)) self.assertEqual(IndexingExtent(start=(25, 50), stop=(50, 100)), leftovers[0]) self.assertEqual(IndexingExtent(start=(0, 75), stop=(25, 100)), leftovers[1]) other = IndexingExtent(start=(25, 75), stop=(50, 100)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(intersection, other) self.assertEqual(2, len(leftovers)) self.assertEqual(IndexingExtent(start=(0, 50), stop=(25, 100)), leftovers[0]) self.assertEqual(IndexingExtent(start=(25, 50), stop=(50, 75)), leftovers[1]) other = IndexingExtent(start=(20, 60), stop=(40, 80)) leftovers, intersection = ie.calc_intersection_split(other) self.assertEqual(intersection, other) self.assertEqual(4, len(leftovers)) self.assertEqual(IndexingExtent(start=(0, 50), stop=(20, 100)), leftovers[0]) self.assertEqual(IndexingExtent(start=(40, 50), stop=(50, 100)), leftovers[1]) self.assertEqual(IndexingExtent(start=(20, 50), stop=(40, 60)), leftovers[2]) self.assertEqual(IndexingExtent(start=(20, 80), stop=(40, 100)), leftovers[3])
[docs]class HaloIndexingExtentTest(_unittest.TestCase): """ :obj:`unittest.TestCase` for :obj:`mpi_array.indexing.HaloIndexingExtentTest`. """
[docs] def test_repr(self): """ Test for :samp:`repr(HaloIndexingExtent(start=(1,2,3), stop=(8,9,10)))`. """ ie = HaloIndexingExtent(start=(10, 15), stop=(32, 66), halo=((1, 2), (3, 4))) self.assertNotEqual(None, str(ie)) self.assertNotEqual("", str(ie)) self.assertEqual(ie, eval(repr(ie)))
[docs] def test_to_tuple(self): """ Test for :meth:`HaloIndexingExtent.to_tuple`. """ ie = HaloIndexingExtent(start=(10, 15), stop=(32, 66), halo=((1, 2), (3, 4))) self.assertEqual(ie, HaloIndexingExtent(*(ie.to_tuple())))
[docs] def test_attributes(self): """ :obj:`unittest.TestCase` for :obj:`mpi_array.indexing.HaloIndexingExtentTest` attributes. """ hie1 = HaloIndexingExtent(start=(10, 0), stop=(32, 20), halo=_np.array(((0, 0), (0, 0)))) self.assertSequenceEqual([10, 0], hie1.start_n.tolist()) self.assertSequenceEqual([10, 0], hie1.start_h.tolist()) self.assertSequenceEqual([32, 20], hie1.stop_n.tolist()) self.assertSequenceEqual([32, 20], hie1.stop_h.tolist()) self.assertSequenceEqual([22, 20], hie1.shape_n.tolist()) self.assertSequenceEqual([22, 20], hie1.shape_h.tolist()) self.assertEqual(22 * 20, hie1.size_n) self.assertEqual(22 * 20, hie1.size_h) hie1 = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertSequenceEqual([10, 3], hie1.start_n.tolist()) self.assertSequenceEqual([9, 0], hie1.start_h.tolist()) self.assertSequenceEqual([32, 20], hie1.stop_n.tolist()) self.assertSequenceEqual([34, 24], hie1.stop_h.tolist()) self.assertSequenceEqual([22, 17], hie1.shape_n.tolist()) self.assertSequenceEqual([25, 24], hie1.shape_h.tolist()) self.assertEqual(22 * 17, hie1.size_n) self.assertEqual(25 * 24, hie1.size_h) ie = HaloIndexingExtent((slice(10, 32), slice(25, 55))) ie.start_n = (3, 4) self.assertSequenceEqual([3, 4], ie.start_n.tolist()) self.assertSequenceEqual([3, 4], ie.start.tolist()) ie.stop_n = (8, 9) self.assertSequenceEqual([8, 9], ie.stop_n.tolist()) self.assertSequenceEqual([8, 9], ie.stop.tolist()) ie.halo = [[1, 2], [4, 8]] self.assertSequenceEqual([[1, 2], [4, 8]], ie.halo.tolist()) ie.halo = 0 self.assertSequenceEqual([[0, 0], [0, 0]], ie.halo.tolist())
[docs] def test_globale_and_locale_index_conversion(self): """ Test for :meth:`mpi_array.indexing.HaloIndexingExtent.globale_to_locale_h`, and :meth:`mpi_array.indexing.HaloIndexingExtent.locale_to_globale_h`. """ hie = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertSequenceEqual( [1, 3], list(hie.globale_to_locale_h((10, 3))) ) self.assertSequenceEqual( [10, 3], list(hie.locale_to_globale_h(hie.globale_to_locale_h((10, 3)))) ) self.assertSequenceEqual( [0, 0], list(hie.globale_to_locale_n((10, 3))) ) self.assertSequenceEqual( [10, 3], list(hie.locale_to_globale_n(hie.globale_to_locale_n((10, 3)))) )
[docs] def test_globale_and_locale_extent_conversion(self): """ Test for :meth:`mpi_array.indexing.HaloIndexingExtent.globale_to_locale_h`, and :meth:`mpi_array.indexing.HaloIndexingExtent.locale_to_globale_h`. """ hie = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) gext = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertEqual( HaloIndexingExtent(start=(1, 3), stop=(23, 20), halo=_np.array(((1, 2), (3, 4)))), hie.globale_to_locale_extent_h(gext) ) gext = IndexingExtent(start=(10, 3), stop=(32, 20)) self.assertEqual( IndexingExtent(start=(1, 3), stop=(23, 20)), hie.globale_to_locale_extent_h(gext) ) lext = HaloIndexingExtent(start=(1, 3), stop=(23, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertEqual( hie.ndim, hie.locale_to_globale_extent_h(lext).ndim ) self.assertEqual( hie, hie.locale_to_globale_extent_h(lext) ) lext = IndexingExtent(start=(1, 3), stop=(23, 20)) self.assertEqual( IndexingExtent(start=hie.start, stop=hie.stop), hie.locale_to_globale_extent_h(lext) )
[docs] def test_globale_and_locale_slice_conversion(self): """ Test for :meth:`mpi_array.indexing.HaloIndexingExtent.globale_to_locale_slice_h`, and :meth:`mpi_array.indexing.HaloIndexingExtent.locale_to_globale_slice_h`. """ hie = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) gext = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertEqual( HaloIndexingExtent( start=(1, 3), stop=(23, 20), halo=_np.array(((1, 2), (3, 4))) ).to_slice_h(), hie.globale_to_locale_slice_h(gext.to_slice_h()) ) self.assertEqual( HaloIndexingExtent( start=(0, 0), stop=(22, 17), halo=_np.array(((1, 2), (3, 4))) ).to_slice_n(), hie.globale_to_locale_slice_n(gext.to_slice_n()) ) lext = HaloIndexingExtent(start=(1, 3), stop=(23, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertEqual( hie.to_slice_h(), hie.locale_to_globale_slice_h(lext.to_slice_h()) ) lext = HaloIndexingExtent(start=(0, 0), stop=(22, 17), halo=_np.array(((1, 2), (3, 4)))) self.assertEqual( hie.to_slice_n(), hie.locale_to_globale_slice_n(lext.to_slice_n()) )
[docs] def test_to_slice(self): """ :obj:`unittest.TestCase` for :obj:`mpi_array.indexing.HaloIndexingExtent` methods: :samp:`to_slice`, :samp:`to_slice_n`, and :samp:`to_slice_h`. """ hie1 = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertSequenceEqual( (slice(10, 32, None), slice(3, 20, None)), hie1.to_slice_n() ) self.assertSequenceEqual( (slice(10, 32, None), slice(3, 20, None)), hie1.to_slice() ) self.assertSequenceEqual( (slice(9, 34, None), slice(0, 24, None)), hie1.to_slice_h() )
[docs] def test_start_stop_shape(self): """ :obj:`unittest.TestCase` for :obj:`mpi_array.indexing.HaloIndexingExtent` attributes: :samp:`start`, :samp:`stop`, and :samp:`shape`. """ hie1 = HaloIndexingExtent(start=(10, 3), stop=(32, 20), halo=_np.array(((1, 2), (3, 4)))) self.assertSequenceEqual(hie1.start_n.tolist(), hie1.start.tolist()) self.assertSequenceEqual(hie1.stop_n.tolist(), hie1.stop.tolist()) self.assertSequenceEqual(hie1.shape_n.tolist(), hie1.shape.tolist())
[docs] def test_calc_intersection_split(self): """ Tests for :obj:`mpi_array.indexing.calc_intersection_split`. """ def update_factory(dst_extent, src_extent, intersection): return [(dst_extent, src_extent, intersection), ] dst_extent = HaloIndexingExtent(start=(4, 50), stop=(50, 100), halo=4) src_extent = HaloIndexingExtent(start=(50, 46), stop=(150, 104), halo=4) update_dst_halo = False leftovers, updates = \ calc_intersection_split( dst_extent, src_extent, update_factory, update_dst_halo ) self.assertEqual(0, len(updates)) self.assertEqual(1, len(leftovers)) self.assertTrue(leftovers[0] is dst_extent) dst_extent = HaloIndexingExtent(start=(4, 50), stop=(50, 100), halo=4) src_extent = HaloIndexingExtent(start=(50, 46), stop=(150, 104), halo=4) update_dst_halo = True leftovers, updates = \ calc_intersection_split( dst_extent, src_extent, update_factory, update_dst_halo ) self.assertEqual(1, len(updates)) self.assertEqual( IndexingExtent(start=(50, 46), stop=(54, 104)), updates[0][2] ) self.assertEqual(1, len(leftovers)) self.assertEqual( HaloIndexingExtent(start=(0, 46), stop=(50, 104)), leftovers[0] ) dst_extent = HaloIndexingExtent(start=(4, 50), stop=(50, 100), halo=4) src_extent = HaloIndexingExtent(start=(20, 30), stop=(32, 128), halo=4) update_dst_halo = False leftovers, updates = \ calc_intersection_split( dst_extent, src_extent, update_factory, update_dst_halo ) self.assertEqual(1, len(updates)) self.assertEqual( IndexingExtent(start=(20, 50), stop=(32, 100)), updates[0][2] ) self.assertEqual(2, len(leftovers)) self.assertEqual( HaloIndexingExtent(start=(4, 50), stop=(20, 100), halo=src_extent.halo), leftovers[0] ) self.assertEqual( HaloIndexingExtent(start=(32, 50), stop=(50, 100), halo=src_extent.halo), leftovers[1] )
_unittest.main(__name__) __all__ = [s for s in dir() if not s.startswith('_')]