Skip to content

Commit 6c5840e

Browse files
Improve performance for backend datetime handling (#7374)
* Add typing to conventions.py use fastpath on recreated variables * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add fastpath * Add fastpath * add typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update times.py * Think mypy found an error here. * Update variables.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update times.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Tuple for mypy38 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Variable is already input to this function Function is about 18% faster without this check. * Don't import DataArray until necessary. Reduces time from 450ms -> 290ms from my open_dataset testing. * Update conventions.py * Only create a Variable if a change has been made. * Don't recreate a unmodified variable * Add ASV test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataset_io.py * Update dataset_io.py * Update dataset_io.py * return early instead of new variables * Update conventions.py * Update conventions.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4f3128b commit 6c5840e

File tree

4 files changed

+214
-149
lines changed

4 files changed

+214
-149
lines changed

xarray/coding/times.py

+74-51
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from datetime import datetime, timedelta
66
from functools import partial
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Callable, Hashable, Union
88

99
import numpy as np
1010
import pandas as pd
@@ -33,6 +33,8 @@
3333
if TYPE_CHECKING:
3434
from xarray.core.types import CFCalendar
3535

36+
T_Name = Union[Hashable, None]
37+
3638
# standard calendars recognized by cftime
3739
_STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"}
3840

@@ -79,7 +81,7 @@
7981
)
8082

8183

82-
def _is_standard_calendar(calendar):
84+
def _is_standard_calendar(calendar: str) -> bool:
8385
return calendar.lower() in _STANDARD_CALENDARS
8486

8587

@@ -103,7 +105,7 @@ def _is_numpy_compatible_time_range(times):
103105
return True
104106

105107

106-
def _netcdf_to_numpy_timeunit(units):
108+
def _netcdf_to_numpy_timeunit(units: str) -> str:
107109
units = units.lower()
108110
if not units.endswith("s"):
109111
units = f"{units}s"
@@ -118,7 +120,7 @@ def _netcdf_to_numpy_timeunit(units):
118120
}[units]
119121

120122

121-
def _ensure_padded_year(ref_date):
123+
def _ensure_padded_year(ref_date: str) -> str:
122124
# Reference dates without a padded year (e.g. since 1-1-1 or since 2-3-4)
123125
# are ambiguous (is it YMD or DMY?). This can lead to some very odd
124126
# behaviour e.g. pandas (via dateutil) passes '1-1-1 00:00:0.0' as
@@ -152,7 +154,7 @@ def _ensure_padded_year(ref_date):
152154
return ref_date_padded
153155

154156

155-
def _unpack_netcdf_time_units(units):
157+
def _unpack_netcdf_time_units(units: str) -> tuple[str, str]:
156158
# CF datetime units follow the format: "UNIT since DATE"
157159
# this parses out the unit and date allowing for extraneous
158160
# whitespace. It also ensures that the year is padded with zeros
@@ -167,7 +169,9 @@ def _unpack_netcdf_time_units(units):
167169
return delta_units, ref_date
168170

169171

170-
def _decode_cf_datetime_dtype(data, units, calendar, use_cftime):
172+
def _decode_cf_datetime_dtype(
173+
data, units: str, calendar: str, use_cftime: bool | None
174+
) -> np.dtype:
171175
# Verify that at least the first and last date can be decoded
172176
# successfully. Otherwise, tracebacks end up swallowed by
173177
# Dataset.__repr__ when users try to view their lazily decoded array.
@@ -194,7 +198,9 @@ def _decode_cf_datetime_dtype(data, units, calendar, use_cftime):
194198
return dtype
195199

196200

197-
def _decode_datetime_with_cftime(num_dates, units, calendar):
201+
def _decode_datetime_with_cftime(
202+
num_dates: np.ndarray, units: str, calendar: str
203+
) -> np.ndarray:
198204
if cftime is None:
199205
raise ModuleNotFoundError("No module named 'cftime'")
200206
if num_dates.size > 0:
@@ -205,7 +211,9 @@ def _decode_datetime_with_cftime(num_dates, units, calendar):
205211
return np.array([], dtype=object)
206212

207213

