1
1
import torch
2
2
import torch .nn as nn
3
3
from .utils import load_state_dict_from_url
4
+ from typing import Union , List , Dict , Any , cast
4
5
5
6
6
7
__all__ = [
23
24
24
25
class VGG (nn .Module ):
25
26
26
- def __init__ (self , features , num_classes = 1000 , init_weights = True ):
27
+ def __init__ (
28
+ self ,
29
+ features : nn .Module ,
30
+ num_classes : int = 1000 ,
31
+ init_weights : bool = True
32
+ ) -> None :
27
33
super (VGG , self ).__init__ ()
28
34
self .features = features
29
35
self .avgpool = nn .AdaptiveAvgPool2d ((7 , 7 ))
@@ -39,14 +45,14 @@ def __init__(self, features, num_classes=1000, init_weights=True):
39
45
if init_weights :
40
46
self ._initialize_weights ()
41
47
42
- def forward (self , x ) :
48
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
43
49
x = self .features (x )
44
50
x = self .avgpool (x )
45
51
x = torch .flatten (x , 1 )
46
52
x = self .classifier (x )
47
53
return x
48
54
49
- def _initialize_weights (self ):
55
+ def _initialize_weights (self ) -> None :
50
56
for m in self .modules ():
51
57
if isinstance (m , nn .Conv2d ):
52
58
nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
@@ -60,13 +66,14 @@ def _initialize_weights(self):
60
66
nn .init .constant_ (m .bias , 0 )
61
67
62
68
63
- def make_layers (cfg , batch_norm = False ):
64
- layers = []
69
+ def make_layers (cfg : List [ Union [ str , int ]], batch_norm : bool = False ) -> nn . Sequential :
70
+ layers : List [ nn . Module ] = []
65
71
in_channels = 3
66
72
for v in cfg :
67
73
if v == 'M' :
68
74
layers += [nn .MaxPool2d (kernel_size = 2 , stride = 2 )]
69
75
else :
76
+ v = cast (int , v )
70
77
conv2d = nn .Conv2d (in_channels , v , kernel_size = 3 , padding = 1 )
71
78
if batch_norm :
72
79
layers += [conv2d , nn .BatchNorm2d (v ), nn .ReLU (inplace = True )]
@@ -76,15 +83,15 @@ def make_layers(cfg, batch_norm=False):
76
83
return nn .Sequential (* layers )
77
84
78
85
79
- cfgs = {
86
+ cfgs : Dict [ str , List [ Union [ str , int ]]] = {
80
87
'A' : [64 , 'M' , 128 , 'M' , 256 , 256 , 'M' , 512 , 512 , 'M' , 512 , 512 , 'M' ],
81
88
'B' : [64 , 64 , 'M' , 128 , 128 , 'M' , 256 , 256 , 'M' , 512 , 512 , 'M' , 512 , 512 , 'M' ],
82
89
'D' : [64 , 64 , 'M' , 128 , 128 , 'M' , 256 , 256 , 256 , 'M' , 512 , 512 , 512 , 'M' , 512 , 512 , 512 , 'M' ],
83
90
'E' : [64 , 64 , 'M' , 128 , 128 , 'M' , 256 , 256 , 256 , 256 , 'M' , 512 , 512 , 512 , 512 , 'M' , 512 , 512 , 512 , 512 , 'M' ],
84
91
}
85
92
86
93
87
- def _vgg (arch , cfg , batch_norm , pretrained , progress , ** kwargs ) :
94
+ def _vgg (arch : str , cfg : str , batch_norm : bool , pretrained : bool , progress : bool , ** kwargs : Any ) -> VGG :
88
95
if pretrained :
89
96
kwargs ['init_weights' ] = False
90
97
model = VGG (make_layers (cfgs [cfg ], batch_norm = batch_norm ), ** kwargs )
@@ -95,7 +102,7 @@ def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
95
102
return model
96
103
97
104
98
- def vgg11 (pretrained = False , progress = True , ** kwargs ) :
105
+ def vgg11 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
99
106
r"""VGG 11-layer model (configuration "A") from
100
107
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
101
108
@@ -106,7 +113,7 @@ def vgg11(pretrained=False, progress=True, **kwargs):
106
113
return _vgg ('vgg11' , 'A' , False , pretrained , progress , ** kwargs )
107
114
108
115
109
- def vgg11_bn (pretrained = False , progress = True , ** kwargs ) :
116
+ def vgg11_bn (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
110
117
r"""VGG 11-layer model (configuration "A") with batch normalization
111
118
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
112
119
@@ -117,7 +124,7 @@ def vgg11_bn(pretrained=False, progress=True, **kwargs):
117
124
return _vgg ('vgg11_bn' , 'A' , True , pretrained , progress , ** kwargs )
118
125
119
126
120
- def vgg13 (pretrained = False , progress = True , ** kwargs ) :
127
+ def vgg13 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
121
128
r"""VGG 13-layer model (configuration "B")
122
129
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
123
130
@@ -128,7 +135,7 @@ def vgg13(pretrained=False, progress=True, **kwargs):
128
135
return _vgg ('vgg13' , 'B' , False , pretrained , progress , ** kwargs )
129
136
130
137
131
- def vgg13_bn (pretrained = False , progress = True , ** kwargs ) :
138
+ def vgg13_bn (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
132
139
r"""VGG 13-layer model (configuration "B") with batch normalization
133
140
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
134
141
@@ -139,7 +146,7 @@ def vgg13_bn(pretrained=False, progress=True, **kwargs):
139
146
return _vgg ('vgg13_bn' , 'B' , True , pretrained , progress , ** kwargs )
140
147
141
148
142
- def vgg16 (pretrained = False , progress = True , ** kwargs ) :
149
+ def vgg16 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
143
150
r"""VGG 16-layer model (configuration "D")
144
151
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
145
152
@@ -150,7 +157,7 @@ def vgg16(pretrained=False, progress=True, **kwargs):
150
157
return _vgg ('vgg16' , 'D' , False , pretrained , progress , ** kwargs )
151
158
152
159
153
- def vgg16_bn (pretrained = False , progress = True , ** kwargs ) :
160
+ def vgg16_bn (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
154
161
r"""VGG 16-layer model (configuration "D") with batch normalization
155
162
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
156
163
@@ -161,7 +168,7 @@ def vgg16_bn(pretrained=False, progress=True, **kwargs):
161
168
return _vgg ('vgg16_bn' , 'D' , True , pretrained , progress , ** kwargs )
162
169
163
170
164
- def vgg19 (pretrained = False , progress = True , ** kwargs ) :
171
+ def vgg19 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
165
172
r"""VGG 19-layer model (configuration "E")
166
173
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
167
174
@@ -172,7 +179,7 @@ def vgg19(pretrained=False, progress=True, **kwargs):
172
179
return _vgg ('vgg19' , 'E' , False , pretrained , progress , ** kwargs )
173
180
174
181
175
- def vgg19_bn (pretrained = False , progress = True , ** kwargs ) :
182
+ def vgg19_bn (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VGG :
176
183
r"""VGG 19-layer model (configuration 'E') with batch normalization
177
184
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
178
185
0 commit comments