Skip to content

Tests for module-level functions with units #3493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7bc2105
add tests for replication functions
keewis Nov 7, 2019
007ad7b
add tests for `xarray.dot`
keewis Nov 7, 2019
4f6daf2
Merge branch 'master' into tests-for-toplevel-functions-with-units
keewis Nov 7, 2019
c1407e2
add tests for apply_ufunc
keewis Nov 7, 2019
a21c872
explicitly set the test ids to repr
keewis Nov 7, 2019
9abb729
add tests for align
keewis Nov 7, 2019
4664630
cover a bit more of align
keewis Nov 8, 2019
3ecdb35
add tests for broadcast
keewis Nov 8, 2019
cf75525
black changed how tuple unpacking should look like
keewis Nov 8, 2019
b79d961
correct the xfail message for full_like tests
keewis Nov 8, 2019
6b6729e
add tests for where
keewis Nov 8, 2019
93b003c
add tests for concat
keewis Nov 8, 2019
f8351ec
add tests for combine_by_coords
keewis Nov 8, 2019
a1cecdc
Merge branch 'master' into tests-for-toplevel-functions-with-units
keewis Nov 9, 2019
eb8fe4e
fix a bug in convert_units
keewis Nov 9, 2019
f9f727e
convert the align results to the same units
keewis Nov 9, 2019
3d0dfb1
rename the combine_by_coords test
keewis Nov 9, 2019
2e426a3
convert the units for expected in combine_by_coords
keewis Nov 9, 2019
341ffbc
add tests for combine_nested
keewis Nov 9, 2019
a474203
add tests for merge with datasets
keewis Nov 9, 2019
61627f0
only use three datasets for merging
keewis Nov 10, 2019
d989ae8
add tests for merge with dataarrays
keewis Nov 10, 2019
c1d8e92
update whats-new.rst
keewis Nov 10, 2019
2c6e604
Merge branch 'master' into tests-for-toplevel-functions-with-units
keewis Nov 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -111,7 +111,8 @@ Internal Changes
~~~~~~~~~~~~~~~~

- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
(:pull:`3238`, :pull:`3447`, :pull:`3508`) by `Justus Magin <https://github.com/keewis>`_.
(:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`)
by `Justus Magin <https://github.com/keewis>`_.

.. note::

871 changes: 863 additions & 8 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
@@ -222,7 +222,9 @@ def convert_units(obj, to):
if name != obj.name
}

new_obj = xr.DataArray(name=name, data=data, coords=coords, attrs=obj.attrs)
new_obj = xr.DataArray(
name=name, data=data, coords=coords, attrs=obj.attrs, dims=obj.dims
)
elif isinstance(obj, unit_registry.Quantity):
units = to.get(None)
new_obj = obj.to(units) if units is not None else obj
@@ -307,19 +309,689 @@ def __repr__(self):


class function:
def __init__(self, name):
self.name = name
self.func = getattr(np, name)
def __init__(self, name_or_function, *args, **kwargs):
if callable(name_or_function):
self.name = name_or_function.__name__
self.func = name_or_function
else:
self.name = name_or_function
self.func = getattr(np, name_or_function)
if self.func is None:
raise AttributeError(
f"module 'numpy' has no attribute named '{self.name}'"
)

self.args = args
self.kwargs = kwargs

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
all_args = list(self.args) + list(args)
all_kwargs = {**self.kwargs, **kwargs}

return self.func(*all_args, **all_kwargs)

def __repr__(self):
return f"function_{self.name}"


def test_apply_ufunc_dataarray(dtype):
func = function(
xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1}
)

array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.m
x = np.arange(20) * unit_registry.s
data_array = xr.DataArray(data=array, dims="x", coords={"x": x})

expected = attach_units(func(strip_units(data_array)), extract_units(data_array))
result = func(data_array)

assert_equal_with_units(expected, result)


@pytest.mark.xfail(
reason="pint does not implement `np.result_type` and align strips units"
)
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
"coords",
),
)
@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan)))
def test_align_dataarray(fill_value, variant, unit, error, dtype):
original_unit = unit_registry.m

variants = {
"data": (unit, original_unit, original_unit),
"dims": (original_unit, unit, original_unit),
"coords": (original_unit, original_unit, unit),
}
data_unit, dim_unit, coord_unit = variants.get(variant)

array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit
array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * data_unit
x = np.arange(2) * original_unit
x_a1 = np.array([10, 5]) * original_unit
x_a2 = np.array([10, 5]) * coord_unit

y1 = np.arange(5) * original_unit
y2 = np.arange(2, 7) * dim_unit

