Skip to content

Commit 05fe726

Browse files
dcherianIllviljan
andauthored
Allow specifying output dtype (#131)
Co-authored-by: Illviljan <[email protected]>
1 parent 031979d commit 05fe726

File tree

6 files changed

+141
-17
lines changed

6 files changed

+141
-17
lines changed

flox/aggregations.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@ def generic_aggregate(
5555

5656
def _normalize_dtype(dtype, array_dtype, fill_value=None):
5757
if dtype is None:
58-
if fill_value is not None and np.isnan(fill_value):
59-
dtype = np.floating
60-
else:
61-
dtype = array_dtype
58+
dtype = array_dtype
6259
if dtype is np.floating:
6360
# mean, std, var always result in floating
6461
# but we preserve the array's dtype if it is floating
@@ -68,6 +65,8 @@ def _normalize_dtype(dtype, array_dtype, fill_value=None):
6865
dtype = np.dtype("float64")
6966
elif not isinstance(dtype, np.dtype):
7067
dtype = np.dtype(dtype)
68+
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
69+
dtype = np.result_type(dtype, fill_value)
7170
return dtype
7271

7372

@@ -465,6 +464,7 @@ def _zip_index(array_, idx_):
465464

466465
def _initialize_aggregation(
467466
func: str | Aggregation,
467+
dtype,
468468
array_dtype,
469469
fill_value,
470470
min_count: int | None,
@@ -484,10 +484,18 @@ def _initialize_aggregation(
484484
else:
485485
raise ValueError("Bad type for func. Expected str or Aggregation")
486486

487-
agg.dtype[func] = _normalize_dtype(agg.dtype[func], array_dtype, fill_value)
487+
# np.dtype(None) == np.dtype("float64")!!!
488+
# so check for not None
489+
if dtype is not None and not isinstance(dtype, np.dtype):
490+
dtype = np.dtype(dtype)
491+
492+
agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value)
488493
agg.dtype["numpy"] = (agg.dtype[func],)
489494
agg.dtype["intermediate"] = [
490-
_normalize_dtype(dtype, array_dtype) for dtype in agg.dtype["intermediate"]
495+
_normalize_dtype(int_dtype, np.result_type(array_dtype, agg.dtype[func]), int_fv)
496+
if int_dtype is None
497+
else int_dtype
498+
for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
491499
]
492500

493501
# Replace sentinel fill values according to dtype

flox/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def _finalize_results(
789789
else:
790790
finalized["groups"] = squeezed["groups"]
791791

792+
finalized[agg.name] = finalized[agg.name].astype(agg.dtype[agg.name], copy=False)
792793
return finalized
793794

794795

@@ -1411,6 +1412,7 @@ def groupby_reduce(
14111412
isbin: T_IsBins = False,
14121413
axis: T_AxesOpt = None,
14131414
fill_value=None,
1415+
dtype: np.typing.DTypeLike = None,
14141416
min_count: int | None = None,
14151417
split_out: int = 1,
14161418
method: T_Method = "map-reduce",
@@ -1444,6 +1446,8 @@ def groupby_reduce(
14441446
Negative integers are normalized using array.ndim
14451447
fill_value : Any
14461448
Value to assign when a label in ``expected_groups`` is not present.
1449+
dtype: data-type , optional
1450+
DType for the output. Can be anything that is accepted by ``np.dtype``.
14471451
min_count : int, default: None
14481452
The required number of valid values to perform the operation. If
14491453
fewer than min_count non-NA values are present the result will be
@@ -1621,7 +1625,7 @@ def groupby_reduce(
16211625
fill_value = np.nan
16221626

16231627
kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
1624-
agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
1628+
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
16251629

16261630
if not has_dask:
16271631
results = _reduce_blockwise(

flox/xarray.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def xarray_reduce(
6363
dim: Dims | ellipsis = None,
6464
split_out: int = 1,
6565
fill_value=None,
66+
dtype: np.typing.DTypeLike = None,
6667
method: str = "map-reduce",
6768
engine: str = "numpy",
6869
keep_attrs: bool | None = True,
@@ -98,6 +99,8 @@ def xarray_reduce(
9899
fill_value
99100
Value used for missing groups in the output i.e. when one of the labels
100101
in ``expected_groups`` is not actually present in ``by``.
102+
dtype: data-type, optional
103+
DType for the output. Can be anything accepted by ``np.dtype``.
101104
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
102105
Strategy for reduction of dask arrays only:
103106
* ``"map-reduce"``:
@@ -387,7 +390,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
387390
exclude_dims=set(dim_tuple),
388391
output_core_dims=[group_names],
389392
dask="allowed",
390-
dask_gufunc_kwargs=dict(output_sizes=group_sizes),
393+
dask_gufunc_kwargs=dict(
394+
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
395+
),
391396
keep_attrs=keep_attrs,
392397
kwargs={
393398
"func": func,
@@ -403,6 +408,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
403408
"expected_groups": tuple(expected_groups),
404409
"isbin": isbins,
405410
"finalize_kwargs": finalize_kwargs,
411+
"dtype": dtype,
406412
},
407413
)
408414

tests/__init__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,28 +80,39 @@ def raise_if_dask_computes(max_computes=0):
8080
return dask.config.set(scheduler=scheduler)
8181

8282

83-
def assert_equal(a, b):
83+
def assert_equal(a, b, tolerance=None):
8484
__tracebackhide__ = True
8585

8686
if isinstance(a, list):
8787
a = np.array(a)
8888
if isinstance(b, list):
8989
b = np.array(b)
90+
9091
if isinstance(a, pd_types) or isinstance(b, pd_types):
9192
pd.testing.assert_index_equal(a, b)
92-
elif has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
93+
return
94+
if has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
9395
xr.testing.assert_identical(a, b)
94-
elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
96+
return
97+
98+
if tolerance is None and (
99+
np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64)
100+
):
101+
tolerance = {"atol": 1e-18, "rtol": 1e-15}
102+
else:
103+
tolerance = {}
104+
105+
if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
95106
# sometimes it's nice to see values and shapes
96107
# rather than being dropped into some file in dask
97-
np.testing.assert_allclose(a, b)
108+
np.testing.assert_allclose(a, b, **tolerance)
98109
# does some validation of the dask graph
99110
da.utils.assert_eq(a, b, equal_nan=True)
100111
else:
101112
if a.dtype != b.dtype:
102113
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")
103114

104-
np.testing.assert_allclose(a, b, equal_nan=True)
115+
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)
105116

106117

107118
@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])

tests/test_core.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
185185
if "var" in func or "std" in func:
186186
finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}]
187187
fill_value = np.nan
188+
tolerance = {"rtol": 1e-14, "atol": 1e-16}
188189
else:
189190
fill_value = None
191+
tolerance = None
190192

191193
for kwargs in finalize_kwargs:
192194
flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
@@ -207,7 +209,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
207209
assert_equal(actual_group, expect)
208210
if "arg" in func:
209211
assert actual.dtype.kind == "i"
210-
assert_equal(actual, expected)
212+
assert_equal(actual, expected, tolerance)
211213

212214
if not has_dask:
213215
continue
@@ -216,10 +218,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
216218
continue
217219
actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs)
218220
for actual_group, expect in zip(groups, expected_groups):
219-
assert_equal(actual_group, expect)
221+
assert_equal(actual_group, expect, tolerance)
220222
if "arg" in func:
221223
assert actual.dtype.kind == "i"
222-
assert_equal(actual, expected)
224+
assert_equal(actual, expected, tolerance)
223225

224226

225227
@requires_dask
@@ -466,6 +468,11 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
466468
fill_value = False
467469
else:
468470
fill_value = 123
471+
472+
if "var" in func or "std" in func:
473+
tolerance = {"rtol": 1e-14, "atol": 1e-16}
474+
else:
475+
tolerance = None
469476
# tests against the numpy output to make sure dask compute matches
470477
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
471478
rng = np.random.default_rng(12345)
@@ -484,7 +491,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
484491
kwargs.pop("engine")
485492
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
486493
assert_equal(expected_npg, expected)
487-
assert_equal(actual, expected)
494+
assert_equal(actual, expected, tolerance)
488495

489496

490497
@pytest.mark.parametrize("chunks", [None, (2, 2, 3)])
@@ -1025,3 +1032,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty
10251032
method="blockwise",
10261033
)
10271034
assert_equal(expected, actual)
1035+
1036+
1037+
@pytest.mark.parametrize("func", ALL_FUNCS)
1038+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
1039+
def test_dtype(func, dtype, engine):
1040+
if "arg" in func or func in ["any", "all"]:
1041+
pytest.skip()
1042+
arr = np.ones((4, 12), dtype=dtype)
1043+
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
1044+
actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
1045+
assert actual.dtype == np.dtype("float64")

tests/test_xarray.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
pass
2525

2626

27+
tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
28+
np.random.seed(123)
29+
30+
2731
@pytest.mark.parametrize("reindex", [None, False, True])
2832
@pytest.mark.parametrize("min_count", [None, 1, 3])
2933
@pytest.mark.parametrize("add_nan", [True, False])
@@ -488,3 +492,76 @@ def test_mixed_grouping(chunk):
488492
fill_value=0,
489493
)
490494
assert (r.sel(v1=[3, 4, 5]) == 0).all().data
495+
496+
497+
@pytest.mark.parametrize("add_nan", [True, False])
498+
@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
499+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
500+
@pytest.mark.parametrize("chunk", (True, False))
501+
def test_dtype(add_nan, chunk, dtype, dtype_out, engine):
502+
if chunk and not has_dask:
503+
pytest.skip()
504+
505+
xp = dask.array if chunk else np
506+
data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12))
507+
508+
if add_nan:
509+
data[1, ...] = np.nan
510+
data[0, [0, 2]] = np.nan
511+
512+
arr = xr.DataArray(
513+
data,
514+
dims=("x", "t"),
515+
coords={
516+
"labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]))
517+
},
518+
name="arr",
519+
)
520+
kwargs = dict(func="mean", dtype=dtype_out, engine=engine)
521+
actual = xarray_reduce(arr, "labels", **kwargs)
522+
expected = arr.groupby("labels").mean(dtype="float64")
523+
524+
assert actual.dtype == np.dtype("float64")
525+
assert actual.compute().dtype == np.dtype("float64")
526+
xr.testing.assert_allclose(expected, actual, **tolerance64)
527+
528+
actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs)
529+
expected = arr.to_dataset().groupby("labels").mean(dtype="float64")
530+
531+
assert actual.arr.dtype == np.dtype("float64")
532+
assert actual.compute().arr.dtype == np.dtype("float64")
533+
xr.testing.assert_allclose(expected, actual.transpose("labels", ...), **tolerance64)
534+
535+
536+
@pytest.mark.parametrize("chunk", [True, False])
537+
@pytest.mark.parametrize("use_flox", [True, False])
538+
def test_dtype_accumulation(use_flox, chunk):
539+
if chunk and not has_dask:
540+
pytest.skip()
541+
542+
datetimes = pd.date_range("2010-01", "2015-01", freq="6H", inclusive="left")
543+
samples = 10 + np.cos(2 * np.pi * 0.001 * np.arange(len(datetimes))) * 1
544+
samples += np.random.randn(len(datetimes))
545+
samples = samples.astype("float32")
546+
547+
nan_indices = np.random.default_rng().integers(0, len(samples), size=5_000)
548+
samples[nan_indices] = np.nan
549+
550+
da = xr.DataArray(samples, dims=("time",), coords=[datetimes])
551+
if chunk:
552+
da = da.chunk(time=1024)
553+
554+
gb = da.groupby("time.month")
555+
556+
with xr.set_options(use_flox=use_flox):
557+
expected = gb.reduce(np.nanmean)
558+
actual = gb.mean()
559+
xr.testing.assert_allclose(expected, actual)
560+
assert np.issubdtype(actual.dtype, np.float32)
561+
assert np.issubdtype(actual.compute().dtype, np.float32)
562+
563+
expected = gb.reduce(np.nanmean, dtype="float64")
564+
actual = gb.mean(dtype="float64")
565+
assert np.issubdtype(actual.dtype, np.float64)
566+
assert np.issubdtype(actual.compute().dtype, np.float64)
567+
xr.testing.assert_allclose(expected, actual, **tolerance64)

0 commit comments

Comments
 (0)