Skip to content

Commit 4bbab48

Browse files
benbovydcherianmaxrjones
authored
Flexible coordinate transform (#9543)
* Add coordinate transform classes from prototype * lint, public API and docstrings * missing import * sel: convert inverse transform results to ints * sel: add todo note about rounding decimal pos * rename create_coordinates -> create_coords More consistent with the rest of Xarray API where `coords` is used everywhere. * add a Coordinates.from_transform convenient method * fix repr (extract subset values of any n-d array) * Apply suggestions from code review Co-authored-by: Max Jones <[email protected]> * remove specific create coordinates methods In favor of the more generic `Coordinates.from_xindex()`. * fix more typing issues * remove public imports: not ready yet for public use * add experimental notice in docstrings * add coordinate transform tests * typing fixes * update what's new --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Max Jones <[email protected]>
1 parent 84e81bc commit 4bbab48

File tree

7 files changed

+579
-2
lines changed

7 files changed

+579
-2
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ New Features
2828
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
2929
- support python 3.13 (no free-threading) (:issue:`9664`, :pull:`9681`)
3030
By `Justus Magin <https://github.com/keewis>`_.
31+
- Added experimental support for coordinate transforms (not ready for public use yet!) (:pull:`9543`)
32+
By `Benoit Bovy <https://github.com/benbovy>`_.
3133

3234
Breaking changes
3335
~~~~~~~~~~~~~~~~

xarray/core/coordinate_transform.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from collections.abc import Hashable, Iterable, Mapping
2+
from typing import Any
3+
4+
import numpy as np
5+
6+
7+
class CoordinateTransform:
8+
"""Abstract coordinate transform with dimension & coordinate names.
9+
10+
EXPERIMENTAL (not ready for public use yet).
11+
12+
"""
13+
14+
coord_names: tuple[Hashable, ...]
15+
dims: tuple[str, ...]
16+
dim_size: dict[str, int]
17+
dtype: Any
18+
19+
def __init__(
20+
self,
21+
coord_names: Iterable[Hashable],
22+
dim_size: Mapping[str, int],
23+
dtype: Any = None,
24+
):
25+
self.coord_names = tuple(coord_names)
26+
self.dims = tuple(dim_size)
27+
self.dim_size = dict(dim_size)
28+
29+
if dtype is None:
30+
dtype = np.dtype(np.float64)
31+
self.dtype = dtype
32+
33+
def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]:
34+
"""Perform grid -> world coordinate transformation.
35+
36+
Parameters
37+
----------
38+
dim_positions : dict
39+
Grid location(s) along each dimension (axis).
40+
41+
Returns
42+
-------
43+
coord_labels : dict
44+
World coordinate labels.
45+
46+
"""
47+
# TODO: cache the results in order to avoid re-computing
48+
# all labels when accessing the values of each coordinate one at a time
49+
raise NotImplementedError
50+
51+
def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
52+
"""Perform world -> grid coordinate reverse transformation.
53+
54+
Parameters
55+
----------
56+
labels : dict
57+
World coordinate labels.
58+
59+
Returns
60+
-------
61+
dim_positions : dict
62+
Grid relative location(s) along each dimension (axis).
63+
64+
"""
65+
raise NotImplementedError
66+
67+
def equals(self, other: "CoordinateTransform") -> bool:
68+
"""Check equality with another CoordinateTransform of the same kind."""
69+
raise NotImplementedError
70+
71+
def generate_coords(
72+
self, dims: tuple[str, ...] | None = None
73+
) -> dict[Hashable, Any]:
74+
"""Compute all coordinate labels at once."""
75+
if dims is None:
76+
dims = self.dims
77+
78+
positions = np.meshgrid(
79+
*[np.arange(self.dim_size[d]) for d in dims],
80+
indexing="ij",
81+
)
82+
dim_positions = {dim: positions[i] for i, dim in enumerate(dims)}
83+
84+
return self.forward(dim_positions)

