One issue that arises when using self.assertTrue((arr1 == arr2).all()) is if arr1 and arr2 contain an element equal to np.nan in the same location, then (arr1 == arr2).all() evaluates to False, and an AssertionError is raised.
Maybe this is what you want (depending on how you want np.nans handled), but perhaps not. In that case, you might want to consider using this:
Subclass unittest.TestCase:
import unittest
import collections
import itertools
class Sentinel(object): pass
class MyTestCase(unittest.TestCase):
def assertEq(self, first, second, *args, **kwargs):
# http://stackoverflow.com/questions/3022952/test-assertions-for-tuples-with-floats/3124155#3124155
if (isinstance(first, collections.Iterable)
and isinstance(second, collections.Iterable)):
for a, b in itertools.izip_longest(first, second, fillvalue=Sentinel()):
self.assertEq(a, b, *args, **kwargs)
else:
# Ignore nans. (which have the strange property that nan != nan)
if first == first and second == second:
try:
self.assertEqual(first,second,*args,**kwargs)
except self.failureException:
self.assertAlmostEqual(first, second, *args, **kwargs)
then you can use it like this:
class Test(MyTestCase):
def test_foo(self):
self.assertEq((arr1,arr2))
It is recursive, so can handle arbitrarily nested iterables (so long as it fits within Python's adjustable recursion limit). Note that it doesn't compare the containers, only the values of the non-iterable elements. For example, self.assertEq([1,2],(1,2)) would pass.