diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 3531394acc8..7176252111a 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -13,25 +13,40 @@ __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] +_COMMON_META = { + "task": "image_classification", + "architecture": "MobileNetV2", + "publication_year": 2018, + "num_params": 3504872, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + class MobileNet_V2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ - "task": "image_classification", - "architecture": "MobileNetV2", - "publication_year": 2018, - "num_params": 3504872, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, + **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", "acc@1": 71.878, "acc@5": 90.286, }, ) - DEFAULT = IMAGENET1K_V1 + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "acc@1": 72.154, + "acc@5": 90.822, + }, + ) + DEFAULT = IMAGENET1K_V2 @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))