Skip to content

Commit d68300a

Browse files
committed
ENH: numerically stable roll_skew and roll_kurt
1 parent d6a7700 commit d68300a

File tree

3 files changed

+149
-140
lines changed

3 files changed

+149
-140
lines changed

pandas/algos.pyx

+83-122
Original file line numberDiff line numberDiff line change
@@ -1331,144 +1331,105 @@ def roll_var(ndarray[double_t] input, int win, int minp, int ddof=1):
13311331

13321332
return output
13331333

1334+
#----------------------------------------------------------------------
1335+
# Rolling skewness and kurtosis
13341336

1335-
#-------------------------------------------------------------------------------
1336-
# Rolling skewness
13371337
@cython.boundscheck(False)
13381338
@cython.wraparound(False)
1339-
def roll_skew(ndarray[double_t] input, int win, int minp):
1339+
def roll_higher_moment(ndarray[double_t] input, int win, int minp, bint kurt):
1340+
"""
1341+
Numerically stable implementation of skewness and kurtosis using a
1342+
Welford-like method. If `kurt` is True, rolling kurtosis is computed,
1343+
if False, rolling skewness.
1344+
"""
13401345
cdef double val, prev
1341-
cdef double x = 0, xx = 0, xxx = 0
1342-
cdef Py_ssize_t nobs = 0, i
1343-
cdef Py_ssize_t N = len(input)
1346+
cdef double mean_x = 0, s2dm_x = 0, s3dm_x = 0, s4dm_x = 0, rep = NaN
1347+
cdef double delta, delta_n, tmp
1348+
cdef Py_ssize_t i, nobs = 0, nrep = 0, N = len(input)
13441349

13451350
cdef ndarray[double_t] output = np.empty(N, dtype=float)
13461351

1347-
# 3 components of the skewness equation
1348-
cdef double A, B, C, R
1349-
13501352
minp = _check_minp(win, minp, N)
1351-
with nogil:
1352-
for i from 0 <= i < minp - 1:
1353-
val = input[i]
1353+
minobs = max(minp, 4 if kurt else 3)
13541354

1355-
# Not NaN
1356-
if val == val:
1357-
nobs += 1
1358-
x += val
1359-
xx += val * val
1360-
xxx += val * val * val
1361-
1362-
output[i] = NaN
1363-
1364-
for i from minp - 1 <= i < N:
1365-
val = input[i]
1366-
1367-
if val == val:
1368-
nobs += 1
1369-
x += val
1370-
xx += val * val
1371-
xxx += val * val * val
1372-
1373-
if i > win - 1:
1374-
prev = input[i - win]
1375-
if prev == prev:
1376-
x -= prev
1377-
xx -= prev * prev
1378-
xxx -= prev * prev * prev
1379-
1380-
nobs -= 1
1381-
if nobs >= minp:
1382-
A = x / nobs
1383-
B = xx / nobs - A * A
1384-
C = xxx / nobs - A * A * A - 3 * A * B
1385-
if B <= 0 or nobs < 3:
1386-
output[i] = NaN
1355+
for i from 0 <= i < N:
1356+
val = input[i]
1357+
prev = NaN if i < win else input[i - win]
1358+
1359+
if prev == prev:
1360+
# prev is not NaN, remove an observation...
1361+
nobs -= 1
1362+
if nobs < nrep:
1363+
# ...all non-NaN values were identical, remove a repeat
1364+
nrep -= 1
1365+
if nobs == nrep:
1366+
# We can get here both if all non-NaN were already identical
1367+
# or if nobs == 1 after removing the observation
1368+
if nrep == 0:
1369+
rep = NaN
1370+
mean_x = 0
13871371
else:
1388-
R = sqrt(B)
1389-
output[i] = ((sqrt(nobs * (nobs - 1.)) * C) /
1390-
((nobs-2) * R * R * R))
1372+
mean_x = rep
1373+
# This is redundant most of the time
1374+
s2dm_x = s3dm_x = s4dm_x = 0
13911375
else:
1392-
output[i] = NaN
1393-
1394-
return output
1395-
1396-
#-------------------------------------------------------------------------------
1397-
# Rolling kurtosis
1398-
@cython.boundscheck(False)
1399-
@cython.wraparound(False)
1400-
def roll_kurt(ndarray[double_t] input,
1401-
int win, int minp):
1402-
cdef double val, prev
1403-
cdef double x = 0, xx = 0, xxx = 0, xxxx = 0
1404-
cdef Py_ssize_t nobs = 0, i
1405-
cdef Py_ssize_t N = len(input)
1406-
1407-
cdef ndarray[double_t] output = np.empty(N, dtype=float)
1408-
1409-
# 5 components of the kurtosis equation
1410-
cdef double A, B, C, D, R, K
1411-
1412-
minp = _check_minp(win, minp, N)
1413-
with nogil:
1414-
for i from 0 <= i < minp - 1:
1415-
val = input[i]
1416-
1417-
# Not NaN
1418-
if val == val:
1419-
nobs += 1
1420-
1421-
# seriously don't ask me why this is faster
1422-
x += val
1423-
xx += val * val
1424-
xxx += val * val * val
1425-
xxxx += val * val * val * val
1426-
1427-
output[i] = NaN
1428-
1429-
for i from minp - 1 <= i < N:
1430-
val = input[i]
1431-
1432-
if val == val:
1433-
nobs += 1
1434-
x += val
1435-
xx += val * val
1436-
xxx += val * val * val
1437-
xxxx += val * val * val * val
1438-
1439-
if i > win - 1:
1440-
prev = input[i - win]
1441-
if prev == prev:
1442-
x -= prev
1443-
xx -= prev * prev
1444-
xxx -= prev * prev * prev
1445-
xxxx -= prev * prev * prev * prev
1446-
1447-
nobs -= 1
1448-
1449-
if nobs >= minp:
1450-
A = x / nobs
1451-
R = A * A
1452-
B = xx / nobs - R
1453-
R = R * A
1454-
C = xxx / nobs - R - 3 * A * B
1455-
R = R * A
1456-
D = xxxx / nobs - R - 6*B*A*A - 4*C*A
1457-
1458-
if B == 0 or nobs < 4:
1459-
output[i] = NaN
1460-
1461-
else:
1462-
K = (nobs * nobs - 1.)*D/(B*B) - 3*((nobs-1.)**2)
1463-
K = K / ((nobs - 2.)*(nobs-3.))
1464-
1465-
output[i] = K
1376+
# ...update mean and sums of raised differences from mean
1377+
delta = prev - mean_x
1378+
delta_n = delta / nobs
1379+
tmp = delta * delta_n * (nobs + 1)
1380+
if kurt:
1381+
s4dm_x -= ((tmp * ((nobs + 3) * nobs + 3) -
1382+
6 * s2dm_x) * delta_n - 4 * s3dm_x) * delta_n
1383+
s3dm_x -= (tmp * (nobs + 2) - 3 * s2dm_x) * delta_n
1384+
s2dm_x -= tmp
1385+
mean_x -= delta_n
14661386

1387+
if val == val:
1388+
# val is not NaN, adding an observation...
1389+
nobs += 1
1390+
if val == rep:
1391+
# ...and adding a repeat
1392+
nrep += 1
14671393
else:
1468-
output[i] = NaN
1394+
# ...and resetting repeats
1395+
nrep = 1
1396+
rep = val
1397+
if nobs == nrep:
1398+
# ...all non-NaN values are identical
1399+
mean_x = rep
1400+
s2dm_x = s3dm_x = s4dm_x = 0
1401+
else:
1402+
# ...update mean and sums of raised differences from mean
1403+
delta = val - mean_x
1404+
delta_n = delta / nobs
1405+
tmp = delta * delta_n * (nobs - 1)
1406+
if kurt:
1407+
s4dm_x += ((tmp * ((nobs - 3) * nobs + 3) +
1408+
6 * s2dm_x) * delta_n - 4 * s3dm_x) * delta_n
1409+
s3dm_x += (tmp * (nobs - 2) - 3 * s2dm_x) * delta_n
1410+
s2dm_x += tmp
1411+
mean_x += delta_n
1412+
1413+
# Sums of even powers must be positive
1414+
if s2dm_x < 0 or s4dm_x < 0:
1415+
s2dm_x = s3dm_x = s4_dm_x = 0
1416+
1417+
if nobs < minobs or s2dm_x == 0:
1418+
output[i] = NaN
1419+
elif kurt:
1420+
# multiplications are cheap, divisions are not
1421+
tmp = s2dm_x * s2dm_x
1422+
output[i] = (nobs - 1) * (nobs * (nobs + 1) * s4dm_x -
1423+
3 * (nobs - 1) * tmp)
1424+
output[i] /= tmp * (nobs - 2) * (nobs - 3)
1425+
else:
1426+
# multiplications are cheap, divisions and square roots are not
1427+
tmp = (nobs - 2) * (nobs - 2) * s2dm_x * s2dm_x * s2dm_x
1428+
output[i] = s3dm_x * nobs * sqrt((nobs - 1) / tmp)
14691429

14701430
return output
14711431

1432+
14721433
#-------------------------------------------------------------------------------
14731434
# Rolling median, min, max
14741435

pandas/stats/moments.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ def rolling_corr_pairwise(df1, df2=None, window=None, min_periods=None,
355355

356356

357357
def _rolling_moment(arg, window, func, minp, axis=0, freq=None, center=False,
358-
how=None, args=(), kwargs={}, **kwds):
358+
how=None, args=(), kwargs={}, center_data=False,
359+
norm_data=False, **kwds):
359360
"""
360361
Rolling statistical measure using supplied function. Designed to be
361362
used with passed-in Cython array-based functions.
@@ -378,15 +379,21 @@ def _rolling_moment(arg, window, func, minp, axis=0, freq=None, center=False,
378379
Passed on to func
379380
kwargs : dict
380381
Passed on to func
382+
center_data : bool
383+
If True, subtract the mean of the data from the values
384+
norm_data: bool
385+
If True, subtract the mean of the data from the values, and divide
386+
by their standard deviation.
381387
382388
Returns
383389
-------
384390
y : type of input
385391
"""
386392
arg = _conv_timerule(arg, freq, how)
387393

