Skip to content

Commit f6c41ae

Browse files
committed
Implement groupby.weighted
1 parent fe87162 commit f6c41ae

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

xarray/core/groupby.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66

77
from . import dtypes, duck_array_ops, nputils, ops
8+
from .alignment import align
89
from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
910
from .concat import concat
1011
from .formatting import format_array_flat
@@ -258,6 +259,7 @@ class GroupBy:
258259
"_stacked_dim",
259260
"_unique_coord",
260261
"_dims",
262+
"_weighted",
261263
)
262264

263265
def __init__(
@@ -395,6 +397,7 @@ def __init__(
395397
self._inserted_dims = inserted_dims
396398
self._full_index = full_index
397399
self._restore_coord_dims = restore_coord_dims
400+
self._weighted = False
398401

399402
# cached attributes
400403
self._groups = None
@@ -431,12 +434,32 @@ def __len__(self):
431434
def __iter__(self):
432435
return zip(self._unique_coord.values, self._iter_grouped())
433436

437+
def weighted(self, weights):
438+
"""Set weights for GroupBy reductions."""
439+
self._weighted = True
440+
441+
# exact alignment to ensure we don't add NaNs to the weights
442+
align(self._obj, weights, join="exact")
443+
444+
# Ensures that weights will get sliced appropriately
445+
self._obj.coords["__weights__"] = weights
446+
447+
# save the name for repr
448+
self._obj.coords["__weights__"].attrs["__name__"] = weights.name
449+
450+
return self
451+
434452
def __repr__(self):
435-
return "{}, grouped over {!r}\n{!r} groups with labels {}.".format(
436-
self.__class__.__name__,
437-
self._unique_coord.name,
438-
self._unique_coord.size,
439-
", ".join(format_array_flat(self._unique_coord, 30).split()),
453+
if self._weighted:
454+
weights_name = self._obj.__weights__.attrs["__name__"]
455+
weights_summary = f"\nweighted by {weights_name}"
456+
else:
457+
weights_summary = "unweighted"
458+
group_summary = ", ".join(format_array_flat(self._unique_coord, 30).split())
459+
return (
460+
f"{self.__class__.__name__}, grouped over {self._unique_coord.name},"
461+
f"\n{self._unique_coord.size} groups with labels {group_summary}."
462+
f"{weights_summary}"
440463
)
441464

442465
def _get_index_and_items(self, index, grouper):
@@ -980,7 +1003,20 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
9801003
keep_attrs = _get_keep_attrs(default=False)
9811004

9821005
def reduce_dataset(ds):
983-
return ds.reduce(func, dim, keep_attrs, **kwargs)
1006+
if self._weighted:
1007+
# yuck
1008+
if func.__name__ in ["sum", "mean"]:
1009+
# TODO: fix this
1010+
kwargs.pop("numeric_only", None)
1011+
return getattr(ds.weighted(ds["__weights__"]), func.__name__)(
1012+
dim=dim, keep_attrs=keep_attrs, **kwargs
1013+
)
1014+
else:
1015+
raise NotImplementedError(
1016+
f"Weighted {func} has not been implemented."
1017+
)
1018+
else:
1019+
return ds.reduce(func, dim, keep_attrs, **kwargs)
9841020

9851021
check_reduce_dims(dim, self.dims)
9861022

0 commit comments

Comments
 (0)