19
19
from xarray .core .arithmetic import DataArrayGroupbyArithmetic , DatasetGroupbyArithmetic
20
20
from xarray .core .common import ImplementsArrayReduce , ImplementsDatasetReduce
21
21
from xarray .core .concat import concat
22
+ from xarray .core .coordinates import Coordinates
22
23
from xarray .core .formatting import format_array_flat
23
24
from xarray .core .indexes import (
25
+ PandasIndex ,
24
26
create_default_index_implicit ,
25
27
filter_indexes_from_coords ,
26
28
)
@@ -246,7 +248,7 @@ def to_array(self) -> DataArray:
246
248
return self .to_dataarray ()
247
249
248
250
249
- T_Group = Union ["T_DataArray" , "IndexVariable" , _DummyGroup ]
251
+ T_Group = Union ["T_DataArray" , _DummyGroup ]
250
252
251
253
252
254
def _ensure_1d (group : T_Group , obj : T_DataWithCoords ) -> tuple [
@@ -256,7 +258,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
256
258
list [Hashable ],
257
259
]:
258
260
# 1D cases: do nothing
259
- if isinstance (group , ( IndexVariable , _DummyGroup ) ) or group .ndim == 1 :
261
+ if isinstance (group , _DummyGroup ) or group .ndim == 1 :
260
262
return group , obj , None , []
261
263
262
264
from xarray .core .dataarray import DataArray
@@ -271,9 +273,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
271
273
newobj = obj .stack ({stacked_dim : orig_dims })
272
274
return newgroup , newobj , stacked_dim , inserted_dims
273
275
274
- raise TypeError (
275
- f"group must be DataArray, IndexVariable or _DummyGroup, got { type (group )!r} ."
276
- )
276
+ raise TypeError (f"group must be DataArray or _DummyGroup, got { type (group )!r} ." )
277
277
278
278
279
279
@dataclass
@@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
299
299
codes : DataArray = field (init = False )
300
300
full_index : pd .Index = field (init = False )
301
301
group_indices : T_GroupIndices = field (init = False )
302
- unique_coord : IndexVariable | _DummyGroup = field (init = False )
302
+ unique_coord : Variable | _DummyGroup = field (init = False )
303
303
304
304
# _ensure_1d:
305
305
group1d : T_Group = field (init = False )
@@ -315,7 +315,7 @@ def __post_init__(self) -> None:
315
315
# might be used multiple times.
316
316
self .grouper = copy .deepcopy (self .grouper )
317
317
318
- self .group : T_Group = _resolve_group (self .obj , self .group )
318
+ self .group = _resolve_group (self .obj , self .group )
319
319
320
320
(
321
321
self .group1d ,
@@ -328,14 +328,18 @@ def __post_init__(self) -> None:
328
328
329
329
@property
330
330
def name (self ) -> Hashable :
331
+ """Name for the grouped coordinate after reduction."""
331
332
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
332
- return self .unique_coord .name
333
+ (name ,) = self .unique_coord .dims
334
+ return name
333
335
334
336
@property
335
337
def size (self ) -> int :
338
+ """Number of groups."""
336
339
return len (self )
337
340
338
341
def __len__ (self ) -> int :
342
+ """Number of groups."""
339
343
return len (self .full_index )
340
344
341
345
@property
@@ -358,8 +362,8 @@ def factorize(self) -> None:
358
362
]
359
363
if encoded .unique_coord is None :
360
364
unique_values = self .full_index [np .unique (encoded .codes )]
361
- self .unique_coord = IndexVariable (
362
- self .codes .name , unique_values , attrs = self .group .attrs
365
+ self .unique_coord = Variable (
366
+ dims = self .codes .name , data = unique_values , attrs = self .group .attrs
363
367
)
364
368
else :
365
369
self .unique_coord = encoded .unique_coord
@@ -378,7 +382,9 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None:
378
382
)
379
383
380
384
381
- def _resolve_group (obj : T_DataWithCoords , group : T_Group | Hashable ) -> T_Group :
385
+ def _resolve_group (
386
+ obj : T_DataWithCoords , group : T_Group | Hashable | IndexVariable
387
+ ) -> T_Group :
382
388
from xarray .core .dataarray import DataArray
383
389
384
390
error_msg = (
@@ -620,6 +626,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]:
620
626
yield self ._obj .isel ({self ._group_dim : indices })
621
627
622
628
def _infer_concat_args (self , applied_example ):
629
+ from xarray .core .groupers import BinGrouper
630
+
623
631
(grouper ,) = self .groupers
624
632
if self ._group_dim in applied_example .dims :
625
633
coord = grouper .group1d
@@ -628,7 +636,10 @@ def _infer_concat_args(self, applied_example):
628
636
coord = grouper .unique_coord
629
637
positions = None
630
638
(dim ,) = coord .dims
631
- if isinstance (coord , _DummyGroup ):
639
+ if isinstance (grouper .group , _DummyGroup ) and not isinstance (
640
+ grouper .grouper , BinGrouper
641
+ ):
642
+ # When binning we actually do set the index
632
643
coord = None
633
644
coord = getattr (coord , "variable" , coord )
634
645
return coord , dim , positions
@@ -641,6 +652,7 @@ def _binary_op(self, other, f, reflexive=False):
641
652
642
653
(grouper ,) = self .groupers
643
654
obj = self ._original_obj
655
+ name = grouper .name
644
656
group = grouper .group
645
657
codes = self ._codes
646
658
dims = group .dims
@@ -649,9 +661,11 @@ def _binary_op(self, other, f, reflexive=False):
649
661
group = coord = group .to_dataarray ()
650
662
else :
651
663
coord = grouper .unique_coord
652
- if not isinstance (coord , DataArray ):
653
- coord = DataArray (grouper .unique_coord )
654
- name = grouper .name
664
+ if isinstance (coord , Variable ):
665
+ assert coord .ndim == 1
666
+ (coord_dim ,) = coord .dims
667
+ # TODO: explicitly create Index here
668
+ coord = DataArray (coord , coords = {coord_dim : coord .data })
655
669
656
670
if not isinstance (other , (Dataset , DataArray )):
657
671
raise TypeError (
@@ -766,6 +780,7 @@ def _flox_reduce(
766
780
767
781
obj = self ._original_obj
768
782
(grouper ,) = self .groupers
783
+ name = grouper .name
769
784
isbin = isinstance (grouper .grouper , BinGrouper )
770
785
771
786
if keep_attrs is None :
@@ -797,14 +812,14 @@ def _flox_reduce(
797
812
# weird backcompat
798
813
# reducing along a unique indexed dimension with squeeze=True
799
814
# should raise an error
800
- if (dim is None or dim == grouper . name ) and grouper . name in obj .xindexes :
801
- index = obj .indexes [grouper . name ]
815
+ if (dim is None or dim == name ) and name in obj .xindexes :
816
+ index = obj .indexes [name ]
802
817
if index .is_unique and self ._squeeze :
803
- raise ValueError (f"cannot reduce over dimensions { grouper . name !r} " )
818
+ raise ValueError (f"cannot reduce over dimensions { name !r} " )
804
819
805
820
unindexed_dims : tuple [Hashable , ...] = tuple ()
806
821
if isinstance (grouper .group , _DummyGroup ) and not isbin :
807
- unindexed_dims = (grouper . name ,)
822
+ unindexed_dims = (name ,)
808
823
809
824
parsed_dim : tuple [Hashable , ...]
810
825
if isinstance (dim , str ):
@@ -848,15 +863,19 @@ def _flox_reduce(
848
863
# in the grouped variable
849
864
group_dims = grouper .group .dims
850
865
if set (group_dims ).issubset (set (parsed_dim )):
851
- result [grouper .name ] = output_index
866
+ result = result .assign_coords (
867
+ Coordinates (
868
+ coords = {name : (name , np .array (output_index ))},
869
+ indexes = {name : PandasIndex (output_index , dim = name )},
870
+ )
871
+ )
852
872
result = result .drop_vars (unindexed_dims )
853
873
854
874
# broadcast and restore non-numeric data variables (backcompat)
855
875
for name , var in non_numeric .items ():
856
876
if all (d not in var .dims for d in parsed_dim ):
857
877
result [name ] = var .variable .set_dims (
858
- (grouper .name ,) + var .dims ,
859
- (result .sizes [grouper .name ],) + var .shape ,
878
+ (name ,) + var .dims , (result .sizes [name ],) + var .shape
860
879
)
861
880
862
881
if not isinstance (result , Dataset ):
0 commit comments