Skip to content

Commit d73eb49

Browse files
committed
Add cohorts benchmark
1 parent a46078d commit d73eb49

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

asv_bench/benchmarks/cohorts.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import dask
2+
import numpy as np
3+
import pandas as pd
4+
5+
import flox
6+
7+
8+
class Cohorts:
9+
"""Time the core reduction function."""
10+
11+
def setup(self, *args, **kwargs):
12+
raise NotImplementedError
13+
14+
def time_find_group_cohorts(self):
15+
flox.core.find_group_cohorts(self.by, self.array.chunks)
16+
flox.cache.cache.clear()
17+
18+
def time_graph_construct(self):
19+
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis, method="cohorts")
20+
21+
22+
class NWMMidwest(Cohorts):
23+
"""2D labels, ireregular w.r.t chunk size.
24+
Mimics National Weather Model, Midwest county groupby."""
25+
26+
def setup(self, *args, **kwargs):
27+
x = np.repeat(np.arange(30), 150)
28+
y = np.repeat(np.arange(30), 60)
29+
self.by = x[np.newaxis, :] * y[:, np.newaxis]
30+
31+
self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
32+
self.axis = (-2, -1)
33+
34+
35+
class ERA5(Cohorts):
36+
"""ERA5"""
37+
38+
def setup(self, *args, **kwargs):
39+
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
40+
41+
self.by = time.dt.dayofyear.values
42+
self.axis = (-1,)
43+
44+
array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48))
45+
self.array = flox.core.rechunk_for_cohorts(
46+
array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
47+
)

0 commit comments

Comments
 (0)