|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from .utils import load_state_dict_from_url |
| 6 | + |
| 7 | +__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] |
| 8 | + |
| 9 | +_MODEL_URLS = { |
| 10 | + "mnasnet0_5": |
| 11 | + "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth", |
| 12 | + "mnasnet0_75": None, |
| 13 | + "mnasnet1_0": |
| 14 | + "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth", |
| 15 | + "mnasnet1_3": None |
| 16 | +} |
| 17 | + |
| 18 | +# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is |
| 19 | +# 1.0 - tensorflow. |
| 20 | +_BN_MOMENTUM = 1 - 0.9997 |
| 21 | + |
| 22 | + |
| 23 | +class _InvertedResidual(nn.Module): |
| 24 | + |
| 25 | + def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, |
| 26 | + bn_momentum=0.1): |
| 27 | + super(_InvertedResidual, self).__init__() |
| 28 | + assert stride in [1, 2] |
| 29 | + assert kernel_size in [3, 5] |
| 30 | + mid_ch = in_ch * expansion_factor |
| 31 | + self.apply_residual = (in_ch == out_ch and stride == 1) |
| 32 | + self.layers = nn.Sequential( |
| 33 | + # Pointwise |
| 34 | + nn.Conv2d(in_ch, mid_ch, 1, bias=False), |
| 35 | + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), |
| 36 | + nn.ReLU(inplace=True), |
| 37 | + # Depthwise |
| 38 | + nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, |
| 39 | + stride=stride, groups=mid_ch, bias=False), |
| 40 | + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), |
| 41 | + nn.ReLU(inplace=True), |
| 42 | + # Linear pointwise. Note that there's no activation. |
| 43 | + nn.Conv2d(mid_ch, out_ch, 1, bias=False), |
| 44 | + nn.BatchNorm2d(out_ch, momentum=bn_momentum)) |
| 45 | + |
| 46 | + def forward(self, input): |
| 47 | + if self.apply_residual: |
| 48 | + return self.layers(input) + input |
| 49 | + else: |
| 50 | + return self.layers(input) |
| 51 | + |
| 52 | + |
| 53 | +def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, |
| 54 | + bn_momentum): |
| 55 | + """ Creates a stack of inverted residuals. """ |
| 56 | + assert repeats >= 1 |
| 57 | + # First one has no skip, because feature map size changes. |
| 58 | + first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, |
| 59 | + bn_momentum=bn_momentum) |
| 60 | + remaining = [] |
| 61 | + for _ in range(1, repeats): |
| 62 | + remaining.append( |
| 63 | + _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, |
| 64 | + bn_momentum=bn_momentum)) |
| 65 | + return nn.Sequential(first, *remaining) |
| 66 | + |
| 67 | + |
| 68 | +def _round_to_multiple_of(val, divisor, round_up_bias=0.9): |
| 69 | + """ Asymmetric rounding to make `val` divisible by `divisor`. With default |
| 70 | + bias, will round up, unless the number is no more than 10% greater than the |
| 71 | + smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ |
| 72 | + assert 0.0 < round_up_bias < 1.0 |
| 73 | + new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) |
| 74 | + return new_val if new_val >= round_up_bias * val else new_val + divisor |
| 75 | + |
| 76 | + |
| 77 | +def _scale_depths(depths, alpha): |
| 78 | + """ Scales tensor depths as in reference MobileNet code, prefers rouding up |
| 79 | + rather than down. """ |
| 80 | + return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] |
| 81 | + |
| 82 | + |
| 83 | +class MNASNet(torch.nn.Module): |
| 84 | + """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. |
| 85 | + >>> model = MNASNet(1000, 1.0) |
| 86 | + >>> x = torch.rand(1, 3, 224, 224) |
| 87 | + >>> y = model(x) |
| 88 | + >>> y.dim() |
| 89 | + 1 |
| 90 | + >>> y.nelement() |
| 91 | + 1000 |
| 92 | + """ |
| 93 | + |
| 94 | + def __init__(self, alpha, num_classes=1000, dropout=0.2): |
| 95 | + super(MNASNet, self).__init__() |
| 96 | + depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) |
| 97 | + layers = [ |
| 98 | + # First layer: regular conv. |
| 99 | + nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), |
| 100 | + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), |
| 101 | + nn.ReLU(inplace=True), |
| 102 | + # Depthwise separable, no skip. |
| 103 | + nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), |
| 104 | + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), |
| 105 | + nn.ReLU(inplace=True), |
| 106 | + nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), |
| 107 | + nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), |
| 108 | + # MNASNet blocks: stacks of inverted residuals. |
| 109 | + _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), |
| 110 | + _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), |
| 111 | + _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), |
| 112 | + _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), |
| 113 | + _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), |
| 114 | + _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), |
| 115 | + # Final mapping to classifier input. |
| 116 | + nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), |
| 117 | + nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), |
| 118 | + nn.ReLU(inplace=True), |
| 119 | + ] |
| 120 | + self.layers = nn.Sequential(*layers) |
| 121 | + self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), |
| 122 | + nn.Linear(1280, num_classes)) |
| 123 | + self._initialize_weights() |
| 124 | + |
| 125 | + def forward(self, x): |
| 126 | + x = self.layers(x) |
| 127 | + # Equivalent to global avgpool and removing H and W dimensions. |
| 128 | + x = x.mean([2, 3]) |
| 129 | + return self.classifier(x) |
| 130 | + |
| 131 | + def _initialize_weights(self): |
| 132 | + for m in self.modules(): |
| 133 | + if isinstance(m, nn.Conv2d): |
| 134 | + nn.init.kaiming_normal_(m.weight, mode="fan_out", |
| 135 | + nonlinearity="relu") |
| 136 | + if m.bias is not None: |
| 137 | + nn.init.zeros_(m.bias) |
| 138 | + elif isinstance(m, nn.BatchNorm2d): |
| 139 | + nn.init.ones_(m.weight) |
| 140 | + nn.init.zeros_(m.bias) |
| 141 | + elif isinstance(m, nn.Linear): |
| 142 | + nn.init.normal_(m.weight, 0.01) |
| 143 | + nn.init.zeros_(m.bias) |
| 144 | + |
| 145 | + |
| 146 | +def _load_pretrained(model_name, model): |
| 147 | + if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: |
| 148 | + raise ValueError( |
| 149 | + "No checkpoint is available for model type {}".format(model_name)) |
| 150 | + checkpoint_url = _MODEL_URLS[model_name] |
| 151 | + model.load_state_dict(load_state_dict_from_url(checkpoint_url)) |
| 152 | + |
| 153 | + |
| 154 | +def mnasnet0_5(pretrained=False, **kwargs): |
| 155 | + """ MNASNet with depth multiplier of 0.5. """ |
| 156 | + model = MNASNet(0.5, **kwargs) |
| 157 | + if pretrained: |
| 158 | + _load_pretrained("mnasnet0_5", model) |
| 159 | + return model |
| 160 | + |
| 161 | + |
| 162 | +def mnasnet0_75(pretrained=False, **kwargs): |
| 163 | + """ MNASNet with depth multiplier of 0.75. """ |
| 164 | + model = MNASNet(0.75, **kwargs) |
| 165 | + if pretrained: |
| 166 | + _load_pretrained("mnasnet0_75", model) |
| 167 | + return model |
| 168 | + |
| 169 | + |
| 170 | +def mnasnet1_0(pretrained=False, **kwargs): |
| 171 | + """ MNASNet with depth multiplier of 1.0. """ |
| 172 | + model = MNASNet(1.0, **kwargs) |
| 173 | + if pretrained: |
| 174 | + _load_pretrained("mnasnet1_0", model) |
| 175 | + return model |
| 176 | + |
| 177 | + |
| 178 | +def mnasnet1_3(pretrained=False, **kwargs): |
| 179 | + """ MNASNet with depth multiplier of 1.3. """ |
| 180 | + model = MNASNet(1.3, **kwargs) |
| 181 | + if pretrained: |
| 182 | + _load_pretrained("mnasnet1_3", model) |
| 183 | + return model |
0 commit comments