Skip to content

Adding multiweight support to Quantized ResNet #4827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ torchrun --nproc_per_node=8 train.py\

## Quantized

### Parameters used for generating quantized models:
### Post training quantized models

For all post training quantized models (All quantized models except mobilenet-v2), the settings are:
For all post training quantized models, the settings are:

1. num_calibration_batches: 32
2. num_workers: 16
Expand All @@ -162,8 +162,11 @@ For all post training quantized models (All quantized models except mobilenet-v2
5. backend: 'fbgemm'

```
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='<model_name>'
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL'
```
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d` and `shufflenet_v2_x1_0`.

### QAT MobileNetV2

For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
Expand All @@ -185,6 +188,8 @@ torchrun --nproc_per_node=8 train_quantization.py --model='mobilenet_v2'

Training converges at about 10 epochs.

### QAT MobileNetV3

For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
2. batch_size: 32
Expand Down
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def load_data(traindir, valdir, args):
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.__dict__[args.model]
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms()

Expand Down
16 changes: 15 additions & 1 deletion references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from train import train_one_epoch, evaluate, load_data


try:
from torchvision.prototype import models as PM
except ImportError:
PM = None


def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
Expand Down Expand Up @@ -46,7 +52,12 @@ def main(args):

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
if not args.weights:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
Expand Down Expand Up @@ -251,6 +262,9 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser


Expand Down
23 changes: 19 additions & 4 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ def get_models_with_module_names(module):
return [(fn, module_name) for fn in TM.get_models_from_module(module)]


def test_get_weight():
fn = models.resnet50
weight_name = "ImageNet1K_RefV2"
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2
@pytest.mark.parametrize(
"model_fn, weight",
[
(models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2),
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1),
],
)
def test_get_weight(model_fn, weight):
assert models._api.get_weight(model_fn, weight.name) == weight


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
Expand All @@ -43,6 +48,12 @@ def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_quantized_classification_model(model_fn):
TM.test_quantized_classification_model(model_fn)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
Expand All @@ -60,6 +71,7 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize(
"model_fn, module_name",
get_models_with_module_names(models)
+ get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video),
)
Expand All @@ -70,6 +82,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
"models": {
"input_shape": (1, 3, 224, 224),
},
"quantization": {
"input_shape": (1, 3, 224, 224),
},
"segmentation": {
"input_shape": (1, 3, 520, 520),
},
Expand Down
86 changes: 83 additions & 3 deletions torchvision/prototype/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..resnet import ResNet50Weights
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights


__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]
__all__ = [
"QuantizableResNet",
"QuantizedResNet18Weights",
"QuantizedResNet50Weights",
"QuantizedResNeXt101_32x8dWeights",
"resnet18",
"resnet50",
"resnext101_32x8d",
]


def _resnet(
Expand Down Expand Up @@ -47,22 +55,67 @@ def _resnet(
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}


class QuantizedResNet18Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"acc@1": 69.494,
"acc@5": 88.882,
},
)


class QuantizedResNet50Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized",
"acc@1": 75.920,
"acc@5": 92.814,
},
)


class QuantizedResNeXt101_32x8dWeights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"acc@1": 78.986,
"acc@5": 94.480,
},
)


def resnet18(
weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
else:
weights = None

if quantize:
weights = QuantizedResNet18Weights.verify(weights)
else:
weights = ResNet18Weights.verify(weights)

return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)


def resnet50(
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
progress: bool = True,
Expand All @@ -82,3 +135,30 @@ def resnet50(
weights = ResNet50Weights.verify(weights)

return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)


def resnext101_32x8d(
weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
if quantize
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
)
else:
weights = None

if quantize:
weights = QuantizedResNeXt101_32x8dWeights.verify(weights)
else:
weights = ResNeXt101_32x8dWeights.verify(weights)

kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)