From f54049fcb49339a5ae092de21df7b8b3cbd7fe7a Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 21 Oct 2020 17:41:45 +0200 Subject: [PATCH 1/8] style: Added annotation typing for resnet --- torchvision/models/resnet.py | 79 +++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 797f459f5cb..3c468702218 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,6 +1,9 @@ import torch +from torch import Tensor import torch.nn as nn from .utils import load_state_dict_from_url +from typing import Callable, Any +from torch.jit.annotations import List, Optional __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -21,13 +24,13 @@ } -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) -def conv1x1(in_planes, out_planes, stride=1): +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) @@ -35,8 +38,17 @@ def conv1x1(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -53,7 +65,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -81,8 +93,17 @@ class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -98,7 +119,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -123,9 +144,17 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None): + def __init__( + self, + block: Callable[..., nn.Module], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -174,7 +203,8 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + def _make_layer(self, block: Callable[..., nn.Module], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -198,7 +228,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) - def _forward_impl(self, x): + def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) @@ -216,11 +246,12 @@ def _forward_impl(self, x): return x - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _resnet(arch, block, layers, pretrained, progress, **kwargs): +def _resnet(arch: str, block: Callable[..., nn.Module], layers: List[int], pretrained: bool, progress: bool, + **kwargs: Any) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], @@ -229,7 +260,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs): return model -def resnet18(pretrained=False, progress=True, **kwargs): +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ @@ -241,7 +272,7 @@ def resnet18(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet34(pretrained=False, progress=True, **kwargs): +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ @@ -253,7 +284,7 @@ def resnet34(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet50(pretrained=False, progress=True, **kwargs): +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ @@ -265,7 +296,7 @@ def resnet50(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet101(pretrained=False, progress=True, **kwargs): +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ @@ -277,7 +308,7 @@ def resnet101(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet152(pretrained=False, progress=True, **kwargs): +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ @@ -289,7 +320,7 @@ def resnet152(pretrained=False, progress=True, **kwargs): **kwargs) -def resnext50_32x4d(pretrained=False, progress=True, **kwargs): +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -303,7 +334,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs): pretrained, progress, **kwargs) -def resnext101_32x8d(pretrained=False, progress=True, **kwargs): +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -317,7 +348,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): pretrained, progress, **kwargs) -def wide_resnet50_2(pretrained=False, progress=True, **kwargs): +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_ @@ -335,7 +366,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs): pretrained, progress, **kwargs) -def wide_resnet101_2(pretrained=False, progress=True, **kwargs): +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_ From b83e5513cc2244afe8c0c6e0dad5a3efec2121ab Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 21 Oct 2020 18:16:39 +0200 Subject: [PATCH 2/8] fix: Fixed annotation to pass classes --- torchvision/models/resnet.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 3c468702218..42223eaf304 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -2,7 +2,7 @@ from torch import Tensor import torch.nn as nn from .utils import load_state_dict_from_url -from typing import Callable, Any +from typing import Type, Any from torch.jit.annotations import List, Optional @@ -47,7 +47,7 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Type[nn.Module]] = None ): super(BasicBlock, self).__init__() if norm_layer is None: @@ -102,7 +102,7 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Type[nn.Module]] = None ): super(Bottleneck, self).__init__() if norm_layer is None: @@ -146,14 +146,14 @@ class ResNet(nn.Module): def __init__( self, - block: Callable[..., nn.Module], + block: Type[nn.Module], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Type[nn.Module]] = None ): super(ResNet, self).__init__() if norm_layer is None: @@ -203,7 +203,7 @@ def __init__( elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def _make_layer(self, block: Callable[..., nn.Module], planes: int, blocks: int, + def _make_layer(self, block: Type[nn.Module], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None @@ -250,7 +250,7 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _resnet(arch: str, block: Callable[..., nn.Module], layers: List[int], pretrained: bool, progress: bool, +def _resnet(arch: str, block: Type[nn.Module], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: From a12cfa451e785e07115e9e4a999ca50fc5e9d4f4 Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 21 Oct 2020 23:05:36 +0200 Subject: [PATCH 3/8] fix: Fixed annotation typing --- torchvision/models/resnet.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 42223eaf304..50fb2207358 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -35,8 +35,12 @@ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) -class BasicBlock(nn.Module): - expansion = 1 +class ResBlock(nn.Module): + expansion: int = 1 + + +class BasicBlock(ResBlock): + expansion: int = 1 def __init__( self, @@ -84,14 +88,14 @@ def forward(self, x: Tensor) -> Tensor: return out -class Bottleneck(nn.Module): +class Bottleneck(ResBlock): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - expansion = 4 + expansion: int = 4 def __init__( self, @@ -146,7 +150,7 @@ class ResNet(nn.Module): def __init__( self, - block: Type[nn.Module], + block: Type[ResBlock], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, @@ -203,7 +207,7 @@ def __init__( elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def _make_layer(self, block: Type[nn.Module], planes: int, blocks: int, + def _make_layer(self, block: Type[ResBlock], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None @@ -250,7 +254,7 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _resnet(arch: str, block: Type[nn.Module], layers: List[int], pretrained: bool, progress: bool, +def _resnet(arch: str, block: Type[ResBlock], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: From 3a21e0262042e1dd0441483a4a4da66d169d2b3b Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 21 Oct 2020 23:53:47 +0200 Subject: [PATCH 4/8] fix: Fixed annotation typing --- torchvision/models/resnet.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 50fb2207358..7deebd85a38 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -2,7 +2,7 @@ from torch import Tensor import torch.nn as nn from .utils import load_state_dict_from_url -from typing import Type, Any +from typing import Type, Any, Callable, Union from torch.jit.annotations import List, Optional @@ -35,11 +35,7 @@ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) -class ResBlock(nn.Module): - expansion: int = 1 - - -class BasicBlock(ResBlock): +class BasicBlock(nn.Module): expansion: int = 1 def __init__( @@ -51,7 +47,7 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Type[nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None ): super(BasicBlock, self).__init__() if norm_layer is None: @@ -88,7 +84,7 @@ def forward(self, x: Tensor) -> Tensor: return out -class Bottleneck(ResBlock): +class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. @@ -106,7 +102,7 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Type[nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None ): super(Bottleneck, self).__init__() if norm_layer is None: @@ -150,14 +146,14 @@ class ResNet(nn.Module): def __init__( self, - block: Type[ResBlock], + block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Type[nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None ): super(ResNet, self).__init__() if norm_layer is None: @@ -207,7 +203,7 @@ def __init__( elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def _make_layer(self, block: Type[ResBlock], planes: int, blocks: int, + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None @@ -254,7 +250,7 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _resnet(arch: str, block: Type[ResBlock], layers: List[int], pretrained: bool, progress: bool, +def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: From 818cf6ec8f4f185d7e9ffa5d34ae39dc763beb30 Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 21 Oct 2020 23:58:16 +0200 Subject: [PATCH 5/8] fix: Fixed annotation typing for resnet --- torchvision/models/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 7deebd85a38..ce49240751d 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -199,9 +199,9 @@ def __init__( if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: From 2fb21bdc01ddc0e0f3c84ccd11e446b97f85aef0 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 22 Oct 2020 12:05:16 +0200 Subject: [PATCH 6/8] refactor: Removed un-necessary import --- torchvision/models/resnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index ce49240751d..feb162a0adc 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -2,8 +2,7 @@ from torch import Tensor import torch.nn as nn from .utils import load_state_dict_from_url -from typing import Type, Any, Callable, Union -from torch.jit.annotations import List, Optional +from typing import Type, Any, Callable, Union, List, Optional __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', From 3277c47949c2a78fc9e6edb41ed72841a61398ef Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 22 Oct 2020 21:18:24 +0200 Subject: [PATCH 7/8] fix: Fixed constructor typing --- torchvision/models/resnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index feb162a0adc..0ff8332ade9 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -47,7 +47,7 @@ def __init__( base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None - ): + ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -102,7 +102,7 @@ def __init__( base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None - ): + ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -153,7 +153,7 @@ def __init__( width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None - ): + ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d From 4584660db3ddb44866661b8315401925453fda0a Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 22 Oct 2020 21:19:29 +0200 Subject: [PATCH 8/8] style: Added black formatting on _resnet --- torchvision/models/resnet.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 0ff8332ade9..f3f86c25cd2 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -249,8 +249,14 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, - **kwargs: Any) -> ResNet: +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch],