|
2 | 2 |
|
3 | 3 | from torch import nn, Tensor
|
4 | 4 | from torch.nn import functional as F
|
5 |
| -from typing import Any, Optional |
| 5 | +from typing import Any, Callable, List, Optional |
6 | 6 |
|
7 | 7 | from .._internally_replaced_utils import load_state_dict_from_url
|
8 | 8 |
|
|
11 | 11 | from torchvision.models.mobilenetv3 import SqueezeExcitation
|
12 | 12 |
|
13 | 13 |
|
14 |
| -__all__ = [] |
| 14 | +__all__ = ["EfficientNet"] |
15 | 15 |
|
16 | 16 |
|
17 |
| -model_urls = {} |
| 17 | +model_urls = { |
| 18 | + "efficientnet_b0": "", # TODO: Add weights |
| 19 | +} |
| 20 | + |
| 21 | + |
| 22 | +def stochastic_depth(x: Tensor, drop_rate: float) -> Tensor: |
| 23 | + survival_rate = 1.0 - drop_rate |
| 24 | + keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > drop_rate |
| 25 | + keep = keep[(None, ) * (x.ndim - 1)].T |
| 26 | + return x / survival_rate * keep |
18 | 27 |
|
19 | 28 |
|
20 | 29 | class MBConvConfig:
|
21 |
| - # TODO: Add dilation for supporting detection and segmentation pipelines |
22 | 30 | def __init__(self,
|
23 |
| - kernel: int, stride: int, |
24 |
| - input_channels: int, out_channels: int, expand_ratio: float, se_ratio: float, |
25 |
| - skip: bool, width_mult: float): |
| 31 | + kernel: int, stride: int, dilation: int, |
| 32 | + input_channels: int, out_channels: int, expand_ratio: float, |
| 33 | + width_mult: float) -> None: |
26 | 34 | self.kernel = kernel
|
27 | 35 | self.stride = stride
|
| 36 | + self.dilation = dilation |
28 | 37 | self.input_channels = self.adjust_channels(input_channels, width_mult)
|
29 | 38 | self.out_channels = self.adjust_channels(out_channels, width_mult)
|
30 | 39 | self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult)
|
31 |
| - self.se_channels = self.adjust_channels(input_channels, se_ratio * width_mult, 1) |
32 |
| - self.skip = skip |
33 | 40 |
|
34 | 41 | @staticmethod
|
35 |
| - def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None): |
| 42 | + def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: |
36 | 43 | return _make_divisible(channels * width_mult, 8, min_value)
|
37 | 44 |
|
38 | 45 |
|
39 | 46 | class MBConv(nn.Module):
|
40 |
| - pass |
| 47 | + def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], |
| 48 | + se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None: |
| 49 | + super().__init__() |
| 50 | + if not (1 <= cnf.stride <= 2): |
| 51 | + raise ValueError('illegal stride value') |
| 52 | + |
| 53 | + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels |
| 54 | + |
| 55 | + layers: List[nn.Module] = [] |
| 56 | + activation_layer = nn.SiLU |
| 57 | + |
| 58 | + # expand |
| 59 | + if cnf.expanded_channels != cnf.input_channels: |
| 60 | + layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, |
| 61 | + norm_layer=norm_layer, activation_layer=activation_layer)) |
| 62 | + |
| 63 | + # depthwise |
| 64 | + stride = 1 if cnf.dilation > 1 else cnf.stride |
| 65 | + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, |
| 66 | + stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, |
| 67 | + norm_layer=norm_layer, activation_layer=activation_layer)) |
| 68 | + |
| 69 | + # squeeze and excitation |
| 70 | + layers.append(se_layer(cnf.expanded_channels, min_value=1, activation_fn=F.sigmoid)) |
| 71 | + |
| 72 | + # project |
| 73 | + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, |
| 74 | + activation_layer=nn.Identity)) |
| 75 | + |
| 76 | + self.block = nn.Sequential(*layers) |
| 77 | + self.out_channels = cnf.out_channels |
| 78 | + |
| 79 | + def forward(self, input: Tensor, drop_connect_rate: Optional[float] = None) -> Tensor: |
| 80 | + result = self.block(input) |
| 81 | + if self.use_res_connect: |
| 82 | + if self.training and drop_connect_rate: |
| 83 | + result = drop_connect(result, drop_connect_rate) |
| 84 | + result += input |
| 85 | + return result |
41 | 86 |
|
42 | 87 |
|
43 | 88 | class EfficientNet(nn.Module):
|
|
0 commit comments