-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Flexible coordinate transform #9543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4a595df
0b545cf
8af6614
e9a11ef
0b3fd9e
acf1c47
e101585
09667c5
0a5b798
b6b9175
4c7ce28
5cfb1af
632c71b
ae8b318
1c425e3
952faa7
03fdc90
406b03b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from collections.abc import Hashable, Iterable, Mapping | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
|
||
class CoordinateTransform: | ||
"""Abstract coordinate transform with dimension & coordinate names. | ||
|
||
EXPERIMENTAL (not ready for public use yet). | ||
|
||
""" | ||
|
||
coord_names: tuple[Hashable, ...] | ||
dims: tuple[str, ...] | ||
dim_size: dict[str, int] | ||
dtype: Any | ||
|
||
def __init__( | ||
self, | ||
coord_names: Iterable[Hashable], | ||
dim_size: Mapping[str, int], | ||
dtype: Any = None, | ||
): | ||
self.coord_names = tuple(coord_names) | ||
self.dims = tuple(dim_size) | ||
self.dim_size = dict(dim_size) | ||
|
||
if dtype is None: | ||
dtype = np.dtype(np.float64) | ||
self.dtype = dtype | ||
|
||
def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: | ||
"""Perform grid -> world coordinate transformation. | ||
|
||
Parameters | ||
---------- | ||
dim_positions : dict | ||
Grid location(s) along each dimension (axis). | ||
|
||
Returns | ||
------- | ||
coord_labels : dict | ||
World coordinate labels. | ||
|
||
""" | ||
# TODO: cache the results in order to avoid re-computing | ||
# all labels when accessing the values of each coordinate one at a time | ||
raise NotImplementedError | ||
|
||
def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: | ||
"""Perform world -> grid coordinate reverse transformation. | ||
|
||
Parameters | ||
---------- | ||
labels : dict | ||
World coordinate labels. | ||
|
||
Returns | ||
------- | ||
dim_positions : dict | ||
Grid relative location(s) along each dimension (axis). | ||
|
||
""" | ||
raise NotImplementedError | ||
|
||
def equals(self, other: "CoordinateTransform") -> bool: | ||
"""Check equality with another CoordinateTransform of the same kind.""" | ||
raise NotImplementedError | ||
|
||
def generate_coords( | ||
self, dims: tuple[str, ...] | None = None | ||
) -> dict[Hashable, Any]: | ||
"""Compute all coordinate labels at once.""" | ||
if dims is None: | ||
dims = self.dims | ||
|
||
positions = np.meshgrid( | ||
*[np.arange(self.dim_size[d]) for d in dims], | ||
indexing="ij", | ||
) | ||
dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} | ||
|
||
return self.forward(dim_positions) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,9 @@ | |
import pandas as pd | ||
|
||
from xarray.core import formatting, nputils, utils | ||
from xarray.core.coordinate_transform import CoordinateTransform | ||
from xarray.core.indexing import ( | ||
CoordinateTransformIndexingAdapter, | ||
IndexSelResult, | ||
PandasIndexingAdapter, | ||
PandasMultiIndexingAdapter, | ||
|
@@ -1377,6 +1379,125 @@ def rename(self, name_dict, dims_dict): | |
) | ||
|
||
|
||
class CoordinateTransformIndex(Index): | ||
"""Helper class for creating Xarray indexes based on coordinate transforms. | ||
|
||
EXPERIMENTAL (not ready for public use yet). | ||
|
||
- wraps a :py:class:`CoordinateTransform` instance | ||
- takes care of creating the index (lazy) coordinates | ||
- supports point-wise label-based selection | ||
- supports exact alignment only, by comparing indexes based on their transform | ||
(not on their explicit coordinate labels) | ||
|
||
""" | ||
|
||
transform: CoordinateTransform | ||
|
||
def __init__( | ||
self, | ||
transform: CoordinateTransform, | ||
): | ||
self.transform = transform | ||
|
||
def create_variables( | ||
self, variables: Mapping[Any, Variable] | None = None | ||
) -> IndexVars: | ||
from xarray.core.variable import Variable | ||
|
||
new_variables = {} | ||
|
||
for name in self.transform.coord_names: | ||
# copy attributes, if any | ||
attrs: Mapping[Hashable, Any] | None | ||
|
||
if variables is not None and name in variables: | ||
var = variables[name] | ||
attrs = var.attrs | ||
else: | ||
attrs = None | ||
|
||
data = CoordinateTransformIndexingAdapter(self.transform, name) | ||
new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) | ||
|
||
return new_variables | ||
|
||
def isel( | ||
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] | ||
) -> Self | None: | ||
# TODO: support returning a new index (e.g., possible to re-calculate the | ||
# the transform or calculate another transform on a reduced dimension space) | ||
return None | ||
|
||
def sel( | ||
self, labels: dict[Any, Any], method=None, tolerance=None | ||
) -> IndexSelResult: | ||
from xarray.core.dataarray import DataArray | ||
from xarray.core.variable import Variable | ||
|
||
if method != "nearest": | ||
raise ValueError( | ||
"CoordinateTransformIndex only supports selection with method='nearest'" | ||
) | ||
|
||
labels_set = set(labels) | ||
coord_names_set = set(self.transform.coord_names) | ||
|
||
missing_labels = coord_names_set - labels_set | ||
if missing_labels: | ||
missing_labels_str = ",".join([f"{name}" for name in missing_labels]) | ||
raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") | ||
|
||
label0_obj = next(iter(labels.values())) | ||
dim_size0 = getattr(label0_obj, "sizes", {}) | ||
|
||
is_xr_obj = [ | ||
isinstance(label, DataArray | Variable) for label in labels.values() | ||
] | ||
if not all(is_xr_obj): | ||
raise TypeError( | ||
"CoordinateTransformIndex only supports advanced (point-wise) indexing " | ||
"with either xarray.DataArray or xarray.Variable objects." | ||
) | ||
dim_size = [getattr(label, "sizes", {}) for label in labels.values()] | ||
if any(ds != dim_size0 for ds in dim_size): | ||
raise ValueError( | ||
"CoordinateTransformIndex only supports advanced (point-wise) indexing " | ||
"with xarray.DataArray or xarray.Variable objects of macthing dimensions." | ||
) | ||
|
||
coord_labels = { | ||
name: labels[name].values for name in self.transform.coord_names | ||
} | ||
dim_positions = self.transform.reverse(coord_labels) | ||
benbovy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
results: dict[str, Variable | DataArray] = {} | ||
dims0 = tuple(dim_size0) | ||
for dim, pos in dim_positions.items(): | ||
# TODO: rounding the decimal positions is not always the behavior we expect | ||
# (there are different ways to represent implicit intervals) | ||
# we should probably make this customizable. | ||
pos = np.round(pos).astype("int") | ||
Comment on lines
+1477
to
+1480
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is important I think. If the coordinates values correspond to the physical values at the top/left pixel corners in the 2D case, we may rather want |
||
if isinstance(label0_obj, Variable): | ||
results[dim] = Variable(dims0, pos) | ||
else: | ||
# dataarray | ||
results[dim] = DataArray(pos, dims=dims0) | ||
|
||
return IndexSelResult(results) | ||
|
||
def equals(self, other: Self) -> bool: | ||
return self.transform.equals(other.transform) | ||
|
||
def rename( | ||
self, | ||
name_dict: Mapping[Any, Hashable], | ||
dims_dict: Mapping[Any, Hashable], | ||
) -> Self: | ||
# TODO: maybe update self.transform coord_names, dim_size and dims attributes | ||
return self | ||
|
||
|
||
def create_default_index_implicit( | ||
dim_variable: Variable, | ||
all_variables: Mapping | Iterable[Hashable] | None = None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
from collections.abc import Hashable | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import xarray as xr | ||
from xarray.core.coordinate_transform import CoordinateTransform | ||
from xarray.core.indexes import CoordinateTransformIndex | ||
from xarray.tests import assert_equal | ||
|
||
|
||
class SimpleCoordinateTransform(CoordinateTransform): | ||
"""Simple uniform scale transform in a 2D space (x/y coordinates).""" | ||
|
||
def __init__(self, shape: tuple[int, int], scale: float, dtype: Any = None): | ||
super().__init__(("x", "y"), {"x": shape[1], "y": shape[0]}, dtype=dtype) | ||
|
||
self.scale = scale | ||
|
||
# array dimensions in reverse order (y = rows, x = cols) | ||
self.xy_dims = tuple(self.dims) | ||
self.dims = (self.dims[1], self.dims[0]) | ||
|
||
def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: | ||
assert set(dim_positions) == set(self.dims) | ||
return {dim: dim_positions[dim] * self.scale for dim in self.xy_dims} | ||
|
||
def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: | ||
return {dim: coord_labels[dim] / self.scale for dim in self.xy_dims} | ||
|
||
def equals(self, other: "CoordinateTransform") -> bool: | ||
if not isinstance(other, SimpleCoordinateTransform): | ||
return False | ||
return self.scale == other.scale | ||
|
||
def __repr__(self) -> str: | ||
return f"Scale({self.scale})" | ||
|
||
|
||
def test_abstract_coordinate_transform() -> None: | ||
tr = CoordinateTransform(["x"], {"x": 5}) | ||
|
||
with pytest.raises(NotImplementedError): | ||
tr.forward({"x": [1, 2]}) | ||
|
||
with pytest.raises(NotImplementedError): | ||
tr.reverse({"x": [3.0, 4.0]}) | ||
|
||
with pytest.raises(NotImplementedError): | ||
tr.equals(CoordinateTransform(["x"], {"x": 5})) | ||
|
||
|
||
def test_coordinate_transform_init() -> None: | ||
tr = SimpleCoordinateTransform((4, 4), 2.0) | ||
|
||
assert tr.coord_names == ("x", "y") | ||
# array dimensions in reverse order (y = rows, x = cols) | ||
assert tr.dims == ("y", "x") | ||
assert tr.dim_size == {"x": 4, "y": 4} | ||
assert tr.dtype == np.dtype(np.float64) | ||
|
||
tr2 = SimpleCoordinateTransform((4, 4), 2.0, dtype=np.int64) | ||
assert tr2.dtype == np.dtype(np.int64) | ||
|
||
|
||
@pytest.mark.parametrize("dims", [None, ("y", "x")]) | ||
def test_coordinate_transform_generate_coords(dims) -> None: | ||
tr = SimpleCoordinateTransform((2, 2), 2.0) | ||
|
||
actual = tr.generate_coords(dims) | ||
expected = {"x": [[0.0, 2.0], [0.0, 2.0]], "y": [[0.0, 0.0], [2.0, 2.0]]} | ||
assert set(actual) == set(expected) | ||
np.testing.assert_array_equal(actual["x"], expected["x"]) | ||
np.testing.assert_array_equal(actual["y"], expected["y"]) | ||
|
||
|
||
def create_coords(scale: float, shape: tuple[int, int]) -> xr.Coordinates: | ||
"""Create x/y Xarray coordinate variables from a simple coordinate transform.""" | ||
tr = SimpleCoordinateTransform(shape, scale) | ||
index = CoordinateTransformIndex(tr) | ||
return xr.Coordinates.from_xindex(index) | ||
|
||
|
||
def test_coordinate_transform_variable() -> None: | ||
coords = create_coords(scale=2.0, shape=(2, 2)) | ||
|
||
assert coords["x"].dtype == np.dtype(np.float64) | ||
assert coords["y"].dtype == np.dtype(np.float64) | ||
assert coords["x"].shape == (2, 2) | ||
assert coords["y"].shape == (2, 2) | ||
|
||
np.testing.assert_array_equal(np.array(coords["x"]), [[0.0, 2.0], [0.0, 2.0]]) | ||
np.testing.assert_array_equal(np.array(coords["y"]), [[0.0, 0.0], [2.0, 2.0]]) | ||
|
||
def assert_repr(var: xr.Variable): | ||
assert ( | ||
repr(var._data) | ||
== "CoordinateTransformIndexingAdapter(transform=Scale(2.0))" | ||
) | ||
|
||
assert_repr(coords["x"].variable) | ||
assert_repr(coords["y"].variable) | ||
|
||
|
||
def test_coordinate_transform_variable_repr_inline() -> None: | ||
var = create_coords(scale=2.0, shape=(2, 2))["x"].variable | ||
|
||
actual = var._data._repr_inline_(70) # type: ignore[union-attr] | ||
assert actual == "0.0 2.0 0.0 2.0" | ||
|
||
# truncated inline repr | ||
var2 = create_coords(scale=2.0, shape=(10, 10))["x"].variable | ||
|
||
actual2 = var2._data._repr_inline_(70) # type: ignore[union-attr] | ||
assert ( | ||
actual2 == "0.0 2.0 4.0 6.0 8.0 10.0 12.0 ... 6.0 8.0 10.0 12.0 14.0 16.0 18.0" | ||
) | ||
|
||
|
||
def test_coordinate_transform_variable_basic_outer_indexing() -> None: | ||
var = create_coords(scale=2.0, shape=(4, 4))["x"].variable | ||
|
||
assert var[0, 0] == 0.0 | ||
assert var[0, 1] == 2.0 | ||
assert var[0, -1] == 6.0 | ||
np.testing.assert_array_equal(var[:, 0:2], [[0.0, 2.0]] * 4) | ||
|
||
with pytest.raises(IndexError, match="out of bounds index"): | ||
var[5] | ||
|
||
with pytest.raises(IndexError, match="out of bounds index"): | ||
var[-5] | ||
|
||
|
||
def test_coordinate_transform_variable_vectorized_indexing() -> None: | ||
var = create_coords(scale=2.0, shape=(4, 4))["x"].variable | ||
|
||
actual = var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] | ||
expected = xr.Variable("z", [0.0]) | ||
assert_equal(actual, expected) | ||
|
||
with pytest.raises(IndexError, match="out of bounds index"): | ||
var[{"x": xr.Variable("z", [5]), "y": xr.Variable("z", [5])}] | ||
|
||
|
||
def test_coordinate_transform_setitem_error() -> None: | ||
var = create_coords(scale=2.0, shape=(4, 4))["x"].variable | ||
|
||
# basic indexing | ||
with pytest.raises(TypeError, match="setting values is not supported"): | ||
var[0, 0] = 1.0 | ||
|
||
# outer indexing | ||
with pytest.raises(TypeError, match="setting values is not supported"): | ||
var[[0, 2], 0] = [1.0, 2.0] | ||
|
||
# vectorized indexing | ||
with pytest.raises(TypeError, match="setting values is not supported"): | ||
var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] = 1.0 | ||
|
||
|
||
def test_coordinate_transform_transpose() -> None: | ||
coords = create_coords(scale=2.0, shape=(2, 2)) | ||
|
||
actual = coords["x"].transpose().values | ||
expected = [[0.0, 0.0], [2.0, 2.0]] | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
|
||
def test_coordinate_transform_equals() -> None: | ||
ds1 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() | ||
ds2 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() | ||
ds3 = create_coords(scale=4.0, shape=(2, 2)).to_dataset() | ||
|
||
# cannot use `assert_equal()` test utility function here yet | ||
# (indexes invariant check are still based on IndexVariable, which | ||
# doesn't work with coordinate transform index coordinate variables) | ||
assert ds1.equals(ds2) | ||
assert not ds1.equals(ds3) | ||
|
||
|
||
def test_coordinate_transform_sel() -> None: | ||
ds = create_coords(scale=2.0, shape=(4, 4)).to_dataset() | ||
|
||
data = [ | ||
[0.0, 1.0, 2.0, 3.0], | ||
[4.0, 5.0, 6.0, 7.0], | ||
[8.0, 9.0, 10.0, 11.0], | ||
[12.0, 13.0, 14.0, 15.0], | ||
] | ||
ds["data"] = (("y", "x"), data) | ||
|
||
actual = ds.sel( | ||
x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5]), method="nearest" | ||
) | ||
expected = ds.isel(x=xr.Variable("z", [0, 3]), y=xr.Variable("z", [0, 0])) | ||
|
||
# cannot use `assert_equal()` test utility function here yet | ||
# (indexes invariant check are still based on IndexVariable, which | ||
# doesn't work with coordinate transform index coordinate variables) | ||
assert actual.equals(expected) | ||
|
||
with pytest.raises(ValueError, match=".*only supports selection.*nearest"): | ||
ds.sel(x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5])) | ||
|
||
with pytest.raises(ValueError, match="missing labels for coordinate.*y"): | ||
ds.sel(x=[0.5, 5.5], method="nearest") | ||
|
||
with pytest.raises(TypeError, match=".*only supports advanced.*indexing"): | ||
ds.sel(x=[0.5, 5.5], y=[0.0, 0.5], method="nearest") | ||
|
||
with pytest.raises(ValueError, match=".*only supports advanced.*indexing"): | ||
ds.sel( | ||
x=xr.Variable("z", [0.5, 5.5]), | ||
y=xr.Variable("z", [0.0, 0.5, 1.5]), | ||
method="nearest", | ||
) |
Unchanged files with check annotations Beta
attrs: _AttrsLike = None, | ||
): | ||
self._data = data | ||
self._dims = self._parse_dimensions(dims) | ||
Check warning on line 264 in xarray/namedarray/core.py
|
||
self._attrs = dict(attrs) if attrs else None | ||
def __init_subclass__(cls, **kwargs: Any) -> None: |
xp = get_array_namespace(data) | ||
if xp == np: | ||
# numpy currently doesn't have a astype: | ||
return data.astype(dtype, **kwargs) | ||
Check warning on line 232 in xarray/core/duck_array_ops.py
|
||
return xp.astype(data, dtype, **kwargs) | ||
return data.astype(dtype, **kwargs) | ||
# otherwise numpy unsigned ints will silently cast to the signed counterpart | ||
fill_value = fill_value.item() | ||
# passes if provided fill value fits in encoded on-disk type | ||
new_fill = encoded_dtype.type(fill_value) | ||
Check warning on line 348 in xarray/coding/variables.py
|
||
except OverflowError: | ||
encoded_kind_str = "signed" if encoded_dtype.kind == "i" else "unsigned" | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How hard would it be to support
tolerance
in some form? This is a common and useful form of error checking.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty tricky to support it here I think, probably better to handle it on a per case basis.
For basic transformations I guess it could be possible to calculate a single, uniform tolerance value in decimal array index units and validate the selected elements using those units (cheap). In other cases we would need to compute the forward transformation of the extracted array indices and then validate the selected elements based on distances in physical units (more expensive).
Also, there may be cases where the coordinates of a same transform object don’t have all the same physical units (e.g., both degrees and radians coordinates in an Astropy WCS object). Unless we forbid that in
xarray.CoordinateTransform
, it doesn’t make much sense to pass a single tolerance value. Passing a dictionarytolerance={coord_name: value}
doesn’t look very nice either IMO. A{unit: value}
dict looks better but adding explicit support for units here might be opening a can of worms.