Skip to content

Commit cfd8aed

Browse files
committed
Allow specifying output dtype
Closes pydata/xarray#6902
1 parent e405517 commit cfd8aed

File tree

4 files changed

+42
-2
lines changed

4 files changed

+42
-2
lines changed

flox/core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,7 @@ def groupby_reduce(
13641364
isbin: bool = False,
13651365
axis=None,
13661366
fill_value=None,
1367+
dtype=None,
13671368
min_count: int | None = None,
13681369
split_out: int = 1,
13691370
method: str = "map-reduce",
@@ -1566,8 +1567,13 @@ def groupby_reduce(
15661567
# overwrite than when min_count is set
15671568
fill_value = np.nan
15681569

1570+
if dtype is not None and not isinstance(dtype, np.dtype):
1571+
dtype = np.dtype(dtype)
1572+
15691573
kwargs = dict(axis=axis, fill_value=fill_value, engine=engine)
1570-
agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
1574+
agg = _initialize_aggregation(
1575+
func, array.dtype if dtype is None else dtype, fill_value, min_count, finalize_kwargs
1576+
)
15711577

15721578
if not has_dask:
15731579
results = _reduce_blockwise(

flox/xarray.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def xarray_reduce(
6060
dim: Hashable = None,
6161
split_out: int = 1,
6262
fill_value=None,
63+
dtype=None,
6364
method: str = "map-reduce",
6465
engine: str = "flox",
6566
keep_attrs: bool | None = True,
@@ -95,6 +96,8 @@ def xarray_reduce(
9596
fill_value
9697
Value used for missing groups in the output i.e. when one of the labels
9798
in ``expected_groups`` is not actually present in ``by``.
99+
dtype: str
100+
DType for the output (DataArray only).
98101
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
99102
Strategy for reduction of dask arrays only:
100103
* ``"map-reduce"``:
@@ -341,7 +344,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
341344
exclude_dims=set(dim),
342345
output_core_dims=[group_names],
343346
dask="allowed",
344-
dask_gufunc_kwargs=dict(output_sizes=group_sizes),
347+
dask_gufunc_kwargs=dict(
348+
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
349+
),
345350
keep_attrs=keep_attrs,
346351
kwargs={
347352
"func": func,
@@ -357,6 +362,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
357362
"expected_groups": tuple(expected_groups),
358363
"isbin": isbin,
359364
"finalize_kwargs": finalize_kwargs,
365+
"dtype": dtype,
360366
},
361367
)
362368

tests/test_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,3 +1009,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty
10091009
method="blockwise",
10101010
)
10111011
assert_equal(expected, actual)
1012+
1013+
1014+
@pytest.mark.parametrize("func", ALL_FUNCS)
1015+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
1016+
def test_dtype(func, dtype, engine):
1017+
if "arg" in func or func in ["any", "all"]:
1018+
pytest.skip()
1019+
arr = np.ones((4, 12), dtype=dtype)
1020+
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
1021+
actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
1022+
assert actual.dtype == np.dtype("float64")

tests/test_xarray.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,20 @@ def test_mixed_grouping(chunk):
485485
fill_value=0,
486486
)
487487
assert (r.sel(v1=[3, 4, 5]) == 0).all().data
488+
489+
490+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
491+
def test_dtype(dtype, engine):
492+
arr = xr.DataArray(
493+
data=np.ones((4, 12), dtype=dtype),
494+
dims=("x", "t"),
495+
coords={
496+
"labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]))
497+
},
498+
)
499+
actual = xarray_reduce(arr, "labels", func="mean", dtype=np.float64)
500+
assert actual.dtype == np.dtype("float64")
501+
502+
actual = xarray_reduce(arr.chunk({"x": 1}), arr.labels, func="mean", dtype=np.float64)
503+
assert actual.dtype == np.dtype("float64")
504+
assert actual.compute().dtype == np.dtype("float64")

0 commit comments

Comments
 (0)