Skip to content

Commit 201f950

Browse files
committed
Implement virtual array indexing using ndindex
1 parent d9fddc7 commit 201f950

File tree

9 files changed

+62
-91
lines changed

9 files changed

+62
-91
lines changed

cubed/core/ops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
OrthogonalIndexer,
1818
SliceDimIndexer,
1919
is_integer_list,
20-
is_slice,
2120
replace_ellipsis,
2221
)
2322

@@ -409,7 +408,7 @@ def index(x, key):
409408
key = (key,)
410409

411410
# No op case
412-
if all(is_slice(ind) and ind == slice(None) for ind in key):
411+
if all(isinstance(ind, slice) and ind == slice(None) for ind in key):
413412
return x
414413

415414
# Remove None values, to be filled in with expand_dims at end
@@ -436,7 +435,9 @@ def index(x, key):
436435
selection = replace_ellipsis(selection, x.shape)
437436

438437
# Check selection is supported
439-
if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)):
438+
if any(
439+
s.step is not None and s.step < 1 for s in selection if isinstance(s, slice)
440+
):
440441
raise NotImplementedError(f"Slice step must be >= 1: {key}")
441442
assert all(isinstance(s, (slice, list, Integral)) for s in selection)
442443
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
@@ -490,7 +491,6 @@ def merged_chunk_len_for_indexer(s):
490491
extra_projected_mem=extra_projected_mem,
491492
target_chunks=target_chunks,
492493
selection=selection,
493-
advanced_indexing=len(where_list) > 0,
494494
)
495495

496496
# merge chunks for any dims with step > 1 so they are
@@ -516,13 +516,14 @@ def _read_index_chunk(
516516
*arrays,
517517
target_chunks=None,
518518
selection=None,
519-
advanced_indexing=None,
520519
block_id=None,
521520
):
522521
array = arrays[0].zarray
523-
if advanced_indexing:
524-
array = array.oindex
525522
idx = block_id
523+
# Note that since we only have a maximum of one integer array index
524+
# we don't need to use Zarr orthogonal indexing, since it is
525+
# "available directly on the array" according to
526+
# https://zarr.readthedocs.io/en/stable/tutorial.html#orthogonal-indexing
526527
out = array[_target_chunk_selection(target_chunks, idx, selection)]
527528
out = numpy_array_to_backend_array(out)
528529
return out
@@ -535,7 +536,7 @@ def _target_chunk_selection(target_chunks, idx, selection):
535536
sel = []
536537
i = 0 # index into target_chunks and idx
537538
for s in selection:
538-
if is_slice(s):
539+
if isinstance(s, slice):
539540
offset = s.start or 0
540541
step = s.step if s.step is not None else 1
541542
start = tuple(

cubed/runtime/executors/modal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"donfig",
3939
"fsspec",
4040
"mypy_extensions", # for rechunker
41+
"ndindex",
4142
"networkx",
4243
"pytest-mock", # TODO: only needed for tests
4344
"s3fs",
@@ -52,6 +53,7 @@
5253
"donfig",
5354
"fsspec",
5455
"mypy_extensions", # for rechunker
56+
"ndindex",
5557
"networkx",
5658
"pytest-mock", # TODO: only needed for tests
5759
"gcsfs",

cubed/storage/virtual.py

Lines changed: 24 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
from typing import Any
33

44
import numpy as np
5-
import zarr
6-
from zarr.indexing import BasicIndexer, is_slice
5+
from ndindex import ndindex
76

8-
from cubed.backend_array_api import backend_array_to_numpy_array
97
from cubed.backend_array_api import namespace as nxp
108
from cubed.backend_array_api import numpy_array_to_backend_array
119
from cubed.types import T_DType, T_RegularChunks, T_Shape
@@ -21,31 +19,21 @@ def __init__(
2119
dtype: T_DType,
2220
chunks: T_RegularChunks,
2321
):
24-
# use an empty in-memory Zarr array as a template since it normalizes its properties
25-
template = zarr.empty(
26-
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
27-
)
28-
self.shape = template.shape
29-
self.dtype = template.dtype
30-
self.chunks = template.chunks
31-
self.template = template
22+
self.shape = shape
23+
self.dtype = np.dtype(dtype)
24+
self.chunks = chunks
3225

3326
def __getitem__(self, key):
34-
if not isinstance(key, tuple):
35-
key = (key,)
36-
indexer = BasicIndexer(key, self.template)
27+
idx = ndindex[key]
28+
newshape = idx.newshape(self.shape)
3729
# use broadcast trick so array chunks only occupy a single value in memory
38-
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)
30+
return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype)
3931