xarray/core/coordinates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def from_xindex(cls, index: Index) -> Self:
392392
def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self:
393393
"""Wrap a pandas multi-index as Xarray coordinates (dimension + levels).
394394
395-
The returned coordinates can be directly assigned to a
395+
The returned coordinate variables can be directly assigned to a
396396
:py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the
397397
``coords`` argument of their constructor.
398398

xarray/core/indexes.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import pandas as pd
1111

1212
from xarray.core import formatting, nputils, utils
13+
from xarray.core.coordinate_transform import CoordinateTransform
1314
from xarray.core.indexing import (
15+
CoordinateTransformIndexingAdapter,
1416
IndexSelResult,
1517
PandasIndexingAdapter,
1618
PandasMultiIndexingAdapter,
@@ -1377,6 +1379,125 @@ def rename(self, name_dict, dims_dict):
13771379
)
13781380

13791381

1382+
class CoordinateTransformIndex(Index):
1383+
"""Helper class for creating Xarray indexes based on coordinate transforms.
1384+
1385+
EXPERIMENTAL (not ready for public use yet).
1386+
1387+
- wraps a :py:class:`CoordinateTransform` instance
1388+
- takes care of creating the index (lazy) coordinates
1389+
- supports point-wise label-based selection
1390+
- supports exact alignment only, by comparing indexes based on their transform
1391+
(not on their explicit coordinate labels)
1392+
1393+
"""
1394+
1395+
transform: CoordinateTransform
1396+
1397+
def __init__(
1398+
self,
1399+
transform: CoordinateTransform,
1400+
):
1401+
self.transform = transform
1402+
1403+
def create_variables(
1404+
self, variables: Mapping[Any, Variable] | None = None
1405+
) -> IndexVars:
1406+
from xarray.core.variable import Variable
1407+
1408+
new_variables = {}
1409+
1410+
for name in self.transform.coord_names:
1411+
# copy attributes, if any
1412+
attrs: Mapping[Hashable, Any] | None
1413+
1414+
if variables is not None and name in variables:
1415+
var = variables[name]
1416+
attrs = var.attrs
1417+
else:
1418+
attrs = None
1419+
1420+
data = CoordinateTransformIndexingAdapter(self.transform, name)
1421+
new_variables[name] = Variable(self.transform.dims, data, attrs=attrs)
1422+
1423+
return new_variables
1424+
1425+
def isel(
1426+
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
1427+
) -> Self | None:
1428+
# TODO: support returning a new index (e.g., possible to re-calculate the
1429+
# the transform or calculate another transform on a reduced dimension space)
1430+
return None
1431+
1432+
def sel(
1433+
self, labels: dict[Any, Any], method=None, tolerance=None
1434+
) -> IndexSelResult:
1435+
from xarray.core.dataarray import DataArray
1436+
from xarray.core.variable import Variable
1437+
1438+
if method != "nearest":
1439+
raise ValueError(
1440+
"CoordinateTransformIndex only supports selection with method='nearest'"
1441+
)
1442+
1443+
labels_set = set(labels)
1444+
coord_names_set = set(self.transform.coord_names)
1445+
1446+
missing_labels = coord_names_set - labels_set
1447+
if missing_labels:
1448+
missing_labels_str = ",".join([f"{name}" for name in missing_labels])
1449+
raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.")
1450+
1451+
label0_obj = next(iter(labels.values()))
1452+
dim_size0 = getattr(label0_obj, "sizes", {})
1453+
1454+
is_xr_obj = [
1455+
isinstance(label, DataArray | Variable) for label in labels.values()
1456+
]
1457+
if not all(is_xr_obj):
1458+
raise TypeError(
1459+
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
1460+
"with either xarray.DataArray or xarray.Variable objects."
1461+
)
1462+
dim_size = [getattr(label, "sizes", {}) for label in labels.values()]
1463+
if any(ds != dim_size0 for ds in dim_size):
1464+
raise ValueError(
1465+
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
1466+
"with xarray.DataArray or xarray.Variable objects of macthing dimensions."
1467+
)
1468+
1469+
coord_labels = {
1470+
name: labels[name].values for name in self.transform.coord_names
1471+
}
1472+
dim_positions = self.transform.reverse(coord_labels)
1473+
1474+
results: dict[str, Variable | DataArray] = {}
1475+
dims0 = tuple(dim_size0)
1476+
for dim, pos in dim_positions.items():
1477+
# TODO: rounding the decimal positions is not always the behavior we expect
1478+
# (there are different ways to represent implicit intervals)
1479+
# we should probably make this customizable.
1480+
pos = np.round(pos).astype("int")
1481+
if isinstance(label0_obj, Variable):
1482+
results[dim] = Variable(dims0, pos)
1483+
else:
1484+
# dataarray
1485+
results[dim] = DataArray(pos, dims=dims0)
1486+
1487+
return IndexSelResult(results)
1488+
1489+
def equals(self, other: Self) -> bool:
1490+
return self.transform.equals(other.transform)
1491+
1492+
def rename(
1493+
self,
1494+
name_dict: Mapping[Any, Hashable],
1495+
dims_dict: Mapping[Any, Hashable],
1496+
) -> Self:
1497+
# TODO: maybe update self.transform coord_names, dim_size and dims attributes
1498+
return self
1499+
1500+
13801501
def create_default_index_implicit(
13811502
dim_variable: Variable,
13821503
all_variables: Mapping | Iterable[Hashable] | None = None,

0 commit comments

Comments
 (0)