Skip to content

apply_ufunc(dask='parallelized') won't accept scalar *args #1697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
crusaderky opened this issue Nov 7, 2017 · 1 comment
Closed

apply_ufunc(dask='parallelized') won't accept scalar *args #1697

crusaderky opened this issue Nov 7, 2017 · 1 comment
Labels
Milestone

Comments

@crusaderky
Copy link
Contributor

As of xarray-0.10-rc1:

Works:

import xarray
import scipy.stats
a = xarray.DataArray([1,2], dims=['x'])

xarray.apply_ufunc(scipy.stats.norm.cdf, a, 0, 1)

<xarray.DataArray (x: 2)>
array([ 0.841345,  0.97725 ])
Dimensions without coordinates: x

Broken:

xarray.apply_ufunc(
    scipy.stats.norm.cdf, a.chunk(), 0, 1, dask='parallelized', output_dtypes=[a.dtype]
).compute()

IndexError                                Traceback (most recent call last)
<ipython-input-35-1d4025e1ebdb> in <module>()
----> 1 xarray.apply_ufunc(scipy.stats.norm.cdf, a.chunk(), 0, 1, dask='parallelized', output_dtypes=[a.dtype]).compute()

~/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in apply_ufunc(func, *args, **kwargs)
    913                                      join=join,
    914                                      exclude_dims=exclude_dims,
--> 915                                      keep_attrs=keep_attrs)
    916     elif any(isinstance(a, Variable) for a in args):
    917         return variables_ufunc(*args)

~/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in apply_dataarray_ufunc(func, *args, **kwargs)
    210 
    211     data_vars = [getattr(a, 'variable', a) for a in args]
--> 212     result_var = func(*data_vars)
    213 
    214     if signature.num_outputs > 1:

~/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, *args, **kwargs)
    561             raise ValueError('unknown setting for dask array handling in '
    562                              'apply_ufunc: {}'.format(dask))
--> 563     result_data = func(*input_data)
    564 
    565     if signature.num_outputs > 1:

~/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in <lambda>(*arrays)
    555             func = lambda *arrays: _apply_with_dask_atop(
    556                 numpy_func, arrays, input_dims, output_dims, signature,
--> 557                 output_dtypes, output_sizes)
    558         elif dask == 'allowed':
    559             pass

~/anaconda3/lib/python3.6/site-packages/xarray/core/computation.py in _apply_with_dask_atop(func, args, input_dims, output_dims, signature, output_dtypes, output_sizes)
    624                  for element in (arg, dims[-getattr(arg, 'ndim', 0):])]
    625     return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True,
--> 626                    new_axes=output_sizes)
    627 
    628 

~/anaconda3/lib/python3.6/site-packages/dask/array/core.py in atop(func, out_ind, *args, **kwargs)
   2231         raise ValueError("Must specify dtype of output array")
   2232 
-> 2233     chunkss, arrays = unify_chunks(*args)
   2234     for k, v in new_axes.items():
   2235         chunkss[k] = (v,)

~/anaconda3/lib/python3.6/site-packages/dask/array/core.py in unify_chunks(*args, **kwargs)
   2117             chunks = tuple(chunkss[j] if a.shape[n] > 1 else a.shape[n]
   2118                            if not np.isnan(sum(chunkss[j])) else None
-> 2119                            for n, j in enumerate(i))
   2120             if chunks != a.chunks and all(a.chunks):
   2121                 arrays.append(a.rechunk(chunks))

~/anaconda3/lib/python3.6/site-packages/dask/array/core.py in <genexpr>(.0)
   2117             chunks = tuple(chunkss[j] if a.shape[n] > 1 else a.shape[n]
   2118                            if not np.isnan(sum(chunkss[j])) else None
-> 2119                            for n, j in enumerate(i))
   2120             if chunks != a.chunks and all(a.chunks):
   2121                 arrays.append(a.rechunk(chunks))

IndexError: tuple index out of range

Workaround:

xarray.apply_ufunc(
    scipy.stats.norm.cdf, a, kwargs={'loc': 0, 'scale': 1}, 
    dask='parallelized', output_dtypes=[a.dtype]).compute()

<xarray.DataArray (x: 2)>
array([ 0.841345,  0.97725 ])
Dimensions without coordinates: x
@shoyer shoyer added the bug label Nov 8, 2017
@shoyer shoyer added this to the 0.10 milestone Nov 8, 2017
@shoyer
Copy link
Member

shoyer commented Nov 8, 2017

Thanks for testing this out! This does seem like a bug: clearly these scalars should be treated as equivalent to 0-dimensional variables.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants