Skip to content

Type protocol for internal variable mapping #6086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
23 changes: 12 additions & 11 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array, sparse_array_type
from .utils import (
CopyableMutableMapping,
Default,
Frozen,
HybridMappingProxy,
Expand Down Expand Up @@ -704,7 +705,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping):
_encoding: Optional[Dict[Hashable, Any]]
_close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, Index]]
_variables: Dict[Hashable, Variable]
_variables: CopyableMutableMapping[Hashable, Variable]

__slots__ = (
"_attrs",
Expand Down Expand Up @@ -1068,14 +1069,14 @@ def persist(self, **kwargs) -> "Dataset":
@classmethod
def _construct_direct(
cls,
variables,
coord_names,
dims=None,
attrs=None,
indexes=None,
encoding=None,
close=None,
):
variables: Mapping[Hashable, Variable],
coord_names: Set[Hashable],
dims: Dict[Hashable, int] = None,
attrs: Dict[Hashable, Any] = None,
indexes: Dict[Hashable, Index] = None,
encoding: Dict[Hashable, Any] = None,
close: Callable[[], None] = None,
) -> "Dataset":
"""Shortcut around __init__ for internal use when we want to skip
costly validation
"""
Expand All @@ -1093,7 +1094,7 @@ def _construct_direct(

def _replace(
self,
variables: Dict[Hashable, Variable] = None,
variables: Mapping[Hashable, Variable] = None,
coord_names: Set[Hashable] = None,
dims: Dict[Any, int] = None,
attrs: Union[Dict[Hashable, Any], None, Default] = _default,
Expand Down Expand Up @@ -2359,7 +2360,7 @@ def isel(
indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims)

variables = {}
dims: Dict[Hashable, Tuple[int, ...]] = {}
dims: Dict[Hashable, int] = {}
coord_names = self._coord_names.copy()
indexes = self._indexes.copy() if self._indexes is not None else None

Expand Down
93 changes: 93 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,23 @@
Container,
Dict,
Hashable,
ItemsView,
Iterable,
Iterator,
KeysView,
Mapping,
MutableMapping,
MutableSet,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
ValuesView,
cast,
overload,
runtime_checkable,
)

import numpy as np
Expand Down Expand Up @@ -444,6 +450,93 @@ def compat_dict_union(
return new_dict


TCopyableMutableMapping = TypeVar(
"TCopyableMutableMapping", bound="CopyableMutableMapping"
)


@runtime_checkable
class CopyableMutableMapping(Protocol[K, V]):
"""
Protocol for the type behaviour of a class which acts essentially like a mutable mapping plus a copy method.
Classes implementing this protocol are used to store variables inside Dataset internally.

This type flexibility allows someone to extend Dataset to have different rules on what variables can be stored,
which they would do by implementing their own custom VariableMapping class. As long as their class conforms to this
protocol then all type checks will pass.

By default the storage class is just a dict[Hashable, Variable].

(It would be nice to inherit from MutableMapping[Hashable, Variable] so that we didn't have to write out all these
abstract typed methods, but protocols can only inherit from protocols, and MutableMapping is not a Protocol.)
"""

def __len__(self) -> int:
...

def __iter__(self) -> Iterator[K]:
...

def __contains__(self, key: K) -> bool:
...

def __getitem__(self, key: K) -> V:
...

def get(self, key: K, default: Optional[V]):
...

def keys(self) -> KeysView[K]:
...

def items(self) -> ItemsView[K, V]:
...

def values(self) -> ValuesView[V]:
...

def __eq__(self, other: Any) -> bool:
...

def __setitem__(self, key: K, value: V):
...

def __delitem__(self, key: K):
...

@overload
def pop(self, key: K) -> V:
...

@overload
def pop(self, key: K, default: V = ...) -> V:
...

def pop(self, key, default=None):
...

def popitem(self) -> Tuple[K, V]:
...

@overload
def update(self, other: Mapping[K, V], **kwargs: V) -> None:
...

@overload
def update(self, other: Iterable[Tuple[K, V]], **kwargs: V) -> None:
...

@overload
def update(self, **kwargs: V) -> None:
...

def update(self, other, **kwargs):
...

def copy(self: TCopyableMutableMapping) -> TCopyableMutableMapping:
...


class Frozen(Mapping[K, V]):
"""Wrapper around an object implementing the mapping interface to make it
immutable. If you really want to modify the mapping, the mutable version is
Expand Down
3 changes: 2 additions & 1 deletion xarray/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index, default_indexes
from xarray.core.utils import CopyableMutableMapping
from xarray.core.variable import IndexVariable, Variable

__all__ = (
Expand Down Expand Up @@ -306,7 +307,7 @@ def _assert_dataarray_invariants(da: DataArray):


def _assert_dataset_invariants(ds: Dataset):
assert isinstance(ds._variables, dict), type(ds._variables)
assert isinstance(ds._variables, CopyableMutableMapping), type(ds._variables)
assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables
for k, v in ds._variables.items():
_assert_variable_invariants(v, k)
Expand Down
115 changes: 114 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import pickle
import sys
import warnings
from collections.abc import ItemsView, KeysView, Mapping, ValuesView
from copy import copy, deepcopy
from io import StringIO
from textwrap import dedent
from typing import (
Any,
Dict,
Generic,
Hashable,
Iterable,
Iterator,
Optional,
Tuple,
TypeVar,
overload,
)

import numpy as np
import pandas as pd
Expand All @@ -29,7 +42,7 @@
from xarray.core.common import duck_array_ops, full_like
from xarray.core.indexes import Index
from xarray.core.pycompat import integer_types, sparse_array_type
from xarray.core.utils import is_scalar
from xarray.core.utils import CopyableMutableMapping, is_scalar

from . import (
InaccessibleArray,
Expand Down Expand Up @@ -6588,3 +6601,103 @@ def test_string_keys_typing() -> None:
ds = xr.Dataset(dict(x=da))
mapping = {"y": da}
ds.assign(variables=mapping)


TCustomMirroredMapping = TypeVar(
"TCustomMirroredMapping", bound="CustomMirroredMapping"
)
K = TypeVar("K") # keys type
V = TypeVar("V") # values type


class CustomMirroredMapping(CopyableMutableMapping):
"""
Test implementation of the CopyableMutableMapping protocol.

Whilst this implementation presents a dict-like API, it differs from dict by storing a copy of every key-value pair
in a hidden "mirror" dict internally. The behaviour is therefore the exact opposite of the custom variable mapping
we need for DataTree: instead of ensuring no collisions between variables and children, it ensures that there is a
collision between every stored object and its mirrored duplicate.

See GH issue #6086 and GH pull #5961 for more context.
"""

def __init__(self, *args, **kwargs):
self._original = dict(*args, **kwargs)
self._mirrored = dict(*args, **kwargs)

def __len__(self) -> int:
return len(self._original)

def __iter__(self) -> Iterator[K]:
return iter(self._original.keys())

def __contains__(self, key: K) -> bool:
return key in self._original.keys()

def __getitem__(self, key: K) -> V:
return self._original[key]

def get(self, key: K, default: Optional[V]):
return self._original.get(key, default=default)

def keys(self) -> KeysView[K]:
return self._original.keys()

def items(self) -> ItemsView[K, V]:
return self._original.items()

def values(self) -> ValuesView[V]:
return self._original.values()

def __eq__(self, other: Any) -> bool:
return self._original == other

def __setitem__(self, key: K, value: V):
self._original[key] = value
self._mirrored[key] = value

def __delitem__(self, key: K):
del self._original[key]
del self._mirrored[key]

def pop(self, key, default=None):
...

def popitem(self) -> Tuple[K, V]:
...

def update(self, other, **kwargs):
...

def copy(self: TCustomMirroredMapping) -> TCustomMirroredMapping:
copy = CustomMirroredMapping()
copy._original = self._original.copy()
copy._mirrored = self._mirrored.copy()
return copy


class TestCustomVariableMapping:
def test_instantiate(self):
var1 = Variable(data=0, dims=())
var2 = Variable(data=1, dims=())
CustomMirroredMapping({"a": var1, "b": var2})

def test_construct_direct_dataset(self):
var1 = Variable(data=0, dims=())
var2 = Variable(data=1, dims=())
cm: CustomMirroredMapping[Any, Variable] = CustomMirroredMapping(
{"a": var1, "b": var2}
)

expected = Dataset._construct_direct(
variables={"a": var1, "b": var2}, coord_names=set(), dims=None
)
actual = Dataset._construct_direct(variables=cm, coord_names=set(), dims=None)
assert_equal(actual, expected)

def test_replace(self):
...

def test_mirror_new_variable(self):
...