Skip to content

Commit 1187d36

Browse files
jdsgomesdatumbox
andauthored
adding alexnet prototype model (#4670)
* adding alexnet prototype model * adding recipe reference * fixing lint Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 3f6ff20 commit 1187d36

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .resnet import *
2+
from .alexnet import *
23
from . import detection
34
from . import quantization
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from ...models.alexnet import AlexNet
6+
from ..transforms.presets import ImageNetEval
7+
from ._api import Weights, WeightEntry
8+
from ._meta import _IMAGENET_CATEGORIES
9+
10+
11+
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
12+
13+
14+
_common_meta = {
15+
"size": (224, 224),
16+
"categories": _IMAGENET_CATEGORIES,
17+
}
18+
19+
20+
class AlexNetWeights(Weights):
21+
ImageNet1K_RefV1 = WeightEntry(
22+
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
23+
transforms=partial(ImageNetEval, crop_size=224),
24+
meta={
25+
**_common_meta,
26+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
27+
"acc@1": 56.522,
28+
"acc@5": 79.066,
29+
},
30+
)
31+
32+
33+
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
34+
if "pretrained" in kwargs:
35+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
36+
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
37+
weights = AlexNetWeights.verify(weights)
38+
if weights is not None:
39+
kwargs["num_classes"] = len(weights.meta["categories"])
40+
41+
model = AlexNet(**kwargs)
42+
43+
if weights is not None:
44+
model.load_state_dict(weights.state_dict(progress=progress))
45+
46+
return model

0 commit comments

Comments
 (0)