Skip to content

Weighted quantile #6059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Mar 27, 2022
Merged
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
dcd1b24
Add weighted quantile
cjauvin Dec 10, 2021
875f766
Add weighted quantile to documentation
cjauvin Dec 10, 2021
217d2f0
Apply suggestions from code review
cjauvin Dec 10, 2021
278bed9
Apply suggestions from code review
cjauvin Dec 10, 2021
80ae229
Improve _weighted_quantile_type7_1d ufunc with suggestions
cjauvin Dec 11, 2021
77bb84e
Merge remote-tracking branch 'origin/main' into weighted-quantile
cjauvin Dec 11, 2021
58af567
Expand scope of invalid q value test
cjauvin Dec 11, 2021
15e3834
Fix weighted quantile with zero weights
cjauvin Dec 13, 2021
83e4210
Replace np.ones by xr.ones_like in weighted quantile test
cjauvin Dec 13, 2021
2237399
Process weighted quantile data with all nans
cjauvin Dec 13, 2021
b936e21
Fix operator precedence bug
cjauvin Dec 13, 2021
c94fa16
Merge branch 'main' into pr/6059
Illviljan Dec 29, 2021
ab810d7
Used effective sample size. Generalize to different quantile types su…
huard Jan 11, 2022
7bcf09e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2022
8427637
Apply suggestions from code review
huard Jan 12, 2022
3217962
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2022
7379d22
added missing Typing hints
huard Jan 12, 2022
c8871d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2022
b26a5fc
update what's new and pep8 fixes
huard Jan 12, 2022
42ebcc2
merge
huard Jan 12, 2022
5aa22a4
Merge branch 'main' into weighted-quantile
huard Jan 12, 2022
abe253e
add docstring paragraph discussing weight interpretation
huard Jan 19, 2022
82147aa
recognize numpy names for quantile interpolation methods
huard Jan 19, 2022
784cedd
tweak to avoid warning with all nans data. simplify test
huard Jan 19, 2022
3ee62fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2022
4d6a4fd
remove integers from quantile interpolation available methods
huard Jan 19, 2022
9f93f55
merge
huard Jan 19, 2022
8132320
Merge remote-tracking branch 'upstream/main' into weighted-quantile
huard Jan 19, 2022
4d7f5f5
remove merge artifacts
Illviljan Jan 19, 2022
c268ddd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2022
db706aa
[skip-ci] fix bad merge in whats-new
dcherian Jan 20, 2022
33ee96c
Add references
huard Jan 20, 2022
2ffd3d3
renamed htype argument to method in private functions
huard Feb 8, 2022
9806db8
Merge branch 'main' into weighted-quantile
huard Feb 8, 2022
15ee999
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2022
9559c87
Update xarray/core/weighted.py
huard Feb 9, 2022
73dde79
Add skipped test to verify equal weights quantile with methods
cjauvin Feb 17, 2022
59714af
Apply suggestions from code review
huard Feb 17, 2022
585b705
Update xarray/core/weighted.py
huard Feb 17, 2022
7443b82
modifications suggested by review: comments, remove align, clarify te…
huard Feb 17, 2022
42a6a49
adjust typing. resolve conflicts
huard Feb 17, 2022
2e0c16e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
4186a24
Apply suggestions from code review
huard Feb 21, 2022
1be1f92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2022
5c251e0
use broadcast
mathause Feb 25, 2022
9060f8e
Merge branch 'main' into weighted-quantile
mathause Mar 4, 2022
c112d2c
move whatsnew entry
mathause Mar 4, 2022
d4ba8ee
Apply suggestions from code review
mathause Mar 4, 2022
fd0a54e
switch skipna determination
mathause Mar 4, 2022
8bd83f9
Merge branch 'weighted-quantile' of https://github.com/cjauvin/xarray…
mathause Mar 4, 2022
343b47e
use align and broadcast
mathause Mar 4, 2022
c298bd0
Merge branch 'main' into pr/6059
Illviljan Mar 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
@@ -944,6 +944,7 @@ Dataset

DatasetWeighted
DatasetWeighted.mean
DatasetWeighted.quantile
DatasetWeighted.sum
DatasetWeighted.std
DatasetWeighted.var
@@ -958,6 +959,7 @@ DataArray

DataArrayWeighted
DataArrayWeighted.mean
DataArrayWeighted.quantile
DataArrayWeighted.sum
DataArrayWeighted.std
DataArrayWeighted.var
8 changes: 7 additions & 1 deletion doc/user-guide/computation.rst
Original file line number Diff line number Diff line change
@@ -265,7 +265,7 @@ Weighted array reductions

:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted`
and :py:meth:`Dataset.weighted` array reduction methods. They currently
support weighted ``sum``, ``mean``, ``std`` and ``var``.
support weighted ``sum``, ``mean``, ``std``, ``var`` and ``quantile``.

.. ipython:: python

@@ -293,6 +293,12 @@ Calculate the weighted mean:

weighted_prec.mean(dim="month")

Calculate the weighted quantile:

.. ipython:: python

weighted_prec.quantile(q=0.5, dim="month")

The weighted sum corresponds to:

.. ipython:: python
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -22,6 +22,9 @@ v2022.03.1 (unreleased)
New Features
~~~~~~~~~~~~

- Add a weighted ``quantile`` method to :py:class:`~core.weighted.DatasetWeighted` and
:py:class:`~core.weighted.DataArrayWeighted` (:pull:`6059`). By
`Christian Jauvin <https://github.com/cjauvin>`_ and `David Huard <https://github.com/huard>`_.
- Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and
:py:meth:`DataArray.stack` so that the creation of multi-indexes is optional
(:pull:`5692`). By `Benoît Bovy <https://github.com/benbovy>`_.
223 changes: 220 additions & 3 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, Hashable, Iterable, cast
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Literal, Sequence, cast

import numpy as np

from . import duck_array_ops
from .computation import dot
from . import duck_array_ops, utils
from .alignment import align, broadcast
from .computation import apply_ufunc, dot
from .npcompat import ArrayLike
from .pycompat import is_duck_dask_array
from .types import T_Xarray

# Weighted quantile methods are a subset of the numpy supported quantile methods.
QUANTILE_METHODS = Literal[
"linear",
"interpolated_inverted_cdf",
"hazen",
"weibull",
"median_unbiased",
"normal_unbiased",
]

_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).

@@ -56,6 +68,61 @@
New {cls} object with the sum of the weights over the given dimension.
"""

_WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE = """
Apply a weighted ``quantile`` to this {cls}'s data along some dimension(s).

Weights are interpreted as *sampling weights* (or probability weights) and
describe how a sample is scaled to the whole population [1]_. There are
other possible interpretations for weights, *precision weights* describing the
precision of observations, or *frequency weights* counting the number of identical
observations, however, they are not implemented here.

For compatibility with NumPy's non-weighted ``quantile`` (which is used by
``DataArray.quantile`` and ``Dataset.quantile``), the only interpolation
method supported by this weighted version corresponds to the default "linear"
option of ``numpy.quantile``. This is "Type 7" option, described in Hyndman
and Fan (1996) [2]_. The implementation is largely inspired by a blog post
from A. Akinshin's [3]_.

Parameters
----------
q : float or sequence of float
Quantile to compute, which must be between 0 and 1 inclusive.
dim : str or sequence of str, optional
Dimension(s) over which to apply the weighted ``quantile``.
skipna : bool, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
keep_attrs : bool, optional
If True, the attributes (``attrs``) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.

Returns
-------
quantiles : {cls}
New {cls} object with weighted ``quantile`` applied to its data and
the indicated dimension(s) removed.

See Also
--------
numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile

Notes
-----
Returns NaN if the ``weights`` sum to 0.0 along the reduced
dimension(s).

