One issue that arises when using self.assertTrue((arr1 == arr2).all())
is if arr1
and arr2
contain an element equal to np.nan
in the same location, then (arr1 == arr2).all()
evaluates to False
, and an AssertionError
is raised.
Maybe this is what you want (depending on how you want np.nan
s handled), but perhaps not. In that case, you might want to consider using this:
Subclass unittest.TestCase
:
import unittest
import collections
import itertools
class Sentinel(object): pass
class MyTestCase(unittest.TestCase):
def assertEq(self, first, second, *args, **kwargs):
# http://stackoverflow.com/questions/3022952/test-assertions-for-tuples-with-floats/3124155#3124155
if (isinstance(first, collections.Iterable)
and isinstance(second, collections.Iterable)):
for a, b in itertools.izip_longest(first, second, fillvalue=Sentinel()):
self.assertEq(a, b, *args, **kwargs)
else:
# Ignore nans. (which have the strange property that nan != nan)
if first == first and second == second:
try:
self.assertEqual(first,second,*args,**kwargs)
except self.failureException:
self.assertAlmostEqual(first, second, *args, **kwargs)
then you can use it like this:
class Test(MyTestCase):
def test_foo(self):
self.assertEq((arr1,arr2))
It is recursive, so can handle arbitrarily nested iterables (so long as it fits within Python's adjustable recursion limit). Note that it doesn't compare the containers, only the values of the non-iterable elements. For example, self.assertEq([1,2],(1,2))
would pass.