Skip to content

Commit c111e9d

Browse files
committed
Avoid making multiple ApproxNumpy types.
1 parent 8524a57 commit c111e9d

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

_pytest/python_api.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,9 @@ class ApproxNumpyBase(ApproxBase):
4949
"""
5050
Perform approximate comparisons for numpy arrays.
5151
52-
This class should not be used directly. Instead, it should be used to make
53-
a subclass that also inherits from `np.ndarray`, e.g.::
54-
55-
import numpy as np
56-
ApproxNumpy = type('ApproxNumpy', (ApproxNumpyBase, np.ndarray), {})
57-
58-
This bizarre invocation is necessary because the object doing the
52+
This class should not be used directly. Instead, the `inherit_ndarray()`
53+
class method should be used to make a subclass that also inherits from
54+
`np.ndarray`. This indirection is necessary because the object doing the
5955
approximate comparison must inherit from `np.ndarray`, or it will only work
6056
on the left side of the `==` operator. But importing numpy is relatively
6157
expensive, so we also want to avoid that unless we actually have a numpy
@@ -81,6 +77,18 @@ class ApproxNumpyBase(ApproxBase):
8177
it appears on.
8278
"""
8379

80+
subclass = None
81+
82+
@classmethod
83+
def inherit_ndarray(cls):
84+
import numpy as np
85+
assert not isinstance(cls, np.ndarray)
86+
87+
if cls.subclass is None:
88+
cls.subclass = type('ApproxNumpy', (ApproxNumpyBase, np.ndarray), {})
89+
90+
return cls.subclass
91+
8492
def __new__(cls, expected, rel=None, abs=None, nan_ok=False):
8593
"""
8694
Numpy uses __new__ (rather than __init__) to initialize objects.
@@ -416,8 +424,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
416424
if _is_numpy_array(expected):
417425
# Create the delegate class on the fly. This allow us to inherit from
418426
# ``np.ndarray`` while still not importing numpy unless we need to.
419-
import numpy as np
420-
cls = type('ApproxNumpy', (ApproxNumpyBase, np.ndarray), {})
427+
cls = ApproxNumpyBase.inherit_ndarray()
421428
elif isinstance(expected, Mapping):
422429
cls = ApproxMapping
423430
elif isinstance(expected, Sequence) and not isinstance(expected, String):

0 commit comments

Comments
 (0)