Skip to content

Commit cc142f4

Browse files
fujiisoupdcheriankeewis
authored
sel with categorical index (#3670)
* Added a support with categorical index * fix from_dataframe * update a test * Added more tests * black * Added a test to make sure raising ValueErrors * remove unnecessary print * added a test for reindex * Fix according to reviews * blacken * delete trailing whitespace Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: keewis <[email protected]>
1 parent 9c72866 commit cc142f4

File tree

5 files changed

+110
-9
lines changed

5 files changed

+110
-9
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Breaking changes
4040

4141
New Features
4242
~~~~~~~~~~~~
43+
- :py:meth:`DataArray.sel` and :py:meth:`Dataset.sel` now support :py:class:`pandas.CategoricalIndex`. (:issue:`3669`)
44+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
4345
- Support using an existing, opened h5netcdf ``File`` with
4446
:py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an
4547
:py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened

xarray/core/dataset.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
default_indexes,
6565
isel_variable_and_index,
6666
propagate_indexes,
67+
remove_unused_levels_categories,
6768
roll_index,
6869
)
6970
from .indexing import is_fancy_indexer
@@ -3411,7 +3412,7 @@ def ensure_stackable(val):
34113412

34123413
def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
34133414
index = self.get_index(dim)
3414-
index = index.remove_unused_levels()
3415+
index = remove_unused_levels_categories(index)
34153416
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
34163417

34173418
# take a shortcut in case the MultiIndex was not modified.
@@ -4460,17 +4461,19 @@ def to_dataframe(self):
44604461
return self._to_dataframe(self.dims)
44614462

44624463
def _set_sparse_data_from_dataframe(
4463-
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
4464+
self, dataframe: pd.DataFrame, dims: tuple
44644465
) -> None:
44654466
from sparse import COO
44664467

44674468
idx = dataframe.index
44684469
if isinstance(idx, pd.MultiIndex):
44694470
coords = np.stack([np.asarray(code) for code in idx.codes], axis=0)
44704471
is_sorted = idx.is_lexsorted
4472+
shape = tuple(lev.size for lev in idx.levels)
44714473
else:
44724474
coords = np.arange(idx.size).reshape(1, -1)
44734475
is_sorted = True
4476+
shape = (idx.size,)
44744477

44754478
for name, series in dataframe.items():
44764479
# Cast to a NumPy array first, in case the Series is a pandas
@@ -4495,14 +4498,16 @@ def _set_sparse_data_from_dataframe(
44954498
self[name] = (dims, data)
44964499

44974500
def _set_numpy_data_from_dataframe(
4498-
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
4501+
self, dataframe: pd.DataFrame, dims: tuple
44994502
) -> None:
45004503
idx = dataframe.index
45014504
if isinstance(idx, pd.MultiIndex):
45024505
# expand the DataFrame to include the product of all levels
45034506
full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names)
45044507
dataframe = dataframe.reindex(full_idx)
4505-
4508+
shape = tuple(lev.size for lev in idx.levels)
4509+
else:
4510+
shape = (idx.size,)
45064511
for name, series in dataframe.items():
45074512
data = np.asarray(series).reshape(shape)
45084513
self[name] = (dims, data)
@@ -4543,7 +4548,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas
45434548
if not dataframe.columns.is_unique:
45444549
raise ValueError("cannot convert DataFrame with non-unique columns")
45454550

4546-
idx = dataframe.index
4551+
idx = remove_unused_levels_categories(dataframe.index)
4552+
dataframe = dataframe.set_index(idx)
45474553
obj = cls()
45484554

45494555
if isinstance(idx, pd.MultiIndex):
@@ -4553,17 +4559,15 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas
45534559
)
45544560
for dim, lev in zip(dims, idx.levels):
45554561
obj[dim] = (dim, lev)
4556-
shape = tuple(lev.size for lev in idx.levels)
45574562
else:
45584563
index_name = idx.name if idx.name is not None else "index"
45594564
dims = (index_name,)
45604565
obj[index_name] = (dims, idx)
4561-
shape = (idx.size,)
45624566

45634567
if sparse:
4564-
obj._set_sparse_data_from_dataframe(dataframe, dims, shape)
4568+
obj._set_sparse_data_from_dataframe(dataframe, dims)
45654569
else:
4566-
obj._set_numpy_data_from_dataframe(dataframe, dims, shape)
4570+
obj._set_numpy_data_from_dataframe(dataframe, dims)
45674571
return obj
45684572

45694573
def to_dask_dataframe(self, dim_order=None, set_index=False):

xarray/core/indexes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,26 @@
99
from .variable import Variable
1010

1111

