Skip to content

Commit a4c0210

Browse files
author
Joseph Hamman
committed
Merge branch 'master' of https://github.com/pydata/xarray into no_more_ordereddict
2 parents 5dc458f + 3f0049f commit a4c0210

File tree

6 files changed

+113
-68
lines changed

6 files changed

+113
-68
lines changed

doc/whats-new.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ Breaking changes
4242

4343
(:issue:`3222`, :issue:`3293`, :issue:`3340`, :issue:`3346`, :issue:`3358`).
4444
By `Guido Imperiale <https://github.com/crusaderky>`_.
45+
- Dropped the 'drop=False' optional parameter from :meth:`Variable.isel`.
46+
It was unused and doesn't make sense for a Variable.
47+
(:pull:`3375`) by `Guido Imperiale <https://github.com/crusaderky>`_.
4548

4649
- Remove internal usage of `collections.OrderedDict`. After dropping support for
4750
Python <=3.5, most uses of `OrderedDict` in Xarray were no longer necessary. We
@@ -57,14 +60,17 @@ New functions/methods
5760
Enhancements
5861
~~~~~~~~~~~~
5962

60-
- Add a repr for :py:class:`~xarray.core.GroupBy` objects (:issue:`3344`).
63+
- Add a repr for :py:class:`~xarray.core.GroupBy` objects.
6164
Example::
6265

6366
>>> da.groupby("time.season")
6467
DataArrayGroupBy, grouped over 'season'
6568
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'
6669

67-
By `Deepak Cherian <https://github.com/dcherian>`_.
70+
(:issue:`3344`) by `Deepak Cherian <https://github.com/dcherian>`_.
71+
- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small
72+
arrays (:issue:`2799`, :pull:`3375`) by
73+
`Guido Imperiale <https://github.com/crusaderky>`_.
6874

6975
Bug fixes
7076
~~~~~~~~~
@@ -75,6 +81,8 @@ Bug fixes
7581
- Line plots with the ``x`` or ``y`` argument set to a 1D non-dimensional coord
7682
now plot the correct data for 2D DataArrays
7783
(:issue:`3334`). By `Tom Nicholas <http://github.com/TomNicholas>`_.
84+
- Fix error in concatenating unlabeled dimensions (:pull:`3362`).
85+
By `Deepak Cherian <https://github.com/dcherian/>`_.
7886

7987
Documentation
8088
~~~~~~~~~~~~~

xarray/core/concat.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat):
175175
if dim not in ds.dims:
176176
if dim in ds:
177177
ds = ds.set_coords(dim)
178-
else:
179-
raise ValueError("%r is not present in all datasets" % dim)
180178
concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)
181179
concat_dim_lengths.append(ds.dims.get(dim, 1))
182180

@@ -360,12 +358,21 @@ def ensure_common_dims(vars):
360358
# n.b. this loop preserves variable order, needed for groupby.
361359
for k in datasets[0].variables:
362360
if k in concat_over:
363-
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
361+
try:
362+
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
363+
except KeyError:
364+
raise ValueError("%r is not present in all datasets." % k)
364365
combined = concat_vars(vars, dim, positions)
365366
assert isinstance(combined, Variable)
366367
result_vars[k] = combined
367368

368369
result = Dataset(result_vars, attrs=result_attrs)
370+
absent_coord_names = coord_names - set(result.variables)
371+
if absent_coord_names:
372+
raise ValueError(
373+
"Variables %r are coordinates in some datasets but not others."
374+
% absent_coord_names
375+
)
369376
result = result.set_coords(coord_names)
370377
result.encoding = result_encoding
371378

xarray/core/dataset.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,8 +1737,8 @@ def maybe_chunk(name, var, chunks):
17371737
return self._replace(variables)
17381738

17391739
def _validate_indexers(
1740-
self, indexers: Mapping
1741-
) -> List[Tuple[Any, Union[slice, Variable]]]:
1740+
self, indexers: Mapping[Hashable, Any]
1741+
) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]:
17421742
""" Here we make sure
17431743
+ indexer has a valid keys
17441744
+ indexer is in a valid data type
@@ -1747,50 +1747,61 @@ def _validate_indexers(
17471747
"""
17481748
from .dataarray import DataArray
17491749

