-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Multi-pretrained weight support - Quantized ResNet50 #4627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
60858f8
81e1f14
3cd6573
d237cad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .resnet import * | ||
from . import detection | ||
from . import quantization |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .resnet import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import warnings | ||
from functools import partial | ||
from typing import Any, List, Optional, Type, Union | ||
|
||
from ....models.quantization.resnet import ( | ||
QuantizableBasicBlock, | ||
QuantizableBottleneck, | ||
QuantizableResNet, | ||
_replace_relu, | ||
quantize_model, | ||
) | ||
from ...transforms.presets import ImageNetEval | ||
from .._api import Weights, WeightEntry | ||
from .._meta import _IMAGENET_CATEGORIES | ||
from ..resnet import ResNet50Weights | ||
|
||
|
||
__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"] | ||
|
||
|
||
def _resnet( | ||
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], | ||
layers: List[int], | ||
weights: Optional[Weights], | ||
progress: bool, | ||
quantize: bool, | ||
**kwargs: Any, | ||
) -> QuantizableResNet: | ||
if weights is not None: | ||
kwargs["num_classes"] = len(weights.meta["categories"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flagging again #4613 (comment) but we can discuss in a follow-up |
||
if "backend" in weights.meta: | ||
kwargs["backend"] = weights.meta["backend"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not all weights are expected to have the "backend" meta. For example when Let's ignore the silent overwriting of parameters (see #4613 (comment)). This can be discussed separately and applied everywhere on a follow up. |
||
backend = kwargs.pop("backend", "fbgemm") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The backend is allowed to be a "hidden" kwargs argument. We can decide making it public on future iterations. |
||
|
||
model = QuantizableResNet(block, layers, **kwargs) | ||
_replace_relu(model) | ||
if quantize: | ||
quantize_model(model, backend) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.state_dict(progress=progress)) | ||
|
||
return model | ||
|
||
|
||
_common_meta = { | ||
"size": (224, 224), | ||
"categories": _IMAGENET_CATEGORIES, | ||
"backend": "fbgemm", | ||
} | ||
|
||
|
||
class QuantizedResNet50Weights(Weights): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Throwing an idea in the wild: as of today (and I think this will be the case for all models), all quantized weights originates from an unquantized weight. Do we want to keep this link somehow in the I think it might be important to keep this relationship somehow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could keep this in the meta-data or if we think we are willing to make a commitment pass this as a proper field of the Weights Data Class. The reason we don't is because quantization is not something too mature at the moment. Only classification models have been quantized and we examine alternative APIs/approaches (such as FX) to achieve it. For these reasons, I would be in favour of not introducing a direct link and review this decision on the near future. |
||
ImageNet1K_FBGEMM_RefV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", | ||
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
**_common_meta, | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized", | ||
"acc@1": 75.920, | ||
"acc@5": 92.814, | ||
}, | ||
) | ||
|
||
|
||
def resnet50( | ||
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We allow passing both Quantized and normal weights. This is aligned with the past behaviours where different URLs were loaded depending on the value of |
||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> QuantizableResNet: | ||
if "pretrained" in kwargs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you could probably simplify a tiny bit the code by doing something like if quantize:
weights_table = QuantizedResNet50Weights
else:
weights_table = ResNet50Weights
...
weights = weights_table.ImageNet1K_RefV1 # different naming convention than now
weights = weights_table.verify(weights) In some sense, it's a bit annoying to have to carry those two There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I decided not to connect the two using the same name because the Reference/recipe version might not be the same; aka RefV1 for quantized is something different than for non-quantized. For example let's say you use a different recipe / config to achieve even better quantization of the same weights; now you need a different version of the enum. Moreover there might be multiple quantized weights enums for the same unquantized weights (for example if multiple backends should be supported). Some of these points will become clearer on the near future and we can revisit them. |
||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
if kwargs.pop("pretrained"): | ||
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The value of which default weights we load depends on whether we quantize or not. |
||
else: | ||
weights = None | ||
|
||
if quantize: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The validation of the weights also depends on the value of quantize. Passing the wrong combination will throw an error. |
||
weights = QuantizedResNet50Weights.verify(weights) | ||
else: | ||
weights = ResNet50Weights.verify(weights) | ||
|
||
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixing in place a previous "bug" on our typing. The correct types here are the more specific Quantizable versions.