Skip to content

Add registration mechanism for models #6333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0d62aaf
Model registration mechanism.
datumbox Jul 28, 2022
0e7eb8a
Add overwrite options to the dataset prototype registration mechanism.
datumbox Jul 28, 2022
1520566
Adding example models.
datumbox Jul 28, 2022
2e16077
Fix module filtering
datumbox Jul 29, 2022
a02c124
Fix linter
datumbox Jul 29, 2022
eedf8df
Fix docs
datumbox Jul 29, 2022
a91a5b4
Make name optional if same as model builder
datumbox Jul 29, 2022
abbe23e
Apply updates from code-review.
datumbox Jul 29, 2022
1eb8159
fix minor bug
datumbox Jul 29, 2022
924388e
Adding getter for model weight enum
datumbox Jul 29, 2022
bd2327a
Support both strings and callables on get_model_weight.
datumbox Jul 29, 2022
a815a63
linter fixes
datumbox Jul 29, 2022
9e4e62c
Fixing mypy.
datumbox Jul 29, 2022
0209327
Renaming `get_model_weight` to `get_model_weights`
datumbox Jul 29, 2022
2a63dce
Registering all classification models.
datumbox Jul 29, 2022
976a93e
Registering all video models.
datumbox Jul 29, 2022
2b8dc89
Registering all detection models.
datumbox Jul 29, 2022
040ddfc
Registering all optical flow models.
datumbox Jul 29, 2022
2031bf7
Fixing mypy.
datumbox Jul 29, 2022
1f27788
Registering all segmentation models.
datumbox Jul 29, 2022
4d98f6c
Registering all quantization models.
datumbox Jul 29, 2022
ba0bb82
Fixing linter
datumbox Jul 29, 2022
648eb6e
Merge branch 'main' into models/registration_mechanism
datumbox Jul 29, 2022
0e2d120
Registering all prototype depth perception models.
datumbox Jul 29, 2022
2499c75
Adding tests and updating existing tests.
datumbox Jul 29, 2022
a75cfac
Fix linters
datumbox Jul 29, 2022
7e9c2e7
Fix tests.
datumbox Jul 30, 2022
81a7e3f
Add beta annotation on docs.
datumbox Jul 30, 2022
867f85f
Fix tests.
datumbox Jul 30, 2022
7ca12b9
Apply changes from code-review.
datumbox Aug 1, 2022
7c3a0ba
Adding documentation.
datumbox Aug 1, 2022
2eae177
Fix docs.
datumbox Aug 1, 2022
efe1bd9
Merge branch 'main' into models/registration_mechanism
datumbox Aug 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,46 @@ behavior, such as batch normalization. To switch between these modes, use
# Set model to eval mode
model.eval()

Model Registration Mechanism
----------------------------

.. betastatus:: registration mechanism

As of v0.14, TorchVision offers a new model registration mechanism which allows retreaving models
and weights by their names. Here are a few examples on how to use them:

.. code:: python

# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)

# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")

# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT

weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights

weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2

Here are the available public methods of the model registration mechanism:

.. currentmodule:: torchvision.models
.. autosummary::
:toctree: generated/
:template: function.rst

get_model
get_model_weights
get_weight
list_models

Using models from Hub
---------------------

Expand Down
21 changes: 6 additions & 15 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names


def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [
k
for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]


@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
Expand Down Expand Up @@ -135,10 +126,10 @@ def _get_return_nodes(self, model):
eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

@pytest.mark.parametrize("model_name", get_available_models())
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).eval()
model = models.get_model(model_name, **self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes
self._create_feature_extractor(
Expand Down Expand Up @@ -172,9 +163,9 @@ def test_node_name_conventions(self):
train_nodes, _ = get_graph_node_names(model)
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))

@pytest.mark.parametrize("model_name", get_available_models())
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train()
model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
Expand Down Expand Up @@ -211,10 +202,10 @@ def test_feature_extraction_methods_equivalence(self):
for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k])

@pytest.mark.parametrize("model_name", get_available_models())
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_jit_forward_backward(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).train()
model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
Expand Down
102 changes: 65 additions & 37 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import importlib
import os

import pytest
import test_models as TM
import torch
from torchvision import models
from torchvision.models._api import Weights, WeightsEnum
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface


Expand All @@ -15,23 +14,52 @@
)


def _get_parent_module(model_fn):
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
module = importlib.import_module(parent_module_name)
return module
@pytest.mark.parametrize(
"name, model_class",
[
("resnet50", models.ResNet),
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
("raft_large", models.optical_flow.RAFT),
("quantized_resnet50", models.quantization.QuantizableResNet),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
("mvit_v1_b", models.video.MViT),
],
)
def test_get_model(name, model_class):
assert isinstance(models.get_model(name), model_class)


