Skip to content

Commit 0df596e

Browse files
author
Louis Stenger
committed
Merge remote-tracking branch 'upstream/main' into feature/fix-interp-docs
Fixes merge with pydata#6637
2 parents 2030c80 + 4615074 commit 0df596e

14 files changed

+1213
-951
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
*.py[cod]
22
__pycache__
3+
.env
4+
.venv
35

46
# example caches from Hypothesis
57
.hypothesis/

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ repos:
4646
# - id: velin
4747
# args: ["--write", "--compact"]
4848
- repo: https://github.com/pre-commit/mirrors-mypy
49-
rev: v0.950
49+
rev: v0.960
5050
hooks:
5151
- id: mypy
5252
# Copied from setup.cfg

HOW_TO_RELEASE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ upstream https://github.com/pydata/xarray (push)
111111

112112
As of 2022.03.0, we utilize the [CALVER](https://calver.org/) version system.
113113
Specifically, we have adopted the pattern `YYYY.MM.X`, where `YYYY` is a 4-digit
114-
year (e.g. `2022`), `MM` is a 2-digit zero-padded month (e.g. `01` for January), and `X` is the release number (starting at zero at the start of each month and incremented once for each additional release).
114+
year (e.g. `2022`), `0M` is a 2-digit zero-padded month (e.g. `01` for January), and `X` is the release number (starting at zero at the start of each month and incremented once for each additional release).

xarray/backends/api.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
T_NetcdfEngine,
5757
Literal["pydap", "pynio", "pseudonetcdf", "cfgrib", "zarr"],
5858
Type[BackendEntrypoint],
59+
str, # no nice typing support for custom backends
60+
None,
5961
]
6062
T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None]
6163
T_NetcdfTypes = Literal[
@@ -392,7 +394,8 @@ def open_dataset(
392394
scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like
393395
objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF).
394396
engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \
395-
"pseudonetcdf", "zarr"} or subclass of xarray.backends.BackendEntrypoint, optional
397+
"pseudonetcdf", "zarr", None}, installed backend \
398+
or subclass of xarray.backends.BackendEntrypoint, optional
396399
Engine to use when reading files. If not provided, the default engine
397400
is chosen based on available dependencies, with a preference for
398401
"netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``)
@@ -579,7 +582,8 @@ def open_dataarray(
579582
scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like
580583
objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF).
581584
engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \
582-
"pseudonetcdf", "zarr"}, optional
585+
"pseudonetcdf", "zarr", None}, installed backend \
586+
or subclass of xarray.backends.BackendEntrypoint, optional
583587
Engine to use when reading files. If not provided, the default engine
584588
is chosen based on available dependencies, with a preference for
585589
"netcdf4".
@@ -804,8 +808,9 @@ def open_mfdataset(
804808
If provided, call this function on each dataset prior to concatenation.
805809
You can find the file-name from which each dataset was loaded in
806810
``ds.encoding["source"]``.
807-
engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", "zarr"}, \
808-
optional
811+
engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \
812+
"pseudonetcdf", "zarr", None}, installed backend \
813+
or subclass of xarray.backends.BackendEntrypoint, optional
809814
Engine to use when reading files. If not provided, the default engine
810815
is chosen based on available dependencies, with a preference for
811816
"netcdf4".

xarray/core/alignment.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Tuple,
1717
Type,
1818
TypeVar,
19+
cast,
1920
)
2021

2122
import numpy as np
@@ -30,7 +31,7 @@
3031
if TYPE_CHECKING:
3132
from .dataarray import DataArray
3233
from .dataset import Dataset
33-
from .types import JoinOptions
34+
from .types import JoinOptions, T_DataArray, T_DataArrayOrSet, T_Dataset
3435

3536
DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)
3637

@@ -559,7 +560,7 @@ def align(self) -> None:
559560
def align(
560561
*objects: DataAlignable,
561562
join: JoinOptions = "inner",
562-
copy=True,
563+
copy: bool = True,
563564
indexes=None,
564565
exclude=frozenset(),
565566
fill_value=dtypes.NA,
@@ -592,7 +593,7 @@ def align(
592593
those of the first object with that dimension. Indexes for the same
593594
dimension must have the same size in all objects.
594595
595-
copy : bool, optional
596+
copy : bool, default: True
596597
If ``copy=True``, data in the return values is always copied. If
597598
``copy=False`` and reindexing is unnecessary, or can be performed with
598599
only slice operations, then the output may share memory with the input.
@@ -609,7 +610,7 @@ def align(
609610
610611
Returns
611612
-------
612-
aligned : DataArray or Dataset
613+
aligned : tuple of DataArray or Dataset
613614
Tuple of objects with the same type as `*objects` with aligned
614615
coordinates.
615616
@@ -935,7 +936,9 @@ def _get_broadcast_dims_map_common_coords(args, exclude):
935936
return dims_map, common_coords
936937

937938

938-
def _broadcast_helper(arg, exclude, dims_map, common_coords):
939+
def _broadcast_helper(
940+
arg: T_DataArrayOrSet, exclude, dims_map, common_coords
941+
) -> T_DataArrayOrSet:
939942

940943
from .dataarray import DataArray
941944
from .dataset import Dataset
@@ -950,22 +953,25 @@ def _set_dims(var):
950953

951954
return var.set_dims(var_dims_map)
952955

953-
def _broadcast_array(array):
956+
def _broadcast_array(array: T_DataArray) -> T_DataArray:
954957
data = _set_dims(array.variable)
955958
coords = dict(array.coords)
956959
coords.update(common_coords)
957-
return DataArray(data, coords, data.dims, name=array.name, attrs=array.attrs)
960+
return array.__class__(
961+
data, coords, data.dims, name=array.name, attrs=array.attrs
962+
)
958963

959-
def _broadcast_dataset(ds):
964+
def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
960965
data_vars = {k: _set_dims(ds.variables[k]) for k in ds.data_vars}
961966
coords = dict(ds.coords)
962967
coords.update(common_coords)
963-
return Dataset(data_vars, coords, ds.attrs)
968+
return ds.__class__(data_vars, coords, ds.attrs)
964969

970+
# remove casts once https://github.com/python/mypy/issues/12800 is resolved
965971
if isinstance(arg, DataArray):
966-
return _broadcast_array(arg)
972+
return cast("T_DataArrayOrSet", _broadcast_array(arg))
967973
elif isinstance(arg, Dataset):
968-
return _broadcast_dataset(arg)
974+
return cast("T_DataArrayOrSet", _broadcast_dataset(arg))
969975
else:
970976
raise ValueError("all input must be Dataset or DataArray objects")
971977

0 commit comments

Comments
 (0)