Skip to content

Commit 66f152b

Browse files
Illviljanpre-commit-ci[bot]dcherian
authored
typing fixes (#235)
* Update core.py * Update xarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid renaming * Update xarray.py * Update xarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray.py * Update xarray.py * Update xarray.py * Update xarray.py * Update xarray.py * split to optional * Update xarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray.py * convert to pd.Index instead of ndarray * Handled different slicer types? * not supported instead? * specify type for simple_combine * Handle None in agg.min_count * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * add overloads and rename * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more overloads * ignore * Update core.py * Update xarray.py * Update core.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update core.py * Update core.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update flox/core.py * Update flox/core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray.py * Have to add another type here because of xarray not supporting IntervalIndex * Update xarray.py * test ex instead of e * Revert "test ex instead of e" This reverts commit 8e55d3a. * check reveal_type * without e * try no redefinition * IF redefining ex, mypy always takes the first definition of ex. even if it has been narrowed down. * test min_count=0 * test min_count=0 * test min_count=0 * test min_count=0 * test min_count = 0 * test min_count=0 * test min_count=0 * test min_count=0 * test min_count=0 * test min_count=0 * test min_count=0 * test min_count=0 * test min_count=0 * Update asv_bench/benchmarks/combine.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 6cf315a commit 66f152b

File tree

4 files changed

+105
-59
lines changed

4 files changed

+105
-59
lines changed

asv_bench/benchmarks/combine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import partial
2+
from typing import Any
23

34
import numpy as np
45

@@ -43,8 +44,8 @@ class Combine1d(Combine):
4344
this is for reducting along a single dimension
4445
"""
4546

46-
def setup(self, *args, **kwargs):
47-
def construct_member(groups):
47+
def setup(self, *args, **kwargs) -> None:
48+
def construct_member(groups) -> dict[str, Any]:
4849
return {
4950
"groups": groups,
5051
"intermediates": [

flox/aggregations.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
# how to aggregate results after first round of reduction
186186
self.combine: FuncTuple = _atleast_1d(combine)
187187
# simpler reductions used with the "simple combine" algorithm
188-
self.simple_combine = None
188+
self.simple_combine: tuple[Callable, ...] = ()
189189
# final aggregation
190190
self.aggregate: Callable | str = aggregate if aggregate else self.combine[0]
191191
# finalize results (see mean)
@@ -207,7 +207,7 @@ def __init__(
207207

208208
# The following are set by _initialize_aggregation
209209
self.finalize_kwargs: dict[Any, Any] = {}
210-
self.min_count: int | None = None
210+
self.min_count: int = 0
211211

212212
def _normalize_dtype_fill_value(self, value, name):
213213
value = _atleast_1d(value)
@@ -504,7 +504,7 @@ def _initialize_aggregation(
504504
dtype,
505505
array_dtype,
506506
fill_value,
507-
min_count: int | None,
507+
min_count: int,
508508
finalize_kwargs: dict[Any, Any] | None,
509509
) -> Aggregation:
510510
if not isinstance(func, Aggregation):
@@ -559,9 +559,6 @@ def _initialize_aggregation(
559559
assert isinstance(finalize_kwargs, dict)
560560
agg.finalize_kwargs = finalize_kwargs
561561

562-
if min_count is None:
563-
min_count = 0
564-
565562
# This is needed for the dask pathway.
566563
# Because we use intermediate fill_value since a group could be
567564
# absent in one block, but present in another block
@@ -579,7 +576,7 @@ def _initialize_aggregation(
579576
else:
580577
agg.min_count = 0
581578

582-
simple_combine = []
579+
simple_combine: list[Callable] = []
583580
for combine in agg.combine:
584581
if isinstance(combine, str):
585582
if combine in ["nanfirst", "nanlast"]:

flox/core.py

Lines changed: 69 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@
5151
T_DuckArray = Union[np.ndarray, DaskArray] # Any ?
5252
T_By = T_DuckArray
5353
T_Bys = tuple[T_By, ...]
54-
T_ExpectIndex = Union[pd.Index, None]
55-
T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
54+
T_ExpectIndex = Union[pd.Index]
5655
T_ExpectIndexTuple = tuple[T_ExpectIndex, ...]
56+
T_ExpectIndexOpt = Union[T_ExpectIndex, None]
57+
T_ExpectIndexOptTuple = tuple[T_ExpectIndexOpt, ...]
58+
T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
5759
T_ExpectTuple = tuple[T_Expect, ...]
58-
T_ExpectedGroups = Union[T_Expect, T_ExpectTuple]
60+
T_ExpectOpt = Union[Sequence, np.ndarray, T_ExpectIndexOpt]
61+
T_ExpectOptTuple = tuple[T_ExpectOpt, ...]
62+
T_ExpectedGroups = Union[T_Expect, T_ExpectOptTuple]
5963
T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
6064
T_Func = Union[str, Callable]
6165
T_Funcs = Union[T_Func, Sequence[T_Func]]
@@ -98,7 +102,7 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
98102
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
99103

100104

101-
def _get_expected_groups(by: T_By, sort: bool) -> pd.Index:
105+
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
102106
if is_duck_dask_array(by):
103107
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
104108
flatby = by.reshape(-1)
@@ -219,8 +223,13 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
219223
raveled = labels.reshape(-1)
220224
# these are chunks where a label is present
221225
label_chunks = pd.Series(which_chunk).groupby(raveled).unique()
226+
222227
# These invert the label_chunks mapping so we know which labels occur together.
223-
chunks_cohorts = tlz.groupby(lambda x: tuple(label_chunks.get(x)), label_chunks.keys())
228+
def invert(x) -> tuple[np.ndarray, ...]:
229+
arr = label_chunks.get(x)
230+
return tuple(arr) # type: ignore [arg-type] # pandas issue?
231+
232+
chunks_cohorts = tlz.groupby(invert, label_chunks.keys())
224233

225234
if merge:
226235
# First sort by number of chunks occupied by cohort
@@ -459,7 +468,7 @@ def factorize_(
459468
axes: T_Axes,
460469
*,
461470
fastpath: Literal[True],
462-
expected_groups: tuple[pd.Index, ...] | None = None,
471+
expected_groups: T_ExpectIndexOptTuple | None = None,
463472
reindex: bool = False,
464473
sort: bool = True,
465474
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]:
@@ -471,7 +480,7 @@ def factorize_(
471480
by: T_Bys,
472481
axes: T_Axes,
473482
*,
474-
expected_groups: tuple[pd.Index, ...] | None = None,
483+
expected_groups: T_ExpectIndexOptTuple | None = None,
475484
reindex: bool = False,
476485
sort: bool = True,
477486
fastpath: Literal[False] = False,
@@ -484,7 +493,7 @@ def factorize_(
484493
by: T_Bys,
485494
axes: T_Axes,
486495
*,
487-
expected_groups: tuple[pd.Index, ...] | None = None,
496+
expected_groups: T_ExpectIndexOptTuple | None = None,
488497
reindex: bool = False,
489498
sort: bool = True,
490499
fastpath: bool = False,
@@ -496,7 +505,7 @@ def factorize_(
496505
by: T_Bys,
497506
axes: T_Axes,
498507
*,
499-
expected_groups: tuple[pd.Index, ...] | None = None,
508+
expected_groups: T_ExpectIndexOptTuple | None = None,
500509
reindex: bool = False,
501510
sort: bool = True,
502511
fastpath: bool = False,
@@ -546,7 +555,7 @@ def factorize_(
546555
else:
547556
idx = np.zeros_like(flat, dtype=np.intp) - 1
548557

549-
found_groups.append(expect)
558+
found_groups.append(np.array(expect))
550559
else:
551560
if expect is not None and reindex:
552561
sorter = np.argsort(expect)
@@ -560,7 +569,7 @@ def factorize_(
560569
idx = sorter[(idx,)]
561570
idx[mask] = -1
562571
else:
563-
idx, groups = pd.factorize(flat, sort=sort)
572+
idx, groups = pd.factorize(flat, sort=sort) # type: ignore # pandas issue?
564573

565574
found_groups.append(np.array(groups))
566575
factorized.append(idx.reshape(groupvar.shape))
@@ -853,7 +862,8 @@ def _finalize_results(
853862
"""
854863
squeezed = _squeeze_results(results, axis)
855864

856-
if agg.min_count > 0:
865+
min_count = agg.min_count
866+
if min_count > 0:
857867
counts = squeezed["intermediates"][-1]
858868
squeezed["intermediates"] = squeezed["intermediates"][:-1]
859869

@@ -864,8 +874,8 @@ def _finalize_results(
864874
else:
865875
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)
866876

867-
if agg.min_count > 0:
868-
count_mask = counts < agg.min_count
877+
if min_count > 0:
878+
count_mask = counts < min_count
869879
if count_mask.any():
870880
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
871881
# necessary
@@ -1283,7 +1293,7 @@ def dask_groupby_agg(
12831293
array: DaskArray,
12841294
by: T_By,
12851295
agg: Aggregation,
1286-
expected_groups: pd.Index | None,
1296+
expected_groups: T_ExpectIndexOpt,
12871297
axis: T_Axes = (),
12881298
fill_value: Any = None,
12891299
method: T_Method = "map-reduce",
@@ -1423,9 +1433,11 @@ def dask_groupby_agg(
14231433
group_chunks = ((np.nan,),)
14241434
else:
14251435
if expected_groups is None:
1426-
expected_groups = _get_expected_groups(by_input, sort=sort)
1427-
groups = (expected_groups.to_numpy(),)
1428-
group_chunks = ((len(expected_groups),),)
1436+
expected_groups_ = _get_expected_groups(by_input, sort=sort)
1437+
else:
1438+
expected_groups_ = expected_groups
1439+
groups = (expected_groups_.to_numpy(),)
1440+
group_chunks = ((len(expected_groups_),),)
14291441

14301442
elif method == "cohorts":
14311443
chunks_cohorts = find_group_cohorts(
@@ -1569,7 +1581,7 @@ def _validate_reindex(
15691581
return reindex
15701582

15711583

1572-
def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys):
1584+
def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys) -> None:
15731585
assert all(b.ndim == by[0].ndim for b in by[1:])
15741586
for idx, b in enumerate(by):
15751587
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
@@ -1584,18 +1596,33 @@ def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys):
15841596
)
15851597

15861598

1599+
@overload
1600+
def _convert_expected_groups_to_index(
1601+
expected_groups: tuple[None, ...], isbin: Sequence[bool], sort: bool
1602+
) -> tuple[None, ...]:
1603+
...
1604+
1605+
1606+
@overload
15871607
def _convert_expected_groups_to_index(
15881608
expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool
15891609
) -> T_ExpectIndexTuple:
1590-
out: list[pd.Index | None] = []
1610+
...
1611+
1612+
1613+
def _convert_expected_groups_to_index(
1614+
expected_groups: T_ExpectOptTuple, isbin: Sequence[bool], sort: bool
1615+
) -> T_ExpectIndexOptTuple:
1616+
out: list[T_ExpectIndexOpt] = []
15911617
for ex, isbin_ in zip(expected_groups, isbin):
15921618
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin_):
15931619
if sort:
1594-
ex = ex.sort_values()
1595-
out.append(ex)
1620+
out.append(ex.sort_values())
1621+
else:
1622+
out.append(ex)
15961623
elif ex is not None:
15971624
if isbin_:
1598-
out.append(pd.IntervalIndex.from_breaks(ex))
1625+
out.append(pd.IntervalIndex.from_breaks(ex)) # type: ignore [arg-type] # TODO: what do we want here?
15991626
else:
16001627
if sort:
16011628
ex = np.sort(ex)
@@ -1613,7 +1640,7 @@ def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray:
16131640

