Skip to content

Multi-pretrained weight support - Quantized ResNet50 #4627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def fuse_model(self) -> None:

def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing in place a previous "bug" on our typing. The correct types here are the more specific Quantizable versions.

layers: List[int],
pretrained: bool,
progress: bool,
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .resnet import *
from . import detection
from . import quantization
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .resnet import *
84 changes: 84 additions & 0 deletions torchvision/prototype/models/quantization/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import warnings
from functools import partial
from typing import Any, List, Optional, Type, Union

from ....models.quantization.resnet import (
QuantizableBasicBlock,
QuantizableBottleneck,
QuantizableResNet,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..resnet import ResNet50Weights


__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]


def _resnet(
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
layers: List[int],
weights: Optional[Weights],
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableResNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging again #4613 (comment) but we can discuss in a follow-up

if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
Copy link
Contributor Author

@datumbox datumbox Oct 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all weights are expected to have the "backend" meta. For example when quantize=False is passed.

Let's ignore the silent overwriting of parameters (see #4613 (comment)). This can be discussed separately and applied everywhere on a follow up.

backend = kwargs.pop("backend", "fbgemm")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend is allowed to be a "hidden" kwargs argument. We can decide making it public on future iterations.


model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))

return model


_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm",
}


class QuantizedResNet50Weights(Weights):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throwing an idea in the wild: as of today (and I think this will be the case for all models), all quantized weights originates from an unquantized weight. Do we want to keep this link somehow in the Weights structure? Do we want the quantized weights to be magically picked from the ResNet50Weights if we pass the quantize flag?

I think it might be important to keep this relationship somehow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could keep this in the meta-data or if we think we are willing to make a commitment pass this as a proper field of the Weights Data Class. The reason we don't is because quantization is not something too mature at the moment. Only classification models have been quantized and we examine alternative APIs/approaches (such as FX) to achieve it. For these reasons, I would be in favour of not introducing a direct link and review this decision on the near future.

ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized",
"acc@1": 75.920,
"acc@5": 92.814,
},
)


def resnet50(
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
Copy link
Contributor Author

@datumbox datumbox Oct 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We allow passing both Quantized and normal weights. This is aligned with the past behaviours where different URLs were loaded depending on the value of quantize.

progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could probably simplify a tiny bit the code by doing something like

if quantize:
    weights_table = QuantizedResNet50Weights
else:
    weights_table = ResNet50Weights
...
weights = weights_table.ImageNet1K_RefV1  # different naming convention than now
weights = weights_table.verify(weights)

In some sense, it's a bit annoying to have to carry those two Weights inside every model builder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided not to connect the two using the same name because the Reference/recipe version might not be the same; aka RefV1 for quantized is something different than for non-quantized. For example let's say you use a different recipe / config to achieve even better quantization of the same weights; now you need a different version of the enum. Moreover there might be multiple quantized weights enums for the same unquantized weights (for example if multiple backends should be supported). Some of these points will become clearer on the near future and we can revisit them.

warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value of which default weights we load depends on whether we quantize or not.

else:
weights = None

if quantize:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation of the weights also depends on the value of quantize. Passing the wrong combination will throw an error.

weights = QuantizedResNet50Weights.verify(weights)
else:
weights = ResNet50Weights.verify(weights)

return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)