12+
def remove_unused_levels_categories(index):
13+
"""
14+
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
15+
"""
16+
if isinstance(index, pd.MultiIndex):
17+
index = index.remove_unused_levels()
18+
# if it contains CategoricalIndex, we need to remove unused categories
19+
# manually. See https://github.com/pandas-dev/pandas/issues/30846
20+
if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels):
21+
levels = []
22+
for i, level in enumerate(index.levels):
23+
if isinstance(level, pd.CategoricalIndex):
24+
level = level[index.codes[i]].remove_unused_categories()
25+
levels.append(level)
26+
index = pd.MultiIndex.from_arrays(levels, names=index.names)
27+
elif isinstance(index, pd.CategoricalIndex):
28+
index = index.remove_unused_categories()
29+
return index
30+
31+
1232
class Indexes(collections.abc.Mapping):
1333
"""Immutable proxy for Dataset or DataArrary indexes."""
1434

xarray/core/indexing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No
175175
if label.ndim == 0:
176176
if isinstance(index, pd.MultiIndex):
177177
indexer, new_index = index.get_loc_level(label.item(), level=0)
178+
elif isinstance(index, pd.CategoricalIndex):
179+
if method is not None:
180+
raise ValueError(
181+
"'method' is not a valid kwarg when indexing using a CategoricalIndex."
182+
)
183+
if tolerance is not None:
184+
raise ValueError(
185+
"'tolerance' is not a valid kwarg when indexing using a CategoricalIndex."
186+
)
187+
indexer = index.get_loc(label.item())
178188
else:
179189
indexer = index.get_loc(
180190
label.item(), method=method, tolerance=tolerance

xarray/tests/test_dataset.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,56 @@ def test_sel_dataarray_mindex(self):
14081408
)
14091409
)
14101410

1411+
def test_sel_categorical(self):
1412+
ind = pd.Series(["foo", "bar"], dtype="category")
1413+
df = pd.DataFrame({"ind": ind, "values": [1, 2]})
1414+
ds = df.set_index("ind").to_xarray()
1415+
actual = ds.sel(ind="bar")
1416+
expected = ds.isel(ind=1)
1417+
assert_identical(expected, actual)
1418+
1419+
def test_sel_categorical_error(self):
1420+
ind = pd.Series(["foo", "bar"], dtype="category")
1421+
df = pd.DataFrame({"ind": ind, "values": [1, 2]})
1422+
ds = df.set_index("ind").to_xarray()
1423+
with pytest.raises(ValueError):
1424+
ds.sel(ind="bar", method="nearest")
1425+
with pytest.raises(ValueError):
1426+
ds.sel(ind="bar", tolerance="nearest")
1427+
1428+
def test_categorical_index(self):
1429+
cat = pd.CategoricalIndex(
1430+
["foo", "bar", "foo"],
1431+
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
1432+
)
1433+
ds = xr.Dataset(
1434+
{"var": ("cat", np.arange(3))},
1435+
coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 1])},
1436+
)
1437+
# test slice
1438+
actual = ds.sel(cat="foo")
1439+
expected = ds.isel(cat=[0, 2])
1440+
assert_identical(expected, actual)
1441+
# make sure the conversion to the array works
1442+
actual = ds.sel(cat="foo")["cat"].values
1443+
assert (actual == np.array(["foo", "foo"])).all()
1444+
1445+
ds = ds.set_index(index=["cat", "c"])
1446+
actual = ds.unstack("index")
1447+
assert actual["var"].shape == (2, 2)
1448+
1449+
def test_categorical_reindex(self):
1450+
cat = pd.CategoricalIndex(
1451+
["foo", "bar", "baz"],
1452+
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
1453+
)
1454+
ds = xr.Dataset(
1455+
{"var": ("cat", np.arange(3))},
1456+
coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 2])},
1457+
)
1458+
actual = ds.reindex(cat=["foo"])["cat"].values
1459+
assert (actual == np.array(["foo"])).all()
1460+
14111461
def test_sel_drop(self):
14121462
data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]})
14131463
expected = Dataset({"foo": 1})
@@ -3865,6 +3915,21 @@ def test_to_and_from_dataframe(self):
38653915
expected = pd.DataFrame([[]], index=idx)
38663916
assert expected.equals(actual), (expected, actual)
38673917

3918+
def test_from_dataframe_categorical(self):
3919+
cat = pd.CategoricalDtype(
3920+
categories=["foo", "bar", "baz", "qux", "quux", "corge"]
3921+
)
3922+
i1 = pd.Series(["foo", "bar", "foo"], dtype=cat)
3923+
i2 = pd.Series(["bar", "bar", "baz"], dtype=cat)
3924+
3925+
df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2, 3]})
3926+
ds = df.set_index("i1").to_xarray()
3927+
assert len(ds["i1"]) == 3
3928+
3929+
ds = df.set_index(["i1", "i2"]).to_xarray()
3930+
assert len(ds["i1"]) == 2
3931+
assert len(ds["i2"]) == 2
3932+
38683933
@requires_sparse
38693934
def test_from_dataframe_sparse(self):
38703935
import sparse

0 commit comments

Comments
 (0)