1
+ from torch import Tensor
1
2
import torch .nn as nn
3
+ from typing import Tuple , Optional , Callable , List , Type , Any , Union
2
4
3
5
from ..._internally_replaced_utils import load_state_dict_from_url
4
6
13
15
14
16
15
17
class Conv3DSimple (nn .Conv3d ):
16
- def __init__ (self ,
17
- in_planes ,
18
- out_planes ,
19
- midplanes = None ,
20
- stride = 1 ,
21
- padding = 1 ):
18
+ def __init__ (
19
+ self ,
20
+ in_planes : int ,
21
+ out_planes : int ,
22
+ midplanes : Optional [int ] = None ,
23
+ stride : int = 1 ,
24
+ padding : int = 1
25
+ ) -> None :
22
26
23
27
super (Conv3DSimple , self ).__init__ (
24
28
in_channels = in_planes ,
@@ -29,18 +33,20 @@ def __init__(self,
29
33
bias = False )
30
34
31
35
@staticmethod
32
- def get_downsample_stride (stride ) :
36
+ def get_downsample_stride (stride : int ) -> Tuple [ int , int , int ] :
33
37
return stride , stride , stride
34
38
35
39
36
40
class Conv2Plus1D (nn .Sequential ):
37
41
38
- def __init__ (self ,
39
- in_planes ,
40
- out_planes ,
41
- midplanes ,
42
- stride = 1 ,
43
- padding = 1 ):
42
+ def __init__ (
43
+ self ,
44
+ in_planes : int ,
45
+ out_planes : int ,
46
+ midplanes : int ,
47
+ stride : int = 1 ,
48
+ padding : int = 1
49
+ ) -> None :
44
50
super (Conv2Plus1D , self ).__init__ (
45
51
nn .Conv3d (in_planes , midplanes , kernel_size = (1 , 3 , 3 ),
46
52
stride = (1 , stride , stride ), padding = (0 , padding , padding ),
@@ -52,18 +58,20 @@ def __init__(self,
52
58
bias = False ))
53
59
54
60
@staticmethod
55
- def get_downsample_stride (stride ) :
61
+ def get_downsample_stride (stride : int ) -> Tuple [ int , int , int ] :
56
62
return stride , stride , stride
57
63
58
64
59
65
class Conv3DNoTemporal (nn .Conv3d ):
60
66
61
- def __init__ (self ,
62
- in_planes ,
63
- out_planes ,
64
- midplanes = None ,
65
- stride = 1 ,
66
- padding = 1 ):
67
+ def __init__ (
68
+ self ,
69
+ in_planes : int ,
70
+ out_planes : int ,
71
+ midplanes : Optional [int ] = None ,
72
+ stride : int = 1 ,
73
+ padding : int = 1
74
+ ) -> None :
67
75
68
76
super (Conv3DNoTemporal , self ).__init__ (
69
77
in_channels = in_planes ,
@@ -74,15 +82,22 @@ def __init__(self,
74
82
bias = False )
75
83
76
84
@staticmethod
77
- def get_downsample_stride (stride ) :
85
+ def get_downsample_stride (stride : int ) -> Tuple [ int , int , int ] :
78
86
return 1 , stride , stride
79
87
80
88
81
89
class BasicBlock (nn .Module ):
82
90
83
91
expansion = 1
84
92
85
- def __init__ (self , inplanes , planes , conv_builder , stride = 1 , downsample = None ):
93
+ def __init__ (
94
+ self ,
95
+ inplanes : int ,
96
+ planes : int ,
97
+ conv_builder : Callable [..., nn .Module ],
98
+ stride : int = 1 ,
99
+ downsample : Optional [nn .Module ] = None ,
100
+ ) -> None :
86
101
midplanes = (inplanes * planes * 3 * 3 * 3 ) // (inplanes * 3 * 3 + 3 * planes )
87
102
88
103
super (BasicBlock , self ).__init__ ()
@@ -99,7 +114,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
99
114
self .downsample = downsample
100
115
self .stride = stride
101
116
102
- def forward (self , x ) :
117
+ def forward (self , x : Tensor ) -> Tensor :
103
118
residual = x
104
119
105
120
out = self .conv1 (x )
@@ -116,7 +131,14 @@ def forward(self, x):
116
131
class Bottleneck (nn .Module ):
117
132
expansion = 4
118
133
119
- def __init__ (self , inplanes , planes , conv_builder , stride = 1 , downsample = None ):
134
+ def __init__ (
135
+ self ,
136
+ inplanes : int ,
137
+ planes : int ,
138
+ conv_builder : Callable [..., nn .Module ],
139
+ stride : int = 1 ,
140
+ downsample : Optional [nn .Module ] = None ,
141
+ ) -> None :
120
142
121
143
super (Bottleneck , self ).__init__ ()
122
144
midplanes = (inplanes * planes * 3 * 3 * 3 ) // (inplanes * 3 * 3 + 3 * planes )
@@ -143,7 +165,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
143
165
self .downsample = downsample
144
166
self .stride = stride
145
167
146
- def forward (self , x ) :
168
+ def forward (self , x : Tensor ) -> Tensor :
147
169
residual = x
148
170
149
171
out = self .conv1 (x )
@@ -162,7 +184,7 @@ def forward(self, x):
162
184
class BasicStem (nn .Sequential ):
163
185
"""The default conv-batchnorm-relu stem
164
186
"""
165
- def __init__ (self ):
187
+ def __init__ (self ) -> None :
166
188
super (BasicStem , self ).__init__ (
167
189
nn .Conv3d (3 , 64 , kernel_size = (3 , 7 , 7 ), stride = (1 , 2 , 2 ),
168
190
padding = (1 , 3 , 3 ), bias = False ),
@@ -173,7 +195,7 @@ def __init__(self):
173
195
class R2Plus1dStem (nn .Sequential ):
174
196
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
175
197
"""
176
- def __init__ (self ):
198
+ def __init__ (self ) -> None :
177
199
super (R2Plus1dStem , self ).__init__ (
178
200
nn .Conv3d (3 , 45 , kernel_size = (1 , 7 , 7 ),
179
201
stride = (1 , 2 , 2 ), padding = (0 , 3 , 3 ),
@@ -189,16 +211,23 @@ def __init__(self):
189
211
190
212
class VideoResNet (nn .Module ):
191
213
192
- def __init__ (self , block , conv_makers , layers ,
193
- stem , num_classes = 400 ,
194
- zero_init_residual = False ):
214
+ def __init__ (
215
+ self ,
216
+ block : Type [Union [BasicBlock , Bottleneck ]],
217
+ conv_makers : List [Type [Union [Conv3DSimple , Conv3DNoTemporal , Conv2Plus1D ]]],
218
+ layers : List [int ],
219
+ stem : Callable [..., nn .Module ],
220
+ num_classes : int = 400 ,
221
+ zero_init_residual : bool = False ,
222
+ ) -> None :
195
223
"""Generic resnet video generator.
196
224
197
225
Args:
198
- block (nn.Module): resnet building block
199
- conv_makers (list(functions)): generator function for each layer
226
+ block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
227
+ conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
228
+ function for each layer
200
229
layers (List[int]): number of blocks per layer
201
- stem (nn.Module, optional ): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None .
230
+ stem (Callable[..., nn.Module] ): module specifying the ResNet stem .
202
231
num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
203
232
zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
204
233
"""
@@ -221,9 +250,9 @@ def __init__(self, block, conv_makers, layers,
221
250
if zero_init_residual :
222
251
for m in self .modules ():
223
252
if isinstance (m , Bottleneck ):
224
- nn .init .constant_ (m .bn3 .weight , 0 )
253
+ nn .init .constant_ (m .bn3 .weight , 0 ) # type: ignore[union-attr, arg-type]
225
254
226
- def forward (self , x ) :
255
+ def forward (self , x : Tensor ) -> Tensor :
227
256
x = self .stem (x )
228
257
229
258
x = self .layer1 (x )
@@ -238,7 +267,14 @@ def forward(self, x):
238
267
239
268
return x
240
269
241
- def _make_layer (self , block , conv_builder , planes , blocks , stride = 1 ):
270
+ def _make_layer (
271
+ self ,
272
+ block : Type [Union [BasicBlock , Bottleneck ]],
273
+ conv_builder : Type [Union [Conv3DSimple , Conv3DNoTemporal , Conv2Plus1D ]],
274
+ planes : int ,
275
+ blocks : int ,
276
+ stride : int = 1
277
+ ) -> nn .Sequential :
242
278
downsample = None
243
279
244
280
if stride != 1 or self .inplanes != planes * block .expansion :
@@ -257,7 +293,7 @@ def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
257
293
258
294
return nn .Sequential (* layers )
259
295
260
- def _initialize_weights (self ):
296
+ def _initialize_weights (self ) -> None :
261
297
for m in self .modules ():
262
298
if isinstance (m , nn .Conv3d ):
263
299
nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' ,
@@ -272,7 +308,7 @@ def _initialize_weights(self):
272
308
nn .init .constant_ (m .bias , 0 )
273
309
274
310
275
- def _video_resnet (arch , pretrained = False , progress = True , ** kwargs ) :
311
+ def _video_resnet (arch : str , pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
276
312
model = VideoResNet (** kwargs )
277
313
278
314
if pretrained :
@@ -282,7 +318,7 @@ def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
282
318
return model
283
319
284
320
285
- def r3d_18 (pretrained = False , progress = True , ** kwargs ) :
321
+ def r3d_18 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
286
322
"""Construct 18 layer Resnet3D model as in
287
323
https://arxiv.org/abs/1711.11248
288
324
@@ -302,7 +338,7 @@ def r3d_18(pretrained=False, progress=True, **kwargs):
302
338
stem = BasicStem , ** kwargs )
303
339
304
340
305
- def mc3_18 (pretrained = False , progress = True , ** kwargs ) :
341
+ def mc3_18 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
306
342
"""Constructor for 18 layer Mixed Convolution network as in
307
343
https://arxiv.org/abs/1711.11248
308
344
@@ -316,12 +352,12 @@ def mc3_18(pretrained=False, progress=True, **kwargs):
316
352
return _video_resnet ('mc3_18' ,
317
353
pretrained , progress ,
318
354
block = BasicBlock ,
319
- conv_makers = [Conv3DSimple ] + [Conv3DNoTemporal ] * 3 ,
355
+ conv_makers = [Conv3DSimple ] + [Conv3DNoTemporal ] * 3 , # type: ignore[list-item]
320
356
layers = [2 , 2 , 2 , 2 ],
321
357
stem = BasicStem , ** kwargs )
322
358
323
359
324
- def r2plus1d_18 (pretrained = False , progress = True , ** kwargs ) :
360
+ def r2plus1d_18 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
325
361
"""Constructor for the 18 layer deep R(2+1)D network as in
326
362
https://arxiv.org/abs/1711.11248
327
363
0 commit comments