diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9f45474e7e7..c2c655b0893 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -736,6 +736,7 @@ def open_mfdataset( parallel=False, join="outer", attrs_file=None, + tolerance=0, **kwargs, ): """Open multiple files as a single dataset. @@ -849,6 +850,10 @@ def open_mfdataset( Path of the file used to read global attributes from. By default global attributes are read from the first file provided, with wildcard matches sorted by filename. + tolerance: numerical or dict of numerical + Value used to check equality with numerical tolerance between the + coordinates with the same name across all datasets. If a dict, maps + coordinate names to tolerance values. **kwargs : optional Additional arguments passed on to :py:func:`xarray.open_dataset`. @@ -950,6 +955,7 @@ def open_mfdataset( coords=coords, join=join, combine_attrs="drop", + tolerance=tolerance ) else: raise ValueError( diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 5b3a8bef6a5..6f9fa4a6384 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,6 +1,7 @@ import itertools from collections import Counter +import numpy as np import pandas as pd from . import dtypes @@ -44,7 +45,7 @@ def _infer_tile_ids_from_nested_list(entry, current_pos): yield current_pos, entry -def _infer_concat_order_from_coords(datasets): +def _infer_concat_order_from_coords(datasets, tolerance=0): concat_dims = [] tile_ids = [() for ds in datasets] @@ -64,9 +65,25 @@ def _infer_concat_order_from_coords(datasets): "inferring concatenation order" ) + # tolerance can be a dict with dims as key or a constant value + if isinstance(tolerance, dict): + atol = tolerance.get(dim, 0) + else: + atol = tolerance + # If dimension coordinate values are same on every dataset then # should be leaving this dimension alone (it's just a "bystander") - if not all(index.equals(indexes[0]) for index in indexes[1:]): + if not ( + all(index.equals(indexes[0]) for index in indexes[1:]) + or ( + atol > 0 + and all(index.is_numeric() for index in indexes) + and all( + np.allclose(index, indexes[0], atol=atol, rtol=0) + for index in indexes[1:] + ) + ) + ): # Infer order datasets should be arranged in along this dim concat_dims.append(dim) @@ -547,6 +564,7 @@ def combine_by_coords( fill_value=dtypes.NA, join="outer", combine_attrs="no_conflicts", + tolerance=0, ): """ Attempt to auto-magically combine the given datasets into one by using @@ -634,6 +652,10 @@ def combine_by_coords( the same name must also have the same value. - "override": skip comparing and copy attrs from the first dataset to the result. + tolerance: numerical or dict of numerical + Value used to check equality with numerical tolerance between the + coordinates with the same name across all datasets. If a dict, maps + coordinate names to tolerance values. Returns ------- @@ -757,7 +779,7 @@ def combine_by_coords( concatenated_grouped_by_data_vars = [] for vars, datasets_with_same_vars in grouped_by_vars: combined_ids, concat_dims = _infer_concat_order_from_coords( - list(datasets_with_same_vars) + list(datasets_with_same_vars), tolerance ) if fill_value is None: diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index fa1b749b1a7..53f9888a5f6 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -818,6 +818,31 @@ def test_combine_by_coords_incomplete_hypercube(self): with pytest.raises(ValueError): combine_by_coords([x1, x2, x3], fill_value=None) + def test_combine_by_coords_with_tolerance(self): + x = [0, 1, 2] + tol = 1e-7 + + x1 = x + tol * np.array([-0.5, 0, 0.5]) + ds1 = Dataset( + {"a": (("time", "x"), [[9, 0, 2]])}, coords={"x": x1, "time": [0]} + ) + + x2 = x + tol * np.array([+0.4, 0.3, 0.1]) + ds2 = Dataset( + {"a": (("time", "x"), [[6, 8, 3]])}, coords={"x": x2, "time": [1]} + ) + + # fail if tolerance is not properly implemented + combined = combine_by_coords([ds1, ds2], tolerance=tol) + # check that x1 was chosen + print(combined.x) + assert np.all(combined.x == x1) + assert len(combined.time) == 2 + + # fail if tolerance is not properly implemented + combined = combine_by_coords([ds1, ds2], tolerance={"x": tol}) + assert np.all(combined.x == x1) + assert len(combined.time) == 2 @requires_cftime def test_combine_by_coords_distant_cftime_dates():