Skip to content

Commit 2c235ff

Browse files
committed
Match BVLC GoogLeNet zero initialization of classifier
1 parent 2c8caab commit 2c235ff

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

torchvision/models/googlenet.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def googlenet(pretrained=False, **kwargs):
1818
pretrained (bool): If True, returns a model pre-trained on ImageNet
1919
"""
2020
if pretrained:
21+
kwargs['init_weights'] = False
2122
model = GoogLeNet(**kwargs)
2223
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
2324
return model
@@ -32,6 +33,7 @@ def googlenet_bn(pretrained=False, **kwargs):
3233
pretrained (bool): If True, returns a model pre-trained on ImageNet
3334
"""
3435
if pretrained:
36+
kwargs['init_weights'] = False
3537
model = GoogLeNet(batch_norm=True, **kwargs)
3638
model.load_state_dict(model_zoo.load_url(model_urls['googlenet_bn']))
3739
return model
@@ -41,7 +43,7 @@ def googlenet_bn(pretrained=False, **kwargs):
4143

4244
class GoogLeNet(nn.Module):
4345

44-
def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
46+
def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_weights=True):
4547
super(GoogLeNet, self).__init__()
4648
self.aux_logits = aux_logits
4749

@@ -73,11 +75,21 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
7375
self.dropout = nn.Dropout(0.4)
7476
self.fc = nn.Linear(1024, num_classes)
7577

78+
if init_weights:
79+
self._initialize_weights()
80+
81+
def _initialize_weights(self):
7682
for m in self.modules():
77-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
83+
if isinstance(m, nn.Conv2d):
7884
nn.init.xavier_uniform_(m.weight)
7985
if m.bias is not None:
8086
nn.init.constant_(m.bias, 0.2)
87+
elif isinstance(m, nn.Linear):
88+
nn.init.xavier_uniform_(m.weight)
89+
nn.init.constant_(m.bias, 0)
90+
elif isinstance(m, nn.BatchNorm2d):
91+
nn.init.constant_(m.weight, 1)
92+
nn.init.constant_(m.bias, 0)
8193

8294
def forward(self, x):
8395
x = self.conv1(x)

0 commit comments

Comments
 (0)