Skip to content

Porting Quantized models #5614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/test_backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
return [
k
for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]


@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
Expand Down
40 changes: 13 additions & 27 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytest
import test_models as TM
import torch
import torchvision
from common_utils import cpu_and_gpu, needs_cuda
from torchvision.models._api import WeightsEnum, Weights
from torchvision.models._utils import handle_legacy_interface
from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface

run_if_test_with_prototype = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
Expand Down Expand Up @@ -54,27 +55,27 @@ def _build_model(fn, **kwargs):
@pytest.mark.parametrize(
"name, weight",
[
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
("ResNet50_Weights.IMAGENET1K_V1", torchvision.models.ResNet50_Weights.IMAGENET1K_V1),
("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2),
(
"ResNet50_QuantizedWeights.DEFAULT",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
),
(
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
),
],
)
def test_get_weight(name, weight):
assert models.get_weight(name) == weight
assert torchvision.models.get_weight(name) == weight


@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
TM.get_models_from_module(torchvision.models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(torchvision.models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
Expand All @@ -88,9 +89,9 @@ def test_naming_conventions(model_fn):

@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
TM.get_models_from_module(torchvision.models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(torchvision.models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
Expand Down Expand Up @@ -142,26 +143,13 @@ def test_schema_meta_validation(model_fn):
assert not bad_names


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_detection_model(model_fn, dev):
TM.test_detection_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@run_if_test_with_prototype
def test_quantized_classification_model(model_fn):
TM.test_quantized_classification_model(model_fn)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
Expand All @@ -186,9 +174,7 @@ def test_raft(model_builder, scripted):

@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
Expand Down
16 changes: 8 additions & 8 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from .alexnet import *
from .convnext import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
from .efficientnet import *
from .googlenet import *
from .mobilenet import *
from .inception import *
from .mnasnet import *
from .shufflenetv2 import *
from .efficientnet import *
from .mobilenet import *
from .regnet import *
from .resnet import *
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from . import detection
from . import feature_extraction
from . import optical_flow
from . import quantization
from . import segmentation
from . import video
from ._api import get_weight
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torchvision._utils import StrEnum

from ..._internally_replaced_utils import load_state_dict_from_url
from .._internally_replaced_utils import load_state_dict_from_url


__all__ = ["WeightsEnum", "Weights", "get_weight"]
Expand Down
File renamed without changes.
163 changes: 162 additions & 1 deletion torchvision/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import functools
import inspect
import warnings
from collections import OrderedDict
from typing import Dict, Optional
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union

from torch import nn

from .._utils import sequence_to_str
from ._api import WeightsEnum


class IntermediateLayerGetter(nn.ModuleDict):
"""
Expand Down Expand Up @@ -81,3 +87,158 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->
if new_v < 0.9 * v:
new_v += divisor
return new_v


D = TypeVar("D")


def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.

For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:

.. code::

def old_fn(foo, bar, baz=None):
...

def new_fn(foo, *, bar, baz=None):
...

Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
and at the same time warn the user of the deprecation, this decorator can be used:

.. code::

@kwonly_to_pos_or_kw
def new_fn(foo, *, bar, baz=None):
...

new_fn("foo", "bar, "baz")
"""
params = inspect.signature(fn).parameters

try:
keyword_only_start_idx = next(
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
)
except StopIteration:
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None

keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> D:
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
if keyword_only_args:
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
warnings.warn(
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
f"parameter(s) is deprecated. Please use keyword parameter(s) instead."
)
kwargs.update(keyword_only_kwargs)

return fn(*args, **kwargs)

return wrapper


W = TypeVar("W", bound=WeightsEnum)
M = TypeVar("M", bound=nn.Module)
V = TypeVar("V")


def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
"""Decorates a model builder with the new interface to make it compatible with the old.

In particular this handles two things:

1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.

Args:
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
should be accessed with :meth:`~dict.get`.
"""

def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
@kwonly_to_pos_or_kw
@functools.wraps(builder)
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
# weight argument, since it is a valid value.
sentinel = object()
weights_arg = kwargs.get(weights_param, sentinel)
if (
(weights_param not in kwargs and pretrained_param not in kwargs)
or isinstance(weights_arg, WeightsEnum)
or (isinstance(weights_arg, str) and weights_arg != "legacy")
or weights_arg is None
):
continue

# If the pretrained parameter was passed as positional argument, it is now mapped to
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
# used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel
if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
# unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else:
pretrained_arg = kwargs[pretrained_param]

if pretrained_arg:
default_weights_arg = default(kwargs) if callable(default) else default
if not isinstance(default_weights_arg, WeightsEnum):
raise ValueError(f"No weights available for model {builder.__name__}")
else:
default_weights_arg = None

if not pretrained_positional:
warnings.warn(
f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
)

msg = (
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)

del kwargs[pretrained_param]
kwargs[weights_param] = default_weights_arg

return builder(*args, **kwargs)

return inner_wrapper

return outer_wrapper


def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
if param in kwargs:
if kwargs[param] != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
else:
kwargs[param] = new_value


def _ovewrite_value_param(param: Optional[V], new_value: V) -> V:
if param is not None:
if param != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.")
return new_value
Loading