References
----------
.. [1] https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/
.. [2] Hyndman, R. J. & Fan, Y. (1996). Sample Quantiles in Statistical Packages.
The American Statistician, 50(4), 361–365. https://doi.org/10.2307/2684934
.. [3] https://aakinshin.net/posts/weighted-quantiles
"""


if TYPE_CHECKING:
from .dataarray import DataArray
@@ -241,6 +308,141 @@ def _weighted_std(

return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))

def _weighted_quantile(
self,
da: DataArray,
q: ArrayLike,
dim: Hashable | Iterable[Hashable] | None = None,
skipna: bool = None,
) -> DataArray:
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""

def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:
"""Return the interpolation parameter."""
# Note that options are not yet exposed in the public API.
if method == "linear":
h = (n - 1) * q + 1
elif method == "interpolated_inverted_cdf":
h = n * q
elif method == "hazen":
h = n * q + 0.5
elif method == "weibull":
h = (n + 1) * q
elif method == "median_unbiased":
h = (n + 1 / 3) * q + 1 / 3
elif method == "normal_unbiased":
h = (n + 1 / 4) * q + 3 / 8
else:
raise ValueError(f"Invalid method: {method}.")
return h.clip(1, n)

def _weighted_quantile_1d(
data: np.ndarray,
weights: np.ndarray,
q: np.ndarray,
skipna: bool,
method: QUANTILE_METHODS = "linear",
) -> np.ndarray:

# This algorithm has been adapted from:
# https://aakinshin.net/posts/weighted-quantiles/#reference-implementation
is_nan = np.isnan(data)
if skipna:
# Remove nans from data and weights
not_nan = ~is_nan
data = data[not_nan]
weights = weights[not_nan]
elif is_nan.any():
# Return nan if data contains any nan
return np.full(q.size, np.nan)

# Filter out data (and weights) associated with zero weights, which also flattens them
nonzero_weights = weights != 0
data = data[nonzero_weights]
weights = weights[nonzero_weights]
n = data.size

if n == 0:
# Possibly empty after nan or zero weight filtering above
return np.full(q.size, np.nan)

# Kish's effective sample size
nw = weights.sum() ** 2 / (weights**2).sum()

# Sort data and weights
sorter = np.argsort(data)
data = data[sorter]
weights = weights[sorter]

# Normalize and sum the weights
weights = weights / weights.sum()
weights_cum = np.append(0, weights.cumsum())

# Vectorize the computation by transposing q with respect to weights
q = np.atleast_2d(q).T

# Get the interpolation parameter for each q
h = _get_h(nw, q, method)

# Find the samples contributing to the quantile computation (at *positions* between (h-1)/nw and h/nw)
u = np.maximum((h - 1) / nw, np.minimum(h / nw, weights_cum))

# Compute their relative weight
v = u * nw - h + 1
w = np.diff(v)

# Apply the weights
return (data * w).sum(axis=1)

if skipna is None and da.dtype.kind in "cfO":
skipna = True

q = np.atleast_1d(np.asarray(q, dtype=np.float64))

if q.ndim > 1:
raise ValueError("q must be a scalar or 1d")

if np.any((q < 0) | (q > 1)):
raise ValueError("q values must be between 0 and 1")

if dim is None:
dim = da.dims

if utils.is_scalar(dim):
dim = [dim]

# To satisfy mypy
dim = cast(Sequence, dim)

# need to align *and* broadcast
# - `_weighted_quantile_1d` requires arrays with the same shape
# - broadcast does an outer join, which can introduce NaN to weights
# - therefore we first need to do align(..., join="inner")

# TODO: use broadcast(..., join="inner") once available
# see https://github.com/pydata/xarray/issues/6304

da, weights = align(da, self.weights, join="inner")
da, weights = broadcast(da, weights)

result = apply_ufunc(
_weighted_quantile_1d,
da,
weights,
input_core_dims=[dim, dim],
output_core_dims=[["quantile"]],
output_dtypes=[np.float64],
dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
dask="parallelized",
vectorize=True,
kwargs={"q": q, "skipna": skipna},
)

result = result.transpose("quantile", ...)
result = result.assign_coords(quantile=q).squeeze()

return result

def _implementation(self, func, dim, **kwargs):

raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
@@ -310,6 +512,19 @@ def std(
self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def quantile(
self,
q: ArrayLike,
*,
dim: Hashable | Sequence[Hashable] | None = None,
keep_attrs: bool = None,
skipna: bool = True,
) -> T_Xarray:

return self._implementation(
self._weighted_quantile, q=q, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def __repr__(self):
"""provide a nice str repr of our Weighted object"""

@@ -360,6 +575,8 @@ def _inject_docstring(cls, cls_name):
cls=cls_name, fcn="std", on_zero="NaN"
)

cls.quantile.__doc__ = _WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE.format(cls=cls_name)


_inject_docstring(DataArrayWeighted, "DataArray")
_inject_docstring(DatasetWeighted, "Dataset")
237 changes: 221 additions & 16 deletions xarray/tests/test_weighted.py
Original file line number Diff line number Diff line change
@@ -194,6 +194,160 @@ def test_weighted_mean_no_nan(weights, expected):
assert_equal(expected, result)


@pytest.mark.parametrize(
("weights", "expected"),
(
(
[0.25, 0.05, 0.15, 0.25, 0.15, 0.1, 0.05],
[1.554595, 2.463784, 3.000000, 3.518378],
),
(
[0.05, 0.05, 0.1, 0.15, 0.15, 0.25, 0.25],
[2.840000, 3.632973, 4.076216, 4.523243],
),
),
)
def test_weighted_quantile_no_nan(weights, expected):
# Expected values were calculated by running the reference implementation
# proposed in https://aakinshin.net/posts/weighted-quantiles/

da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5])
q = [0.2, 0.4, 0.6, 0.8]
weights = DataArray(weights)

expected = DataArray(expected, coords={"quantile": q})
result = da.weighted(weights).quantile(q)

assert_allclose(expected, result)


def test_weighted_quantile_zero_weights():

da = DataArray([0, 1, 2, 3])
weights = DataArray([1, 0, 1, 0])
q = 0.75

result = da.weighted(weights).quantile(q)
expected = DataArray([0, 2]).quantile(0.75)

assert_allclose(expected, result)


def test_weighted_quantile_simple():
# Check that weighted quantiles return the same value as numpy quantiles
da = DataArray([0, 1, 2, 3])
w = DataArray([1, 0, 1, 0])

w_eps = DataArray([1, 0.0001, 1, 0.0001])
q = 0.75

expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q}) # 1.5

assert_equal(expected, da.weighted(w).quantile(q))
assert_allclose(expected, da.weighted(w_eps).quantile(q), rtol=0.001)


@pytest.mark.parametrize("skipna", (True, False))
def test_weighted_quantile_nan(skipna):
# Check skipna behavior
da = DataArray([0, 1, 2, 3, np.nan])
w = DataArray([1, 0, 1, 0, 1])
q = [0.5, 0.75]

result = da.weighted(w).quantile(q, skipna=skipna)

if skipna:
expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q})
else:
expected = DataArray(np.full(len(q), np.nan), coords={"quantile": q})

assert_allclose(expected, result)


@pytest.mark.parametrize(
"da",
(
[1, 1.9, 2.2, 3, 3.7, 4.1, 5],
[1, 1.9, 2.2, 3, 3.7, 4.1, np.nan],
[np.nan, np.nan, np.nan],
),
)
@pytest.mark.parametrize("q", (0.5, (0.2, 0.8)))
@pytest.mark.parametrize("skipna", (True, False))
@pytest.mark.parametrize("factor", [1, 3.14])
def test_weighted_quantile_equal_weights(da, q, skipna, factor):
# if all weights are equal (!= 0), should yield the same result as quantile

da = DataArray(da)
weights = xr.full_like(da, factor)

expected = da.quantile(q, skipna=skipna)
result = da.weighted(weights).quantile(q, skipna=skipna)

assert_allclose(expected, result)


@pytest.mark.skip(reason="`method` argument is not currently exposed")
@pytest.mark.parametrize(
"da",
(
[1, 1.9, 2.2, 3, 3.7, 4.1, 5],
[1, 1.9, 2.2, 3, 3.7, 4.1, np.nan],
[np.nan, np.nan, np.nan],
),
)
@pytest.mark.parametrize("q", (0.5, (0.2, 0.8)))
@pytest.mark.parametrize("skipna", (True, False))
@pytest.mark.parametrize(
"method",
[
"linear",
"interpolated_inverted_cdf",
"hazen",
"weibull",
"median_unbiased",
"normal_unbiased2",
],
)
def test_weighted_quantile_equal_weights_all_methods(da, q, skipna, factor, method):
# If all weights are equal (!= 0), should yield the same result as numpy quantile

da = DataArray(da)
weights = xr.full_like(da, 3.14)

expected = da.quantile(q, skipna=skipna, method=method)
result = da.weighted(weights).quantile(q, skipna=skipna, method=method)

assert_allclose(expected, result)


def test_weighted_quantile_bool():
# https://github.com/pydata/xarray/issues/4074
da = DataArray([1, 1])
weights = DataArray([True, True])
q = 0.5

expected = DataArray([1], coords={"quantile": [q]}).squeeze()
result = da.weighted(weights).quantile(q)

assert_equal(expected, result)


@pytest.mark.parametrize("q", (-1, 1.1, (0.5, 1.1), ((0.2, 0.4), (0.6, 0.8))))
def test_weighted_quantile_with_invalid_q(q):

da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5])
q = np.asarray(q)
weights = xr.ones_like(da)

if q.ndim <= 1:
with pytest.raises(ValueError, match="q values must be between 0 and 1"):
da.weighted(weights).quantile(q)
else:
with pytest.raises(ValueError, match="q must be a scalar or 1d"):
da.weighted(weights).quantile(q)


@pytest.mark.parametrize(
("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan))
)
@@ -466,16 +620,56 @@ def test_weighted_operations_3D(dim, add_nans, skipna):
check_weighted_operations(data, weights, dim, skipna)


def test_weighted_operations_nonequal_coords():
@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None))
@pytest.mark.parametrize("q", (0.5, (0.1, 0.9), (0.2, 0.4, 0.6, 0.8)))
@pytest.mark.parametrize("add_nans", (True, False))
@pytest.mark.parametrize("skipna", (None, True, False))
def test_weighted_quantile_3D(dim, q, add_nans, skipna):

dims = ("a", "b", "c")
coords = dict(a=[0, 1, 2], b=[0, 1, 2, 3], c=[0, 1, 2, 3, 4])

data = np.arange(60).reshape(3, 4, 5).astype(float)

# add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700)
if add_nans:
c = int(data.size * 0.25)
data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN

da = DataArray(data, dims=dims, coords=coords)

# Weights are all ones, because we will compare against DataArray.quantile (non-weighted)
weights = xr.ones_like(da)

result = da.weighted(weights).quantile(q, dim=dim, skipna=skipna)
expected = da.quantile(q, dim=dim, skipna=skipna)

assert_allclose(expected, result)

ds = da.to_dataset(name="data")
result2 = ds.weighted(weights).quantile(q, dim=dim, skipna=skipna)

assert_allclose(expected, result2.data)


def test_weighted_operations_nonequal_coords():
# There are no weights for a == 4, so that data point is ignored.
weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3]))
data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4]))

check_weighted_operations(data, weights, dim="a", skipna=None)

q = 0.5
result = data.weighted(weights).quantile(q, dim="a")
# Expected value computed using code from https://aakinshin.net/posts/weighted-quantiles/ with values at a=1,2,3
expected = DataArray([0.9308707], coords={"quantile": [q]}).squeeze()
assert_allclose(result, expected)

data = data.to_dataset(name="data")
check_weighted_operations(data, weights, dim="a", skipna=None)

result = data.weighted(weights).quantile(q, dim="a")
assert_allclose(result, expected.to_dataset(name="data"))


@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4)))
@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4)))
@@ -506,7 +700,8 @@ def test_weighted_operations_different_shapes(


@pytest.mark.parametrize(
"operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std")
"operation",
("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"),
)
@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("keep_attrs", (True, False, None))
@@ -520,22 +715,23 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):

data.attrs = dict(attr="weights")

result = getattr(data.weighted(weights), operation)(keep_attrs=True)
kwargs = {"keep_attrs": keep_attrs}
if operation == "quantile":
kwargs["q"] = 0.5

result = getattr(data.weighted(weights), operation)(**kwargs)

if operation == "sum_of_weights":
assert weights.attrs == result.attrs
assert result.attrs == (weights.attrs if keep_attrs else {})
assert result.attrs == (weights.attrs if keep_attrs else {})
else:
assert data.attrs == result.attrs

result = getattr(data.weighted(weights), operation)(keep_attrs=None)
assert not result.attrs

result = getattr(data.weighted(weights), operation)(keep_attrs=False)
assert not result.attrs
assert result.attrs == (weights.attrs if keep_attrs else {})
assert result.attrs == (data.attrs if keep_attrs else {})


@pytest.mark.parametrize(
"operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std")
"operation",
("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"),
)
def test_weighted_operations_keep_attr_da_in_ds(operation):
# GH #3595
@@ -544,22 +740,31 @@ def test_weighted_operations_keep_attr_da_in_ds(operation):
data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data"))
data = data.to_dataset(name="a")

result = getattr(data.weighted(weights), operation)(keep_attrs=True)
kwargs = {"keep_attrs": True}
if operation == "quantile":
kwargs["q"] = 0.5

result = getattr(data.weighted(weights), operation)(**kwargs)

assert data.a.attrs == result.a.attrs


@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean", "quantile"))
@pytest.mark.parametrize("as_dataset", (True, False))
def test_weighted_bad_dim(as_dataset):
def test_weighted_bad_dim(operation, as_dataset):

data = DataArray(np.random.randn(2, 2))
weights = xr.ones_like(data)
if as_dataset:
data = data.to_dataset(name="data")

kwargs = {"dim": "bad_dim"}
if operation == "quantile":
kwargs["q"] = 0.5

error_msg = (
f"{data.__class__.__name__}Weighted"
" does not contain the dimensions: {'bad_dim'}"
)
with pytest.raises(ValueError, match=error_msg):
data.weighted(weights).mean("bad_dim")
getattr(data.weighted(weights), operation)(**kwargs)