@@ -18,6 +18,7 @@ def googlenet(pretrained=False, **kwargs):
18
18
pretrained (bool): If True, returns a model pre-trained on ImageNet
19
19
"""
20
20
if pretrained :
21
+ kwargs ['init_weights' ] = False
21
22
model = GoogLeNet (** kwargs )
22
23
model .load_state_dict (model_zoo .load_url (model_urls ['googlenet' ]))
23
24
return model
@@ -32,6 +33,7 @@ def googlenet_bn(pretrained=False, **kwargs):
32
33
pretrained (bool): If True, returns a model pre-trained on ImageNet
33
34
"""
34
35
if pretrained :
36
+ kwargs ['init_weights' ] = False
35
37
model = GoogLeNet (batch_norm = True , ** kwargs )
36
38
model .load_state_dict (model_zoo .load_url (model_urls ['googlenet_bn' ]))
37
39
return model
@@ -41,7 +43,7 @@ def googlenet_bn(pretrained=False, **kwargs):
41
43
42
44
class GoogLeNet (nn .Module ):
43
45
44
- def __init__ (self , num_classes = 1000 , aux_logits = True , batch_norm = False ):
46
+ def __init__ (self , num_classes = 1000 , aux_logits = True , batch_norm = False , init_weights = True ):
45
47
super (GoogLeNet , self ).__init__ ()
46
48
self .aux_logits = aux_logits
47
49
@@ -73,11 +75,25 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
73
75
self .dropout = nn .Dropout (0.4 )
74
76
self .fc = nn .Linear (1024 , num_classes )
75
77
78
+ if init_weights :
79
+ self ._initialize_weights ()
80
+
81
+ def _initialize_weights (self ):
76
82
for m in self .modules ():
77
83
if isinstance (m , nn .Conv2d ) or isinstance (m , nn .Linear ):
78
84
nn .init .xavier_uniform_ (m .weight )
79
85
if m .bias is not None :
80
86
nn .init .constant_ (m .bias , 0.2 )
87
+ elif isinstance (m , nn .BatchNorm2d ):
88
+ nn .init .constant_ (m .weight , 1 )
89
+ nn .init .constant_ (m .bias , 0 )
90
+
91
+ # zero init classifier
92
+ for m in self .modules ():
93
+ if isinstance (m , InceptionAux ):
94
+ nn .init .zeros_ (m .fc2 .bias )
95
+ elif m == self .fc :
96
+ nn .init .zeros_ (m .bias )
81
97
82
98
def forward (self , x ):
83
99
x = self .conv1 (x )
0 commit comments