Skip to content

Commit 0ce4670

Browse files
committed
Adding quantizated inception builder.
1 parent b2c71c9 commit 0ce4670

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .googlenet import *
2+
from .inception import *
23
from .resnet import *
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional, Union
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ....models.quantization.inception import (
8+
QuantizableInception3,
9+
_replace_relu,
10+
quantize_model,
11+
)
12+
from ...transforms.presets import ImageNetEval
13+
from .._api import Weights, WeightEntry
14+
from .._meta import _IMAGENET_CATEGORIES
15+
from ..inception import InceptionV3Weights
16+
17+
18+
__all__ = [
19+
"QuantizableInception3",
20+
"QuantizedInceptionV3Weights",
21+
"inception_v3",
22+
]
23+
24+
25+
class QuantizedInceptionV3Weights(Weights):
26+
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
27+
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
28+
transforms=partial(ImageNetEval, crop_size=224),
29+
meta={
30+
"size": (224, 224),
31+
"categories": _IMAGENET_CATEGORIES,
32+
"interpolation": InterpolationMode.BILINEAR,
33+
"backend": "fbgemm",
34+
"quantization": "ptq",
35+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
36+
"unquantized": InceptionV3Weights.ImageNet1K_TFV1,
37+
"acc@1": 69.826,
38+
"acc@5": 89.404,
39+
},
40+
)
41+
42+
43+
def inception_v3(
44+
weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None,
45+
progress: bool = True,
46+
quantize: bool = False,
47+
**kwargs: Any,
48+
) -> QuantizableInception3:
49+
if "pretrained" in kwargs:
50+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
51+
if kwargs.pop("pretrained"):
52+
weights = QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
53+
else:
54+
weights = None
55+
56+
if quantize:
57+
weights = QuantizedInceptionV3Weights.verify(weights)
58+
else:
59+
weights = InceptionV3Weights.verify(weights)
60+
61+
original_aux_logits = kwargs.get("aux_logits", False)
62+
if weights is not None:
63+
if "transform_input" not in kwargs:
64+
kwargs["transform_input"] = True
65+
kwargs["aux_logits"] = True
66+
kwargs["init_weights"] = False
67+
kwargs["num_classes"] = len(weights.meta["categories"])
68+
if "backend" in weights.meta:
69+
kwargs["backend"] = weights.meta["backend"]
70+
backend = kwargs.pop("backend", "fbgemm")
71+
72+
model = QuantizableInception3(**kwargs)
73+
_replace_relu(model)
74+
if quantize:
75+
quantize_model(model, backend)
76+
77+
if weights is not None:
78+
if quantize and not original_aux_logits:
79+
model.aux_logits = False
80+
model.AuxLogits = None
81+
model.load_state_dict(weights.state_dict(progress=progress))
82+
if not quantize and not original_aux_logits:
83+
model.aux_logits = False
84+
model.AuxLogits = None
85+
86+
return model

0 commit comments

Comments
 (0)