diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 1ca5de965d0..6bb0d59428d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -877,7 +877,22 @@ def reduce_array(ar): return self.map(reduce_array, shortcut=shortcut) -class DatasetGroupBy(GroupBy, DatasetGroupbyArithmetic): +class DatasetGroupByCombine(GroupBy): + def _combine(self, applied): + """Recombine the applied objects like the original.""" + applied_example, applied = peek_at(applied) + coord, dim, positions = self._infer_concat_args(applied_example) + combined = concat(applied, dim) + combined = _maybe_reorder(combined, dim, positions) + # assign coord when the applied function does not return that coord + if coord is not None and dim not in applied_example.dims: + combined[coord.name] = coord + combined = self._maybe_restore_empty_groups(combined) + combined = self._maybe_unstack(combined) + return combined + + +class DatasetGroupBy(DatasetGroupByCombine, DatasetGroupbyArithmetic): __slots__ = () @@ -931,19 +946,6 @@ def apply(self, func, args=(), shortcut=None, **kwargs): ) return self.map(func, shortcut=shortcut, args=args, **kwargs) - def _combine(self, applied): - """Recombine the applied objects like the original.""" - applied_example, applied = peek_at(applied) - coord, dim, positions = self._infer_concat_args(applied_example) - combined = concat(applied, dim) - combined = _maybe_reorder(combined, dim, positions) - # assign coord when the applied function does not return that coord - if coord is not None and dim not in applied_example.dims: - combined[coord.name] = coord - combined = self._maybe_restore_empty_groups(combined) - combined = self._maybe_unstack(combined) - return combined - def reduce(self, func, dim=None, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). @@ -994,3 +996,53 @@ def assign(self, **kwargs): Dataset.assign """ return self.map(lambda ds: ds.assign(**kwargs)) + + +class WeightedDatasetGroupBy(DatasetGroupByCombine): + def __init__( + self, + obj, + weights, + group, + squeeze=False, + restore_coord_dims=None, + ): + # Only used for repr + + weights_name = weights.name + if weights_name: + name_str = f"by {weights_name!r}" + else: + name_str = "" + weight_dims = ", ".join(weights.dims) + + self.weights_repr = f"\nweighted along dimensions: {weight_dims} " f"{name_str}" + super().__init__( + # assigning as coords means weights get sliced out like + # groups + obj=obj.assign_coords({"__weights__": weights}), + group=group, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + + def __repr__(self): + return f"{super().__repr__()}" + self.weights_repr + + def _reduce(self, func, dim=None, skipna=None, keep_attrs=None): + if dim is None: + dim = self._group_dim + + applied = ( + getattr(ds.weighted(ds.__weights__), func)( + dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + for ds in super()._iter_grouped() + ) + return super()._combine(applied) + + def mean(self): + return self._reduce("mean") + + def sum(self): + return self._reduce("sum") diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e8838b07157..a1d4ecfac4f 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -263,6 +263,25 @@ def _implementation(self, func, dim, **kwargs) -> "Dataset": return self.obj.map(func, dim=dim, **kwargs) + def groupby( + self, + group, + squeeze=False, + grouper=None, + bins=None, + restore_coord_dims=None, + cut_kwargs=None, + ): + from .groupby import WeightedDatasetGroupBy + + return WeightedDatasetGroupBy( + weights=self.weights, + obj=self.obj, + group=group, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + def _inject_docstring(cls, cls_name):