3
3
from functools import partial
4
4
from torch import nn , Tensor
5
5
from torch .nn import functional as F
6
- from typing import Any , Callable , List , Optional , Sequence
6
+ from typing import Any , Callable , Dict , List , Optional , Sequence
7
7
8
8
from torchvision .models .utils import load_state_dict_from_url
9
9
from torchvision .models .mobilenetv2 import _make_divisible , ConvBNActivation
@@ -24,14 +24,18 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4):
24
24
super ().__init__ ()
25
25
squeeze_channels = _make_divisible (input_channels // squeeze_factor , 8 )
26
26
self .fc1 = nn .Conv2d (input_channels , squeeze_channels , 1 )
27
+ self .relu = nn .ReLU (inplace = True )
27
28
self .fc2 = nn .Conv2d (squeeze_channels , input_channels , 1 )
28
29
29
- def forward (self , input : Tensor ) -> Tensor :
30
+ def _scale (self , input : Tensor , inplace : bool ) -> Tensor :
30
31
scale = F .adaptive_avg_pool2d (input , 1 )
31
32
scale = self .fc1 (scale )
32
- scale = F .relu (scale , inplace = True )
33
+ scale = self .relu (scale )
33
34
scale = self .fc2 (scale )
34
- scale = F .hardsigmoid (scale , inplace = True )
35
+ return F .hardsigmoid (scale , inplace = inplace )
36
+
37
+ def forward (self , input : Tensor ) -> Tensor :
38
+ scale = self ._scale (input , True )
35
39
return scale * input
36
40
37
41
@@ -55,7 +59,8 @@ def adjust_channels(channels: int, width_mult: float):
55
59
56
60
class InvertedResidual (nn .Module ):
57
61
58
- def __init__ (self , cnf : InvertedResidualConfig , norm_layer : Callable [..., nn .Module ]):
62
+ def __init__ (self , cnf : InvertedResidualConfig , norm_layer : Callable [..., nn .Module ],
63
+ se_layer : Callable [..., nn .Module ] = SqueezeExcitation ):
59
64
super ().__init__ ()
60
65
if not (1 <= cnf .stride <= 2 ):
61
66
raise ValueError ('illegal stride value' )
@@ -76,7 +81,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
76
81
stride = stride , dilation = cnf .dilation , groups = cnf .expanded_channels ,
77
82
norm_layer = norm_layer , activation_layer = activation_layer ))
78
83
if cnf .use_se :
79
- layers .append (SqueezeExcitation (cnf .expanded_channels ))
84
+ layers .append (se_layer (cnf .expanded_channels ))
80
85
81
86
# project
82
87
layers .append (ConvBNActivation (cnf .expanded_channels , cnf .out_channels , kernel_size = 1 , norm_layer = norm_layer ,
@@ -179,7 +184,56 @@ def forward(self, x: Tensor) -> Tensor:
179
184
return self ._forward_impl (x )
180
185
181
186
182
- def _mobilenet_v3 (
187
+ def _mobilenet_v3_conf (arch : str , params : Dict [str , Any ]):
188
+ # non-public config parameters
189
+ reduce_divider = 2 if params .pop ('_reduced_tail' , False ) else 1
190
+ dilation = 2 if params .pop ('_dilated' , False ) else 1
191
+ width_mult = params .pop ('_width_mult' , 1.0 )
192
+
193
+ bneck_conf = partial (InvertedResidualConfig , width_mult = width_mult )
194
+ adjust_channels = partial (InvertedResidualConfig .adjust_channels , width_mult = width_mult )
195
+
196
+ if arch == "mobilenet_v3_large" :
197
+ inverted_residual_setting = [
198
+ bneck_conf (16 , 3 , 16 , 16 , False , "RE" , 1 , 1 ),
199
+ bneck_conf (16 , 3 , 64 , 24 , False , "RE" , 2 , 1 ), # C1
200
+ bneck_conf (24 , 3 , 72 , 24 , False , "RE" , 1 , 1 ),
201
+ bneck_conf (24 , 5 , 72 , 40 , True , "RE" , 2 , 1 ), # C2
202
+ bneck_conf (40 , 5 , 120 , 40 , True , "RE" , 1 , 1 ),
203
+ bneck_conf (40 , 5 , 120 , 40 , True , "RE" , 1 , 1 ),
204
+ bneck_conf (40 , 3 , 240 , 80 , False , "HS" , 2 , 1 ), # C3
205
+ bneck_conf (80 , 3 , 200 , 80 , False , "HS" , 1 , 1 ),
206
+ bneck_conf (80 , 3 , 184 , 80 , False , "HS" , 1 , 1 ),
207
+ bneck_conf (80 , 3 , 184 , 80 , False , "HS" , 1 , 1 ),
208
+ bneck_conf (80 , 3 , 480 , 112 , True , "HS" , 1 , 1 ),
209
+ bneck_conf (112 , 3 , 672 , 112 , True , "HS" , 1 , 1 ),
210
+ bneck_conf (112 , 5 , 672 , 160 // reduce_divider , True , "HS" , 2 , dilation ), # C4
211
+ bneck_conf (160 // reduce_divider , 5 , 960 // reduce_divider , 160 // reduce_divider , True , "HS" , 1 , dilation ),
212
+ bneck_conf (160 // reduce_divider , 5 , 960 // reduce_divider , 160 // reduce_divider , True , "HS" , 1 , dilation ),
213
+ ]
214
+ last_channel = adjust_channels (1280 // reduce_divider ) # C5
215
+ elif arch == "mobilenet_v3_small" :
216
+ inverted_residual_setting = [
217
+ bneck_conf (16 , 3 , 16 , 16 , True , "RE" , 2 , 1 ), # C1
218
+ bneck_conf (16 , 3 , 72 , 24 , False , "RE" , 2 , 1 ), # C2
219
+ bneck_conf (24 , 3 , 88 , 24 , False , "RE" , 1 , 1 ),
220
+ bneck_conf (24 , 5 , 96 , 40 , True , "HS" , 2 , 1 ), # C3
221
+ bneck_conf (40 , 5 , 240 , 40 , True , "HS" , 1 , 1 ),
222
+ bneck_conf (40 , 5 , 240 , 40 , True , "HS" , 1 , 1 ),
223
+ bneck_conf (40 , 5 , 120 , 48 , True , "HS" , 1 , 1 ),
224
+ bneck_conf (48 , 5 , 144 , 48 , True , "HS" , 1 , 1 ),
225
+ bneck_conf (48 , 5 , 288 , 96 // reduce_divider , True , "HS" , 2 , dilation ), # C4
226
+ bneck_conf (96 // reduce_divider , 5 , 576 // reduce_divider , 96 // reduce_divider , True , "HS" , 1 , dilation ),
227
+ bneck_conf (96 // reduce_divider , 5 , 576 // reduce_divider , 96 // reduce_divider , True , "HS" , 1 , dilation ),
228
+ ]
229
+ last_channel = adjust_channels (1024 // reduce_divider ) # C5
230
+ else :
231
+ raise ValueError ("Unsupported model type {}" .format (arch ))
232
+
233
+ return inverted_residual_setting , last_channel
234
+
235
+
236
+ def _mobilenet_v3_model (
183
237
arch : str ,
184
238
inverted_residual_setting : List [InvertedResidualConfig ],
185
239
last_channel : int ,
@@ -205,34 +259,9 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
205
259
pretrained (bool): If True, returns a model pre-trained on ImageNet
206
260
progress (bool): If True, displays a progress bar of the download to stderr
207
261
"""
208
- # non-public config parameters
209
- reduce_divider = 2 if kwargs .pop ('_reduced_tail' , False ) else 1
210
- dilation = 2 if kwargs .pop ('_dilated' , False ) else 1
211
- width_mult = 1.0
212
-
213
- bneck_conf = partial (InvertedResidualConfig , width_mult = width_mult )
214
- adjust_channels = partial (InvertedResidualConfig .adjust_channels , width_mult = width_mult )
215
-
216
- inverted_residual_setting = [
217
- bneck_conf (16 , 3 , 16 , 16 , False , "RE" , 1 , 1 ),
218
- bneck_conf (16 , 3 , 64 , 24 , False , "RE" , 2 , 1 ), # C1
219
- bneck_conf (24 , 3 , 72 , 24 , False , "RE" , 1 , 1 ),
220
- bneck_conf (24 , 5 , 72 , 40 , True , "RE" , 2 , 1 ), # C2
221
- bneck_conf (40 , 5 , 120 , 40 , True , "RE" , 1 , 1 ),
222
- bneck_conf (40 , 5 , 120 , 40 , True , "RE" , 1 , 1 ),
223
- bneck_conf (40 , 3 , 240 , 80 , False , "HS" , 2 , 1 ), # C3
224
- bneck_conf (80 , 3 , 200 , 80 , False , "HS" , 1 , 1 ),
225
- bneck_conf (80 , 3 , 184 , 80 , False , "HS" , 1 , 1 ),
226
- bneck_conf (80 , 3 , 184 , 80 , False , "HS" , 1 , 1 ),
227
- bneck_conf (80 , 3 , 480 , 112 , True , "HS" , 1 , 1 ),
228
- bneck_conf (112 , 3 , 672 , 112 , True , "HS" , 1 , 1 ),
229
- bneck_conf (112 , 5 , 672 , 160 // reduce_divider , True , "HS" , 2 , dilation ), # C4
230
- bneck_conf (160 // reduce_divider , 5 , 960 // reduce_divider , 160 // reduce_divider , True , "HS" , 1 , dilation ),
231
- bneck_conf (160 // reduce_divider , 5 , 960 // reduce_divider , 160 // reduce_divider , True , "HS" , 1 , dilation ),
232
- ]
233
- last_channel = adjust_channels (1280 // reduce_divider ) # C5
234
-
235
- return _mobilenet_v3 ("mobilenet_v3_large" , inverted_residual_setting , last_channel , pretrained , progress , ** kwargs )
262
+ arch = "mobilenet_v3_large"
263
+ inverted_residual_setting , last_channel = _mobilenet_v3_conf (arch , kwargs )
264
+ return _mobilenet_v3_model (arch , inverted_residual_setting , last_channel , pretrained , progress , ** kwargs )
236
265
237
266
238
267
def mobilenet_v3_small (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> MobileNetV3 :
@@ -244,27 +273,6 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
244
273
pretrained (bool): If True, returns a model pre-trained on ImageNet
245
274
progress (bool): If True, displays a progress bar of the download to stderr
246
275
"""
247
- # non-public config parameters
248
- reduce_divider = 2 if kwargs .pop ('_reduced_tail' , False ) else 1
249
- dilation = 2 if kwargs .pop ('_dilated' , False ) else 1
250
- width_mult = 1.0
251
-
252
- bneck_conf = partial (InvertedResidualConfig , width_mult = width_mult )
253
- adjust_channels = partial (InvertedResidualConfig .adjust_channels , width_mult = width_mult )
254
-
255
- inverted_residual_setting = [
256
- bneck_conf (16 , 3 , 16 , 16 , True , "RE" , 2 , 1 ), # C1
257
- bneck_conf (16 , 3 , 72 , 24 , False , "RE" , 2 , 1 ), # C2
258
- bneck_conf (24 , 3 , 88 , 24 , False , "RE" , 1 , 1 ),
259
- bneck_conf (24 , 5 , 96 , 40 , True , "HS" , 2 , 1 ), # C3
260
- bneck_conf (40 , 5 , 240 , 40 , True , "HS" , 1 , 1 ),
261
- bneck_conf (40 , 5 , 240 , 40 , True , "HS" , 1 , 1 ),
262
- bneck_conf (40 , 5 , 120 , 48 , True , "HS" , 1 , 1 ),
263
- bneck_conf (48 , 5 , 144 , 48 , True , "HS" , 1 , 1 ),
264
- bneck_conf (48 , 5 , 288 , 96 // reduce_divider , True , "HS" , 2 , dilation ), # C4
265
- bneck_conf (96 // reduce_divider , 5 , 576 // reduce_divider , 96 // reduce_divider , True , "HS" , 1 , dilation ),
266
- bneck_conf (96 // reduce_divider , 5 , 576 // reduce_divider , 96 // reduce_divider , True , "HS" , 1 , dilation ),
267
- ]
268
- last_channel = adjust_channels (1024 // reduce_divider ) # C5
269
-
270
- return _mobilenet_v3 ("mobilenet_v3_small" , inverted_residual_setting , last_channel , pretrained , progress , ** kwargs )
276
+ arch = "mobilenet_v3_small"
277
+ inverted_residual_setting , last_channel = _mobilenet_v3_conf (arch , kwargs )
278
+ return _mobilenet_v3_model (arch , inverted_residual_setting , last_channel , pretrained , progress , ** kwargs )
0 commit comments