Skip to content

Commit f9e31a6

Browse files
authored
Added annotation typing to vgg (#2861)
* style: Added annotation typing for vgg * fix: Fixed annotation typing * refactor: Removed un-necessary import * fix: Added missing annotation for kwargs * fix: Fixed constructor typing * refactor: Refactored typing to minize changes * refactor: Refactored typing cast * fix: Fixed module list typing
1 parent d559ad8 commit f9e31a6

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

torchvision/models/vgg.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from .utils import load_state_dict_from_url
4+
from typing import Union, List, Dict, Any, cast
45

56

67
__all__ = [
@@ -23,7 +24,12 @@
2324

2425
class VGG(nn.Module):
2526

26-
def __init__(self, features, num_classes=1000, init_weights=True):
27+
def __init__(
28+
self,
29+
features: nn.Module,
30+
num_classes: int = 1000,
31+
init_weights: bool = True
32+
) -> None:
2733
super(VGG, self).__init__()
2834
self.features = features
2935
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
@@ -39,14 +45,14 @@ def __init__(self, features, num_classes=1000, init_weights=True):
3945
if init_weights:
4046
self._initialize_weights()
4147

42-
def forward(self, x):
48+
def forward(self, x: torch.Tensor) -> torch.Tensor:
4349
x = self.features(x)
4450
x = self.avgpool(x)
4551
x = torch.flatten(x, 1)
4652
x = self.classifier(x)
4753
return x
4854

49-
def _initialize_weights(self):
55+
def _initialize_weights(self) -> None:
5056
for m in self.modules():
5157
if isinstance(m, nn.Conv2d):
5258
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
@@ -60,13 +66,14 @@ def _initialize_weights(self):
6066
nn.init.constant_(m.bias, 0)
6167

6268

63-
def make_layers(cfg, batch_norm=False):
64-
layers = []
69+
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
70+
layers: List[nn.Module] = []
6571
in_channels = 3
6672
for v in cfg:
6773
if v == 'M':
6874
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
6975
else:
76+
v = cast(int, v)
7077
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
7178
if batch_norm:
7279
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
@@ -76,15 +83,15 @@ def make_layers(cfg, batch_norm=False):
7683
return nn.Sequential(*layers)
7784

7885

79-
cfgs = {
86+
cfgs: Dict[str, List[Union[str, int]]] = {
8087
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
8188
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
8289
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
8390
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
8491
}
8592

8693

87-
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
94+
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
8895
if pretrained:
8996
kwargs['init_weights'] = False
9097
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
@@ -95,7 +102,7 @@ def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
95102
return model
96103

97104

98-
def vgg11(pretrained=False, progress=True, **kwargs):
105+
def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
99106
r"""VGG 11-layer model (configuration "A") from
100107
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
101108
@@ -106,7 +113,7 @@ def vgg11(pretrained=False, progress=True, **kwargs):
106113
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
107114

108115

109-
def vgg11_bn(pretrained=False, progress=True, **kwargs):
116+
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
110117
r"""VGG 11-layer model (configuration "A") with batch normalization
111118
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
112119
@@ -117,7 +124,7 @@ def vgg11_bn(pretrained=False, progress=True, **kwargs):
117124
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
118125

119126

120-
def vgg13(pretrained=False, progress=True, **kwargs):
127+
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
121128
r"""VGG 13-layer model (configuration "B")
122129
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
123130
@@ -128,7 +135,7 @@ def vgg13(pretrained=False, progress=True, **kwargs):
128135
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
129136

130137

131-
def vgg13_bn(pretrained=False, progress=True, **kwargs):
138+
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
132139
r"""VGG 13-layer model (configuration "B") with batch normalization
133140
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
134141
@@ -139,7 +146,7 @@ def vgg13_bn(pretrained=False, progress=True, **kwargs):
139146
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
140147

141148

142-
def vgg16(pretrained=False, progress=True, **kwargs):
149+
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
143150
r"""VGG 16-layer model (configuration "D")
144151
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
145152
@@ -150,7 +157,7 @@ def vgg16(pretrained=False, progress=True, **kwargs):
150157
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
151158

152159

153-
def vgg16_bn(pretrained=False, progress=True, **kwargs):
160+
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
154161
r"""VGG 16-layer model (configuration "D") with batch normalization
155162
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
156163
@@ -161,7 +168,7 @@ def vgg16_bn(pretrained=False, progress=True, **kwargs):
161168
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
162169

163170

164-
def vgg19(pretrained=False, progress=True, **kwargs):
171+
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
165172
r"""VGG 19-layer model (configuration "E")
166173
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
167174
@@ -172,7 +179,7 @@ def vgg19(pretrained=False, progress=True, **kwargs):
172179
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
173180

174181

175-
def vgg19_bn(pretrained=False, progress=True, **kwargs):
182+
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
176183
r"""VGG 19-layer model (configuration 'E') with batch normalization
177184
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
178185

0 commit comments

Comments
 (0)