Skip to content

Commit 5bf1da1

Browse files
kazhangfacebook-github-bot
authored andcommitted
[fbsync] disable weight download and state dict loading for model tests (#4867)
Summary: * disable weight download and state dict loading for model tests * fix indent * debug * nuclear option * revert unrelated change * cleanup * add explanation * typo Reviewed By: datumbox Differential Revision: D32298972 fbshipit-source-id: 6c73192e0de442a691a993f2a4782811c5db2c32
1 parent 0a9ff12 commit 5bf1da1

File tree

1 file changed

+49
-2
lines changed

1 file changed

+49
-2
lines changed

test/test_models.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import contextlib
12
import functools
23
import io
34
import operator
45
import os
6+
import pkgutil
7+
import sys
58
import traceback
69
import warnings
710
from collections import OrderedDict
@@ -14,7 +17,6 @@
1417
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda
1518
from torchvision import models
1619

17-
1820
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
1921

2022

@@ -23,6 +25,51 @@ def get_models_from_module(module):
2325
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
2426

2527

28+
@pytest.fixture
29+
def disable_weight_loading(mocker):
30+
"""When testing models, the two slowest operations are the downloading of the weights to a file and loading them
31+
into the model. Unless, you want to test against specific weights, these steps can be disabled without any
32+
drawbacks.
33+
34+
Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
35+
through all models in `torchvision.models` and will patch all occurrences of the function
36+
`download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
37+
no-ops.
38+
39+
.. warning:
40+
41+
Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
42+
fixture if you want to compare the model output against reference values.
43+
44+
"""
45+
starting_point = models
46+
function_name = "load_state_dict_from_url"
47+
method_name = "load_state_dict"
48+
49+
module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
50+
targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
51+
for name in module_names:
52+
module = sys.modules.get(name)
53+
if not module:
54+
continue
55+
56+
if function_name in module.__dict__:
57+
targets.add(f"{module.__name__}.{function_name}")
58+
59+
targets.update(
60+
{
61+
f"{module.__name__}.{obj.__name__}.{method_name}"
62+
for obj in module.__dict__.values()
63+
if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
64+
}
65+
)
66+
67+
for target in targets:
68+
# See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
69+
with contextlib.suppress(AttributeError):
70+
mocker.patch(target)
71+
72+
2673
def _get_expected_file(name=None):
2774
# Determine expected file based on environment
2875
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
@@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn):
762809

763810

764811
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
765-
def test_detection_model_trainable_backbone_layers(model_fn):
812+
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
766813
model_name = model_fn.__name__
767814
max_trainable = _model_tests_values[model_name]["max_trainable"]
768815
n_trainable_params = []

0 commit comments

Comments
 (0)