Skip to content

Commit 7daad4f

Browse files
pums974Alexandre Poux
and
Alexandre Poux
authored
Implement interp for interpolating between chunks of data (dask) (#4155)
* Implement interp for interpolating between chunks of data (dask) * do not forget extra points at the end * add tests * add whats-new comment * fix isort / black * typo * update pull number * fix github pep8 warnigns * fix isort * clearer arguments in _dask_aware_interpnd * typo * fix for datetimelike index * chunked interpolation does not work for high order interpolation (quadratic or cubic) * fix whats new * remove a useless import * use Variable instead of InexVariable * avoid some list to tuple conversion * black fix * more comments to explain _compute_chunks * For orthogonal linear- and nearest-neighbor interpolation, the scalar interpolation can also be done sequentially * better detection of Advanced interpolation * implement support of unsorted interpolation destination * rework the tests * fix for datetime index (bug introduced with unsorted destination) * Variable is cheaber that DataArray * add warning if unsorted * simplify _compute_chunks * add ghosts point in order to make quadratic and cubic method work in a chunked direction * black * forgot to remove an exception in test_upsample_interpolate_dask * fix filtering out-of-order warning * use extrapolate to check external points * Revert "add ghosts point in order to make quadratic and cubic method work in a chunked direction" * Complete rewrite using blockwise * update whats-new.rst * reduce the diff * more decomposition of orthogonal interpolation * simplify _dask_aware_interpnd a little * fix dask interp when chunks are not aligned * continue simplifying _dask_aware_interpnd * update whats-new.rst * clean tests Co-authored-by: Alexandre Poux <[email protected]>
1 parent d514b12 commit 7daad4f

File tree

4 files changed

+272
-79
lines changed

4 files changed

+272
-79
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,10 @@ New Features
176176
Enhancements
177177
~~~~~~~~~~~~
178178
- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp`
179-
For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially
180-
rather than interpolating in multidimensional space. (:issue:`2223`)
179+
We performs independant interpolation sequentially rather than interpolating in
180+
one large multidimensional space. (:issue:`2223`)
181181
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
182+
- :py:meth:`DataArray.interp` now support interpolations over chunked dimensions (:pull:`4155`). By `Alexandre Poux <https://github.com/pums974>`_.
182183
- Major performance improvement for :py:meth:`Dataset.from_dataframe` when the
183184
dataframe has a MultiIndex (:pull:`4184`).
184185
By `Stephan Hoyer <https://github.com/shoyer>`_.

xarray/core/missing.py

Lines changed: 138 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -544,15 +544,6 @@ def _get_valid_fill_mask(arr, dim, limit):
544544
) <= limit
545545

546546

