diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 83c7b154658..a985fbf035e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -82,6 +82,7 @@ from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array, sparse_array_type from .utils import ( + CopyableMutableMapping, Default, Frozen, HybridMappingProxy, @@ -704,7 +705,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping): _encoding: Optional[Dict[Hashable, Any]] _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, Index]] - _variables: Dict[Hashable, Variable] + _variables: CopyableMutableMapping[Hashable, Variable] __slots__ = ( "_attrs", @@ -1068,14 +1069,14 @@ def persist(self, **kwargs) -> "Dataset": @classmethod def _construct_direct( cls, - variables, - coord_names, - dims=None, - attrs=None, - indexes=None, - encoding=None, - close=None, - ): + variables: Mapping[Hashable, Variable], + coord_names: Set[Hashable], + dims: Dict[Hashable, int] = None, + attrs: Dict[Hashable, Any] = None, + indexes: Dict[Hashable, Index] = None, + encoding: Dict[Hashable, Any] = None, + close: Callable[[], None] = None, + ) -> "Dataset": """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -1093,7 +1094,7 @@ def _construct_direct( def _replace( self, - variables: Dict[Hashable, Variable] = None, + variables: Mapping[Hashable, Variable] = None, coord_names: Set[Hashable] = None, dims: Dict[Any, int] = None, attrs: Union[Dict[Hashable, Any], None, Default] = _default, @@ -2359,7 +2360,7 @@ def isel( indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims) variables = {} - dims: Dict[Hashable, Tuple[int, ...]] = {} + dims: Dict[Hashable, int] = {} coord_names = self._coord_names.copy() indexes = self._indexes.copy() if self._indexes is not None else None diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 89e3714ffff..edc0f195484 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -16,17 +16,23 @@ Container, Dict, Hashable, + ItemsView, Iterable, Iterator, + KeysView, Mapping, MutableMapping, MutableSet, Optional, + Protocol, Sequence, Tuple, TypeVar, Union, + ValuesView, cast, + overload, + runtime_checkable, ) import numpy as np @@ -444,6 +450,93 @@ def compat_dict_union( return new_dict +TCopyableMutableMapping = TypeVar( + "TCopyableMutableMapping", bound="CopyableMutableMapping" +) + + +@runtime_checkable +class CopyableMutableMapping(Protocol[K, V]): + """ + Protocol for the type behaviour of a class which acts essentially like a mutable mapping plus a copy method. + Classes implementing this protocol are used to store variables inside Dataset internally. + + This type flexibility allows someone to extend Dataset to have different rules on what variables can be stored, + which they would do by implementing their own custom VariableMapping class. As long as their class conforms to this + protocol then all type checks will pass. + + By default the storage class is just a dict[Hashable, Variable]. + + (It would be nice to inherit from MutableMapping[Hashable, Variable] so that we didn't have to write out all these + abstract typed methods, but protocols can only inherit from protocols, and MutableMapping is not a Protocol.) + """ + + def __len__(self) -> int: + ... + + def __iter__(self) -> Iterator[K]: + ... + + def __contains__(self, key: K) -> bool: + ... + + def __getitem__(self, key: K) -> V: + ... + + def get(self, key: K, default: Optional[V]): + ... + + def keys(self) -> KeysView[K]: + ... + + def items(self) -> ItemsView[K, V]: + ... + + def values(self) -> ValuesView[V]: + ... + + def __eq__(self, other: Any) -> bool: + ... + + def __setitem__(self, key: K, value: V): + ... + + def __delitem__(self, key: K): + ... + + @overload + def pop(self, key: K) -> V: + ... + + @overload + def pop(self, key: K, default: V = ...) -> V: + ... + + def pop(self, key, default=None): + ... + + def popitem(self) -> Tuple[K, V]: + ... + + @overload + def update(self, other: Mapping[K, V], **kwargs: V) -> None: + ... + + @overload + def update(self, other: Iterable[Tuple[K, V]], **kwargs: V) -> None: + ... + + @overload + def update(self, **kwargs: V) -> None: + ... + + def update(self, other, **kwargs): + ... + + def copy(self: TCopyableMutableMapping) -> TCopyableMutableMapping: + ... + + class Frozen(Mapping[K, V]): """Wrapper around an object implementing the mapping interface to make it immutable. If you really want to modify the mapping, the mutable version is diff --git a/xarray/testing.py b/xarray/testing.py index 40ca12852b9..2385ad8209c 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -9,6 +9,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.indexes import Index, default_indexes +from xarray.core.utils import CopyableMutableMapping from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -306,7 +307,7 @@ def _assert_dataarray_invariants(da: DataArray): def _assert_dataset_invariants(ds: Dataset): - assert isinstance(ds._variables, dict), type(ds._variables) + assert isinstance(ds._variables, CopyableMutableMapping), type(ds._variables) assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables for k, v in ds._variables.items(): _assert_variable_invariants(v, k) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 16148c21b43..74bd07d7c15 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1,9 +1,22 @@ import pickle import sys import warnings +from collections.abc import ItemsView, KeysView, Mapping, ValuesView from copy import copy, deepcopy from io import StringIO from textwrap import dedent +from typing import ( + Any, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + Optional, + Tuple, + TypeVar, + overload, +) import numpy as np import pandas as pd @@ -29,7 +42,7 @@ from xarray.core.common import duck_array_ops, full_like from xarray.core.indexes import Index from xarray.core.pycompat import integer_types, sparse_array_type -from xarray.core.utils import is_scalar +from xarray.core.utils import CopyableMutableMapping, is_scalar from . import ( InaccessibleArray, @@ -6588,3 +6601,103 @@ def test_string_keys_typing() -> None: ds = xr.Dataset(dict(x=da)) mapping = {"y": da} ds.assign(variables=mapping) + + +TCustomMirroredMapping = TypeVar( + "TCustomMirroredMapping", bound="CustomMirroredMapping" +) +K = TypeVar("K") # keys type +V = TypeVar("V") # values type + + +class CustomMirroredMapping(CopyableMutableMapping): + """ + Test implementation of the CopyableMutableMapping protocol. + + Whilst this implementation presents a dict-like API, it differs from dict by storing a copy of every key-value pair + in a hidden "mirror" dict internally. The behaviour is therefore the exact opposite of the custom variable mapping + we need for DataTree: instead of ensuring no collisions between variables and children, it ensures that there is a + collision between every stored object and its mirrored duplicate. + + See GH issue #6086 and GH pull #5961 for more context. + """ + + def __init__(self, *args, **kwargs): + self._original = dict(*args, **kwargs) + self._mirrored = dict(*args, **kwargs) + + def __len__(self) -> int: + return len(self._original) + + def __iter__(self) -> Iterator[K]: + return iter(self._original.keys()) + + def __contains__(self, key: K) -> bool: + return key in self._original.keys() + + def __getitem__(self, key: K) -> V: + return self._original[key] + + def get(self, key: K, default: Optional[V]): + return self._original.get(key, default=default) + + def keys(self) -> KeysView[K]: + return self._original.keys() + + def items(self) -> ItemsView[K, V]: + return self._original.items() + + def values(self) -> ValuesView[V]: + return self._original.values() + + def __eq__(self, other: Any) -> bool: + return self._original == other + + def __setitem__(self, key: K, value: V): + self._original[key] = value + self._mirrored[key] = value + + def __delitem__(self, key: K): + del self._original[key] + del self._mirrored[key] + + def pop(self, key, default=None): + ... + + def popitem(self) -> Tuple[K, V]: + ... + + def update(self, other, **kwargs): + ... + + def copy(self: TCustomMirroredMapping) -> TCustomMirroredMapping: + copy = CustomMirroredMapping() + copy._original = self._original.copy() + copy._mirrored = self._mirrored.copy() + return copy + + +class TestCustomVariableMapping: + def test_instantiate(self): + var1 = Variable(data=0, dims=()) + var2 = Variable(data=1, dims=()) + CustomMirroredMapping({"a": var1, "b": var2}) + + def test_construct_direct_dataset(self): + var1 = Variable(data=0, dims=()) + var2 = Variable(data=1, dims=()) + cm: CustomMirroredMapping[Any, Variable] = CustomMirroredMapping( + {"a": var1, "b": var2} + ) + + expected = Dataset._construct_direct( + variables={"a": var1, "b": var2}, coord_names=set(), dims=None + ) + actual = Dataset._construct_direct(variables=cm, coord_names=set(), dims=None) + assert_equal(actual, expected) + + def test_replace(self): + ... + + def test_mirror_new_variable(self): + ...