@@ -208,15 +208,20 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
208
208
# 1. First subset the array appropriately
209
209
axis = range (- labels .ndim , 0 )
210
210
# Easier to create a dask array and use the .blocks property
211
- array = dask .array .ones (tuple (sum (c ) for c in chunks ), chunks = chunks )
211
+ array = dask .array .empty (tuple (sum (c ) for c in chunks ), chunks = chunks )
212
212
labels = np .broadcast_to (labels , array .shape [- labels .ndim :])
213
213
214
214
# Iterate over each block and create a new block of same shape with "chunk number"
215
215
shape = tuple (array .blocks .shape [ax ] for ax in axis )
216
- blocks = np .empty (math .prod (shape ), dtype = object )
217
- for idx , block in enumerate (array .blocks .ravel ()):
218
- blocks [idx ] = np .full (tuple (block .shape [ax ] for ax in axis ), idx )
219
- which_chunk = np .block (blocks .reshape (shape ).tolist ()).reshape (- 1 )
216
+ # Use a numpy object array to enable assignment in the loop
217
+ # TODO: is it possible to just use a nested list?
218
+ # That is what we need for `np.block`
219
+ blocks = np .empty (shape , dtype = object )
220
+ array_chunks = tuple (np .array (c ) for c in array .chunks )
221
+ for idx , blockindex in enumerate (np .ndindex (array .numblocks )):
222
+ chunkshape = tuple (c [i ] for c , i in zip (array_chunks , blockindex ))
223
+ blocks [blockindex ] = np .full (chunkshape , idx )
224
+ which_chunk = np .block (blocks .tolist ()).reshape (- 1 )
220
225
221
226
raveled = labels .reshape (- 1 )
222
227
# these are chunks where a label is present
@@ -229,7 +234,11 @@ def invert(x) -> tuple[np.ndarray, ...]:
229
234
230
235
chunks_cohorts = tlz .groupby (invert , label_chunks .keys ())
231
236
232
- if merge :
237
+ # If our dataset has chunksize one along the axis,
238
+ # then no merging is possible.
239
+ single_chunks = all ((ac == 1 ).all () for ac in array_chunks )
240
+
241
+ if merge and not single_chunks :
233
242
# First sort by number of chunks occupied by cohort
234
243
sorted_chunks_cohorts = dict (
235
244
sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ]), reverse = True )
0 commit comments