|
5 | 5 | import pandas as pd
|
6 | 6 |
|
7 | 7 | from . import dtypes, duck_array_ops, nputils, ops
|
| 8 | +from .alignment import align |
8 | 9 | from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
|
9 | 10 | from .concat import concat
|
10 | 11 | from .formatting import format_array_flat
|
@@ -258,6 +259,7 @@ class GroupBy:
|
258 | 259 | "_stacked_dim",
|
259 | 260 | "_unique_coord",
|
260 | 261 | "_dims",
|
| 262 | + "_weighted", |
261 | 263 | )
|
262 | 264 |
|
263 | 265 | def __init__(
|
@@ -395,6 +397,7 @@ def __init__(
|
395 | 397 | self._inserted_dims = inserted_dims
|
396 | 398 | self._full_index = full_index
|
397 | 399 | self._restore_coord_dims = restore_coord_dims
|
| 400 | + self._weighted = False |
398 | 401 |
|
399 | 402 | # cached attributes
|
400 | 403 | self._groups = None
|
@@ -431,12 +434,32 @@ def __len__(self):
|
431 | 434 | def __iter__(self):
|
432 | 435 | return zip(self._unique_coord.values, self._iter_grouped())
|
433 | 436 |
|
| 437 | + def weighted(self, weights): |
| 438 | + """Set weights for GroupBy reductions.""" |
| 439 | + self._weighted = True |
| 440 | + |
| 441 | + # exact alignment to ensure we don't add NaNs to the weights |
| 442 | + align(self._obj, weights, join="exact") |
| 443 | + |
| 444 | + # Ensures that weights will get sliced appropriately |
| 445 | + self._obj.coords["__weights__"] = weights |
| 446 | + |
| 447 | + # save the name for repr |
| 448 | + self._obj.coords["__weights__"].attrs["__name__"] = weights.name |
| 449 | + |
| 450 | + return self |
| 451 | + |
434 | 452 | def __repr__(self):
|
435 |
| - return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( |
436 |
| - self.__class__.__name__, |
437 |
| - self._unique_coord.name, |
438 |
| - self._unique_coord.size, |
439 |
| - ", ".join(format_array_flat(self._unique_coord, 30).split()), |
| 453 | + if self._weighted: |
| 454 | + weights_name = self._obj.__weights__.attrs["__name__"] |
| 455 | + weights_summary = f"\nweighted by {weights_name}" |
| 456 | + else: |
| 457 | + weights_summary = "unweighted" |
| 458 | + group_summary = ", ".join(format_array_flat(self._unique_coord, 30).split()) |
| 459 | + return ( |
| 460 | + f"{self.__class__.__name__}, grouped over {self._unique_coord.name}," |
| 461 | + f"\n{self._unique_coord.size} groups with labels {group_summary}." |
| 462 | + f"{weights_summary}" |
440 | 463 | )
|
441 | 464 |
|
442 | 465 | def _get_index_and_items(self, index, grouper):
|
@@ -980,7 +1003,20 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
|
980 | 1003 | keep_attrs = _get_keep_attrs(default=False)
|
981 | 1004 |
|
982 | 1005 | def reduce_dataset(ds):
|
983 |
| - return ds.reduce(func, dim, keep_attrs, **kwargs) |
| 1006 | + if self._weighted: |
| 1007 | + # yuck |
| 1008 | + if func.__name__ in ["sum", "mean"]: |
| 1009 | + # TODO: fix this |
| 1010 | + kwargs.pop("numeric_only", None) |
| 1011 | + return getattr(ds.weighted(ds["__weights__"]), func.__name__)( |
| 1012 | + dim=dim, keep_attrs=keep_attrs, **kwargs |
| 1013 | + ) |
| 1014 | + else: |
| 1015 | + raise NotImplementedError( |
| 1016 | + f"Weighted {func} has not been implemented." |
| 1017 | + ) |
| 1018 | + else: |
| 1019 | + return ds.reduce(func, dim, keep_attrs, **kwargs) |
984 | 1020 |
|
985 | 1021 | check_reduce_dims(dim, self.dims)
|
986 | 1022 |
|
|
0 commit comments