Skip to content

Commit 8514a9e

Browse files
committed
Remove internal use of IndexVariable
1 parent 62595fd commit 8514a9e

File tree

2 files changed

+55
-20
lines changed

2 files changed

+55
-20
lines changed

xarray/core/groupby.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
324324
codes: DataArray = field(init=False)
325325
full_index: pd.Index = field(init=False)
326326
group_indices: T_GroupIndices = field(init=False)
327-
unique_coord: IndexVariable | _DummyGroup = field(init=False)
327+
unique_coord: Variable | _DummyGroup = field(init=False)
328328

329329
# _ensure_1d:
330330
group1d: T_Group = field(init=False)
@@ -353,14 +353,18 @@ def __post_init__(self) -> None:
353353

354354
@property
355355
def name(self) -> Hashable:
356+
"""Name for the grouped coordinate after reduction."""
356357
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
357-
return self.unique_coord.name
358+
(name,) = self.unique_coord.dims
359+
return name
358360

359361
@property
360362
def size(self) -> int:
363+
"""Number of groups."""
361364
return len(self)
362365

363366
def __len__(self) -> int:
367+
"""Number of groups."""
364368
return len(self.full_index)
365369

366370
@property
@@ -383,8 +387,8 @@ def factorize(self) -> None:
383387
]
384388
if encoded.unique_coord is None:
385389
unique_values = self.full_index[np.unique(encoded.codes)]
386-
self.unique_coord = IndexVariable(
387-
self.codes.name, unique_values, attrs=self.group.attrs
390+
self.unique_coord = Variable(
391+
dims=self.codes.name, data=unique_values, attrs=self.group.attrs
388392
)
389393
else:
390394
self.unique_coord = encoded.unique_coord
@@ -645,6 +649,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]:
645649
yield self._obj.isel({self._group_dim: indices})
646650

647651
def _infer_concat_args(self, applied_example):
652+
from xarray.core.groupers import BinGrouper
653+
648654
(grouper,) = self.groupers
649655
if self._group_dim in applied_example.dims:
650656
coord = grouper.group1d
@@ -653,7 +659,10 @@ def _infer_concat_args(self, applied_example):
653659
coord = grouper.unique_coord
654660
positions = None
655661
(dim,) = coord.dims
656-
if isinstance(coord, _DummyGroup):
662+
if isinstance(grouper.group, _DummyGroup) and not isinstance(
663+
grouper.grouper, BinGrouper
664+
):
665+
# When binning we actually do set the index
657666
coord = None
658667
coord = getattr(coord, "variable", coord)
659668
return coord, dim, positions
@@ -670,12 +679,20 @@ def _binary_op(self, other, f, reflexive=False):
670679
codes = self._codes
671680
dims = group.dims
672681

673-
if isinstance(group, _DummyGroup):
682+
if isinstance(grouper.group, _DummyGroup):
674683
group = coord = group.to_dataarray()
675684
else:
676685
coord = grouper.unique_coord
677-
if not isinstance(coord, DataArray):
678-
coord = DataArray(grouper.unique_coord)
686+
if isinstance(coord, Variable):
687+
assert coord.ndim == 1
688+
(coord_dim,) = coord.dims
689+
# TODO: explicitly create Index here
690+
coord = DataArray(
691+
dims=coord_dim,
692+
data=coord.data,
693+
attrs=coord.attrs,
694+
coords={coord_dim: coord.data},
695+
)
679696
name = grouper.name
680697

681698
if not isinstance(other, (Dataset, DataArray)):
@@ -873,7 +890,10 @@ def _flox_reduce(
873890
# in the grouped variable
874891
group_dims = grouper.group.dims
875892
if set(group_dims).issubset(set(parsed_dim)):
876-
result[grouper.name] = output_index
893+
new_coord = Variable(
894+
dims=grouper.name, data=np.array(output_index), attrs=self._codes.attrs
895+
)
896+
result = result.assign_coords({grouper.name: new_coord})
877897
result = result.drop_vars(unindexed_dims)
878898

879899
# broadcast and restore non-numeric data variables (backcompat)

xarray/core/groupers.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from xarray.core.resample_cftime import CFTimeGrouper
2424
from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices
2525
from xarray.core.utils import emit_user_level_warning
26-
from xarray.core.variable import IndexVariable
26+
from xarray.core.variable import Variable
2727

2828
__all__ = [
2929
"EncodedGroups",
@@ -55,7 +55,17 @@ class EncodedGroups:
5555
codes: DataArray
5656
full_index: pd.Index
5757
group_indices: T_GroupIndices | None = field(default=None)
58-
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
58+
unique_coord: Variable | _DummyGroup | None = field(default=None)
59+
60+
def __post_init__(self):
61+
assert isinstance(self.codes, DataArray)
62+
if self.codes.name is None:
63+
raise ValueError("Please set a name on the array you are grouping by.")
64+
assert isinstance(self.full_index, pd.Index)
65+
assert (
66+
isinstance(self.unique_coord, (Variable, _DummyGroup))
67+
or self.unique_coord is None
68+
)
5969

6070

6171
class Grouper(ABC):
@@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups:
134144
"Failed to group data. Are you grouping by a variable that is all NaN?"
135145
)
136146
codes = self.group.copy(data=codes_)
137-
unique_coord = IndexVariable(
138-
self.group.name, unique_values, attrs=self.group.attrs
147+
unique_coord = Variable(
148+
dims=codes.name, data=unique_values, attrs=self.group.attrs
139149
)
140-
full_index = unique_coord
150+
full_index = pd.Index(unique_values)
141151

142152
return EncodedGroups(
143153
codes=codes, full_index=full_index, unique_coord=unique_coord
@@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups:
152162
size_range = np.arange(size)
153163
if isinstance(self.group, _DummyGroup):
154164
codes = self.group.to_dataarray().copy(data=size_range)
165+
unique_coord = self.group
166+
full_index = pd.RangeIndex(self.group.size)
155167
else:
156168
codes = self.group.copy(data=size_range)
157-
unique_coord = self.group
158-
full_index = IndexVariable(
159-
self.group.name, unique_coord.values, self.group.attrs
160-
)
169+
unique_coord = self.group.variable
170+
full_index = pd.Index(unique_coord.data)
171+
161172
return EncodedGroups(
162173
codes=codes,
163174
group_indices=group_indices,
@@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups:
201212
codes = DataArray(
202213
binned_codes, getattr(group, "coords", None), name=new_dim_name
203214
)
204-
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
215+
unique_coord = Variable(
216+
dims=new_dim_name, data=pd.Index(unique_values), attrs=group.attrs
217+
)
205218
return EncodedGroups(
206219
codes=codes, full_index=full_index, unique_coord=unique_coord
207220
)
@@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups:
318331
]
319332
group_indices += [slice(sbins[-1], None)]
320333

321-
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
334+
unique_coord = Variable(
335+
dims=group.name, data=first_items.index, attrs=group.attrs
336+
)
322337
codes = group.copy(data=codes_)
323338

324339
return EncodedGroups(

0 commit comments

Comments
 (0)