diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 706ccbde8c4..19af7f6c3cd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2853,6 +2853,7 @@ def interp( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, + method_non_numeric: str = "nearest", **coords_kwargs: Any, ) -> "Dataset": """Multidimensional interpolation of Dataset. @@ -2877,6 +2878,9 @@ def interp( Additional keyword arguments passed to scipy's interpolator. Valid options and their behavior depend on if 1-dimensional or multi-dimensional interpolation is used. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. **coords_kwargs : {dim: coordinate, ...}, optional The keyword arguments form of ``coords``. One of coords or coords_kwargs must be provided. @@ -3034,6 +3038,7 @@ def _validate_interp_indexer(x, new_x): } variables: Dict[Hashable, Variable] = {} + to_reindex: Dict[Hashable, Variable] = {} for name, var in obj._variables.items(): if name in indexers: continue @@ -3043,20 +3048,45 @@ def _validate_interp_indexer(x, new_x): else: use_indexers = validated_indexers - if var.dtype.kind in "uifc": + dtype_kind = var.dtype.kind + if dtype_kind in "uifc": + # For normal number types do the interpolation: var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims} variables[name] = missing.interp(var, var_indexers, method, **kwargs) + elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): + # For types that we do not understand do stepwise + # interpolation to avoid modifying the elements. + # Use reindex_variables instead because it supports + # booleans and objects and retains the dtype but inside + # this loop there might be some duplicate code that slows it + # down, therefore collect these signals and run it later: + to_reindex[name] = var elif all(d not in indexers for d in var.dims): - # keep unrelated object array + # For anything else we can only keep variables if they + # are not dependent on any coords that are being + # interpolated along: variables[name] = var + if to_reindex: + # Reindex variables: + variables_reindex = alignment.reindex_variables( + variables=to_reindex, + sizes=obj.sizes, + indexes=obj.xindexes, + indexers={k: v[-1] for k, v in validated_indexers.items()}, + method=method_non_numeric, + )[0] + variables.update(variables_reindex) + + # Get the coords that also exist in the variables: coord_names = obj._coord_names & variables.keys() + # Get the indexes that are not being interpolated along: indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} selected = self._replace_with_new_dims( variables.copy(), coord_names, indexes=indexes ) - # attach indexer as coordinate + # Attach indexer as coordinate variables.update(indexers) for k, v in indexers.items(): assert isinstance(v, Variable) @@ -3077,6 +3107,7 @@ def interp_like( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, + method_non_numeric: str = "nearest", ) -> "Dataset": """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -3098,6 +3129,9 @@ def interp_like( values. kwargs : dict, optional Additional keyword passed to scipy's interpolator. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. Returns ------- @@ -3133,7 +3167,13 @@ def interp_like( # We do not support interpolation along object coordinate. # reindex instead. ds = self.reindex(object_coords) - return ds.interp(numeric_coords, method, assume_sorted, kwargs) + return ds.interp( + coords=numeric_coords, + method=method, + assume_sorted=assume_sorted, + kwargs=kwargs, + method_non_numeric=method_non_numeric, + ) # Helper methods for rename() def _rename_vars(self, name_dict, dims_dict): diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 97a6d236f0a..ab023dc1558 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -416,15 +416,19 @@ def test_errors(use_dask): @requires_scipy def test_dtype(): - ds = xr.Dataset( - {"var1": ("x", [0, 1, 2]), "var2": ("x", ["a", "b", "c"])}, - coords={"x": [0.1, 0.2, 0.3], "z": ("x", ["a", "b", "c"])}, - ) - actual = ds.interp(x=[0.15, 0.25]) - assert "var1" in actual - assert "var2" not in actual - # object array should be dropped - assert "z" not in actual.coords + data_vars = dict( + a=("time", np.array([1, 1.25, 2])), + b=("time", np.array([True, True, False], dtype=bool)), + c=("time", np.array(["start", "start", "end"], dtype=str)), + ) + time = np.array([0, 0.25, 1], dtype=float) + expected = xr.Dataset(data_vars, coords=dict(time=time)) + actual = xr.Dataset( + {k: (dim, arr[[0, -1]]) for k, (dim, arr) in data_vars.items()}, + coords=dict(time=time[[0, -1]]), + ) + actual = actual.interp(time=time, method="linear") + assert_identical(expected, actual) @requires_scipy