Skip to content

Commit a209300

Browse files
TheCodezfmassa
authored andcommitted
Add GoogLeNet (Inception v1) (#678)
* Add GoogLeNet (Inception v1) * Fix missing padding * Add missing ReLu to aux classifier * Add Batch normalized version of GoogLeNet * Use ceil_mode instead of padding and initialize weights using "xavier" * Match BVLC GoogLeNet zero initialization of classifier * Small cleanup * use adaptive avg pool * adjust network to match TensorFlow * Update url of pre-trained model and add classification results on ImageNet * Bugfix that improves performance by 1 point
1 parent a7d8898 commit a209300

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

docs/source/models.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ architectures:
1010
- `SqueezeNet`_
1111
- `DenseNet`_
1212
- `Inception`_ v3
13+
- `GoogLeNet`_
1314

1415
You can construct a model with random weights by calling its constructor:
1516

@@ -22,6 +23,7 @@ You can construct a model with random weights by calling its constructor:
2223
squeezenet = models.squeezenet1_0()
2324
densenet = models.densenet161()
2425
inception = models.inception_v3()
26+
googlenet = models.googlenet()
2527
2628
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
2729
These can be constructed by passing ``pretrained=True``:
@@ -35,6 +37,7 @@ These can be constructed by passing ``pretrained=True``:
3537
vgg16 = models.vgg16(pretrained=True)
3638
densenet = models.densenet161(pretrained=True)
3739
inception = models.inception_v3(pretrained=True)
40+
googlenet = models.googlenet(pretrained=True)
3841
3942
Instancing a pre-trained model will download its weights to a cache directory.
4043
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
@@ -84,6 +87,7 @@ Densenet-169 24.00 7.00
8487
Densenet-201 22.80 6.43
8588
Densenet-161 22.35 6.20
8689
Inception v3 22.55 6.44
90+
GoogleNet 30.22 10.47
8791
================================ ============= =============
8892

8993

@@ -93,6 +97,7 @@ Inception v3 22.55 6.44
9397
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
9498
.. _DenseNet: https://arxiv.org/abs/1608.06993
9599
.. _Inception: https://arxiv.org/abs/1512.00567
100+
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
96101

97102
.. currentmodule:: torchvision.models
98103

@@ -142,3 +147,8 @@ Inception v3
142147

143148
.. autofunction:: inception_v3
144149

150+
GoogLeNet
151+
------------
152+
153+
.. autofunction:: googlenet
154+

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .squeezenet import *
55
from .inception import *
66
from .densenet import *
7+
from .googlenet import *

torchvision/models/googlenet.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.utils import model_zoo
5+
6+
__all__ = ['GoogLeNet', 'googlenet']
7+
8+
model_urls = {
9+
# GoogLeNet ported from TensorFlow
10+
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
11+
}
12+
13+
14+
def googlenet(pretrained=False, **kwargs):
15+
r"""GoogLeNet (Inception v1) model architecture from
16+
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
17+
Args:
18+
pretrained (bool): If True, returns a model pre-trained on ImageNet
19+
"""
20+
if pretrained:
21+
if 'transform_input' not in kwargs:
22+
kwargs['transform_input'] = True
23+
kwargs['init_weights'] = False
24+
model = GoogLeNet(**kwargs)
25+
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
26+
return model
27+
28+
return GoogLeNet(**kwargs)
29+
30+
31+
class GoogLeNet(nn.Module):
32+
33+
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
34+
super(GoogLeNet, self).__init__()
35+
self.aux_logits = aux_logits
36+
self.transform_input = transform_input
37+
38+
self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
39+
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
40+
self.conv2 = BasicConv2d(64, 64, kernel_size=1)
41+
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
42+
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
43+
44+
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
45+
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
46+
self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
47+
48+
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
49+
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
50+
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
51+
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
52+
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
53+
self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54+
55+
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
56+
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
57+
if aux_logits:
58+
self.aux1 = InceptionAux(512, num_classes)
59+
self.aux2 = InceptionAux(528, num_classes)
60+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
61+
self.dropout = nn.Dropout(0.4)
62+
self.fc = nn.Linear(1024, num_classes)
63+
64+
if init_weights:
65+
self._initialize_weights()
66+
67+
def _initialize_weights(self):
68+
for m in self.modules():
69+
if isinstance(m, nn.Conv2d):
70+
nn.init.xavier_uniform_(m.weight)
71+
if m.bias is not None:
72+
nn.init.constant_(m.bias, 0.2)
73+
elif isinstance(m, nn.Linear):
74+
nn.init.xavier_uniform_(m.weight)
75+
nn.init.constant_(m.bias, 0)
76+
elif isinstance(m, nn.BatchNorm2d):
77+
nn.init.constant_(m.weight, 1)
78+
nn.init.constant_(m.bias, 0)
79+
80+
def forward(self, x):
81+
if self.transform_input:
82+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
83+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
84+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
85+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
86+
87+
x = self.conv1(x)
88+
x = self.maxpool1(x)
89+
x = self.conv2(x)
90+
x = self.conv3(x)
91+
x = self.maxpool2(x)
92+
93+
x = self.inception3a(x)
94+
x = self.inception3b(x)
95+
x = self.maxpool3(x)
96+
x = self.inception4a(x)
97+
if self.training and self.aux_logits:
98+
aux1 = self.aux1(x)
99+
100+
x = self.inception4b(x)
101+
x = self.inception4c(x)
102+
x = self.inception4d(x)
103+
if self.training and self.aux_logits:
104+
aux2 = self.aux2(x)
105+
106+
x = self.inception4e(x)
107+
x = self.maxpool4(x)
108+
x = self.inception5a(x)
109+
x = self.inception5b(x)
110+
111+
x = self.avgpool(x)
112+
x = x.view(x.size(0), -1)
113+
x = self.dropout(x)
114+
x = self.fc(x)
115+
if self.training and self.aux_logits:
116+
return aux1, aux2, x
117+
return x
118+
119+
120+
class Inception(nn.Module):
121+
122+
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
123+
super(Inception, self).__init__()
124+
125+
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
126+
127+
self.branch2 = nn.Sequential(
128+
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
129+
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
130+
)
131+
132+
self.branch3 = nn.Sequential(
133+
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
134+
BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
135+
)
136+
137+
self.branch4 = nn.Sequential(
138+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
139+
BasicConv2d(in_channels, pool_proj, kernel_size=1)
140+
)
141+
142+
def forward(self, x):
143+
branch1 = self.branch1(x)
144+
branch2 = self.branch2(x)
145+
branch3 = self.branch3(x)
146+
branch4 = self.branch4(x)
147+
148+
outputs = [branch1, branch2, branch3, branch4]
149+
return torch.cat(outputs, 1)
150+
151+
152+
class InceptionAux(nn.Module):
153+
154+
def __init__(self, in_channels, num_classes):
155+
super(InceptionAux, self).__init__()
156+
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
157+
158+
self.fc1 = nn.Linear(2048, 1024)
159+
self.fc2 = nn.Linear(1024, num_classes)
160+
161+
def forward(self, x):
162+
x = F.adaptive_avg_pool2d(x, (4, 4))
163+
164+
x = self.conv(x)
165+
x = x.view(x.size(0), -1)
166+
x = F.relu(self.fc1(x), inplace=True)
167+
x = F.dropout(x, 0.7, training=self.training)
168+
x = self.fc2(x)
169+
170+
return x
171+
172+
173+
class BasicConv2d(nn.Module):
174+
175+
def __init__(self, in_channels, out_channels, **kwargs):
176+
super(BasicConv2d, self).__init__()
177+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
178+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
179+
180+
def forward(self, x):
181+
x = self.conv(x)
182+
x = self.bn(x)
183+
return F.relu(x, inplace=True)

0 commit comments

Comments
 (0)