data_array1 = xr.DataArray(
data=array1, coords={"x": x, "x_a": ("x", x_a1), "y": y1}, dims=("x", "y")
)
data_array2 = xr.DataArray(
data=array2, coords={"x": x, "x_a": ("x", x_a2), "y": y2}, dims=("x", "y")
)

fill_value = fill_value * data_unit
func = function(xr.align, join="outer", fill_value=fill_value)
if error is not None:
with pytest.raises(error):
func(data_array1, data_array2)

return

stripped_kwargs = {
key: strip_units(
convert_units(value, {None: original_unit})
if isinstance(value, unit_registry.Quantity)
else value
)
for key, value in func.kwargs.items()
}
units = extract_units(data_array1)
# FIXME: should the expected_b have the same units as data_array1
# or data_array2?
expected_a, expected_b = tuple(
attach_units(elem, units)
for elem in func(
strip_units(data_array1),
strip_units(convert_units(data_array2, units)),
**stripped_kwargs,
)
)
result_a, result_b = func(data_array1, data_array2)

assert_equal_with_units(expected_a, result_a)
assert_equal_with_units(expected_b, result_b)


@pytest.mark.xfail(
reason="pint does not implement `np.result_type` and align strips units"
)
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
"coords",
),
)
@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan)))
def test_align_dataset(fill_value, unit, variant, error, dtype):
original_unit = unit_registry.m

variants = {
"data": (unit, original_unit, original_unit),
"dims": (original_unit, unit, original_unit),
"coords": (original_unit, original_unit, unit),
}
data_unit, dim_unit, coord_unit = variants.get(variant)

array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit
array2 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit

x = np.arange(2) * original_unit
x_a1 = np.array([10, 5]) * original_unit
x_a2 = np.array([10, 5]) * coord_unit

y1 = np.arange(5) * original_unit
y2 = np.arange(2, 7) * dim_unit

ds1 = xr.Dataset(
data_vars={"a": (("x", "y"), array1)},
coords={"x": x, "x_a": ("x", x_a1), "y": y1},
)
ds2 = xr.Dataset(
data_vars={"a": (("x", "y"), array2)},
coords={"x": x, "x_a": ("x", x_a2), "y": y2},
)

fill_value = fill_value * data_unit
func = function(xr.align, join="outer", fill_value=fill_value)
if error is not None:
with pytest.raises(error):
func(ds1, ds2)

return

stripped_kwargs = {
key: strip_units(
convert_units(value, {None: original_unit})
if isinstance(value, unit_registry.Quantity)
else value
)
for key, value in func.kwargs.items()
}
units = extract_units(ds1)
# FIXME: should the expected_b have the same units as ds1 or ds2?
expected_a, expected_b = tuple(
attach_units(elem, units)
for elem in func(
strip_units(ds1), strip_units(convert_units(ds2, units)), **stripped_kwargs
)
)
result_a, result_b = func(ds1, ds2)

assert_equal_with_units(expected_a, result_a)
assert_equal_with_units(expected_b, result_b)


def test_broadcast_dataarray(dtype):
array1 = np.linspace(0, 10, 2) * unit_registry.Pa
array2 = np.linspace(0, 10, 3) * unit_registry.Pa

a = xr.DataArray(data=array1, dims="x")
b = xr.DataArray(data=array2, dims="y")

expected_a, expected_b = tuple(
attach_units(elem, extract_units(a))
for elem in xr.broadcast(strip_units(a), strip_units(b))
)
result_a, result_b = xr.broadcast(a, b)

assert_equal_with_units(expected_a, result_a)
assert_equal_with_units(expected_b, result_b)


def test_broadcast_dataset(dtype):
array1 = np.linspace(0, 10, 2) * unit_registry.Pa
array2 = np.linspace(0, 10, 3) * unit_registry.Pa

ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("y", array2)})

(expected,) = tuple(
attach_units(elem, extract_units(ds)) for elem in xr.broadcast(strip_units(ds))
)
(result,) = xr.broadcast(ds)

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="`combine_by_coords` strips units")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
"coords",
),
)
def test_combine_by_coords(variant, unit, error, dtype):
original_unit = unit_registry.m

variants = {
"data": (unit, original_unit, original_unit),
"dims": (original_unit, unit, original_unit),
"coords": (original_unit, original_unit, unit),
}
data_unit, dim_unit, coord_unit = variants.get(variant)

array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
x = np.arange(1, 4) * 10 * original_unit
y = np.arange(2) * original_unit
z = np.arange(3) * original_unit

