Skip to content

Commit 6530546

Browse files
jdsgomesdatumbox
andauthored
adding test for trainable paramters in detection models (#4632)
* adding test for trainable paramters in detection models * modifying range of trainable layers Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 4628481 commit 6530546

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test/test_models.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,43 @@ def _check_input_backprop(model, inputs):
267267
}
268268

269269

270+
# The following contains configuration and expected values to be used tests that are model specific
271+
_model_tests_values = {
272+
"retinanet_resnet50_fpn": {
273+
"max_trainable": 5,
274+
"n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
275+
},
276+
"keypointrcnn_resnet50_fpn": {
277+
"max_trainable": 5,
278+
"n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
279+
},
280+
"fasterrcnn_resnet50_fpn": {
281+
"max_trainable": 5,
282+
"n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
283+
},
284+
"maskrcnn_resnet50_fpn": {
285+
"max_trainable": 5,
286+
"n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
287+
},
288+
"fasterrcnn_mobilenet_v3_large_fpn": {
289+
"max_trainable": 6,
290+
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
291+
},
292+
"fasterrcnn_mobilenet_v3_large_320_fpn": {
293+
"max_trainable": 6,
294+
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
295+
},
296+
"ssd300_vgg16": {
297+
"max_trainable": 5,
298+
"n_trn_params_per_layer": [45, 51, 57, 63, 67, 71],
299+
},
300+
"ssdlite320_mobilenet_v3_large": {
301+
"max_trainable": 6,
302+
"n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266],
303+
},
304+
}
305+
306+
270307
def _make_sliced_model(model, stop_layer):
271308
layers = OrderedDict()
272309
for name, layer in model.named_children():
@@ -740,5 +777,18 @@ def test_quantized_classification_model(model_name):
740777
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e
741778

742779

780+
@pytest.mark.parametrize("model_name", get_available_detection_models())
781+
def test_detection_model_trainable_backbone_layers(model_name):
782+
max_trainable = _model_tests_values[model_name]["max_trainable"]
783+
n_trainable_params = []
784+
for trainable_layers in range(0, max_trainable + 1):
785+
model = torchvision.models.detection.__dict__[model_name](
786+
pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers
787+
)
788+
789+
n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad]))
790+
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]
791+
792+
743793
if __name__ == "__main__":
744794
pytest.main([__file__])

0 commit comments

Comments
 (0)