Skip to content

Commit 661acb3

Browse files
authored
Buffer Prototype Argument (#1910)
1 parent b431cf7 commit 661acb3

25 files changed

+438
-252
lines changed

src/zarr/abc/codec.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
if TYPE_CHECKING:
1414
from typing_extensions import Self
1515

16-
from zarr.common import ArraySpec
16+
from zarr.array_spec import ArraySpec
1717
from zarr.indexing import SelectorTuple
1818
from zarr.metadata import ArrayMetadata
1919

20-
2120
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
2221
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
2322

src/zarr/abc/store.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import AsyncGenerator
33
from typing import Protocol, runtime_checkable
44

5-
from zarr.buffer import Buffer
5+
from zarr.buffer import Buffer, BufferPrototype
66
from zarr.common import BytesLike, OpenMode
77

88

@@ -30,7 +30,10 @@ def _check_writable(self) -> None:
3030

3131
@abstractmethod
3232
async def get(
33-
self, key: str, byte_range: tuple[int | None, int | None] | None = None
33+
self,
34+
key: str,
35+
prototype: BufferPrototype,
36+
byte_range: tuple[int | None, int | None] | None = None,
3437
) -> Buffer | None:
3538
"""Retrieve the value associated with a given key.
3639
@@ -47,7 +50,9 @@ async def get(
4750

4851
@abstractmethod
4952
async def get_partial_values(
50-
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
53+
self,
54+
prototype: BufferPrototype,
55+
key_ranges: list[tuple[str, tuple[int | None, int | None]]],
5156
) -> list[Buffer | None]:
5257
"""Retrieve possibly partial values from given key_ranges.
5358
@@ -175,12 +180,16 @@ def close(self) -> None: # noqa: B027
175180

176181
@runtime_checkable
177182
class ByteGetter(Protocol):
178-
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
183+
async def get(
184+
self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None
185+
) -> Buffer | None: ...
179186

180187

181188
@runtime_checkable
182189
class ByteSetter(Protocol):
183-
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
190+
async def get(
191+
self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None
192+
) -> Buffer | None: ...
184193

185194
async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: ...
186195

src/zarr/array.py

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from zarr.abc.codec import Codec
2121
from zarr.abc.store import set_or_delete
2222
from zarr.attributes import Attributes
23-
from zarr.buffer import Factory, NDArrayLike, NDBuffer
23+
from zarr.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
2424
from zarr.chunk_grids import RegularChunkGrid
2525
from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding
2626
from zarr.codecs import BytesCodec
@@ -414,8 +414,8 @@ async def _get_selection(
414414
self,
415415
indexer: Indexer,
416416
*,
417+
prototype: BufferPrototype,
417418
out: NDBuffer | None = None,
418-
factory: Factory.Create = NDBuffer.create,
419419
fields: Fields | None = None,
420420
) -> NDArrayLike:
421421
# check fields are sensible
@@ -432,7 +432,7 @@ async def _get_selection(
432432
f"shape of out argument doesn't match. Expected {indexer.shape}, got {out.shape}"
433433
)
434434
else:
435-
out_buffer = factory(
435+
out_buffer = prototype.nd_buffer.create(
436436
shape=indexer.shape,
437437
dtype=out_dtype,
438438
order=self.order,
@@ -444,7 +444,7 @@ async def _get_selection(
444444
[
445445
(
446446
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
447-
self.metadata.get_chunk_spec(chunk_coords, self.order),
447+
self.metadata.get_chunk_spec(chunk_coords, self.order, prototype=prototype),
448448
chunk_selection,
449449
out_selection,
450450
)
@@ -456,14 +456,14 @@ async def _get_selection(
456456
return out_buffer.as_ndarray_like()
457457

458458
async def getitem(
459-
self, selection: Selection, *, factory: Factory.Create = NDBuffer.create
459+
self, selection: Selection, *, prototype: BufferPrototype = default_buffer_prototype
460460
) -> NDArrayLike:
461461
indexer = BasicIndexer(
462462
selection,
463463
shape=self.metadata.shape,
464464
chunk_grid=self.metadata.chunk_grid,
465465
)
466-
return await self._get_selection(indexer, factory=factory)
466+
return await self._get_selection(indexer, prototype=prototype)
467467

468468
async def _save_metadata(self, metadata: ArrayMetadata) -> None:
469469
to_save = metadata.to_buffer_dict()
@@ -475,7 +475,7 @@ async def _set_selection(
475475
indexer: Indexer,
476476
value: NDArrayLike,
477477
*,
478-
factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like,
478+
prototype: BufferPrototype,
479479
fields: Fields | None = None,
480480
) -> None:
481481
# check fields are sensible
@@ -497,14 +497,14 @@ async def _set_selection(
497497
# We accept any ndarray like object from the user and convert it
498498
# to a NDBuffer (or subclass). From this point onwards, we only pass
499499
# Buffer and NDBuffer between components.
500-
value_buffer = factory(value)
500+
value_buffer = prototype.nd_buffer.from_ndarray_like(value)
501501

502502
# merging with existing data and encoding chunks
503503
await self.metadata.codec_pipeline.write(
504504
[
505505
(
506506
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
507-
self.metadata.get_chunk_spec(chunk_coords, self.order),
507+
self.metadata.get_chunk_spec(chunk_coords, self.order, prototype),
508508
chunk_selection,
509509
out_selection,
510510
)
@@ -518,14 +518,14 @@ async def setitem(
518518
self,
519519
selection: Selection,
520520
value: NDArrayLike,
521-
factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like,
521+
prototype: BufferPrototype = default_buffer_prototype,
522522
) -> None:
523523
indexer = BasicIndexer(
524524
selection,
525525
shape=self.metadata.shape,
526526
chunk_grid=self.metadata.chunk_grid,
527527
)
528-
return await self._set_selection(indexer, value, factory=factory)
528+
return await self._set_selection(indexer, value, prototype=prototype)
529529

530530
async def resize(
531531
self, new_shape: ChunkCoords, delete_outside_chunks: bool = True
@@ -714,7 +714,9 @@ def __setitem__(self, selection: Selection, value: NDArrayLike) -> None:
714714
def get_basic_selection(
715715
self,
716716
selection: BasicSelection = Ellipsis,
717+
*,
717718
out: NDBuffer | None = None,
719+
prototype: BufferPrototype = default_buffer_prototype,
718720
fields: Fields | None = None,
719721
) -> NDArrayLike:
720722
if self.shape == ():
@@ -725,57 +727,101 @@ def get_basic_selection(
725727
BasicIndexer(selection, self.shape, self.metadata.chunk_grid),
726728
out=out,
727729
fields=fields,
730+
prototype=prototype,
728731
)
729732
)
730733

731734
def set_basic_selection(
732-
self, selection: BasicSelection, value: NDArrayLike, fields: Fields | None = None
735+
self,
736+
selection: BasicSelection,
737+
value: NDArrayLike,
738+
*,
739+
fields: Fields | None = None,
740+
prototype: BufferPrototype = default_buffer_prototype,
733741
) -> None:
734742
indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid)
735-
sync(self._async_array._set_selection(indexer, value, fields=fields))
743+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
736744

737745
def get_orthogonal_selection(
738746
self,
739747
selection: OrthogonalSelection,
748+
*,
740749
out: NDBuffer | None = None,
741750
fields: Fields | None = None,
751+
prototype: BufferPrototype = default_buffer_prototype,
742752
) -> NDArrayLike:
743753
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
744-
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
754+
return sync(
755+
self._async_array._get_selection(
756+
indexer=indexer, out=out, fields=fields, prototype=prototype
757+
)
758+
)
745759

746760
def set_orthogonal_selection(
747-
self, selection: OrthogonalSelection, value: NDArrayLike, fields: Fields | None = None
761+
self,
762+
selection: OrthogonalSelection,
763+
value: NDArrayLike,
764+
*,
765+
fields: Fields | None = None,
766+
prototype: BufferPrototype = default_buffer_prototype,
748767
) -> None:
749768
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
750-
return sync(self._async_array._set_selection(indexer, value, fields=fields))
769+
return sync(
770+
self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
771+
)
751772

752773
def get_mask_selection(
753-
self, mask: MaskSelection, out: NDBuffer | None = None, fields: Fields | None = None
774+
self,
775+
mask: MaskSelection,
776+
*,
777+
out: NDBuffer | None = None,
778+
fields: Fields | None = None,
779+
prototype: BufferPrototype = default_buffer_prototype,
754780
) -> NDArrayLike:
755781
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
756-
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
782+
return sync(
783+
self._async_array._get_selection(
784+
indexer=indexer, out=out, fields=fields, prototype=prototype
785+
)
786+
)
757787

758788
def set_mask_selection(
759-
self, mask: MaskSelection, value: NDArrayLike, fields: Fields | None = None
789+
self,
790+
mask: MaskSelection,
791+
value: NDArrayLike,
792+
*,
793+
fields: Fields | None = None,
794+
prototype: BufferPrototype = default_buffer_prototype,
760795
) -> None:
761796
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
762-
sync(self._async_array._set_selection(indexer, value, fields=fields))
797+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
763798

764799
def get_coordinate_selection(
765800
self,
766801
selection: CoordinateSelection,
802+
*,
767803
out: NDBuffer | None = None,
768804
fields: Fields | None = None,
805+
prototype: BufferPrototype = default_buffer_prototype,
769806
) -> NDArrayLike:
770807
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
771-
out_array = sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
808+
out_array = sync(
809+
self._async_array._get_selection(
810+
indexer=indexer, out=out, fields=fields, prototype=prototype
811+
)
812+
)
772813

773814
# restore shape
774815
out_array = out_array.reshape(indexer.sel_shape)
775816
return out_array
776817

777818
def set_coordinate_selection(
778-
self, selection: CoordinateSelection, value: NDArrayLike, fields: Fields | None = None
819+
self,
820+
selection: CoordinateSelection,
821+
value: NDArrayLike,
822+
*,
823+
fields: Fields | None = None,
824+
prototype: BufferPrototype = default_buffer_prototype,
779825
) -> None:
780826
# setup indexer
781827
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
@@ -792,25 +838,33 @@ def set_coordinate_selection(
792838
if hasattr(value, "shape") and len(value.shape) > 1:
793839
value = value.reshape(-1)
794840

795-
sync(self._async_array._set_selection(indexer, value, fields=fields))
841+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
796842

797843
def get_block_selection(
798844
self,
799845
selection: BlockSelection,
846+
*,
800847
out: NDBuffer | None = None,
801848
fields: Fields | None = None,
849+
prototype: BufferPrototype = default_buffer_prototype,
802850
) -> NDArrayLike:
803851
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
804-
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
852+
return sync(
853+
self._async_array._get_selection(
854+
indexer=indexer, out=out, fields=fields, prototype=prototype
855+
)
856+
)
805857

806858
def set_block_selection(
807859
self,
808860
selection: BlockSelection,
809861
value: NDArrayLike,
862+
*,
810863
fields: Fields | None = None,
864+
prototype: BufferPrototype = default_buffer_prototype,
811865
) -> None:
812866
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
813-
sync(self._async_array._set_selection(indexer, value, fields=fields))
867+
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))
814868

815869
@property
816870
def vindex(self) -> VIndex:

src/zarr/array_spec.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Literal
5+
6+
import numpy as np
7+
8+
from zarr.buffer import BufferPrototype
9+
from zarr.common import ChunkCoords, parse_dtype, parse_fill_value, parse_order, parse_shapelike
10+
11+
12+
@dataclass(frozen=True)
13+
class ArraySpec:
14+
shape: ChunkCoords
15+
dtype: np.dtype[Any]
16+
fill_value: Any
17+
order: Literal["C", "F"]
18+
prototype: BufferPrototype
19+
20+
def __init__(
21+
self,
22+
shape: ChunkCoords,
23+
dtype: np.dtype[Any],
24+
fill_value: Any,
25+
order: Literal["C", "F"],
26+
prototype: BufferPrototype,
27+
) -> None:
28+
shape_parsed = parse_shapelike(shape)
29+
dtype_parsed = parse_dtype(dtype)
30+
fill_value_parsed = parse_fill_value(fill_value)
31+
order_parsed = parse_order(order)
32+
33+
object.__setattr__(self, "shape", shape_parsed)
34+
object.__setattr__(self, "dtype", dtype_parsed)
35+
object.__setattr__(self, "fill_value", fill_value_parsed)
36+
object.__setattr__(self, "order", order_parsed)
37+
object.__setattr__(self, "prototype", prototype)
38+
39+
@property
40+
def ndim(self) -> int:
41+
return len(self.shape)

0 commit comments

Comments
 (0)