Skip to content

Speed up isel and __getitem__ #3375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 9, 2019
Merged
10 changes: 8 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -42,21 +42,27 @@ Breaking changes

(:issue:`3222`, :issue:`3293`, :issue:`3340`, :issue:`3346`, :issue:`3358`).
By `Guido Imperiale <https://github.com/crusaderky>`_.
- Dropped the 'drop=False' optional parameter from :meth:`Variable.isel`.
It was unused and doesn't make sense for a Variable.
(:pull:`3375`) by `Guido Imperiale <https://github.com/crusaderky>`_.

New functions/methods
~~~~~~~~~~~~~~~~~~~~~

Enhancements
~~~~~~~~~~~~

- Add a repr for :py:class:`~xarray.core.GroupBy` objects (:issue:`3344`).
- Add a repr for :py:class:`~xarray.core.GroupBy` objects.
Example::

>>> da.groupby("time.season")
DataArrayGroupBy, grouped over 'season'
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'

By `Deepak Cherian <https://github.com/dcherian>`_.
(:issue:`3344`) by `Deepak Cherian <https://github.com/dcherian>`_.
- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small
arrays (:issue:`2799`, :pull:`3375`) by
`Guido Imperiale <https://github.com/crusaderky>`_.

Bug fixes
~~~~~~~~~
108 changes: 59 additions & 49 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1745,8 +1745,8 @@ def maybe_chunk(name, var, chunks):
return self._replace(variables)

def _validate_indexers(
self, indexers: Mapping
) -> List[Tuple[Any, Union[slice, Variable]]]:
self, indexers: Mapping[Hashable, Any]
) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]:
""" Here we make sure
+ indexer has a valid keys
+ indexer is in a valid data type
@@ -1755,50 +1755,61 @@ def _validate_indexers(
"""
from .dataarray import DataArray

invalid = [k for k in indexers if k not in self.dims]
invalid = indexers.keys() - self.dims.keys()
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

# all indexers should be int, slice, np.ndarrays, or Variable
indexers_list: List[Tuple[Any, Union[slice, Variable]]] = []
for k, v in indexers.items():
if isinstance(v, slice):
indexers_list.append((k, v))
continue

if isinstance(v, Variable):
pass
if isinstance(v, (int, slice, Variable)):
yield k, v
elif isinstance(v, DataArray):
v = v.variable
yield k, v.variable
elif isinstance(v, tuple):
v = as_variable(v)
yield k, as_variable(v)
elif isinstance(v, Dataset):
raise TypeError("cannot use a Dataset as an indexer")
elif isinstance(v, Sequence) and len(v) == 0:
v = Variable((k,), np.zeros((0,), dtype="int64"))
yield k, np.empty((0,), dtype="int64")
else:
v = np.asarray(v)

if v.dtype.kind == "U" or v.dtype.kind == "S":
if v.dtype.kind in "US":
index = self.indexes[k]
if isinstance(index, pd.DatetimeIndex):
v = v.astype("datetime64[ns]")
elif isinstance(index, xr.CFTimeIndex):
v = _parse_array_of_cftime_strings(v, index.date_type)

if v.ndim == 0:
v = Variable((), v)
elif v.ndim == 1:
v = Variable((k,), v)
else:
if v.ndim > 1:
raise IndexError(
"Unlabeled multi-dimensional array cannot be "
"used for indexing: {}".format(k)
)
yield k, v

indexers_list.append((k, v))

return indexers_list
def _validate_interp_indexers(
self, indexers: Mapping[Hashable, Any]
) -> Iterator[Tuple[Hashable, Variable]]:
"""Variant of _validate_indexers to be used for interpolation
"""
for k, v in self._validate_indexers(indexers):
if isinstance(v, Variable):
if v.ndim == 1:
yield k, v.to_index_variable()
else:
yield k, v
elif isinstance(v, int):
yield k, Variable((), v)
elif isinstance(v, np.ndarray):
if v.ndim == 0:
yield k, Variable((), v)
elif v.ndim == 1:
yield k, IndexVariable((k,), v)
else:
raise AssertionError() # Already tested by _validate_indexers
else:
raise TypeError(type(v))

def _get_indexers_coords_and_indexes(self, indexers):
"""Extract coordinates and indexes from indexers.
@@ -1885,10 +1896,10 @@ def isel(
Dataset.sel
DataArray.isel
"""

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")

indexers_list = self._validate_indexers(indexers)
# Note: we need to preserve the original indexers variable in order to merge the
# coords below
indexers_list = list(self._validate_indexers(indexers))

variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
indexes = OrderedDict() # type: OrderedDict[Hashable, pd.Index]
@@ -1904,19 +1915,21 @@ def isel(
)
if new_index is not None:
indexes[name] = new_index
else:
elif var_indexers:
new_var = var.isel(indexers=var_indexers)
else:
new_var = var.copy(deep=False)

variables[name] = new_var