208-
def _decode_datetime_with_pandas(flat_num_dates, units, calendar):
214+
def _decode_datetime_with_pandas(
215+
flat_num_dates: np.ndarray, units: str, calendar: str
216+
) -> np.ndarray:
209217
if not _is_standard_calendar(calendar):
210218
raise OutOfBoundsDatetime(
211219
"Cannot decode times from a non-standard calendar, {!r}, using "
@@ -250,7 +258,9 @@ def _decode_datetime_with_pandas(flat_num_dates, units, calendar):
250258
return (pd.to_timedelta(flat_num_dates_ns_int, "ns") + ref_date).values
251259

252260

253-
def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None):
261+
def decode_cf_datetime(
262+
num_dates, units: str, calendar: str | None = None, use_cftime: bool | None = None
263+
) -> np.ndarray:
254264
"""Given an array of numeric dates in netCDF format, convert it into a
255265
numpy array of date time objects.
256266
@@ -314,7 +324,7 @@ def to_datetime_unboxed(value, **kwargs):
314324
return result
315325

316326

317-
def decode_cf_timedelta(num_timedeltas, units):
327+
def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray:
318328
"""Given an array of numeric timedeltas in netCDF format, convert it into a
319329
numpy timedelta64[ns] array.
320330
"""
@@ -324,16 +334,18 @@ def decode_cf_timedelta(num_timedeltas, units):
324334
return result.reshape(num_timedeltas.shape)
325335

326336

327-
def _unit_timedelta_cftime(units):
337+
def _unit_timedelta_cftime(units: str) -> timedelta:
328338
return timedelta(microseconds=_US_PER_TIME_DELTA[units])
329339

330340

331-
def _unit_timedelta_numpy(units):
341+
def _unit_timedelta_numpy(units: str) -> np.timedelta64:
332342
numpy_units = _netcdf_to_numpy_timeunit(units)
333343
return np.timedelta64(_NS_PER_TIME_DELTA[numpy_units], "ns")
334344

335345

336-
def _infer_time_units_from_diff(unique_timedeltas):
346+
def _infer_time_units_from_diff(unique_timedeltas) -> str:
347+
unit_timedelta: Callable[[str], timedelta] | Callable[[str], np.timedelta64]
348+
zero_timedelta: timedelta | np.timedelta64
337349
if unique_timedeltas.dtype == np.dtype("O"):
338350
time_units = _NETCDF_TIME_UNITS_CFTIME
339351
unit_timedelta = _unit_timedelta_cftime
@@ -374,7 +386,7 @@ def infer_calendar_name(dates) -> CFCalendar:
374386
raise ValueError("Array does not contain datetime objects.")
375387

376388

377-
def infer_datetime_units(dates):
389+
def infer_datetime_units(dates) -> str:
378390
"""Given an array of datetimes, returns a CF compatible time-unit string of
379391
the form "{time_unit} since {date[0]}", where `time_unit` is 'days',
380392
'hours', 'minutes' or 'seconds' (the first one that can evenly divide all
@@ -394,7 +406,7 @@ def infer_datetime_units(dates):
394406
return f"{units} since {reference_date}"
395407

396408

397-
def format_cftime_datetime(date):
409+
def format_cftime_datetime(date) -> str:
398410
"""Converts a cftime.datetime object to a string with the format:
399411
YYYY-MM-DD HH:MM:SS.UUUUUU
400412
"""
@@ -409,7 +421,7 @@ def format_cftime_datetime(date):
409421
)
410422

411423

412-
def infer_timedelta_units(deltas):
424+
def infer_timedelta_units(deltas) -> str:
413425
"""Given an array of timedeltas, returns a CF compatible time-unit from
414426
{'days', 'hours', 'minutes' 'seconds'} (the first one that can evenly
415427
divide all unique time deltas in `deltas`)
@@ -419,7 +431,7 @@ def infer_timedelta_units(deltas):
419431
return _infer_time_units_from_diff(unique_timedeltas)
420432

421433

422-
def cftime_to_nptime(times, raise_on_invalid=True):
434+
def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray:
423435
"""Given an array of cftime.datetime objects, return an array of
424436
numpy.datetime64 objects of the same size
425437
@@ -448,7 +460,7 @@ def cftime_to_nptime(times, raise_on_invalid=True):
448460
return new
449461

