From 035edb84528e50998bd5efb0c2df8f0333f59127 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 29 Jul 2021 18:32:44 +0200 Subject: [PATCH 1/5] style: Added typing annotations --- torchvision/transforms/autoaugment.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 97522945d2e..ef484944b44 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -3,7 +3,7 @@ from enum import Enum from torch import Tensor -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict, Any from . import functional as F, InterpolationMode @@ -19,7 +19,9 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" -def _get_transforms(policy: AutoAugmentPolicy): +def _get_transforms( + policy: AutoAugmentPolicy +) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: if policy == AutoAugmentPolicy.IMAGENET: return [ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), @@ -106,7 +108,7 @@ def _get_transforms(policy: AutoAugmentPolicy): ] -def _get_magnitudes(): +def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]: _BINS = 10 return { # name: (magnitudes, signed) @@ -144,8 +146,12 @@ class AutoAugment(torch.nn.Module): image. If given a number, the value is used for all bands respectively. """ - def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, - interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None): + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None + ) -> None: super().__init__() self.policy = policy self.interpolation = interpolation @@ -172,7 +178,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: return self._op_meta[name] - def forward(self, img: Tensor): + def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. @@ -233,5 +239,5 @@ def forward(self, img: Tensor): return img - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) From 34af75ef7b4726434abd18f4667816cef6156e5b Mon Sep 17 00:00:00 2001 From: frgfm Date: Sat, 31 Jul 2021 15:53:30 +0200 Subject: [PATCH 2/5] style: Fixed typing --- torchvision/transforms/autoaugment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index ef484944b44..9b125c8a575 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -3,7 +3,7 @@ from enum import Enum from torch import Tensor -from typing import List, Tuple, Optional, Dict, Any +from typing import List, Tuple, Optional, Dict, Any, cast from . import functional as F, InterpolationMode @@ -19,7 +19,7 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" -def _get_transforms( +def _get_transforms( # type: ignore[return] policy: AutoAugmentPolicy ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: if policy == AutoAugmentPolicy.IMAGENET: @@ -169,7 +169,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: Returns: params required by the autoaugment transformation """ - policy_id = torch.randint(transform_num, (1,)).item() + policy_id = cast(int, torch.randint(transform_num, (1,)).item()) probs = torch.rand((2,)) signs = torch.randint(2, (2,)) From 1f0dcd6127b83319578984fd33284da111b84ae6 Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 2 Aug 2021 13:09:56 +0200 Subject: [PATCH 3/5] style: Fixed typing --- torchvision/transforms/autoaugment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 9b125c8a575..a0b0aa65c1a 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -3,7 +3,7 @@ from enum import Enum from torch import Tensor -from typing import List, Tuple, Optional, Dict, Any, cast +from typing import List, Tuple, Optional, Dict, Any from . import functional as F, InterpolationMode @@ -169,7 +169,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: Returns: params required by the autoaugment transformation """ - policy_id = cast(int, torch.randint(transform_num, (1,)).item()) + policy_id = int(torch.randint(transform_num, (1,)).item()) probs = torch.rand((2,)) signs = torch.randint(2, (2,)) From 5991ef78ae3fb66fec38d653ff5d874c0c943b5f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 16 Aug 2021 12:50:48 +0100 Subject: [PATCH 4/5] Remove unnecessary any. --- torchvision/transforms/autoaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index a0b0aa65c1a..3b6c927a4eb 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -3,7 +3,7 @@ from enum import Enum from torch import Tensor -from typing import List, Tuple, Optional, Dict, Any +from typing import List, Tuple, Optional, Dict from . import functional as F, InterpolationMode From 8d065d6ab52e5609f221361572a0b496bbb91067 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 16 Aug 2021 12:59:22 +0100 Subject: [PATCH 5/5] Update mypy.ini --- mypy.ini | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mypy.ini b/mypy.ini index 5cfd4056ba0..52ddce8ec51 100644 --- a/mypy.ini +++ b/mypy.ini @@ -36,10 +36,6 @@ ignore_errors = True ignore_errors = True -[mypy-torchvision.transforms.autoaugment.*] - -ignore_errors = True - [mypy-PIL.*] ignore_missing_imports = True