"""
=========================================
The :mod:`mpi_array.types_test` Module
=========================================
Module defining :mod:`mpi_array.types` unit-tests.
Execute as::
python -m mpi_array.types_test
Classes
=======
.. autosummary::
:toctree: generated/
:template: autosummary/inherits_TestCase_class.rst
TypesTest - Tests for :func:`mpi_array.types.to_datatype`.
"""
from __future__ import absolute_import
import numpy as _np # noqa: E402,F401
import mpi4py.MPI as _mpi
from .license import license as _license, copyright as _copyright, version as _version
from . import types as _types
from . import logging as _logging
from . import unittest as _unittest
__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _version
[docs]class TypesTest(_unittest.TestCase):
"""
:obj:`unittest.TestCase` for :obj:`mpi_array.types.to_datatype`.
"""
[docs] def setUp(self):
"""
Set up, assign :obj:`logging.Logger` object.
"""
self.logger = _logging.get_rank_logger(self.id())
[docs] def test_basic(self):
"""
Test for :obj:`mpi_array.types.to_datatype` converting basic :obj:`numpy.dtype`
to MPI data types.
"""
type_tuples = [
["bool", _mpi.BOOL, ],
["uint8", _mpi.UNSIGNED_CHAR, ],
["int8", _mpi.SIGNED_CHAR, ],
["uint16", _mpi.UNSIGNED_SHORT, ],
["int16", _mpi.SIGNED_SHORT, ],
["uint32", _mpi.UNSIGNED_INT, _mpi.UNSIGNED_LONG, ],
["int32", _mpi.SIGNED_INT, _mpi.INT, _mpi.LONG],
["uint64", _mpi.UNSIGNED_LONG_LONG, _mpi.UNSIGNED_LONG],
["int64", _mpi.AINT, _mpi.LONG, _mpi.LONG_INT, _mpi.SIGNED_LONG_LONG, _mpi.SIGNED_LONG],
["float32", _mpi.FLOAT, ],
["float64", _mpi.DOUBLE, _mpi.LONG_DOUBLE, ],
]
for tup in type_tuples:
np_dt = tup[0]
mpi_dt_from_np_dt = _types.to_datatype(np_dt)
matches =\
_np.any(
tuple(
mpi_dt_from_np_dt == mpi_dt for mpi_dt in tup[1:]
)
)
self.assertTrue(
matches,
"%s converted to %s is not equal to any of %s"
%
(
np_dt,
mpi_dt_from_np_dt.Get_name(),
tuple(mpi_dt.Get_name() for mpi_dt in tup[1:])
)
)
[docs] def test_contiguous(self):
"""
Test for :obj:`mpi_array.types.to_datatype` converting contiguous :obj:`numpy.dtype`
to MPI data types.
"""
mpi_dt_from_np_dt = _types.to_datatype("float16")
self.assertEqual(2, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = _types.to_datatype(("float16", (4,)))
self.assertEqual(8, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = _types.to_datatype(("int32", (4,)))
self.assertEqual(16, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = _types.to_datatype(("int8", (3,)))
self.assertEqual(3, mpi_dt_from_np_dt.Get_extent()[1])
[docs] def test_struct(self):
"""
Test for :obj:`mpi_array.types.to_datatype` converting structure :obj:`numpy.dtype`
to MPI data types.
"""
mpi_dt_from_np_dt = _types.to_datatype([("m0", "float16"), ])
self.assertEqual(2, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = _types.to_datatype([("m0", "float16"), ("m1", "float16")])
self.assertEqual(4, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = _types.to_datatype([("b0", "int8"), ("b1", "int8")])
self.assertEqual(2, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = _types.to_datatype([("b0", "int8"), ("b1", "int8"), ("b2", "int8")])
self.assertEqual(3, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = \
_types.to_datatype("float64")
self.assertEqual(8, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = \
_types.to_datatype([("m0", "float16"), ])
self.assertEqual(2, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = \
_types.to_datatype([("m0", "int32"), ])
self.assertEqual(4, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = \
_types.to_datatype([("m0", "float64"), ("m1", "float64"), ("m2", "float64")])
self.assertEqual(24, mpi_dt_from_np_dt.Get_extent()[1])
mpi_dt_from_np_dt = \
_types.to_datatype(
[("m0", "float64"), ("m1", "int16"), ("m2", "float16"), ("m3", "int32")]
)
self.assertEqual(0, mpi_dt_from_np_dt.Get_extent()[0])
self.assertEqual(16, mpi_dt_from_np_dt.Get_extent()[1])
[docs] def test_struct_bcast(self):
"""
Tests MPI communications with
`structured arrays <https://docs.scipy.org/doc/numpy/user/basics.rec.html>`_.
"""
comm = _mpi.COMM_WORLD
for align in [True, False]:
np_dt = _np.dtype([("m0", "int8"), ("m1", "float64"), ("m2", "int16"), ], align=align)
mpi_dt = _types.to_datatype(np_dt)
ary = _np.empty((101,), dtype=np_dt)
self.logger.info("ary.dtype.isalignedstruct = %s" % (ary.dtype.isalignedstruct,))
self.logger.info("ary.nbytes = %s" % (ary.nbytes,))
self.logger.info("ary.itemsize = %s" % (ary.itemsize,))
self.logger.info("mpi_dt.Get_extent() = %s" % (mpi_dt.Get_extent(),))
ary["m0"] = _np.arange(0, ary.shape[0])
ary["m1"] = _np.arange(ary.shape[0], 2 * ary.shape[0])
ary["m2"] = _np.arange(2 * ary.shape[0], 3 * ary.shape[0])
root_rank = 0
bcast_ary = _np.empty_like(ary)
if comm.rank == root_rank:
bcast_ary = ary.copy()
comm.Bcast([bcast_ary, mpi_dt], root_rank)
all_bcast_ary = _np.empty((comm.size, bcast_ary.shape[0]), dtype=bcast_ary.dtype)
comm.Allgather([bcast_ary, mpi_dt], [all_bcast_ary, mpi_dt])
self.assertSequenceEqual(ary.tolist(), bcast_ary.tolist())
for i in range(all_bcast_ary.shape[0]):
self.assertSequenceEqual(ary.tolist(), all_bcast_ary[i].tolist())
_unittest.main(__name__)
__all__ = [s for s in dir() if not s.startswith('_')]