Skip to content

Commit 2645d7f

Browse files
authored
groupby: remove some internal use of IndexVariable (#9123)
* Remove internal use of IndexVariable * cleanup * cleanup more * cleanup
1 parent af722f0 commit 2645d7f

File tree

2 files changed

+67
-33
lines changed

2 files changed

+67
-33
lines changed

xarray/core/groupby.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
2020
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
2121
from xarray.core.concat import concat
22+
from xarray.core.coordinates import Coordinates
2223
from xarray.core.formatting import format_array_flat
2324
from xarray.core.indexes import (
25+
PandasIndex,
2426
create_default_index_implicit,
2527
filter_indexes_from_coords,
2628
)
@@ -246,7 +248,7 @@ def to_array(self) -> DataArray:
246248
return self.to_dataarray()
247249

248250

249-
T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup]
251+
T_Group = Union["T_DataArray", _DummyGroup]
250252

251253

252254
def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
@@ -256,7 +258,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
256258
list[Hashable],
257259
]:
258260
# 1D cases: do nothing
259-
if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1:
261+
if isinstance(group, _DummyGroup) or group.ndim == 1:
260262
return group, obj, None, []
261263

262264
from xarray.core.dataarray import DataArray
@@ -271,9 +273,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
271273
newobj = obj.stack({stacked_dim: orig_dims})
272274
return newgroup, newobj, stacked_dim, inserted_dims
273275

274-
raise TypeError(
275-
f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}."
276-
)
276+
raise TypeError(f"group must be DataArray or _DummyGroup, got {type(group)!r}.")
277277

278278

279279
@dataclass
@@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
299299
codes: DataArray = field(init=False)
300300
full_index: pd.Index = field(init=False)
301301
group_indices: T_GroupIndices = field(init=False)
302-
unique_coord: IndexVariable | _DummyGroup = field(init=False)
302+
unique_coord: Variable | _DummyGroup = field(init=False)
303303

304304
# _ensure_1d:
305305
group1d: T_Group = field(init=False)
@@ -315,7 +315,7 @@ def __post_init__(self) -> None:
315315
# might be used multiple times.
316316
self.grouper = copy.deepcopy(self.grouper)
317317

318-
self.group: T_Group = _resolve_group(self.obj, self.group)
318+
self.group = _resolve_group(self.obj, self.group)
319319

320320
(
321321
self.group1d,
@@ -328,14 +328,18 @@ def __post_init__(self) -> None:
328328

329329
@property
330330
def name(self) -> Hashable:
331+
"""Name for the grouped coordinate after reduction."""
331332
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
332-
return self.unique_coord.name
333+
(name,) = self.unique_coord.dims
334+
return name
333335

334336
@property
335337
def size(self) -> int:
338+
"""Number of groups."""
336339
return len(self)
337340

338341
def __len__(self) -> int:
342+
"""Number of groups."""
339343
return len(self.full_index)
340344

341345
@property
@@ -358,8 +362,8 @@ def factorize(self) -> None:
358362
]
359363
if encoded.unique_coord is None:
360364
unique_values = self.full_index[np.unique(encoded.codes)]
361-
self.unique_coord = IndexVariable(
362-
self.codes.name, unique_values, attrs=self.group.attrs
365+
self.unique_coord = Variable(
366+
dims=self.codes.name, data=unique_values, attrs=self.group.attrs
363367
)
364368
else:
365369
self.unique_coord = encoded.unique_coord
@@ -378,7 +382,9 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None:
378382
)
379383

380384

381-
def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group:
385+
def _resolve_group(
386+
obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable
387+
) -> T_Group:
382388
from xarray.core.dataarray import DataArray
383389

384390
error_msg = (
@@ -620,6 +626,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]:
620626
yield self._obj.isel({self._group_dim: indices})
621627

622628
def _infer_concat_args(self, applied_example):
629+
from xarray.core.groupers import BinGrouper
630+
623631
(grouper,) = self.groupers
624632
if self._group_dim in applied_example.dims:
625633
coord = grouper.group1d
@@ -628,7 +636,10 @@ def _infer_concat_args(self, applied_example):
628636
coord = grouper.unique_coord
629637
positions = None
630638
(dim,) = coord.dims
631-
if isinstance(coord, _DummyGroup):
639+
if isinstance(grouper.group, _DummyGroup) and not isinstance(
640+
grouper.grouper, BinGrouper
641+
):
642+
# When binning we actually do set the index
632643
coord = None
633644
coord = getattr(coord, "variable", coord)
634645
return coord, dim, positions
@@ -641,6 +652,7 @@ def _binary_op(self, other, f, reflexive=False):
641652

642653
(grouper,) = self.groupers
643654
obj = self._original_obj
655+
name = grouper.name
644656
group = grouper.group
645657
codes = self._codes
646658
dims = group.dims
@@ -649,9 +661,11 @@ def _binary_op(self, other, f, reflexive=False):
649661
group = coord = group.to_dataarray()
650662
else:
651663
coord = grouper.unique_coord
652-
if not isinstance(coord, DataArray):
653-
coord = DataArray(grouper.unique_coord)
654-
name = grouper.name
664+
if isinstance(coord, Variable):
665+
assert coord.ndim == 1
666+
(coord_dim,) = coord.dims
667+
# TODO: explicitly create Index here
668+
coord = DataArray(coord, coords={coord_dim: coord.data})
655669

