Skip to content

Support nanfirst, nanlast with simple combine algo #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def __init__(
self.chunk: FuncTuple = _atleast_1d(chunk)
# how to aggregate results after first round of reduction
self.combine: FuncTuple = _atleast_1d(combine)
# simpler reductions used with the "simple combine" algorithm
self.simple_combine = None
# final aggregation
self.aggregate: Callable | str = aggregate if aggregate else self.combine[0]
# finalize results (see mean)
Expand Down Expand Up @@ -577,4 +579,16 @@ def _initialize_aggregation(
else:
agg.min_count = 0

simple_combine = []
for combine in agg.combine:
if isinstance(combine, str):
if combine in ["nanfirst", "nanlast"]:
simple_combine.append(getattr(xrutils, combine))
else:
simple_combine.append(getattr(np, combine))
else:
simple_combine.append(combine)

agg.simple_combine = tuple(simple_combine)

return agg
36 changes: 30 additions & 6 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def _is_minmax_reduction(func: T_Agg) -> bool:
)


def _is_first_last_reduction(func: T_Agg) -> bool:
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]


def _get_expected_groups(by: T_By, sort: bool) -> pd.Index:
if is_duck_dask_array(by):
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
Expand Down Expand Up @@ -954,13 +958,13 @@ def _simple_combine(
results: IntermediateDict = {"groups": unique_groups}
results["intermediates"] = []
axis_ = axis[:-1] + (DUMMY_AXIS,)
for idx, combine in enumerate(agg.combine):
for idx, combine in enumerate(agg.simple_combine):
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_)
assert array.ndim >= 2
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
assert isinstance(combine, str)
result = getattr(np, combine)(array, axis=axis_, keepdims=True)
assert callable(combine)
result = combine(array, axis=axis_, keepdims=True)
if is_aggregate:
# squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
result = result.squeeze(axis=DUMMY_AXIS)
Expand Down Expand Up @@ -1534,11 +1538,17 @@ def _validate_reindex(
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)
if func in ["first", "last"]:
raise ValueError("reindex must be None or False when func is 'first' or 'last.")

if reindex is None:
if all_numpy:
return True

if func in ["first", "last"]:
# have to do the grouped_combine since there's no good fill_value
reindex = False

if method == "blockwise" or _is_arg_reduction(func):
reindex = False

Expand All @@ -1552,6 +1562,7 @@ def _validate_reindex(
reindex = True

assert isinstance(reindex, bool)

return reindex


Expand Down Expand Up @@ -1875,6 +1886,21 @@ def groupby_reduce(
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
nax = len(axis_)

has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)

if _is_first_last_reduction(func):
if has_dask and nax != 1:
raise ValueError(
"For dask arrays: first, last, nanfirst, nanlast reductions are "
"only supported along a single axis. Please reshape appropriately."
)

elif nax not in [1, by_.ndim]:
raise ValueError(
"first, last, nanfirst, nanlast reductions are only supported "
"along a single axis or when reducing across all dimensions of `by`."
)

# TODO: make sure expected_groups is unique
if nax == 1 and by_.ndim > 1 and expected_groups is None:
if not any_by_dask:
Expand All @@ -1898,8 +1924,6 @@ def groupby_reduce(
axis_ = tuple(array.ndim + np.arange(-nax, 0))
nax = len(axis_)

has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)

# When axis is a subset of possible values; then npg will
# apply it to groups that don't exist along a particular axis (for e.g.)
# since these count as a group that is absent. thoo!
Expand Down Expand Up @@ -1986,6 +2010,6 @@ def groupby_reduce(
).reshape(result.shape[:-1] + grp_shape)
groups = final_groups

if _is_minmax_reduction(func) and is_bool_array:
if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
result = result.astype(bool)
return (result, *groups) # type: ignore[return-value] # Unpack not in mypy yet
35 changes: 34 additions & 1 deletion flox/xrutils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# The functions defined here were copied based on the source code
# defined in xarray


import datetime
from typing import Any, Iterable

import numpy as np
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]

try:
import cftime
Expand Down Expand Up @@ -283,3 +283,36 @@ def _contains_cftime_datetimes(array) -> bool:
return isinstance(sample, cftime.datetime)
else:
return False


def _select_along_axis(values, idx, axis):
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
sl = other_ind[:axis] + (idx,) + other_ind[axis:]
return values[sl]


