From f1e03dd48571619b4c1611272b286dd1ab3e6687 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 11:16:53 +0100 Subject: [PATCH 1/8] disable weight download and state dict loading for model tests --- test/test_models.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index da70f57f902..66ec757bf05 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,8 +1,11 @@ +import contextlib import functools import io import operator import os +import sys import traceback +import unittest.mock import warnings from collections import OrderedDict @@ -14,7 +17,6 @@ from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda from torchvision import models - ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" @@ -23,6 +25,24 @@ def get_models_from_module(module): return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] +@contextlib.contextmanager +def disable_weight_loading(model_fn): + model_type = model_fn.__annotations__.get("return") + module_name = model_fn.__module__ + module = sys.modules[module_name] + + with contextlib.ExitStack() as stack: + function_name = "load_state_dict_from_url" + module_target = module_name if function_name in dir(module) else "torchvision._internally_replaced_utils" + download_mock = stack.enter_context(unittest.mock.patch(f"{module_target}.{function_name}")) + + method_name = "load_state_dict" + type_target = f"{model_type.__module__}.{model_type.__name__}" if model_type else "torch.nn.Module" + load_mock = stack.enter_context(unittest.mock.patch(f"{type_target}.{method_name}")) + + yield download_mock, load_mock + + def _get_expected_file(name=None): # Determine expected file based on environment expected_file_base = get_relative_path(os.path.realpath(__file__), "expect") @@ -766,8 +786,9 @@ def test_detection_model_trainable_backbone_layers(model_fn): model_name = model_fn.__name__ max_trainable = _model_tests_values[model_name]["max_trainable"] n_trainable_params = [] - for trainable_layers in range(0, max_trainable + 1): - model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) + with disable_weight_loading(model_fn): + for trainable_layers in range(0, max_trainable + 1): + model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] From 67679ed9e9f385399ac7a676bdc8d7d8e29f37dd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 11:25:04 +0100 Subject: [PATCH 2/8] fix indent --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 66ec757bf05..8aef11eae05 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -790,7 +790,7 @@ def test_detection_model_trainable_backbone_layers(model_fn): for trainable_layers in range(0, max_trainable + 1): model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) - n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) + n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] From 78329ba92892d347953195821d99e0c9fb94fa1e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 11:52:46 +0100 Subject: [PATCH 3/8] debug --- test/test_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 8aef11eae05..369b3d1cd1e 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -786,11 +786,14 @@ def test_detection_model_trainable_backbone_layers(model_fn): model_name = model_fn.__name__ max_trainable = _model_tests_values[model_name]["max_trainable"] n_trainable_params = [] - with disable_weight_loading(model_fn): + with disable_weight_loading(model_fn) as (download_mock, load_mock): for trainable_layers in range(0, max_trainable + 1): model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) + + download_mock.assert_called() + load_mock.assert_called() assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] From 94af62840b0af7db46e79886ca12ef109fdcc835 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 14:40:26 +0100 Subject: [PATCH 4/8] nuclear option --- test/test_models.py | 54 +++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 369b3d1cd1e..1a9c86a0448 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -3,9 +3,9 @@ import io import operator import os +import pkgutil import sys import traceback -import unittest.mock import warnings from collections import OrderedDict @@ -25,22 +25,35 @@ def get_models_from_module(module): return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] -@contextlib.contextmanager -def disable_weight_loading(model_fn): - model_type = model_fn.__annotations__.get("return") - module_name = model_fn.__module__ - module = sys.modules[module_name] - - with contextlib.ExitStack() as stack: - function_name = "load_state_dict_from_url" - module_target = module_name if function_name in dir(module) else "torchvision._internally_replaced_utils" - download_mock = stack.enter_context(unittest.mock.patch(f"{module_target}.{function_name}")) +@pytest.fixture +def disable_weight_loading(mocker): + starting_point = models + python_module_names = { + info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.") + } - method_name = "load_state_dict" - type_target = f"{model_type.__module__}.{model_type.__name__}" if model_type else "torch.nn.Module" - load_mock = stack.enter_context(unittest.mock.patch(f"{type_target}.{method_name}")) + function_name = "load_state_dict_from_url" + method_name = "load_state_dict" + targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"} + for name in python_module_names: + python_module = sys.modules.get(name) + if not python_module: + continue + + if function_name in python_module.__dict__: + targets.add(f"{python_module.__name__}.{function_name}") + + targets.update( + { + f"{python_module.__name__}.{obj.__name__}.{method_name}" + for obj in python_module.__dict__.values() + if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__ + } + ) - yield download_mock, load_mock + for target in targets: + with contextlib.suppress(AttributeError): + mocker.patch(target) def _get_expected_file(name=None): @@ -782,18 +795,15 @@ def test_quantized_classification_model(model_fn): @pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) -def test_detection_model_trainable_backbone_layers(model_fn): +def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading): model_name = model_fn.__name__ max_trainable = _model_tests_values[model_name]["max_trainable"] n_trainable_params = [] - with disable_weight_loading(model_fn) as (download_mock, load_mock): - for trainable_layers in range(0, max_trainable + 1): - model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) + for trainable_layers in range(0, max_trainable + 1): + model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) - n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) + n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) - download_mock.assert_called() - load_mock.assert_called() assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] From 5742f3bfe7d7a32a7ca9a8f5ba20c3c4c070a644 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 14:45:51 +0100 Subject: [PATCH 5/8] revert unrelated change --- test/test_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 1a9c86a0448..a5cc4c4d9c6 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -803,7 +803,6 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) - assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] From 1036ef2f356567e8920130da025aac7b32b788da Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 15:18:25 +0100 Subject: [PATCH 6/8] cleanup --- test/test_models.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index a5cc4c4d9c6..ab6a1cc7e3f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -28,30 +28,29 @@ def get_models_from_module(module): @pytest.fixture def disable_weight_loading(mocker): starting_point = models - python_module_names = { - info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.") - } + module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")} function_name = "load_state_dict_from_url" method_name = "load_state_dict" targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"} - for name in python_module_names: - python_module = sys.modules.get(name) - if not python_module: + for name in module_names: + module = sys.modules.get(name) + if not module: continue - if function_name in python_module.__dict__: - targets.add(f"{python_module.__name__}.{function_name}") + if function_name in module.__dict__: + targets.add(f"{module.__name__}.{function_name}") targets.update( { - f"{python_module.__name__}.{obj.__name__}.{method_name}" - for obj in python_module.__dict__.values() + f"{module.__name__}.{obj.__name__}.{method_name}" + for obj in module.__dict__.values() if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__ } ) for target in targets: + # See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details with contextlib.suppress(AttributeError): mocker.patch(target) From 5eca75cd0bdaf03a6fbc612db0c1dc5b9f136ee2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 15:46:52 +0100 Subject: [PATCH 7/8] add explanation --- test/test_models.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index ab6a1cc7e3f..e109bf4197d 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -27,11 +27,26 @@ def get_models_from_module(module): @pytest.fixture def disable_weight_loading(mocker): - starting_point = models - module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")} + """When testing models, the two slowest operations are the downloading of the weights to a file and loading them + into the model. Unless, you want to test against specific weights, these steps can be disabled without any + drawbacks. + + Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse + through all modules in `torchvision.modules` and will patch all occurrences of the function + `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be + no-ops. + + .. warning: + + Loaded models are still executable as normal, but will always have random weights. Make sure to not use this + fixture if you want to compare the model output against reference values. + """ + starting_point = models function_name = "load_state_dict_from_url" method_name = "load_state_dict" + + module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")} targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"} for name in module_names: module = sys.modules.get(name) From e9a4d8353731a2018e35625ab052080d2f77a2be Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 15:48:18 +0100 Subject: [PATCH 8/8] typo --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index e109bf4197d..150b813b0cb 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -32,7 +32,7 @@ def disable_weight_loading(mocker): drawbacks. Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse - through all modules in `torchvision.modules` and will patch all occurrences of the function + through all models in `torchvision.models` and will patch all occurrences of the function `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be no-ops.