Skip to content

Commit 4176556

Browse files
authored
Add weight for mnasnet0_75 and mnasnet1_3 (#6019)
* Add weight for mnasnet0_75 and mnasnet1_3 * Fix missing comma * Add PR url as recipe, and update the metrics * Add weights to legacy handler * Update docs to specify there are weights available
1 parent 9e78871 commit 4176556

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

torchvision/models/mnasnet.py

+38-10
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,20 @@ class MNASNet0_5_Weights(WeightsEnum):
235235

236236

237237
class MNASNet0_75_Weights(WeightsEnum):
238-
# If a default model is added here the corresponding changes need to be done in mnasnet0_75
239-
pass
238+
IMAGENET1K_V1 = Weights(
239+
url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
240+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
241+
meta={
242+
**_COMMON_META,
243+
"recipe": "https://github.com/pytorch/vision/pull/6019",
244+
"num_params": 3170208,
245+
"metrics": {
246+
"acc@1": 71.180,
247+
"acc@5": 90.496,
248+
},
249+
},
250+
)
251+
DEFAULT = IMAGENET1K_V1
240252

241253

242254
class MNASNet1_0_Weights(WeightsEnum):
@@ -256,8 +268,20 @@ class MNASNet1_0_Weights(WeightsEnum):
256268

257269

258270
class MNASNet1_3_Weights(WeightsEnum):
259-
# If a default model is added here the corresponding changes need to be done in mnasnet1_3
260-
pass
271+
IMAGENET1K_V1 = Weights(
272+
url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
273+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
274+
meta={
275+
**_COMMON_META,
276+
"recipe": "https://github.com/pytorch/vision/pull/6019",
277+
"num_params": 6282256,
278+
"metrics": {
279+
"acc@1": 76.506,
280+
"acc@5": 93.522,
281+
},
282+
},
283+
)
284+
DEFAULT = IMAGENET1K_V1
261285

262286

263287
def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
@@ -299,15 +323,17 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool =
299323
return _mnasnet(0.5, weights, progress, **kwargs)
300324

301325

302-
@handle_legacy_interface(weights=("pretrained", None))
326+
@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
303327
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
304328
"""MNASNet with depth multiplier of 0.75 from
305329
`MnasNet: Platform-Aware Neural Architecture Search for Mobile
306330
<https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
307331
308332
Args:
309-
weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): Currently
310-
no pre-trained weights are available and by default no pre-trained
333+
weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
334+
pretrained weights to use. See
335+
:class:`~torchvision.models.MNASNet0_75_Weights` below for
336+
more details, and possible values. By default, no pre-trained
311337
weights are used.
312338
progress (bool, optional): If True, displays a progress bar of the
313339
download to stderr. Default is True.
@@ -351,15 +377,17 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool =
351377
return _mnasnet(1.0, weights, progress, **kwargs)
352378

353379

354-
@handle_legacy_interface(weights=("pretrained", None))
380+
@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
355381
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
356382
"""MNASNet with depth multiplier of 1.3 from
357383
`MnasNet: Platform-Aware Neural Architecture Search for Mobile
358384
<https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
359385
360386
Args:
361-
weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): Currently
362-
no pre-trained weights are available and by default no pre-trained
387+
weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
388+
pretrained weights to use. See
389+
:class:`~torchvision.models.MNASNet1_3_Weights` below for
390+
more details, and possible values. By default, no pre-trained
363391
weights are used.
364392
progress (bool, optional): If True, displays a progress bar of the
365393
download to stderr. Default is True.

0 commit comments

Comments
 (0)