From 4d9d109151136d31737e68090bda13a8b1a437ba Mon Sep 17 00:00:00 2001 From: Tal Regev Date: Fri, 4 Mar 2022 16:51:54 +0200 Subject: [PATCH] Improve test of backbone utils --- test/test_backbone_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index a55929d4b36..ed9b52d0499 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -7,7 +7,7 @@ from common_utils import set_rng_seed from torchvision import models from torchvision.models._utils import IntermediateLayerGetter -from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone +from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names @@ -19,7 +19,9 @@ def get_available_models(): @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") - y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) + model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False) + assert isinstance(model, BackboneWithFPN) + y = model(x) assert list(y.keys()) == ["0", "1", "2", "3", "pool"] with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): @@ -38,6 +40,10 @@ def test_mobilenet_backbone(backbone_name): mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6]) + model_fpn = mobilenet_backbone(backbone_name, False, fpn=True) + assert isinstance(model_fpn, BackboneWithFPN) + model = mobilenet_backbone(backbone_name, False, fpn=False) + assert isinstance(model, torch.nn.Sequential) # Needed by TestFxFeatureExtraction.test_leaf_module_and_function