Skip to content

Commit 1fb9050

Browse files
Fix some mypy issues (pydata#6531)
* Fix some mypy issues Unfortunately these have crept back in. I'll add a workflow job, since the pre-commit is not covering everything on its own * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add mypy workflow * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 73a9932 commit 1fb9050

File tree

7 files changed

+84
-57
lines changed

7 files changed

+84
-57
lines changed

.github/workflows/ci-additional.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,41 @@ jobs:
151151
run: |
152152
python -m pytest --doctest-modules xarray --ignore xarray/tests
153153
154+
mypy:
155+
name: Mypy
156+
runs-on: "ubuntu-latest"
157+
if: needs.detect-ci-trigger.outputs.triggered == 'false'
158+
defaults:
159+
run:
160+
shell: bash -l {0}
161+
162+
steps:
163+
- uses: actions/checkout@v3
164+
with:
165+
fetch-depth: 0 # Fetch all history for all branches and tags.
166+
- uses: conda-incubator/setup-miniconda@v2
167+
with:
168+
channels: conda-forge
169+
channel-priority: strict
170+
mamba-version: "*"
171+
activate-environment: xarray-tests
172+
auto-update-conda: false
173+
python-version: "3.9"
174+
175+
- name: Install conda dependencies
176+
run: |
177+
mamba env update -f ci/requirements/environment.yml
178+
- name: Install xarray
179+
run: |
180+
python -m pip install --no-deps -e .
181+
- name: Version info
182+
run: |
183+
conda info -a
184+
conda list
185+
python xarray/util/print_versions.py
186+
- name: Run mypy
187+
run: mypy
188+
154189
min-version-policy:
155190
name: Minimum Version Policy
156191
runs-on: "ubuntu-latest"

xarray/core/dataarray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,8 @@ def chunk(
11541154
chunks = {}
11551155

11561156
if isinstance(chunks, (float, str, int)):
1157-
chunks = dict.fromkeys(self.dims, chunks)
1157+
# ignoring type; unclear why it won't accept a Literal into the value.
1158+
chunks = dict.fromkeys(self.dims, chunks) # type: ignore
11581159
elif isinstance(chunks, (tuple, list)):
11591160
chunks = dict(zip(self.dims, chunks))
11601161
else:
@@ -4735,7 +4736,7 @@ def curvefit(
47354736

47364737
def drop_duplicates(
47374738
self,
4738-
dim: Hashable | Iterable[Hashable] | ...,
4739+
dim: Hashable | Iterable[Hashable],
47394740
keep: Literal["first", "last"] | Literal[False] = "first",
47404741
):
47414742
"""Returns a new DataArray with duplicate dimension values removed.

xarray/core/dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7981,7 +7981,7 @@ def _wrapper(Y, *coords_, **kwargs):
79817981

79827982
def drop_duplicates(
79837983
self,
7984-
dim: Hashable | Iterable[Hashable] | ...,
7984+
dim: Hashable | Iterable[Hashable],
79857985
keep: Literal["first", "last"] | Literal[False] = "first",
79867986
):
79877987
"""Returns a new Dataset with duplicate dimension values removed.
@@ -8005,9 +8005,11 @@ def drop_duplicates(
80058005
DataArray.drop_duplicates
80068006
"""
80078007
if isinstance(dim, str):
8008-
dims = (dim,)
8008+
dims: Iterable = (dim,)
80098009
elif dim is ...:
80108010
dims = self.dims
8011+
elif not isinstance(dim, Iterable):
8012+
dims = [dim]
80118013
else:
80128014
dims = dim
80138015

xarray/core/indexing.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import enum
24
import functools
35
import operator
@@ -6,19 +8,7 @@
68
from dataclasses import dataclass, field
79
from datetime import timedelta
810
from html import escape
9-
from typing import (
10-
TYPE_CHECKING,
11-
Any,
12-
Callable,
13-
Dict,
14-
Hashable,
15-
Iterable,
16-
List,
17-
Mapping,
18-
Optional,
19-
Tuple,
20-
Union,
21-
)
11+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Mapping
2212

2313
import numpy as np
2414
import pandas as pd
@@ -59,12 +49,12 @@ class IndexSelResult:
5949
6050
"""
6151

62-
dim_indexers: Dict[Any, Any]
63-
indexes: Dict[Any, "Index"] = field(default_factory=dict)
64-
variables: Dict[Any, "Variable"] = field(default_factory=dict)
65-
drop_coords: List[Hashable] = field(default_factory=list)
66-
drop_indexes: List[Hashable] = field(default_factory=list)
67-
rename_dims: Dict[Any, Hashable] = field(default_factory=dict)
52+
dim_indexers: dict[Any, Any]
53+
indexes: dict[Any, Index] = field(default_factory=dict)
54+
variables: dict[Any, Variable] = field(default_factory=dict)
55+
drop_coords: list[Hashable] = field(default_factory=list)
56+
drop_indexes: list[Hashable] = field(default_factory=list)
57+
rename_dims: dict[Any, Hashable] = field(default_factory=dict)
6858

6959
def as_tuple(self):
7060
"""Unlike ``dataclasses.astuple``, return a shallow copy.
@@ -82,7 +72,7 @@ def as_tuple(self):
8272
)
8373

8474

85-
def merge_sel_results(results: List[IndexSelResult]) -> IndexSelResult:
75+
def merge_sel_results(results: list[IndexSelResult]) -> IndexSelResult:
8676
all_dims_count = Counter([dim for res in results for dim in res.dim_indexers])
8777
duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1}
8878

@@ -124,13 +114,13 @@ def group_indexers_by_index(
124114
obj: T_Xarray,
125115
indexers: Mapping[Any, Any],
126116
options: Mapping[str, Any],
127-
) -> List[Tuple["Index", Dict[Any, Any]]]:
117+
) -> list[tuple[Index, dict[Any, Any]]]:
128118
"""Returns a list of unique indexes and their corresponding indexers."""
129119
unique_indexes = {}
130-
grouped_indexers: Mapping[Union[int, None], Dict] = defaultdict(dict)
120+
grouped_indexers: Mapping[int | None, dict] = defaultdict(dict)
131121

132122
for key, label in indexers.items():
133-
index: "Index" = obj.xindexes.get(key, None)
123+
index: Index = obj.xindexes.get(key, None)
134124

135125
if index is not None:
136126
index_id = id(index)
@@ -787,7 +777,7 @@ class IndexingSupport(enum.Enum):
787777

788778
def explicit_indexing_adapter(
789779
key: ExplicitIndexer,
790-
shape: Tuple[int, ...],
780+
shape: tuple[int, ...],
791781
indexing_support: IndexingSupport,
792782
raw_indexing_method: Callable,
793783
) -> Any:
@@ -821,8 +811,8 @@ def explicit_indexing_adapter(
821811

822812

823813
def decompose_indexer(
824-
indexer: ExplicitIndexer, shape: Tuple[int, ...], indexing_support: IndexingSupport
825-
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
814+
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
815+
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
826816
if isinstance(indexer, VectorizedIndexer):
827817
return _decompose_vectorized_indexer(indexer, shape, indexing_support)
828818
if isinstance(indexer, (BasicIndexer, OuterIndexer)):
@@ -848,9 +838,9 @@ def _decompose_slice(key, size):
848838

849839
def _decompose_vectorized_indexer(
850840
indexer: VectorizedIndexer,
851-
shape: Tuple[int, ...],
841+
shape: tuple[int, ...],
852842
indexing_support: IndexingSupport,
853-
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
843+
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
854844
"""
855845
Decompose vectorized indexer to the successive two indexers, where the
856846
first indexer will be used to index backend arrays, while the second one
@@ -929,10 +919,10 @@ def _decompose_vectorized_indexer(
929919

930920

931921
def _decompose_outer_indexer(
932-
indexer: Union[BasicIndexer, OuterIndexer],
933-
shape: Tuple[int, ...],
922+
indexer: BasicIndexer | OuterIndexer,
923+
shape: tuple[int, ...],
934924
indexing_support: IndexingSupport,
935-
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
925+
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
936926
"""
937927
Decompose outer indexer to the successive two indexers, where the
938928
first indexer will be used to index backend arrays, while the second one
@@ -973,7 +963,7 @@ def _decompose_outer_indexer(
973963
return indexer, BasicIndexer(())
974964
assert isinstance(indexer, (OuterIndexer, BasicIndexer))
975965

976-
backend_indexer: List[Any] = []
966+
backend_indexer: list[Any] = []
977967
np_indexer = []
978968
# make indexer positive
979969
pos_indexer: list[np.ndarray | int | np.number] = []
@@ -1395,7 +1385,7 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
13951385
return np.asarray(array.values, dtype=dtype)
13961386

13971387
@property
1398-
def shape(self) -> Tuple[int]:
1388+
def shape(self) -> tuple[int]:
13991389
return (len(self.array),)
14001390

14011391
def _convert_scalar(self, item):
@@ -1420,13 +1410,13 @@ def _convert_scalar(self, item):
14201410

14211411
def __getitem__(
14221412
self, indexer
1423-
) -> Union[
1424-
"PandasIndexingAdapter",
1425-
NumpyIndexingAdapter,
1426-
np.ndarray,
1427-
np.datetime64,
1428-
np.timedelta64,
1429-
]:
1413+
) -> (
1414+
PandasIndexingAdapter
1415+
| NumpyIndexingAdapter
1416+
| np.ndarray
1417+
| np.datetime64
1418+
| np.timedelta64
1419+
):
14301420
key = indexer.tuple
14311421
if isinstance(key, tuple) and len(key) == 1:
14321422
# unpack key so it can index a pandas.Index object (pandas.Index
@@ -1449,7 +1439,7 @@ def transpose(self, order) -> pd.Index:
14491439
def __repr__(self) -> str:
14501440
return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})"
14511441

1452-
def copy(self, deep: bool = True) -> "PandasIndexingAdapter":
1442+
def copy(self, deep: bool = True) -> PandasIndexingAdapter:
14531443
# Not the same as just writing `self.array.copy(deep=deep)`, as
14541444
# shallow copies of the underlying numpy.ndarrays become deep ones
14551445
# upon pickling
@@ -1476,7 +1466,7 @@ def __init__(
14761466
self,
14771467
array: pd.MultiIndex,
14781468
dtype: DTypeLike = None,
1479-
level: Optional[str] = None,
1469+
level: str | None = None,
14801470
):
14811471
super().__init__(array, dtype)
14821472
self.level = level
@@ -1535,7 +1525,7 @@ def _repr_html_(self) -> str:
15351525
array_repr = short_numpy_repr(self._get_array_subset())
15361526
return f"<pre>{escape(array_repr)}</pre>"
15371527

1538-
def copy(self, deep: bool = True) -> "PandasMultiIndexingAdapter":
1528+
def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter:
15391529
# see PandasIndexingAdapter.copy
15401530
array = self.array.copy(deep=True) if deep else self.array
15411531
return type(self)(array, self._dtype, self.level)

xarray/core/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ def remove_incompatible_items(
237237
del first_dict[k]
238238

239239

240-
def is_dict_like(value: Any) -> bool:
240+
# It's probably OK to give this as a TypeGuard; though it's not perfectly robust.
241+
def is_dict_like(value: Any) -> TypeGuard[dict]:
241242
return hasattr(value, "keys") and hasattr(value, "__getitem__")
242243

243244

xarray/tests/test_coding_times.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,20 +1007,18 @@ def test_decode_ambiguous_time_warns(calendar) -> None:
10071007
units = "days since 1-1-1"
10081008
expected = num2date(dates, units, calendar=calendar, only_use_cftime_datetimes=True)
10091009

1010-
exp_warn_type = SerializationWarning if is_standard_calendar else None
1011-
1012-
with pytest.warns(exp_warn_type) as record:
1013-
result = decode_cf_datetime(dates, units, calendar=calendar)
1014-
10151010
if is_standard_calendar:
1011+
with pytest.warns(SerializationWarning) as record:
1012+
result = decode_cf_datetime(dates, units, calendar=calendar)
10161013
relevant_warnings = [
10171014
r
10181015
for r in record.list
10191016
if str(r.message).startswith("Ambiguous reference date string: 1-1-1")
10201017
]
10211018
assert len(relevant_warnings) == 1
10221019
else:
1023-
assert not record
1020+
with assert_no_warnings():
1021+
result = decode_cf_datetime(dates, units, calendar=calendar)
10241022

10251023
np.testing.assert_array_equal(result, expected)
10261024

xarray/tests/test_formatting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,10 @@ def test_short_numpy_repr() -> None:
480480
assert num_lines < 30
481481

482482
# threshold option (default: 200)
483-
array = np.arange(100)
484-
assert "..." not in formatting.short_numpy_repr(array)
483+
array2 = np.arange(100)
484+
assert "..." not in formatting.short_numpy_repr(array2)
485485
with xr.set_options(display_values_threshold=10):
486-
assert "..." in formatting.short_numpy_repr(array)
486+
assert "..." in formatting.short_numpy_repr(array2)
487487

488488

489489
def test_large_array_repr_length() -> None:

0 commit comments

Comments
 (0)