4032
@property
4133
def chunkmem(self):
4234
# take broadcast trick into account
4335
return array_memory(self.dtype, (1,))
4436

45-
@property
46-
def oindex(self):
47-
return self.template.oindex
48-
4937

5038
class VirtualFullArray:
5139
"""An array that is never materialized (in memory or on disk) and contains a single fill value."""
@@ -57,53 +45,29 @@ def __init__(
5745
chunks: T_RegularChunks,
5846
fill_value: Any = None,
5947
):
60-
# use an empty in-memory Zarr array as a template since it normalizes its properties
61-
template = zarr.full(
62-
shape,
63-
fill_value,
64-
dtype=dtype,
65-
chunks=chunks,
66-
store=zarr.storage.MemoryStore(),
67-
)
68-
self.shape = template.shape
69-
self.dtype = template.dtype
70-
self.chunks = template.chunks
71-
self.template = template
48+
self.shape = shape
49+
self.dtype = np.dtype(dtype)
50+
self.chunks = chunks
7251
self.fill_value = fill_value
7352

7453
def __getitem__(self, key):
75-
if not isinstance(key, tuple):
76-
key = (key,)
77-
indexer = BasicIndexer(key, self.template)
54+
idx = ndindex[key]
55+
newshape = idx.newshape(self.shape)
7856
# use broadcast trick so array chunks only occupy a single value in memory
7957
return broadcast_trick(nxp.full)(
80-
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
58+
newshape, fill_value=self.fill_value, dtype=self.dtype
8159
)
8260

83-
@property
84-
def chunkmem(self):
85-
# take broadcast trick into account
86-
return array_memory(self.dtype, (1,))
87-
88-
@property
89-
def oindex(self):
90-
return self.template.oindex
91-
9261

9362
class VirtualOffsetsArray:
9463
"""An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers."""
9564

9665
def __init__(self, shape: T_Shape):
9766
dtype = nxp.int32
9867
chunks = (1,) * len(shape)
99-
# use an empty in-memory Zarr array as a template since it normalizes its properties
100-
template = zarr.empty(
101-
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
102-
)
103-
self.shape = template.shape
104-
self.dtype = template.dtype
105-
self.chunks = template.chunks
106-
self.ndim = template.ndim
68+
self.shape = shape
69+
self.dtype = np.dtype(dtype)
70+
self.chunks = chunks
10771

