Skip to content

Commit db71acf

Browse files
committed
Implement virtual array indexing using ndindex
1 parent c2283f3 commit db71acf

File tree

9 files changed

+62
-86
lines changed

9 files changed

+62
-86
lines changed

cubed/core/ops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
OrthogonalIndexer,
1717
SliceDimIndexer,
1818
is_integer_list,
19-
is_slice,
2019
replace_ellipsis,
2120
)
2221

@@ -408,7 +407,7 @@ def index(x, key):
408407
key = (key,)
409408

410409
# No op case
411-
if all(is_slice(ind) and ind == slice(None) for ind in key):
410+
if all(isinstance(ind, slice) and ind == slice(None) for ind in key):
412411
return x
413412

414413
# Remove None values, to be filled in with expand_dims at end
@@ -435,7 +434,9 @@ def index(x, key):
435434
selection = replace_ellipsis(selection, x.shape)
436435

437436
# Check selection is supported
438-
if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)):
437+
if any(
438+
s.step is not None and s.step < 1 for s in selection if isinstance(s, slice)
439+
):
439440
raise NotImplementedError(f"Slice step must be >= 1: {key}")
440441
assert all(isinstance(s, (slice, list, Integral)) for s in selection)
441442
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
@@ -489,7 +490,6 @@ def merged_chunk_len_for_indexer(s):
489490
extra_projected_mem=extra_projected_mem,
490491
target_chunks=target_chunks,
491492
selection=selection,
492-
advanced_indexing=len(where_list) > 0,
493493
)
494494

495495
# merge chunks for any dims with step > 1 so they are
@@ -515,13 +515,14 @@ def _read_index_chunk(
515515
*arrays,
516516
target_chunks=None,
517517
selection=None,
518-
advanced_indexing=None,
519518
block_id=None,
520519
):
521520
array = arrays[0].zarray
522-
if advanced_indexing:
523-
array = array.oindex
524521
idx = block_id
522+
# Note that since we only have a maximum of one integer array index
523+
# we don't need to use Zarr orthogonal indexing, since it is
524+
# "available directly on the array" according to
525+
# https://zarr.readthedocs.io/en/stable/tutorial.html#orthogonal-indexing
525526
out = array[_target_chunk_selection(target_chunks, idx, selection)]
526527
out = numpy_array_to_backend_array(out)
527528
return out
@@ -534,7 +535,7 @@ def _target_chunk_selection(target_chunks, idx, selection):
534535
sel = []
535536
i = 0 # index into target_chunks and idx
536537
for s in selection:
537-
if is_slice(s):
538+
if isinstance(s, slice):
538539
offset = s.start or 0
539540
step = s.step if s.step is not None else 1
540541
start = tuple(

cubed/runtime/executors/modal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"donfig",
2222
"fsspec",
2323
"mypy_extensions", # for rechunker
24+
"ndindex",
2425
"networkx",
2526
"pytest-mock", # TODO: only needed for tests
2627
"s3fs",
@@ -35,6 +36,7 @@
3536
"donfig",
3637
"fsspec",
3738
"mypy_extensions", # for rechunker
39+
"ndindex",
3840
"networkx",
3941
"pytest-mock", # TODO: only needed for tests
4042
"gcsfs",

cubed/storage/virtual.py

Lines changed: 24 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
from typing import Any
33

44
import numpy as np
5-
import zarr
6-
from zarr.indexing import BasicIndexer, is_slice
5+
from ndindex import ndindex
76

8-
from cubed.backend_array_api import backend_array_to_numpy_array
97
from cubed.backend_array_api import namespace as nxp
108
from cubed.backend_array_api import numpy_array_to_backend_array
119
from cubed.types import T_DType, T_RegularChunks, T_Shape
@@ -21,25 +19,15 @@ def __init__(
2119
dtype: T_DType,
2220
chunks: T_RegularChunks,
2321
):
24-
# use an empty in-memory Zarr array as a template since it normalizes its properties
25-
template = zarr.empty(
26-
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
27-
)
28-
self.shape = template.shape
29-
self.dtype = template.dtype
30-
self.chunks = template.chunks
31-
self.template = template
22+
self.shape = shape
23+
self.dtype = np.dtype(dtype)
24+
self.chunks = chunks
3225

3326
def __getitem__(self, key):
34-
if not isinstance(key, tuple):
35-
key = (key,)
36-
indexer = BasicIndexer(key, self.template)
27+
idx = ndindex[key]
28+
newshape = idx.newshape(self.shape)
3729
# use broadcast trick so array chunks only occupy a single value in memory
38-
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)
39-
40-
@property
41-
def oindex(self):
42-
return self.template.oindex
30+
return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype)
4331

