Skip to content

Improve performance and typing in _protect_dataset_variables_inplace #9069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
JoinOptions,
NestedSequence,
T_Chunks,
T_Dataset,
)

T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"]
Expand Down Expand Up @@ -231,14 +232,15 @@ def _get_mtime(filename_or_obj):
return mtime


def _protect_dataset_variables_inplace(dataset, cache):
def _protect_dataset_variables_inplace(dataset: T_Dataset, cache: bool) -> None:
for name, variable in dataset.variables.items():
if name not in dataset._indexes:
# no need to protect IndexVariable objects
data = indexing.CopyOnWriteArray(variable._data)
if cache:
data = indexing.MemoryCachedArray(data)
variable.data = data
variable._data = indexing.MemoryCachedArray(data)
else:
variable._data = data


def _finalize_store(write, store):
Expand Down Expand Up @@ -305,7 +307,7 @@ def load_dataarray(filename_or_obj, **kwargs):


def _chunk_ds(
backend_ds,
backend_ds: T_Dataset,
filename_or_obj,
engine,
chunks,
Expand All @@ -314,15 +316,15 @@ def _chunk_ds(
chunked_array_type,
from_array_kwargs,
**extra_tokens,
):
) -> T_Dataset:
chunkmanager = guess_chunkmanager(chunked_array_type)

# TODO refactor to move this dask-specific logic inside the DaskManager class
if isinstance(chunkmanager, DaskManager):
from dask.base import tokenize

mtime = _get_mtime(filename_or_obj)
token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
token: Any = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
name_prefix = "open_dataset-"
else:
# not used
Expand All @@ -347,7 +349,7 @@ def _chunk_ds(


def _dataset_from_backend_dataset(
backend_ds,
backend_ds: T_Dataset,
filename_or_obj,
engine,
chunks,
Expand All @@ -357,7 +359,7 @@ def _dataset_from_backend_dataset(
chunked_array_type,
from_array_kwargs,
**extra_tokens,
):
) -> T_Dataset:
if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
Expand Down
Loading