Skip to content

Commit 6e0bbeb

Browse files
committed
BUG: fix edge case in np.ma.allequal
Fix an edge case in numpy/ma/core.py::allequal where calling the function on the same input (i.e. `allequal(x, x)`) where the input is an unmasked array (i.e. `mask=np.ma.nomask`) would return `False`. The fix involves updating the `np.ma.mask_or` function to call `_shrink_mask` on the mask returned in this case. See issue numpy#27201. add test for mask_or(x, x) where x is all False fix linting issue
1 parent 9e43697 commit 6e0bbeb

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

numpy/ma/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,7 @@ def mask_or(m1, m2, copy=False, shrink=True):
17871787
dtype = getattr(m1, 'dtype', MaskType)
17881788
return make_mask(m1, copy=copy, shrink=shrink, dtype=dtype)
17891789
if m1 is m2 and is_mask(m1):
1790-
return m1
1790+
return _shrink_mask(m1) if shrink else m1
17911791
(dtype1, dtype2) = (getattr(m1, 'dtype', None), getattr(m2, 'dtype', None))
17921792
if dtype1 != dtype2:
17931793
raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))

numpy/ma/tests/test_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4849,6 +4849,26 @@ def test_mask_or(self):
48494849
cntrl = np.array([(1, (1, 1)), (0, (1, 0))], dtype=dtype)
48504850
assert_equal(mask_or(amask, bmask), cntrl)
48514851

4852+
a = np.array([False, False])
4853+
assert mask_or(a, a) is nomask # gh-27360
4854+
4855+
def test_allequal(self):
4856+
x = array([1, 2, 3], mask=[0, 0, 0])
4857+
y = array([1, 2, 3], mask=[1, 0, 0])
4858+
z = array([[1, 2, 3], [4, 5, 6]], mask=[[0, 0, 0], [1, 1, 1]])
4859+
4860+
assert allequal(x, y)
4861+
assert not allequal(x, y, fill_value=False)
4862+
assert allequal(x, z)
4863+
4864+
# test allequal for the same input, with mask=nomask, this test is for
4865+
# the scenario raised in https://github.com/numpy/numpy/issues/27201
4866+
assert allequal(x, x)
4867+
assert allequal(x, x, fill_value=False)
4868+
4869+
assert allequal(y, y)
4870+
assert not allequal(y, y, fill_value=False)
4871+
48524872
def test_flatten_mask(self):
48534873
# Tests flatten mask
48544874
# Standard dtype

0 commit comments

Comments
 (0)