Skip to content

Commit ccea7a5

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Moving common layers to ops (#4504)
Summary: * Moving _make_divisible to utils. * Replace the old ConvBNReLU and ConvBNActivation layers * Fix minor bug. * Moving SE layer to ops. * Adding deprecation warnings on old layers. * Apply changes to regnets. Reviewed By: prabhat00155, NicolasHug Differential Revision: D31309549 fbshipit-source-id: 2780783ddfeb58974829607ac90f122b915f7366
1 parent bc32497 commit ccea7a5

File tree

9 files changed

+153
-129
lines changed

9 files changed

+153
-129
lines changed

torchvision/models/_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import OrderedDict
22

33
from torch import nn
4-
from typing import Dict
4+
from typing import Dict, Optional
55

66

77
class IntermediateLayerGetter(nn.ModuleDict):
@@ -64,3 +64,19 @@ def forward(self, x):
6464
out_name = self.return_layers[name]
6565
out[out_name] = x
6666
return out
67+
68+
69+
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
70+
"""
71+
This function is taken from the original tf repo.
72+
It ensures that all layers have a channel number that is divisible by 8
73+
It can be seen here:
74+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
75+
"""
76+
if min_value is None:
77+
min_value = divisor
78+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
79+
# Make sure that round down does not go down by more than 10%.
80+
if new_v < 0.9 * v:
81+
new_v += divisor
82+
return new_v

torchvision/models/detection/ssdlite.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from .anchor_utils import DefaultBoxGenerator
1212
from .backbone_utils import _validate_trainable_layers
1313
from .. import mobilenet
14-
from ..mobilenetv3 import ConvBNActivation
1514
from ..._internally_replaced_utils import load_state_dict_from_url
15+
from ...ops.misc import ConvNormActivation
1616

1717

1818
__all__ = ['ssdlite320_mobilenet_v3_large']
@@ -28,8 +28,8 @@ def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
2828
norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
2929
return nn.Sequential(
3030
# 3x3 depthwise with stride 1 and padding 1
31-
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
32-
norm_layer=norm_layer, activation_layer=nn.ReLU6),
31+
ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
32+
norm_layer=norm_layer, activation_layer=nn.ReLU6),
3333

3434
# 1x1 projetion to output channels
3535
nn.Conv2d(in_channels, out_channels, 1)
@@ -41,16 +41,16 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
4141
intermediate_channels = out_channels // 2
4242
return nn.Sequential(
4343
# 1x1 projection to half output channels
44-
ConvBNActivation(in_channels, intermediate_channels, kernel_size=1,
45-
norm_layer=norm_layer, activation_layer=activation),
44+
ConvNormActivation(in_channels, intermediate_channels, kernel_size=1,
45+
norm_layer=norm_layer, activation_layer=activation),
4646

4747
# 3x3 depthwise with stride 2 and padding 1
48-
ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
49-
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
48+
ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
49+
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
5050

5151
# 1x1 projetion to output channels
52-
ConvBNActivation(intermediate_channels, out_channels, kernel_size=1,
53-
norm_layer=norm_layer, activation_layer=activation),
52+
ConvNormActivation(intermediate_channels, out_channels, kernel_size=1,
53+
norm_layer=norm_layer, activation_layer=activation),
5454
)
5555

5656

torchvision/models/efficientnet.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
from functools import partial
66
from torch import nn, Tensor
7-
from torch.nn import functional as F
87
from typing import Any, Callable, List, Optional, Sequence
98

109
from .._internally_replaced_utils import load_state_dict_from_url
10+
from ..ops.misc import ConvNormActivation, SqueezeExcitation
11+
from ._utils import _make_divisible
1112
from torchvision.ops import StochasticDepth
1213

13-
from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible
14-
1514

1615
__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3",
1716
"efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"]
@@ -31,32 +30,6 @@
3130
}
3231

3332

