Skip to content

Commit e42a6be

Browse files
brokkoli71jhammand-v-b
authored
Fix indexing with bools (#1968)
* test z[selection] for orthogonal selection * include boolean indexing in is_pure_orthogonal_indexing * Revert "test z[selection] for orthogonal selection" This reverts commit 38578dd. * add test_indexing_equals_numpy * extend _test_get_mask_selection for square bracket notation * fix is_pure_fancy_indexing for mask selection * add test_orthogonal_bool_indexing_like_numpy_ix * fix for mypy * ruff format * fix is_pure_orthogonal_indexing * fix is_pure_orthogonal_indexing * replace deprecated ~ by not * restrict is_integer to not bool * correct typing Co-authored-by: Joe Hamman <[email protected]> * correct typing * check if bool list has only bools * check if bool list has only bools * fix list unpacking in test for python3.10 * Apply spelling suggestions from code review Co-authored-by: Davis Bennett <[email protected]> * fix mypy --------- Co-authored-by: Joe Hamman <[email protected]> Co-authored-by: Davis Bennett <[email protected]>
1 parent 08b5763 commit e42a6be

File tree

3 files changed

+88
-24
lines changed

3 files changed

+88
-24
lines changed

src/zarr/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def attrs(self) -> dict[str, JSON]:
403403

404404
@property
405405
def read_only(self) -> bool:
406-
return bool(~self.store_path.store.writeable)
406+
return bool(not self.store_path.store.writeable)
407407

408408
@property
409409
def path(self) -> str:

src/zarr/indexing.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -87,21 +87,23 @@ def ceildiv(a: float, b: float) -> int:
8787

8888

8989
def is_integer(x: Any) -> TypeGuard[int]:
90-
"""True if x is an integer (both pure Python or NumPy).
90+
"""True if x is an integer (both pure Python or NumPy)."""
91+
return isinstance(x, numbers.Integral) and not is_bool(x)
9192

92-
Note that Python's bool is considered an integer too.
93-
"""
94-
return isinstance(x, numbers.Integral)
93+
94+
def is_bool(x: Any) -> TypeGuard[bool | np.bool_]:
95+
"""True if x is a boolean (both pure Python or NumPy)."""
96+
return type(x) in [bool, np.bool_]
9597

9698

9799
def is_integer_list(x: Any) -> TypeGuard[list[int]]:
98-
"""True if x is a list of integers.
100+
"""True if x is a list of integers."""
101+
return isinstance(x, list) and len(x) > 0 and all(is_integer(i) for i in x)
99102

100-
This function assumes ie *does not check* that all elements of the list
101-
have the same type. Mixed type lists will result in other errors that will
102-
bubble up anyway.
103-
"""
104-
return isinstance(x, list) and len(x) > 0 and is_integer(x[0])
103+
104+
def is_bool_list(x: Any) -> TypeGuard[list[bool | np.bool_]]:
105+
"""True if x is a list of boolean."""
106+
return isinstance(x, list) and len(x) > 0 and all(is_bool(i) for i in x)
105107

106108

107109
def is_integer_array(x: Any, ndim: int | None = None) -> TypeGuard[npt.NDArray[np.intp]]:
@@ -118,6 +120,10 @@ def is_bool_array(x: Any, ndim: int | None = None) -> TypeGuard[npt.NDArray[np.b
118120
return t
119121

120122

123+
def is_int_or_bool_iterable(x: Any) -> bool:
124+
return is_integer_list(x) or is_integer_array(x) or is_bool_array(x) or is_bool_list(x)
125+
126+
121127
def is_scalar(value: Any, dtype: np.dtype[Any]) -> bool:
122128
if np.isscalar(value):
123129
return True
@@ -129,7 +135,7 @@ def is_scalar(value: Any, dtype: np.dtype[Any]) -> bool:
129135

130136

131137
def is_pure_fancy_indexing(selection: Any, ndim: int) -> bool:
132-
"""Check whether a selection contains only scalars or integer array-likes.
138+
"""Check whether a selection contains only scalars or integer/bool array-likes.
133139
134140
Parameters
135141
----------
@@ -142,9 +148,14 @@ def is_pure_fancy_indexing(selection: Any, ndim: int) -> bool:
142148
True if the selection is a pure fancy indexing expression (ie not mixed
143149
with boolean or slices).
144150
"""
151+
if is_bool_array(selection):
152+
# is mask selection
153+
return True
154+
145155
if ndim == 1:
146-
if is_integer_list(selection) or is_integer_array(selection):
156+
if is_integer_list(selection) or is_integer_array(selection) or is_bool_list(selection):
147157
return True
158+
148159
# if not, we go through the normal path below, because a 1-tuple
149160
# of integers is also allowed.
150161
no_slicing = (
@@ -166,19 +177,21 @@ def is_pure_orthogonal_indexing(selection: Selection, ndim: int) -> TypeGuard[Or
166177
if not ndim:
167178
return False
168179

169-
# Case 1: Selection is a single iterable of integers
170-
if is_integer_list(selection) or is_integer_array(selection, ndim=1):
180+
selection_normalized = (selection,) if not isinstance(selection, tuple) else selection
181+
182+
# Case 1: Selection contains of iterable of integers or boolean
183+
if len(selection_normalized) == ndim and all(
184+
is_int_or_bool_iterable(s) for s in selection_normalized
185+
):
171186
return True
172187

173-
# Case two: selection contains either zero or one integer iterables.
188+
# Case 2: selection contains either zero or one integer iterables.
174189
# All other selection elements are slices or integers
175190
return (
176-
isinstance(selection, tuple)
177-
and len(selection) == ndim
178-
and sum(is_integer_list(elem) or is_integer_array(elem) for elem in selection) <= 1
191+
len(selection_normalized) <= ndim
192+
and sum(is_int_or_bool_iterable(s) for s in selection_normalized) <= 1
179193
and all(
180-
is_integer_list(elem) or is_integer_array(elem) or isinstance(elem, int | slice)
181-
for elem in selection
194+
is_int_or_bool_iterable(s) or isinstance(s, int | slice) for s in selection_normalized
182195
)
183196
)
184197

@@ -1023,7 +1036,7 @@ def __init__(self, selection: CoordinateSelection, shape: ChunkCoords, chunk_gri
10231036
# flatten selection
10241037
selection_broadcast = tuple(dim_sel.reshape(-1) for dim_sel in selection_broadcast)
10251038
chunks_multi_index_broadcast = tuple(
1026-
dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast
1039+
[dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast]
10271040
)
10281041

10291042
# ravel chunk indices

tests/v3/test_indexing.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def test_get_basic_selection_0d(store: StorePath, use_out: bool, value: Any, dty
204204
slice(50, 150, 10),
205205
]
206206

207-
208207
basic_selections_1d_bad = [
209208
# only positive step supported
210209
slice(None, None, -1),
@@ -305,7 +304,6 @@ def test_get_basic_selection_1d(store: StorePath):
305304
(Ellipsis, slice(None), slice(None)),
306305
]
307306

308-
309307
basic_selections_2d_bad = [
310308
# bad stuff
311309
2.3,
@@ -1272,6 +1270,8 @@ def _test_get_mask_selection(a, z, selection):
12721270
assert_array_equal(expect, actual)
12731271
actual = z.vindex[selection]
12741272
assert_array_equal(expect, actual)
1273+
actual = z[selection]
1274+
assert_array_equal(expect, actual)
12751275

12761276

12771277
mask_selections_1d_bad = [
@@ -1344,6 +1344,9 @@ def _test_set_mask_selection(v, a, z, selection):
13441344
z[:] = 0
13451345
z.vindex[selection] = v[selection]
13461346
assert_array_equal(a, z[:])
1347+
z[:] = 0
1348+
z[selection] = v[selection]
1349+
assert_array_equal(a, z[:])
13471350

13481351

13491352
def test_set_mask_selection_1d(store: StorePath):
@@ -1726,3 +1729,51 @@ def test_accessed_chunks(shape, chunks, ops):
17261729
) == 1
17271730
# Check that no other chunks were accessed
17281731
assert len(delta_counts) == 0
1732+
1733+
1734+
@pytest.mark.parametrize(
1735+
"selection",
1736+
[
1737+
# basic selection
1738+
[...],
1739+
[1, ...],
1740+
[slice(None)],
1741+
[1, 3],
1742+
[[1, 2, 3], 9],
1743+
[np.arange(1000)],
1744+
[slice(5, 15)],
1745+
[slice(2, 4), 4],
1746+
[[1, 3]],
1747+
# mask selection
1748+
[np.tile([True, False], (1000, 5))],
1749+
[np.full((1000, 10), False)],
1750+
# coordinate selection
1751+
[[1, 2, 3, 4], [5, 6, 7, 8]],
1752+
[[100, 200, 300], [4, 5, 6]],
1753+
],
1754+
)
1755+
def test_indexing_equals_numpy(store, selection):
1756+
a = np.arange(10000, dtype=int).reshape(1000, 10)
1757+
z = zarr_array_from_numpy_array(store, a, chunk_shape=(300, 3))
1758+
# note: in python 3.10 a[*selection] is not valid unpacking syntax
1759+
expected = a[(*selection,)]
1760+
actual = z[(*selection,)]
1761+
assert_array_equal(expected, actual, err_msg=f"selection: {selection}")
1762+
1763+
1764+
@pytest.mark.parametrize(
1765+
"selection",
1766+
[
1767+
[np.tile([True, False], 500), np.tile([True, False], 5)],
1768+
[np.full(1000, False), np.tile([True, False], 5)],
1769+
[np.full(1000, True), np.full(10, True)],
1770+
[np.full(1000, True), [True, False] * 5],
1771+
],
1772+
)
1773+
def test_orthogonal_bool_indexing_like_numpy_ix(store, selection):
1774+
a = np.arange(10000, dtype=int).reshape(1000, 10)
1775+
z = zarr_array_from_numpy_array(store, a, chunk_shape=(300, 3))
1776+
expected = a[np.ix_(*selection)]
1777+
# note: in python 3.10 z[*selection] is not valid unpacking syntax
1778+
actual = z[(*selection,)]
1779+
assert_array_equal(expected, actual, err_msg=f"{selection=}")

0 commit comments

Comments
 (0)