Skip to content

Commit e5150c9

Browse files
committed
Add proof of concept dask-friendly datetime encoding
1 parent 03ec3cb commit e5150c9

File tree

3 files changed

+233
-2
lines changed

3 files changed

+233
-2
lines changed

xarray/coding/times.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from xarray.core import indexing
2424
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
25+
from xarray.core.duck_array_ops import asarray
2526
from xarray.core.formatting import first_n_items, format_timestamp, last_item
2627
from xarray.core.pdcompat import nanosecond_precision_timestamp
2728
from xarray.core.pycompat import is_duck_dask_array
@@ -672,6 +673,43 @@ def encode_cf_datetime(
672673
units: str | None = None,
673674
calendar: str | None = None,
674675
dtype: np.dtype | None = None,
676+
):
677+
dates = asarray(dates)
678+
if isinstance(dates, np.ndarray):
679+
return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)
680+
elif is_duck_dask_array(dates):
681+
return _lazily_encode_cf_datetime(dates, units, calendar, dtype)
682+
683+
684+
def _cast_to_dtype_safe(num, dtype) -> np.ndarray:
685+
cast_num = np.asarray(num, dtype=dtype)
686+
687+
if np.issubdtype(dtype, np.integer):
688+
if not (num == cast_num).all():
689+
raise ValueError(
690+
f"Not possible to cast all encoded times from dtype {num.dtype!r} "
691+
f"to dtype {dtype!r} without changing any of their values. "
692+
f"Consider removing the dtype encoding or explicitly switching to "
693+
f"a dtype encoding with a higher precision."
694+
)
695+
else:
696+
if np.isinf(cast_num).any():
697+
raise OverflowError(
698+
f"Not possible to cast encoded times from dtype {num.dtype!r} "
699+
f"to dtype {dtype!r} without overflow. Consider removing the "
700+
f"dtype encoding or explicitly switching to a dtype encoding "
701+
f"with a higher precision."
702+
)
703+
704+
return cast_num
705+
706+
707+
def _eagerly_encode_cf_datetime(
708+
dates,
709+
units: str | None = None,
710+
calendar: str | None = None,
711+
dtype: np.dtype | None = None,
712+
called_via_map_blocks: bool = False,
675713
) -> tuple[np.ndarray, str, str]:
676714
"""Given an array of datetime objects, returns the tuple `(num, units,
677715
calendar)` suitable for a CF compliant time variable.
@@ -731,7 +769,7 @@ def encode_cf_datetime(
731769
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
732770
f"Set encoding['dtype'] to floating point dtype to silence this warning."
733771
)
734-
elif np.issubdtype(dtype, np.integer):
772+
elif np.issubdtype(dtype, np.integer) and not called_via_map_blocks:
735773
new_units = f"{needed_units} since {format_timestamp(ref_date)}"
736774
emit_user_level_warning(
737775
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
@@ -752,7 +790,53 @@ def encode_cf_datetime(
752790
# we already covered for this in pandas-based flow
753791
num = cast_to_int_if_safe(num)
754792

755-
return (num, units, calendar)
793+
if dtype is not None:
794+
num = _cast_to_dtype_safe(num, dtype)
795+
796+
if called_via_map_blocks:
797+
return num
798+
else:
799+
return (num, units, calendar)
800+
801+
802+
def _lazily_encode_cf_datetime(
803+
dates,
804+
units: str | None = None,
805+
calendar: str | None = None,
806+
dtype: np.dtype | None = None,
807+
):
808+
import dask.array
809+
810+
if calendar is None:
811+
# This will only trigger minor compute if dates is an object dtype array.
812+
calendar = infer_calendar_name(dates)
813+
814+
if units is None and dtype is None:
815+
if dates.dtype == "O":
816+
units = "microseconds since 1970-01-01"
817+
dtype = np.dtype("int64")
818+
else:
819+
units = "nanoseconds since 1970-01-01"
820+
dtype = np.dtype("int64")
821+
822+
if units is None or dtype is None:
823+
raise ValueError(
824+
f"When encoding chunked arrays of datetime values, both the units and "
825+
f"dtype must be prescribed or both must be unprescribed. Prescribing "
826+
f"only one or the other is not currently supported. Got a units "
827+
f"encoding of {units} and a dtype encoding of {dtype}."
828+
)
829+
830+
num = dask.array.map_blocks(
831+
_eagerly_encode_cf_datetime,
832+
dates,
833+
units,
834+
calendar,
835+
dtype,
836+
called_via_map_blocks=True,
837+
dtype=dtype,
838+
)
839+
return num, units, calendar
756840

757841

758842
def encode_cf_timedelta(

xarray/tests/test_backends.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2809,6 +2809,15 @@ def test_attributes(self, obj) -> None:
28092809
with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."):
28102810
ds.to_zarr(store_target, **self.version_kwargs)
28112811

2812+
@requires_dask
2813+
def test_chunked_datetime64(self) -> None:
2814+
# Copied from @malmans2's PR #8253
2815+
original = create_test_data().astype("datetime64[ns]").chunk(1)
2816+
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
2817+
for name, actual_var in actual.variables.items():
2818+
assert original[name].chunks == actual_var.chunks
2819+
assert original.chunks == actual.chunks
2820+
28122821
def test_vectorized_indexing_negative_step(self) -> None:
28132822
if not has_dask:
28142823
pytest.xfail(

xarray/tests/test_coding_times.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
cftime_range,
1717
coding,
1818
conventions,
19+
date_range,
1920
decode_cf,
2021
)
2122
from xarray.coding.times import (
@@ -30,6 +31,7 @@
3031
from xarray.coding.variables import SerializationWarning
3132
from xarray.conventions import _update_bounds_attributes, cf_encoder
3233
from xarray.core.common import contains_cftime_datetimes
34+
from xarray.core.pycompat import is_duck_dask_array
3335
from xarray.testing import assert_equal, assert_identical
3436
from xarray.tests import (
3537
FirstElementAccessibleArray,
@@ -1387,3 +1389,139 @@ def test_roundtrip_float_times() -> None:
13871389
assert_identical(var, decoded_var)
13881390
assert decoded_var.encoding["units"] == units
13891391
assert decoded_var.encoding["_FillValue"] == fill_value
1392+
1393+
1394+
ENCODE_DATETIME64_VIA_DASK_TESTS = {
1395+
"pandas-encoding-with-prescribed-units-and-dtype": (
1396+
"D",
1397+
"days since 1700-01-01",
1398+
np.dtype("int32"),
1399+
),
1400+
"mixed-cftime-pandas-encoding-with-prescribed-units-and-dtype": (
1401+
"252YS",
1402+
"days since 1700-01-01",
1403+
np.dtype("int32"),
1404+
),
1405+
"pandas-encoding-with-default-units-and-dtype": ("252YS", None, None),
1406+
}
1407+
1408+
1409+
@requires_dask
1410+
@pytest.mark.parametrize(
1411+
("freq", "units", "dtype"),
1412+
ENCODE_DATETIME64_VIA_DASK_TESTS.values(),
1413+
ids=ENCODE_DATETIME64_VIA_DASK_TESTS.keys(),
1414+
)
1415+
def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype):
1416+
import dask.array
1417+
1418+
times = pd.date_range(start="1700", freq=freq, periods=3)
1419+
times = dask.array.from_array(times, chunks=1)
1420+
encoded_times, encoding_units, encoding_calendar = encode_cf_datetime(
1421+
times, units, None, dtype
1422+
)
1423+
1424+
assert is_duck_dask_array(encoded_times)
1425+
assert encoded_times.chunks == times.chunks
1426+
1427+
if units is not None and dtype is not None:
1428+
assert encoding_units == units
1429+
assert encoded_times.dtype == dtype
1430+
else:
1431+
assert encoding_units == "nanoseconds since 1970-01-01"
1432+
assert encoded_times.dtype == np.dtype("int64")
1433+
1434+
assert encoding_calendar == "proleptic_gregorian"
1435+
1436+
decoded_times = decode_cf_datetime(encoded_times, encoding_units, encoding_calendar)
1437+
np.testing.assert_equal(decoded_times, times)
1438+
1439+
1440+
@requires_dask
1441+
@pytest.mark.parametrize(
1442+
("units", "dtype"), [(None, np.dtype("int32")), ("2000-01-01", None)]
1443+
)
1444+
def test_encode_cf_datetime_via_dask_error(units, dtype):
1445+
import dask.array
1446+
1447+
times = pd.date_range(start="1700", freq="D", periods=3)
1448+
times = dask.array.from_array(times, chunks=1)
1449+
1450+
with pytest.raises(ValueError, match="When encoding chunked arrays"):
1451+
encode_cf_datetime(times, units, None, dtype)
1452+
1453+
1454+
ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS = {
1455+
"prescribed-units-and-dtype": ("D", "days since 1700-01-01", np.dtype("int32")),
1456+
"default-units-and-dtype": ("252YS", None, None),
1457+
}
1458+
1459+
1460+
@requires_cftime
1461+
@requires_dask
1462+
@pytest.mark.parametrize(
1463+
"calendar",
1464+
["standard", "proleptic_gregorian", "julian", "noleap", "all_leap", "360_day"],
1465+
)
1466+
@pytest.mark.parametrize(
1467+
("freq", "units", "dtype"),
1468+
ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS.values(),
1469+
ids=ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS.keys(),
1470+
)
1471+
def test_encode_cf_datetime_cftime_datetime_via_dask(calendar, freq, units, dtype):
1472+
import dask.array
1473+
1474+
times = cftime_range(start="1700", freq=freq, periods=3, calendar=calendar)
1475+
times = dask.array.from_array(times, chunks=1)
1476+
encoded_times, encoding_units, encoding_calendar = encode_cf_datetime(
1477+
times, units, None, dtype
1478+
)
1479+
1480+
assert is_duck_dask_array(encoded_times)
1481+
assert encoded_times.chunks == times.chunks
1482+
1483+
if units is not None and dtype is not None:
1484+
assert encoding_units == units
1485+
assert encoded_times.dtype == dtype
1486+
else:
1487+
assert encoding_units == "microseconds since 1970-01-01"
1488+
assert encoded_times.dtype == np.int64
1489+
1490+
assert encoding_calendar == calendar
1491+
1492+
decoded_times = decode_cf_datetime(
1493+
encoded_times, encoding_units, encoding_calendar, use_cftime=True
1494+
)
1495+
np.testing.assert_equal(decoded_times, times)
1496+
1497+
1498+
@requires_dask
1499+
@pytest.mark.parametrize(
1500+
"use_cftime", [False, pytest.param(True, marks=requires_cftime)]
1501+
)
1502+
def test_encode_cf_datetime_via_dask_casting_value_error(use_cftime):
1503+
import dask.array
1504+
1505+
times = date_range(start="2000", freq="12h", periods=3, use_cftime=use_cftime)
1506+
times = dask.array.from_array(times, chunks=1)
1507+
units = "days since 2000-01-01"
1508+
dtype = np.int64
1509+
encoded_times, *_ = encode_cf_datetime(times, units, None, dtype)
1510+
with pytest.raises(ValueError, match="Not possible"):
1511+
encoded_times.compute()
1512+
1513+
1514+
@requires_dask
1515+
@pytest.mark.parametrize(
1516+
"use_cftime", [False, pytest.param(True, marks=requires_cftime)]
1517+
)
1518+
def test_encode_cf_datetime_via_dask_casting_overflow_error(use_cftime):
1519+
import dask.array
1520+
1521+
times = date_range(start="1700", freq="252YS", periods=3, use_cftime=use_cftime)
1522+
times = dask.array.from_array(times, chunks=1)
1523+
units = "days since 1700-01-01"
1524+
dtype = np.dtype("float16")
1525+
encoded_times, *_ = encode_cf_datetime(times, units, None, dtype)
1526+
with pytest.raises(OverflowError, match="Not possible"):
1527+
encoded_times.compute()

0 commit comments

Comments
 (0)