Skip to content

Implementation of the MNASNet family of models #829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .mnasnet import *
176 changes: 176 additions & 0 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import math

import torch
import torch.nn as nn


__all__ = ['MNASNet', 'MNASNet0_5', 'MNASNet0_75', 'MNASNet1_0', 'MNASNet1_3']

# Paper suggests 0.9997 momentum, for TensFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):
""" Inverted residual block from MobileNetV2 and MNASNet papers. This can
be used to implement MobileNet V2, if ReLU is replaced with ReLU6. """

def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
bn_momentum=0.1):
super().__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
mid_ch = in_ch * expansion_factor
self.apply_residual = (in_ch == out_ch and stride == 1)
self.layers = nn.Sequential(
# Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))

def forward(self, input):
if self.apply_residual:
return self.layers.forward(input) + input
else:
return self.layers.forward(input)


def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
bn_momentum):
""" Creates a stack of inverted residuals as seen in e.g. MobileNetV2 or
MNASNet. """
assert repeats >= 1
# First one has no skip, because feature map size changes.
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
bn_momentum=bn_momentum)
remaining = []
for _ in range(1, repeats):
remaining.append(
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
bn_momentum=bn_momentum))
return nn.Sequential(first, *remaining)


def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
assert 0.0 < round_up_bias < 1.0
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
return new_val if new_val >= round_up_bias * val else new_val + divisor


def _scale_depths(depths, alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
>>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224)
>>> y = model.forward(x)
>>> y.dim()
1
>>> y.nelement()
1000
"""

def __init__(self, num_classes, alpha, dropout=0.2):
super().__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is Python3-only.
Can you replace instead by super(MNASNet, self).__init__()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.alpha = alpha
self.num_classes = num_classes
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
layers = [
# First layer: regular conv.
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise separable, no skip.
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
# MNASNet blocks: stacks of inverted residuals.
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
# Final mapping to classifier input.
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
self.layers = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
if dropout > 0.0:
self.classifier = nn.Sequential(
nn.Dropout(inplace=True, p=0.2),
nn.Linear(1280, self.num_classes))
else:
self.classifier = nn.Linear(1280, self.num_classes)

self._initialize_weights()

def features(self, x):
return self.layers.forward(x)

def forward(self, x):
x = self.features(x)
x = self.avgpool(x).squeeze()
return self.classifier(x)

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()


class MNASNet0_5(MNASNet):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about having mnasnet0_5 etc as functions, instead of classes, in the same way as what we have for the other models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

""" MNASNet with depth multiplier of 0.5. """

def __init__(self, num_classes):
super().__init__(num_classes, 0.5)


class MNASNet0_75(MNASNet):
""" MNASNet with depth multiplier of 0.75. """

def __init__(self, num_classes):
super().__init__(num_classes, 0.75)


class MNASNet1_0(MNASNet):
""" MNASNet with depth multiplier of 1.0. """

def __init__(self, num_classes):
super().__init__(num_classes, 1.0)


class MNASNet1_3(MNASNet):
""" MNASNet with depth multiplier of 1.3. """

def __init__(self, num_classes):
super().__init__(num_classes, 1.3)