Skip to content

disable weight download and state dict loading for model tests #4867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 6, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import contextlib
import functools
import io
import operator
import os
import pkgutil
import sys
import traceback
import warnings
from collections import OrderedDict
Expand All @@ -14,7 +17,6 @@
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda
from torchvision import models


ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"


Expand All @@ -23,6 +25,51 @@ def get_models_from_module(module):
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


@pytest.fixture
def disable_weight_loading(mocker):
"""When testing models, the two slowest operations are the downloading of the weights to a file and loading them
into the model. Unless, you want to test against specific weights, these steps can be disabled without any
drawbacks.

Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
through all models in `torchvision.models` and will patch all occurrences of the function
`download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
no-ops.

.. warning:

Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
fixture if you want to compare the model output against reference values.

"""
starting_point = models
function_name = "load_state_dict_from_url"
method_name = "load_state_dict"

module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
for name in module_names:
module = sys.modules.get(name)
if not module:
continue

if function_name in module.__dict__:
targets.add(f"{module.__name__}.{function_name}")

targets.update(
{
f"{module.__name__}.{obj.__name__}.{method_name}"
for obj in module.__dict__.values()
if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
}
)

for target in targets:
# See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
with contextlib.suppress(AttributeError):
Copy link
Collaborator Author

@pmeier pmeier Nov 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is (atleast) one situation where this patching doesn't work. We have the module torchvision.models.alexnet that defines the function alexnet, which is then imported into the torchvision.models namespace. Thus, after the import is complete, we can no longer access the module through torchvision.models.alexnet.

This is an anti-pattern and should be avoided. Given that we expose everything under torchvision.models, I suggest making all the other modules private. This avoids this naming issue and gives us the freedom to change the underlying structure however we want without needing to keep BC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, it is what it is. Not sure many other models have the same issue but I don't think it's worth breaking BC over it. Alexnet is in for historical reasons and unlikely to be used as a backbone for any modern architecture.

mocker.patch(target)


def _get_expected_file(name=None):
# Determine expected file based on environment
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
Expand Down Expand Up @@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn):


@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn):
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
model_name = model_fn.__name__
max_trainable = _model_tests_values[model_name]["max_trainable"]
n_trainable_params = []
Expand Down