8
8
Implementation of ResNeXt (https://arxiv.org/pdf/1611.05431.pdf)
9
9
"""
10
10
11
+ import copy
11
12
import math
12
13
from typing import Any , Dict , List , Optional , Tuple , Union
13
14
16
17
17
18
from . import register_model
18
19
from .classy_model import ClassyModel
20
+ from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer
19
21
20
22
21
23
# global setting for in-place ReLU:
@@ -55,6 +57,8 @@ def __init__(
55
57
mid_planes_and_cardinality = None ,
56
58
reduction = 4 ,
57
59
final_bn_relu = True ,
60
+ use_se = False ,
61
+ se_reduction_ratio = 16 ,
58
62
):
59
63
60
64
# assertions on inputs:
@@ -79,6 +83,12 @@ def __init__(
79
83
nn .BatchNorm2d (out_planes ),
80
84
)
81
85
86
+ self .se = (
87
+ SqueezeAndExcitationLayer (out_planes , reduction_ratio = se_reduction_ratio )
88
+ if use_se
89
+ else None
90
+ )
91
+
82
92
def forward (self , x ):
83
93
84
94
# if required, perform downsampling along shortcut connection:
@@ -92,6 +102,10 @@ def forward(self, x):
92
102
93
103
if self .final_bn_relu :
94
104
out = self .bn (out )
105
+
106
+ if self .se is not None :
107
+ out = self .se (out )
108
+
95
109
# add residual connection, perform rely + batchnorm, and return result:
96
110
out += residual
97
111
if self .final_bn_relu :
@@ -101,7 +115,7 @@ def forward(self, x):
101
115
102
116
class BasicLayer (GenericLayer ):
103
117
"""
104
- ResNeXt bottleneck layer with `in_planes` input planes and `out_planes`
118
+ ResNeXt layer with `in_planes` input planes and `out_planes`
105
119
output planes.
106
120
"""
107
121
@@ -113,6 +127,8 @@ def __init__(
113
127
mid_planes_and_cardinality = None ,
114
128
reduction = 4 ,
115
129
final_bn_relu = True ,
130
+ use_se = False ,
131
+ se_reduction_ratio = 16 ,
116
132
):
117
133
118
134
# assertions on inputs:
@@ -128,13 +144,15 @@ def __init__(
128
144
)
129
145
130
146
# call constructor of generic layer:
131
- super (BasicLayer , self ).__init__ (
147
+ super ().__init__ (
132
148
convolutional_block ,
133
149
in_planes ,
134
150
out_planes ,
135
151
stride = stride ,
136
152
reduction = reduction ,
137
153
final_bn_relu = final_bn_relu ,
154
+ use_se = use_se ,
155
+ se_reduction_ratio = se_reduction_ratio ,
138
156
)
139
157
140
158
@@ -152,6 +170,8 @@ def __init__(
152
170
mid_planes_and_cardinality = None ,
153
171
reduction = 4 ,
154
172
final_bn_relu = True ,
173
+ use_se = False ,
174
+ se_reduction_ratio = 16 ,
155
175
):
156
176
157
177
# assertions on inputs:
@@ -185,6 +205,8 @@ def __init__(
185
205
stride = stride ,
186
206
reduction = reduction ,
187
207
final_bn_relu = final_bn_relu ,
208
+ use_se = use_se ,
209
+ se_reduction_ratio = se_reduction_ratio ,
188
210
)
189
211
190
212
@@ -236,14 +258,20 @@ def __init__(
236
258
basic_layer : bool = False ,
237
259
final_bn_relu : bool = True ,
238
260
bn_weight_decay : Optional [bool ] = False ,
261
+ use_se : bool = False ,
262
+ se_reduction_ratio : int = 16 ,
239
263
):
240
264
"""
241
265
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
242
266
243
- Set ``small_input`` to `True` for 32x32 sized image inputs.
244
-
245
- Set ``final_bn_relu`` to `False` to exclude the final batchnorm and
246
- ReLU layers. These settings are useful when training Siamese networks.
267
+ Args:
268
+ small_input: set to `True` for 32x32 sized image inputs.
269
+ final_bn_relu: set to `False` to exclude the final batchnorm and
270
+ ReLU layers. These settings are useful when training Siamese
271
+ networks.
272
+ use_se: Enable squeeze and excitation
273
+ se_reduction_ratio: The reduction ratio to apply in the excitation
274
+ stage. Only used if `use_se` is `True`.
247
275
"""
248
276
super ().__init__ ()
249
277
@@ -263,6 +291,7 @@ def __init__(
263
291
and is_pos_int (base_width_and_cardinality [0 ])
264
292
and is_pos_int (base_width_and_cardinality [1 ])
265
293
)
294
+ assert isinstance (use_se , bool ), "use_se has to be a boolean"
266
295
267
296
# Chooses whether to apply weight decay to batch norm
268
297
# parameters. This improves results in some situations,
@@ -295,6 +324,8 @@ def __init__(
295
324
mid_planes_and_cardinality = mid_planes_and_cardinality ,
296
325
reduction = reduction ,
297
326
final_bn_relu = final_bn_relu or (idx != (len (out_planes ) - 1 )),
327
+ use_se = use_se ,
328
+ se_reduction_ratio = se_reduction_ratio ,
298
329
)
299
330
blocks .append (nn .Sequential (* new_block ))
300
331
self .blocks = nn .Sequential (* blocks )
@@ -337,6 +368,8 @@ def _make_resolution_block(
337
368
mid_planes_and_cardinality = None ,
338
369
reduction = 4 ,
339
370
final_bn_relu = True ,
371
+ use_se = False ,
372
+ se_reduction_ratio = 16 ,
340
373
):
341
374
342
375
# add the desired number of residual blocks:
@@ -352,6 +385,8 @@ def _make_resolution_block(
352
385
mid_planes_and_cardinality = mid_planes_and_cardinality ,
353
386
reduction = reduction ,
354
387
final_bn_relu = final_bn_relu or (idx != (num_blocks - 1 )),
388
+ use_se = use_se ,
389
+ se_reduction_ratio = se_reduction_ratio ,
355
390
),
356
391
)
357
392
)
@@ -379,6 +414,8 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
379
414
"final_bn_relu" : config .get ("final_bn_relu" , True ),
380
415
"zero_init_bn_residuals" : config .get ("zero_init_bn_residuals" , False ),
381
416
"bn_weight_decay" : config .get ("bn_weight_decay" , False ),
417
+ "use_se" : config .get ("use_se" , False ),
418
+ "se_reduction_ratio" : config .get ("se_reduction_ratio" , 16 ),
382
419
}
383
420
return cls (** config )
384
421
@@ -421,65 +458,68 @@ def model_depth(self):
421
458
return sum (self .num_blocks )
422
459
423
460
461
+ class _ResNeXt (ResNeXt ):
462
+ @classmethod
463
+ def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
464
+ config = copy .deepcopy (config )
465
+ config .pop ("name" )
466
+ return cls (** config )
467
+
468
+
424
469
@register_model ("resnet18" )
425
- class ResNet18 (ResNeXt ):
426
- def __init__ (self ):
470
+ class ResNet18 (_ResNeXt ):
471
+ def __init__ (self , ** kwargs ):
427
472
super ().__init__ (
428
- num_blocks = [2 , 2 , 2 , 2 ], basic_layer = True , zero_init_bn_residuals = True
473
+ num_blocks = [2 , 2 , 2 , 2 ],
474
+ basic_layer = True ,
475
+ zero_init_bn_residuals = True ,
476
+ ** kwargs ,
429
477
)
430
478
431
- @classmethod
432
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
433
- return cls ()
434
-
435
479
436
480
@register_model ("resnet34" )
437
481
class ResNet34 (ResNeXt ):
438
- def __init__ (self ):
482
+ def __init__ (self , ** kwargs ):
439
483
super ().__init__ (
440
- num_blocks = [3 , 4 , 6 , 3 ], basic_layer = True , zero_init_bn_residuals = True
484
+ num_blocks = [3 , 4 , 6 , 3 ],
485
+ basic_layer = True ,
486
+ zero_init_bn_residuals = True ,
487
+ ** kwargs ,
441
488
)
442
489
443
- @classmethod
444
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
445
- return cls ()
446
-
447
490
448
491
@register_model ("resnet50" )
449
- class ResNet50 (ResNeXt ):
450
- def __init__ (self ):
492
+ class ResNet50 (_ResNeXt ):
493
+ def __init__ (self , ** kwargs ):
451
494
super ().__init__ (
452
- num_blocks = [3 , 4 , 6 , 3 ], basic_layer = False , zero_init_bn_residuals = True
495
+ num_blocks = [3 , 4 , 6 , 3 ],
496
+ basic_layer = False ,
497
+ zero_init_bn_residuals = True ,
498
+ ** kwargs ,
453
499
)
454
500
455
- @classmethod
456
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
457
- return cls ()
458
-
459
501
460
502
@register_model ("resnet101" )
461
- class ResNet101 (ResNeXt ):
462
- def __init__ (self ):
503
+ class ResNet101 (_ResNeXt ):
504
+ def __init__ (self , ** kwargs ):
463
505
super ().__init__ (
464
- num_blocks = [3 , 4 , 23 , 3 ], basic_layer = False , zero_init_bn_residuals = True
506
+ num_blocks = [3 , 4 , 23 , 3 ],
507
+ basic_layer = False ,
508
+ zero_init_bn_residuals = True ,
509
+ ** kwargs ,
465
510
)
466
511
467
- @classmethod
468
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
469
- return cls ()
470
-
471
512
472
513
@register_model ("resnet152" )
473
- class ResNet152 (ResNeXt ):
474
- def __init__ (self ):
514
+ class ResNet152 (_ResNeXt ):
515
+ def __init__ (self , ** kwargs ):
475
516
super ().__init__ (
476
- num_blocks = [3 , 8 , 36 , 3 ], basic_layer = False , zero_init_bn_residuals = True
517
+ num_blocks = [3 , 8 , 36 , 3 ],
518
+ basic_layer = False ,
519
+ zero_init_bn_residuals = True ,
520
+ ** kwargs ,
477
521
)
478
522
479
- @classmethod
480
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
481
- return cls ()
482
-
483
523
484
524
# Note, the ResNeXt models all have weight decay enabled for the batch
485
525
# norm parameters. We have found empirically that this gives better
@@ -488,48 +528,39 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
488
528
# training on other datasets, we have observed losses in accuracy (for
489
529
# example, the dataset used in https://arxiv.org/abs/1805.00932).
490
530
@register_model ("resnext50_32x4d" )
491
- class ResNeXt50 (ResNeXt ):
492
- def __init__ (self ):
531
+ class ResNeXt50 (_ResNeXt ):
532
+ def __init__ (self , ** kwargs ):
493
533
super ().__init__ (
494
534
num_blocks = [3 , 4 , 6 , 3 ],
495
535
basic_layer = False ,
496
536
zero_init_bn_residuals = True ,
497
537
base_width_and_cardinality = (4 , 32 ),
498
538
bn_weight_decay = True ,
539
+ ** kwargs ,
499
540
)
500
541
501
- @classmethod
502
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
503
- return cls ()
504
-
505
542
506
543
@register_model ("resnext101_32x4d" )
507
- class ResNeXt101 (ResNeXt ):
508
- def __init__ (self ):
544
+ class ResNeXt101 (_ResNeXt ):
545
+ def __init__ (self , ** kwargs ):
509
546
super ().__init__ (
510
547
num_blocks = [3 , 4 , 23 , 3 ],
511
548
basic_layer = False ,
512
549
zero_init_bn_residuals = True ,
513
550
base_width_and_cardinality = (4 , 32 ),
514
551
bn_weight_decay = True ,
552
+ ** kwargs ,
515
553
)
516
554
517
- @classmethod
518
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
519
- return cls ()
520
-
521
555
522
556
@register_model ("resnext152_32x4d" )
523
- class ResNeXt152 (ResNeXt ):
524
- def __init__ (self ):
557
+ class ResNeXt152 (_ResNeXt ):
558
+ def __init__ (self , ** kwargs ):
525
559
super ().__init__ (
526
560
num_blocks = [3 , 8 , 36 , 3 ],
527
561
basic_layer = False ,
528
562
zero_init_bn_residuals = True ,
529
563
base_width_and_cardinality = (4 , 32 ),
530
564
bn_weight_decay = True ,
565
+ ** kwargs ,
531
566
)
532
-
533
- @classmethod
534
- def from_config (cls , config : Dict [str , Any ]) -> "ResNeXt" :
535
- return cls ()
0 commit comments