Skip to content

Commit 24d7ea6

Browse files
committed
Add efficientnet_fpn_backbone
1 parent a784db4 commit 24d7ea6

File tree

2 files changed

+83
-4
lines changed

2 files changed

+83
-4
lines changed

test/test_backbone_utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
import torch
77
from common_utils import set_rng_seed
88
from torchvision import models
9+
from torchvision.models import efficientnet, mobilenet, resnet
910
from torchvision.models._utils import IntermediateLayerGetter
10-
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
11+
from torchvision.models.detection.backbone_utils import (
12+
BackboneWithFPN,
13+
efficientnet_fpn_backbone,
14+
mobilenet_backbone,
15+
resnet_fpn_backbone,
16+
)
1117
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
1218

1319

@@ -16,7 +22,7 @@ def get_available_models():
1622
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
1723

1824

19-
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
25+
@pytest.mark.parametrize("backbone_name", resnet.__all__[1:])
2026
def test_resnet_fpn_backbone(backbone_name):
2127
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
2228
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
@@ -28,16 +34,34 @@ def test_resnet_fpn_backbone(backbone_name):
2834
resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3])
2935
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
3036
resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5])
37+
model = resnet_fpn_backbone(backbone_name, False)
38+
assert isinstance(model, BackboneWithFPN)
3139

3240

33-
@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
41+
@pytest.mark.parametrize("backbone_name", mobilenet.mv2_all[1:] + mobilenet.mv3_all[1:])
3442
def test_mobilenet_backbone(backbone_name):
3543
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
3644
mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1)
3745
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
3846
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
3947
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
4048
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])
49+
model_fpn = mobilenet_backbone(backbone_name, False, fpn=True)
50+
assert isinstance(model_fpn, BackboneWithFPN)
51+
model = mobilenet_backbone(backbone_name, False, fpn=False)
52+
assert isinstance(model, torch.nn.Sequential)
53+
54+
55+
@pytest.mark.parametrize("backbone_name", efficientnet.__all__[1:])
56+
def test_efficientnet_fpn_backbone(backbone_name):
57+
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
58+
efficientnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=-1)
59+
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
60+
efficientnet_fpn_backbone(backbone_name, False, returned_layers=[-1, 0, 1, 2])
61+
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
62+
efficientnet_fpn_backbone(backbone_name, False, returned_layers=[3, 4, 5, 6, 9])
63+
model = efficientnet_fpn_backbone(backbone_name, False)
64+
assert isinstance(model, BackboneWithFPN)
4165

4266

4367
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function

torchvision/models/detection/backbone_utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision.ops import misc as misc_nn_ops
66
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
77

8-
from .. import mobilenet, resnet
8+
from .. import efficientnet, mobilenet, resnet
99
from .._utils import IntermediateLayerGetter
1010

1111

@@ -216,3 +216,58 @@ def _mobilenet_extractor(
216216
)
217217
m.out_channels = out_channels # type: ignore[assignment]
218218
return m
219+
220+
221+
def efficientnet_fpn_backbone(
222+
backbone_name: str,
223+
pretrained: bool,
224+
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
225+
trainable_layers: int = 2,
226+
returned_layers: Optional[List[int]] = None,
227+
extra_blocks: ExtraFPNBlock = LastLevelMaxPool(),
228+
) -> nn.Module:
229+
if backbone_name in [
230+
"efficientnet_b5",
231+
"efficientnet_b6",
232+
"efficientnet_b7",
233+
"efficientnet_v2_s",
234+
"efficientnet_v2_m",
235+
"efficientnet_v2_l",
236+
]:
237+
backbone = efficientnet.__dict__[backbone_name](pretrained=pretrained)
238+
else:
239+
backbone = efficientnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
240+
return _efficientnet_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
241+
242+
243+
def _efficientnet_extractor(
244+
backbone: efficientnet.EfficientNet,
245+
trainable_layers: int,
246+
returned_layers: Optional[List[int]] = None,
247+
extra_blocks: ExtraFPNBlock = LastLevelMaxPool(),
248+
) -> nn.Module:
249+
backbone = backbone.features
250+
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
251+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
252+
stage_indices = [i for i, b in enumerate(backbone) if getattr(b[0], "out_channels", False)]
253+
num_stages = len(stage_indices)
254+
255+
# find the index of the layer from which we wont freeze
256+
if trainable_layers < 0 or trainable_layers > num_stages:
257+
raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
258+
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
259+
260+
for b in backbone[:freeze_before]:
261+
for parameter in b.parameters():
262+
parameter.requires_grad_(False)
263+
264+
out_channels = 256
265+
266+
if returned_layers is None:
267+
returned_layers = [num_stages - 2, num_stages - 1]
268+
if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
269+
raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
270+
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
271+
272+
in_channels_list = [backbone[stage_indices[i]][0].out_channels for i in returned_layers]
273+
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)

0 commit comments

Comments
 (0)