Skip to content

apply_ufunc(dask='parallelized'): mix of chunked and non-chunked *args results in shape mismatch #2817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
christophrenkl opened this issue Mar 18, 2019 · 2 comments

Comments

@christophrenkl
Copy link

I have a xr.DataArray() with dimensions (time, latitude, longitude) which is wrapped around a dask.Array() and I would like to apply a bandpass filter along the time dimension using the function scipy.signal.filtfilt(b, a, x, axis=0) where b and a are coefficients and x is my xr.DataArray() with the data to be filtered. Since my array is big and this is function is only operating along one dimension, I want to apply it as a ufunc using dask='parallelized':

Here is an example

Code Sample

import numpy as np
import pandas as pd
from scipy.signal import butter, filtfilt
import xarray as xr

# create example xarray.DataArray() - note that my actual array is much bigger.
data = np.random.rand(100, 10, 10)
times = pd.date_range('1982-01-01', periods=100)
lons = np.arange(10)
lats = np.arange(10)

arr = xr.DataArray(data, coords=[times, lats, lons], dims=['time', 'lat', 'lon'])
# <xarray.DataArray (time: 100, lat: 10, lon: 10)>
# array([[[0.922799, 0.533868, ..., 0.144572, 0.894127],
#         [0.314781, 0.74535 , ..., 0.644554, 0.522697],
#         ...,
#         [0.072118, 0.347918, ..., 0.276387, 0.748218],
#         [0.97935 , 0.600887, ..., 0.213457, 0.904182]]])
# Coordinates:
#   * time     (time) datetime64[ns] 1982-01-01 1982-01-02 ... 1982-04-10
#   * lat      (lat) int64 0 1 2 3 4 5 6 7 8 9
#   * lon      (lon) int64 0 1 2 3 4 5 6 7 8 9

# construct Butterworth filter
b, a = butter(3, [2/50, 2/5], 'bandpass')
# b = array([ 0.07674591,  0.        , -0.23023772,  0.        ,  0.23023772,
#             0.        , -0.07674591])
# a = array([ 1.        , -3.47676086,  5.08018486, -4.23100528,  2.23928617,
#            -0.69437338,  0.08427357])

# apply filtfilt along time dimension (axis=0)
filtered = xr.apply_ufunc(filtfilt,
                          b, a, arr.chunk(),
                          dask='parallelized',
                          output_dtypes=[arr.dtype],
                          kwargs={'axis': 0}).compute()

Problem description

I am getting the error message below. As far as I understand it, the numpy arrays a and b have a shape=(7,) which does not match the shape of the chunks of arr. This mismatch causes the problem.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-0dcd81577f52> in <module>
      3                           dask='parallelized',
      4                           output_dtypes=[arr.dtype],
----> 5                           kwargs={'axis': 0})
      6 
      7 filtered

~/Software/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in apply_ufunc(func, *args, **kwargs)
    985                                      signature=signature,
    986                                      join=join,
--> 987                                      exclude_dims=exclude_dims)
    988     elif any(isinstance(a, Variable) for a in args):
    989         return variables_ufunc(*args)

~/Software/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in apply_dataarray_ufunc(func, *args, **kwargs)
    209 
    210     data_vars = [getattr(a, 'variable', a) for a in args]
--> 211     result_var = func(*data_vars)
    212 
    213     if signature.num_outputs > 1:

~/Software/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, *args, **kwargs)
    559             raise ValueError('unknown setting for dask array handling in '
    560                              'apply_ufunc: {}'.format(dask))
--> 561     result_data = func(*input_data)
    562 
    563     if signature.num_outputs == 1:

~/Software/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in func(*arrays)
    553                 return _apply_with_dask_atop(
    554                     numpy_func, arrays, input_dims, output_dims,
--> 555                     signature, output_dtypes, output_sizes)
    556         elif dask == 'allowed':
    557             pass

~/Software/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in _apply_with_dask_atop(func, args, input_dims, output_dims, signature, output_dtypes, output_sizes)
    654 
    655     return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True,
