diff --git a/flox/xarray.py b/flox/xarray.py index 15e2ff148..358b57abd 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -366,7 +366,7 @@ def wrapper(array, *by, func, skipna, **kwargs): if all(d not in ds[var].dims for d in dim): actual[var] = ds[var] - for name, expect in zip(group_names, expected_groups): + for name, expect, by_ in zip(group_names, expected_groups, by): # Can't remove this till xarray handles IntervalIndex if isinstance(expect, pd.IntervalIndex): expect = expect.to_numpy() @@ -382,6 +382,8 @@ def wrapper(array, *by, func, skipna, **kwargs): actual = actual.set_coords(levelnames) else: actual[name] = expect + if keep_attrs: + actual[name].attrs = by_.attrs if unindexed_dims: actual = actual.drop_vars(unindexed_dims)