@pytest.mark.parametrize(
"name, weight",
[
("resnet50", models.ResNet50_Weights),
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
("raft_large", models.optical_flow.Raft_Large_Weights),
("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
("mvit_v1_b", models.video.MViT_V1_B_Weights),
],
)
def test_get_model_weights(name, weight):
assert models.get_model_weights(name) == weight


def _get_model_weights(model_fn):
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
try:
return next(
v
@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
)
except StopIteration:
return None
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]

a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module))

assert len(b) > 0
assert a == b


@pytest.mark.parametrize(
Expand All @@ -55,27 +83,27 @@ def test_get_weight(name, weight):

@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
weights_enum = get_model_weights(model_fn)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")


@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
@run_if_test_with_extended
def test_schema_meta_validation(model_fn):
Expand Down Expand Up @@ -112,7 +140,7 @@ def test_schema_meta_validation(model_fn):
module_name = model_fn.__module__.split(".")[-2]
expected_fields = defaults["all"] | defaults[module_name]

weights_enum = _get_model_weights(model_fn)
weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")

Expand Down Expand Up @@ -153,17 +181,17 @@ def test_schema_meta_validation(model_fn):

@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
@run_if_test_with_extended
def test_transforms_jit(model_fn):
model_name = model_fn.__name__
weights_enum = _get_model_weights(model_fn)
weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")

Expand Down
33 changes: 15 additions & 18 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@
from _utils_internal import get_relative_path
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
from torchvision import models
from torchvision.models._api import find_model, list_models


ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"


def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models
return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
def list_model_fns(module):
return [find_model(name) for name in list_models(module)]


@pytest.fixture
Expand Down Expand Up @@ -597,7 +594,7 @@ def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", get_models_from_module(models))
@pytest.mark.parametrize("model_fn", list_model_fns(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_fn, dev):
set_rng_seed(0)
Expand Down Expand Up @@ -633,7 +630,7 @@ def test_classification_model(model_fn, dev):
_check_input_backprop(model, x)


@pytest.mark.parametrize("model_fn", get_models_from_module(models.segmentation))
@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_segmentation_model(model_fn, dev):
set_rng_seed(0)
Expand Down Expand Up @@ -695,7 +692,7 @@ def check_out(out):
_check_input_backprop(model, x)


@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_detection_model(model_fn, dev):
set_rng_seed(0)
Expand Down Expand Up @@ -793,7 +790,7 @@ def compute_mean_std(tensor):
_check_input_backprop(model, model_input)


@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
def test_detection_model_validation(model_fn):
set_rng_seed(0)
model = model_fn(num_classes=50, weights=None, weights_backbone=None)
Expand Down Expand Up @@ -822,7 +819,7 @@ def test_detection_model_validation(model_fn):
model(x, targets=targets)


@pytest.mark.parametrize("model_fn", get_models_from_module(models.video))
@pytest.mark.parametrize("model_fn", list_model_fns(models.video))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_video_model(model_fn, dev):
set_rng_seed(0)
Expand Down Expand Up @@ -868,7 +865,7 @@ def test_video_model(model_fn, dev):
),
reason="This Pytorch Build has not been built with fbgemm and qnnpack",
)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.quantization))
@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
def test_quantized_classification_model(model_fn):
set_rng_seed(0)
defaults = {
Expand Down Expand Up @@ -917,7 +914,7 @@ def test_quantized_classification_model(model_fn):
torch.ao.quantization.convert(model, inplace=True)


@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
model_name = model_fn.__name__
max_trainable = _model_tests_values[model_name]["max_trainable"]
Expand All @@ -930,9 +927,9 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load


@needs_cuda
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small))
@pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow))
@pytest.mark.parametrize("scripted", (False, True))
def test_raft(model_builder, scripted):
def test_raft(model_fn, scripted):

torch.manual_seed(0)

Expand All @@ -942,7 +939,7 @@ def test_raft(model_builder, scripted):
# reduced to 1)
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)

model = model_builder(corr_block=corr_block).eval().to("cuda")
model = model_fn(corr_block=corr_block).eval().to("cuda")
if scripted:
model = torch.jit.script(model)

Expand All @@ -954,7 +951,7 @@ def test_raft(model_builder, scripted):
flow_pred = preds[-1]
# Tolerance is fairly high, but there are 2 * H * W outputs to check
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1)
_assert_expected(flow_pred, name=model_fn.__name__, atol=1e-2, rtol=1)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_assign_targets_to_proposals(self):
],
)
def test_forward_negative_sample_frcnn(self, name):
model = torchvision.models.detection.__dict__[name](
weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
model = torchvision.models.get_model(
name, weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
)

images, targets = self._make_empty_sample()
Expand Down
Loading