|
10 | 10 | import pandas as pd
|
11 | 11 |
|
12 | 12 | from xarray.core import formatting, nputils, utils
|
| 13 | +from xarray.core.coordinate_transform import CoordinateTransform |
13 | 14 | from xarray.core.indexing import (
|
| 15 | + CoordinateTransformIndexingAdapter, |
14 | 16 | IndexSelResult,
|
15 | 17 | PandasIndexingAdapter,
|
16 | 18 | PandasMultiIndexingAdapter,
|
@@ -1377,6 +1379,125 @@ def rename(self, name_dict, dims_dict):
|
1377 | 1379 | )
|
1378 | 1380 |
|
1379 | 1381 |
|
| 1382 | +class CoordinateTransformIndex(Index): |
| 1383 | + """Helper class for creating Xarray indexes based on coordinate transforms. |
| 1384 | +
|
| 1385 | + EXPERIMENTAL (not ready for public use yet). |
| 1386 | +
|
| 1387 | + - wraps a :py:class:`CoordinateTransform` instance |
| 1388 | + - takes care of creating the index (lazy) coordinates |
| 1389 | + - supports point-wise label-based selection |
| 1390 | + - supports exact alignment only, by comparing indexes based on their transform |
| 1391 | + (not on their explicit coordinate labels) |
| 1392 | +
|
| 1393 | + """ |
| 1394 | + |
| 1395 | + transform: CoordinateTransform |
| 1396 | + |
| 1397 | + def __init__( |
| 1398 | + self, |
| 1399 | + transform: CoordinateTransform, |
| 1400 | + ): |
| 1401 | + self.transform = transform |
| 1402 | + |
| 1403 | + def create_variables( |
| 1404 | + self, variables: Mapping[Any, Variable] | None = None |
| 1405 | + ) -> IndexVars: |
| 1406 | + from xarray.core.variable import Variable |
| 1407 | + |
| 1408 | + new_variables = {} |
| 1409 | + |
| 1410 | + for name in self.transform.coord_names: |
| 1411 | + # copy attributes, if any |
| 1412 | + attrs: Mapping[Hashable, Any] | None |
| 1413 | + |
| 1414 | + if variables is not None and name in variables: |
| 1415 | + var = variables[name] |
| 1416 | + attrs = var.attrs |
| 1417 | + else: |
| 1418 | + attrs = None |
| 1419 | + |
| 1420 | + data = CoordinateTransformIndexingAdapter(self.transform, name) |
| 1421 | + new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) |
| 1422 | + |
| 1423 | + return new_variables |
| 1424 | + |
| 1425 | + def isel( |
| 1426 | + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] |
| 1427 | + ) -> Self | None: |
| 1428 | + # TODO: support returning a new index (e.g., possible to re-calculate the |
| 1429 | + # the transform or calculate another transform on a reduced dimension space) |
| 1430 | + return None |
| 1431 | + |
| 1432 | + def sel( |
| 1433 | + self, labels: dict[Any, Any], method=None, tolerance=None |
| 1434 | + ) -> IndexSelResult: |
| 1435 | + from xarray.core.dataarray import DataArray |
| 1436 | + from xarray.core.variable import Variable |
| 1437 | + |
| 1438 | + if method != "nearest": |
| 1439 | + raise ValueError( |
| 1440 | + "CoordinateTransformIndex only supports selection with method='nearest'" |
| 1441 | + ) |
| 1442 | + |
| 1443 | + labels_set = set(labels) |
| 1444 | + coord_names_set = set(self.transform.coord_names) |
| 1445 | + |
| 1446 | + missing_labels = coord_names_set - labels_set |
| 1447 | + if missing_labels: |
| 1448 | + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) |
| 1449 | + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") |
| 1450 | + |
| 1451 | + label0_obj = next(iter(labels.values())) |
| 1452 | + dim_size0 = getattr(label0_obj, "sizes", {}) |
| 1453 | + |
| 1454 | + is_xr_obj = [ |
| 1455 | + isinstance(label, DataArray | Variable) for label in labels.values() |
| 1456 | + ] |
| 1457 | + if not all(is_xr_obj): |
| 1458 | + raise TypeError( |
| 1459 | + "CoordinateTransformIndex only supports advanced (point-wise) indexing " |
| 1460 | + "with either xarray.DataArray or xarray.Variable objects." |
| 1461 | + ) |
| 1462 | + dim_size = [getattr(label, "sizes", {}) for label in labels.values()] |
| 1463 | + if any(ds != dim_size0 for ds in dim_size): |
| 1464 | + raise ValueError( |
| 1465 | + "CoordinateTransformIndex only supports advanced (point-wise) indexing " |
| 1466 | + "with xarray.DataArray or xarray.Variable objects of macthing dimensions." |
| 1467 | + ) |
| 1468 | + |
| 1469 | + coord_labels = { |
| 1470 | + name: labels[name].values for name in self.transform.coord_names |
| 1471 | + } |
| 1472 | + dim_positions = self.transform.reverse(coord_labels) |
| 1473 | + |
| 1474 | + results: dict[str, Variable | DataArray] = {} |
| 1475 | + dims0 = tuple(dim_size0) |
| 1476 | + for dim, pos in dim_positions.items(): |
| 1477 | + # TODO: rounding the decimal positions is not always the behavior we expect |
| 1478 | + # (there are different ways to represent implicit intervals) |
| 1479 | + # we should probably make this customizable. |
| 1480 | + pos = np.round(pos).astype("int") |
| 1481 | + if isinstance(label0_obj, Variable): |
| 1482 | + results[dim] = Variable(dims0, pos) |
| 1483 | + else: |
| 1484 | + # dataarray |
| 1485 | + results[dim] = DataArray(pos, dims=dims0) |
| 1486 | + |
| 1487 | + return IndexSelResult(results) |
| 1488 | + |
| 1489 | + def equals(self, other: Self) -> bool: |
| 1490 | + return self.transform.equals(other.transform) |
| 1491 | + |
| 1492 | + def rename( |
| 1493 | + self, |
| 1494 | + name_dict: Mapping[Any, Hashable], |
| 1495 | + dims_dict: Mapping[Any, Hashable], |
| 1496 | + ) -> Self: |
| 1497 | + # TODO: maybe update self.transform coord_names, dim_size and dims attributes |
| 1498 | + return self |
| 1499 | + |
| 1500 | + |
1380 | 1501 | def create_default_index_implicit(
|
1381 | 1502 | dim_variable: Variable,
|
1382 | 1503 | all_variables: Mapping | Iterable[Hashable] | None = None,
|
|
0 commit comments