Source code for mpi_array.unittest

"""
======================================
The :mod:`mpi_array.unittest` Module
======================================

Some simple wrappers of python built-in :mod:`unittest` module
for :mod:`mpi_array` unit-tests.

Classes and Functions
=====================

.. autosummary::
   :toctree: generated/
   :template: autosummary/inherits_TestCase_class.rst

   TestCase - Extends :obj:`unittest.TestCase` with :obj:`TestCase.assertArraySplitEqual`.

.. autosummary::
   :toctree: generated/

   TestProgram - Over-ride to use :obj:`logging.Logger` output.
   TextTestRunner - Over-ride to use :obj:`logging.Logger` output.
   TextTestResult - Over-ride to use :obj:`logging.Logger` output.
   main - Convenience command-line test-case *search and run* function.

"""
from __future__ import absolute_import

import unittest as _builtin_unittest
import mpi_array.logging
import numpy as _np
import mpi4py.MPI as _mpi
import time as _time
import warnings as _warnings


def _fix_docstring_for_sphinx(docstr):
    lines = docstr.split("\n")
    for i in range(len(lines)):
        if lines[i].find(" " * 8) == 0:
            lines[i] = lines[i][8:]
    return "\n".join(lines)


[docs]class TestCase(_builtin_unittest.TestCase): """ Extends :obj:`unittest.TestCase` with the :meth:`assertArraySplitEqual`. """
[docs] def assertArraySplitEqual(self, splt1, splt2): """ Compares :obj:`list` of :obj:`numpy.ndarray` results returned by :func:`numpy.mpi_array` and :func:`mpi_array.split.mpi_array` functions. :type splt1: :obj:`list` of :obj:`numpy.ndarray` :param splt1: First object in equality comparison. :type splt2: :obj:`list` of :obj:`numpy.ndarray` :param splt2: Second object in equality comparison. :raises unittest.AssertionError: If any element of :samp:`{splt1}` is not equal to the corresponding element of :samp:`splt2`. """ self.assertEqual(len(splt1), len(splt2)) for i in range(len(splt1)): self.assertTrue( ( _np.all(_np.array(splt1[i]) == _np.array(splt2[i])) or ((_np.array(splt1[i]).size == 0) and (_np.array(splt2[i]).size == 0)) ), msg=( "element %d of split is not equal %s != %s" % (i, _np.array(splt1[i]), _np.array(splt2[i])) ) )
# # Method over-rides below are just to avoid sphinx warnings # def assertItemsEqual(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertItemsEqual`. """ _builtin_unittest.TestCase.assertItemsEqual(self, *args, **kwargs) def assertListEqual(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertListEqual`. """ _builtin_unittest.TestCase.assertListEqual(self, *args, **kwargs) def assertRaisesRegexp(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertRaisesRegexp`. """ _builtin_unittest.TestCase.assertRaisesRegexp(self, *args, **kwargs) def assertRaisesRegex(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertRaisesRegex`. """ _builtin_unittest.TestCase.assertRaisesRegex(self, *args, **kwargs) def assertSetEqual(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertSetEqual`. """ _builtin_unittest.TestCase.assertSetEqual(self, *args, **kwargs) def assertTupleEqual(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertTupleEqual`. """ _builtin_unittest.TestCase.assertTupleEqual(self, *args, **kwargs) def assertWarnsRegex(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertWarnsRegex`. """ _builtin_unittest.TestCase.assertWarnsRegex(self, *args, **kwargs)
if not hasattr(TestCase, "assertSequenceEqual"): # code from python-2.7 unitest.case.TestCase _MAX_LENGTH = 80 def safe_repr(obj, short=False): try: result = repr(obj) except Exception: result = object.__repr__(obj) if not short or len(result) < _MAX_LENGTH: return result return result[:_MAX_LENGTH] + ' [truncated]...' def strclass(cls): return "%s.%s" % (cls.__module__, cls.__name__) def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None): """An equality assertion for ordered sequences (like lists and tuples). For the purposes of this function, a valid ordered sequence type is one which can be indexed, has a length, and has an equality operator. :param seq1: The first sequence to compare. :param seq2: The second sequence to compare. :param seq_type: The expected datatype of the sequences, or None if no datatype should be enforced. :param msg: Optional message to use on failure instead of a list of differences. """ import pprint import difflib if seq_type is not None: seq_type_name = seq_type.__name__ if not isinstance(seq1, seq_type): raise self.failureException('First sequence is not a %s: %s' % (seq_type_name, safe_repr(seq1))) if not isinstance(seq2, seq_type): raise self.failureException('Second sequence is not a %s: %s' % (seq_type_name, safe_repr(seq2))) else: seq_type_name = "sequence" differing = None try: len1 = len(seq1) except (TypeError, NotImplementedError): differing = 'First %s has no length. Non-sequence?' % ( seq_type_name) if differing is None: try: len2 = len(seq2) except (TypeError, NotImplementedError): differing = 'Second %s has no length. Non-sequence?' % ( seq_type_name) if differing is None: if seq1 == seq2: return seq1_repr = safe_repr(seq1) seq2_repr = safe_repr(seq2) if len(seq1_repr) > 30: seq1_repr = seq1_repr[:30] + '...' if len(seq2_repr) > 30: seq2_repr = seq2_repr[:30] + '...' elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr) differing = '%ss differ: %s != %s\n' % elements for i in range(min(len1, len2)): try: item1 = seq1[i] except (TypeError, IndexError, NotImplementedError): differing += ('\nUnable to index element %d of first %s\n' % (i, seq_type_name)) break try: item2 = seq2[i] except (TypeError, IndexError, NotImplementedError): differing += ('\nUnable to index element %d of second %s\n' % (i, seq_type_name)) break if item1 != item2: differing += ('\nFirst differing element %d:\n%s\n%s\n' % (i, item1, item2)) break else: if (len1 == len2 and seq_type is None and not isinstance(seq1, type(seq2))): # The sequences are the same, but have differing types. return if len1 > len2: differing += ('\nFirst %s contains %d additional ' 'elements.\n' % (seq_type_name, len1 - len2)) try: differing += ('First extra element %d:\n%s\n' % (len2, seq1[len2])) except (TypeError, IndexError, NotImplementedError): differing += ('Unable to index element %d ' 'of first %s\n' % (len2, seq_type_name)) elif len1 < len2: differing += ('\nSecond %s contains %d additional ' 'elements.\n' % (seq_type_name, len2 - len1)) try: differing += ('First extra element %d:\n%s\n' % (len1, seq2[len1])) except (TypeError, IndexError, NotImplementedError): differing += ('Unable to index element %d ' 'of second %s\n' % (len1, seq_type_name)) standardMsg = differing diffMsg = '\n' + '\n'.join( difflib.ndiff(pprint.pformat(seq1).splitlines(), pprint.pformat(seq2).splitlines())) standardMsg = self._truncateMessage(standardMsg, diffMsg) msg = self._formatMessage(msg, standardMsg) self.fail(msg) def _formatMessage(self, msg, standardMsg): """Honour the longMessage attribute when generating failure messages. If longMessage is False this means: * Use only an explicit message if it is provided * Otherwise use the standard message for the assert If longMessage is True: * Use the standard message * If an explicit message is provided, plus ' : ' and the explicit message """ if not self.longMessage: return msg or standardMsg if msg is None: return standardMsg try: # don't switch to '{}' formatting in Python 2.X # it changes the way unicode input is handled return '%s : %s' % (standardMsg, msg) except UnicodeDecodeError: return '%s : %s' % (safe_repr(standardMsg), safe_repr(msg)) def _truncateMessage(self, message, diff): DIFF_OMITTED = ('\nDiff is %s characters long. ' 'Set self.maxDiff to None to see it.') max_diff = self.maxDiff if max_diff is None or len(diff) <= max_diff: return message + diff return message + (DIFF_OMITTED % len(diff)) _maxDiff = 80 * 8 setattr(TestCase, "maxDiff", _maxDiff) setattr(TestCase, "_truncateMessage", _truncateMessage) setattr(TestCase, "_formatMessage", _formatMessage) setattr(TestCase, "assertSequenceEqual", assertSequenceEqual) else: def assertSequenceEqual(self, *args, **kwargs): """ See :obj:`unittest.TestCase.assertSequenceEqual`. """ _builtin_unittest.TestCase.assertSequenceEqual(self, *args, **kwargs) setattr(TestCase, "assertSequenceEqual", assertSequenceEqual) class LoggerDecorator: """ Decorator for :obj:`logging.Logger` to provide :meth:`write`, :meth:`writeln` and :meth:`flush` methods. """ def __init__(self, logger): self.logger = logger def write(self, v=""): self.logger.info(v) def writeln(self, v=""): self.logger.info(v) def write_error(self, v): self.logger.error(v) def flush(self): pass
[docs]class TextTestResult(_builtin_unittest.TextTestResult): """ """
[docs] def startTest(self, test): _builtin_unittest.result.TestResult.startTest(self, test) if self.showAll: self.stream.write(self.getDescription(test) + "...") self.stream.flush()
[docs] def addSuccess(self, test): _builtin_unittest.result.TestResult.addSuccess(self, test) if self.showAll: self.stream.write(self.getDescription(test) + "..." + "ok") elif self.dots: self.stream.write('.') self.stream.flush()
[docs] def addError(self, test, err): _builtin_unittest.result.TestResult.addError(self, test, err) if self.showAll: self.stream.write_error("ERROR") elif self.dots: self.stream.write_error('E') self.stream.flush()
[docs] def addFailure(self, test, err): _builtin_unittest.result.TestResult.addFailure(self, test, err) if self.showAll: self.stream.write_error("FAIL") elif self.dots: self.stream.write_error('F') self.stream.flush()
[docs] def addSkip(self, test, reason): _builtin_unittest.result.TestResult.addSkip(self, test, reason) if self.showAll: self.stream.write("skipped {0!r}".format(reason)) elif self.dots: self.stream.write("s") self.stream.flush()
[docs] def addExpectedFailure(self, test, err): _builtin_unittest.result.TestResult.addExpectedFailure(self, test, err) if self.showAll: self.stream.write("expected failure") elif self.dots: self.stream.write("x") self.stream.flush()
[docs] def addUnexpectedSuccess(self, test): _builtin_unittest.result.TestResult.addUnexpectedSuccess(self, test) if self.showAll: self.stream.write("unexpected success") elif self.dots: self.stream.write("u") self.stream.flush()
[docs] def printErrors(self): if self.dots or self.showAll: self.stream.write() self.printErrorList('ERROR', self.errors) self.printErrorList('FAIL', self.failures)
[docs] def printErrorList(self, flavour, errors): for test, err in errors: self.stream.write_error(self.separator1) self.stream.write_error("%s: %s" % (flavour, self.getDescription(test))) self.stream.write_error(self.separator2) self.stream.write_error("%s" % err)
def handle_arg(arg_index, arg_key, arg_value, args, kwargs): """ Replace an argument with a specified value if it does not appear in :samp:`{args}` or :samp:`{kwargs}` argument list. :type arg_index: :obj:`int` :param arg_index: Index of argument in :samp:`{args}` :type arg_key: :obj:`str` :param arg_index: String key of argument in :samp:`{kwargs}` :type arg_value: :obj:`object` :param arg_value: Value for argument if it does not appear in argument lists or has :samp:`None` value in argument lists. :type args: :obj:`list` :param args: List of arguments. :type kwargs: :obj:`dict` :param kwargs: Dictionary of key-word arguments. """ a = None if len(args) > arg_index: a = args[arg_index] if arg_key in kwargs.keys(): a = kwargs[arg_key] if a is None: a = arg_value if len(args) > arg_index: args[arg_index] = a if arg_key in kwargs.keys(): kwargs[arg_key] = a if (len(args) <= arg_index) and (arg_key not in kwargs.keys()): kwargs[arg_key] = a return a
[docs]class TextTestRunner(_builtin_unittest.TextTestRunner): """ A test runner class that displays results in textual form. Extends :obj:`unittest.TextTestRunner` with logging output instead of :obj:`sys.stderr` output. """
[docs] def __init__(self, *args, **kwargs): handle_arg(5, "resultclass", TextTestResult, args, kwargs) verbosity = handle_arg(2, "verbosity", 0, args, kwargs) logger_name = __name__ + ".TextTestRunner" logger = mpi_array.logging.get_rank_logger(logger_name) log_level = mpi_array.logging.WARN if verbosity <= 1: if _mpi.COMM_WORLD.rank == 0: log_level = mpi_array.logging.INFO else: log_level = mpi_array.logging.INFO mpi_array.logging.initialise_loggers([logger_name, ], log_level) stream = LoggerDecorator(logger) handle_arg(0, "stream", stream, args, kwargs) _builtin_unittest.TextTestRunner.__init__(self, *args, **kwargs) # Remove _WritelnDecorator decoration for LoggerDecorator if hasattr(self.stream, "stream") and isinstance(self.stream.stream, LoggerDecorator): self.stream = self.stream.stream if not hasattr(self, "warnings"): self.warnings = None
[docs] def run(self, test): """ Run the given test case or test suite. """ result = self._makeResult() _builtin_unittest.registerResult(result) result.failfast = self.failfast result.buffer = self.buffer with _warnings.catch_warnings(): if self.warnings: # if self.warnings is set, use it to filter all the warnings _warnings.simplefilter(self.warnings) # if the filter is 'default' or 'always', special-case the # warnings from the deprecated unittest methods to show them # no more than once per module, because they can be fairly # noisy. The -Wd and -Wa flags can be used to bypass this # only when self.warnings is None. if self.warnings in ['default', 'always']: _warnings.filterwarnings( 'module', category=DeprecationWarning, message='Please use assert\w+ instead.' ) startTime = _time.time() startTestRun = getattr(result, 'startTestRun', None) if startTestRun is not None: startTestRun() try: test(result) finally: stopTestRun = getattr(result, 'stopTestRun', None) if stopTestRun is not None: stopTestRun() stopTime = _time.time() timeTaken = stopTime - startTime result.printErrors() if hasattr(result, 'separator2'): self.stream.writeln(result.separator2) run = result.testsRun self.stream.writeln("Ran %d test%s in %.3fs (COMM_WORLD.size=%3d)" % (run, run != 1 and "s" or "", timeTaken, _mpi.COMM_WORLD.size)) self.stream.writeln() expectedFails = unexpectedSuccesses = skipped = 0 try: results = map(len, (result.expectedFailures, result.unexpectedSuccesses, result.skipped)) except AttributeError: pass else: expectedFails, unexpectedSuccesses, skipped = results infos = [] if not result.wasSuccessful(): self.stream.write("FAILED") failed, errored = len(result.failures), len(result.errors) if failed: infos.append("failures=%d" % failed) if errored: infos.append("errors=%d" % errored) else: self.stream.write("OK") if skipped: infos.append("skipped=%d" % skipped) if expectedFails: infos.append("expected failures=%d" % expectedFails) if unexpectedSuccesses: infos.append("unexpected successes=%d" % unexpectedSuccesses) if infos: self.stream.writeln(" (%s)" % (", ".join(infos),)) else: self.stream.write("\n") return result
[docs]class TestProgram(_builtin_unittest.TestProgram): """ A command-line program that runs a set of tests, extends :obj:`unittest.TestProgram` by using logging rather than standard stream. """
[docs] def __init__(self, *args, **kwargs): handle_arg(3, "testRunner", TextTestRunner, args, kwargs) _builtin_unittest.TestProgram.__init__(self, *args, **kwargs)
[docs]def main( module_name, log_level=mpi_array.logging.DEBUG, init_logger_names=None, verbosity=None, failfast=None ): """ Like :func:`unittest.main`, initialises :mod:`logging.Logger` objects and instantiates a :obj:`TestProgram` to discover and run :obj:`TestCase` objects. Loads a set of tests from module and runs them; this is primarily for making test modules conveniently executable. The simplest use for this function is to include the following line at the end of a test module:: mpi_array.unittest.main(__name__) If :samp:`__name__ == "__main__"`, then *discoverable* :obj:`unittest.TestCase` test cases are executed. Logging level can be explicitly set for a group of modules using:: import logging mpi_array.unittest.main( __name__, logging.DEBUG, [__name__, "module_name_0", "module_name_1", "package.module_name_2"] ) :type module_name: :obj:`str` :param module_name: If :samp:`{module_name} == "__main__"` then unit-tests are *discovered* and run. :type log_level: :obj:`int` :param log_level: The default logging level for all :obj:`mpi_array.logging.Logger` objects. :type init_logger_names: sequence of :obj:`str` :param init_logger_names: List of logger names to initialise (using :func:`mpi_array.logging.initialise_loggers`). If :samp:`None`, then the list defaults to :samp:`[{module_name}, "mpi_array"]`. If list is empty no loggers are initialised. """ if module_name == "__main__": if (init_logger_names is None): init_logger_names = [module_name, "mpi_array"] if (len(init_logger_names) > 0): mpi_array.logging.initialise_loggers( init_logger_names, log_level=log_level) kwargs = dict() if failfast is not None: kwargs["failfast"] = failfast if verbosity is not None: kwargs["verbosity"] = verbosity TestProgram(**kwargs)
__all__ = [s for s in dir() if not s.startswith('_')]