450462

451-
def convert_times(times, date_type, raise_on_invalid=True):
463+
def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray:
452464
"""Given an array of datetimes, return the same dates in another cftime or numpy date type.
453465
454466
Useful to convert between calendars in numpy and cftime or between cftime calendars.
@@ -529,7 +541,9 @@ def convert_time_or_go_back(date, date_type):
529541
)
530542

531543

532-
def _should_cftime_be_used(source, target_calendar, use_cftime):
544+
def _should_cftime_be_used(
545+
source, target_calendar: str, use_cftime: bool | None
546+
) -> bool:
533547
"""Return whether conversion of the source to the target calendar should
534548
result in a cftime-backed array.
535549
@@ -542,7 +556,7 @@ def _should_cftime_be_used(source, target_calendar, use_cftime):
542556
if _is_standard_calendar(target_calendar):
543557
if _is_numpy_compatible_time_range(source):
544558
# Conversion is possible with pandas, force False if it was None
545-
use_cftime = False
559+
return False
546560
elif use_cftime is False:
547561
raise ValueError(
548562
"Source time range is not valid for numpy datetimes. Try using `use_cftime=True`."
@@ -551,12 +565,10 @@ def _should_cftime_be_used(source, target_calendar, use_cftime):
551565
raise ValueError(
552566
f"Calendar '{target_calendar}' is only valid with cftime. Try using `use_cftime=True`."
553567
)
554-
else:
555-
use_cftime = True
556-
return use_cftime
568+
return True
557569

558570

559-
def _cleanup_netcdf_time_units(units):
571+
def _cleanup_netcdf_time_units(units: str) -> str:
560572
delta, ref_date = _unpack_netcdf_time_units(units)
561573
try:
562574
units = f"{delta} since {format_timestamp(ref_date)}"
@@ -567,7 +579,7 @@ def _cleanup_netcdf_time_units(units):
567579
return units
568580

569581

570-
def _encode_datetime_with_cftime(dates, units, calendar):
582+
def _encode_datetime_with_cftime(dates, units: str, calendar: str) -> np.ndarray:
571583
"""Fallback method for encoding dates using cftime.
572584
573585
This method is more flexible than xarray's parsing using datetime64[ns]
@@ -597,14 +609,16 @@ def encode_datetime(d):
597609
return np.array([encode_datetime(d) for d in dates.ravel()]).reshape(dates.shape)
598610

599611

600-
def cast_to_int_if_safe(num):
612+
def cast_to_int_if_safe(num) -> np.ndarray:
601613
int_num = np.asarray(num, dtype=np.int64)
602614
if (num == int_num).all():
603615
num = int_num
604616
return num
605617

606618

607-
def encode_cf_datetime(dates, units=None, calendar=None):
619+
def encode_cf_datetime(
620+
dates, units: str | None = None, calendar: str | None = None
621+
) -> tuple[np.ndarray, str, str]:
608622
"""Given an array of datetime objects, returns the tuple `(num, units,
609623
calendar)` suitable for a CF compliant time variable.
610624
@@ -624,7 +638,7 @@ def encode_cf_datetime(dates, units=None, calendar=None):
624638
if calendar is None:
625639
calendar = infer_calendar_name(dates)
626640

627-
delta, ref_date = _unpack_netcdf_time_units(units)
641+
delta, _ref_date = _unpack_netcdf_time_units(units)
628642
try:
629643
if not _is_standard_calendar(calendar) or dates.dtype.kind == "O":
630644
# parse with cftime instead
@@ -633,7 +647,7 @@ def encode_cf_datetime(dates, units=None, calendar=None):
633647

634648
delta_units = _netcdf_to_numpy_timeunit(delta)
635649
time_delta = np.timedelta64(1, delta_units).astype("timedelta64[ns]")
636-
ref_date = pd.Timestamp(ref_date)
650+
ref_date = pd.Timestamp(_ref_date)
637651

638652
# If the ref_date Timestamp is timezone-aware, convert to UTC and
639653
# make it timezone-naive (GH 2649).
@@ -661,7 +675,7 @@ def encode_cf_datetime(dates, units=None, calendar=None):
661675
return (num, units, calendar)
662676

663677

664-
def encode_cf_timedelta(timedeltas, units=None):
678+
def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]:
665679
if units is None:
666680
units = infer_timedelta_units(timedeltas)
667681

