From 627df2272a848e12da2c3f2dd0d4e1e6063c2b4d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 4 Jun 2024 21:27:59 +0200 Subject: [PATCH 1/2] Add typing --- xarray/backends/api.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 76fcac62cd3..561cd7abc2a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -55,6 +55,7 @@ JoinOptions, NestedSequence, T_Chunks, + T_Dataset, ) T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] @@ -231,14 +232,15 @@ def _get_mtime(filename_or_obj): return mtime -def _protect_dataset_variables_inplace(dataset, cache): +def _protect_dataset_variables_inplace(dataset: T_Dataset, cache: bool) -> None: for name, variable in dataset.variables.items(): if name not in dataset._indexes: # no need to protect IndexVariable objects data = indexing.CopyOnWriteArray(variable._data) if cache: - data = indexing.MemoryCachedArray(data) - variable.data = data + variable.data = indexing.MemoryCachedArray(data) + else: + variable.data = data def _finalize_store(write, store): @@ -305,7 +307,7 @@ def load_dataarray(filename_or_obj, **kwargs): def _chunk_ds( - backend_ds, + backend_ds: T_Dataset, filename_or_obj, engine, chunks, @@ -314,7 +316,7 @@ def _chunk_ds( chunked_array_type, from_array_kwargs, **extra_tokens, -): +) -> T_Dataset: chunkmanager = guess_chunkmanager(chunked_array_type) # TODO refactor to move this dask-specific logic inside the DaskManager class @@ -322,7 +324,7 @@ def _chunk_ds( from dask.base import tokenize mtime = _get_mtime(filename_or_obj) - token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) + token: Any = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) name_prefix = "open_dataset-" else: # not used @@ -347,7 +349,7 @@ def _chunk_ds( def _dataset_from_backend_dataset( - backend_ds, + backend_ds: T_Dataset, filename_or_obj, engine, chunks, @@ -357,7 +359,7 @@ def _dataset_from_backend_dataset( chunked_array_type, from_array_kwargs, **extra_tokens, -): +) -> T_Dataset: if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: raise ValueError( f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}." From 44199793952a946aeb410996ca043508471b7e38 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 4 Jun 2024 21:32:51 +0200 Subject: [PATCH 2/2] data is already a valid duckarray, no need to recheck --- xarray/backends/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 561cd7abc2a..a0b9b4ecd18 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -238,9 +238,9 @@ def _protect_dataset_variables_inplace(dataset: T_Dataset, cache: bool) -> None: # no need to protect IndexVariable objects data = indexing.CopyOnWriteArray(variable._data) if cache: - variable.data = indexing.MemoryCachedArray(data) + variable._data = indexing.MemoryCachedArray(data) else: - variable.data = data + variable._data = data def _finalize_store(write, store):