-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Implementation of the MNASNet family of models #829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
50bfbe6
e1c5506
0d77acc
c41aaab
d6115f9
568bd50
5617b8e
ba0ad4d
bd4836b
5ac43bd
102ba55
9c8b827
2872b1f
05b387b
2d39797
8b5f7b9
8de71fe
40471ac
b1d54ec
ec717d0
8b2dba9
06177ee
c34df87
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ | |
from .densenet import * | ||
from .googlenet import * | ||
from .mobilenet import * | ||
from .mnasnet import * | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
__all__ = ['MNASNet', 'MNASNet0_5', 'MNASNet0_75', 'MNASNet1_0', 'MNASNet1_3'] | ||
|
||
# Paper suggests 0.9997 momentum, for TensFlow. Equivalent PyTorch momentum is | ||
# 1.0 - tensorflow. | ||
_BN_MOMENTUM = 1 - 0.9997 | ||
|
||
|
||
class _InvertedResidual(nn.Module): | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" Inverted residual block from MobileNetV2 and MNASNet papers. This can | ||
be used to implement MobileNet V2, if ReLU is replaced with ReLU6. """ | ||
|
||
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, | ||
bn_momentum=0.1): | ||
super().__init__() | ||
assert stride in [1, 2] | ||
assert kernel_size in [3, 5] | ||
mid_ch = in_ch * expansion_factor | ||
self.apply_residual = (in_ch == out_ch and stride == 1) | ||
self.layers = nn.Sequential( | ||
# Pointwise | ||
nn.Conv2d(in_ch, mid_ch, 1, bias=False), | ||
nn.BatchNorm2d(mid_ch, momentum=bn_momentum), | ||
nn.ReLU(inplace=True), | ||
# Depthwise | ||
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, | ||
stride=stride, groups=mid_ch, bias=False), | ||
nn.BatchNorm2d(mid_ch, momentum=bn_momentum), | ||
nn.ReLU(inplace=True), | ||
# Linear pointwise. Note that there's no activation. | ||
nn.Conv2d(mid_ch, out_ch, 1, bias=False), | ||
nn.BatchNorm2d(out_ch, momentum=bn_momentum)) | ||
|
||
def forward(self, input): | ||
if self.apply_residual: | ||
return self.layers.forward(input) + input | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
return self.layers.forward(input) | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, | ||
bn_momentum): | ||
""" Creates a stack of inverted residuals as seen in e.g. MobileNetV2 or | ||
MNASNet. """ | ||
assert repeats >= 1 | ||
# First one has no skip, because feature map size changes. | ||
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, | ||
bn_momentum=bn_momentum) | ||
remaining = [] | ||
for _ in range(1, repeats): | ||
remaining.append( | ||
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, | ||
bn_momentum=bn_momentum)) | ||
return nn.Sequential(first, *remaining) | ||
|
||
|
||
def _round_to_multiple_of(val, divisor, round_up_bias=0.9): | ||
""" Asymmetric rounding to make `val` divisible by `divisor`. With default | ||
bias, will round up, unless the number is no more than 10% greater than the | ||
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ | ||
assert 0.0 < round_up_bias < 1.0 | ||
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) | ||
return new_val if new_val >= round_up_bias * val else new_val + divisor | ||
|
||
|
||
def _scale_depths(depths, alpha): | ||
""" Scales tensor depths as in reference MobileNet code, prefers rouding up | ||
rather than down. """ | ||
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] | ||
|
||
|
||
class MNASNet(torch.nn.Module): | ||
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. | ||
>>> model = MNASNet(1000, 1.0) | ||
>>> x = torch.rand(1, 3, 224, 224) | ||
>>> y = model.forward(x) | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> y.dim() | ||
1 | ||
>>> y.nelement() | ||
1000 | ||
""" | ||
|
||
def __init__(self, num_classes, alpha, dropout=0.2): | ||
super().__init__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is Python3-only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
self.alpha = alpha | ||
self.num_classes = num_classes | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) | ||
layers = [ | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# First layer: regular conv. | ||
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), | ||
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), | ||
nn.ReLU(inplace=True), | ||
# Depthwise separable, no skip. | ||
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), | ||
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), | ||
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), | ||
# MNASNet blocks: stacks of inverted residuals. | ||
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), | ||
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), | ||
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), | ||
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), | ||
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), | ||
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), | ||
# Final mapping to classifier input. | ||
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), | ||
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), | ||
nn.ReLU(inplace=True), | ||
] | ||
self.layers = nn.Sequential(*layers) | ||
self.avgpool = nn.AdaptiveAvgPool2d(1) | ||
if dropout > 0.0: | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(inplace=True, p=0.2), | ||
nn.Linear(1280, self.num_classes)) | ||
else: | ||
self.classifier = nn.Linear(1280, self.num_classes) | ||
|
||
self._initialize_weights() | ||
|
||
def features(self, x): | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.layers.forward(x) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x = self.avgpool(x).squeeze() | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.classifier(x) | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
1e100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
m.weight.data.normal_(0, math.sqrt(2.0 / n)) | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1.0) | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.Linear): | ||
n = m.weight.size(1) | ||
m.weight.data.normal_(0, 0.01) | ||
m.bias.data.zero_() | ||
|
||
|
||
class MNASNet0_5(MNASNet): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
""" MNASNet with depth multiplier of 0.5. """ | ||
|
||
def __init__(self, num_classes): | ||
super().__init__(num_classes, 0.5) | ||
|
||
|
||
class MNASNet0_75(MNASNet): | ||
""" MNASNet with depth multiplier of 0.75. """ | ||
|
||
def __init__(self, num_classes): | ||
super().__init__(num_classes, 0.75) | ||
|
||
|
||
class MNASNet1_0(MNASNet): | ||
""" MNASNet with depth multiplier of 1.0. """ | ||
|
||
def __init__(self, num_classes): | ||
super().__init__(num_classes, 1.0) | ||
|
||
|
||
class MNASNet1_3(MNASNet): | ||
""" MNASNet with depth multiplier of 1.3. """ | ||
|
||
def __init__(self, num_classes): | ||
super().__init__(num_classes, 1.3) |
Uh oh!
There was an error while loading. Please reload this page.