diff --git a/mypy.ini b/mypy.ini index 07b9c75c516..f58ab3d9919 100644 --- a/mypy.ini +++ b/mypy.ini @@ -47,10 +47,6 @@ ignore_errors = True ignore_errors = True -[mypy-torchvision.models.densenet.*] - -ignore_errors=True - [mypy-torchvision.models.detection.anchor_utils] ignore_errors = True diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 3b42807cc96..569ed0ad3d0 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,7 +1,7 @@ import re from collections import OrderedDict from functools import partial -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast import torch import torch.nn as nn @@ -116,7 +116,7 @@ def __init__( ) self.add_module("denselayer%d" % (i + 1), layer) - def forward(self, init_features: Tensor) -> Tensor: + def forward(self, init_features: Tensor) -> Tensor: # type: ignore[override] features = [init_features] for name, layer in self.items(): new_features = layer(features) @@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) - state_dict = weights.get_state_dict(progress=progress, check_hash=True) + state_dict = cast(Dict[str, Tensor], weights.get_state_dict(progress=progress, check_hash=True)) for key in list(state_dict.keys()): res = pattern.match(key) if res: