diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index c77d27e8009..70cab0ab59c 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,3 +1,4 @@ from .resnet import * +from .alexnet import * from . import detection from . import quantization diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py new file mode 100644 index 00000000000..9e19a6ced08 --- /dev/null +++ b/torchvision/prototype/models/alexnet.py @@ -0,0 +1,46 @@ +import warnings +from functools import partial +from typing import Any, Optional + +from ...models.alexnet import AlexNet +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = ["AlexNet", "AlexNetWeights", "alexnet"] + + +_common_meta = { + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, +} + + +class AlexNetWeights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", + "acc@1": 56.522, + "acc@5": 79.066, + }, + ) + + +def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = AlexNetWeights.verify(weights) + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = AlexNet(**kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model