diff --git a/mypy.ini b/mypy.ini index 5cfd4056ba0..52ddce8ec51 100644 --- a/mypy.ini +++ b/mypy.ini @@ -36,10 +36,6 @@ ignore_errors = True ignore_errors = True -[mypy-torchvision.transforms.autoaugment.*] - -ignore_errors = True - [mypy-PIL.*] ignore_missing_imports = True diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 97522945d2e..3b6c927a4eb 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -3,7 +3,7 @@ from enum import Enum from torch import Tensor -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict from . import functional as F, InterpolationMode @@ -19,7 +19,9 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" -def _get_transforms(policy: AutoAugmentPolicy): +def _get_transforms( # type: ignore[return] + policy: AutoAugmentPolicy +) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: if policy == AutoAugmentPolicy.IMAGENET: return [ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), @@ -106,7 +108,7 @@ def _get_transforms(policy: AutoAugmentPolicy): ] -def _get_magnitudes(): +def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]: _BINS = 10 return { # name: (magnitudes, signed) @@ -144,8 +146,12 @@ class AutoAugment(torch.nn.Module): image. If given a number, the value is used for all bands respectively. """ - def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, - interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None): + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None + ) -> None: super().__init__() self.policy = policy self.interpolation = interpolation @@ -163,7 +169,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: Returns: params required by the autoaugment transformation """ - policy_id = torch.randint(transform_num, (1,)).item() + policy_id = int(torch.randint(transform_num, (1,)).item()) probs = torch.rand((2,)) signs = torch.randint(2, (2,)) @@ -172,7 +178,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: return self._op_meta[name] - def forward(self, img: Tensor): + def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. @@ -233,5 +239,5 @@ def forward(self, img: Tensor): return img - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)