Skip to content

Commit 0bfb10e

Browse files
committed
Match BVLC GoogLeNet zero initialization of classifier
1 parent 2c8caab commit 0bfb10e

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

torchvision/models/googlenet.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def googlenet(pretrained=False, **kwargs):
1818
pretrained (bool): If True, returns a model pre-trained on ImageNet
1919
"""
2020
if pretrained:
21+
kwargs['init_weights'] = False
2122
model = GoogLeNet(**kwargs)
2223
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
2324
return model
@@ -32,6 +33,7 @@ def googlenet_bn(pretrained=False, **kwargs):
3233
pretrained (bool): If True, returns a model pre-trained on ImageNet
3334
"""
3435
if pretrained:
36+
kwargs['init_weights'] = False
3537
model = GoogLeNet(batch_norm=True, **kwargs)
3638
model.load_state_dict(model_zoo.load_url(model_urls['googlenet_bn']))
3739
return model
@@ -41,7 +43,7 @@ def googlenet_bn(pretrained=False, **kwargs):
4143

4244
class GoogLeNet(nn.Module):
4345

44-
def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
46+
def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_weights=True):
4547
super(GoogLeNet, self).__init__()
4648
self.aux_logits = aux_logits
4749

@@ -73,11 +75,25 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
7375
self.dropout = nn.Dropout(0.4)
7476
self.fc = nn.Linear(1024, num_classes)
7577

78+
if init_weights:
79+
self._initialize_weights()
80+
81+
def _initialize_weights(self):
7682
for m in self.modules():
7783
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
7884
nn.init.xavier_uniform_(m.weight)
7985
if m.bias is not None:
8086
nn.init.constant_(m.bias, 0.2)
87+
elif isinstance(m, nn.BatchNorm2d):
88+
nn.init.constant_(m.weight, 1)
89+
nn.init.constant_(m.bias, 0)
90+
91+
# zero init classifier
92+
for m in self.modules():
93+
if isinstance(m, InceptionAux):
94+
nn.init.zeros_(m.fc2.bias)
95+
elif m == self.fc:
96+
nn.init.zeros_(m.bias)
8197

8298
def forward(self, x):
8399
x = self.conv1(x)

0 commit comments

Comments
 (0)