10872
def __getitem__(self, key):
10973
if key == () and self.shape == ():
@@ -127,28 +91,13 @@ def __init__(
12791
f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`."
12892
)
12993
self.array = array
130-
# use an in-memory Zarr array as a template since it normalizes its properties
131-
# and is needed for oindex
132-
template = zarr.empty(
133-
array.shape,
134-
dtype=array.dtype,
135-
chunks=chunks,
136-
store=zarr.storage.MemoryStore(),
137-
)
138-
self.shape = template.shape
139-
self.dtype = template.dtype
140-
self.chunks = template.chunks
141-
self.template = template
142-
if array.size > 0:
143-
template[...] = backend_array_to_numpy_array(array)
94+
self.shape = array.shape
95+
self.dtype = array.dtype
96+
self.chunks = chunks
14497

14598
def __getitem__(self, key):
14699
return self.array.__getitem__(key)
147100

148-
@property
149-
def oindex(self):
150-
return self.template.oindex
151-
152101

153102
def _key_to_index_tuple(selection):
154103
if isinstance(selection, slice):
@@ -158,7 +107,11 @@ def _key_to_index_tuple(selection):
158107
for s in selection:
159108
if isinstance(s, Integral):
160109
sel.append(s)
161-
elif is_slice(s) and s.stop == s.start + 1 and (s.step is None or s.step == 1):
110+
elif (
111+
isinstance(s, slice)
112+
and s.stop == s.start + 1
113+
and (s.step is None or s.step == 1)
114+
):
162115
sel.append(s.start)
163116
else:
164117
raise NotImplementedError(f"Offset selection not supported: {selection}")

cubed/storage/zarr.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from operator import mul
12
from typing import Optional, Union
23

4+
import numpy as np
35
import zarr
6+
from toolz import reduce
47

58
from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
69

@@ -23,18 +26,23 @@ def __init__(
2326
**kwargs,
2427
):
2528
"""Create a Zarr array lazily in memory."""
26-
# use an empty in-memory Zarr array as a template since it normalizes its properties
27-
template = zarr.empty(
28-
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
29-
)
30-
self.shape = template.shape
31-
self.dtype = template.dtype
32-
self.chunks = template.chunks
33-
self.nbytes = template.nbytes
29+
self.shape = shape
30+
self.dtype = np.dtype(dtype)
31+
self.chunks = chunks
3432
self.store = store
3533
self.path = path
3634
self.kwargs = kwargs
3735

36+
@property
37+
def size(self):
38+
"""Number of elements in the array."""
39+
return reduce(mul, self.shape, 1)
40+
41+
@property
42+
def nbytes(self) -> int:
43+
"""Number of bytes in array"""
44+
return self.size * self.dtype.itemsize
45+
3846
def create(self, mode: str = "w-") -> zarr.Array:
3947
"""Create the Zarr array in storage.
4048

cubed/tests/runtime/test_modal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"donfig",
2424
"fsspec",
2525
"mypy_extensions", # for rechunker
26+
"ndindex",
2627
"networkx",
2728
"pytest-mock", # TODO: only needed for tests
2829
"s3fs",

cubed/tests/test_indexing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def spec(tmp_path):
2323
],
2424
)
2525
def test_int_array_index_1d(spec, ind):
26-
a = xp.arange(12, chunks=(4,), spec=spec)
27-
assert_array_equal(a[ind].compute(), np.arange(12)[ind])
26+
a = xp.arange(12, chunks=(3,), spec=spec)
27+
b = a.rechunk((4,)) # force materialization to test indexing against zarr
28+
assert_array_equal(b[ind].compute(), np.arange(12)[ind])
2829

2930

3031
@pytest.mark.parametrize(
@@ -40,11 +41,12 @@ def test_int_array_index_1d(spec, ind):
4041
def test_int_array_index_2d(spec, ind):
4142
a = xp.asarray(
4243
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
43-
chunks=(2, 2),
44+
chunks=(3, 3),
4445
spec=spec,
4546
)
47+
b = a.rechunk((2, 2)) # force materialization to test indexing against zarr
4648
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
47-
assert_array_equal(a[ind].compute(), x[ind])
49+
assert_array_equal(b[ind].compute(), x[ind])
4850

4951

5052
def test_multiple_int_array_indexes(spec):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"donfig",
3131
"fsspec",
3232
"mypy_extensions", # for rechunker
33+
"ndindex",
3334
"networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*",
3435
"numpy >= 1.22",
3536
"tenacity",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ aiostream
66
array-api-compat
77
fsspec
88
mypy_extensions # for rechunker
9+
ndindex
910
networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*
1011
numpy >= 1.22
1112
tenacity

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ ignore_missing_imports = True
4444
ignore_missing_imports = True
4545
[mypy-matplotlib.*]
4646
ignore_missing_imports = True
47+
[mypy-ndindex.*]
48+
ignore_missing_imports = True
4749
[mypy-networkx.*]
4850
ignore_missing_imports = True
4951
[mypy-numpy.*]

0 commit comments

Comments
 (0)