Skip to content

Commit 1cbd9cd

Browse files
authored
Adding multiweight support to Quantized InceptionV3 (#4850)
* Moving builder to the bottom to use proper typing. * Renaming weights. * Adding quantizated inception builder. * Correct meta info. * Fix linter. * Removing init_weights to avoid exposing it on the class.
1 parent 4ccef06 commit 1cbd9cd

File tree

4 files changed

+158
-71
lines changed

4 files changed

+158
-71
lines changed

torchvision/models/quantization/inception.py

Lines changed: 65 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -24,72 +24,6 @@
2424
}
2525

2626

27-
def inception_v3(
28-
pretrained: bool = False,
29-
progress: bool = True,
30-
quantize: bool = False,
31-
**kwargs: Any,
32-
) -> "QuantizableInception3":
33-
34-
r"""Inception v3 model architecture from
35-
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
36-
37-
.. note::
38-
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
39-
N x 3 x 299 x 299, so ensure your images are sized accordingly.
40-
41-
Note that quantize = True returns a quantized model with 8 bit
42-
weights. Quantized models only support inference and run on CPUs.
43-
GPU inference is not yet supported
44-
45-
Args:
46-
pretrained (bool): If True, returns a model pre-trained on ImageNet
47-
progress (bool): If True, displays a progress bar of the download to stderr
48-
quantize (bool): If True, return a quantized version of the model
49-
aux_logits (bool): If True, add an auxiliary branch that can improve training.
50-
Default: *True*
51-
transform_input (bool): If True, preprocesses the input according to the method with which it
52-
was trained on ImageNet. Default: *False*
53-
"""
54-
if pretrained:
55-
if "transform_input" not in kwargs:
56-
kwargs["transform_input"] = True
57-
if "aux_logits" in kwargs:
58-
original_aux_logits = kwargs["aux_logits"]
59-
kwargs["aux_logits"] = True
60-
else:
61-
original_aux_logits = False
62-
63-
model = QuantizableInception3(**kwargs)
64-
_replace_relu(model)
65-
66-
if quantize:
67-
# TODO use pretrained as a string to specify the backend
68-
backend = "fbgemm"
69-
quantize_model(model, backend)
70-
else:
71-
assert pretrained in [True, False]
72-
73-
if pretrained:
74-
if quantize:
75-
if not original_aux_logits:
76-
model.aux_logits = False
77-
model.AuxLogits = None
78-
model_url = quant_model_urls["inception_v3_google_" + backend]
79-
else:
80-
model_url = inception_module.model_urls["inception_v3_google"]
81-
82-
state_dict = load_state_dict_from_url(model_url, progress=progress)
83-
84-
model.load_state_dict(state_dict)
85-
86-
if not quantize:
87-
if not original_aux_logits:
88-
model.aux_logits = False
89-
model.AuxLogits = None
90-
return model
91-
92-
9327
class QuantizableBasicConv2d(inception_module.BasicConv2d):
9428
def __init__(self, *args: Any, **kwargs: Any) -> None:
9529
super().__init__(*args, **kwargs)
@@ -237,3 +171,68 @@ def fuse_model(self) -> None:
237171
for m in self.modules():
238172
if type(m) is QuantizableBasicConv2d:
239173
m.fuse_model()
174+
175+
176+
def inception_v3(
177+
pretrained: bool = False,
178+
progress: bool = True,
179+
quantize: bool = False,
180+
**kwargs: Any,
181+
) -> QuantizableInception3:
182+
r"""Inception v3 model architecture from
183+
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
184+
185+
.. note::
186+
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
187+
N x 3 x 299 x 299, so ensure your images are sized accordingly.
188+
189+
Note that quantize = True returns a quantized model with 8 bit
190+
weights. Quantized models only support inference and run on CPUs.
191+
GPU inference is not yet supported
192+
193+
Args:
194+
pretrained (bool): If True, returns a model pre-trained on ImageNet
195+
progress (bool): If True, displays a progress bar of the download to stderr
196+
quantize (bool): If True, return a quantized version of the model
197+
aux_logits (bool): If True, add an auxiliary branch that can improve training.
198+
Default: *True*
199+
transform_input (bool): If True, preprocesses the input according to the method with which it
200+
was trained on ImageNet. Default: *False*
201+
"""
202+
if pretrained:
203+
if "transform_input" not in kwargs:
204+
kwargs["transform_input"] = True
205+
if "aux_logits" in kwargs:
206+
original_aux_logits = kwargs["aux_logits"]
207+
kwargs["aux_logits"] = True
208+
else:
209+
original_aux_logits = False
210+
211+
model = QuantizableInception3(**kwargs)
212+
_replace_relu(model)
213+
214+
if quantize:
215+
# TODO use pretrained as a string to specify the backend
216+
backend = "fbgemm"
217+
quantize_model(model, backend)
218+
else:
219+
assert pretrained in [True, False]
220+
221+
if pretrained:
222+
if quantize:
223+
if not original_aux_logits:
224+
model.aux_logits = False
225+
model.AuxLogits = None
226+
model_url = quant_model_urls["inception_v3_google_" + backend]
227+
else:
228+
model_url = inception_module.model_urls["inception_v3_google"]
229+
230+
state_dict = load_state_dict_from_url(model_url, progress=progress)
231+
232+
model.load_state_dict(state_dict)
233+
234+
if not quantize:
235+
if not original_aux_logits:
236+
model.aux_logits = False
237+
model.AuxLogits = None
238+
return model

