Skip to content

[v3] Buffer ensure correct subclass based on the BufferPrototype argument #1974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/zarr/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def create_zero_length(cls) -> Self:

@classmethod
def from_array_like(cls, array_like: ArrayLike) -> Self:
"""Create a new buffer of a array-like object
"""Create a new buffer of an array-like object

Parameters
----------
Expand All @@ -159,6 +159,29 @@ def from_array_like(cls, array_like: ArrayLike) -> Self:
"""
return cls(array_like)

@classmethod
def from_buffer(cls, buffer: Buffer) -> Self:
"""Create a new buffer of an existing Buffer

This is useful if you want to ensure that an existing buffer is
of the correct subclass of Buffer. E.g., MemoryStore uses this
to return a buffer instance of the subclass specified by its
BufferPrototype argument.

Typically, this only copies data if the data has to be moved between
memory types, such as from host to device memory.

Parameters
----------
buffer
buffer object.

Returns
-------
A new buffer representing the content of the input buffer
"""
return cls.from_array_like(buffer.as_array_like())

@classmethod
def from_bytes(cls, bytes_like: BytesLike) -> Self:
"""Create a new buffer of a bytes-like object (host memory)
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def get(
try:
value = self._store_dict[key]
start, length = _normalize_interval_index(value, byte_range)
return value[start : start + length]
return prototype.buffer.from_buffer(value[start : start + length])
except KeyError:
return None

Expand Down
6 changes: 3 additions & 3 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import fsspec

from zarr.abc.store import Store
from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype
from zarr.buffer import Buffer, BufferPrototype
from zarr.common import OpenMode
from zarr.store.core import _dereference_path

Expand Down Expand Up @@ -84,7 +84,7 @@ def __repr__(self) -> str:
async def get(
self,
key: str,
prototype: BufferPrototype = default_buffer_prototype,
prototype: BufferPrototype,
byte_range: tuple[int | None, int | None] | None = None,
) -> Buffer | None:
path = _dereference_path(self.path, key)
Expand All @@ -99,7 +99,7 @@ async def get(
end = length
else:
end = None
value: Buffer = prototype.buffer.from_bytes(
value = prototype.buffer.from_bytes(
await (
self._fs._cat_file(path, start=byte_range[0], end=end)
if byte_range
Expand Down
5 changes: 4 additions & 1 deletion tests/v3/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ async def get(
) -> Buffer | None:
if "json" not in key:
assert prototype.buffer is MyBuffer
return await super().get(key, byte_range)
ret = await super().get(key=key, prototype=prototype, byte_range=byte_range)
if ret is not None:
assert isinstance(ret, prototype.buffer)
return ret


def test_nd_array_like(xp):
Expand Down
2 changes: 1 addition & 1 deletion tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def test_basic():
data = b"hello"
await store.set("foo", Buffer.from_bytes(data))
assert await store.exists("foo")
assert (await store.get("foo")).to_bytes() == data
assert (await store.get("foo", prototype=default_buffer_prototype)).to_bytes() == data
out = await store.get_partial_values(
prototype=default_buffer_prototype, key_ranges=[("foo", (1, None))]
)
Expand Down