@@ -673,27 +687,30 @@ def encode_cf_timedelta(timedeltas, units=None):
673687

674688

675689
class CFDatetimeCoder(VariableCoder):
676-
def __init__(self, use_cftime=None):
690+
def __init__(self, use_cftime: bool | None = None) -> None:
677691
self.use_cftime = use_cftime
678692

679-
def encode(self, variable, name=None):
680-
dims, data, attrs, encoding = unpack_for_encoding(variable)
681-
if np.issubdtype(data.dtype, np.datetime64) or contains_cftime_datetimes(
682-
variable
683-
):
693+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
694+
if np.issubdtype(
695+
variable.data.dtype, np.datetime64
696+
) or contains_cftime_datetimes(variable):
697+
dims, data, attrs, encoding = unpack_for_encoding(variable)
698+
684699
(data, units, calendar) = encode_cf_datetime(
685700
data, encoding.pop("units", None), encoding.pop("calendar", None)
686701
)
687702
safe_setitem(attrs, "units", units, name=name)
688703
safe_setitem(attrs, "calendar", calendar, name=name)
689704

690-
return Variable(dims, data, attrs, encoding)
691-
692-
def decode(self, variable, name=None):
693-
dims, data, attrs, encoding = unpack_for_decoding(variable)
705+
return Variable(dims, data, attrs, encoding, fastpath=True)
706+
else:
707+
return variable
694708

695-
units = attrs.get("units")
709+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
710+
units = variable.attrs.get("units", None)
696711
if isinstance(units, str) and "since" in units:
712+
dims, data, attrs, encoding = unpack_for_decoding(variable)
713+
697714
units = pop_to(attrs, encoding, "units")
698715
calendar = pop_to(attrs, encoding, "calendar")
699716
dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
@@ -705,27 +722,33 @@ def decode(self, variable, name=None):
705722
)
706723
data = lazy_elemwise_func(data, transform, dtype)
707724

708-
return Variable(dims, data, attrs, encoding)
725+
return Variable(dims, data, attrs, encoding, fastpath=True)
726+
else:
727+
return variable
709728

710729

711730
class CFTimedeltaCoder(VariableCoder):
712-
def encode(self, variable, name=None):
713-
dims, data, attrs, encoding = unpack_for_encoding(variable)
731+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
732+
if np.issubdtype(variable.data.dtype, np.timedelta64):
733+
dims, data, attrs, encoding = unpack_for_encoding(variable)
714734

715-
if np.issubdtype(data.dtype, np.timedelta64):
716735
data, units = encode_cf_timedelta(data, encoding.pop("units", None))
717736
safe_setitem(attrs, "units", units, name=name)
718737

719-
return Variable(dims, data, attrs, encoding)
720-
721-
def decode(self, variable, name=None):
722-
dims, data, attrs, encoding = unpack_for_decoding(variable)
738+
return Variable(dims, data, attrs, encoding, fastpath=True)
739+
else:
740+
return variable
723741

724-
units = attrs.get("units")
742+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
743+
units = variable.attrs.get("units", None)
725744
if isinstance(units, str) and units in TIME_UNITS:
745+
dims, data, attrs, encoding = unpack_for_decoding(variable)
746+
726747
units = pop_to(attrs, encoding, "units")
727748
transform = partial(decode_cf_timedelta, units=units)
728749
dtype = np.dtype("timedelta64[ns]")
729750
data = lazy_elemwise_func(data, transform, dtype=dtype)
730751

731-
return Variable(dims, data, attrs, encoding)
752+
return Variable(dims, data, attrs, encoding, fastpath=True)
753+
else:
754+
return variable

0 commit comments

Comments
 (0)