4432

4533
class VirtualFullArray:
@@ -52,48 +40,29 @@ def __init__(
5240
chunks: T_RegularChunks,
5341
fill_value: Any = None,
5442
):
55-
# use an empty in-memory Zarr array as a template since it normalizes its properties
56-
template = zarr.full(
57-
shape,
58-
fill_value,
59-
dtype=dtype,
60-
chunks=chunks,
61-
store=zarr.storage.MemoryStore(),
62-
)
63-
self.shape = template.shape
64-
self.dtype = template.dtype
65-
self.chunks = template.chunks
66-
self.template = template
43+
self.shape = shape
44+
self.dtype = np.dtype(dtype)
45+
self.chunks = chunks
6746
self.fill_value = fill_value
6847

6948
def __getitem__(self, key):
70-
if not isinstance(key, tuple):
71-
key = (key,)
72-
indexer = BasicIndexer(key, self.template)
49+
idx = ndindex[key]
50+
newshape = idx.newshape(self.shape)
7351
# use broadcast trick so array chunks only occupy a single value in memory
7452
return broadcast_trick(nxp.full)(
75-
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
53+
newshape, fill_value=self.fill_value, dtype=self.dtype
7654
)
7755

78-
@property
79-
def oindex(self):
80-
return self.template.oindex
81-
8256

8357
class VirtualOffsetsArray:
8458
"""An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers."""
8559

8660
def __init__(self, shape: T_Shape):
8761
dtype = nxp.int32
8862
chunks = (1,) * len(shape)
89-
# use an empty in-memory Zarr array as a template since it normalizes its properties
90-
template = zarr.empty(
91-
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
92-
)
93-
self.shape = template.shape
94-
self.dtype = template.dtype
95-
self.chunks = template.chunks
96-
self.ndim = template.ndim
63+
self.shape = shape
64+
self.dtype = np.dtype(dtype)
65+
self.chunks = chunks
9766

