|
8 | 8 |
|
9 | 9 | _MODEL_URLS = {
|
10 | 10 | "mnasnet0_5":
|
11 |
| - "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth", |
| 11 | + "https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth", |
12 | 12 | "mnasnet0_75": None,
|
13 | 13 | "mnasnet1_0":
|
14 |
| - "https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth", |
| 14 | + "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", |
15 | 15 | "mnasnet1_3": None
|
16 | 16 | }
|
17 | 17 |
|
@@ -143,41 +143,41 @@ def _initialize_weights(self):
|
143 | 143 | nn.init.zeros_(m.bias)
|
144 | 144 |
|
145 | 145 |
|
146 |
| -def _load_pretrained(model_name, model): |
| 146 | +def _load_pretrained(model_name, model, progress): |
147 | 147 | if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
|
148 | 148 | raise ValueError(
|
149 | 149 | "No checkpoint is available for model type {}".format(model_name))
|
150 | 150 | checkpoint_url = _MODEL_URLS[model_name]
|
151 |
| - model.load_state_dict(load_state_dict_from_url(checkpoint_url)) |
| 151 | + model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) |
152 | 152 |
|
153 | 153 |
|
154 |
| -def mnasnet0_5(pretrained=False, **kwargs): |
| 154 | +def mnasnet0_5(pretrained=False, progress=True, **kwargs): |
155 | 155 | """ MNASNet with depth multiplier of 0.5. """
|
156 | 156 | model = MNASNet(0.5, **kwargs)
|
157 | 157 | if pretrained:
|
158 |
| - _load_pretrained("mnasnet0_5", model) |
| 158 | + _load_pretrained("mnasnet0_5", model, progress) |
159 | 159 | return model
|
160 | 160 |
|
161 | 161 |
|
162 |
| -def mnasnet0_75(pretrained=False, **kwargs): |
| 162 | +def mnasnet0_75(pretrained=False, progress=True, **kwargs): |
163 | 163 | """ MNASNet with depth multiplier of 0.75. """
|
164 | 164 | model = MNASNet(0.75, **kwargs)
|
165 | 165 | if pretrained:
|
166 |
| - _load_pretrained("mnasnet0_75", model) |
| 166 | + _load_pretrained("mnasnet0_75", model, progress) |
167 | 167 | return model
|
168 | 168 |
|
169 | 169 |
|
170 |
| -def mnasnet1_0(pretrained=False, **kwargs): |
| 170 | +def mnasnet1_0(pretrained=False, progress=True, **kwargs): |
171 | 171 | """ MNASNet with depth multiplier of 1.0. """
|
172 | 172 | model = MNASNet(1.0, **kwargs)
|
173 | 173 | if pretrained:
|
174 |
| - _load_pretrained("mnasnet1_0", model) |
| 174 | + _load_pretrained("mnasnet1_0", model, progress) |
175 | 175 | return model
|
176 | 176 |
|
177 | 177 |
|
178 |
| -def mnasnet1_3(pretrained=False, **kwargs): |
| 178 | +def mnasnet1_3(pretrained=False, progress=True, **kwargs): |
179 | 179 | """ MNASNet with depth multiplier of 1.3. """
|
180 | 180 | model = MNASNet(1.3, **kwargs)
|
181 | 181 | if pretrained:
|
182 |
| - _load_pretrained("mnasnet1_3", model) |
| 182 | + _load_pretrained("mnasnet1_3", model, progress) |
183 | 183 | return model
|
0 commit comments