diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 40e206aaa86..0457b77f20c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,10 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed labeled indexing with slice bounds given by xarray objects with + datetime64 or timedelta64 dtypes (:issue:`1240`). + By `Stephan Hoyer `_. + .. _whats-new.0.10.2: v0.10.2 (13 March 2018) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f7477a3e6b2..2c1f08379ab 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -48,11 +48,25 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def _try_get_item(x): - try: - return x.item() - except AttributeError: - return x +def _sanitize_slice_element(x): + from .variable import Variable + from .dataarray import DataArray + + if isinstance(x, (Variable, DataArray)): + x = x.values + + if isinstance(x, np.ndarray): + if x.ndim != 0: + raise ValueError('cannot use non-scalar arrays in a slice for ' + 'xarray indexing: {}'.format(x)) + x = x[()] + + if isinstance(x, np.timedelta64): + # pandas does not support indexing with np.timedelta64 yet: + # https://github.com/pandas-dev/pandas/issues/20393 + x = pd.Timedelta(x) + + return x def _asarray_tuplesafe(values): @@ -119,9 +133,9 @@ def convert_label_indexer(index, label, index_name='', method=None, raise NotImplementedError( 'cannot use ``method`` argument if any indexers are ' 'slice objects') - indexer = index.slice_indexer(_try_get_item(label.start), - _try_get_item(label.stop), - _try_get_item(label.step)) + indexer = index.slice_indexer(_sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step)) if not isinstance(indexer, slice): # unlike pandas, in xarray we never want to silently convert a # slice indexer into an array indexer diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 059e93fc70c..3fd229cf394 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -779,6 +779,22 @@ def test_sel_dataarray(self): assert 'new_dim' in actual.coords assert_equal(actual['new_dim'].drop('x'), ind['new_dim']) + def test_sel_invalid_slice(self): + array = DataArray(np.arange(10), [('x', np.arange(10))]) + with raises_regex(ValueError, 'cannot use non-scalar arrays'): + array.sel(x=slice(array.x)) + + def test_sel_dataarray_datetime(self): + # regression test for GH1240 + times = pd.date_range('2000-01-01', freq='D', periods=365) + array = DataArray(np.arange(365), [('time', times)]) + result = array.sel(time=slice(array.time[0], array.time[-1])) + assert_equal(result, array) + + array = DataArray(np.arange(365), [('delta', times - times[0])]) + result = array.sel(delta=slice(array.delta[0], array.delta[-1])) + assert_equal(result, array) + def test_sel_no_index(self): array = DataArray(np.arange(10), dims='x') assert_identical(array[0], array.sel(x=0))