1
+ import contextlib
1
2
import functools
2
3
import io
3
4
import operator
4
5
import os
6
+ import pkgutil
7
+ import sys
5
8
import traceback
6
9
import warnings
7
10
from collections import OrderedDict
14
17
from common_utils import map_nested_tensor_object , freeze_rng_state , set_rng_seed , cpu_and_gpu , needs_cuda
15
18
from torchvision import models
16
19
17
-
18
20
ACCEPT = os .getenv ("EXPECTTEST_ACCEPT" , "0" ) == "1"
19
21
20
22
@@ -23,6 +25,51 @@ def get_models_from_module(module):
23
25
return [v for k , v in module .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
24
26
25
27
28
+ @pytest .fixture
29
+ def disable_weight_loading (mocker ):
30
+ """When testing models, the two slowest operations are the downloading of the weights to a file and loading them
31
+ into the model. Unless, you want to test against specific weights, these steps can be disabled without any
32
+ drawbacks.
33
+
34
+ Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
35
+ through all models in `torchvision.models` and will patch all occurrences of the function
36
+ `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
37
+ no-ops.
38
+
39
+ .. warning:
40
+
41
+ Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
42
+ fixture if you want to compare the model output against reference values.
43
+
44
+ """
45
+ starting_point = models
46
+ function_name = "load_state_dict_from_url"
47
+ method_name = "load_state_dict"
48
+
49
+ module_names = {info .name for info in pkgutil .walk_packages (starting_point .__path__ , f"{ starting_point .__name__ } ." )}
50
+ targets = {f"torchvision._internally_replaced_utils.{ function_name } " , f"torch.nn.Module.{ method_name } " }
51
+ for name in module_names :
52
+ module = sys .modules .get (name )
53
+ if not module :
54
+ continue
55
+
56
+ if function_name in module .__dict__ :
57
+ targets .add (f"{ module .__name__ } .{ function_name } " )
58
+
59
+ targets .update (
60
+ {
61
+ f"{ module .__name__ } .{ obj .__name__ } .{ method_name } "
62
+ for obj in module .__dict__ .values ()
63
+ if isinstance (obj , type ) and issubclass (obj , nn .Module ) and method_name in obj .__dict__
64
+ }
65
+ )
66
+
67
+ for target in targets :
68
+ # See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
69
+ with contextlib .suppress (AttributeError ):
70
+ mocker .patch (target )
71
+
72
+
26
73
def _get_expected_file (name = None ):
27
74
# Determine expected file based on environment
28
75
expected_file_base = get_relative_path (os .path .realpath (__file__ ), "expect" )
@@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn):
762
809
763
810
764
811
@pytest .mark .parametrize ("model_fn" , get_models_from_module (models .detection ))
765
- def test_detection_model_trainable_backbone_layers (model_fn ):
812
+ def test_detection_model_trainable_backbone_layers (model_fn , disable_weight_loading ):
766
813
model_name = model_fn .__name__
767
814
max_trainable = _model_tests_values [model_name ]["max_trainable" ]
768
815
n_trainable_params = []
0 commit comments