23
23
24
24
from . import dtypes , duck_array_ops , nputils , ops
25
25
from ._reductions import DataArrayGroupByReductions , DatasetGroupByReductions
26
+ from .alignment import align
26
27
from .arithmetic import DataArrayGroupbyArithmetic , DatasetGroupbyArithmetic
27
28
from .concat import concat
28
29
from .formatting import format_array_flat
@@ -309,7 +310,7 @@ class GroupBy(Generic[T_Xarray]):
309
310
"_squeeze" ,
310
311
# Save unstacked object for flox
311
312
"_original_obj" ,
312
- "_unstacked_group " ,
313
+ "_original_group " ,
313
314
"_bins" ,
314
315
)
315
316
_obj : T_Xarray
@@ -374,7 +375,7 @@ def __init__(
374
375
group .name = "group"
375
376
376
377
self ._original_obj : T_Xarray = obj
377
- self ._unstacked_group = group
378
+ self ._original_group = group
378
379
self ._bins = bins
379
380
380
381
group , obj , stacked_dim , inserted_dims = _ensure_1d (group , obj )
@@ -571,11 +572,22 @@ def _binary_op(self, other, f, reflexive=False):
571
572
572
573
g = f if not reflexive else lambda x , y : f (y , x )
573
574
574
- obj = self ._obj
575
- group = self ._group
576
- dim = self ._group_dim
575
+ if self ._bins is None :
576
+ obj = self ._original_obj
577
+ group = self ._original_group
578
+ dims = group .dims
579
+ else :
580
+ obj = self ._maybe_unstack (self ._obj )
581
+ group = self ._maybe_unstack (self ._group )
582
+ dims = (self ._group_dim ,)
583
+
577
584
if isinstance (group , _DummyGroup ):
578
- group = obj [dim ]
585
+ group = obj [group .name ]
586
+ coord = group
587
+ else :
588
+ coord = self ._unique_coord
589
+ if not isinstance (coord , DataArray ):
590
+ coord = DataArray (self ._unique_coord )
579
591
name = group .name
580
592
581
593
if not isinstance (other , (Dataset , DataArray )):
@@ -592,37 +604,19 @@ def _binary_op(self, other, f, reflexive=False):
592
604
"is not a dimension on the other argument"
593
605
)
594
606
595
- try :
596
- expanded = other .sel ({name : group })
597
- except KeyError :
598
- # some labels are absent i.e. other is not aligned
599
- # so we align by reindexing and then rename dimensions.
600
-
601
- # Broadcast out scalars for backwards compatibility
602
- # TODO: get rid of this when fixing GH2145
603
- for var in other .coords :
604
- if other [var ].ndim == 0 :
605
- other [var ] = (
606
- other [var ].drop_vars (var ).expand_dims ({name : other .sizes [name ]})
607
- )
608
- expanded = (
609
- other .reindex ({name : group .data })
610
- .rename ({name : dim })
611
- .assign_coords ({dim : obj [dim ]})
612
- )
607
+ # Broadcast out scalars for backwards compatibility
608
+ # TODO: get rid of this when fixing GH2145
609
+ for var in other .coords :
610
+ if other [var ].ndim == 0 :
611
+ other [var ] = (
612
+ other [var ].drop_vars (var ).expand_dims ({name : other .sizes [name ]})
613
+ )
613
614
614
- if self ._bins is not None and name == dim and dim not in obj .xindexes :
615
- # When binning by unindexed coordinate we need to reindex obj.
616
- # _full_index is IntervalIndex, so idx will be -1 where
617
- # a value does not belong to any bin. Using IntervalIndex
618
- # accounts for any non-default cut_kwargs passed to the constructor
619
- idx = pd .cut (group , bins = self ._full_index ).codes
620
- obj = obj .isel ({dim : np .arange (group .size )[idx != - 1 ]})
615
+ other , _ = align (other , coord , join = "outer" )
616
+ expanded = other .sel ({name : group })
621
617
622
618
result = g (obj , expanded )
623
619
624
- result = self ._maybe_unstack (result )
625
- group = self ._maybe_unstack (group )
626
620
if group .ndim > 1 :
627
621
# backcompat:
628
622
# TODO: get rid of this when fixing GH2145
@@ -632,8 +626,9 @@ def _binary_op(self, other, f, reflexive=False):
632
626
633
627
if isinstance (result , Dataset ) and isinstance (obj , Dataset ):
634
628
for var in set (result ):
635
- if dim not in obj [var ].dims :
636
- result [var ] = result [var ].transpose (dim , ...)
629
+ for d in dims :
630
+ if d not in obj [var ].dims :
631
+ result [var ] = result [var ].transpose (d , ...)
637
632
return result
638
633
639
634
def _maybe_restore_empty_groups (self , combined ):
@@ -695,10 +690,10 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs):
695
690
# group is only passed by resample
696
691
group = kwargs .pop ("group" , None )
697
692
if group is None :
698
- if isinstance (self ._unstacked_group , _DummyGroup ):
699
- group = self ._unstacked_group .name
693
+ if isinstance (self ._original_group , _DummyGroup ):
694
+ group = self ._original_group .name
700
695
else :
701
- group = self ._unstacked_group
696
+ group = self ._original_group
702
697
703
698
unindexed_dims = tuple ()
704
699
if isinstance (group , str ):
0 commit comments