Skip to content

Commit adb8202

Browse files
committed
WIP
1 parent 1eae283 commit adb8202

File tree

1 file changed

+79
-17
lines changed

1 file changed

+79
-17
lines changed

flox/core.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _get_expected_groups(by, raise_if_dask=True) -> pd.Index | None:
6666
return None
6767
flatby = by.ravel()
6868
expected = np.unique(flatby[~isnull(flatby)])
69-
return _convert_expected_groups_to_index(expected, isbin=False)
69+
return _convert_expected_groups_to_index((expected,), isbin=(False,))[0]
7070

7171

7272
def _get_chunk_reduction(reduction_type: str) -> Callable:
@@ -385,7 +385,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
385385
return offset, size
386386

387387

388-
def factorize_(by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, reindex=False):
388+
def factorize_(by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, reindex=False, fastpath=False):
389389
if not isinstance(by, tuple):
390390
raise ValueError(f"Expected `by` to be a tuple. Received {type(by)} instead")
391391

@@ -420,10 +420,13 @@ def factorize_(by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, re
420420
grp_shape = tuple(len(grp) for grp in found_groups)
421421
ngroups = np.prod(grp_shape)
422422
if len(by) > 1:
423-
group_idx = np.ravel_multi_index(factorized, grp_shape).reshape(by[0].shape)
423+
group_idx = np.ravel_multi_index(factorized, grp_shape)
424424
else:
425425
group_idx = factorized[0]
426426

427+
if fastpath:
428+
return group_idx, found_groups, grp_shape
429+
427430
if np.isscalar(axis) and groupvar.ndim > 1:
428431
# Not reducing along all dimensions of by
429432
# this is OK because for 3D by and axis=(1,2),
@@ -1243,23 +1246,60 @@ def _assert_by_is_aligned(shape, by):
12431246
)
12441247

12451248

1246-
def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None:
1247-
if isinstance(expected_groups, pd.IntervalIndex) or (
1248-
isinstance(expected_groups, pd.Index) and not isbin
1249-
):
1250-
return expected_groups
1251-
if isbin:
1252-
return pd.IntervalIndex.from_arrays(expected_groups[:-1], expected_groups[1:])
1253-
elif expected_groups is not None:
1254-
return pd.Index(expected_groups)
1255-
return expected_groups
1249+
def _convert_expected_groups_to_index(expected_groups: tuple, isbin: bool) -> pd.Index | None:
1250+
out = []
1251+
for ex, isbin_ in zip(expected_groups, isbin):
1252+
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):
1253+
out.append(expected_groups)
1254+
elif ex is not None:
1255+
if isbin_:
1256+
out.append(pd.IntervalIndex.from_arrays(ex[:-1], ex[1:]))
1257+
else:
1258+
out.append(pd.Index(ex))
1259+
else:
1260+
assert ex is None
1261+
out.append(None)
1262+
return tuple(out)
1263+
1264+
1265+
def _lazy_factorize_wrapper(*by, **kwargs):
1266+
group_idx, _ = factorize_(by, **kwargs)
1267+
return group_idx
1268+
1269+
1270+
def _factorize_multiple(by, expected_groups, by_is_dask):
1271+
kwargs = dict(
1272+
expected_groups=expected_groups,
1273+
axis=None, # always None, we offset later if necessary.
1274+
fastpath=True,
1275+
)
1276+
if by_is_dask:
1277+
import dask.array
1278+
1279+
group_idx = dask.array.map_blocks(
1280+
_lazy_factorize_wrapper,
1281+
*np.broadcast_arrays(*by),
1282+
meta=np.array((), dtype=np.int64),
1283+
**kwargs,
1284+
)
1285+
found_groups = tuple(None if is_duck_dask_array(b) else np.unique(b) for b in by)
1286+
else:
1287+
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)
1288+
1289+
final_groups = tuple(
1290+
pd.Index(found) if expect is None else expect
1291+
for found, expect in zip(found_groups, expected_groups)
1292+
)
1293+
1294+
if any(grp is None for grp in final_groups):
1295+
raise
1296+
return (group_idx,), final_groups, grp_shape
12561297

12571298

12581299
def groupby_reduce(
12591300
array: np.ndarray | DaskArray,
1260-
by: np.ndarray | DaskArray,
1301+
*by: np.ndarray | DaskArray,
12611302
func: str | Aggregation,
1262-
*,
12631303
expected_groups: Sequence | np.ndarray | None = None,
12641304
sort: bool = True,
12651305
isbin: bool = False,
@@ -1373,6 +1413,7 @@ def groupby_reduce(
13731413
reindex = _validate_reindex(reindex, func, method, expected_groups)
13741414

13751415
by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
1416+
nby = len(by)
13761417
by_is_dask = any(is_duck_dask_array(b) for b in by)
13771418
if not is_duck_array(array):
13781419
array = np.asarray(array)
@@ -1392,6 +1433,20 @@ def groupby_reduce(
13921433
# (pd.IntervalIndex or not)
13931434
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin)
13941435

1436+
# when grouping by multiple variables, we factorize early.
1437+
# TODO: could restrict this to dask-only
1438+
if nby > 1:
1439+
by, final_groups, grp_shape = _factorize_multiple(
1440+
by, expected_groups, by_is_dask=by_is_dask
1441+
)
1442+
expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)
1443+
else:
1444+
final_groups = expected_groups
1445+
1446+
assert len(by) == 1
1447+
by = by[0]
1448+
expected_groups = expected_groups[0]
1449+
13951450
if axis is None:
13961451
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
13971452
else:
@@ -1404,7 +1459,7 @@ def groupby_reduce(
14041459

14051460
# TODO: make sure expected_groups is unique
14061461
# TODO: hack
1407-
if len(axis) == 1 and by_ndim > 1 and expected_groups[0] is None:
1462+
if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
14081463
# TODO: hack
14091464
if not by_is_dask:
14101465
expected_groups = _get_expected_groups(by)
@@ -1509,7 +1564,11 @@ def groupby_reduce(
15091564
array = rechunk_for_blockwise(array, axis=-1, labels=by)
15101565

15111566
result, *groups = partial_agg(
1512-
array, by, expected_groups=expected_groups, reindex=reindex, method=method
1567+
array,
1568+
by,
1569+
expected_groups=None if method == "blockwise" else expected_groups,
1570+
reindex=reindex,
1571+
method=method,
15131572
)
15141573

15151574
if sort and method != "map-reduce":
@@ -1518,4 +1577,7 @@ def groupby_reduce(
15181577
result = result[..., sorted_idx]
15191578
groups = (groups[0][sorted_idx],)
15201579

1580+
if nby > 1:
1581+
groups = final_groups
1582+
result = result.reshape(result.shape[:-1] + grp_shape)
15211583
return (result, *groups)

0 commit comments

Comments
 (0)