coord_names = set(variables).intersection(self._coord_names)
coord_names = self._coord_names & variables.keys()
selected = self._replace_with_new_dims(variables, coord_names, indexes)

# Extract coordinates from indexers
coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers)
variables.update(coord_vars)
indexes.update(new_indexes)
coord_names = set(variables).intersection(self._coord_names).union(coord_vars)
coord_names = self._coord_names & variables.keys() | coord_vars.keys()
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

def sel(
@@ -2478,11 +2491,9 @@ def interp(

if kwargs is None:
kwargs = {}

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = OrderedDict(
(k, v.to_index_variable() if isinstance(v, Variable) and v.ndim == 1 else v)
for k, v in self._validate_indexers(coords)
)
indexers = OrderedDict(self._validate_interp_indexers(coords))

obj = self if assume_sorted else self.sortby([k for k in coords])

@@ -2507,26 +2518,25 @@ def _validate_interp_indexer(x, new_x):
"strings or datetimes. "
"Instead got\n{}".format(new_x)
)
else:
return (x, new_x)
return x, new_x

variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
for name, var in obj._variables.items():
if name not in indexers:
if var.dtype.kind in "uifc":
var_indexers = {
k: _validate_interp_indexer(maybe_variable(obj, k), v)
for k, v in indexers.items()
if k in var.dims
}
variables[name] = missing.interp(
var, var_indexers, method, **kwargs
)
elif all(d not in indexers for d in var.dims):
# keep unrelated object array
variables[name] = var
if name in indexers:
continue

if var.dtype.kind in "uifc":
var_indexers = {
k: _validate_interp_indexer(maybe_variable(obj, k), v)
for k, v in indexers.items()
if k in var.dims
}
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
elif all(d not in indexers for d in var.dims):
# keep unrelated object array
variables[name] = var

coord_names = set(variables).intersection(obj._coord_names)
coord_names = obj._coord_names & variables.keys()
indexes = OrderedDict(
(k, v) for k, v in obj.indexes.items() if k not in indexers
)
@@ -2546,7 +2556,7 @@ def _validate_interp_indexer(x, new_x):
variables.update(coord_vars)
indexes.update(new_indexes)

coord_names = set(variables).intersection(obj._coord_names).union(coord_vars)
coord_names = obj._coord_names & variables.keys() | coord_vars.keys()
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

def interp_like(
3 changes: 2 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from collections import OrderedDict
from typing import Any, Hashable, Iterable, Mapping, Optional, Tuple, Union

import numpy as np
import pandas as pd

from . import formatting
@@ -63,7 +64,7 @@ def isel_variable_and_index(
name: Hashable,
variable: Variable,
index: pd.Index,
indexers: Mapping[Any, Union[slice, Variable]],
indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]],
) -> Tuple[Variable, Optional[pd.Index]]:
"""Index a Variable and pandas.Index together."""
if not indexers:
35 changes: 24 additions & 11 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from collections import OrderedDict, defaultdict
from datetime import timedelta
from distutils.version import LooseVersion
from typing import Any, Hashable, Mapping, Union
from typing import Any, Hashable, Mapping, Union, TypeVar

import numpy as np
import pandas as pd
@@ -41,6 +41,18 @@
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore

VariableType = TypeVar("VariableType", bound="Variable")
"""Type annotation to be used when methods of Variable return self or a copy of self.
When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the
output as an instance of the subclass.

Usage::

class Variable:
def f(self: VariableType, ...) -> VariableType:
...
"""


class MissingDimensionsError(ValueError):
"""Error class used when we can't safely guess a dimension name.
@@ -663,8 +675,8 @@ def _broadcast_indexes_vectorized(self, key):

return out_dims, VectorizedIndexer(tuple(out_key)), new_order

def __getitem__(self, key):
"""Return a new Array object whose contents are consistent with
def __getitem__(self: VariableType, key) -> VariableType:
"""Return a new Variable object whose contents are consistent with
getting the provided key from the underlying data.

NB. __getitem__ and __setitem__ implement xarray-style indexing,
@@ -682,7 +694,7 @@ def __getitem__(self, key):
data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)

def _finalize_indexing_result(self, dims, data):
def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
"""Used by IndexVariable to return IndexVariable objects when possible.
"""
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
@@ -957,7 +969,11 @@ def chunk(self, chunks=None, name=None, lock=False):

return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)

def isel(self, indexers=None, drop=False, **indexers_kwargs):
def isel(
self: VariableType,
indexers: Mapping[Hashable, Any] = None,
**indexers_kwargs: Any
) -> VariableType:
"""Return a new array indexed along the specified dimension(s).

Parameters
@@ -976,15 +992,12 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs):
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")

invalid = [k for k in indexers if k not in self.dims]
invalid = indexers.keys() - set(self.dims)
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

key = [slice(None)] * self.ndim
for i, dim in enumerate(self.dims):
if dim in indexers:
key[i] = indexers[dim]
return self[tuple(key)]
key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
return self[key]

def squeeze(self, dim=None):
"""Return a new object with squeezed data.