1750-
invalid = [k for k in indexers if k not in self.dims]
1750+
invalid = indexers.keys() - self.dims.keys()
17511751
if invalid:
17521752
raise ValueError("dimensions %r do not exist" % invalid)
17531753

17541754
# all indexers should be int, slice, np.ndarrays, or Variable
1755-
indexers_list: List[Tuple[Any, Union[slice, Variable]]] = []
17561755
for k, v in indexers.items():
1757-
if isinstance(v, slice):
1758-
indexers_list.append((k, v))
1759-
continue
1760-
1761-
if isinstance(v, Variable):
1762-
pass
1756+
if isinstance(v, (int, slice, Variable)):
1757+
yield k, v
17631758
elif isinstance(v, DataArray):
1764-
v = v.variable
1759+
yield k, v.variable
17651760
elif isinstance(v, tuple):
1766-
v = as_variable(v)
1761+
yield k, as_variable(v)
17671762
elif isinstance(v, Dataset):
17681763
raise TypeError("cannot use a Dataset as an indexer")
17691764
elif isinstance(v, Sequence) and len(v) == 0:
1770-
v = Variable((k,), np.zeros((0,), dtype="int64"))
1765+
yield k, np.empty((0,), dtype="int64")
17711766
else:
17721767
v = np.asarray(v)
17731768

1774-
if v.dtype.kind == "U" or v.dtype.kind == "S":
1769+
if v.dtype.kind in "US":
17751770
index = self.indexes[k]
17761771
if isinstance(index, pd.DatetimeIndex):
17771772
v = v.astype("datetime64[ns]")
17781773
elif isinstance(index, xr.CFTimeIndex):
17791774
v = _parse_array_of_cftime_strings(v, index.date_type)
17801775

1781-
if v.ndim == 0:
1782-
v = Variable((), v)
1783-
elif v.ndim == 1:
1784-
v = Variable((k,), v)
1785-
else:
1776+
if v.ndim > 1:
17861777
raise IndexError(
17871778
"Unlabeled multi-dimensional array cannot be "
17881779
"used for indexing: {}".format(k)
17891780
)
1781+
yield k, v
17901782

1791-
indexers_list.append((k, v))
1792-
1793-
return indexers_list
1783+
def _validate_interp_indexers(
1784+
self, indexers: Mapping[Hashable, Any]
1785+
) -> Iterator[Tuple[Hashable, Variable]]:
1786+
"""Variant of _validate_indexers to be used for interpolation
1787+
"""
1788+
for k, v in self._validate_indexers(indexers):
1789+
if isinstance(v, Variable):
1790+
if v.ndim == 1:
1791+
yield k, v.to_index_variable()
1792+
else:
1793+
yield k, v
1794+
elif isinstance(v, int):
1795+
yield k, Variable((), v)
1796+
elif isinstance(v, np.ndarray):
1797+
if v.ndim == 0:
1798+
yield k, Variable((), v)
1799+
elif v.ndim == 1:
1800+
yield k, IndexVariable((k,), v)
1801+
else:
1802+
raise AssertionError() # Already tested by _validate_indexers
1803+
else:
1804+
raise TypeError(type(v))
17941805

17951806
def _get_indexers_coords_and_indexes(self, indexers):
17961807
"""Extract coordinates and indexes from indexers.
@@ -1875,10 +1886,10 @@ def isel(
18751886
Dataset.sel
18761887
DataArray.isel
18771888
"""
1878-
18791889
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
1880-
1881-
indexers_list = self._validate_indexers(indexers)
1890+
# Note: we need to preserve the original indexers variable in order to merge the
1891+
# coords below
1892+
indexers_list = list(self._validate_indexers(indexers))
18821893