388-
return_hook, values = _process_data_structure(arg)
389-
394+
return_hook, values = _process_data_structure(arg,
395+
center_data=center_data,
396+
norm_data=norm_data)
390397
if values.size == 0:
391398
result = values.copy()
392399
else:
@@ -423,7 +430,8 @@ def _center_window(rs, window, axis):
423430
return rs
424431

425432

426-
def _process_data_structure(arg, kill_inf=True):
433+
def _process_data_structure(arg, kill_inf=True, center_data=False,
434+
norm_data=False):
427435
if isinstance(arg, DataFrame):
428436
return_hook = lambda v: type(arg)(v, index=arg.index,
429437
columns=arg.columns)
@@ -438,9 +446,15 @@ def _process_data_structure(arg, kill_inf=True):
438446
if not issubclass(values.dtype.type, float):
439447
values = values.astype(float)
440448

441-
if kill_inf:
449+
if kill_inf or center_data or norm_data:
442450
values = values.copy()
443-
values[np.isinf(values)] = np.NaN
451+
mask = np.isfinite(values)
452+
if kill_inf:
453+
values[~mask] = np.NaN
454+
if center_data or norm_data:
455+
values -= np.mean(values[mask])
456+
if norm_data:
457+
values /= np.std(values[mask])
444458

445459
return return_hook, values
446460

@@ -629,7 +643,8 @@ def _use_window(minp, window):
629643
return minp
630644

631645

632-
def _rolling_func(func, desc, check_minp=_use_window, how=None, additional_kw=''):
646+
def _rolling_func(func, desc, check_minp=_use_window, how=None,
647+
additional_kw='', center_data=False, norm_data=False):
633648
if how is None:
634649
how_arg_str = 'None'
635650
else:
@@ -645,7 +660,8 @@ def call_cython(arg, window, minp, args=(), kwargs={}, **kwds):
645660
minp = check_minp(minp, window)
646661
return func(arg, window, minp, **kwds)
647662
return _rolling_moment(arg, window, call_cython, min_periods, freq=freq,
648-
center=center, how=how, **kwargs)
663+
center=center, how=how, center_data=center_data,
664+
norm_data=norm_data, **kwargs)
649665

650666
return f
651667

@@ -657,16 +673,24 @@ def call_cython(arg, window, minp, args=(), kwargs={}, **kwds):
657673
how='median')
658674

659675
_ts_std = lambda *a, **kw: _zsqrt(algos.roll_var(*a, **kw))
676+
def _roll_skew(*args, **kwargs):
677+
kwargs['kurt'] = False
678+
return algos.roll_higher_moment(*args, **kwargs)
679+
def _roll_kurt(*args, **kwargs):
680+
kwargs['kurt'] = True
681+
return algos.roll_higher_moment(*args, **kwargs)
660682
rolling_std = _rolling_func(_ts_std, 'Moving standard deviation.',
661683
check_minp=_require_min_periods(1),
662684
additional_kw=_ddof_kw)
663685
rolling_var = _rolling_func(algos.roll_var, 'Moving variance.',
664686
check_minp=_require_min_periods(1),
665687
additional_kw=_ddof_kw)
666-
rolling_skew = _rolling_func(algos.roll_skew, 'Unbiased moving skewness.',
667-
check_minp=_require_min_periods(3))
668-
rolling_kurt = _rolling_func(algos.roll_kurt, 'Unbiased moving kurtosis.',
669-
check_minp=_require_min_periods(4))
688+
rolling_skew = _rolling_func(_roll_skew, 'Unbiased moving skewness.',
689+
check_minp=_require_min_periods(3),
690+
center_data=True, norm_data=False)
691+
rolling_kurt = _rolling_func(_roll_kurt, 'Unbiased moving kurtosis.',
692+
check_minp=_require_min_periods(4),
693+
center_data=True, norm_data=True)
670694

671695

672696
def rolling_quantile(arg, window, quantile, min_periods=None, freq=None,
@@ -903,9 +927,9 @@ def call_cython(arg, window, minp, args=(), kwargs={}, **kwds):
903927
expanding_var = _expanding_func(algos.roll_var, 'Expanding variance.',
904928
check_minp=_require_min_periods(1),
905929
additional_kw=_ddof_kw)
906-
expanding_skew = _expanding_func(algos.roll_skew, 'Unbiased expanding skewness.',
930+
expanding_skew = _expanding_func(_roll_skew, 'Unbiased expanding skewness.',
907931
check_minp=_require_min_periods(3))
908-
expanding_kurt = _expanding_func(algos.roll_kurt, 'Unbiased expanding kurtosis.',
932+
expanding_kurt = _expanding_func(_roll_kurt, 'Unbiased expanding kurtosis.',
909933
check_minp=_require_min_periods(4))
910934

911935

0 commit comments

Comments
 (0)