Skip to content

Commit 6dc7cbe

Browse files
committed
Support new_axes= keyword in atop
Add new single-chunk dimensions with the ``new_axes=`` keyword, including the length of the new dimension. New dimensions will always be in a single chunk. >>> def f(x): ... return x[:, None] * np.ones((1, 5)) >>> z = atop(f, 'az', x, 'a', new_axes={'z': 5})
1 parent 3862b4d commit 6dc7cbe

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

dask/array/core.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
355355
"""
356356
numblocks = kwargs.pop('numblocks')
357357
concatenate = kwargs.pop('concatenate', None)
358+
new_axes = kwargs.pop('new_axes', {})
358359
argpairs = list(partition(2, arrind_pairs))
359360

360361
assert set(numblocks) == set(pluck(0, argpairs))
@@ -364,6 +365,8 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
364365

365366
# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
366367
dims = broadcast_dimensions(argpairs, numblocks)
368+
for k in new_axes:
369+
dims[k] = 1
367370

368371
# (0, 0), (0, 1), (0, 2), (1, 0), ...
369372
keytups = list(product(*[range(dims[i]) for i in out_indices]))
@@ -1748,12 +1751,14 @@ def atop(func, out_ind, *args, **kwargs):
17481751
Function to apply to individual tuples of blocks
17491752
out_ind: iterable
17501753
Block pattern of the output, something like 'ijk' or (1, 2, 3)
1751-
concatenate: bool
1752-
If true concatenate arrays along dummy indices, else provide lists
17531754
*args: sequence of Array, index pairs
17541755
Sequence like (x, 'ij', y, 'jk', z, 'i')
17551756
**kwargs: dict
17561757
Extra keyword arguments to pass to function
1758+
concatenate: bool, keyword only
1759+
If true concatenate arrays along dummy indices, else provide lists
1760+
new_axes: dict, keyword only
1761+
New indexes and their dimension lengths
17571762
17581763
This is best explained through example. Consider the following examples:
17591764
@@ -1800,6 +1805,15 @@ def atop(func, out_ind, *args, **kwargs):
18001805
18011806
>>> z = atop(sequence_dot, '', x, 'i', y, 'i') # doctest: +SKIP
18021807
1808+
Add new single-chunk dimensions with the ``new_axes=`` keyword, including
1809+
the length of the new dimension. New dimensions will always be in a single
1810+
chunk.
1811+
1812+
>>> def f(x):
1813+
... return x[:, None] * np.ones((1, 5))
1814+
1815+
>>> z = atop(f, 'az', x, 'a', new_axes={'z': 5}) # doctest: +SKIP
1816+
18031817
Many dask.array operations are special cases of atop. These tensor
18041818
operations cover a broad subset of NumPy and this function has been battle
18051819
tested, supporting tricky concepts like broadcasting.
@@ -1811,8 +1825,11 @@ def atop(func, out_ind, *args, **kwargs):
18111825
out = kwargs.pop('name', None) # May be None at this point
18121826
token = kwargs.pop('token', None)
18131827
dtype = kwargs.pop('dtype', None)
1828+
new_axes = kwargs.get('new_axes', {})
18141829

18151830
chunkss, arrays = unify_chunks(*args)
1831+
for k, v in new_axes.items():
1832+
chunkss[k] = (v,)
18161833
arginds = list(zip(arrays, args[1::2]))
18171834

18181835
numblocks = dict([(a.name, a.numblocks) for a, _ in arginds])

dask/array/tests/test_array_core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,31 @@ def test_atop_names():
20442044
assert y.name.startswith('add')
20452045

20462046

2047+
def test_atop_new_axes():
2048+
def f(x):
2049+
return x[:, None] * np.ones((1, 7))
2050+
x = da.ones(5, chunks=2)
2051+
y = atop(f, 'aq', x, 'a', new_axes={'q': 7}, concatenate=True)
2052+
assert y.chunks == ((2, 2, 1), (7,))
2053+
assert_eq(y, np.ones((5, 7)))
2054+
2055+
def f(x):
2056+
return x[None, :] * np.ones((7, 1))
2057+
x = da.ones(5, chunks=2)
2058+
y = atop(f, 'qa', x, 'a', new_axes={'q': 7}, concatenate=True)
2059+
assert y.chunks == ((7,), (2, 2, 1))
2060+
assert_eq(y, np.ones((7, 5)))
2061+
2062+
def f(x):
2063+
y = x.sum(axis=1)
2064+
return y[:, None] * np.ones((1, 5))
2065+
2066+
x = da.ones((4, 6), chunks=(2, 2))
2067+
y = atop(f, 'aq', x, 'ab', new_axes={'q': 5}, concatenate=True)
2068+
assert y.chunks == ((2, 2), (5,))
2069+
assert_eq(y, np.ones((4, 5)) * 6)
2070+
2071+
20472072
def test_atop_kwargs():
20482073
def f(a, b=0):
20492074
return a + b

docs/source/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Array
2727
+++++
2828
- Add information about how ``dask.array`` ``chunks`` argument work (:pr:`1504`)
2929
- Fix field access with non-scalar fields in ``dask.array`` (:pr:`1484`)
30+
- Add concatenate= keyword to atop to concatenate chunks of contracted dimensions
31+
- Add new_axes= keyword to atop to support adding new dimensions
3032

3133
Bag
3234
++++

0 commit comments

Comments
 (0)