Skip to content

Commit e13428b

Browse files
ENH: Enabled skipna argument on groupby reduction ops (pandas-dev#15675)
Added a skipna argurment to the groupby reduction ops: sum, prod, min, max, mean, median, var, std and sem Added relevant tests Updated whatsnew to reflect changes Co-authored-by: Tiago Firmino <[email protected]>
1 parent b162331 commit e13428b

File tree

11 files changed

+428
-109
lines changed

11 files changed

+428
-109
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ Other enhancements
3939
- Users can globally disable any ``PerformanceWarning`` by setting the option ``mode.performance_warnings`` to ``False`` (:issue:`56920`)
4040
- :meth:`Styler.format_index_names` can now be used to format the index and column names (:issue:`48936` and :issue:`47489`)
4141
- :class:`.errors.DtypeWarning` improved to include column names when mixed data types are detected (:issue:`58174`)
42+
- :meth:`.DataFrameGroupBy.sum`, :meth:`.DataFrameGroupBy.prod`, :meth:`.DataFrameGroupBy.min`, :meth:`.DataFrameGroupBy.max`, :meth:`.DataFrameGroupBy.mean`, :meth:`.DataFrameGroupBy.median`, :meth:`.DataFrameGroupBy.sem`, :meth:`.DataFrameGroupBy.std` and :meth:`.DataFrameGroupBy.var` now accept a skipna argument. (:issue:`15675`)
4243
- :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`)
4344
- :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`)
4445
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
4546
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4647
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
4748
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
48-
-
4949

5050
.. ---------------------------------------------------------------------------
5151
.. _whatsnew_300.notable_bug_fixes:

