Skip to content

Commit b43353e

Browse files
authored
Adding multiweight support to Quantized GoogLeNet (#4848)
* Reordering the builders to use proper typing. * Adding additional meta-data on existing quantized models. * Fixing meta on unquantized model. * Adding quantized googlenet builder. * undo inception move. * Adding recipe information.
1 parent 3300692 commit b43353e

File tree

7 files changed

+166
-66
lines changed

7 files changed

+166
-66
lines changed

references/classification/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
3131
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
3232
normalization and thus are trained with the default parameters.
3333

34+
### GoogLeNet
35+
36+
The weights of the GoogLeNet model are ported from the original paper rather than trained from scratch.
37+
3438
### Inception V3
3539

3640
The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.

test/test_prototype_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
9595
},
9696
"quantization": {
9797
"input_shape": (1, 3, 224, 224),
98+
"quantize": True,
9899
},
99100
"segmentation": {
100101
"input_shape": (1, 3, 520, 520),

torchvision/models/quantization/googlenet.py

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -19,69 +19,6 @@
1919
}
2020

2121

22-
def googlenet(
23-
pretrained: bool = False,
24-
progress: bool = True,
25-
quantize: bool = False,
26-
**kwargs: Any,
27-
) -> "QuantizableGoogLeNet":
28-
29-
r"""GoogLeNet (Inception v1) model architecture from
30-
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
31-
32-
Note that quantize = True returns a quantized model with 8 bit
33-
weights. Quantized models only support inference and run on CPUs.
34-
GPU inference is not yet supported
35-
36-
Args:
37-
pretrained (bool): If True, returns a model pre-trained on ImageNet
38-
progress (bool): If True, displays a progress bar of the download to stderr
39-
quantize (bool): If True, return a quantized version of the model
40-
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
41-
Default: *False* when pretrained is True otherwise *True*
42-
transform_input (bool): If True, preprocesses the input according to the method with which it
43-
was trained on ImageNet. Default: *False*
44-
"""
45-
if pretrained:
46-
if "transform_input" not in kwargs:
47-
kwargs["transform_input"] = True
48-
if "aux_logits" not in kwargs:
49-
kwargs["aux_logits"] = False
50-
if kwargs["aux_logits"]:
51-
warnings.warn(
52-
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
53-
)
54-
original_aux_logits = kwargs["aux_logits"]
55-
kwargs["aux_logits"] = True
56-
kwargs["init_weights"] = False
57-
58-
model = QuantizableGoogLeNet(**kwargs)
59-
_replace_relu(model)
60-
61-
if quantize:
62-
# TODO use pretrained as a string to specify the backend
63-
backend = "fbgemm"
64-
quantize_model(model, backend)
65-
else:
66-
assert pretrained in [True, False]
67-
68-
if pretrained:
69-
if quantize:
70-
model_url = quant_model_urls["googlenet_" + backend]
71-
else:
72-
model_url = model_urls["googlenet"]
73-
74-
state_dict = load_state_dict_from_url(model_url, progress=progress)
75-
76-
model.load_state_dict(state_dict)
77-
78-
if not original_aux_logits:
79-
model.aux_logits = False
80-
model.aux1 = None # type: ignore[assignment]
81-
model.aux2 = None # type: ignore[assignment]
82-
return model
83-
84-
8522
class QuantizableBasicConv2d(BasicConv2d):
8623
def __init__(self, *args: Any, **kwargs: Any) -> None:
8724
super().__init__(*args, **kwargs)
@@ -164,3 +101,65 @@ def fuse_model(self) -> None:
164101
for m in self.modules():
165102
if type(m) is QuantizableBasicConv2d:
166103
m.fuse_model()
104+
105+
106+
def googlenet(
107+
pretrained: bool = False,
108+
progress: bool = True,
109+
quantize: bool = False,
110+
**kwargs: Any,
111+
) -> QuantizableGoogLeNet:
112+
r"""GoogLeNet (Inception v1) model architecture from
113+
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
114+
115+
Note that quantize = True returns a quantized model with 8 bit
116+
weights. Quantized models only support inference and run on CPUs.
117+
GPU inference is not yet supported
118+
119+
Args:
120+
pretrained (bool): If True, returns a model pre-trained on ImageNet
121+
progress (bool): If True, displays a progress bar of the download to stderr
122+
quantize (bool): If True, return a quantized version of the model
123+
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
124+
Default: *False* when pretrained is True otherwise *True*
125+
transform_input (bool): If True, preprocesses the input according to the method with which it
126+
was trained on ImageNet. Default: *False*
127+
"""
128+
if pretrained:
129+
if "transform_input" not in kwargs:
130+
kwargs["transform_input"] = True
131+
if "aux_logits" not in kwargs:
132+
kwargs["aux_logits"] = False
133+
if kwargs["aux_logits"]:
134+
warnings.warn(
135+
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
136+
)
137+
original_aux_logits = kwargs["aux_logits"]
138+
kwargs["aux_logits"] = True
139+
kwargs["init_weights"] = False
140+
141+
model = QuantizableGoogLeNet(**kwargs)
142+
_replace_relu(model)
143+
144+
if quantize:
145+
# TODO use pretrained as a string to specify the backend
146+
backend = "fbgemm"
147+
quantize_model(model, backend)
148+
else:
149+
assert pretrained in [True, False]
150+
151+
if pretrained:
152+
if quantize:
153+
model_url = quant_model_urls["googlenet_" + backend]
154+
else:
155+
model_url = model_urls["googlenet"]
156+
157+
state_dict = load_state_dict_from_url(model_url, progress=progress)
158+
159+
model.load_state_dict(state_dict)
160+
161+
if not original_aux_logits:
162+
model.aux_logits = False
163+
model.aux1 = None # type: ignore[assignment]
164+
model.aux2 = None # type: ignore[assignment]
165+
return model

torchvision/prototype/models/googlenet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515

1616
class GoogLeNetWeights(Weights):
17-
ImageNet1K_Community = WeightEntry(
17+
ImageNet1K_TFV1 = WeightEntry(
1818
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
1919
transforms=partial(ImageNetEval, crop_size=224),
2020
meta={
2121
"size": (224, 224),
2222
"categories": _IMAGENET_CATEGORIES,
2323
"interpolation": InterpolationMode.BILINEAR,
24-
"recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet",
24+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
2525
"acc@1": 69.778,
2626
"acc@5": 89.530,
2727
},
@@ -31,7 +31,7 @@ class GoogLeNetWeights(Weights):
3131
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
3232
if "pretrained" in kwargs:
3333
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
34-
weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None
34+
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
3535
weights = GoogLeNetWeights.verify(weights)
3636

3737
original_aux_logits = kwargs.get("aux_logits", False)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .googlenet import *
12
from .resnet import *
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional, Union
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ....models.quantization.googlenet import (
8+
QuantizableGoogLeNet,
9+
_replace_relu,
10+
quantize_model,
11+
)
12+
from ...transforms.presets import ImageNetEval
13+
from .._api import Weights, WeightEntry
14+
from .._meta import _IMAGENET_CATEGORIES
15+
from ..googlenet import GoogLeNetWeights
16+
17+
18+
__all__ = [
19+
"QuantizableGoogLeNet",
20+
"QuantizedGoogLeNetWeights",
21+
"googlenet",
22+
]
23+
24+
25+
class QuantizedGoogLeNetWeights(Weights):
26+
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
27+
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
28+
transforms=partial(ImageNetEval, crop_size=224),
29+
meta={
30+
"size": (224, 224),
31+
"categories": _IMAGENET_CATEGORIES,
32+
"interpolation": InterpolationMode.BILINEAR,
33+
"backend": "fbgemm",
34+
"quantization": "ptq",
35+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
36+
"unquantized": GoogLeNetWeights.ImageNet1K_TFV1,
37+
"acc@1": 69.826,
38+
"acc@5": 89.404,
39+
},
40+
)
41+
42+
43+
def googlenet(
44+
weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None,
45+
progress: bool = True,
46+
quantize: bool = False,
47+
**kwargs: Any,
48+
) -> QuantizableGoogLeNet:
49+
if "pretrained" in kwargs:
50+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
51+
if kwargs.pop("pretrained"):
52+
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
53+
else:
54+
weights = None
55+
56+
if quantize:
57+
weights = QuantizedGoogLeNetWeights.verify(weights)
58+
else:
59+
weights = GoogLeNetWeights.verify(weights)
60+
61+
original_aux_logits = kwargs.get("aux_logits", False)
62+
if weights is not None:
63+
if "transform_input" not in kwargs:
64+
kwargs["transform_input"] = True
65+
if original_aux_logits:
66+
warnings.warn(
67+
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
68+
)
69+
kwargs["aux_logits"] = True
70+
kwargs["init_weights"] = False
71+
kwargs["num_classes"] = len(weights.meta["categories"])
72+
if "backend" in weights.meta:
73+
kwargs["backend"] = weights.meta["backend"]
74+
backend = kwargs.pop("backend", "fbgemm")
75+
76+
model = QuantizableGoogLeNet(**kwargs)
77+
_replace_relu(model)
78+
if quantize:
79+
quantize_model(model, backend)
80+
81+
if weights is not None:
82+
model.load_state_dict(weights.state_dict(progress=progress))
83+
if not original_aux_logits:
84+
model.aux_logits = False
85+
model.aux1 = None # type: ignore[assignment]
86+
model.aux2 = None # type: ignore[assignment]
87+
88+
return model

torchvision/prototype/models/quantization/resnet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from typing import Any, List, Optional, Type, Union
44

5+
from torchvision.transforms.functional import InterpolationMode
6+
57
from ....models.quantization.resnet import (
68
QuantizableBasicBlock,
79
QuantizableBottleneck,
@@ -54,7 +56,9 @@ def _resnet(
5456
_common_meta = {
5557
"size": (224, 224),
5658
"categories": _IMAGENET_CATEGORIES,
59+
"interpolation": InterpolationMode.BILINEAR,
5760
"backend": "fbgemm",
61+
"quantization": "ptq",
5862
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
5963
}
6064

@@ -65,6 +69,7 @@ class QuantizedResNet18Weights(Weights):
6569
transforms=partial(ImageNetEval, crop_size=224),
6670
meta={
6771
**_common_meta,
72+
"unquantized": ResNet18Weights.ImageNet1K_RefV1,
6873
"acc@1": 69.494,
6974
"acc@5": 88.882,
7075
},
@@ -77,6 +82,7 @@ class QuantizedResNet50Weights(Weights):
7782
transforms=partial(ImageNetEval, crop_size=224),
7883
meta={
7984
**_common_meta,
85+
"unquantized": ResNet50Weights.ImageNet1K_RefV1,
8086
"acc@1": 75.920,
8187
"acc@5": 92.814,
8288
},
@@ -89,6 +95,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
8995
transforms=partial(ImageNetEval, crop_size=224),
9096
meta={
9197
**_common_meta,
98+
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1,
9299
"acc@1": 78.986,
93100
"acc@5": 94.480,
94101
},

0 commit comments

Comments
 (0)