@@ -267,6 +267,43 @@ def _check_input_backprop(model, inputs):
267
267
}
268
268
269
269
270
+ # The following contains configuration and expected values to be used tests that are model specific
271
+ _model_tests_values = {
272
+ "retinanet_resnet50_fpn" : {
273
+ "max_trainable" : 5 ,
274
+ "n_trn_params_per_layer" : [36 , 46 , 65 , 78 , 88 , 89 ],
275
+ },
276
+ "keypointrcnn_resnet50_fpn" : {
277
+ "max_trainable" : 5 ,
278
+ "n_trn_params_per_layer" : [48 , 58 , 77 , 90 , 100 , 101 ],
279
+ },
280
+ "fasterrcnn_resnet50_fpn" : {
281
+ "max_trainable" : 5 ,
282
+ "n_trn_params_per_layer" : [30 , 40 , 59 , 72 , 82 , 83 ],
283
+ },
284
+ "maskrcnn_resnet50_fpn" : {
285
+ "max_trainable" : 5 ,
286
+ "n_trn_params_per_layer" : [42 , 52 , 71 , 84 , 94 , 95 ],
287
+ },
288
+ "fasterrcnn_mobilenet_v3_large_fpn" : {
289
+ "max_trainable" : 6 ,
290
+ "n_trn_params_per_layer" : [22 , 23 , 44 , 70 , 91 , 97 , 100 ],
291
+ },
292
+ "fasterrcnn_mobilenet_v3_large_320_fpn" : {
293
+ "max_trainable" : 6 ,
294
+ "n_trn_params_per_layer" : [22 , 23 , 44 , 70 , 91 , 97 , 100 ],
295
+ },
296
+ "ssd300_vgg16" : {
297
+ "max_trainable" : 5 ,
298
+ "n_trn_params_per_layer" : [45 , 51 , 57 , 63 , 67 , 71 ],
299
+ },
300
+ "ssdlite320_mobilenet_v3_large" : {
301
+ "max_trainable" : 6 ,
302
+ "n_trn_params_per_layer" : [96 , 99 , 138 , 200 , 239 , 257 , 266 ],
303
+ },
304
+ }
305
+
306
+
270
307
def _make_sliced_model (model , stop_layer ):
271
308
layers = OrderedDict ()
272
309
for name , layer in model .named_children ():
@@ -740,5 +777,18 @@ def test_quantized_classification_model(model_name):
740
777
raise AssertionError (f"model cannot be scripted. Traceback = { str (tb )} " ) from e
741
778
742
779
780
+ @pytest .mark .parametrize ("model_name" , get_available_detection_models ())
781
+ def test_detection_model_trainable_backbone_layers (model_name ):
782
+ max_trainable = _model_tests_values [model_name ]["max_trainable" ]
783
+ n_trainable_params = []
784
+ for trainable_layers in range (0 , max_trainable + 1 ):
785
+ model = torchvision .models .detection .__dict__ [model_name ](
786
+ pretrained = False , pretrained_backbone = True , trainable_backbone_layers = trainable_layers
787
+ )
788
+
789
+ n_trainable_params .append (len ([p for p in model .parameters () if p .requires_grad ]))
790
+ assert n_trainable_params == _model_tests_values [model_name ]["n_trn_params_per_layer" ]
791
+
792
+
743
793
if __name__ == "__main__" :
744
794
pytest .main ([__file__ ])
0 commit comments