Skip to content

Commit edbe693

Browse files
committed
Implement MBConv.
1 parent e173b8f commit edbe693

File tree

1 file changed

+56
-11
lines changed

1 file changed

+56
-11
lines changed

torchvision/models/efficientnet.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch import nn, Tensor
44
from torch.nn import functional as F
5-
from typing import Any, Optional
5+
from typing import Any, Callable, List, Optional
66

77
from .._internally_replaced_utils import load_state_dict_from_url
88

@@ -11,33 +11,78 @@
1111
from torchvision.models.mobilenetv3 import SqueezeExcitation
1212

1313

14-
__all__ = []
14+
__all__ = ["EfficientNet"]
1515

1616

17-
model_urls = {}
17+
model_urls = {
18+
"efficientnet_b0": "", # TODO: Add weights
19+
}
20+
21+
22+
def stochastic_depth(x: Tensor, drop_rate: float) -> Tensor:
23+
survival_rate = 1.0 - drop_rate
24+
keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > drop_rate
25+
keep = keep[(None, ) * (x.ndim - 1)].T
26+
return x / survival_rate * keep
1827

1928

2029
class MBConvConfig:
21-
# TODO: Add dilation for supporting detection and segmentation pipelines
2230
def __init__(self,
23-
kernel: int, stride: int,
24-
input_channels: int, out_channels: int, expand_ratio: float, se_ratio: float,
25-
skip: bool, width_mult: float):
31+
kernel: int, stride: int, dilation: int,
32+
input_channels: int, out_channels: int, expand_ratio: float,
33+
width_mult: float) -> None:
2634
self.kernel = kernel
2735
self.stride = stride
36+
self.dilation = dilation
2837
self.input_channels = self.adjust_channels(input_channels, width_mult)
2938
self.out_channels = self.adjust_channels(out_channels, width_mult)
3039
self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult)
31-
self.se_channels = self.adjust_channels(input_channels, se_ratio * width_mult, 1)
32-
self.skip = skip
3340

3441
@staticmethod
35-
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None):
42+
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
3643
return _make_divisible(channels * width_mult, 8, min_value)
3744

3845

3946
class MBConv(nn.Module):
40-
pass
47+
def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module],
48+
se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None:
49+
super().__init__()
50+
if not (1 <= cnf.stride <= 2):
51+
raise ValueError('illegal stride value')
52+
53+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
54+
55+
layers: List[nn.Module] = []
56+
activation_layer = nn.SiLU
57+
58+
# expand
59+
if cnf.expanded_channels != cnf.input_channels:
60+
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
61+
norm_layer=norm_layer, activation_layer=activation_layer))
62+
63+
# depthwise
64+
stride = 1 if cnf.dilation > 1 else cnf.stride
65+
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
66+
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
67+
norm_layer=norm_layer, activation_layer=activation_layer))
68+
69+
# squeeze and excitation
70+
layers.append(se_layer(cnf.expanded_channels, min_value=1, activation_fn=F.sigmoid))
71+
72+
# project
73+
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
74+
activation_layer=nn.Identity))
75+
76+
self.block = nn.Sequential(*layers)
77+
self.out_channels = cnf.out_channels
78+
79+
def forward(self, input: Tensor, drop_connect_rate: Optional[float] = None) -> Tensor:
80+
result = self.block(input)
81+
if self.use_res_connect:
82+
if self.training and drop_connect_rate:
83+
result = drop_connect(result, drop_connect_rate)
84+
result += input
85+
return result
4186

4287

4388
class EfficientNet(nn.Module):

0 commit comments

Comments
 (0)