Skip to content

Refactor the backbone builders of detection #4656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl
Binary file not shown.
21 changes: 15 additions & 6 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Dict, Optional, List
from typing import Callable, Dict, Optional, List, Union

from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops
Expand Down Expand Up @@ -100,14 +100,14 @@ def resnet_fpn_backbone(
default a ``LastLevelMaxPool`` is used.
"""
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)


def _resnet_backbone_config(
def _resnet_fpn_extractor(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename for consistency.

backbone: resnet.ResNet,
trainable_layers: int,
returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock],
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:

# select layers that wont be frozen
Expand Down Expand Up @@ -165,9 +165,18 @@ def mobilenet_backbone(
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)

backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

def _mobilenet_extractor(
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass a full pre-initialized MobileNetV* backbone instead of a backbone_name.

fpn: bool,
trainable_layers,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
Expand Down
15 changes: 11 additions & 4 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from torchvision.ops import MultiScaleRoIAlign

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import misc as misc_nn_ops
from ..mobilenetv3 import mobilenet_v3_large
from ..resnet import resnet50
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor
from .generalized_rcnn import GeneralizedRCNN
from .roi_heads import RoIHeads
from .rpn import RPNHead, RegionProposalNetwork
Expand Down Expand Up @@ -385,7 +388,9 @@ def fasterrcnn_resnet50_fpn(
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)

backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pretrained backbone is always initialized beforehand in the builder.

backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then passed to the *_extractor method which builds the proper backbone.

model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress)
Expand All @@ -409,9 +414,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn(

if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone(
"mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers

backbone = mobilenet_v3_large(
pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d
)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)

anchor_sizes = (
(
Expand Down
8 changes: 6 additions & 2 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from torchvision.ops import MultiScaleRoIAlign

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import misc as misc_nn_ops
from ..resnet import resnet50
from ._utils import overwrite_eps
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .faster_rcnn import FasterRCNN


Expand Down Expand Up @@ -367,7 +369,9 @@ def keypointrcnn_resnet50_fpn(
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)

backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained:
key = "keypointrcnn_resnet50_fpn_coco"
Expand Down
8 changes: 6 additions & 2 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from torchvision.ops import MultiScaleRoIAlign

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import misc as misc_nn_ops
from ..resnet import resnet50
from ._utils import overwrite_eps
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .faster_rcnn import FasterRCNN

__all__ = [
Expand Down Expand Up @@ -364,7 +366,9 @@ def maskrcnn_resnet50_fpn(
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)

backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress)
Expand Down
14 changes: 7 additions & 7 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
from ...ops import misc as misc_nn_ops
from ...ops.feature_pyramid_network import LastLevelP6P7
from ..resnet import resnet50
from . import _utils as det_utils
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .transform import GeneralizedRCNNTransform


Expand Down Expand Up @@ -630,13 +632,11 @@ def retinanet_resnet50_fpn(
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False

backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
# skip P2 because it generates too many anchors (according to their paper)
backbone = resnet_fpn_backbone(
"resnet50",
pretrained_backbone,
returned_layers=[2, 3, 4],
extra_blocks=LastLevelP6P7(256, 256),
trainable_layers=trainable_backbone_layers,
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As commented earlier, because the initialization of the backbone happens now before the one of the LastLevelP6P7, the expected value of the test needs to change. I tried moving the initialization before and it works with the previous expected file. So all good.

)
model = RetinaNet(backbone, num_classes, **kwargs)
if pretrained:
Expand Down
25 changes: 11 additions & 14 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
backbone_urls = {
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth"
# Only the `features` weights have proper values, those on the `classifier` module are filled with nans.
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous weights only contained the values of vgg.features. Now it contains the values of the entire vgg model. Though the size increases, it allows us to load the weights with the standard mechanism instead of relying on hacks (this will become possible on the multi-pretrained weights PR).

}


Expand Down Expand Up @@ -519,18 +520,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
return OrderedDict([(str(i), v) for i, v in enumerate(output)])


def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int):
if backbone_name in backbone_urls:
# Use custom backbones more appropriate for SSD
arch = backbone_name.split("_")[0]
backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features
if pretrained:
state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress)
backbone.load_state_dict(state_dict)
else:
# Use standard backbones from TorchVision
backbone = vgg.__dict__[backbone_name](pretrained=pretrained, progress=progress).features

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This complex logic can go away.

def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int):
backbone = backbone.features
# Gather the indices of maxpools. These are the locations of output blocks.
stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1]
num_stages = len(stage_indices)
Expand Down Expand Up @@ -609,7 +600,13 @@ def ssd300_vgg16(
# no need to download the backbone if pretrained is set
pretrained_backbone = False

backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers)
# Use custom backbones more appropriate for SSD
backbone = vgg.vgg16(pretrained=False, progress=progress)
if pretrained_backbone:
state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress)
backbone.load_state_dict(state_dict)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By initializing the backbone on the builder, the complexity is reduced. Moreover, the multi-pretrained weights project will allow us to remove the if statement and just pass directly the exact weights we want to load to the vgg16.


backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator(
[[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
Expand Down
30 changes: 11 additions & 19 deletions torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -117,7 +117,6 @@ def __init__(
norm_layer: Callable[..., nn.Module],
width_mult: float = 1.0,
min_depth: int = 16,
**kwargs: Any,
):
super().__init__()

Expand Down Expand Up @@ -156,20 +155,11 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:


def _mobilenet_extractor(
backbone_name: str,
progress: bool,
pretrained: bool,
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
trainable_layers: int,
norm_layer: Callable[..., nn.Module],
**kwargs: Any,
):
backbone = mobilenet.__dict__[backbone_name](
pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs
).features
if not pretrained:
# Change the default initialization scheme if not pretrained
_normal_init(backbone)

backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
Expand All @@ -183,7 +173,7 @@ def _mobilenet_extractor(
for parameter in b.parameters():
parameter.requires_grad_(False)

return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Safe. No kwargs were used.

return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)


def ssdlite320_mobilenet_v3_large(
Expand Down Expand Up @@ -235,14 +225,16 @@ def ssdlite320_mobilenet_v3_large(
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)

backbone = mobilenet.mobilenet_v3_large(
pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reduced_tail and kwargs were supposed to be passed to the mobilenet_v3_large builder.

)
if not pretrained_backbone:
# Change the default initialization scheme if not pretrained
_normal_init(backbone)
backbone = _mobilenet_extractor(
"mobilenet_v3_large",
progress,
pretrained_backbone,
backbone,
trainable_backbone_layers,
norm_layer,
reduced_tail=reduce_tail,
**kwargs,
)

size = (320, 320)
Expand Down
20 changes: 0 additions & 20 deletions torchvision/prototype/models/detection/backbone_utils.py

This file was deleted.

14 changes: 10 additions & 4 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import warnings
from typing import Any, Optional

from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers
from ....models.detection.faster_rcnn import (
_validate_trainable_layers,
_resnet_fpn_extractor,
FasterRCNN,
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights
from .backbone_utils import resnet_fpn_backbone
from ..resnet import ResNet50Weights, resnet50


__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
Expand Down Expand Up @@ -49,7 +54,8 @@ def fasterrcnn_resnet50_fpn(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)

backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
Expand Down