def nanfirst(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
values = np.asarray(values)
axis = normalize_axis_index(axis, values.ndim)
idx_first = np.argmax(~pd.isnull(values), axis=axis)
result = _select_along_axis(values, idx_first, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
else:
return result


def nanlast(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
values = np.asarray(values)
axis = normalize_axis_index(axis, values.ndim)
rev = (slice(None),) * axis + (slice(None, None, -1),)
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
result = _select_along_axis(values, idx_last, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
else:
return result
113 changes: 93 additions & 20 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import itertools
import warnings
from functools import partial, reduce
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

import numpy as np
import pandas as pd
import pytest
from numpy_groupies.aggregate_numpy import aggregate

from flox import xrutils
from flox.aggregations import Aggregation
from flox.core import (
_convert_expected_groups_to_index,
Expand Down Expand Up @@ -53,6 +54,7 @@ def dask_array_ones(*args):
"sum",
"nansum",
"argmax",
"nanfirst",
pytest.param("nanargmax", marks=(pytest.mark.skip,)),
"prod",
"nanprod",
Expand All @@ -70,6 +72,7 @@ def dask_array_ones(*args):
pytest.param("nanargmin", marks=(pytest.mark.skip,)),
"any",
"all",
"nanlast",
pytest.param("median", marks=(pytest.mark.skip,)),
pytest.param("nanmedian", marks=(pytest.mark.skip,)),
)
Expand All @@ -78,6 +81,21 @@ def dask_array_ones(*args):
from flox.core import T_Engine, T_ExpectedGroupsOpt, T_Func2


def _get_array_func(func: str) -> Callable:
if func == "count":

def npfunc(x):
x = np.asarray(x)
return (~np.isnan(x)).sum()

elif func in ["nanfirst", "nanlast"]:
npfunc = getattr(xrutils, func)
else:
npfunc = getattr(np, func)

return npfunc


def test_alignment_error():
da = np.ones((12,))
labels = np.ones((5,))
Expand Down Expand Up @@ -217,6 +235,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
if "arg" in func and add_nan_by:
array_[..., nanmask] = np.nan
expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs)
# elif func in ["first", "last"]:
# expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
elif func in ["nanfirst", "nanlast"]:
expected = getattr(xrutils, func)(array_[..., ~nanmask], axis=-1, **kwargs)
else:
expected = getattr(np, func)(array_[..., ~nanmask], axis=-1, **kwargs)
for _ in range(nby):
Expand All @@ -241,7 +263,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
call = partial(
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
)
if "arg" in func and reindex is True:
if ("arg" in func or func in ["first", "last"]) and reindex is True:
# simple_combine with argreductions not supported right now
with pytest.raises(NotImplementedError):
call()
Expand Down Expand Up @@ -486,6 +508,28 @@ def test_dask_reduce_axis_subset():
)


@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
@pytest.mark.parametrize("axis", [(0, 1)])
def test_first_last_disallowed(axis, func):
with pytest.raises(ValueError):
groupby_reduce(np.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=axis)


@requires_dask
@pytest.mark.parametrize("func", ["nanfirst", "nanlast"])
@pytest.mark.parametrize("axis", [None, (0, 1, 2)])
def test_nanfirst_nanlast_disallowed_dask(axis, func):
with pytest.raises(ValueError):
groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=axis)


@requires_dask
@pytest.mark.parametrize("func", ["first", "last"])
def test_first_last_disallowed_dask(func):
with pytest.raises(NotImplementedError):
groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=-1)


@requires_dask
@pytest.mark.parametrize("func", ALL_FUNCS)
@pytest.mark.parametrize(
Expand All @@ -495,8 +539,34 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
if "arg" in func and engine == "flox":
pytest.skip()

if not isinstance(axis, int) and "arg" in func and (axis is None or len(axis) > 1):
pytest.skip()
if not isinstance(axis, int):
if "arg" in func and (axis is None or len(axis) > 1):
pytest.skip()
if ("first" in func or "last" in func) and (axis is not None and len(axis) not in [1, 3]):
pytest.skip()

if func in ["all", "any"]:
fill_value = False
else:
fill_value = 123

if "var" in func or "std" in func:
tolerance = {"rtol": 1e-14, "atol": 1e-16}
else:
tolerance = None
# tests against the numpy output to make sure dask compute matches
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
rng = np.random.default_rng(12345)
array = rng.random(by.shape)
kwargs = dict(
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
)
expected, _ = groupby_reduce(array, by, **kwargs)
if engine == "flox":
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)

if func in ["all", "any"]:
fill_value = False
else:
Expand All @@ -513,17 +583,23 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
kwargs = dict(
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
)
expected, _ = groupby_reduce(array, by, **kwargs)
if engine == "flox":
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)

if ("first" in func or "last" in func) and (
axis is None or (not isinstance(axis, int) and len(axis) != 1)
):
return

with raise_if_dask_computes():
actual, _ = groupby_reduce(
da.from_array(array, chunks=(-1, 2, 3)),
da.from_array(by, chunks=(-1, 2, 2)),
**kwargs,
)
expected, _ = groupby_reduce(array, by, **kwargs)
if engine == "flox":
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)
assert_equal(actual, expected, tolerance)


Expand Down Expand Up @@ -751,23 +827,17 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine):
if chunks is not None and not has_dask:
pytest.skip()

if func == "count":

def npfunc(x):
x = np.asarray(x)
return (~np.isnan(x)).sum()

else:
npfunc = getattr(np, func)

npfunc = _get_array_func(func)
by = np.array([1, 2, 3, 1, 2, 3])
array = np.array([np.nan, 1, 1, np.nan, 1, 1])
if chunks:
array = dask.array.from_array(array, chunks)
actual, _ = groupby_reduce(
array, by, func=func, engine=engine, fill_value=fill_value, expected_groups=[0, 1, 2, 3]
)
expected = np.array([fill_value, fill_value, npfunc([1.0, 1.0]), npfunc([1.0, 1.0])])
expected = np.array(
[fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)]
)
assert_equal(actual, expected)


Expand Down Expand Up @@ -832,6 +902,8 @@ def test_cohorts_nd_by(func, method, axis, engine):

if axis is not None and method != "map-reduce":
pytest.xfail()
if axis is None and ("first" in func or "last" in func):
pytest.skip()

kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value)
actual, groups = groupby_reduce(array, by, **kwargs)
Expand Down Expand Up @@ -897,7 +969,8 @@ def test_bool_reductions(func, engine):
pytest.skip()
groups = np.array([1, 1, 1])
data = np.array([True, True, False])
expected = np.expand_dims(getattr(np, func)(data), -1)
npfunc = _get_array_func(func)
expected = np.expand_dims(npfunc(data, axis=0), -1)
actual, _ = groupby_reduce(data, groups, func=func, engine=engine)
assert_equal(expected, actual)

Expand Down