34-
class SqueezeExcitation(nn.Module):
35-
def __init__(
36-
self,
37-
input_channels: int,
38-
squeeze_channels: int,
39-
activation: Callable[..., nn.Module] = nn.ReLU,
40-
scale_activation: Callable[..., nn.Module] = nn.Sigmoid,
41-
) -> None:
42-
super().__init__()
43-
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
44-
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
45-
self.activation = activation()
46-
self.scale_activation = scale_activation()
47-
48-
def _scale(self, input: Tensor) -> Tensor:
49-
scale = F.adaptive_avg_pool2d(input, 1)
50-
scale = self.fc1(scale)
51-
scale = self.activation(scale)
52-
scale = self.fc2(scale)
53-
return self.scale_activation(scale)
54-
55-
def forward(self, input: Tensor) -> Tensor:
56-
scale = self._scale(input)
57-
return scale * input
58-
59-
6033
class MBConvConfig:
6134
# Stores information listed at Table 1 of the EfficientNet paper
6235
def __init__(self,
@@ -106,21 +79,21 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer:
10679
# expand
10780
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
10881
if expanded_channels != cnf.input_channels:
109-
layers.append(ConvBNActivation(cnf.input_channels, expanded_channels, kernel_size=1,
110-
norm_layer=norm_layer, activation_layer=activation_layer))
82+
layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1,
83+
norm_layer=norm_layer, activation_layer=activation_layer))
11184

11285
# depthwise
113-
layers.append(ConvBNActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
114-
stride=cnf.stride, groups=expanded_channels,
115-
norm_layer=norm_layer, activation_layer=activation_layer))
86+
layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
87+
stride=cnf.stride, groups=expanded_channels,
88+
norm_layer=norm_layer, activation_layer=activation_layer))
11689

11790
# squeeze and excitation
11891
squeeze_channels = max(1, cnf.input_channels // 4)
11992
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
12093

12194
# project
122-
layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
123-
activation_layer=nn.Identity))
95+
layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
96+
activation_layer=None))
12497

12598
self.block = nn.Sequential(*layers)
12699
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
@@ -174,8 +147,8 @@ def __init__(
174147

175148
# building first layer
176149
firstconv_output_channels = inverted_residual_setting[0].input_channels
177-
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
178-
activation_layer=nn.SiLU))
150+
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
151+
activation_layer=nn.SiLU))
179152

180153
# building inverted residual blocks
181154
total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting])
@@ -202,8 +175,8 @@ def __init__(
202175
# building last several layers
203176
lastconv_input_channels = inverted_residual_setting[-1].out_channels
204177
lastconv_output_channels = 4 * lastconv_input_channels
205-
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
206-
norm_layer=norm_layer, activation_layer=nn.SiLU))
178+
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
179+
norm_layer=norm_layer, activation_layer=nn.SiLU))
207180

208181
self.features = nn.Sequential(*layers)
209182
self.avgpool = nn.AdaptiveAvgPool2d(1)

torchvision/models/mobilenetv2.py

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import torch
2+
import warnings
3+
4+
from functools import partial
25
from torch import nn
36
from torch import Tensor
47
from .._internally_replaced_utils import load_state_dict_from_url
8+
from ..ops.misc import ConvNormActivation
9+
from ._utils import _make_divisible
510
from typing import Callable, Any, Optional, List
611

712

@@ -13,50 +18,21 @@
1318
}
1419

1520

16-
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
17-
"""
18-
This function is taken from the original tf repo.
19-
It ensures that all layers have a channel number that is divisible by 8
20-
It can be seen here:
21-
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
22-
"""
23-
if min_value is None:
24-
min_value = divisor
25-
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
26-
# Make sure that round down does not go down by more than 10%.
27-
if new_v < 0.9 * v:
28-
new_v += divisor
29-
return new_v
30-
31-
32-
class ConvBNActivation(nn.Sequential):
33-
def __init__(
34-
self,
35-
in_planes: int,
36-
out_planes: int,
37-
kernel_size: int = 3,
38-
stride: int = 1,
39-
groups: int = 1,
40-
norm_layer: Optional[Callable[..., nn.Module]] = None,
41-
activation_layer: Optional[Callable[..., nn.Module]] = None,
42-
dilation: int = 1,
43-
) -> None:
44-
padding = (kernel_size - 1) // 2 * dilation
45-
if norm_layer is None:
46-
norm_layer = nn.BatchNorm2d
47-
if activation_layer is None:
48-
activation_layer = nn.ReLU6
49-
super().__init__(
50-
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
51-
bias=False),
52-
norm_layer(out_planes),
53-
activation_layer(inplace=True)
54-
)
55-
self.out_channels = out_planes
21+
# necessary for backwards compatibility
22+
class _DeprecatedConvBNAct(ConvNormActivation):
23+
def __init__(self, *args, **kwargs):
24+
warnings.warn(
25+
"The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. "
26+
"Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning)
27+
if kwargs.get("norm_layer", None) is None:
28+
kwargs["norm_layer"] = nn.BatchNorm2d
29+
if kwargs.get("activation_layer", None) is None:
30+
kwargs["activation_layer"] = nn.ReLU6
31+
super().__init__(*args, **kwargs)
5632

5733

58-
# necessary for backwards compatibility
59-
ConvBNReLU = ConvBNActivation
34+
ConvBNReLU = _DeprecatedConvBNAct
35+
ConvBNActivation = _DeprecatedConvBNAct
6036

6137

6238
class InvertedResidual(nn.Module):
@@ -81,10 +57,12 @@ def __init__(
8157
layers: List[nn.Module] = []
8258
if expand_ratio != 1:
8359
# pw
84-
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
60+
layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer,
61+
activation_layer=nn.ReLU6))
8562
layers.extend([
8663
# dw
87-
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
64+
ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer,
65+
activation_layer=nn.ReLU6),
8866
# pw-linear
8967
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
9068
norm_layer(oup),
@@ -154,7 +132,8 @@ def __init__(
154132
# building first layer
155133
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
156134
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
157-
features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
135+
features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer,
136+
activation_layer=nn.ReLU6)]
158137
# building inverted residual blocks
159138
for t, c, n, s in inverted_residual_setting:
160139
output_channel = _make_divisible(c * width_mult, round_nearest)
@@ -163,7 +142,8 @@ def __init__(
163142
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
164143
input_channel = output_channel
165144
# building last several layers
166-
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
145+
features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer,
146+
activation_layer=nn.ReLU6))
167147
# make it nn.Sequential
168148
self.features = nn.Sequential(*features)
169149