--> 656                    new_axes=output_sizes)
    657 
    658 

~/Software/anaconda3/lib/python3.6/site-packages/dask/array/top.py in atop(func, out_ind, *args, **kwargs)
    471         raise ValueError("Must specify dtype of output array")
    472 
--> 473     chunkss, arrays = unify_chunks(*args)
    474     for k, v in new_axes.items():
    475         chunkss[k] = (v,)

~/Software/anaconda3/lib/python3.6/site-packages/dask/array/core.py in unify_chunks(*args, **kwargs)
   2568                            for n, j in enumerate(i))
   2569             if chunks != a.chunks and all(a.chunks):
-> 2570                 arrays.append(a.rechunk(chunks))
   2571             else:
   2572                 arrays.append(a)

~/Software/anaconda3/lib/python3.6/site-packages/dask/array/core.py in rechunk(self, chunks, threshold, block_size_limit)
   1767         """ See da.rechunk for docstring """
   1768         from . import rechunk   # avoid circular import
-> 1769         return rechunk(self, chunks, threshold, block_size_limit)
   1770 
   1771     @property

~/Software/anaconda3/lib/python3.6/site-packages/dask/array/rechunk.py in rechunk(x, chunks, threshold, block_size_limit)
    223                        for lc, rc in zip(chunks, x.chunks))
    224     chunks = normalize_chunks(chunks, x.shape, limit=block_size_limit,
--> 225                               dtype=x.dtype, previous_chunks=x.chunks)
    226 
    227     if chunks == x.chunks:

~/Software/anaconda3/lib/python3.6/site-packages/dask/array/core.py in normalize_chunks(chunks, shape, limit, dtype, previous_chunks)
   2013                    for c, s in zip(map(sum, chunks), shape)):
   2014             raise ValueError("Chunks do not add up to shape. "
-> 2015                              "Got chunks=%s, shape=%s" % (chunks, shape))
   2016 
   2017     return tuple(tuple(int(x) if not math.isnan(x) else x for x in c) for c in chunks)

ValueError: Chunks do not add up to shape. Got chunks=((10,),), shape=(7,)

Expected Output

I would expect that a and b just get passed along with every chunk. I found a similar issue #1697, but the proposed workaround of passing non-chunked objects as kwargs does not work in case of filtfilt because a and b are needed before arr in the function call.

Output of xr.show_versions()


INSTALLED VERSIONS
------------------
commit: None
python: 3.6.8 |Anaconda custom (64-bit)| (default, Dec 30 2018, 01:22:34) 
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 4.15.0-46-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_CA.UTF-8
LOCALE: en_CA.UTF-8
libhdf5: 1.10.2
libnetcdf: 4.6.1

xarray: 0.11.3
pandas: 0.23.4
numpy: 1.15.2
scipy: 1.1.0
netCDF4: 1.4.1
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.0.3.4
PseudonetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: 1.2.1
cyordereddict: None
dask: 1.0.0
distributed: 1.25.2
matplotlib: 3.0.2
cartopy: 0.17.0
seaborn: 0.9.0
setuptools: 40.6.3
pip: 18.1
conda: 4.6.8
pytest: 4.1.0
IPython: 7.2.0
sphinx: 1.8.3

@shoyer
Copy link
Member

shoyer commented Mar 18, 2019

The problem is that a and b don't have an aligned shape with arr -- the last dimension has the wrong shape. Basically, this function doesn't fit in the "ufunc" model.

Something like this would probably work:

from functools import partial
filtered = xr.apply_ufunc(partial(filtfilt, b, a),
                          arr.chunk(),
                          dask='parallelized',
                          output_dtypes=[arr.dtype],
                          kwargs={'axis': 0}).compute()

@christophrenkl
Copy link
Author

Thanks, @shoyer, the suggested solution works like a charm. I am just referencing #2808, in case this can serve as an example in an apply_ufunc() tutorial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants