Skip to content

Commit a897034

Browse files
authored
Significantly faster cohorts detection. (#272)
* Significantly faster cohorts detection. * cleanup * add benchmark * update types * fix * single chunk optimization * more optimization * more optimization * add comment * fix benchmark
1 parent 9f82e19 commit a897034

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

asv_bench/benchmarks/cohorts.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import dask
22
import numpy as np
33
import pandas as pd
4+
import xarray as xr
45

56
import flox
7+
from flox.xarray import xarray_reduce
68

79

810
class Cohorts:
@@ -12,7 +14,7 @@ def setup(self, *args, **kwargs):
1214
raise NotImplementedError
1315

1416
def time_find_group_cohorts(self):
15-
flox.core.find_group_cohorts(self.by, self.array.chunks)
17+
flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis])
1618
# The cache clear fails dependably in CI
1719
# Not sure why
1820
try:
@@ -125,3 +127,13 @@ class PerfectMonthlyRechunked(PerfectMonthly):
125127
def setup(self, *args, **kwargs):
126128
super().setup()
127129
super().rechunk()
130+
131+
132+
def time_cohorts_era5_single():
133+
TIME = 900 # 92044 in Google ARCO ERA5
134+
da = xr.DataArray(
135+
dask.array.ones((TIME, 721, 1440), chunks=(1, -1, -1)),
136+
dims=("time", "lat", "lon"),
137+
coords=dict(time=pd.date_range("1959-01-01", freq="6H", periods=TIME)),
138+
)
139+
xarray_reduce(da, da.time.dt.day, method="cohorts", func="any")

flox/core.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,15 +208,20 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
208208
# 1. First subset the array appropriately
209209
axis = range(-labels.ndim, 0)
210210
# Easier to create a dask array and use the .blocks property
211-
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
211+
array = dask.array.empty(tuple(sum(c) for c in chunks), chunks=chunks)
212212
labels = np.broadcast_to(labels, array.shape[-labels.ndim :])
213213

214214
# Iterate over each block and create a new block of same shape with "chunk number"
215215
shape = tuple(array.blocks.shape[ax] for ax in axis)
216-
blocks = np.empty(math.prod(shape), dtype=object)
217-
for idx, block in enumerate(array.blocks.ravel()):
218-
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
219-
which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)
216+
# Use a numpy object array to enable assignment in the loop
217+
# TODO: is it possible to just use a nested list?
218+
# That is what we need for `np.block`
219+
blocks = np.empty(shape, dtype=object)
220+
array_chunks = tuple(np.array(c) for c in array.chunks)
221+
for idx, blockindex in enumerate(np.ndindex(array.numblocks)):
222+
chunkshape = tuple(c[i] for c, i in zip(array_chunks, blockindex))
223+
blocks[blockindex] = np.full(chunkshape, idx)
224+
which_chunk = np.block(blocks.tolist()).reshape(-1)
220225

221226
raveled = labels.reshape(-1)
222227
# these are chunks where a label is present
@@ -229,7 +234,11 @@ def invert(x) -> tuple[np.ndarray, ...]:
229234

230235
chunks_cohorts = tlz.groupby(invert, label_chunks.keys())
231236

232-
if merge:
237+
# If our dataset has chunksize one along the axis,
238+
# then no merging is possible.
239+
single_chunks = all((ac == 1).all() for ac in array_chunks)
240+
241+
if merge and not single_chunks:
233242
# First sort by number of chunks occupied by cohort
234243
sorted_chunks_cohorts = dict(
235244
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)

0 commit comments

Comments
 (0)