-
Notifications
You must be signed in to change notification settings - Fork 7.1k
disable weight download and state dict loading for model tests #4867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f1e03dd
67679ed
78329ba
94af628
5742f3b
70976a1
1036ef2
5eca75c
e9a4d83
b6e6569
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
import contextlib | ||
import functools | ||
import io | ||
import operator | ||
import os | ||
import pkgutil | ||
import sys | ||
import traceback | ||
import warnings | ||
from collections import OrderedDict | ||
|
@@ -14,7 +17,6 @@ | |
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda | ||
from torchvision import models | ||
|
||
|
||
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" | ||
|
||
|
||
|
@@ -23,6 +25,51 @@ def get_models_from_module(module): | |
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] | ||
|
||
|
||
@pytest.fixture | ||
def disable_weight_loading(mocker): | ||
"""When testing models, the two slowest operations are the downloading of the weights to a file and loading them | ||
into the model. Unless, you want to test against specific weights, these steps can be disabled without any | ||
drawbacks. | ||
|
||
Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse | ||
through all models in `torchvision.models` and will patch all occurrences of the function | ||
`download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be | ||
no-ops. | ||
|
||
.. warning: | ||
|
||
Loaded models are still executable as normal, but will always have random weights. Make sure to not use this | ||
fixture if you want to compare the model output against reference values. | ||
|
||
""" | ||
starting_point = models | ||
function_name = "load_state_dict_from_url" | ||
method_name = "load_state_dict" | ||
|
||
module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")} | ||
targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"} | ||
for name in module_names: | ||
module = sys.modules.get(name) | ||
if not module: | ||
continue | ||
|
||
if function_name in module.__dict__: | ||
targets.add(f"{module.__name__}.{function_name}") | ||
|
||
targets.update( | ||
{ | ||
f"{module.__name__}.{obj.__name__}.{method_name}" | ||
for obj in module.__dict__.values() | ||
if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__ | ||
} | ||
) | ||
|
||
for target in targets: | ||
# See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details | ||
with contextlib.suppress(AttributeError): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is (atleast) one situation where this patching doesn't work. We have the module This is an anti-pattern and should be avoided. Given that we expose everything under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeap, it is what it is. Not sure many other models have the same issue but I don't think it's worth breaking BC over it. Alexnet is in for historical reasons and unlikely to be used as a backbone for any modern architecture. |
||
mocker.patch(target) | ||
|
||
|
||
def _get_expected_file(name=None): | ||
# Determine expected file based on environment | ||
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect") | ||
|
@@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn): | |
|
||
|
||
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) | ||
def test_detection_model_trainable_backbone_layers(model_fn): | ||
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading): | ||
model_name = model_fn.__name__ | ||
max_trainable = _model_tests_values[model_name]["max_trainable"] | ||
n_trainable_params = [] | ||
|
Uh oh!
There was an error while loading. Please reload this page.