Skip to content

Commit 9b54b44

Browse files
authored
Refactor groupby binary ops code. (#6789)
1 parent 392a614 commit 9b54b44

File tree

1 file changed

+33
-38
lines changed

1 file changed

+33
-38
lines changed

xarray/core/groupby.py

+33-38
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from . import dtypes, duck_array_ops, nputils, ops
2525
from ._reductions import DataArrayGroupByReductions, DatasetGroupByReductions
26+
from .alignment import align
2627
from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
2728
from .concat import concat
2829
from .formatting import format_array_flat
@@ -309,7 +310,7 @@ class GroupBy(Generic[T_Xarray]):
309310
"_squeeze",
310311
# Save unstacked object for flox
311312
"_original_obj",
312-
"_unstacked_group",
313+
"_original_group",
313314
"_bins",
314315
)
315316
_obj: T_Xarray
@@ -374,7 +375,7 @@ def __init__(
374375
group.name = "group"
375376

376377
self._original_obj: T_Xarray = obj
377-
self._unstacked_group = group
378+
self._original_group = group
378379
self._bins = bins
379380

380381
group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj)
@@ -571,11 +572,22 @@ def _binary_op(self, other, f, reflexive=False):
571572

572573
g = f if not reflexive else lambda x, y: f(y, x)
573574

574-
obj = self._obj
575-
group = self._group
576-
dim = self._group_dim
575+
if self._bins is None:
576+
obj = self._original_obj
577+
group = self._original_group
578+
dims = group.dims
579+
else:
580+
obj = self._maybe_unstack(self._obj)
581+
group = self._maybe_unstack(self._group)
582+
dims = (self._group_dim,)
583+
577584
if isinstance(group, _DummyGroup):
578-
group = obj[dim]
585+
group = obj[group.name]
586+
coord = group
587+
else:
588+
coord = self._unique_coord
589+
if not isinstance(coord, DataArray):
590+
coord = DataArray(self._unique_coord)
579591
name = group.name
580592

581593
if not isinstance(other, (Dataset, DataArray)):
@@ -592,37 +604,19 @@ def _binary_op(self, other, f, reflexive=False):
592604
"is not a dimension on the other argument"
593605
)
594606

595-
try:
596-
expanded = other.sel({name: group})
597-
except KeyError:
598-
# some labels are absent i.e. other is not aligned
599-
# so we align by reindexing and then rename dimensions.
600-
601-
# Broadcast out scalars for backwards compatibility
602-
# TODO: get rid of this when fixing GH2145
603-
for var in other.coords:
604-
if other[var].ndim == 0:
605-
other[var] = (
606-
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
607-
)
608-
expanded = (
609-
other.reindex({name: group.data})
610-
.rename({name: dim})
611-
.assign_coords({dim: obj[dim]})
612-
)
607+
# Broadcast out scalars for backwards compatibility
608+
# TODO: get rid of this when fixing GH2145
609+
for var in other.coords:
610+
if other[var].ndim == 0:
611+
other[var] = (
612+
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
613+
)
613614

614-
if self._bins is not None and name == dim and dim not in obj.xindexes:
615-
# When binning by unindexed coordinate we need to reindex obj.
616-
# _full_index is IntervalIndex, so idx will be -1 where
617-
# a value does not belong to any bin. Using IntervalIndex
618-
# accounts for any non-default cut_kwargs passed to the constructor
619-
idx = pd.cut(group, bins=self._full_index).codes
620-
obj = obj.isel({dim: np.arange(group.size)[idx != -1]})
615+
other, _ = align(other, coord, join="outer")
616+
expanded = other.sel({name: group})
621617

622618
result = g(obj, expanded)
623619

624-
result = self._maybe_unstack(result)
625-
group = self._maybe_unstack(group)
626620
if group.ndim > 1:
627621
# backcompat:
628622
# TODO: get rid of this when fixing GH2145
@@ -632,8 +626,9 @@ def _binary_op(self, other, f, reflexive=False):
632626

633627
if isinstance(result, Dataset) and isinstance(obj, Dataset):
634628
for var in set(result):
635-
if dim not in obj[var].dims:
636-
result[var] = result[var].transpose(dim, ...)
629+
for d in dims:
630+
if d not in obj[var].dims:
631+
result[var] = result[var].transpose(d, ...)
637632
return result
638633

639634
def _maybe_restore_empty_groups(self, combined):
@@ -695,10 +690,10 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs):
695690
# group is only passed by resample
696691
group = kwargs.pop("group", None)
697692
if group is None:
698-
if isinstance(self._unstacked_group, _DummyGroup):
699-
group = self._unstacked_group.name
693+
if isinstance(self._original_group, _DummyGroup):
694+
group = self._original_group.name
700695
else:
701-
group = self._unstacked_group
696+
group = self._original_group
702697

703698
unindexed_dims = tuple()
704699
if isinstance(group, str):

0 commit comments

Comments
 (0)