diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index ad9da0fad33..dda89a28f61 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -2,6 +2,9 @@ import functools import operator +from abc import ABC, abstractclassmethod +from dataclasses import dataclass +from typing import Any import numpy as np import pandas as pd @@ -31,6 +34,35 @@ DimensionalityError = pint.errors.DimensionalityError +@dataclass +class UnitInfo(ABC): + # A unit (e.g. m) + unit: Any + # A different unit that previous unit can be transformed to (e.g. mm) + compatible_unit: Any + # A unit that can't be converted to unit (e.g. s) + incompatible_unit: Any + # Dimensionless unit + dimensionless: Any + unit_type: type + quantity_type: type + + @abstractclassmethod + def strip_units(quantity): + # Remove units from a quantity + pass + + @abstractclassmethod + def get_unit(quantity): + # Get the unit of a given quantity + pass + + @abstractclassmethod + def assert_equal(q1, q2): + # Assert that two quantities are equal + pass + + # make sure scalars are converted to 0d arrays so quantities can # always be treated like ndarrays unit_registry = pint.UnitRegistry(force_ndarray_like=True) @@ -42,6 +74,73 @@ ] +class PintInfo(UnitInfo): + ureg = unit_registry + + unit = ureg.m + compatible_unit = ureg.mm + incompatible_unit = ureg.s + dimensionless = ureg.dimensionless + + unit_type = ureg.Unit + quantity_type = ureg.Quantity + + @staticmethod + def strip_units(quantity): + return quantity.magnitude + + @staticmethod + def get_unit(quantity): + return quantity.units + + @staticmethod + def assert_equal(q1, q2): + assert np.all(q1 == q2) + + +""" +class AstropyInfo(UnitInfo): + unit = astropy.units.m + compatible_unit = astropy.units.mm + incompatible_unit = astropy.units.s + dimensionless = astropy.units.dimensionless_unscaled + + unit_type = astropy.units.UnitBase + quantity_type = astropy.units.Quantity + + def strip_units(quantity): + return quantity.value + + @staticmethod + def get_unit(quantity): + return quantity.unit + + @staticmethod + def assert_equal(q1, q2): + (q1 == q2).all() +""" + +unit_libs = [PintInfo] # + [AstropyInfo] +known_quantity_types = tuple(lib.quantity_type for lib in unit_libs) +known_unit_types = tuple(lib.unit_type for lib in unit_libs) + + +def get_unit_lib(obj): + for unit_lib in unit_libs: + if isinstance(obj, unit_lib.quantity_type): + return unit_lib + + return None + + +@pytest.fixture(params=unit_libs) +def unit_lib(request): + """ + A fixture to return a unit represented in different libraries. + """ + return request.param + + def is_compatible(unit1, unit2): def dimensionality(obj): if isinstance(obj, (unit_registry.Quantity, unit_registry.Unit)): @@ -78,16 +177,18 @@ def array_extract_units(obj): if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): obj = obj.data - try: - return obj.units - except AttributeError: + lib = get_unit_lib(obj) + if lib is not None: + return lib.get_unit(obj) + else: return None def array_strip_units(array): - try: - return array.magnitude - except AttributeError: + lib = get_unit_lib(array) + if lib is not None: + return lib.strip_units(array) + else: return array @@ -125,18 +226,20 @@ def extract_units(obj): units = {**vars_units, **coords_units} elif isinstance(obj, xr.Variable): vars_units = {None: array_extract_units(obj.data)} - units = {**vars_units} - elif isinstance(obj, Quantity): + else: vars_units = {None: array_extract_units(obj)} - units = {**vars_units} - else: - units = {} return units +def test_extract_units(unit_lib): + unit = unit_lib.unit + quantity = np.arange(10) * unit + assert extract_units(quantity) == {None: unit} + + def strip_units(obj): if isinstance(obj, xr.Dataset): data_vars = { @@ -153,7 +256,7 @@ def strip_units(obj): coords = { strip_units(name): ( (value.dims, array_strip_units(value.variable._data)) - if isinstance(value.data, Quantity) + if isinstance(value.data, known_quantity_types) else value # to preserve multiindexes ) for name, value in obj.coords.items() @@ -165,8 +268,8 @@ def strip_units(obj): elif isinstance(obj, xr.Variable): data = array_strip_units(obj.data) new_obj = obj.copy(data=data) - elif isinstance(obj, unit_registry.Quantity): - new_obj = obj.magnitude + elif isinstance(obj, known_quantity_types): + new_obj = array_strip_units(obj) elif isinstance(obj, (list, tuple)): return type(obj)(strip_units(elem) for elem in obj) else: @@ -175,6 +278,19 @@ def strip_units(obj): return new_obj +def test_strip_units(unit_lib): + unit = unit_lib.unit + # Array + quantity = np.arange(10) * unit + np.testing.assert_equal(strip_units(quantity), np.arange(10)) + + # DataArray + array = np.linspace(0, 10, 20) * unit + x = np.arange(20) + da = xr.DataArray(data=array, dims="x", coords={"x": x}) + assert type(strip_units(da).data) == np.ndarray + + def attach_units(obj, units): if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): units = units.get("data", None) or units.get(None, None) or 1 @@ -219,6 +335,21 @@ def attach_units(obj, units): return new_obj +def test_attach_units(unit_lib): + unit = unit = unit_lib.unit + # Array + array = np.arange(10) + quantity = attach_units(array, units={"data": unit}) + unit_lib.assert_equal(quantity, np.arange(10) * unit) + + # DataArray + array = np.linspace(0, 10, 20) + x = np.arange(20) + da = xr.DataArray(data=array, dims="x", coords={"x": x}) + da = attach_units(da, {None: unit}) + assert isinstance(da.data, unit_lib.quantity_type) + + def convert_units(obj, to): # preprocess to = { @@ -382,11 +513,12 @@ def __repr__(self): "coords", ), ) -def test_apply_ufunc_dataarray(variant, dtype): +def test_apply_ufunc_dataarray(variant, unit_lib, dtype): + unit = unit_lib.unit variants = { - "data": (unit_registry.m, 1, 1), - "dims": (1, unit_registry.m, 1), - "coords": (1, 1, unit_registry.m), + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), } data_unit, dim_unit, coord_unit = variants.get(variant) func = functools.partial( @@ -415,11 +547,12 @@ def test_apply_ufunc_dataarray(variant, dtype): "coords", ), ) -def test_apply_ufunc_dataset(variant, dtype): +def test_apply_ufunc_dataset(variant, unit_lib, dtype): + unit = unit_lib.unit variants = { - "data": (unit_registry.m, 1, 1), - "dims": (1, unit_registry.m, 1), - "coords": (1, 1, unit_registry.s), + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), } data_unit, dim_unit, coord_unit = variants.get(variant) @@ -451,12 +584,10 @@ def test_apply_ufunc_dataset(variant, dtype): "unit,error", ( pytest.param(1, DimensionalityError, id="no_unit"), - pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" - ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.mm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + pytest.param("dimensionless", DimensionalityError, id="dimensionless"), + pytest.param("incompatible_unit", DimensionalityError, id="incompatible_unit"), + pytest.param("compatible_unit", None, id="compatible_unit"), + pytest.param("unit", None, id="identical_unit"), ), ids=repr, ) @@ -471,9 +602,12 @@ def test_apply_ufunc_dataset(variant, dtype): ), ) @pytest.mark.parametrize("value", (10, dtypes.NA)) -def test_align_dataarray(value, variant, unit, error, dtype): +def test_align_dataarray(value, variant, unit, error, dtype, unit_lib): + if isinstance(unit, str): + unit = getattr(unit_lib, unit) + if variant == "coords" and ( - value != dtypes.NA or isinstance(unit, unit_registry.Unit) + value != dtypes.NA or isinstance(unit, known_unit_types) ): pytest.xfail( reason=( @@ -484,7 +618,7 @@ def test_align_dataarray(value, variant, unit, error, dtype): fill_value = dtypes.get_fill_value(dtype) if value == dtypes.NA else value - original_unit = unit_registry.m + original_unit = unit_lib.unit variants = { "data": ((original_unit, unit), (1, 1), (1, 1)), @@ -546,9 +680,9 @@ def test_align_dataarray(value, variant, unit, error, dtype): actual_a, actual_b = func(data_array1, data_array2) assert_units_equal(expected_a, actual_a) - assert_allclose(expected_a, actual_a) + assert_allclose(expected_a, actual_a, atol=0) assert_units_equal(expected_b, actual_b) - assert_allclose(expected_b, actual_b) + assert_allclose(expected_b, actual_b, atol=0) @pytest.mark.parametrize( @@ -577,7 +711,7 @@ def test_align_dataarray(value, variant, unit, error, dtype): @pytest.mark.parametrize("value", (10, dtypes.NA)) def test_align_dataset(value, unit, variant, error, dtype): if variant == "coords" and ( - value != dtypes.NA or isinstance(unit, unit_registry.Unit) + value != dtypes.NA or isinstance(unit, known_unit_types) ): pytest.xfail( reason=( @@ -650,15 +784,16 @@ def test_align_dataset(value, unit, variant, error, dtype): actual_a, actual_b = func(ds1, ds2) assert_units_equal(expected_a, actual_a) - assert_allclose(expected_a, actual_a) + assert_allclose(expected_a, actual_a, atol=0) assert_units_equal(expected_b, actual_b) - assert_allclose(expected_b, actual_b) + assert_allclose(expected_b, actual_b, atol=0) -def test_broadcast_dataarray(dtype): +def test_broadcast_dataarray(dtype, unit_lib): + unit = unit_lib.unit # uses align internally so more thorough tests are not needed - array1 = np.linspace(0, 10, 2) * unit_registry.Pa - array2 = np.linspace(0, 10, 3) * unit_registry.Pa + array1 = np.linspace(0, 10, 2) * unit + array2 = np.linspace(0, 10, 3) * unit a = xr.DataArray(data=array1, dims="x") b = xr.DataArray(data=array2, dims="y") @@ -677,10 +812,12 @@ def test_broadcast_dataarray(dtype): assert_identical(expected_b, actual_b) -def test_broadcast_dataset(dtype): +def test_broadcast_dataset(dtype, unit_lib): + unit = unit_lib.unit + compatible_unit = unit_lib.compatible_unit # uses align internally so more thorough tests are not needed - array1 = np.linspace(0, 10, 2) * unit_registry.Pa - array2 = np.linspace(0, 10, 3) * unit_registry.Pa + array1 = np.linspace(0, 10, 2) * unit + array2 = np.linspace(0, 10, 3) * unit x1 = np.arange(2) y1 = np.arange(3) @@ -693,8 +830,8 @@ def test_broadcast_dataset(dtype): ) other = xr.Dataset( data_vars={ - "a": ("x", array1.to(unit_registry.hPa)), - "b": ("y", array2.to(unit_registry.hPa)), + "a": ("x", array1.to(compatible_unit)), + "b": ("y", array2.to(compatible_unit)), }, coords={"x": x2, "y": y2}, ) @@ -1131,7 +1268,7 @@ def test_merge_dataarray(variant, unit, error, dtype): actual = xr.merge([arr1, arr2, arr3]) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize( @@ -1221,7 +1358,7 @@ def test_merge_dataset(variant, unit, error, dtype): actual = func([ds1, ds2, ds3]) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize( @@ -1551,7 +1688,7 @@ def test_aggregation(self, func, dtype): actual = func(variable) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) def test_aggregate_complex(self): variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m) @@ -1559,7 +1696,7 @@ def test_aggregate_complex(self): actual = variable.mean() assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize( "func", @@ -1617,7 +1754,7 @@ def test_numpy_methods(self, func, unit, error, dtype): actual = func(variable, *args, **kwargs) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize( "func", (method("item", 5), method("searchsorted", 5)), ids=repr @@ -1940,7 +2077,7 @@ def test_1d_math(self, func, unit, error, dtype): actual = func(variable, y) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize( "unit,error", @@ -2414,7 +2551,7 @@ def test_aggregation(self, func, dtype): actual = func(data_array) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize( "func", @@ -3497,7 +3634,7 @@ def test_interp_reindex(self, variant, func, dtype): actual = func(data_array, x=new_x) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.skip(reason="indexes don't support units") @pytest.mark.parametrize( @@ -3573,7 +3710,7 @@ def test_interp_reindex_like(self, variant, func, dtype): actual = func(data_array, other) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.skip(reason="indexes don't support units") @pytest.mark.parametrize( @@ -3878,7 +4015,7 @@ def test_computation_objects(self, func, variant, dtype): actual = func(data_array).mean() assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) def test_resample(self, dtype): array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m @@ -3978,21 +4115,21 @@ class TestDataset: ) def test_init(self, shared, unit, error, dtype): original_unit = unit_registry.m - scaled_unit = unit_registry.mm + compatible_unit = unit_registry.mm a = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa b = np.linspace(-1, 0, 10).astype(dtype) * unit_registry.degK values_a = np.arange(a.shape[0]) dim_a = values_a * original_unit - coord_a = dim_a.to(scaled_unit) + coord_a = dim_a.to(compatible_unit) values_b = np.arange(b.shape[0]) dim_b = values_b * unit coord_b = ( - dim_b.to(scaled_unit) - if unit_registry.is_compatible_with(dim_b, scaled_unit) - and unit != scaled_unit + dim_b.to(compatible_unit) + if unit_registry.is_compatible_with(dim_b, compatible_unit) + and unit != compatible_unit else dim_b * 1000 ) @@ -4131,7 +4268,7 @@ def test_aggregation(self, func, dtype): expected = attach_units(func(strip_units(ds)), units) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize("property", ("imag", "real")) def test_numpy_properties(self, property, dtype): @@ -5366,7 +5503,7 @@ def test_computation_objects(self, func, variant, dtype): actual = func(ds).mean(*args, **kwargs) assert_units_equal(expected, actual) - assert_allclose(expected, actual) + assert_allclose(expected, actual, atol=0) @pytest.mark.parametrize( "variant",