Skip to content

Commit 79ba631

Browse files
committed
Improve test of backbone utils
1 parent a8bde78 commit 79ba631

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

test/test_backbone_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from common_utils import set_rng_seed
88
from torchvision import models
99
from torchvision.models._utils import IntermediateLayerGetter
10-
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
10+
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
1111
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
1212

1313

@@ -16,10 +16,12 @@ def get_available_models():
1616
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
1717

1818

19-
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
19+
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50", "resnet101"))
2020
def test_resnet_fpn_backbone(backbone_name):
2121
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
22-
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
22+
model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)
23+
assert isinstance(model, BackboneWithFPN)
24+
y = model(x)
2325
assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
2426

2527
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
@@ -38,6 +40,10 @@ def test_mobilenet_backbone(backbone_name):
3840
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
3941
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
4042
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])
43+
model_fpn = mobilenet_backbone(backbone_name, False, fpn=True)
44+
assert isinstance(model_fpn, BackboneWithFPN)
45+
model = mobilenet_backbone(backbone_name, False, fpn=False)
46+
assert isinstance(model, torch.nn.Sequential)
4147

4248

4349
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function

0 commit comments

Comments
 (0)