Skip to content

Commit b47ea83

Browse files
authored
Adding multiweight suport on Quant ShuffleNetV2 (#4856)
1 parent 9109d8d commit b47ea83

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .googlenet import *
22
from .inception import *
33
from .resnet import *
4+
from .shufflenetv2 import *
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, List, Optional, Union
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ....models.quantization.shufflenetv2 import (
8+
QuantizableShuffleNetV2,
9+
_replace_relu,
10+
quantize_model,
11+
)
12+
from ...transforms.presets import ImageNetEval
13+
from .._api import Weights, WeightEntry
14+
from .._meta import _IMAGENET_CATEGORIES
15+
from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights
16+
17+
18+
__all__ = [
19+
"QuantizableShuffleNetV2",
20+
"QuantizedShuffleNetV2_x0_5Weights",
21+
"QuantizedShuffleNetV2_x1_0Weights",
22+
"shufflenet_v2_x0_5",
23+
"shufflenet_v2_x1_0",
24+
]
25+
26+
27+
def _shufflenetv2(
28+
stages_repeats: List[int],
29+
stages_out_channels: List[int],
30+
weights: Optional[Weights],
31+
progress: bool,
32+
quantize: bool,
33+
**kwargs: Any,
34+
) -> QuantizableShuffleNetV2:
35+
if weights is not None:
36+
kwargs["num_classes"] = len(weights.meta["categories"])
37+
if "backend" in weights.meta:
38+
kwargs["backend"] = weights.meta["backend"]
39+
backend = kwargs.pop("backend", "fbgemm")
40+
41+
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
42+
_replace_relu(model)
43+
if quantize:
44+
quantize_model(model, backend)
45+
46+
if weights is not None:
47+
model.load_state_dict(weights.state_dict(progress=progress))
48+
49+
return model
50+
51+
52+
_common_meta = {
53+
"size": (224, 224),
54+
"categories": _IMAGENET_CATEGORIES,
55+
"interpolation": InterpolationMode.BILINEAR,
56+
"backend": "fbgemm",
57+
"quantization": "ptq",
58+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
59+
}
60+
61+
62+
class QuantizedShuffleNetV2_x0_5Weights(Weights):
63+
ImageNet1K_FBGEMM_Community = WeightEntry(
64+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
65+
transforms=partial(ImageNetEval, crop_size=224),
66+
meta={
67+
**_common_meta,
68+
"unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community,
69+
"acc@1": 57.972,
70+
"acc@5": 79.780,
71+
},
72+
)
73+
74+
75+
class QuantizedShuffleNetV2_x1_0Weights(Weights):
76+
ImageNet1K_FBGEMM_Community = WeightEntry(
77+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
78+
transforms=partial(ImageNetEval, crop_size=224),
79+
meta={
80+
**_common_meta,
81+
"unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community,
82+
"acc@1": 68.360,
83+
"acc@5": 87.582,
84+
},
85+
)
86+
87+
88+
def shufflenet_v2_x0_5(
89+
weights: Optional[Union[QuantizedShuffleNetV2_x0_5Weights, ShuffleNetV2_x0_5Weights]] = None,
90+
progress: bool = True,
91+
quantize: bool = False,
92+
**kwargs: Any,
93+
) -> QuantizableShuffleNetV2:
94+
if "pretrained" in kwargs:
95+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
96+
if kwargs.pop("pretrained"):
97+
weights = (
98+
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
99+
if quantize
100+
else ShuffleNetV2_x0_5Weights.ImageNet1K_Community
101+
)
102+
else:
103+
weights = None
104+
105+
if quantize:
106+
weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights)
107+
else:
108+
weights = ShuffleNetV2_x0_5Weights.verify(weights)
109+
110+
return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs)
111+
112+
113+
def shufflenet_v2_x1_0(
114+
weights: Optional[Union[QuantizedShuffleNetV2_x1_0Weights, ShuffleNetV2_x1_0Weights]] = None,
115+
progress: bool = True,
116+
quantize: bool = False,
117+
**kwargs: Any,
118+
) -> QuantizableShuffleNetV2:
119+
if "pretrained" in kwargs:
120+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
121+
if kwargs.pop("pretrained"):
122+
weights = (
123+
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
124+
if quantize
125+
else ShuffleNetV2_x1_0Weights.ImageNet1K_Community
126+
)
127+
else:
128+
weights = None
129+
130+
if quantize:
131+
weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights)
132+
else:
133+
weights = ShuffleNetV2_x1_0Weights.verify(weights)
134+
135+
return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs)

0 commit comments

Comments
 (0)