18831894
variables = {} # type: Dict[Hashable, Variable]
18841895
indexes = {} # type: Dict[Hashable, pd.Index]
@@ -1894,19 +1905,21 @@ def isel(
18941905
)
18951906
if new_index is not None:
18961907
indexes[name] = new_index
1897-
else:
1908+
elif var_indexers:
18981909
new_var = var.isel(indexers=var_indexers)
1910+
else:
1911+
new_var = var.copy(deep=False)
18991912

19001913
variables[name] = new_var
19011914

1902-
coord_names = set(variables).intersection(self._coord_names)
1915+
coord_names = self._coord_names & variables.keys()
19031916
selected = self._replace_with_new_dims(variables, coord_names, indexes)
19041917

19051918
# Extract coordinates from indexers
19061919
coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers)
19071920
variables.update(coord_vars)
19081921
indexes.update(new_indexes)
1909-
coord_names = set(variables).intersection(self._coord_names).union(coord_vars)
1922+
coord_names = self._coord_names & variables.keys() | coord_vars.keys()
19101923
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
19111924

19121925
def sel(
@@ -2468,11 +2481,9 @@ def interp(
24682481

24692482
if kwargs is None:
24702483
kwargs = {}
2484+
24712485
coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
2472-
indexers = {
2473-
k: v.to_index_variable() if isinstance(v, Variable) and v.ndim == 1 else v
2474-
for k, v in self._validate_indexers(coords)
2475-
}
2486+
indexers = dict(self._validate_interp_indexers(coords))
24762487

24772488
obj = self if assume_sorted else self.sortby([k for k in coords])
24782489

@@ -2497,26 +2508,25 @@ def _validate_interp_indexer(x, new_x):
24972508
"strings or datetimes. "
24982509
"Instead got\n{}".format(new_x)
24992510
)
2500-
else:
2501-
return (x, new_x)
2511+
return x, new_x
25022512

25032513
variables = {} # type: Dict[Hashable, Variable]
25042514
for name, var in obj._variables.items():
2505-
if name not in indexers:
2506-
if var.dtype.kind in "uifc":
2507-
var_indexers = {
2508-
k: _validate_interp_indexer(maybe_variable(obj, k), v)
2509-
for k, v in indexers.items()
2510-
if k in var.dims
2511-
}
2512-
variables[name] = missing.interp(
2513-
var, var_indexers, method, **kwargs
2514-
)
2515-
elif all(d not in indexers for d in var.dims):
2516-
# keep unrelated object array
2517-
variables[name] = var
2515+
if name in indexers:
2516+
continue
2517+
2518+
if var.dtype.kind in "uifc":
2519+
var_indexers = {
2520+
k: _validate_interp_indexer(maybe_variable(obj, k), v)
2521+
for k, v in indexers.items()
2522+
if k in var.dims
2523+
}
2524+
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
2525+
elif all(d not in indexers for d in var.dims):
2526+
# keep unrelated object array
2527+
variables[name] = var
25182528

2519-
coord_names = set(variables).intersection(obj._coord_names)
2529+
coord_names = obj._coord_names & variables.keys()
25202530
indexes = {k: v for k, v in obj.indexes.items() if k not in indexers}
25212531
selected = self._replace_with_new_dims(
25222532
variables.copy(), coord_names, indexes=indexes
@@ -2534,7 +2544,7 @@ def _validate_interp_indexer(x, new_x):
25342544
variables.update(coord_vars)
25352545
indexes.update(new_indexes)
25362546

2537-
coord_names = set(variables).intersection(obj._coord_names).union(coord_vars)
2547+
coord_names = obj._coord_names & variables.keys() | coord_vars.keys()
25382548
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
25392549

25402550
def interp_like(

xarray/core/indexes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections.abc
22
from typing import Any, Hashable, Iterable, Mapping, Optional, Tuple, Union
33

4+
import numpy as np
45
import pandas as pd
56

67
from . import formatting
@@ -62,7 +63,7 @@ def isel_variable_and_index(
6263
name: Hashable,
6364
variable: Variable,
6465
index: pd.Index,
65-
indexers: Mapping[Any, Union[slice, Variable]],
66+
indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]],
6667
) -> Tuple[Variable, Optional[pd.Index]]:
6768
"""Index a Variable and pandas.Index together."""
6869
if not indexers:

xarray/core/variable.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from datetime import timedelta
55
from distutils.version import LooseVersion
6-
from typing import Any, Hashable, Mapping, Union
6+
from typing import Any, Hashable, Mapping, Union, TypeVar
77

88
import numpy as np
99
import pandas as pd
@@ -41,6 +41,18 @@
4141
# https://github.com/python/mypy/issues/224
4242
BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore
4343

44+
VariableType = TypeVar("VariableType", bound="Variable")
45+
"""Type annotation to be used when methods of Variable return self or a copy of self.
46+
When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the
47+
output as an instance of the subclass.
48+
49+
Usage::
50+
51+
class Variable:
52+
def f(self: VariableType, ...) -> VariableType:
53+
...
54+
"""
55+
4456

4557
class MissingDimensionsError(ValueError):
4658
"""Error class used when we can't safely guess a dimension name.
@@ -663,8 +675,8 @@ def _broadcast_indexes_vectorized(self, key):
663675

664676
return out_dims, VectorizedIndexer(tuple(out_key)), new_order
665677

666-
def __getitem__(self, key):
667-
"""Return a new Array object whose contents are consistent with
678+
def __getitem__(self: VariableType, key) -> VariableType:
679+
"""Return a new Variable object whose contents are consistent with
668680
getting the provided key from the underlying data.
669681
670682
NB. __getitem__ and __setitem__ implement xarray-style indexing,
@@ -682,7 +694,7 @@ def __getitem__(self, key):
682694
data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order)
683695
return self._finalize_indexing_result(dims, data)
684696

685-
def _finalize_indexing_result(self, dims, data):
697+
def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
686698
"""Used by IndexVariable to return IndexVariable objects when possible.
687699
"""
688700
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
@@ -957,7 +969,11 @@ def chunk(self, chunks=None, name=None, lock=False):
957969

958970
return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)
959971

960-
def isel(self, indexers=None, drop=False, **indexers_kwargs):
972+
def isel(
973+
self: VariableType,
974+
indexers: Mapping[Hashable, Any] = None,
975+
**indexers_kwargs: Any
976+
) -> VariableType:
961977
"""Return a new array indexed along the specified dimension(s).
962978
963979
Parameters
@@ -976,15 +992,12 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs):
976992
"""
977993
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
978994

979-
invalid = [k for k in indexers if k not in self.dims]
995+
invalid = indexers.keys() - set(self.dims)
980996
if invalid:
981997
raise ValueError("dimensions %r do not exist" % invalid)
982998

983-
key = [slice(None)] * self.ndim
984-
for i, dim in enumerate(self.dims):
985-
if dim in indexers:
986-
key[i] = indexers[dim]
987-
return self[tuple(key)]
999+
key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
1000+
return self[key]
9881001

9891002
def squeeze(self, dim=None):
9901003
"""Return a new object with squeezed data.

xarray/tests/test_concat.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ def test_concat_compat():
4242
for var in ["has_x", "no_x_y"]:
4343
assert "y" not in result[var]
4444

45+
with raises_regex(ValueError, "coordinates in some datasets but not others"):
46+
concat([ds1, ds2], dim="q")
4547
with raises_regex(ValueError, "'q' is not present in all datasets"):
46-
concat([ds1, ds2], dim="q", data_vars="all", compat="broadcast_equals")
48+
concat([ds2, ds1], dim="q")
4749

4850

4951
class TestConcatDataset:
@@ -90,7 +92,11 @@ def test_concat_coords_kwarg(self, data, dim, coords):
9092
assert_equal(data["extra"], actual["extra"])
9193

9294
def test_concat(self, data):
93-
split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))]
95+
split_data = [
96+
data.isel(dim1=slice(3)),
97+
data.isel(dim1=3),
98+
data.isel(dim1=slice(4, None)),
99+
]
94100
assert_identical(data, concat(split_data, "dim1"))
95101

96102
def test_concat_dim_precedence(self, data):

0 commit comments

Comments
 (0)