Skip to content

Commit 9d572a5

Browse files
authored
Return slices when possible from CFTimeIndex.get_loc() (#2569)
This makes string indexing usable in a pandas.MultiIndex.
1 parent 57fdcc5 commit 9d572a5

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ Breaking changes
3636
Enhancements
3737
~~~~~~~~~~~~
3838

39+
- :py:class:`CFTimeIndex` uses slicing for string indexing when possible (like
40+
:py:class:`pandas.DatetimeIndex`), which avoids unnecessary copies.
41+
By `Stephan Hoyer <https://github.com/shoyer>`_
42+
3943
Bug fixes
4044
~~~~~~~~~
4145

xarray/coding/cftimeindex.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,31 @@ def _partial_date_slice(self, resolution, parsed):
263263
"""
264264
start, end = _parsed_string_to_bounds(self.date_type, resolution,
265265
parsed)
266-
lhs_mask = (self._data >= start)
267-
rhs_mask = (self._data <= end)
268-
return (lhs_mask & rhs_mask).nonzero()[0]
266+
267+
times = self._data
268+
269+
if self.is_monotonic:
270+
if (len(times) and ((start < times[0] and end < times[0]) or
271+
(start > times[-1] and end > times[-1]))):
272+
# we are out of range
273+
raise KeyError
274+
275+
# a monotonic (sorted) series can be sliced
276+
left = times.searchsorted(start, side='left')
277+
right = times.searchsorted(end, side='right')
278+
return slice(left, right)
279+
280+
lhs_mask = times >= start
281+
rhs_mask = times <= end
282+
return np.flatnonzero(lhs_mask & rhs_mask)
269283

270284
def _get_string_slice(self, key):
271285
"""Adapted from pandas.tseries.index.DatetimeIndex._get_string_slice"""
272286
parsed, resolution = _parse_iso8601_with_reso(self.date_type, key)
273-
loc = self._partial_date_slice(resolution, parsed)
287+
try:
288+
loc = self._partial_date_slice(resolution, parsed)
289+
except KeyError:
290+
raise KeyError(key)
274291
return loc
275292

276293
def get_loc(self, key, method=None, tolerance=None):
@@ -431,7 +448,7 @@ def to_datetimeindex(self, unsafe=False):
431448
'calendar, {!r}, to a pandas.DatetimeIndex, which uses dates '
432449
'from the standard calendar. This may lead to subtle errors '
433450
'in operations that depend on the length of time between '
434-
'dates.'.format(calendar), RuntimeWarning)
451+
'dates.'.format(calendar), RuntimeWarning, stacklevel=2)
435452
return pd.DatetimeIndex(nptimes)
436453

437454

xarray/tests/test_cftimeindex.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
_parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601)
1313
from xarray.tests import assert_array_equal, assert_identical
1414

15-
from . import has_cftime, has_cftime_or_netCDF4, requires_cftime
15+
from . import has_cftime, has_cftime_or_netCDF4, requires_cftime, raises_regex
1616
from .test_coding_times import (_all_cftime_date_types, _ALL_CALENDARS,
1717
_NON_STANDARD_CALENDARS)
1818

@@ -251,16 +251,16 @@ def test_parsed_string_to_bounds_raises(date_type):
251251
@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
252252
def test_get_loc(date_type, index):
253253
result = index.get_loc('0001')
254-
expected = [0, 1]
255-
assert_array_equal(result, expected)
254+
assert result == slice(0, 2)
256255

257256
result = index.get_loc(date_type(1, 2, 1))
258-
expected = 1
259-
assert result == expected
257+
assert result == 1
260258

261259
result = index.get_loc('0001-02-01')
262-
expected = 1
263-
assert result == expected
260+
assert result == slice(1, 2)
261+
262+
with raises_regex(KeyError, '1234'):
263+
index.get_loc('1234')
264264

265265

266266
@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
@@ -758,7 +758,7 @@ def test_to_datetimeindex(calendar, unsafe):
758758
with pytest.warns(RuntimeWarning, match='non-standard'):
759759
result = index.to_datetimeindex()
760760
else:
761-
result = index.to_datetimeindex()
761+
result = index.to_datetimeindex(unsafe=unsafe)
762762

763763
assert result.equals(expected)
764764
np.testing.assert_array_equal(result, expected)
@@ -779,3 +779,10 @@ def test_to_datetimeindex_feb_29(calendar):
779779
index = xr.cftime_range('2001-02-28', periods=2, calendar=calendar)
780780
with pytest.raises(ValueError, match='29'):
781781
index.to_datetimeindex()
782+
783+
784+
@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
785+
def test_multiindex():
786+
index = xr.cftime_range('2001-01-01', periods=100, calendar='360_day')
787+
mindex = pd.MultiIndex.from_arrays([index])
788+
assert mindex.get_loc('2001-01') == slice(0, 30)

0 commit comments

Comments
 (0)