656670
if not isinstance(other, (Dataset, DataArray)):
657671
raise TypeError(
@@ -766,6 +780,7 @@ def _flox_reduce(
766780

767781
obj = self._original_obj
768782
(grouper,) = self.groupers
783+
name = grouper.name
769784
isbin = isinstance(grouper.grouper, BinGrouper)
770785

771786
if keep_attrs is None:
@@ -797,14 +812,14 @@ def _flox_reduce(
797812
# weird backcompat
798813
# reducing along a unique indexed dimension with squeeze=True
799814
# should raise an error
800-
if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes:
801-
index = obj.indexes[grouper.name]
815+
if (dim is None or dim == name) and name in obj.xindexes:
816+
index = obj.indexes[name]
802817
if index.is_unique and self._squeeze:
803-
raise ValueError(f"cannot reduce over dimensions {grouper.name!r}")
818+
raise ValueError(f"cannot reduce over dimensions {name!r}")
804819

805820
unindexed_dims: tuple[Hashable, ...] = tuple()
806821
if isinstance(grouper.group, _DummyGroup) and not isbin:
807-
unindexed_dims = (grouper.name,)
822+
unindexed_dims = (name,)
808823

809824
parsed_dim: tuple[Hashable, ...]
810825
if isinstance(dim, str):
@@ -848,15 +863,19 @@ def _flox_reduce(
848863
# in the grouped variable
849864
group_dims = grouper.group.dims
850865
if set(group_dims).issubset(set(parsed_dim)):
851-
result[grouper.name] = output_index
866+
result = result.assign_coords(
867+
Coordinates(
868+
coords={name: (name, np.array(output_index))},
869+
indexes={name: PandasIndex(output_index, dim=name)},
870+
)
871+
)
852872
result = result.drop_vars(unindexed_dims)
853873

854874
# broadcast and restore non-numeric data variables (backcompat)
855875
for name, var in non_numeric.items():
856876
if all(d not in var.dims for d in parsed_dim):
857877
result[name] = var.variable.set_dims(
858-
(grouper.name,) + var.dims,
859-
(result.sizes[grouper.name],) + var.shape,
878+
(name,) + var.dims, (result.sizes[name],) + var.shape
860879
)
861880

862881
if not isinstance(result, Dataset):

xarray/core/groupers.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from xarray.core.resample_cftime import CFTimeGrouper
2424
from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices
2525
from xarray.core.utils import emit_user_level_warning
26-
from xarray.core.variable import IndexVariable
26+
from xarray.core.variable import Variable
2727

2828
__all__ = [
2929
"EncodedGroups",
@@ -55,7 +55,17 @@ class EncodedGroups:
5555
codes: DataArray
5656
full_index: pd.Index
5757
group_indices: T_GroupIndices | None = field(default=None)
58-
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
58+
unique_coord: Variable | _DummyGroup | None = field(default=None)
59+
60+
def __post_init__(self):
61+
assert isinstance(self.codes, DataArray)
62+
if self.codes.name is None:
63+
raise ValueError("Please set a name on the array you are grouping by.")
64+
assert isinstance(self.full_index, pd.Index)
65+
assert (
66+
isinstance(self.unique_coord, (Variable, _DummyGroup))
67+
or self.unique_coord is None
68+
)
5969

6070

6171
class Grouper(ABC):
@@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups:
134144
"Failed to group data. Are you grouping by a variable that is all NaN?"
135145
)
136146
codes = self.group.copy(data=codes_)
137-
unique_coord = IndexVariable(
138-
self.group.name, unique_values, attrs=self.group.attrs
147+
unique_coord = Variable(
148+
dims=codes.name, data=unique_values, attrs=self.group.attrs
139149
)
140-
full_index = unique_coord
150+
full_index = pd.Index(unique_values)
141151

142152
return EncodedGroups(
143153
codes=codes, full_index=full_index, unique_coord=unique_coord
@@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups:
152162
size_range = np.arange(size)
153163
if isinstance(self.group, _DummyGroup):
154164
codes = self.group.to_dataarray().copy(data=size_range)
165+
unique_coord = self.group
166+
full_index = pd.RangeIndex(self.group.size)
155167
else:
156168
codes = self.group.copy(data=size_range)
157-
unique_coord = self.group
158-
full_index = IndexVariable(
159-
self.group.name, unique_coord.values, self.group.attrs
160-
)
169+
unique_coord = self.group.variable.to_base_variable()
170+
full_index = pd.Index(unique_coord.data)
171+
161172
return EncodedGroups(
162173
codes=codes,
163174
group_indices=group_indices,
@@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups:
201212
codes = DataArray(
202213
binned_codes, getattr(group, "coords", None), name=new_dim_name
203214
)
204-
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
215+
unique_coord = Variable(
216+
dims=new_dim_name, data=unique_values, attrs=group.attrs
217+
)
205218
return EncodedGroups(
206219
codes=codes, full_index=full_index, unique_coord=unique_coord
207220
)
@@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups:
318331
]
319332
group_indices += [slice(sbins[-1], None)]
320333

321-
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
334+
unique_coord = Variable(
335+
dims=group.name, data=first_items.index, attrs=group.attrs
336+
)
322337
codes = group.copy(data=codes_)
323338

324339
return EncodedGroups(

0 commit comments

Comments
 (0)