torchvision/prototype/models/inception.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from ._meta import _IMAGENET_CATEGORIES
1111

1212

13-
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"]
13+
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"]
1414

1515

16-
class Inception3Weights(Weights):
16+
class InceptionV3Weights(Weights):
1717
ImageNet1K_TFV1 = WeightEntry(
1818
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
1919
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
@@ -28,11 +28,11 @@ class Inception3Weights(Weights):
2828
)
2929

3030

31-
def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
31+
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
3232
if "pretrained" in kwargs:
3333
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
34-
weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
35-
weights = Inception3Weights.verify(weights)
34+
weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
35+
weights = InceptionV3Weights.verify(weights)
3636

3737
original_aux_logits = kwargs.get("aux_logits", True)
3838
if weights is not None:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .googlenet import *
2+
from .inception import *
23
from .resnet import *
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional, Union
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ....models.quantization.inception import (
8+
QuantizableInception3,
9+
_replace_relu,
10+
quantize_model,
11+
)
12+
from ...transforms.presets import ImageNetEval
13+
from .._api import Weights, WeightEntry
14+
from .._meta import _IMAGENET_CATEGORIES
15+
from ..inception import InceptionV3Weights
16+
17+
18+
__all__ = [
19+
"QuantizableInception3",
20+
"QuantizedInceptionV3Weights",
21+
"inception_v3",
22+
]
23+
24+
25+
class QuantizedInceptionV3Weights(Weights):
26+
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
27+
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
28+
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
29+
meta={
30+
"size": (299, 299),
31+
"categories": _IMAGENET_CATEGORIES,
32+
"interpolation": InterpolationMode.BILINEAR,
33+
"backend": "fbgemm",
34+
"quantization": "ptq",
35+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
36+
"unquantized": InceptionV3Weights.ImageNet1K_TFV1,
37+
"acc@1": 77.176,
38+
"acc@5": 93.354,
39+
},
40+
)
41+
42+
43+
def inception_v3(
44+
weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None,
45+
progress: bool = True,
46+
quantize: bool = False,
47+
**kwargs: Any,
48+
) -> QuantizableInception3:
49+
if "pretrained" in kwargs:
50+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
51+
if kwargs.pop("pretrained"):
52+
weights = (
53+
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
54+
)
55+
else:
56+
weights = None
57+
58+
if quantize:
59+
weights = QuantizedInceptionV3Weights.verify(weights)
60+
else:
61+
weights = InceptionV3Weights.verify(weights)
62+
63+
original_aux_logits = kwargs.get("aux_logits", False)
64+
if weights is not None:
65+
if "transform_input" not in kwargs:
66+
kwargs["transform_input"] = True
67+
kwargs["aux_logits"] = True
68+
kwargs["num_classes"] = len(weights.meta["categories"])
69+
if "backend" in weights.meta:
70+
kwargs["backend"] = weights.meta["backend"]
71+
backend = kwargs.pop("backend", "fbgemm")
72+
73+
model = QuantizableInception3(**kwargs)
74+
_replace_relu(model)
75+
if quantize:
76+
quantize_model(model, backend)
77+
78+
if weights is not None:
79+
if quantize and not original_aux_logits:
80+
model.aux_logits = False
81+
model.AuxLogits = None
82+
model.load_state_dict(weights.state_dict(progress=progress))
83+
if not quantize and not original_aux_logits:
84+
model.aux_logits = False
85+
model.AuxLogits = None
86+
87+
return model

0 commit comments

Comments
 (0)