Skip to content

Handle kwargs better in store #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cubed_xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from importlib.metadata import version


try:
__version__ = version("cubed-xarray")
except Exception:
Expand Down
28 changes: 23 additions & 5 deletions cubed_xarray/cubedmanager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, Union

import numpy as np

from tlz import partition

from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint


if TYPE_CHECKING:
from xarray.core.types import T_Chunks, T_NormalizedChunks
from cubed import Array as CubedArray
from xarray.core.types import T_Chunks, T_NormalizedChunks


class CubedManager(ChunkManagerEntrypoint["CubedArray"]):
Expand Down Expand Up @@ -204,6 +201,27 @@ def store(
"""Used when writing to any backend."""
from cubed.core.ops import store

compute = kwargs.pop("compute", True)
if not compute:
raise NotImplementedError("Delayed compute is not supported.")

lock = kwargs.pop("lock", None)
if lock:
raise NotImplementedError("Locking is not supported.")

regions = kwargs.pop("regions", None)
if regions:
# regions is either a tuple of slices or a collection of tuples of slices
if isinstance(regions, tuple):
regions = [regions]
for t in regions:
if not all(r == slice(None) for r in t):
raise NotImplementedError(
"Only whole slices are supported for regions."
)

kwargs.pop("flush", None) # not used

return store(
sources,
targets,
Expand Down
53 changes: 47 additions & 6 deletions cubed_xarray/tests/test_wrapping.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,63 @@
import sys

import cubed
import pytest
import xarray as xr
from cubed.runtime.create import create_executor
from xarray.namedarray.parallelcompat import list_chunkmanagers
import cubed
from xarray.tests import assert_allclose, create_test_data

from cubed_xarray.cubedmanager import CubedManager

EXECUTORS = [create_executor("single-threaded")]

if sys.version_info >= (3, 11):
EXECUTORS.append(create_executor("processes"))


@pytest.fixture(
scope="module",
params=EXECUTORS,
ids=[executor.name for executor in EXECUTORS],
)
def executor(request):
return request.param


class TestDiscoverCubedManager:
def test_list_cubedmanager(self):
chunkmanagers = list_chunkmanagers()
assert 'cubed' in chunkmanagers
assert isinstance(chunkmanagers['cubed'], CubedManager)
assert "cubed" in chunkmanagers
assert isinstance(chunkmanagers["cubed"], CubedManager)

def test_chunk(self):
da = xr.DataArray([1, 2], dims='x')
chunked = da.chunk(x=1, chunked_array_type='cubed')
da = xr.DataArray([1, 2], dims="x")
chunked = da.chunk(x=1, chunked_array_type="cubed")
assert isinstance(chunked.data, cubed.Array)
assert chunked.chunksizes == {'x': (1, 1)}
assert chunked.chunksizes == {"x": (1, 1)}

# TODO test cubed is default when dask not installed

# TODO test dask is default over cubed when both installed


def test_to_zarr(tmpdir, executor):
spec = cubed.Spec(allowed_mem="200MB", executor=executor)

original = create_test_data().chunk(
chunked_array_type="cubed", from_array_kwargs={"spec": spec}
)

filename = tmpdir / "out.zarr"
original.to_zarr(filename)

with xr.open_dataset(
filename,
chunks="auto",
engine="zarr",
chunked_array_type="cubed",
from_array_kwargs={"spec": spec},
) as restored:
assert isinstance(restored.var1.data, cubed.Array)
computed = restored.compute()
assert_allclose(original, computed)