@@ -355,6 +355,7 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
355
355
"""
356
356
numblocks = kwargs .pop ('numblocks' )
357
357
concatenate = kwargs .pop ('concatenate' , None )
358
+ new_axes = kwargs .pop ('new_axes' , {})
358
359
argpairs = list (partition (2 , arrind_pairs ))
359
360
360
361
assert set (numblocks ) == set (pluck (0 , argpairs ))
@@ -364,6 +365,8 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
364
365
365
366
# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
366
367
dims = broadcast_dimensions (argpairs , numblocks )
368
+ for k in new_axes :
369
+ dims [k ] = 1
367
370
368
371
# (0, 0), (0, 1), (0, 2), (1, 0), ...
369
372
keytups = list (product (* [range (dims [i ]) for i in out_indices ]))
@@ -1748,12 +1751,14 @@ def atop(func, out_ind, *args, **kwargs):
1748
1751
Function to apply to individual tuples of blocks
1749
1752
out_ind: iterable
1750
1753
Block pattern of the output, something like 'ijk' or (1, 2, 3)
1751
- concatenate: bool
1752
- If true concatenate arrays along dummy indices, else provide lists
1753
1754
*args: sequence of Array, index pairs
1754
1755
Sequence like (x, 'ij', y, 'jk', z, 'i')
1755
1756
**kwargs: dict
1756
1757
Extra keyword arguments to pass to function
1758
+ concatenate: bool, keyword only
1759
+ If true concatenate arrays along dummy indices, else provide lists
1760
+ new_axes: dict, keyword only
1761
+ New indexes and their dimension lengths
1757
1762
1758
1763
This is best explained through example. Consider the following examples:
1759
1764
@@ -1800,6 +1805,15 @@ def atop(func, out_ind, *args, **kwargs):
1800
1805
1801
1806
>>> z = atop(sequence_dot, '', x, 'i', y, 'i') # doctest: +SKIP
1802
1807
1808
+ Add new single-chunk dimensions with the ``new_axes=`` keyword, including
1809
+ the length of the new dimension. New dimensions will always be in a single
1810
+ chunk.
1811
+
1812
+ >>> def f(x):
1813
+ ... return x[:, None] * np.ones((1, 5))
1814
+
1815
+ >>> z = atop(f, 'az', x, 'a', new_axes={'z': 5}) # doctest: +SKIP
1816
+
1803
1817
Many dask.array operations are special cases of atop. These tensor
1804
1818
operations cover a broad subset of NumPy and this function has been battle
1805
1819
tested, supporting tricky concepts like broadcasting.
@@ -1811,8 +1825,11 @@ def atop(func, out_ind, *args, **kwargs):
1811
1825
out = kwargs .pop ('name' , None ) # May be None at this point
1812
1826
token = kwargs .pop ('token' , None )
1813
1827
dtype = kwargs .pop ('dtype' , None )
1828
+ new_axes = kwargs .get ('new_axes' , {})
1814
1829
1815
1830
chunkss , arrays = unify_chunks (* args )
1831
+ for k , v in new_axes .items ():
1832
+ chunkss [k ] = (v ,)
1816
1833
arginds = list (zip (arrays , args [1 ::2 ]))
1817
1834
1818
1835
numblocks = dict ([(a .name , a .numblocks ) for a , _ in arginds ])
0 commit comments