other_array1 = np.ones_like(array1) * data_unit
other_array2 = np.ones_like(array2) * data_unit
other_x = np.arange(1, 4) * 10 * dim_unit
other_y = np.arange(2, 4) * dim_unit
other_z = np.arange(3, 6) * coord_unit

ds = xr.Dataset(
data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
coords={"x": x, "y": y, "z": ("x", z)},
)
other = xr.Dataset(
data_vars={"a": (("y", "x"), other_array1), "b": (("y", "x"), other_array2)},
coords={"x": other_x, "y": other_y, "z": ("x", other_z)},
)

if error is not None:
with pytest.raises(error):
xr.combine_by_coords([ds, other])

return

units = extract_units(ds)
expected = attach_units(
xr.combine_by_coords(
[strip_units(ds), strip_units(convert_units(other, units))]
),
units,
)
result = xr.combine_by_coords([ds, other])

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="blocked by `where`")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
"coords",
),
)
def test_combine_nested(variant, unit, error, dtype):
original_unit = unit_registry.m

variants = {
"data": (unit, original_unit, original_unit),
"dims": (original_unit, unit, original_unit),
"coords": (original_unit, original_unit, unit),
}
data_unit, dim_unit, coord_unit = variants.get(variant)

array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit

x = np.arange(1, 4) * 10 * original_unit
y = np.arange(2) * original_unit
z = np.arange(3) * original_unit

ds1 = xr.Dataset(
data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
coords={"x": x, "y": y, "z": ("x", z)},
)
ds2 = xr.Dataset(
data_vars={
"a": (("y", "x"), np.ones_like(array1) * data_unit),
"b": (("y", "x"), np.ones_like(array2) * data_unit),
},
coords={
"x": np.arange(3) * dim_unit,
"y": np.arange(2, 4) * dim_unit,
"z": ("x", np.arange(-3, 0) * coord_unit),
},
)
ds3 = xr.Dataset(
data_vars={
"a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit),
"b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit),
},
coords={
"x": np.arange(3, 6) * dim_unit,
"y": np.arange(4, 6) * dim_unit,
"z": ("x", np.arange(3, 6) * coord_unit),
},
)
ds4 = xr.Dataset(
data_vars={
"a": (("y", "x"), -1 * np.ones_like(array1) * data_unit),
"b": (("y", "x"), -1 * np.ones_like(array2) * data_unit),
},
coords={
"x": np.arange(6, 9) * dim_unit,
"y": np.arange(6, 8) * dim_unit,
"z": ("x", np.arange(6, 9) * coord_unit),
},
)

func = function(xr.combine_nested, concat_dim=["x", "y"])
if error is not None:
with pytest.raises(error):
func([[ds1, ds2], [ds3, ds4]])

return

units = extract_units(ds1)
convert_and_strip = lambda ds: strip_units(convert_units(ds, units))
expected = attach_units(
func(
[
[strip_units(ds1), convert_and_strip(ds2)],
[convert_and_strip(ds3), convert_and_strip(ds4)],
]
),
units,
)
result = func([[ds1, ds2], [ds3, ds4]])

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="`concat` strips units")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
),
)
def test_concat_dataarray(variant, unit, error, dtype):
original_unit = unit_registry.m

variants = {"data": (unit, original_unit), "dims": (original_unit, unit)}
data_unit, dims_unit = variants.get(variant)

array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit
x1 = np.arange(5, 15) * original_unit
x2 = np.arange(5) * dims_unit

arr1 = xr.DataArray(data=array1, coords={"x": x1}, dims="x")
arr2 = xr.DataArray(data=array2, coords={"x": x2}, dims="x")

if error is not None:
with pytest.raises(error):
xr.concat([arr1, arr2], dim="x")

return

expected = attach_units(
xr.concat([strip_units(arr1), strip_units(arr2)], dim="x"), extract_units(arr1)
)
result = xr.concat([arr1, arr2], dim="x")

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="`concat` strips units")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
),
)
def test_concat_dataset(variant, unit, error, dtype):
original_unit = unit_registry.m

variants = {"data": (unit, original_unit), "dims": (original_unit, unit)}
data_unit, dims_unit = variants.get(variant)

array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit
x1 = np.arange(5, 15) * original_unit
x2 = np.arange(5) * dims_unit

ds1 = xr.Dataset(data_vars={"a": ("x", array1)}, coords={"x": x1})
ds2 = xr.Dataset(data_vars={"a": ("x", array2)}, coords={"x": x2})

if error is not None:
with pytest.raises(error):
xr.concat([ds1, ds2], dim="x")

return

expected = attach_units(
xr.concat([strip_units(ds1), strip_units(ds2)], dim="x"), extract_units(ds1)
)
result = xr.concat([ds1, ds2], dim="x")

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="blocked by `where`")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
"coords",
),
)
def test_merge_dataarray(variant, unit, error, dtype):
original_unit = unit_registry.m

variants = {
"data": (unit, original_unit, original_unit),
"dims": (original_unit, unit, original_unit),
"coords": (original_unit, original_unit, unit),
}
data_unit, dim_unit, coord_unit = variants.get(variant)

array1 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * original_unit
array2 = np.linspace(1, 2, 2 * 4).reshape(2, 4).astype(dtype) * data_unit
array3 = np.linspace(0, 2, 3 * 4).reshape(3, 4).astype(dtype) * data_unit

x = np.arange(2) * original_unit
y = np.arange(3) * original_unit
z = np.arange(4) * original_unit
u = np.linspace(10, 20, 2) * original_unit
v = np.linspace(10, 20, 3) * original_unit
w = np.linspace(10, 20, 4) * original_unit

arr1 = xr.DataArray(
name="a",
data=array1,
coords={"x": x, "y": y, "u": ("x", u), "v": ("y", v)},
dims=("x", "y"),
)
arr2 = xr.DataArray(
name="b",
data=array2,
coords={
"x": np.arange(2, 4) * dim_unit,
"z": z,
"u": ("x", np.linspace(20, 30, 2) * coord_unit),
"w": ("z", w),
},
dims=("x", "z"),
)
arr3 = xr.DataArray(
name="c",
data=array3,
coords={
"y": np.arange(3, 6) * dim_unit,
"z": np.arange(4, 8) * dim_unit,
"v": ("y", np.linspace(10, 20, 3) * coord_unit),
"w": ("z", np.linspace(10, 20, 4) * coord_unit),
},
dims=("y", "z"),
)

func = function(xr.merge)
if error is not None:
with pytest.raises(error):
func([arr1, arr2, arr3])

return

units = {name: original_unit for name in list("abcuvwxyz")}
convert_and_strip = lambda arr: strip_units(convert_units(arr, units))
expected = attach_units(
func([strip_units(arr1), convert_and_strip(arr2), convert_and_strip(arr3)]),
units,
)
result = func([arr1, arr2, arr3])

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="blocked by `where`")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize(
"variant",
(
"data",
pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
"coords",
),
)
def test_merge_dataset(variant, unit, error, dtype):
original_unit = unit_registry.m

variants = {
"data": (unit, original_unit, original_unit),
"dims": (original_unit, unit, original_unit),
"coords": (original_unit, original_unit, unit),
}
data_unit, dim_unit, coord_unit = variants.get(variant)

array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit

x = np.arange(11, 14) * original_unit
y = np.arange(2) * original_unit
z = np.arange(3) * original_unit

ds1 = xr.Dataset(
data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
coords={"x": x, "y": y, "z": ("x", z)},
)
ds2 = xr.Dataset(
data_vars={
"a": (("y", "x"), np.ones_like(array1) * data_unit),
"b": (("y", "x"), np.ones_like(array2) * data_unit),
},
coords={
"x": np.arange(3) * dim_unit,
"y": np.arange(2, 4) * dim_unit,
"z": ("x", np.arange(-3, 0) * coord_unit),
},
)
ds3 = xr.Dataset(
data_vars={
"a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit),
"b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit),
},
coords={
"x": np.arange(3, 6) * dim_unit,
"y": np.arange(4, 6) * dim_unit,
"z": ("x", np.arange(3, 6) * coord_unit),
},
)

func = function(xr.merge)
if error is not None:
with pytest.raises(error):
func([ds1, ds2, ds3])

return

units = extract_units(ds1)
convert_and_strip = lambda ds: strip_units(convert_units(ds, units))
expected = attach_units(
func([strip_units(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]), units
)
result = func([ds1, ds2, ds3])

assert_equal_with_units(expected, result)


@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like))
def test_replication(func, dtype):
def test_replication_dataarray(func, dtype):
array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
data_array = xr.DataArray(data=array, dims="x")

@@ -330,8 +1002,33 @@ def test_replication(func, dtype):
assert_equal_with_units(expected, result)


@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like))
def test_replication_dataset(func, dtype):
array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa
x = np.arange(20).astype(dtype) * unit_registry.m
y = np.arange(10).astype(dtype) * unit_registry.m
z = y.to(unit_registry.mm)

ds = xr.Dataset(
data_vars={"a": ("x", array1), "b": ("y", array2)},
coords={"x": x, "y": y, "z": ("y", z)},
)

