Skip to content

Commit 69b2857

Browse files
1e100fmassa
authored andcommitted
Implementation of the MNASNet family of models (#829)
* Add initial mnasnet impl * Remove all type hints, comply with PyTorch overall style * Expose models * Remove avgpool from features() and add separately * Fix python3-only stuff, replace subclasses with functions * fix __all__ * Fix typo * Remove conditional dropout * Make dropout functional * Addressing @fmassa's feedback, round 1 * Replaced adaptive avgpool with mean on H and W to prevent collapsing the batch dimension * Partially address feedback * YAPF * Removed redundant class vars * Update urls to releases * Add information to models.rst * Replace init with kaiming_normal_ in fan-out mode * Use load_state_dict_from_url
1 parent 12fab3a commit 69b2857

File tree

3 files changed

+197
-0
lines changed

3 files changed

+197
-0
lines changed

docs/source/models.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ architectures for image classification:
2424
- `ShuffleNet`_ v2
2525
- `MobileNet`_ v2
2626
- `ResNeXt`_
27+
- `MNASNet`_
2728

2829
You can construct a model with random weights by calling its constructor:
2930

@@ -40,6 +41,7 @@ You can construct a model with random weights by calling its constructor:
4041
shufflenet = models.shufflenet_v2_x1_0()
4142
mobilenet = models.mobilenet_v2()
4243
resnext50_32x4d = models.resnext50_32x4d()
44+
mnasnet = models.mnasnet1_0()
4345
4446
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
4547
These can be constructed by passing ``pretrained=True``:
@@ -57,6 +59,7 @@ These can be constructed by passing ``pretrained=True``:
5759
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
5860
mobilenet = models.mobilenet_v2(pretrained=True)
5961
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
62+
mnasnet = models.mnasnet1_0(pretrained=True)
6063
6164
Instancing a pre-trained model will download its weights to a cache directory.
6265
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
@@ -111,6 +114,7 @@ ShuffleNet V2 30.64 11.68
111114
MobileNet V2 28.12 9.71
112115
ResNeXt-50-32x4d 22.38 6.30
113116
ResNeXt-101-32x8d 20.69 5.47
117+
MNASNet 1.0 26.49 8.456
114118
================================ ============= =============
115119

116120

@@ -124,6 +128,7 @@ ResNeXt-101-32x8d 20.69 5.47
124128
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
125129
.. _MobileNet: https://arxiv.org/abs/1801.04381
126130
.. _ResNeXt: https://arxiv.org/abs/1611.05431
131+
.. _MNASNet: https://arxiv.org/abs/1807.11626
127132

128133
.. currentmodule:: torchvision.models
129134

@@ -197,6 +202,14 @@ ResNext
197202
.. autofunction:: resnext50_32x4d
198203
.. autofunction:: resnext101_32x8d
199204

205+
MNASNet
206+
--------
207+
208+
.. autofunction:: mnasnet0_5
209+
.. autofunction:: mnasnet0_75
210+
.. autofunction:: mnasnet1_0
211+
.. autofunction:: mnasnet1_3
212+
200213

201214
Semantic Segmentation
202215
=====================

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .densenet import *
77
from .googlenet import *
88
from .mobilenet import *
9+
from .mnasnet import *
910
from .shufflenetv2 import *
1011
from . import segmentation
1112
from . import detection

torchvision/models/mnasnet.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
from .utils import load_state_dict_from_url
6+
7+
__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
8+
9+
_MODEL_URLS = {
10+
"mnasnet0_5":
11+
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
12+
"mnasnet0_75": None,
13+
"mnasnet1_0":
14+
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth",
15+
"mnasnet1_3": None
16+
}
17+
18+
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
19+
# 1.0 - tensorflow.
20+
_BN_MOMENTUM = 1 - 0.9997
21+
22+
23+
class _InvertedResidual(nn.Module):
24+
25+
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
26+
bn_momentum=0.1):
27+
super(_InvertedResidual, self).__init__()
28+
assert stride in [1, 2]
29+
assert kernel_size in [3, 5]
30+
mid_ch = in_ch * expansion_factor
31+
self.apply_residual = (in_ch == out_ch and stride == 1)
32+
self.layers = nn.Sequential(
33+
# Pointwise
34+
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
35+
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
36+
nn.ReLU(inplace=True),
37+
# Depthwise
38+
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
39+
stride=stride, groups=mid_ch, bias=False),
40+
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
41+
nn.ReLU(inplace=True),
42+
# Linear pointwise. Note that there's no activation.
43+
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
44+
nn.BatchNorm2d(out_ch, momentum=bn_momentum))
45+
46+
def forward(self, input):
47+
if self.apply_residual:
48+
return self.layers(input) + input
49+
else:
50+
return self.layers(input)
51+
52+
53+
def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
54+
bn_momentum):
55+
""" Creates a stack of inverted residuals. """
56+
assert repeats >= 1
57+
# First one has no skip, because feature map size changes.
58+
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
59+
bn_momentum=bn_momentum)
60+
remaining = []
61+
for _ in range(1, repeats):
62+
remaining.append(
63+
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
64+
bn_momentum=bn_momentum))
65+
return nn.Sequential(first, *remaining)
66+
67+
68+
def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
69+
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
70+
bias, will round up, unless the number is no more than 10% greater than the
71+
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
72+
assert 0.0 < round_up_bias < 1.0
73+
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
74+
return new_val if new_val >= round_up_bias * val else new_val + divisor
75+
76+
77+
def _scale_depths(depths, alpha):
78+
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
79+
rather than down. """
80+
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
81+
82+
83+
class MNASNet(torch.nn.Module):
84+
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
85+
>>> model = MNASNet(1000, 1.0)
86+
>>> x = torch.rand(1, 3, 224, 224)
87+
>>> y = model(x)
88+
>>> y.dim()
89+
1
90+
>>> y.nelement()
91+
1000
92+
"""
93+
94+
def __init__(self, alpha, num_classes=1000, dropout=0.2):
95+
super(MNASNet, self).__init__()
96+
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
97+
layers = [
98+
# First layer: regular conv.
99+
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
100+
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
101+
nn.ReLU(inplace=True),
102+
# Depthwise separable, no skip.
103+
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
104+
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
105+
nn.ReLU(inplace=True),
106+
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
107+
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
108+
# MNASNet blocks: stacks of inverted residuals.
109+
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
110+
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
111+
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
112+
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
113+
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
114+
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
115+
# Final mapping to classifier input.
116+
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
117+
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
118+
nn.ReLU(inplace=True),
119+
]
120+
self.layers = nn.Sequential(*layers)
121+
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
122+
nn.Linear(1280, num_classes))
123+
self._initialize_weights()
124+
125+
def forward(self, x):
126+
x = self.layers(x)
127+
# Equivalent to global avgpool and removing H and W dimensions.
128+
x = x.mean([2, 3])
129+
return self.classifier(x)
130+
131+
def _initialize_weights(self):
132+
for m in self.modules():
133+
if isinstance(m, nn.Conv2d):
134+
nn.init.kaiming_normal_(m.weight, mode="fan_out",
135+
nonlinearity="relu")
136+
if m.bias is not None:
137+
nn.init.zeros_(m.bias)
138+
elif isinstance(m, nn.BatchNorm2d):
139+
nn.init.ones_(m.weight)
140+
nn.init.zeros_(m.bias)
141+
elif isinstance(m, nn.Linear):
142+
nn.init.normal_(m.weight, 0.01)
143+
nn.init.zeros_(m.bias)
144+
145+
146+
def _load_pretrained(model_name, model):
147+
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
148+
raise ValueError(
149+
"No checkpoint is available for model type {}".format(model_name))
150+
checkpoint_url = _MODEL_URLS[model_name]
151+
model.load_state_dict(load_state_dict_from_url(checkpoint_url))
152+
153+
154+
def mnasnet0_5(pretrained=False, **kwargs):
155+
""" MNASNet with depth multiplier of 0.5. """
156+
model = MNASNet(0.5, **kwargs)
157+
if pretrained:
158+
_load_pretrained("mnasnet0_5", model)
159+
return model
160+
161+
162+
def mnasnet0_75(pretrained=False, **kwargs):
163+
""" MNASNet with depth multiplier of 0.75. """
164+
model = MNASNet(0.75, **kwargs)
165+
if pretrained:
166+
_load_pretrained("mnasnet0_75", model)
167+
return model
168+
169+
170+
def mnasnet1_0(pretrained=False, **kwargs):
171+
""" MNASNet with depth multiplier of 1.0. """
172+
model = MNASNet(1.0, **kwargs)
173+
if pretrained:
174+
_load_pretrained("mnasnet1_0", model)
175+
return model
176+
177+
178+
def mnasnet1_3(pretrained=False, **kwargs):
179+
""" MNASNet with depth multiplier of 1.3. """
180+
model = MNASNet(1.3, **kwargs)
181+
if pretrained:
182+
_load_pretrained("mnasnet1_3", model)
183+
return model

0 commit comments

Comments
 (0)