Skip to content

Commit ad5a240

Browse files
authored
add efficientnetv2
1 parent 348f75c commit ad5a240

File tree

1 file changed

+324
-0
lines changed

1 file changed

+324
-0
lines changed

torchvision/models/efficientnetv2.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
import copy
2+
import math
3+
from functools import partial
4+
from typing import Any, Callable, Optional, List, Sequence
5+
6+
import torch
7+
from torch import nn, Tensor
8+
9+
from ._utils import _make_divisible
10+
from .._internally_replaced_utils import load_state_dict_from_url
11+
from ..ops import StochasticDepth
12+
from .efficientnet import MBConv, ConvNormActivation
13+
14+
15+
__all__ = [
16+
"EfficientNetV2",
17+
"efficientnet_v2_s", # 384
18+
"efficientnet_v2_m", # 480
19+
"efficientnet_v2_l", # 480
20+
]
21+
22+
23+
model_urls = {
24+
# Weights ported from https://github.com/rwightman/pytorch-image-models/
25+
"efficientnet_v2_s": "",
26+
'efficientnet_v2_m': "",
27+
'efficientnet_v2_l': ""
28+
}
29+
30+
31+
class MBConvConfig:
32+
# Stores information listed at Table 1 of the EfficientNet paper
33+
def __init__(
34+
self,
35+
block_type: str,
36+
expand_ratio: float,
37+
kernel: int,
38+
stride: int,
39+
input_channels: int,
40+
out_channels: int,
41+
num_layers: int,
42+
) -> None:
43+
self.block_type = block_type
44+
self.expand_ratio = expand_ratio
45+
self.kernel = kernel
46+
self.stride = stride
47+
self.input_channels = input_channels
48+
self.out_channels = out_channels
49+
self.num_layers = num_layers
50+
51+
def __repr__(self) -> str:
52+
s = self.__class__.__name__ + "("
53+
s += "block_type={block_type}"
54+
s += "expand_ratio={expand_ratio}"
55+
s += ", kernel={kernel}"
56+
s += ", stride={stride}"
57+
s += ", input_channels={input_channels}"
58+
s += ", out_channels={out_channels}"
59+
s += ", num_layers={num_layers}"
60+
s += ")"
61+
return s.format(**self.__dict__)
62+
63+
@staticmethod
64+
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
65+
return _make_divisible(channels * width_mult, 8, min_value)
66+
67+
@staticmethod
68+
def adjust_depth(num_layers: int, depth_mult: float):
69+
return int(math.ceil(num_layers * depth_mult))
70+
71+
72+
class FusedMBConv(nn.Module):
73+
def __init__(
74+
self,
75+
cnf: MBConvConfig,
76+
stochastic_depth_prob: float,
77+
norm_layer: Callable[..., nn.Module],
78+
se_layer: Callable[..., nn.Module] = None,
79+
) -> None:
80+
super().__init__()
81+
82+
if not (1 <= cnf.stride <= 2):
83+
raise ValueError("illegal stride value")
84+
85+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
86+
87+
layers: List[nn.Module] = []
88+
activation_layer = nn.SiLU
89+
90+
# expand
91+
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
92+
if expanded_channels != cnf.input_channels:
93+
layers.append(
94+
ConvNormActivation(
95+
cnf.input_channels,
96+
expanded_channels,
97+
kernel_size=cnf.kernel,
98+
stride=cnf.stride,
99+
norm_layer=norm_layer,
100+
activation_layer=activation_layer,
101+
)
102+
)
103+
104+
if se_layer:
105+
# squeeze and excitation
106+
squeeze_channels = max(1, cnf.input_channels // 4)
107+
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
108+
109+
# project
110+
layers.append(
111+
ConvNormActivation(
112+
expanded_channels,
113+
cnf.out_channels,
114+
kernel_size=1 if expanded_channels != cnf.input_channels else cnf.kernel,
115+
stride=1 if expanded_channels != cnf.input_channels else cnf.stride,
116+
norm_layer=norm_layer,
117+
activation_layer=None if expanded_channels != cnf.input_channels else activation_layer,
118+
)
119+
)
120+
121+
self.block = nn.Sequential(*layers)
122+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
123+
self.out_channels = cnf.out_channels
124+
125+
def forward(self, input: Tensor) -> Tensor:
126+
result = self.block(input)
127+
if self.use_res_connect:
128+
result = self.stochastic_depth(result)
129+
result += input
130+
return result
131+
132+
133+
class EfficientNetV2(nn.Module):
134+
def __init__(
135+
self,
136+
block_setting: List[MBConvConfig],
137+
dropout: float,
138+
lastconv_output_channels: int = 1280,
139+
stochastic_depth_prob: float = 0.2,
140+
num_classes: int = 1000,
141+
norm_layer: Optional[Callable[..., nn.Module]] = None,
142+
**kwargs: Any,
143+
) -> None:
144+
"""
145+
EfficientNetV2 main class
146+
Args:
147+
block_setting (List): Network structure
148+
dropout (float): The droupout probability
149+
lastconv_output_channels (int): the output channels of last conv layer
150+
stochastic_depth_prob (float): The stochastic depth probability
151+
num_classes (int): Number of classes
152+
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
153+
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
154+
"""
155+
super().__init__()
156+
157+
if not block_setting:
158+
raise ValueError("The block_setting should not be empty")
159+
elif not (
160+
isinstance(block_setting, Sequence)
161+
and all([isinstance(s, MBConvConfig) for s in block_setting])
162+
):
163+
raise TypeError("The block_setting should be List[MBConvConfig]")
164+
165+
if norm_layer is None:
166+
norm_layer = nn.BatchNorm2d
167+
168+
layers: List[nn.Module] = []
169+
170+
# building first layer
171+
firstconv_output_channels = block_setting[0].input_channels
172+
layers.append(
173+
ConvNormActivation(
174+
3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
175+
)
176+
)
177+
178+
# building blocks
179+
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
180+
stage_block_id = 0
181+
for cnf in block_setting:
182+
block = MBConv if cnf.block_type == 'MB' else FusedMBConv
183+
stage: List[nn.Module] = []
184+
for _ in range(cnf.num_layers):
185+
# copy to avoid modifications. shallow copy is enough
186+
block_cnf = copy.copy(cnf)
187+
188+
# overwrite info if not the first conv in the stage
189+
if stage:
190+
block_cnf.input_channels = block_cnf.out_channels
191+
block_cnf.stride = 1
192+
193+
# adjust stochastic depth probability based on the depth of the stage block
194+
sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
195+
196+
stage.append(block(block_cnf, sd_prob, norm_layer))
197+
stage_block_id += 1
198+
199+
layers.append(nn.Sequential(*stage))
200+
201+
# building last several layers
202+
lastconv_input_channels = block_setting[-1].out_channels
203+
if lastconv_output_channels is None:
204+
lastconv_output_channels = 4 * lastconv_input_channels
205+
layers.append(
206+
ConvNormActivation(
207+
lastconv_input_channels,
208+
lastconv_output_channels,
209+
kernel_size=1,
210+
norm_layer=norm_layer,
211+
activation_layer=nn.SiLU,
212+
)
213+
)
214+
215+
self.features = nn.Sequential(*layers)
216+
self.avgpool = nn.AdaptiveAvgPool2d(1)
217+
self.classifier = nn.Sequential(
218+
nn.Dropout(p=dropout, inplace=True),
219+
nn.Linear(lastconv_output_channels, num_classes),
220+
)
221+
222+
for m in self.modules():
223+
if isinstance(m, nn.Conv2d):
224+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
225+
if m.bias is not None:
226+
nn.init.zeros_(m.bias)
227+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
228+
nn.init.ones_(m.weight)
229+
nn.init.zeros_(m.bias)
230+
elif isinstance(m, nn.Linear):
231+
init_range = 1.0 / math.sqrt(m.out_features)
232+
nn.init.uniform_(m.weight, -init_range, init_range)
233+
nn.init.zeros_(m.bias)
234+
235+
def _forward_impl(self, x: Tensor) -> Tensor:
236+
x = self.features(x)
237+
238+
x = self.avgpool(x)
239+
x = torch.flatten(x, 1)
240+
241+
x = self.classifier(x)
242+
243+
return x
244+
245+
def forward(self, x: Tensor) -> Tensor:
246+
return self._forward_impl(x)
247+
248+
249+
def _efficientnet_v2(
250+
arch: str,
251+
block_setting,
252+
dropout: float,
253+
lastconv_output_channels: int,
254+
pretrained: bool,
255+
progress: bool,
256+
**kwargs: Any,
257+
) -> EfficientNetV2:
258+
259+
model = EfficientNetV2(block_setting, dropout, lastconv_output_channels=lastconv_output_channels, **kwargs)
260+
if pretrained:
261+
if model_urls.get(arch, None) is None:
262+
raise ValueError(f"No checkpoint is available for model type {arch}")
263+
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
264+
model.load_state_dict(state_dict)
265+
return model
266+
267+
268+
def efficientnet_v2_s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNetV2:
269+
"""
270+
Constructs a EfficientNetV2-S architecture from
271+
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
272+
Args:
273+
pretrained (bool): If True, returns a model pre-trained on ImageNet
274+
progress (bool): If True, displays a progress bar of the download to stderr
275+
"""
276+
block_setting = [
277+
MBConvConfig('FusedMB', 1, 3, 1, 24, 24, 2),
278+
MBConvConfig('FusedMB', 4, 3, 2, 24, 48, 4),
279+
MBConvConfig('FusedMB', 4, 3, 2, 48, 64, 4),
280+
MBConvConfig('MB', 4, 3, 2, 64, 128, 6),
281+
MBConvConfig('MB', 6, 3, 1, 128, 160, 9),
282+
MBConvConfig('MB', 6, 3, 2, 160, 256, 15)
283+
]
284+
return _efficientnet_v2("efficientnet_v2_s", block_setting, 0., 1280, pretrained, progress, **kwargs)
285+
286+
287+
def efficientnet_v2_m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNetV2:
288+
"""
289+
Constructs a EfficientNetV2-M architecture from
290+
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
291+
Args:
292+
pretrained (bool): If True, returns a model pre-trained on ImageNet
293+
progress (bool): If True, displays a progress bar of the download to stderr
294+
"""
295+
block_setting = [
296+
MBConvConfig('FusedMB', 1, 3, 1, 24, 24, 3),
297+
MBConvConfig('FusedMB', 4, 3, 2, 24, 48, 5),
298+
MBConvConfig('FusedMB', 4, 3, 2, 48, 80, 5),
299+
MBConvConfig('MB', 4, 3, 2, 80, 160, 7),
300+
MBConvConfig('MB', 6, 3, 1, 160, 176, 14),
301+
MBConvConfig('MB', 6, 3, 2, 176, 304, 18),
302+
MBConvConfig('MB', 6, 3, 1, 304, 512, 5)
303+
]
304+
return _efficientnet_v2("efficientnet_v2_m", block_setting, 0.2, 1280, pretrained, progress, **kwargs)
305+
306+
307+
def efficientnet_v2_l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNetV2:
308+
"""
309+
Constructs a EfficientNetV2-L architecture from
310+
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
311+
Args:
312+
pretrained (bool): If True, returns a model pre-trained on ImageNet
313+
progress (bool): If True, displays a progress bar of the download to stderr
314+
"""
315+
block_setting = [
316+
MBConvConfig('FusedMB', 1, 3, 1, 32, 32, 4),
317+
MBConvConfig('FusedMB', 4, 3, 2, 32, 64, 7),
318+
MBConvConfig('FusedMB', 4, 3, 2, 64, 96, 7),
319+
MBConvConfig('MB', 4, 3, 2, 96, 192, 10),
320+
MBConvConfig('MB', 6, 3, 1, 192, 224, 19),
321+
MBConvConfig('MB', 6, 3, 2, 224, 384, 25),
322+
MBConvConfig('MB', 6, 3, 1, 384, 640, 7),
323+
]
324+
return _efficientnet_v2("efficientnet_v2_l", block_setting, 0.5, 1280, pretrained, progress, **kwargs)

0 commit comments

Comments
 (0)