diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 5deb87c2ad1..ba815206be2 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -8,10 +8,10 @@ _MODEL_URLS = { "mnasnet0_5": - "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth", + "https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth", "mnasnet0_75": None, "mnasnet1_0": - "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth", + "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", "mnasnet1_3": None } @@ -143,41 +143,41 @@ def _initialize_weights(self): nn.init.zeros_(m.bias) -def _load_pretrained(model_name, model): +def _load_pretrained(model_name, model, progress): if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: raise ValueError( "No checkpoint is available for model type {}".format(model_name)) checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict(load_state_dict_from_url(checkpoint_url)) + model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) -def mnasnet0_5(pretrained=False, **kwargs): +def mnasnet0_5(pretrained=False, progress=True, **kwargs): """ MNASNet with depth multiplier of 0.5. """ model = MNASNet(0.5, **kwargs) if pretrained: - _load_pretrained("mnasnet0_5", model) + _load_pretrained("mnasnet0_5", model, progress) return model -def mnasnet0_75(pretrained=False, **kwargs): +def mnasnet0_75(pretrained=False, progress=True, **kwargs): """ MNASNet with depth multiplier of 0.75. """ model = MNASNet(0.75, **kwargs) if pretrained: - _load_pretrained("mnasnet0_75", model) + _load_pretrained("mnasnet0_75", model, progress) return model -def mnasnet1_0(pretrained=False, **kwargs): +def mnasnet1_0(pretrained=False, progress=True, **kwargs): """ MNASNet with depth multiplier of 1.0. """ model = MNASNet(1.0, **kwargs) if pretrained: - _load_pretrained("mnasnet1_0", model) + _load_pretrained("mnasnet1_0", model, progress) return model -def mnasnet1_3(pretrained=False, **kwargs): +def mnasnet1_3(pretrained=False, progress=True, **kwargs): """ MNASNet with depth multiplier of 1.3. """ model = MNASNet(1.3, **kwargs) if pretrained: - _load_pretrained("mnasnet1_3", model) + _load_pretrained("mnasnet1_3", model, progress) return model