torchvision/models/mobilenetv3.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from typing import Any, Callable, List, Optional, Sequence
77

88
from .._internally_replaced_utils import load_state_dict_from_url
9-
from .efficientnet import SqueezeExcitation as SElayer
10-
from .mobilenetv2 import _make_divisible, ConvBNActivation
9+
from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
10+
from ._utils import _make_divisible
1111

1212

1313
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
@@ -28,7 +28,8 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4):
2828
self.relu = self.activation
2929
delattr(self, 'activation')
3030
warnings.warn(
31-
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
31+
"This SqueezeExcitation class is deprecated and will be removed in future versions. "
32+
"Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning)
3233

3334

3435
class InvertedResidualConfig:
@@ -64,21 +65,21 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
6465

6566
# expand
6667
if cnf.expanded_channels != cnf.input_channels:
67-
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
68-
norm_layer=norm_layer, activation_layer=activation_layer))
68+
layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
69+
norm_layer=norm_layer, activation_layer=activation_layer))
6970

7071
# depthwise
7172
stride = 1 if cnf.dilation > 1 else cnf.stride
72-
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
73-
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
74-
norm_layer=norm_layer, activation_layer=activation_layer))
73+
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
74+
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
75+
norm_layer=norm_layer, activation_layer=activation_layer))
7576
if cnf.use_se:
7677
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
7778
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
7879

7980
# project
80-
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
81-
activation_layer=nn.Identity))
81+
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
82+
activation_layer=None))
8283

8384
self.block = nn.Sequential(*layers)
8485
self.out_channels = cnf.out_channels
@@ -130,8 +131,8 @@ def __init__(
130131

131132
# building first layer
132133
firstconv_output_channels = inverted_residual_setting[0].input_channels
133-
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
134-
activation_layer=nn.Hardswish))
134+
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
135+
activation_layer=nn.Hardswish))
135136

136137
# building inverted residual blocks
137138
for cnf in inverted_residual_setting:
@@ -140,8 +141,8 @@ def __init__(
140141
# building last several layers
141142
lastconv_input_channels = inverted_residual_setting[-1].out_channels
142143
lastconv_output_channels = 6 * lastconv_input_channels
143-
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
144-
norm_layer=norm_layer, activation_layer=nn.Hardswish))
144+
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
145+
norm_layer=norm_layer, activation_layer=nn.Hardswish))
145146

146147
self.features = nn.Sequential(*layers)
147148
self.avgpool = nn.AdaptiveAvgPool2d(1)

torchvision/models/quantization/mobilenetv2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from typing import Any
77

8-
from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
8+
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
99
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
1010
from .utils import _replace_relu, quantize_model
11+
from ...ops.misc import ConvNormActivation
1112

1213

1314
__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']
@@ -55,7 +56,7 @@ def forward(self, x: Tensor) -> Tensor:
5556

5657
def fuse_model(self) -> None:
5758
for m in self.modules():
58-
if type(m) == ConvBNReLU:
59+
if type(m) == ConvNormActivation:
5960
fuse_modules(m, ['0', '1', '2'], inplace=True)
6061
if type(m) == QuantizableInvertedResidual:
6162
m.fuse_model()

0 commit comments

Comments
 (0)