numpy_func = getattr(np, func.__name__)
expected = ds.copy(
data={name: numpy_func(array.data) for name, array in ds.data_vars.items()}
)
result = func(ds)

assert_equal_with_units(expected, result)


@pytest.mark.xfail(
reason="np.full_like on Variable strips the unit and pint does not allow mixed args"
reason=(
"pint is undecided on how `full_like` should work, so incorrect errors "
"may be expected: hgrecco/pint#882"
)
)
@pytest.mark.parametrize(
"unit,error",
@@ -344,8 +1041,9 @@ def test_replication(func, dtype):
pytest.param(unit_registry.ms, None, id="compatible_unit"),
pytest.param(unit_registry.s, None, id="identical_unit"),
),
ids=repr,
)
def test_replication_full_like(unit, error, dtype):
def test_replication_full_like_dataarray(unit, error, dtype):
array = np.linspace(0, 5, 10) * unit_registry.s
data_array = xr.DataArray(data=array, dims="x")

@@ -360,6 +1058,163 @@ def test_replication_full_like(unit, error, dtype):
assert_equal_with_units(expected, result)


@pytest.mark.xfail(
reason=(
"pint is undecided on how `full_like` should work, so incorrect errors "
"may be expected: hgrecco/pint#882"
)
)
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.ms, None, id="compatible_unit"),
pytest.param(unit_registry.s, None, id="identical_unit"),
),
ids=repr,
)
def test_replication_full_like_dataset(unit, error, dtype):
array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa
x = np.arange(20).astype(dtype) * unit_registry.m
y = np.arange(10).astype(dtype) * unit_registry.m
z = y.to(unit_registry.mm)

ds = xr.Dataset(
data_vars={"a": ("x", array1), "b": ("y", array2)},
coords={"x": x, "y": y, "z": ("y", z)},
)

fill_value = -1 * unit
if error is not None:
with pytest.raises(error):
xr.full_like(ds, fill_value=fill_value)

return

expected = ds.copy(
data={
name: np.full_like(array, fill_value=fill_value)
for name, array in ds.data_vars.items()
}
)
result = xr.full_like(ds, fill_value=fill_value)

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="`where` strips units")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize("fill_value", (np.nan, 10.2))
def test_where_dataarray(fill_value, unit, error, dtype):
array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m

x = xr.DataArray(data=array, dims="x")
cond = x < 5 * unit_registry.m
# FIXME: this should work without wrapping in array()
fill_value = np.array(fill_value) * unit

if error is not None:
with pytest.raises(error):
xr.where(cond, x, fill_value)

return

fill_value_ = (
fill_value.to(unit_registry.m)
if isinstance(fill_value, unit_registry.Quantity)
and fill_value.check(unit_registry.m)
else fill_value
)
expected = attach_units(
xr.where(cond, strip_units(x), strip_units(fill_value_)), extract_units(x)
)
result = xr.where(cond, x, fill_value)

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="`where` strips units")
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.mm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
ids=repr,
)
@pytest.mark.parametrize("fill_value", (np.nan, 10.2))
def test_where_dataset(fill_value, unit, error, dtype):
array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
array2 = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.m
x = np.arange(10) * unit_registry.s

ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x})
cond = ds.x < 5 * unit_registry.s
# FIXME: this should work without wrapping in array()
fill_value = np.array(fill_value) * unit

if error is not None:
with pytest.raises(error):
xr.where(cond, ds, fill_value)

return

fill_value_ = (
fill_value.to(unit_registry.m)
if isinstance(fill_value, unit_registry.Quantity)
and fill_value.check(unit_registry.m)
else fill_value
)
expected = attach_units(
xr.where(cond, strip_units(ds), strip_units(fill_value_)), extract_units(ds)
)
result = xr.where(cond, ds, fill_value)

assert_equal_with_units(expected, result)


@pytest.mark.xfail(reason="pint does not implement `np.einsum`")
def test_dot_dataarray(dtype):
array1 = (
np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype)
* unit_registry.m
/ unit_registry.s
)
array2 = (
np.linspace(10, 20, 10 * 20).reshape(10, 20).astype(dtype) * unit_registry.s
)

arr1 = xr.DataArray(data=array1, dims=("x", "y"))
arr2 = xr.DataArray(data=array2, dims=("y", "z"))

expected = array1.dot(array2)
result = xr.dot(arr1, arr2)

assert_equal_with_units(expected, result)


class TestDataArray:
@pytest.mark.filterwarnings("error:::pint[.*]")
@pytest.mark.parametrize(