Skip to content

Commit feb9aad

Browse files
hmaarrfkdcherian
authored andcommitted
Micro optimizations to improve indexing (#9002)
* conda instead of mamba * Make speedups using fastpath * Change core logic to apply_indexes_fast * Always have fastpath=True in one path * Remove basicindexer fastpath=True * Duplicate a comment * Add comments * revert asv changes * Avoid fastpath=True assignment * Remove changes to basicindexer * Do not do fast fastpath for IndexVariable * Remove one unecessary change * Remove one more fastpath * Revert uneeded change to PandasIndexingAdapter * Update xarray/core/indexes.py * Update whats-new.rst * Update whats-new.rst * fix whats-new --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent f36494e commit feb9aad

File tree

3 files changed

+65
-10
lines changed

3 files changed

+65
-10
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Performance
2828

2929
- Small optimization to the netCDF4 and h5netcdf backends (:issue:`9058`, :pull:`9067`).
3030
By `Deepak Cherian <https://github.com/dcherian>`_.
31+
- Small optimizations to help reduce indexing speed of datasets (:pull:`9002`).
32+
By `Mark Harfouche <https://github.com/hmaarrfk>`_.
3133

3234

3335
Breaking changes
@@ -2906,7 +2908,7 @@ Bug fixes
29062908
process (:issue:`4045`, :pull:`4684`). It also enables encoding and decoding standard
29072909
calendar dates with time units of nanoseconds (:pull:`4400`).
29082910
By `Spencer Clark <https://github.com/spencerkclark>`_ and `Mark Harfouche
2909-
<http://github.com/hmaarrfk>`_.
2911+
<https://github.com/hmaarrfk>`_.
29102912
- :py:meth:`DataArray.astype`, :py:meth:`Dataset.astype` and :py:meth:`Variable.astype` support
29112913
the ``order`` and ``subok`` parameters again. This fixes a regression introduced in version 0.16.1
29122914
(:issue:`4644`, :pull:`4683`).

xarray/core/indexes.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -575,13 +575,24 @@ class PandasIndex(Index):
575575

576576
__slots__ = ("index", "dim", "coord_dtype")
577577

578-
def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None):
579-
# make a shallow copy: cheap and because the index name may be updated
580-
# here or in other constructors (cannot use pd.Index.rename as this
581-
# constructor is also called from PandasMultiIndex)
582-
index = safe_cast_to_index(array).copy()
578+
def __init__(
579+
self,
580+
array: Any,
581+
dim: Hashable,
582+
coord_dtype: Any = None,
583+
*,
584+
fastpath: bool = False,
585+
):
586+
if fastpath:
587+
index = array
588+
else:
589+
index = safe_cast_to_index(array)
583590

584591
if index.name is None:
592+
# make a shallow copy: cheap and because the index name may be updated
593+
# here or in other constructors (cannot use pd.Index.rename as this
594+
# constructor is also called from PandasMultiIndex)
595+
index = index.copy()
585596
index.name = dim
586597

587598
self.index = index
@@ -596,7 +607,7 @@ def _replace(self, index, dim=None, coord_dtype=None):
596607
dim = self.dim
597608
if coord_dtype is None:
598609
coord_dtype = self.coord_dtype
599-
return type(self)(index, dim, coord_dtype)
610+
return type(self)(index, dim, coord_dtype, fastpath=True)
600611

601612
@classmethod
602613
def from_variables(
@@ -642,6 +653,11 @@ def from_variables(
642653

643654
obj = cls(data, dim, coord_dtype=var.dtype)
644655
assert not isinstance(obj.index, pd.MultiIndex)
656+
# Rename safely
657+
# make a shallow copy: cheap and because the index name may be updated
658+
# here or in other constructors (cannot use pd.Index.rename as this
659+
# constructor is also called from PandasMultiIndex)
660+
obj.index = obj.index.copy()
645661
obj.index.name = name
646662

647663
return obj
@@ -1773,6 +1789,36 @@ def check_variables():
17731789
return not not_equal
17741790

17751791

1792+
def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str):
1793+
# This function avoids the call to indexes.group_by_index
1794+
# which is really slow when repeatidly iterating through
1795+
# an array. However, it fails to return the correct ID for
1796+
# multi-index arrays
1797+
indexes_fast, coords = indexes._indexes, indexes._variables
1798+
1799+
new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes_fast.items()}
1800+
new_index_variables: dict[Hashable, Variable] = {}
1801+
for name, index in indexes_fast.items():
1802+
coord = coords[name]
1803+
if hasattr(coord, "_indexes"):
1804+
index_vars = {n: coords[n] for n in coord._indexes}
1805+
else:
1806+
index_vars = {name: coord}
1807+
index_dims = {d for var in index_vars.values() for d in var.dims}
1808+
index_args = {k: v for k, v in args.items() if k in index_dims}
1809+
1810+
if index_args:
1811+
new_index = getattr(index, func)(index_args)
1812+
if new_index is not None:
1813+
new_indexes.update({k: new_index for k in index_vars})
1814+
new_index_vars = new_index.create_variables(index_vars)
1815+
new_index_variables.update(new_index_vars)
1816+
else:
1817+
for k in index_vars:
1818+
new_indexes.pop(k, None)
1819+
return new_indexes, new_index_variables
1820+
1821+
17761822
def _apply_indexes(
17771823
indexes: Indexes[Index],
17781824
args: Mapping[Any, Any],
@@ -1801,7 +1847,13 @@ def isel_indexes(
18011847
indexes: Indexes[Index],
18021848
indexers: Mapping[Any, Any],
18031849
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
1804-
return _apply_indexes(indexes, indexers, "isel")
1850+
# TODO: remove if clause in the future. It should be unnecessary.
1851+
# See failure introduced when removed
1852+
# https://github.com/pydata/xarray/pull/9002#discussion_r1590443756
1853+
if any(isinstance(v, PandasMultiIndex) for v in indexes._indexes.values()):
1854+
return _apply_indexes(indexes, indexers, "isel")
1855+
else:
1856+
return _apply_indexes_fast(indexes, indexers, "isel")
18051857

18061858

18071859
def roll_indexes(

xarray/core/indexing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,15 +297,16 @@ def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice:
297297

298298

299299
def _index_indexer_1d(old_indexer, applied_indexer, size: int):
300-
assert isinstance(applied_indexer, integer_types + (slice, np.ndarray))
301300
if isinstance(applied_indexer, slice) and applied_indexer == slice(None):
302301
# shortcut for the usual case
303302
return old_indexer
304303
if isinstance(old_indexer, slice):
305304
if isinstance(applied_indexer, slice):
306305
indexer = slice_slice(old_indexer, applied_indexer, size)
306+
elif isinstance(applied_indexer, integer_types):
307+
indexer = range(*old_indexer.indices(size))[applied_indexer] # type: ignore[assignment]
307308
else:
308-
indexer = _expand_slice(old_indexer, size)[applied_indexer] # type: ignore[assignment]
309+
indexer = _expand_slice(old_indexer, size)[applied_indexer]
309310
else:
310311
indexer = old_indexer[applied_indexer]
311312
return indexer

0 commit comments

Comments
 (0)