Skip to content

Commit fc85ba2

Browse files
committed
Remove internal use of IndexVariable
1 parent 599b779 commit fc85ba2

File tree

2 files changed

+55
-20
lines changed

2 files changed

+55
-20
lines changed

xarray/core/groupby.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
299299
codes: DataArray = field(init=False)
300300
full_index: pd.Index = field(init=False)
301301
group_indices: T_GroupIndices = field(init=False)
302-
unique_coord: IndexVariable | _DummyGroup = field(init=False)
302+
unique_coord: Variable | _DummyGroup = field(init=False)
303303

304304
# _ensure_1d:
305305
group1d: T_Group = field(init=False)
@@ -328,14 +328,18 @@ def __post_init__(self) -> None:
328328

329329
@property
330330
def name(self) -> Hashable:
331+
"""Name for the grouped coordinate after reduction."""
331332
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
332-
return self.unique_coord.name
333+
(name,) = self.unique_coord.dims
334+
return name
333335

334336
@property
335337
def size(self) -> int:
338+
"""Number of groups."""
336339
return len(self)
337340

338341
def __len__(self) -> int:
342+
"""Number of groups."""
339343
return len(self.full_index)
340344

341345
@property
@@ -358,8 +362,8 @@ def factorize(self) -> None:
358362
]
359363
if encoded.unique_coord is None:
360364
unique_values = self.full_index[np.unique(encoded.codes)]
361-
self.unique_coord = IndexVariable(
362-
self.codes.name, unique_values, attrs=self.group.attrs
365+
self.unique_coord = Variable(
366+
dims=self.codes.name, data=unique_values, attrs=self.group.attrs
363367
)
364368
else:
365369
self.unique_coord = encoded.unique_coord
@@ -620,6 +624,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]:
620624
yield self._obj.isel({self._group_dim: indices})
621625

622626
def _infer_concat_args(self, applied_example):
627+
from xarray.core.groupers import BinGrouper
628+
623629
(grouper,) = self.groupers
624630
if self._group_dim in applied_example.dims:
625631
coord = grouper.group1d
@@ -628,7 +634,10 @@ def _infer_concat_args(self, applied_example):
628634
coord = grouper.unique_coord
629635
positions = None
630636
(dim,) = coord.dims
631-
if isinstance(coord, _DummyGroup):
637+
if isinstance(grouper.group, _DummyGroup) and not isinstance(
638+
grouper.grouper, BinGrouper
639+
):
640+
# When binning we actually do set the index
632641
coord = None
633642
coord = getattr(coord, "variable", coord)
634643
return coord, dim, positions
@@ -645,12 +654,20 @@ def _binary_op(self, other, f, reflexive=False):
645654
codes = self._codes
646655
dims = group.dims
647656

648-
if isinstance(group, _DummyGroup):
657+
if isinstance(grouper.group, _DummyGroup):
649658
group = coord = group.to_dataarray()
650659
else:
651660
coord = grouper.unique_coord
652-
if not isinstance(coord, DataArray):
653-
coord = DataArray(grouper.unique_coord)
661+
if isinstance(coord, Variable):
662+
assert coord.ndim == 1
663+
(coord_dim,) = coord.dims
664+
# TODO: explicitly create Index here
665+
coord = DataArray(
666+
dims=coord_dim,
667+
data=coord.data,
668+
attrs=coord.attrs,
669+
coords={coord_dim: coord.data},
670+
)
654671
name = grouper.name
655672

656673
if not isinstance(other, (Dataset, DataArray)):
@@ -848,7 +865,10 @@ def _flox_reduce(
848865
# in the grouped variable
849866
group_dims = grouper.group.dims
850867
if set(group_dims).issubset(set(parsed_dim)):
851-
result[grouper.name] = output_index
868+
new_coord = Variable(
869+
dims=grouper.name, data=np.array(output_index), attrs=self._codes.attrs
870+
)
871+
result = result.assign_coords({grouper.name: new_coord})
852872
result = result.drop_vars(unindexed_dims)
853873

854874
# broadcast and restore non-numeric data variables (backcompat)

xarray/core/groupers.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from xarray.core.resample_cftime import CFTimeGrouper
2424
from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices
2525
from xarray.core.utils import emit_user_level_warning
26-
from xarray.core.variable import IndexVariable
26+
from xarray.core.variable import Variable
2727

2828
__all__ = [
2929
"EncodedGroups",
@@ -55,7 +55,17 @@ class EncodedGroups:
5555
codes: DataArray
5656
full_index: pd.Index
5757
group_indices: T_GroupIndices | None = field(default=None)
58-
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
58+
unique_coord: Variable | _DummyGroup | None = field(default=None)
59+
60+
def __post_init__(self):
61+
assert isinstance(self.codes, DataArray)
62+
if self.codes.name is None:
63+
raise ValueError("Please set a name on the array you are grouping by.")
64+
assert isinstance(self.full_index, pd.Index)
65+
assert (
66+
isinstance(self.unique_coord, (Variable, _DummyGroup))
67+
or self.unique_coord is None
68+
)
5969

6070

6171
class Grouper(ABC):
@@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups:
134144
"Failed to group data. Are you grouping by a variable that is all NaN?"
135145
)
136146
codes = self.group.copy(data=codes_)
137-
unique_coord = IndexVariable(
138-
self.group.name, unique_values, attrs=self.group.attrs
147+
unique_coord = Variable(
148+
dims=codes.name, data=unique_values, attrs=self.group.attrs
139149
)
140-
full_index = unique_coord
150+
full_index = pd.Index(unique_values)
141151

142152
return EncodedGroups(
143153
codes=codes, full_index=full_index, unique_coord=unique_coord
@@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups:
152162
size_range = np.arange(size)
153163
if isinstance(self.group, _DummyGroup):
154164
codes = self.group.to_dataarray().copy(data=size_range)
165+
unique_coord = self.group
166+
full_index = pd.RangeIndex(self.group.size)
155167
else:
156168
codes = self.group.copy(data=size_range)
157-
unique_coord = self.group
158-
full_index = IndexVariable(
159-
self.group.name, unique_coord.values, self.group.attrs
160-
)
169+
unique_coord = self.group.variable
170+
full_index = pd.Index(unique_coord.data)
171+
161172
return EncodedGroups(
162173
codes=codes,
163174
group_indices=group_indices,
@@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups:
201212
codes = DataArray(
202213
binned_codes, getattr(group, "coords", None), name=new_dim_name
203214
)
204-
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
215+
unique_coord = Variable(
216+
dims=new_dim_name, data=pd.Index(unique_values), attrs=group.attrs
217+
)
205218
return EncodedGroups(
206219
codes=codes, full_index=full_index, unique_coord=unique_coord
207220
)
@@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups:
318331
]
319332
group_indices += [slice(sbins[-1], None)]
320333

321-
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
334+
unique_coord = Variable(
335+
dims=group.name, data=first_items.index, attrs=group.attrs
336+
)
322337
codes = group.copy(data=codes_)
323338

324339
return EncodedGroups(

0 commit comments

Comments
 (0)