Skip to content

Commit 5e57f75

Browse files
jhammand-v-b
andauthored
fix: selection with zarr arrays (#2137)
* fix: selection with zarr arrays * fixup * Update src/zarr/core/indexing.py Co-authored-by: Davis Bennett <[email protected]> * Apply suggestions from code review * lint --------- Co-authored-by: Davis Bennett <[email protected]>
1 parent 30e2bc3 commit 5e57f75

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

src/zarr/core/indexing.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_]
3636
BasicSelector = int | slice | EllipsisType
3737
Selector = BasicSelector | ArrayOfIntOrBool
38-
3938
BasicSelection = BasicSelector | tuple[BasicSelector, ...] # also used for BlockIndex
4039
CoordinateSelection = IntSequence | tuple[IntSequence, ...]
4140
MaskSelection = npt.NDArray[np.bool_]
@@ -75,6 +74,15 @@ def err_too_many_indices(selection: Any, shape: ChunkCoords) -> None:
7574
raise IndexError(f"too many indices for array; expected {len(shape)}, got {len(selection)}")
7675

7776

77+
def _zarr_array_to_int_or_bool_array(arr: Array) -> npt.NDArray[np.intp] | npt.NDArray[np.bool_]:
78+
if arr.dtype.kind in ("i", "b"):
79+
return np.asarray(arr)
80+
else:
81+
raise IndexError(
82+
f"Invalid array dtype: {arr.dtype}. Arrays used as indices must be of integer or boolean type"
83+
)
84+
85+
7886
@runtime_checkable
7987
class Indexer(Protocol):
8088
shape: ChunkCoords
@@ -842,7 +850,14 @@ def __iter__(self) -> Iterator[ChunkProjection]:
842850
class OIndex:
843851
array: Array
844852

845-
def __getitem__(self, selection: OrthogonalSelection) -> NDArrayLike:
853+
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
854+
def __getitem__(self, selection: OrthogonalSelection | Array) -> NDArrayLike:
855+
from zarr.core.array import Array
856+
857+
# if input is a Zarr array, we materialize it now.
858+
if isinstance(selection, Array):
859+
selection = _zarr_array_to_int_or_bool_array(selection)
860+
846861
fields, new_selection = pop_fields(selection)
847862
new_selection = ensure_tuple(new_selection)
848863
new_selection = replace_lists(new_selection)
@@ -1130,7 +1145,13 @@ def __init__(self, selection: MaskSelection, shape: ChunkCoords, chunk_grid: Chu
11301145
class VIndex:
11311146
array: Array
11321147

1133-
def __getitem__(self, selection: CoordinateSelection | MaskSelection) -> NDArrayLike:
1148+
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
1149+
def __getitem__(self, selection: CoordinateSelection | MaskSelection | Array) -> NDArrayLike:
1150+
from zarr.core.array import Array
1151+
1152+
# if input is a Zarr array, we materialize it now.
1153+
if isinstance(selection, Array):
1154+
selection = _zarr_array_to_int_or_bool_array(selection)
11341155
fields, new_selection = pop_fields(selection)
11351156
new_selection = ensure_tuple(new_selection)
11361157
new_selection = replace_lists(new_selection)

tests/v3/test_indexing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,3 +1861,20 @@ def test_orthogonal_bool_indexing_like_numpy_ix(
18611861
# note: in python 3.10 z[*selection] is not valid unpacking syntax
18621862
actual = z[(*selection,)]
18631863
assert_array_equal(expected, actual, err_msg=f"{selection=}")
1864+
1865+
1866+
def test_indexing_with_zarr_array(store: StorePath) -> None:
1867+
# regression test for https://github.com/zarr-developers/zarr-python/issues/2133
1868+
a = np.arange(10)
1869+
za = zarr.array(a, chunks=2, store=store, path="a")
1870+
ix = [False, True, False, True, False, True, False, True, False, True]
1871+
ii = [0, 2, 4, 5]
1872+
1873+
zix = zarr.array(ix, chunks=2, store=store, dtype="bool", path="ix")
1874+
zii = zarr.array(ii, chunks=2, store=store, dtype="i4", path="ii")
1875+
assert_array_equal(a[ix], za[zix])
1876+
assert_array_equal(a[ix], za.oindex[zix])
1877+
assert_array_equal(a[ix], za.vindex[zix])
1878+
1879+
assert_array_equal(a[ii], za[zii])
1880+
assert_array_equal(a[ii], za.oindex[zii])

0 commit comments

Comments
 (0)