-
Notifications
You must be signed in to change notification settings - Fork 18
Allow specifying output dtype #131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I don't really ever specify dtypes so a close review would be very valuable if you have the time. |
Thanks for the super quick fix!!! I will try this and get back to you, need to read the developer docs first for using this branch. |
The easiest way would be to clone this repo and then use the github cli
|
It seems to be working fine, except I would expect A == B below. Output dtype mutates, and values don't match ground truth. The rest of them match, which is excellent! ground_truth_fp64 = da.groupby("time.month").apply(np.mean, dtype='float64').compute()
ground_truth_fp32 = da.groupby("time.month").apply(np.mean).compute() (A)
flox_as_is = da.groupby("time.month").mean().compute() (B)
dtype_in_mean = da.groupby("time.month").mean(dtype="float64").compute()
cast_input_dtype = da.astype("float64").groupby("time.month").mean().compute()
cast_and_in_mean = da.astype("float64").groupby("time.month").mean(dtype="float64").compute() **ground truth fp64**
input dtype: float32 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
6.9561013470 7.1530014089]
**ground truth fp32** (A)
input dtype: float32 output_dtype float32
[6.5968837738 7.7664070129 7.1014094353 8.2660837173 8.4281892776
8.3720579147 6.9553837776 7.0804729462 7.0591216087 6.8767600060
6.9561014175 7.1530008316]
**As is Flox groupby mean()** (B)
input dtype: float32 output_dtype float64
[6.5968835755 7.7664066341 7.1014098027 8.2660843461 8.4281902792
8.3720576534 6.9553839571 7.0804727753 7.0591213650 6.8767606127
6.9561012551 7.1530014440]
**Flox with dtype in mean()**
input dtype: float32 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
6.9561013470 7.1530014089]
**Flox with input cast to float64, mean() as is (also ground truth)**
input dtype: float64 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
6.9561013470 7.1530014089]
**Flox with input cast to float64, mean() also with dtype**
input dtype: float64 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
6.9561013470 7.1530014089] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few typing suggestions mostly. I don't have a lot of experience either what should happen here besides the output being correct dtype.
Thanks for all your suggestions. Now tests are failing only for engine="flox" and input dtype=float64.
I'm not really sure why. probably a small roundoff difference. @tasansal Does your sample dataset contain NaNs? The promotion of the output to float64 in (B) is funny. |
Yes it has NaNs. The normal Xarray mean drops them and averages to float32. I also have Bottleneck and numpy_groupies installed, could that have anything to do with it? The precision is a bit strange, I would expect it to match the numpy and flox=False version of Xarray. Let me know if I can help any further! |
assert_equal is maybe too accurate? np.testing.assert_allclose(actual, expected, rtol=0, atol=np.finfo(actual.dtype).eps) |
Yeah allclose is what's needed. @tasansal I can't reproduce your results even after adding NaNs. Can you try to construct a synthetic example that shows the problem? Does passing |
@dcherian here you go. I am calling Xarray without Flox "Vanilla Xarray," FYI. Conclusions:
Step 1:
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
np.set_printoptions(precision=18)
# Tolerances for fp64
# fp64 can have between 15-18 decimal points accouracy
RTOL = 1e-15
ATOL = 1e-18 Step 2: Make synthetic data with NaNs. Synthetic data is cast to float32 as the real world example. datetimes = pd.date_range("2010-01", "2015-01", freq="6H", inclusive="left")
samples = 10 + np.cos(2 * np.pi * 0.001 * np.arange(len(datetimes))) * 1
samples += np.random.randn(len(datetimes))
samples = samples.astype('float32')
nan_indices = np.random.default_rng().integers(0, len(samples), size=5_000)
samples[nan_indices] = np.nan Step 3: Plot plt.figure(figsize=(15, 3))
line = plt.plot(datetimes, samples) Step 4: Calculate ground truth with NumPy
ground_truth_fp32 = []
ground_truth_fp64 = []
cast_before_fp64 = []
for month in range(1, 13):
month_mask = datetimes.month == month
month_data = samples[month_mask]
ground_truth_fp32.append(np.nanmean(month_data))
ground_truth_fp64.append(np.nanmean(month_data, dtype="float64"))
cast_before_fp64.append(np.nanmean(month_data.astype("float64")))
ground_truth_fp32 = np.asarray(ground_truth_fp32)
ground_truth_fp64 = np.asarray(ground_truth_fp64)
cast_before_fp64 = np.asarray(cast_before_fp64)
print("fp32 ground truth:\n", ground_truth_fp32)
print("fp64 ground truth:\n", ground_truth_fp64)
print("fp64 casted ground truth:\n", ground_truth_fp64)
print("cast equals dtype arg:", np.equal(ground_truth_fp64, cast_before_fp64).all())
---
fp32 ground truth:
[10.1700325 10.009565 10.092217 9.943765 9.769657 9.922329
9.994079 9.99988 10.225714 10.153132 10.065705 9.93171 ]
fp64 ground truth:
[10.170032398493728 10.009565626110946 10.092218036775465
9.94376481957987 9.769656868304237 9.922329294098008
9.994078634636194 9.999879809414468 10.225713932633004
10.153132471137742 10.065704485859733 9.93170997271171 ]
fp64 casted ground truth:
[10.170032398493728 10.009565626110946 10.092218036775465
9.94376481957987 9.769656868304237 9.922329294098008
9.994078634636194 9.999879809414468 10.225713932633004
10.153132471137742 10.065704485859733 9.93170997271171 ]
cast equals dtype arg: True Step 5: Pandas groupby mean for reference series = pd.Series(samples, index=datetimes)
pd_mean_fp32 = series.groupby(series.index.month).mean().to_numpy()
pd_mean_fp64 = series.astype("float64").groupby(series.index.month).mean().to_numpy()
print("pandas fp32 all equal:", np.equal(pd_mean_fp32, ground_truth_fp32).all())
print("pandas fp32 all close:", np.allclose(pd_mean_fp32, ground_truth_fp32))
print("pandas fp64 all equal:", np.equal(pd_mean_fp64, ground_truth_fp64).all())
print("pandas fp64 all close:", np.allclose(pd_mean_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", pd_mean_fp32.dtype)
print("fp64 out dtype:", pd_mean_fp64.dtype)
---
pandas fp32 all equal: False
pandas fp32 all close: True
pandas fp64 all equal: True
pandas fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64 Step 6: Xarray + Flox + No Dask da = xr.DataArray(samples, dims=("time",), coords=[datetimes])
da_mean_nodask_flox_fp32 = da.groupby("time.month").mean().values
da_mean_nodask_flox_fp64 = da.groupby("time.month").mean(dtype="float64").values
print("flox_nodask fp32 all equal:", np.equal(da_mean_nodask_flox_fp32, ground_truth_fp32).all())
print("flox_nodask fp32 all close:", np.allclose(da_mean_nodask_flox_fp32, ground_truth_fp32))
print("flox_nodask fp64 all equal:", np.equal(da_mean_nodask_flox_fp64, ground_truth_fp64).all())
print("flox_nodask fp64 all close:", np.allclose(da_mean_nodask_flox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_nodask_flox_fp32.dtype)
print("fp64 out dtype:", da_mean_nodask_flox_fp64.dtype)
---
flox_nodask fp32 all equal: False
flox_nodask fp32 all close: True
flox_nodask fp64 all equal: True
flox_nodask fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64 Step 7: Xarray + Flox + With Dask (Chunked) da_mean_dask_flox_fp32 = da.chunk(time=1024).groupby("time.month").mean().values
da_mean_dask_flox_fp64 = da.chunk(time=1024).groupby("time.month").mean(dtype="float64").values
print("flox dask fp32 all equal:", np.equal(da_mean_dask_flox_fp32, ground_truth_fp32).all())
print("flox dask fp32 all close:", np.allclose(da_mean_dask_flox_fp32, ground_truth_fp32))
print("flox dask fp64 all equal:", np.equal(da_mean_dask_flox_fp64, ground_truth_fp64).all())
print("flox dask fp64 all close:", np.allclose(da_mean_dask_flox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_dask_flox_fp32.dtype)
print("fp64 out dtype:", da_mean_dask_flox_fp64.dtype)
---
flox dask fp32 all equal: False
flox dask fp32 all close: True
flox dask fp64 all equal: True
flox dask fp64 all close: True
fp32 out dtype: float64 # HERE
fp64 out dtype: float64 Now turn off Flox with Step 8: Xarray + No Flox + No Dask da_mean_nodask_noflox_fp32 = da.groupby("time.month").mean().values
da_mean_nodask_noflox_fp64 = da.groupby("time.month").mean(dtype="float64").values
print("no flox nodask fp32 all equal:", np.equal(da_mean_nodask_noflox_fp32, ground_truth_fp32).all())
print("no flox nodask fp32 all close:", np.allclose(da_mean_nodask_noflox_fp32, ground_truth_fp32))
print("no flox nodask fp64 all equal:", np.equal(da_mean_nodask_noflox_fp64, ground_truth_fp64).all())
print("no flox nodask fp64 all close:", np.allclose(da_mean_nodask_noflox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_nodask_noflox_fp32.dtype)
print("fp64 out dtype:", da_mean_nodask_noflox_fp64.dtype)
---
no flox nodask fp32 all equal: True
no flox nodask fp32 all close: True
no flox nodask fp64 all equal: True
no flox nodask fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64 Step 9: Xarray No Flox + With Dask (Chunked) da_mean_dask_noflox_fp32 = da.chunk(time=1024).groupby("time.month").mean().values
da_mean_dask_noflox_fp64 = da.chunk(time=1024).groupby("time.month").mean(dtype="float64").values
print("no flox nodask fp32 all equal:", np.equal(da_mean_dask_noflox_fp32, ground_truth_fp32).all())
print("no flox nodask fp32 all close:", np.allclose(da_mean_dask_noflox_fp32, ground_truth_fp32))
print("no flox nodask fp64 all equal:", np.equal(da_mean_dask_noflox_fp64, ground_truth_fp64).all())
print("no flox nodask fp64 all close:", np.allclose(da_mean_dask_noflox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_dask_noflox_fp32.dtype)
print("fp64 out dtype:", da_mean_dask_noflox_fp64.dtype)
---
no flox nodask fp32 all equal: False
no flox nodask fp32 all close: True
no flox nodask fp64 all equal: True
no flox nodask fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64 |
Thanks @tasansal this is very valuable! Would you mind writing up your post as a test and sending in a PR? It'd be nice to give you credit and would immensely help tracking down the fix. We can |
Sure thing, do you think the test should be upstreamed to Xarray, or is Flox where the test should belong? |
Good call, Xarray's One thing to note is that we don't expect dask + Xarray to "match" numpy + Xarray in general (depends on what you mean by match) Because dask accumulates in chunks first you'll have different roundoff error. |
Co-authored-by: Illviljan <[email protected]>
This reverts commit 4dab89a.
Ok, this didn't fix the particular dtype issues I'm having. Maybe it's time to merge this then? These tests failed when I added a dtype check, 4dab89a:
|
From #131 (comment)
@dcherian and @Illviljan This one still is inconsistent. Let me elaborate. When you run a mean operation on With this PR, most examples in the comment have consistent behavior. However, when Xarray + Flox + Dask (must be chunked) edge case returns a Note that the inconsistency is with the return I suggest this should be fixed before we merge for consistency across the board. |
* main: Update ci-additional.yaml (#167) Refactor before redoing cohorts (#164) Fix mypy errors in core.py (#150) Add link to numpy_groupies (#160) Bump codecov/codecov-action from 3.1.0 to 3.1.1 (#159) Use math.prod instead of np.prod (#157) Remove None output from _get_expected_groups (#152) Fix mypy errors in xarray.py, xrutils.py, cache.py (#144) Raise error if multiple by's are used with Ellipsis (#149) pre-commit autoupdate (#148) Add mypy ignores (#146) Get pre commit bot to update (#145) Remove duplicate examples headers (#147) Add ci additional (#143) Bump mamba-org/provision-with-micromamba from 12 to 13 (#141) Add ASV benchmark CI workflow (#139) Fix func count for dtype O with numpy and numba (#138)
* main: Add a dtype check for numpy arrays in assert_equal (#158)
@tasansal should be all good here. Would you mind running your internal test suite and confirming? |
@dcherian, thanks for the awesome fixes! Sorry just catching up; I will test it and let you know if anything acts up in a few days. Will this be in 0.5.11 ? |
Yes |
* main: (29 commits) Major fix to subset_to_blocks (#173) Performance improvements for cohorts detection (#172) Remove split_out (#170) Deprecate resample_reduce (#169) More efficient cohorts. (#165) Allow specifying output dtype (#131) Add a dtype check for numpy arrays in assert_equal (#158) Update ci-additional.yaml (#167) Refactor before redoing cohorts (#164) Fix mypy errors in core.py (#150) Add link to numpy_groupies (#160) Bump codecov/codecov-action from 3.1.0 to 3.1.1 (#159) Use math.prod instead of np.prod (#157) Remove None output from _get_expected_groups (#152) Fix mypy errors in xarray.py, xrutils.py, cache.py (#144) Raise error if multiple by's are used with Ellipsis (#149) pre-commit autoupdate (#148) Add mypy ignores (#146) Get pre commit bot to update (#145) Remove duplicate examples headers (#147) ...
Closes pydata/xarray#6902
cc @Illviljan @tasansal