views:

111

answers:

3

I want to make some unittests for my app, and I need to compare two arrays. Since array.__eq__ returns a new array (so TestCase.assertEqual fails), what is the best way to assert for equality?

Currently I'm using

self.assertTrue((arr1 == arr2).all())

but I don't really like it :\

+2  A: 

One issue that arises when using self.assertTrue((arr1 == arr2).all()) is if arr1 and arr2 contain an element equal to np.nan in the same location, then (arr1 == arr2).all() evaluates to False, and an AssertionError is raised.

Maybe this is what you want (depending on how you want np.nans handled), but perhaps not. In that case, you might want to consider using this:

Subclass unittest.TestCase:

import unittest
import collections
import itertools

class Sentinel(object): pass
class MyTestCase(unittest.TestCase):
    def assertEq(self, first, second, *args, **kwargs):
       # http://stackoverflow.com/questions/3022952/test-assertions-for-tuples-with-floats/3124155#3124155 
       if (isinstance(first, collections.Iterable)
           and isinstance(second, collections.Iterable)):
           for a, b in itertools.izip_longest(first, second, fillvalue=Sentinel()):
               self.assertEq(a, b, *args, **kwargs)
       else:
           # Ignore nans. (which have the strange property that nan != nan)
           if first == first and second == second:
               try:
                   self.assertEqual(first,second,*args,**kwargs)
               except self.failureException:           
                   self.assertAlmostEqual(first, second, *args, **kwargs)

then you can use it like this:

class Test(MyTestCase):
    def test_foo(self):
        self.assertEq((arr1,arr2))

It is recursive, so can handle arbitrarily nested iterables (so long as it fits within Python's adjustable recursion limit). Note that it doesn't compare the containers, only the values of the non-iterable elements. For example, self.assertEq([1,2],(1,2)) would pass.

unutbu
+1  A: 

I think (arr1 == arr2).all() looks pretty nice. But you could use:

numpy.allclose(arr1, arr2)

but it's not quite the same.

An alternative, almost the same as your example is:

numpy.alltrue(arr1 == arr2)

Note that scipy.array is actually a reference numpy.array. That makes it easier to find the documentation.

DiggyF
+2  A: 

check out the assert functions in numpy.testing, e.g.

assert_array_equal

for floating point arrays equality test might fail and assert_almost_equal is more reliable.