Skip to content

Commit 03e2573

Browse files
authored
Update URL and add progress option (#1043)
1 parent 3254560 commit 03e2573

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

torchvision/models/mnasnet.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
_MODEL_URLS = {
1010
"mnasnet0_5":
11-
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
11+
"https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
1212
"mnasnet0_75": None,
1313
"mnasnet1_0":
14-
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth",
14+
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
1515
"mnasnet1_3": None
1616
}
1717

@@ -143,41 +143,41 @@ def _initialize_weights(self):
143143
nn.init.zeros_(m.bias)
144144

145145

146-
def _load_pretrained(model_name, model):
146+
def _load_pretrained(model_name, model, progress):
147147
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
148148
raise ValueError(
149149
"No checkpoint is available for model type {}".format(model_name))
150150
checkpoint_url = _MODEL_URLS[model_name]
151-
model.load_state_dict(load_state_dict_from_url(checkpoint_url))
151+
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
152152

153153

154-
def mnasnet0_5(pretrained=False, **kwargs):
154+
def mnasnet0_5(pretrained=False, progress=True, **kwargs):
155155
""" MNASNet with depth multiplier of 0.5. """
156156
model = MNASNet(0.5, **kwargs)
157157
if pretrained:
158-
_load_pretrained("mnasnet0_5", model)
158+
_load_pretrained("mnasnet0_5", model, progress)
159159
return model
160160

161161

162-
def mnasnet0_75(pretrained=False, **kwargs):
162+
def mnasnet0_75(pretrained=False, progress=True, **kwargs):
163163
""" MNASNet with depth multiplier of 0.75. """
164164
model = MNASNet(0.75, **kwargs)
165165
if pretrained:
166-
_load_pretrained("mnasnet0_75", model)
166+
_load_pretrained("mnasnet0_75", model, progress)
167167
return model
168168

169169

170-
def mnasnet1_0(pretrained=False, **kwargs):
170+
def mnasnet1_0(pretrained=False, progress=True, **kwargs):
171171
""" MNASNet with depth multiplier of 1.0. """
172172
model = MNASNet(1.0, **kwargs)
173173
if pretrained:
174-
_load_pretrained("mnasnet1_0", model)
174+
_load_pretrained("mnasnet1_0", model, progress)
175175
return model
176176

177177

178-
def mnasnet1_3(pretrained=False, **kwargs):
178+
def mnasnet1_3(pretrained=False, progress=True, **kwargs):
179179
""" MNASNet with depth multiplier of 1.3. """
180180
model = MNASNet(1.3, **kwargs)
181181
if pretrained:
182-
_load_pretrained("mnasnet1_3", model)
182+
_load_pretrained("mnasnet1_3", model, progress)
183183
return model

0 commit comments

Comments
 (0)