Skip to content

Commit 5a602cd

Browse files
authored
Replace SortedKeysDict with dict (#4753)
* Remove SortedKeysDict * Retain sorted reprs * Update tests * _ * Removing sorting in repr * whatsnew * Reset snapshot tests * Fix unstable ordering * Existing bug? * Overwrite doctest test with combine_by_coords bug * Revert adding ovrride to join docstring * Ordered not sorted!
1 parent 6d2a730 commit 5a602cd

File tree

7 files changed

+71
-104
lines changed

7 files changed

+71
-104
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ Internal Changes
100100
``PandasIndexAdapter`` to ``PandasIndex``, which now inherits from
101101
``xarray.Index`` (:pull:`5102`).
102102
By `Benoit Bovy <https://github.com/benbovy>`_.
103+
- Replace ``SortedKeysDict`` with python's ``dict``, given dicts are now ordered.
104+
By `Maximilian Roos <https://github.com/max-sixty>`_.
103105
- Updated the release guide for developers. Now accounts for actions that are automated via github
104106
actions. (:pull:`5274`).
105107
By `Tom Nicholas <https://github.com/TomNicholas>`_.

xarray/core/combine.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,6 @@ def combine_by_coords(
675675
they are concatenated based on the values in their dimension coordinates,
676676
not on their position in the list passed to `combine_by_coords`.
677677
678-
679678
>>> x1 = xr.Dataset(
680679
... {
681680
... "temperature": (("y", "x"), 20 * np.random.rand(6).reshape(2, 3)),
@@ -700,7 +699,7 @@ def combine_by_coords(
700699
701700
>>> x1
702701
<xarray.Dataset>
703-
Dimensions: (x: 3, y: 2)
702+
Dimensions: (y: 2, x: 3)
704703
Coordinates:
705704
* y (y) int64 0 1
706705
* x (x) int64 10 20 30
@@ -710,7 +709,7 @@ def combine_by_coords(
710709
711710
>>> x2
712711
<xarray.Dataset>
713-
Dimensions: (x: 3, y: 2)
712+
Dimensions: (y: 2, x: 3)
714713
Coordinates:
715714
* y (y) int64 2 3
716715
* x (x) int64 10 20 30
@@ -720,7 +719,7 @@ def combine_by_coords(
720719
721720
>>> x3
722721
<xarray.Dataset>
723-
Dimensions: (x: 3, y: 2)
722+
Dimensions: (y: 2, x: 3)
724723
Coordinates:
725724
* y (y) int64 2 3
726725
* x (x) int64 40 50 60
@@ -730,7 +729,7 @@ def combine_by_coords(
730729
731730
>>> xr.combine_by_coords([x2, x1])
732731
<xarray.Dataset>
733-
Dimensions: (x: 3, y: 4)
732+
Dimensions: (y: 4, x: 3)
734733
Coordinates:
735734
* y (y) int64 0 1 2 3
736735
* x (x) int64 10 20 30
@@ -740,30 +739,30 @@ def combine_by_coords(
740739
741740
>>> xr.combine_by_coords([x3, x1])
742741
<xarray.Dataset>
743-
Dimensions: (x: 6, y: 4)
742+
Dimensions: (y: 4, x: 6)
744743
Coordinates:
745-
* x (x) int64 10 20 30 40 50 60
746744
* y (y) int64 0 1 2 3
745+
* x (x) int64 10 20 30 40 50 60
747746
Data variables:
748747
temperature (y, x) float64 10.98 14.3 12.06 nan ... nan 18.89 10.44 8.293
749748
precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176
750749
751750
>>> xr.combine_by_coords([x3, x1], join="override")
752751
<xarray.Dataset>
753-
Dimensions: (x: 3, y: 4)
752+
Dimensions: (y: 2, x: 6)
754753
Coordinates:
755-
* x (x) int64 10 20 30
756-
* y (y) int64 0 1 2 3
754+
* y (y) int64 0 1
755+
* x (x) int64 10 20 30 40 50 60
757756
Data variables:
758-
temperature (y, x) float64 10.98 14.3 12.06 10.9 ... 18.89 10.44 8.293
757+
temperature (y, x) float64 10.98 14.3 12.06 2.365 ... 18.89 10.44 8.293
759758
precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176
760759
761760
>>> xr.combine_by_coords([x1, x2, x3])
762761
<xarray.Dataset>
763-
Dimensions: (x: 6, y: 4)
762+
Dimensions: (y: 4, x: 6)
764763
Coordinates:
765-
* x (x) int64 10 20 30 40 50 60
766764
* y (y) int64 0 1 2 3
765+
* x (x) int64 10 20 30 40 50 60
767766
Data variables:
768767
temperature (y, x) float64 10.98 14.3 12.06 nan ... 18.89 10.44 8.293
769768
precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176

xarray/core/dataset.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
Default,
8484
Frozen,
8585
HybridMappingProxy,
86-
SortedKeysDict,
86+
OrderedSet,
8787
_default,
8888
decode_numpy_dict_values,
8989
drop_dims_from_indexers,
@@ -654,7 +654,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping):
654654
... )
655655
>>> ds
656656
<xarray.Dataset>
657-
Dimensions: (time: 3, x: 2, y: 2)
657+
Dimensions: (x: 2, y: 2, time: 3)
658658
Coordinates:
659659
lon (x, y) float64 -99.83 -99.32 -99.79 -99.23
660660
lat (x, y) float64 42.25 42.21 42.63 42.59
@@ -803,7 +803,7 @@ def dims(self) -> Mapping[Hashable, int]:
803803
See `Dataset.sizes` and `DataArray.sizes` for consistently named
804804
properties.
805805
"""
806-
return Frozen(SortedKeysDict(self._dims))
806+
return Frozen(self._dims)
807807

808808
@property
809809
def sizes(self) -> Mapping[Hashable, int]:
@@ -1344,7 +1344,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset":
13441344
if (var_name,) == var.dims:
13451345
indexes[var_name] = var._to_xindex()
13461346

1347-
needed_dims: Set[Hashable] = set()
1347+
needed_dims: OrderedSet[Hashable] = OrderedSet()
13481348
for v in variables.values():
13491349
needed_dims.update(v.dims)
13501350

@@ -1992,7 +1992,7 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]:
19921992
"This can be fixed by calling unify_chunks()."
19931993
)
19941994
chunks[dim] = c
1995-
return Frozen(SortedKeysDict(chunks))
1995+
return Frozen(chunks)
19961996

19971997
def chunk(
19981998
self,
@@ -6384,30 +6384,32 @@ def filter_by_attrs(self, **kwargs):
63846384
63856385
Examples
63866386
--------
6387-
>>> # Create an example dataset:
63886387
>>> temp = 15 + 8 * np.random.randn(2, 2, 3)
63896388
>>> precip = 10 * np.random.rand(2, 2, 3)
63906389
>>> lon = [[-99.83, -99.32], [-99.79, -99.23]]
63916390
>>> lat = [[42.25, 42.21], [42.63, 42.59]]
63926391
>>> dims = ["x", "y", "time"]
63936392
>>> temp_attr = dict(standard_name="air_potential_temperature")
63946393
>>> precip_attr = dict(standard_name="convective_precipitation_flux")
6394+
63956395
>>> ds = xr.Dataset(
6396-
... {
6397-
... "temperature": (dims, temp, temp_attr),
6398-
... "precipitation": (dims, precip, precip_attr),
6399-
... },
6400-
... coords={
6401-
... "lon": (["x", "y"], lon),
6402-
... "lat": (["x", "y"], lat),
6403-
... "time": pd.date_range("2014-09-06", periods=3),
6404-
... "reference_time": pd.Timestamp("2014-09-05"),
6405-
... },
6396+
... dict(
6397+
... temperature=(dims, temp, temp_attr),
6398+
... precipitation=(dims, precip, precip_attr),
6399+
... ),
6400+
... coords=dict(
6401+
... lon=(["x", "y"], lon),
6402+
... lat=(["x", "y"], lat),
6403+
... time=pd.date_range("2014-09-06", periods=3),
6404+
... reference_time=pd.Timestamp("2014-09-05"),
6405+
... ),
64066406
... )
6407-
>>> # Get variables matching a specific standard_name.
6407+
6408+
Get variables matching a specific standard_name:
6409+
64086410
>>> ds.filter_by_attrs(standard_name="convective_precipitation_flux")
64096411
<xarray.Dataset>
6410-
Dimensions: (time: 3, x: 2, y: 2)
6412+
Dimensions: (x: 2, y: 2, time: 3)
64116413
Coordinates:
64126414
lon (x, y) float64 -99.83 -99.32 -99.79 -99.23
64136415
lat (x, y) float64 42.25 42.21 42.63 42.59
@@ -6416,11 +6418,13 @@ def filter_by_attrs(self, **kwargs):
64166418
Dimensions without coordinates: x, y
64176419
Data variables:
64186420
precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805
6419-
>>> # Get all variables that have a standard_name attribute.
6421+
6422+
Get all variables that have a standard_name attribute:
6423+
64206424
>>> standard_name = lambda v: v is not None
64216425
>>> ds.filter_by_attrs(standard_name=standard_name)
64226426
<xarray.Dataset>
6423-
Dimensions: (time: 3, x: 2, y: 2)
6427+
Dimensions: (x: 2, y: 2, time: 3)
64246428
Coordinates:
64256429
lon (x, y) float64 -99.83 -99.32 -99.79 -99.23
64266430
lat (x, y) float64 42.25 42.21 42.63 42.59

xarray/core/groupby.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def quantile(
632632
* x (x) int64 0 1
633633
>>> ds.groupby("y").quantile([0, 0.5, 1], dim=...)
634634
<xarray.Dataset>
635-
Dimensions: (quantile: 3, y: 2)
635+
Dimensions: (y: 2, quantile: 3)
636636
Coordinates:
637637
* quantile (quantile) float64 0.0 0.5 1.0
638638
* y (y) int64 1 2

xarray/core/utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -475,40 +475,6 @@ def __len__(self) -> int:
475475
return len(self._keys)
476476

477477

478-
class SortedKeysDict(MutableMapping[K, V]):
479-
"""An wrapper for dictionary-like objects that always iterates over its
480-
items in sorted order by key but is otherwise equivalent to the underlying
481-
mapping.
482-
"""
483-
484-
__slots__ = ("mapping",)
485-
486-
def __init__(self, mapping: MutableMapping[K, V] = None):
487-
self.mapping = {} if mapping is None else mapping
488-
489-
def __getitem__(self, key: K) -> V:
490-
return self.mapping[key]
491-
492-
def __setitem__(self, key: K, value: V) -> None:
493-
self.mapping[key] = value
494-
495-
def __delitem__(self, key: K) -> None:
496-
del self.mapping[key]
497-
498-
def __iter__(self) -> Iterator[K]:
499-
# see #4571 for the reason of the type ignore
500-
return iter(sorted(self.mapping)) # type: ignore[type-var]
501-
502-
def __len__(self) -> int:
503-
return len(self.mapping)
504-
505-
def __contains__(self, key: object) -> bool:
506-
return key in self.mapping
507-
508-
def __repr__(self) -> str:
509-
return "{}({!r})".format(type(self).__name__, self.mapping)
510-
511-
512478
class OrderedSet(MutableSet[T]):
513479
"""A simple ordered set.
514480

xarray/tests/test_dataset.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def create_test_data(seed=None, add_attrs=True):
7272
_dims = {"dim1": 8, "dim2": 9, "dim3": 10}
7373

7474
obj = Dataset()
75-
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
7675
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
7776
obj["dim3"] = ("dim3", list("abcdefghij"))
77+
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
7878
for v, dims in sorted(_vars.items()):
7979
data = rs.normal(size=tuple(_dims[d] for d in dims))
8080
obj[v] = (dims, data)
@@ -205,11 +205,11 @@ def test_repr(self):
205205
expected = dedent(
206206
"""\
207207
<xarray.Dataset>
208-
Dimensions: (dim1: 8, dim2: 9, dim3: 10, time: 20)
208+
Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8)
209209
Coordinates:
210-
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20
211210
* dim2 (dim2) float64 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0
212211
* dim3 (dim3) %s 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
212+
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20
213213
numbers (dim3) int64 0 1 2 0 0 1 1 2 2 3
214214
Dimensions without coordinates: dim1
215215
Data variables:
@@ -357,14 +357,14 @@ def test_info(self):
357357
"""\
358358
xarray.Dataset {
359359
dimensions:
360-
\tdim1 = 8 ;
361360
\tdim2 = 9 ;
362-
\tdim3 = 10 ;
363361
\ttime = 20 ;
362+
\tdim1 = 8 ;
363+
\tdim3 = 10 ;
364364
365365
variables:
366-
\tdatetime64[ns] time(time) ;
367366
\tfloat64 dim2(dim2) ;
367+
\tdatetime64[ns] time(time) ;
368368
\tfloat64 var1(dim1, dim2) ;
369369
\t\tvar1:foo = variable ;
370370
\tfloat64 var2(dim1, dim2) ;
@@ -560,14 +560,13 @@ def test_constructor_with_coords(self):
560560
def test_properties(self):
561561
ds = create_test_data()
562562
assert ds.dims == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20}
563-
assert list(ds.dims) == sorted(ds.dims)
564563
assert ds.sizes == ds.dims
565564

566565
# These exact types aren't public API, but this makes sure we don't
567566
# change them inadvertently:
568567
assert isinstance(ds.dims, utils.Frozen)
569-
assert isinstance(ds.dims.mapping, utils.SortedKeysDict)
570-
assert type(ds.dims.mapping.mapping) is dict
568+
assert isinstance(ds.dims.mapping, dict)
569+
assert type(ds.dims.mapping) is dict
571570

572571
assert list(ds) == list(ds.data_vars)
573572
assert list(ds.keys()) == list(ds.data_vars)
@@ -593,7 +592,7 @@ def test_properties(self):
593592
assert "dim2" in repr(ds.indexes)
594593
assert all([isinstance(idx, pd.Index) for idx in ds.indexes.values()])
595594

596-
assert list(ds.coords) == ["time", "dim2", "dim3", "numbers"]
595+
assert list(ds.coords) == ["dim2", "dim3", "time", "numbers"]
597596
assert "dim2" in ds.coords
598597
assert "numbers" in ds.coords
599598
assert "var1" not in ds.coords
@@ -1035,14 +1034,14 @@ def test_isel(self):
10351034
ValueError,
10361035
match=r"Dimensions {'not_a_dim'} do not exist. Expected "
10371036
r"one or more of "
1038-
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
1037+
r"[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*",
10391038
):
10401039
data.isel(not_a_dim=slice(0, 2))
10411040
with pytest.warns(
10421041
UserWarning,
10431042
match=r"Dimensions {'not_a_dim'} do not exist. "
10441043
r"Expected one or more of "
1045-
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
1044+
r"[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*",
10461045
):
10471046
data.isel(not_a_dim=slice(0, 2), missing_dims="warn")
10481047
assert_identical(data, data.isel(not_a_dim=slice(0, 2), missing_dims="ignore"))
@@ -4823,10 +4822,10 @@ def test_reduce(self):
48234822
assert_equal(data.min(dim=["dim1"]), data.min(dim="dim1"))
48244823

48254824
for reduct, expected in [
4826-
("dim2", ["dim1", "dim3", "time"]),
4827-
(["dim2", "time"], ["dim1", "dim3"]),
4828-
(("dim2", "time"), ["dim1", "dim3"]),
4829-
((), ["dim1", "dim2", "dim3", "time"]),
4825+
("dim2", ["dim3", "time", "dim1"]),
4826+
(["dim2", "time"], ["dim3", "dim1"]),
4827+
(("dim2", "time"), ["dim3", "dim1"]),
4828+
((), ["dim2", "dim3", "time", "dim1"]),
48304829
]:
48314830
actual = list(data.min(dim=reduct).dims)
48324831
assert actual == expected
@@ -4876,21 +4875,24 @@ def test_reduce_cumsum(self):
48764875
)
48774876
assert_identical(expected, data.cumsum())
48784877

4879-
def test_reduce_cumsum_test_dims(self):
4878+
@pytest.mark.parametrize(
4879+
"reduct, expected",
4880+
[
4881+
("dim1", ["dim2", "dim3", "time", "dim1"]),
4882+
("dim2", ["dim3", "time", "dim1", "dim2"]),
4883+
("dim3", ["dim2", "time", "dim1", "dim3"]),
4884+
("time", ["dim2", "dim3", "dim1"]),
4885+
],
4886+
)
4887+
@pytest.mark.parametrize("func", ["cumsum", "cumprod"])
4888+
def test_reduce_cumsum_test_dims(self, reduct, expected, func):
48804889
data = create_test_data()
4881-
for cumfunc in ["cumsum", "cumprod"]:
4882-
with pytest.raises(ValueError, match=r"Dataset does not contain"):
4883-
getattr(data, cumfunc)(dim="bad_dim")
4884-
4885-
# ensure dimensions are correct
4886-
for reduct, expected in [
4887-
("dim1", ["dim1", "dim2", "dim3", "time"]),
4888-
("dim2", ["dim1", "dim2", "dim3", "time"]),
4889-
("dim3", ["dim1", "dim2", "dim3", "time"]),
4890-
("time", ["dim1", "dim2", "dim3"]),
4891-
]:
4892-
actual = getattr(data, cumfunc)(dim=reduct).dims
4893-
assert list(actual) == expected
4890+
with pytest.raises(ValueError, match=r"Dataset does not contain"):
4891+
getattr(data, func)(dim="bad_dim")
4892+
4893+
# ensure dimensions are correct
4894+
actual = getattr(data, func)(dim=reduct).dims
4895+
assert list(actual) == expected
48944896

48954897
def test_reduce_non_numeric(self):
48964898
data1 = create_test_data(seed=44)

0 commit comments

Comments
 (0)