-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Refactor the backbone builders of detection #4656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ef1c220
ed319d7
9ff8df5
33f6db1
893ff41
6d60c5f
ab71f1d
1526680
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import warnings | ||
from typing import Callable, Dict, Optional, List | ||
from typing import Callable, Dict, Optional, List, Union | ||
|
||
from torch import nn, Tensor | ||
from torchvision.ops import misc as misc_nn_ops | ||
|
@@ -100,14 +100,14 @@ def resnet_fpn_backbone( | |
default a ``LastLevelMaxPool`` is used. | ||
""" | ||
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) | ||
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) | ||
return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) | ||
|
||
|
||
def _resnet_backbone_config( | ||
def _resnet_fpn_extractor( | ||
backbone: resnet.ResNet, | ||
trainable_layers: int, | ||
returned_layers: Optional[List[int]], | ||
extra_blocks: Optional[ExtraFPNBlock], | ||
returned_layers: Optional[List[int]] = None, | ||
extra_blocks: Optional[ExtraFPNBlock] = None, | ||
) -> BackboneWithFPN: | ||
|
||
# select layers that wont be frozen | ||
|
@@ -165,9 +165,18 @@ def mobilenet_backbone( | |
returned_layers: Optional[List[int]] = None, | ||
extra_blocks: Optional[ExtraFPNBlock] = None, | ||
) -> nn.Module: | ||
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) | ||
return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks) | ||
|
||
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features | ||
|
||
def _mobilenet_extractor( | ||
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pass a full pre-initialized MobileNetV* backbone instead of a backbone_name. |
||
fpn: bool, | ||
trainable_layers, | ||
returned_layers: Optional[List[int]] = None, | ||
extra_blocks: Optional[ExtraFPNBlock] = None, | ||
) -> nn.Module: | ||
backbone = backbone.features | ||
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. | ||
# The first and last blocks are always included because they are the C0 (conv1) and Cn. | ||
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,12 @@ | |
from torchvision.ops import MultiScaleRoIAlign | ||
|
||
from ..._internally_replaced_utils import load_state_dict_from_url | ||
from ...ops import misc as misc_nn_ops | ||
from ..mobilenetv3 import mobilenet_v3_large | ||
from ..resnet import resnet50 | ||
from ._utils import overwrite_eps | ||
from .anchor_utils import AnchorGenerator | ||
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone | ||
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor | ||
from .generalized_rcnn import GeneralizedRCNN | ||
from .roi_heads import RoIHeads | ||
from .rpn import RPNHead, RegionProposalNetwork | ||
|
@@ -385,7 +388,9 @@ def fasterrcnn_resnet50_fpn( | |
if pretrained: | ||
# no need to download the backbone if pretrained is set | ||
pretrained_backbone = False | ||
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) | ||
|
||
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pretrained backbone is always initialized beforehand in the builder. |
||
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then passed to the *_extractor method which builds the proper backbone. |
||
model = FasterRCNN(backbone, num_classes, **kwargs) | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) | ||
|
@@ -409,9 +414,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn( | |
|
||
if pretrained: | ||
pretrained_backbone = False | ||
backbone = mobilenet_backbone( | ||
"mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers | ||
|
||
backbone = mobilenet_v3_large( | ||
pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d | ||
) | ||
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) | ||
|
||
anchor_sizes = ( | ||
( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,11 +9,13 @@ | |
from ..._internally_replaced_utils import load_state_dict_from_url | ||
from ...ops import sigmoid_focal_loss | ||
from ...ops import boxes as box_ops | ||
from ...ops import misc as misc_nn_ops | ||
from ...ops.feature_pyramid_network import LastLevelP6P7 | ||
from ..resnet import resnet50 | ||
from . import _utils as det_utils | ||
from ._utils import overwrite_eps | ||
from .anchor_utils import AnchorGenerator | ||
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers | ||
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers | ||
from .transform import GeneralizedRCNNTransform | ||
|
||
|
||
|
@@ -630,13 +632,11 @@ def retinanet_resnet50_fpn( | |
if pretrained: | ||
# no need to download the backbone if pretrained is set | ||
pretrained_backbone = False | ||
|
||
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) | ||
# skip P2 because it generates too many anchors (according to their paper) | ||
backbone = resnet_fpn_backbone( | ||
"resnet50", | ||
pretrained_backbone, | ||
returned_layers=[2, 3, 4], | ||
extra_blocks=LastLevelP6P7(256, 256), | ||
trainable_layers=trainable_backbone_layers, | ||
backbone = _resnet_fpn_extractor( | ||
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As commented earlier, because the initialization of the |
||
) | ||
model = RetinaNet(backbone, num_classes, **kwargs) | ||
if pretrained: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,8 @@ | |
backbone_urls = { | ||
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the | ||
# same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth | ||
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth" | ||
# Only the `features` weights have proper values, those on the `classifier` module are filled with nans. | ||
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous weights only contained the values of |
||
} | ||
|
||
|
||
|
@@ -519,18 +520,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: | |
return OrderedDict([(str(i), v) for i, v in enumerate(output)]) | ||
|
||
|
||
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): | ||
if backbone_name in backbone_urls: | ||
# Use custom backbones more appropriate for SSD | ||
arch = backbone_name.split("_")[0] | ||
backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress) | ||
backbone.load_state_dict(state_dict) | ||
else: | ||
# Use standard backbones from TorchVision | ||
backbone = vgg.__dict__[backbone_name](pretrained=pretrained, progress=progress).features | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This complex logic can go away. |
||
def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): | ||
backbone = backbone.features | ||
# Gather the indices of maxpools. These are the locations of output blocks. | ||
stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] | ||
num_stages = len(stage_indices) | ||
|
@@ -609,7 +600,13 @@ def ssd300_vgg16( | |
# no need to download the backbone if pretrained is set | ||
pretrained_backbone = False | ||
|
||
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) | ||
# Use custom backbones more appropriate for SSD | ||
backbone = vgg.vgg16(pretrained=False, progress=progress) | ||
if pretrained_backbone: | ||
state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress) | ||
backbone.load_state_dict(state_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By initializing the backbone on the builder, the complexity is reduced. Moreover, the multi-pretrained weights project will allow us to remove the if statement and just pass directly the exact weights we want to load to the vgg16. |
||
|
||
backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) | ||
anchor_generator = DefaultBoxGenerator( | ||
[[2], [2, 3], [2, 3], [2, 3], [2], [2]], | ||
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import warnings | ||
from collections import OrderedDict | ||
from functools import partial | ||
from typing import Any, Callable, Dict, List, Optional | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
|
@@ -117,7 +117,6 @@ def __init__( | |
norm_layer: Callable[..., nn.Module], | ||
width_mult: float = 1.0, | ||
min_depth: int = 16, | ||
**kwargs: Any, | ||
): | ||
super().__init__() | ||
|
||
|
@@ -156,20 +155,11 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: | |
|
||
|
||
def _mobilenet_extractor( | ||
backbone_name: str, | ||
progress: bool, | ||
pretrained: bool, | ||
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], | ||
trainable_layers: int, | ||
norm_layer: Callable[..., nn.Module], | ||
**kwargs: Any, | ||
): | ||
backbone = mobilenet.__dict__[backbone_name]( | ||
pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs | ||
).features | ||
if not pretrained: | ||
# Change the default initialization scheme if not pretrained | ||
_normal_init(backbone) | ||
|
||
backbone = backbone.features | ||
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. | ||
# The first and last blocks are always included because they are the C0 (conv1) and Cn. | ||
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] | ||
|
@@ -183,7 +173,7 @@ def _mobilenet_extractor( | |
for parameter in b.parameters(): | ||
parameter.requires_grad_(False) | ||
|
||
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Safe. No kwargs were used. |
||
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer) | ||
|
||
|
||
def ssdlite320_mobilenet_v3_large( | ||
|
@@ -235,14 +225,16 @@ def ssdlite320_mobilenet_v3_large( | |
if norm_layer is None: | ||
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) | ||
|
||
backbone = mobilenet.mobilenet_v3_large( | ||
pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
) | ||
if not pretrained_backbone: | ||
# Change the default initialization scheme if not pretrained | ||
_normal_init(backbone) | ||
backbone = _mobilenet_extractor( | ||
"mobilenet_v3_large", | ||
progress, | ||
pretrained_backbone, | ||
backbone, | ||
trainable_backbone_layers, | ||
norm_layer, | ||
reduced_tail=reduce_tail, | ||
**kwargs, | ||
) | ||
|
||
size = (320, 320) | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename for consistency.