Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit ea6a56b

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Add Squeeze and Excitation to ResNeXt models (#426)
Summary: Pull Request resolved: #426 Added a `SqueezeAndExcitation` layer to a new sub-package, `models.common` (open to other suggestions, I didn't want to have a `generic.py` or `util.py` as that is too vague and broad). Plugged in the layer to `ResNeXt` models. Differential Revision: D20283172 fbshipit-source-id: 21d5183a61d7aa13fca094afe95ecb0aa18f1632
1 parent 636740b commit ea6a56b

File tree

4 files changed

+150
-59
lines changed

4 files changed

+150
-59
lines changed

classy_vision/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def build_model(config):
9393
from .resnet import ResNet # isort:skip
9494
from .resnext import ResNeXt # isort:skip
9595
from .resnext3d import ResNeXt3D # isort:skip
96+
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer # isort:skip
9697

9798

9899
__all__ = [
@@ -107,4 +108,5 @@ def build_model(config):
107108
"ResNet",
108109
"ResNeXt",
109110
"ResNeXt3D",
111+
"SqueezeAndExcitationLayer",
110112
]

classy_vision/models/resnext.py

Lines changed: 89 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Implementation of ResNeXt (https://arxiv.org/pdf/1611.05431.pdf)
99
"""
1010

11+
import copy
1112
import math
1213
from typing import Any, Dict, List, Optional, Tuple, Union
1314

@@ -16,6 +17,7 @@
1617

1718
from . import register_model
1819
from .classy_model import ClassyModel
20+
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer
1921

2022

2123
# global setting for in-place ReLU:
@@ -55,6 +57,8 @@ def __init__(
5557
mid_planes_and_cardinality=None,
5658
reduction=4,
5759
final_bn_relu=True,
60+
use_se=False,
61+
se_reduction_ratio=16,
5862
):
5963

6064
# assertions on inputs:
@@ -79,6 +83,12 @@ def __init__(
7983
nn.BatchNorm2d(out_planes),
8084
)
8185

86+
self.se = (
87+
SqueezeAndExcitationLayer(out_planes, reduction_ratio=se_reduction_ratio)
88+
if use_se
89+
else None
90+
)
91+
8292
def forward(self, x):
8393

8494
# if required, perform downsampling along shortcut connection:
@@ -92,6 +102,10 @@ def forward(self, x):
92102

93103
if self.final_bn_relu:
94104
out = self.bn(out)
105+
106+
if self.se is not None:
107+
out = self.se(out)
108+
95109
# add residual connection, perform rely + batchnorm, and return result:
96110
out += residual
97111
if self.final_bn_relu:
@@ -101,7 +115,7 @@ def forward(self, x):
101115

102116
class BasicLayer(GenericLayer):
103117
"""
104-
ResNeXt bottleneck layer with `in_planes` input planes and `out_planes`
118+
ResNeXt layer with `in_planes` input planes and `out_planes`
105119
output planes.
106120
"""
107121

@@ -113,6 +127,8 @@ def __init__(
113127
mid_planes_and_cardinality=None,
114128
reduction=4,
115129
final_bn_relu=True,
130+
use_se=False,
131+
se_reduction_ratio=16,
116132
):
117133

118134
# assertions on inputs:
@@ -128,13 +144,15 @@ def __init__(
128144
)
129145

130146
# call constructor of generic layer:
131-
super(BasicLayer, self).__init__(
147+
super().__init__(
132148
convolutional_block,
133149
in_planes,
134150
out_planes,
135151
stride=stride,
136152
reduction=reduction,
137153
final_bn_relu=final_bn_relu,
154+
use_se=use_se,
155+
se_reduction_ratio=se_reduction_ratio,
138156
)
139157

140158

@@ -152,6 +170,8 @@ def __init__(
152170
mid_planes_and_cardinality=None,
153171
reduction=4,
154172
final_bn_relu=True,
173+
use_se=False,
174+
se_reduction_ratio=16,
155175
):
156176

157177
# assertions on inputs:
@@ -185,6 +205,8 @@ def __init__(
185205
stride=stride,
186206
reduction=reduction,
187207
final_bn_relu=final_bn_relu,
208+
use_se=use_se,
209+
se_reduction_ratio=se_reduction_ratio,
188210
)
189211

190212

@@ -236,14 +258,20 @@ def __init__(
236258
basic_layer: bool = False,
237259
final_bn_relu: bool = True,
238260
bn_weight_decay: Optional[bool] = False,
261+
use_se: bool = False,
262+
se_reduction_ratio: int = 16,
239263
):
240264
"""
241265
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
242266
243-
Set ``small_input`` to `True` for 32x32 sized image inputs.
244-
245-
Set ``final_bn_relu`` to `False` to exclude the final batchnorm and
246-
ReLU layers. These settings are useful when training Siamese networks.
267+
Args:
268+
small_input: set to `True` for 32x32 sized image inputs.
269+
final_bn_relu: set to `False` to exclude the final batchnorm and
270+
ReLU layers. These settings are useful when training Siamese
271+
networks.
272+
use_se: Enable squeeze and excitation
273+
se_reduction_ratio: The reduction ratio to apply in the excitation
274+
stage. Only used if `use_se` is `True`.
247275
"""
248276
super().__init__()
249277

@@ -263,6 +291,7 @@ def __init__(
263291
and is_pos_int(base_width_and_cardinality[0])
264292
and is_pos_int(base_width_and_cardinality[1])
265293
)
294+
assert isinstance(use_se, bool), "use_se has to be a boolean"
266295

267296
# Chooses whether to apply weight decay to batch norm
268297
# parameters. This improves results in some situations,
@@ -295,6 +324,8 @@ def __init__(
295324
mid_planes_and_cardinality=mid_planes_and_cardinality,
296325
reduction=reduction,
297326
final_bn_relu=final_bn_relu or (idx != (len(out_planes) - 1)),
327+
use_se=use_se,
328+
se_reduction_ratio=se_reduction_ratio,
298329
)
299330
blocks.append(nn.Sequential(*new_block))
300331
self.blocks = nn.Sequential(*blocks)
@@ -337,6 +368,8 @@ def _make_resolution_block(
337368
mid_planes_and_cardinality=None,
338369
reduction=4,
339370
final_bn_relu=True,
371+
use_se=False,
372+
se_reduction_ratio=16,
340373
):
341374

342375
# add the desired number of residual blocks:
@@ -352,6 +385,8 @@ def _make_resolution_block(
352385
mid_planes_and_cardinality=mid_planes_and_cardinality,
353386
reduction=reduction,
354387
final_bn_relu=final_bn_relu or (idx != (num_blocks - 1)),
388+
use_se=use_se,
389+
se_reduction_ratio=se_reduction_ratio,
355390
),
356391
)
357392
)
@@ -379,6 +414,8 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
379414
"final_bn_relu": config.get("final_bn_relu", True),
380415
"zero_init_bn_residuals": config.get("zero_init_bn_residuals", False),
381416
"bn_weight_decay": config.get("bn_weight_decay", False),
417+
"use_se": config.get("use_se", False),
418+
"se_reduction_ratio": config.get("se_reduction_ratio", 16),
382419
}
383420
return cls(**config)
384421

@@ -421,65 +458,68 @@ def model_depth(self):
421458
return sum(self.num_blocks)
422459

423460

461+
class _ResNeXt(ResNeXt):
462+
@classmethod
463+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
464+
config = copy.deepcopy(config)
465+
config.pop("name")
466+
return cls(**config)
467+
468+
424469
@register_model("resnet18")
425-
class ResNet18(ResNeXt):
426-
def __init__(self):
470+
class ResNet18(_ResNeXt):
471+
def __init__(self, **kwargs):
427472
super().__init__(
428-
num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True
473+
num_blocks=[2, 2, 2, 2],
474+
basic_layer=True,
475+
zero_init_bn_residuals=True,
476+
**kwargs,
429477
)
430478

431-
@classmethod
432-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
433-
return cls()
434-
435479

436480
@register_model("resnet34")
437481
class ResNet34(ResNeXt):
438-
def __init__(self):
482+
def __init__(self, **kwargs):
439483
super().__init__(
440-
num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True
484+
num_blocks=[3, 4, 6, 3],
485+
basic_layer=True,
486+
zero_init_bn_residuals=True,
487+
**kwargs,
441488
)
442489

443-
@classmethod
444-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
445-
return cls()
446-
447490

448491
@register_model("resnet50")
449-
class ResNet50(ResNeXt):
450-
def __init__(self):
492+
class ResNet50(_ResNeXt):
493+
def __init__(self, **kwargs):
451494
super().__init__(
452-
num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True
495+
num_blocks=[3, 4, 6, 3],
496+
basic_layer=False,
497+
zero_init_bn_residuals=True,
498+
**kwargs,
453499
)
454500

455-
@classmethod
456-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
457-
return cls()
458-
459501

460502
@register_model("resnet101")
461-
class ResNet101(ResNeXt):
462-
def __init__(self):
503+
class ResNet101(_ResNeXt):
504+
def __init__(self, **kwargs):
463505
super().__init__(
464-
num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True
506+
num_blocks=[3, 4, 23, 3],
507+
basic_layer=False,
508+
zero_init_bn_residuals=True,
509+
**kwargs,
465510
)
466511

467-
@classmethod
468-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
469-
return cls()
470-
471512

472513
@register_model("resnet152")
473-
class ResNet152(ResNeXt):
474-
def __init__(self):
514+
class ResNet152(_ResNeXt):
515+
def __init__(self, **kwargs):
475516
super().__init__(
476-
num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True
517+
num_blocks=[3, 8, 36, 3],
518+
basic_layer=False,
519+
zero_init_bn_residuals=True,
520+
**kwargs,
477521
)
478522

479-
@classmethod
480-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
481-
return cls()
482-
483523

484524
# Note, the ResNeXt models all have weight decay enabled for the batch
485525
# norm parameters. We have found empirically that this gives better
@@ -488,48 +528,39 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
488528
# training on other datasets, we have observed losses in accuracy (for
489529
# example, the dataset used in https://arxiv.org/abs/1805.00932).
490530
@register_model("resnext50_32x4d")
491-
class ResNeXt50(ResNeXt):
492-
def __init__(self):
531+
class ResNeXt50(_ResNeXt):
532+
def __init__(self, **kwargs):
493533
super().__init__(
494534
num_blocks=[3, 4, 6, 3],
495535
basic_layer=False,
496536
zero_init_bn_residuals=True,
497537
base_width_and_cardinality=(4, 32),
498538
bn_weight_decay=True,
539+
**kwargs,
499540
)
500541

501-
@classmethod
502-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
503-
return cls()
504-
505542

506543
@register_model("resnext101_32x4d")
507-
class ResNeXt101(ResNeXt):
508-
def __init__(self):
544+
class ResNeXt101(_ResNeXt):
545+
def __init__(self, **kwargs):
509546
super().__init__(
510547
num_blocks=[3, 4, 23, 3],
511548
basic_layer=False,
512549
zero_init_bn_residuals=True,
513550
base_width_and_cardinality=(4, 32),
514551
bn_weight_decay=True,
552+
**kwargs,
515553
)
516554

517-
@classmethod
518-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
519-
return cls()
520-
521555

522556
@register_model("resnext152_32x4d")
523-
class ResNeXt152(ResNeXt):
524-
def __init__(self):
557+
class ResNeXt152(_ResNeXt):
558+
def __init__(self, **kwargs):
525559
super().__init__(
526560
num_blocks=[3, 8, 36, 3],
527561
basic_layer=False,
528562
zero_init_bn_residuals=True,
529563
base_width_and_cardinality=(4, 32),
530564
bn_weight_decay=True,
565+
**kwargs,
531566
)
532-
533-
@classmethod
534-
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
535-
return cls()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch.nn as nn
9+
10+
11+
class SqueezeAndExcitationLayer(nn.Module):
12+
"""Squeeze and excitation layer, as per https://arxiv.org/pdf/1709.01507.pdf"""
13+
14+
def __init__(self, in_planes, reduction_ratio=16):
15+
super().__init__()
16+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
17+
reduced_planes = in_planes // reduction_ratio
18+
self.excitation = nn.Sequential(
19+
nn.Conv2d(in_planes, reduced_planes, kernel_size=1, stride=1, bias=True),
20+
nn.ReLU(),
21+
nn.Conv2d(reduced_planes, in_planes, kernel_size=1, stride=1, bias=True),
22+
nn.Sigmoid(),
23+
)
24+
25+
def forward(self, x):
26+
x_squeezed = self.avgpool(x)
27+
x_excited = self.excitation(x_squeezed)
28+
x_scaled = x * x_excited
29+
return x_scaled

0 commit comments

Comments
 (0)