16141641
def _factorize_multiple(
16151642
by: T_Bys,
1616-
expected_groups: T_ExpectIndexTuple,
1643+
expected_groups: T_ExpectIndexOptTuple,
16171644
any_by_dask: bool,
16181645
reindex: bool,
16191646
sort: bool = True,
@@ -1668,7 +1695,17 @@ def _factorize_multiple(
16681695
return (group_idx,), found_groups, grp_shape
16691696

16701697

1671-
def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple:
1698+
@overload
1699+
def _validate_expected_groups(nby: int, expected_groups: None) -> tuple[None, ...]:
1700+
...
1701+
1702+
1703+
@overload
1704+
def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroups) -> T_ExpectTuple:
1705+
...
1706+
1707+
1708+
def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectOptTuple:
16721709
if expected_groups is None:
16731710
return (None,) * nby
16741711

@@ -1935,21 +1972,20 @@ def groupby_reduce(
19351972
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
19361973
if min_count is None:
19371974
if nax < by_.ndim or fill_value is not None:
1938-
min_count = 1
1975+
min_count_: int = 1
1976+
else:
1977+
min_count_ = 0
1978+
else:
1979+
min_count_ = min_count
19391980

19401981
# TODO: set in xarray?
1941-
if (
1942-
min_count is not None
1943-
and min_count > 0
1944-
and func in ["nansum", "nanprod"]
1945-
and fill_value is None
1946-
):
1982+
if min_count_ > 0 and func in ["nansum", "nanprod"] and fill_value is None:
19471983
# nansum, nanprod have fill_value=0, 1
19481984
# overwrite than when min_count is set
19491985
fill_value = np.nan
19501986

19511987
kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
1952-
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
1988+
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs)
19531989

19541990
groups: tuple[np.ndarray | DaskArray, ...]
19551991
if not has_dask:

0 commit comments

Comments
 (0)