9867
def __getitem__(self, key):
9968
if key == () and self.shape == ():
@@ -117,28 +86,13 @@ def __init__(
11786
f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`."
11887
)
11988
self.array = array
120-
# use an in-memory Zarr array as a template since it normalizes its properties
121-
# and is needed for oindex
122-
template = zarr.empty(
123-
array.shape,
124-
dtype=array.dtype,
125-
chunks=chunks,
126-
store=zarr.storage.MemoryStore(),
127-
)
128-
self.shape = template.shape
129-
self.dtype = template.dtype
130-
self.chunks = template.chunks
131-
self.template = template
132-
if array.size > 0:
133-
template[...] = backend_array_to_numpy_array(array)
89+
self.shape = array.shape
90+
self.dtype = array.dtype
91+
self.chunks = chunks
13492

13593
def __getitem__(self, key):
13694
return self.array.__getitem__(key)
13795

138-
@property
139-
def oindex(self):
140-
return self.template.oindex
141-
14296

14397
def _key_to_index_tuple(selection):
14498
if isinstance(selection, slice):
@@ -148,7 +102,11 @@ def _key_to_index_tuple(selection):
148102
for s in selection:
149103
if isinstance(s, Integral):
150104
sel.append(s)
151-
elif is_slice(s) and s.stop == s.start + 1 and (s.step is None or s.step == 1):
105+
elif (
106+
isinstance(s, slice)
107+
and s.stop == s.start + 1
108+
and (s.step is None or s.step == 1)
109+
):
152110
sel.append(s.start)
153111
else:
154112
raise NotImplementedError(f"Offset selection not supported: {selection}")

cubed/storage/zarr.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from operator import mul
12
from typing import Optional, Union
23

4+
import numpy as np
35
import zarr
6+
from toolz import reduce
47

58
from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
69

@@ -23,18 +26,23 @@ def __init__(
2326
**kwargs,
2427
):
2528
"""Create a Zarr array lazily in memory."""
26-
# use an empty in-memory Zarr array as a template since it normalizes its properties
27-
template = zarr.empty(
28-
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
29-
)
30-
self.shape = template.shape
31-
self.dtype = template.dtype
32-
self.chunks = template.chunks
33-
self.nbytes = template.nbytes
29+
self.shape = shape
30+
self.dtype = np.dtype(dtype)
31+
self.chunks = chunks
3432
self.store = store
3533
self.path = path
3634
self.kwargs = kwargs
3735

36+
@property
37+
def size(self):
38+
"""Number of elements in the array."""
39+
return reduce(mul, self.shape, 1)
40+
41+
@property
42+
def nbytes(self) -> int:
43+
"""Number of bytes in array"""
44+
return self.size * self.dtype.itemsize
45+
3846
def create(self, mode: str = "w-") -> zarr.Array:
3947
"""Create the Zarr array in storage.
4048

cubed/tests/runtime/test_modal_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"donfig",
2424
"fsspec",
2525
"mypy_extensions", # for rechunker
26+
"ndindex",
2627
"networkx",
2728
"pytest-mock", # TODO: only needed for tests
2829
"s3fs",

cubed/tests/test_indexing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def spec(tmp_path):
2323
],
2424
)
2525
def test_int_array_index_1d(spec, ind):
26-
a = xp.arange(12, chunks=(4,), spec=spec)
27-
assert_array_equal(a[ind].compute(), np.arange(12)[ind])
26+
a = xp.arange(12, chunks=(3,), spec=spec)
27+
b = a.rechunk((4,)) # force materialization to test indexing against zarr
28+
assert_array_equal(b[ind].compute(), np.arange(12)[ind])
2829

2930

3031
@pytest.mark.parametrize(
@@ -40,11 +41,12 @@ def test_int_array_index_1d(spec, ind):
4041
def test_int_array_index_2d(spec, ind):
4142
a = xp.asarray(
4243
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
43-
chunks=(2, 2),
44+
chunks=(3, 3),
4445
spec=spec,
4546
)
47+
b = a.rechunk((2, 2)) # force materialization to test indexing against zarr
4648
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
47-
assert_array_equal(a[ind].compute(), x[ind])
49+
assert_array_equal(b[ind].compute(), x[ind])
4850

4951

5052
def test_multiple_int_array_indexes(spec):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"donfig",
2828
"fsspec",
2929
"mypy_extensions", # for rechunker
30+
"ndindex",
3031
"networkx < 2.8.3",
3132
"numpy >= 1.22",
3233
"tenacity",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ aiostream
66
array-api-compat
77
fsspec
88
mypy_extensions # for rechunker
9+
ndindex
910
networkx < 2.8.3
1011
numpy >= 1.22
1112
tenacity

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ ignore_missing_imports = True
4444
ignore_missing_imports = True
4545
[mypy-matplotlib.*]
4646
ignore_missing_imports = True
47+
[mypy-ndindex.*]
48+
ignore_missing_imports = True
4749
[mypy-networkx.*]
4850
ignore_missing_imports = True
4951
[mypy-numpy.*]

0 commit comments

Comments
 (0)