|
| 1 | +import dask |
| 2 | +import numpy as np |
| 3 | +import pandas as pd |
| 4 | + |
| 5 | +import flox |
| 6 | + |
| 7 | + |
| 8 | +class Cohorts: |
| 9 | + """Time the core reduction function.""" |
| 10 | + |
| 11 | + def setup(self, *args, **kwargs): |
| 12 | + raise NotImplementedError |
| 13 | + |
| 14 | + def time_find_group_cohorts(self): |
| 15 | + flox.core.find_group_cohorts(self.by, self.array.chunks) |
| 16 | + flox.cache.cache.clear() |
| 17 | + |
| 18 | + def time_graph_construct(self): |
| 19 | + flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis, method="cohorts") |
| 20 | + |
| 21 | + |
| 22 | +class NWMMidwest(Cohorts): |
| 23 | + """2D labels, ireregular w.r.t chunk size. |
| 24 | + Mimics National Weather Model, Midwest county groupby.""" |
| 25 | + |
| 26 | + def setup(self, *args, **kwargs): |
| 27 | + x = np.repeat(np.arange(30), 150) |
| 28 | + y = np.repeat(np.arange(30), 60) |
| 29 | + self.by = x[np.newaxis, :] * y[:, np.newaxis] |
| 30 | + |
| 31 | + self.array = dask.array.ones(self.by.shape, chunks=(350, 350)) |
| 32 | + self.axis = (-2, -1) |
| 33 | + |
| 34 | + |
| 35 | +class ERA5(Cohorts): |
| 36 | + """ERA5""" |
| 37 | + |
| 38 | + def setup(self, *args, **kwargs): |
| 39 | + time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H")) |
| 40 | + |
| 41 | + self.by = time.dt.dayofyear.values |
| 42 | + self.axis = (-1,) |
| 43 | + |
| 44 | + array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48)) |
| 45 | + self.array = flox.core.rechunk_for_cohorts( |
| 46 | + array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True |
| 47 | + ) |
0 commit comments