pandas/_libs/groupby.pyx

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ cdef float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) noexcept n
104104
cdef float64_t median_linear(
105105
float64_t* a,
106106
int n,
107-
bint is_datetimelike=False
107+
bint is_datetimelike=False,
108+
bint skipna=True
108109
) noexcept nogil:
109110
cdef:
110111
int i, j, na_count = 0
@@ -118,10 +119,14 @@ cdef float64_t median_linear(
118119
if is_datetimelike:
119120
for i in range(n):
120121
if a[i] == NPY_NAT:
122+
if not skipna:
123+
return NaN
121124
na_count += 1
122125
else:
123126
for i in range(n):
124127
if a[i] != a[i]:
128+
if not skipna:
129+
return NaN
125130
na_count += 1
126131

127132
if na_count:
@@ -186,6 +191,7 @@ def group_median_float64(
186191
const uint8_t[:, :] mask=None,
187192
uint8_t[:, ::1] result_mask=None,
188193
bint is_datetimelike=False,
194+
bint skipna=True,
189195
) -> None:
190196
"""
191197
Only aggregates on axis=0
@@ -244,7 +250,7 @@ def group_median_float64(
244250
ptr += _counts[0]
245251
for j in range(ngroups):
246252
size = _counts[j + 1]
247-
out[j, i] = median_linear(ptr, size, is_datetimelike)
253+
out[j, i] = median_linear(ptr, size, is_datetimelike, skipna)
248254
ptr += size
249255

250256

@@ -694,6 +700,7 @@ def group_sum(
694700
uint8_t[:, ::1] result_mask=None,
695701
Py_ssize_t min_count=0,
696702
bint is_datetimelike=False,
703+
bint skipna=True,
697704
) -> None:
698705
"""
699706
Only aggregates on axis=0 using Kahan summation
@@ -733,32 +740,39 @@ def group_sum(
733740
else:
734741
isna_entry = _treat_as_na(val, is_datetimelike)
735742

736-
if not isna_entry:
737-
nobs[lab, j] += 1
743+
if isna_entry:
744+
if skipna:
745+
continue
746+
else:
747+
sumx[lab, j] = val
748+
compensation[lab, j] = 0
749+
break
738750

739-
if sum_t is object:
740-
# NB: this does not use 'compensation' like the non-object
741-
# track does.
742-
if nobs[lab, j] == 1:
743-
# i.e. we haven't added anything yet; avoid TypeError
744-
# if e.g. val is a str and sumx[lab, j] is 0
745-
t = val
746-
else:
747-
t = sumx[lab, j] + val
748-
sumx[lab, j] = t
751+
nobs[lab, j] += 1
749752

753+
if sum_t is object:
754+
# NB: this does not use 'compensation' like the non-object
755+
# track does.
756+
if nobs[lab, j] == 1:
757+
# i.e. we haven't added anything yet; avoid TypeError
758+
# if e.g. val is a str and sumx[lab, j] is 0
759+
t = val
750760
else:
751-
y = val - compensation[lab, j]
752-
t = sumx[lab, j] + y
753-
compensation[lab, j] = t - sumx[lab, j] - y
754-
if compensation[lab, j] != compensation[lab, j]:
755-
# GH#53606
756-
# If val is +/- infinity compensation is NaN
757-
# which would lead to results being NaN instead
758-
# of +/- infinity. We cannot use util.is_nan
759-
# because of no gil
760-
compensation[lab, j] = 0
761-
sumx[lab, j] = t
761+
t = sumx[lab, j] + val
762+
sumx[lab, j] = t
763+
764+
else:
765+
y = val - compensation[lab, j]
766+
t = sumx[lab, j] + y
767+
compensation[lab, j] = t - sumx[lab, j] - y
768+
if compensation[lab, j] != compensation[lab, j]:
769+
# GH#53606
770+
# If val is +/- infinity compensation is NaN
771+
# which would lead to results being NaN instead
772+
# of +/- infinity. We cannot use util.is_nan
773+
# because of no gil
774+
compensation[lab, j] = 0
775+
sumx[lab, j] = t
762776

763777
_check_below_mincount(
764778
out, uses_mask, result_mask, ncounts, K, nobs, min_count, sumx
@@ -775,6 +789,7 @@ def group_prod(
775789
const uint8_t[:, ::1] mask,
776790
uint8_t[:, ::1] result_mask=None,
777791
Py_ssize_t min_count=0,
792+
bint skipna=True,
778793
) -> None:
779794
"""
780795
Only aggregates on axis=0
@@ -813,6 +828,10 @@ def group_prod(
813828
if not isna_entry:
814829
nobs[lab, j] += 1
815830
prodx[lab, j] *= val
831+
elif not skipna:
832+
prodx[lab, j] = val
833+
nobs[lab, j] = 0
834+
break
816835

817836
_check_below_mincount(
818837
out, uses_mask, result_mask, ncounts, K, nobs, min_count, prodx
@@ -832,6 +851,7 @@ def group_var(
832851
const uint8_t[:, ::1] mask=None,
833852
uint8_t[:, ::1] result_mask=None,
834853
bint is_datetimelike=False,
854+
bint skipna=True,
835855
str name="var",
836856
) -> None:
837857
cdef:
@@ -877,7 +897,12 @@ def group_var(
877897
else:
878898
isna_entry = _treat_as_na(val, is_datetimelike)
879899

880-
if not isna_entry:
900+
if not skipna and isna_entry:
901+
out[lab, j] = val
902+
nobs[lab, j] = 0
903+
break
904+
905+
elif not isna_entry:
881906
nobs[lab, j] += 1
882907
oldmean = mean[lab, j]
883908
mean[lab, j] += (val - oldmean) / nobs[lab, j]
@@ -998,6 +1023,7 @@ def group_mean(
9981023
const intp_t[::1] labels,
9991024
Py_ssize_t min_count=-1,
10001025
bint is_datetimelike=False,
1026+
bint skipna=True,
10011027
const uint8_t[:, ::1] mask=None,
10021028
uint8_t[:, ::1] result_mask=None,
10031029
) -> None:
@@ -1021,6 +1047,8 @@ def group_mean(
10211047
Only used in sum and prod. Always -1.
10221048
is_datetimelike : bool
10231049
True if `values` contains datetime-like entries.
1050+
skipna : bool, default True
1051+
Exclude NA/null values when computing the result.
10241052
mask : ndarray[bool, ndim=2], optional
10251053
Mask of the input values.
10261054
result_mask : ndarray[bool, ndim=2], optional
@@ -1078,7 +1106,12 @@ def group_mean(
10781106
else:
10791107
isna_entry = _treat_as_na(val, is_datetimelike)
10801108

1081-
if not isna_entry:
1109+
if not skipna and isna_entry:
1110+
sumx[lab, j] = nan_val
1111+
nobs[lab, j] = 0
1112+
break
1113+
1114+
elif not isna_entry:
10821115
nobs[lab, j] += 1
10831116
y = val - compensation[lab, j]
10841117
t = sumx[lab, j] + y
@@ -1096,12 +1129,10 @@ def group_mean(
10961129
for j in range(K):
10971130
count = nobs[i, j]
10981131
if nobs[i, j] == 0:
1099-
11001132
if uses_mask:
11011133
result_mask[i, j] = True
11021134
else:
11031135
out[i, j] = nan_val
1104-
11051136
else:
11061137
out[i, j] = sumx[i, j] / count
11071138

@@ -1660,6 +1691,7 @@ cdef group_min_max(
16601691
Py_ssize_t min_count=-1,
16611692
bint is_datetimelike=False,
16621693
bint compute_max=True,
1694+
bint skipna=True,
16631695
const uint8_t[:, ::1] mask=None,
16641696
uint8_t[:, ::1] result_mask=None,
16651697
):
@@ -1683,6 +1715,8 @@ cdef group_min_max(
16831715
True if `values` contains datetime-like entries.
16841716
compute_max : bint, default True
16851717
True to compute group-wise max, False to compute min
1718+
skipna : bool, default True
1719+
Exclude NA/null values when computing the result.
16861720
mask : ndarray[bool, ndim=2], optional
16871721
If not None, indices represent missing values,
16881722
otherwise the mask will not be used
@@ -1729,7 +1763,12 @@ cdef group_min_max(
17291763
else:
17301764
isna_entry = _treat_as_na(val, is_datetimelike)
17311765

1732-
if not isna_entry:
1766+
if not skipna and isna_entry:
1767+
group_min_or_max[lab, j] = val
1768+
nobs[lab, j] = 0
1769+
break
1770+
1771+
elif not isna_entry:
17331772
nobs[lab, j] += 1
17341773
if compute_max:
17351774
if val > group_min_or_max[lab, j]:
@@ -1866,6 +1905,7 @@ def group_max(
18661905
const intp_t[::1] labels,
18671906
Py_ssize_t min_count=-1,
18681907
bint is_datetimelike=False,
1908+
bint skipna=True,
18691909
const uint8_t[:, ::1] mask=None,
18701910
uint8_t[:, ::1] result_mask=None,
18711911
) -> None:
@@ -1880,6 +1920,7 @@ def group_max(
18801920
compute_max=True,
18811921
mask=mask,
18821922
result_mask=result_mask,
1923+
skipna=skipna,
18831924
)
18841925

18851926

@@ -1892,6 +1933,7 @@ def group_min(
18921933
const intp_t[::1] labels,
18931934
Py_ssize_t min_count=-1,
18941935
bint is_datetimelike=False,
1936+
bint skipna=True,
18951937
const uint8_t[:, ::1] mask=None,
18961938
uint8_t[:, ::1] result_mask=None,
18971939
) -> None:
@@ -1906,6 +1948,7 @@ def group_min(
19061948
compute_max=False,
19071949
mask=mask,
19081950
result_mask=result_mask,
1951+
skipna=skipna,
19091952
)
19101953

19111954

pandas/core/_numba/executor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,20 @@ def column_looper(
6969
labels: np.ndarray,
7070
ngroups: int,
7171
min_periods: int,
72+
skipna: bool = True,
7273
*args,
7374
):
7475
result = np.empty((values.shape[0], ngroups), dtype=result_dtype)
7576
na_positions = {}
7677
for i in numba.prange(values.shape[0]):
7778
output, na_pos = func(
78-
values[i], result_dtype, labels, ngroups, min_periods, *args
79+
values[i],
80+
result_dtype,
81+
labels,
82+
ngroups,
83+
min_periods,
84+
*args,
85+
skipna,
7986
)
8087
result[i] = output
8188
if len(na_pos) > 0:
@@ -162,6 +169,7 @@ def generate_shared_aggregator(
162169
nopython: bool,
163170
nogil: bool,
164171
parallel: bool,
172+
skipna: bool = True,
165173
):
166174
"""
167175
Generate a Numba function that loops over the columns 2D object and applies
@@ -190,7 +198,6 @@ def generate_shared_aggregator(
190198
-------
191199
Numba function
192200
"""
193-
194201
# A wrapper around the looper function,
195202
# to dispatch based on dtype since numba is unable to do that in nopython mode
196203

@@ -214,11 +221,11 @@ def looper_wrapper(
214221
# Need to unpack kwargs since numba only supports *args
215222
if is_grouped_kernel:
216223
result, na_positions = column_looper(
217-
values, labels, ngroups, min_periods, *kwargs.values()
224+
values, labels, ngroups, min_periods, skipna, *kwargs.values()
218225
)
219226
else:
220227
result, na_positions = column_looper(
221-
values, start, end, min_periods, *kwargs.values()
228+
values, start, end, min_periods, skipna, *kwargs.values()
222229
)
223230
if result.dtype.kind == "i":
224231
# Look if na_positions is not empty

pandas/core/_numba/kernels/mean_.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,10 @@ def grouped_mean(
169169
labels: npt.NDArray[np.intp],
170170
ngroups: int,
171171
min_periods: int,
172+
skipna: bool = True,
172173
) -> tuple[np.ndarray, list[int]]:
173174
output, nobs_arr, comp_arr, consecutive_counts, prev_vals = grouped_kahan_sum(
174-
values, result_dtype, labels, ngroups
175+
values, result_dtype, labels, ngroups, skipna
175176
)
176177

177178
# Post-processing, replace sums that don't satisfy min_periods
@@ -187,7 +188,8 @@ def grouped_mean(
187188
result = sum_x
188189
else:
189190
result = np.nan
190-
result /= nobs
191+
if nobs != 0:
192+
result /= nobs
191193
output[lab] = result
192194

193195
# na_position is empty list since float64 can already hold nans

pandas/core/_numba/kernels/min_max_.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def grouped_min_max(
8888
ngroups: int,
8989
min_periods: int,
9090
is_max: bool,
91+
skipna: bool = True,
9192
) -> tuple[np.ndarray, list[int]]:
9293
N = len(labels)
9394
nobs = np.zeros(ngroups, dtype=np.int64)
@@ -102,6 +103,9 @@ def grouped_min_max(
102103

103104
if values.dtype.kind == "i" or not np.isnan(val):
104105
nobs[lab] += 1
106+
elif not skipna and np.isnan(val):
107+
output[lab] = np.nan
108+
continue
105109
else:
106110
# NaN value cannot be a min/max value
107111
continue

0 commit comments

Comments
 (0)