547-
def _assert_single_chunk(var, axes):
548-
for axis in axes:
549-
if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]:
550-
raise NotImplementedError(
551-
"Chunking along the dimension to be interpolated "
552-
"({}) is not yet supported.".format(axis)
553-
)
554-
555-
556547
def _localize(var, indexes_coords):
557548
""" Speed up for linear and nearest neighbor method.
558549
Only consider a subspace that is needed for the interpolation
@@ -617,49 +608,42 @@ def interp(var, indexes_coords, method, **kwargs):
617608
if not indexes_coords:
618609
return var.copy()
619610

620-
# simple speed up for the local interpolation
621-
if method in ["linear", "nearest"]:
622-
var, indexes_coords = _localize(var, indexes_coords)
623-
624611
# default behavior
625612
kwargs["bounds_error"] = kwargs.get("bounds_error", False)
626613

627-
# check if the interpolation can be done in orthogonal manner
628-
if (
629-
len(indexes_coords) > 1
630-
and method in ["linear", "nearest"]
631-
and all(dest[1].ndim == 1 for dest in indexes_coords.values())
632-
and len(set([d[1].dims[0] for d in indexes_coords.values()]))
633-
== len(indexes_coords)
634-
):
635-
# interpolate sequentially
636-
for dim, dest in indexes_coords.items():
637-
var = interp(var, {dim: dest}, method, **kwargs)
638-
return var
639-
640-
# target dimensions
641-
dims = list(indexes_coords)
642-
x, new_x = zip(*[indexes_coords[d] for d in dims])
643-
destination = broadcast_variables(*new_x)
644-
645-
# transpose to make the interpolated axis to the last position
646-
broadcast_dims = [d for d in var.dims if d not in dims]
647-
original_dims = broadcast_dims + dims
648-
new_dims = broadcast_dims + list(destination[0].dims)
649-
interped = interp_func(
650-
var.transpose(*original_dims).data, x, destination, method, kwargs
651-
)
614+
result = var
615+
# decompose the interpolation into a succession of independant interpolation
616+
for indexes_coords in decompose_interp(indexes_coords):
617+
var = result
618+
619+
# simple speed up for the local interpolation
620+
if method in ["linear", "nearest"]:
621+
var, indexes_coords = _localize(var, indexes_coords)
622+
623+
# target dimensions
624+
dims = list(indexes_coords)
625+
x, new_x = zip(*[indexes_coords[d] for d in dims])
626+
destination = broadcast_variables(*new_x)
627+
628+
# transpose to make the interpolated axis to the last position
629+
broadcast_dims = [d for d in var.dims if d not in dims]
630+
original_dims = broadcast_dims + dims
631+
new_dims = broadcast_dims + list(destination[0].dims)
632+
interped = interp_func(
633+
var.transpose(*original_dims).data, x, destination, method, kwargs
634+
)
652635

653-
result = Variable(new_dims, interped, attrs=var.attrs)
636+
result = Variable(new_dims, interped, attrs=var.attrs)
654637

655-
# dimension of the output array
656-
out_dims = OrderedSet()
657-
for d in var.dims:
658-
if d in dims:
659-
out_dims.update(indexes_coords[d][1].dims)
660-
else:
661-
out_dims.add(d)
662-
return result.transpose(*tuple(out_dims))
638+
# dimension of the output array
639+
out_dims = OrderedSet()
640+
for d in var.dims:
641+
if d in dims:
642+
out_dims.update(indexes_coords[d][1].dims)
643+
else:
644+
out_dims.add(d)
645+
result = result.transpose(*tuple(out_dims))
646+
return result
663647

664648

665649
def interp_func(var, x, new_x, method, kwargs):
@@ -706,21 +690,51 @@ def interp_func(var, x, new_x, method, kwargs):
706690
if isinstance(var, dask_array_type):
707691
import dask.array as da
708692

709-
_assert_single_chunk(var, range(var.ndim - len(x), var.ndim))
710-
chunks = var.chunks[: -len(x)] + new_x[0].shape
711-
drop_axis = range(var.ndim - len(x), var.ndim)
712-
new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim)
713-
return da.map_blocks(
714-
_interpnd,
693+
nconst = var.ndim - len(x)
694+
695+
out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim))
696+
697+
# blockwise args format
698+
x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)]
699+
x_arginds = [item for pair in x_arginds for item in pair]
700+
new_x_arginds = [
701+
[_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x
702+
]
703+
new_x_arginds = [item for pair in new_x_arginds for item in pair]
704+
705+
args = (
715706
var,
716-
x,
717-
new_x,
718-
func,
719-
kwargs,
707+
range(var.ndim),
708+
*x_arginds,
709+
*new_x_arginds,
710+
)
711+
712+
_, rechunked = da.unify_chunks(*args)
713+
714+
args = tuple([elem for pair in zip(rechunked, args[1::2]) for elem in pair])
715+
716+
new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]
717+
718+
new_axes = {
719+
var.ndim + i: new_x[0].chunks[i]
720+
if new_x[0].chunks is not None
721+
else new_x[0].shape[i]
722+
for i in range(new_x[0].ndim)
723+
}
724+
725+
# if usefull, re-use localize for each chunk of new_x
726+
localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None)
727+
728+
return da.blockwise(
729+
_dask_aware_interpnd,
730+
out_ind,
731+
*args,
732+
interp_func=func,
733+
interp_kwargs=kwargs,
734+
localize=localize,
735+
concatenate=True,
720736
dtype=var.dtype,
721-
chunks=chunks,
722-
new_axis=new_axis,
723-
drop_axis=drop_axis,
737+
new_axes=new_axes,
724738
)
725739

726740
return _interpnd(var, x, new_x, func, kwargs)
@@ -751,3 +765,67 @@ def _interpnd(var, x, new_x, func, kwargs):
751765
# move back the interpolation axes to the last position
752766
rslt = rslt.transpose(range(-rslt.ndim + 1, 1))
753767
return rslt.reshape(rslt.shape[:-1] + new_x[0].shape)
768+
769+
770+
def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):
771+
"""Wrapper for `_interpnd` through `blockwise`
772+
773+
The first half arrays in `coords` are original coordinates,
774+
the other half are destination coordinates
775+
"""
776+
n_x = len(coords) // 2
777+
nconst = len(var.shape) - n_x
778+
779+
# _interpnd expect coords to be Variables
780+
x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
781+
new_x = [
782+
Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x)
783+
for _x in coords[n_x:]
784+
]
785+
786+
if localize:
787+
# _localize expect var to be a Variable
788+
var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var)
789+
790+
indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)}
791+
792+
# simple speed up for the local interpolation
793+
var, indexes_coords = _localize(var, indexes_coords)
794+
x, new_x = zip(*[indexes_coords[d] for d in indexes_coords])
795+
796+
# put var back as a ndarray
797+
var = var.data
798+
799+
return _interpnd(var, x, new_x, interp_func, interp_kwargs)
800+
801+
802+
def decompose_interp(indexes_coords):
803+
"""Decompose the interpolation into a succession of independant interpolation keeping the order"""
804+
805+
dest_dims = [
806+
dest[1].dims if dest[1].ndim > 0 else [dim]
807+
for dim, dest in indexes_coords.items()
808+
]
809+
partial_dest_dims = []
810+
partial_indexes_coords = {}
811+
for i, index_coords in enumerate(indexes_coords.items()):
812+
partial_indexes_coords.update([index_coords])
813+
814+
if i == len(dest_dims) - 1:
815+
break
816+
817+
partial_dest_dims += [dest_dims[i]]
818+
other_dims = dest_dims[i + 1 :]
819+
820+
s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims}
821+
s_other_dims = {dim for dims in other_dims for dim in dims}
822+
823+
if not s_partial_dest_dims.intersection(s_other_dims):
824+
# this interpolation is orthogonal to the rest
825+
826+
yield partial_indexes_coords
827+
828+
partial_dest_dims = []
829+
partial_indexes_coords = {}
830+
831+
yield partial_indexes_coords

xarray/tests/test_dataarray.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,7 +3147,8 @@ def test_upsample_interpolate_regression_1605(self):
31473147

31483148
@requires_dask
31493149
@requires_scipy
3150-
def test_upsample_interpolate_dask(self):
3150+
@pytest.mark.parametrize("chunked_time", [True, False])
3151+
def test_upsample_interpolate_dask(self, chunked_time):
31513152
from scipy.interpolate import interp1d
31523153

31533154
xs = np.arange(6)
@@ -3158,6 +3159,8 @@ def test_upsample_interpolate_dask(self):
31583159
data = np.tile(z, (6, 3, 1))
31593160
array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time"))
31603161
chunks = {"x": 2, "y": 1}
3162+
if chunked_time:
3163+
chunks["time"] = 3
31613164

31623165
expected_times = times.to_series().resample("1H").asfreq().index
31633166
# Split the times into equal sub-intervals to simulate the 6 hour
@@ -3185,13 +3188,6 @@ def test_upsample_interpolate_dask(self):
31853188
# done here due to floating point arithmetic
31863189
assert_allclose(expected, actual, rtol=1e-16)
31873190

3188-
# Check that an error is raised if an attempt is made to interpolate
3189-
# over a chunked dimension
3190-
with raises_regex(
3191-
NotImplementedError, "Chunking along the dimension to be interpolated"
3192-
):
3193-
array.chunk({"time": 1}).resample(time="1H").interpolate("linear")
3194-
31953191
def test_align(self):
31963192
array = DataArray(
31973193
np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"]

0 commit comments

Comments
 (0)