diff --git a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl index b8fd97cbfe1..7fb8d66b080 100644 Binary files a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl and b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl differ diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index e697e06929a..1c2bceacda0 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Dict, Optional, List +from typing import Callable, Dict, Optional, List, Union from torch import nn, Tensor from torchvision.ops import misc as misc_nn_ops @@ -100,14 +100,14 @@ def resnet_fpn_backbone( default a ``LastLevelMaxPool`` is used. """ backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) - return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) + return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) -def _resnet_backbone_config( +def _resnet_fpn_extractor( backbone: resnet.ResNet, trainable_layers: int, - returned_layers: Optional[List[int]], - extra_blocks: Optional[ExtraFPNBlock], + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, ) -> BackboneWithFPN: # select layers that wont be frozen @@ -165,9 +165,18 @@ def mobilenet_backbone( returned_layers: Optional[List[int]] = None, extra_blocks: Optional[ExtraFPNBlock] = None, ) -> nn.Module: + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks) - backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features +def _mobilenet_extractor( + backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], + fpn: bool, + trainable_layers, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> nn.Module: + backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 02da39e8c73..30a22529a00 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -3,9 +3,12 @@ from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import misc as misc_nn_ops +from ..mobilenetv3 import mobilenet_v3_large +from ..resnet import resnet50 from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor from .generalized_rcnn import GeneralizedRCNN from .roi_heads import RoIHeads from .rpn import RPNHead, RegionProposalNetwork @@ -385,7 +388,9 @@ def fasterrcnn_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = FasterRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) @@ -409,9 +414,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn( if pretrained: pretrained_backbone = False - backbone = mobilenet_backbone( - "mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers + + backbone = mobilenet_v3_large( + pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d ) + backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) anchor_sizes = ( ( diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 7cd975ea6a0..d16d1600c77 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -3,8 +3,10 @@ from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import misc as misc_nn_ops +from ..resnet import resnet50 from ._utils import overwrite_eps -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN @@ -367,7 +369,9 @@ def keypointrcnn_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: key = "keypointrcnn_resnet50_fpn_coco" diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 6b8208b19d8..ca2cbc114f0 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -4,8 +4,10 @@ from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import misc as misc_nn_ops +from ..resnet import resnet50 from ._utils import overwrite_eps -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN __all__ = [ @@ -364,7 +366,9 @@ def maskrcnn_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index eb05144cb0c..4b16d7edc7f 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -9,11 +9,13 @@ from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops +from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ..resnet import resnet50 from . import _utils as det_utils from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform @@ -630,13 +632,11 @@ def retinanet_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) # skip P2 because it generates too many anchors (according to their paper) - backbone = resnet_fpn_backbone( - "resnet50", - pretrained_backbone, - returned_layers=[2, 3, 4], - extra_blocks=LastLevelP6P7(256, 256), - trainable_layers=trainable_backbone_layers, + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = RetinaNet(backbone, num_classes, **kwargs) if pretrained: diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index c7d74b4e1af..5a068a0f0cc 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -23,7 +23,8 @@ backbone_urls = { # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth - "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth" + # Only the `features` weights have proper values, those on the `classifier` module are filled with nans. + "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth" } @@ -519,18 +520,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): - if backbone_name in backbone_urls: - # Use custom backbones more appropriate for SSD - arch = backbone_name.split("_")[0] - backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features - if pretrained: - state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress) - backbone.load_state_dict(state_dict) - else: - # Use standard backbones from TorchVision - backbone = vgg.__dict__[backbone_name](pretrained=pretrained, progress=progress).features - +def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): + backbone = backbone.features # Gather the indices of maxpools. These are the locations of output blocks. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] num_stages = len(stage_indices) @@ -609,7 +600,13 @@ def ssd300_vgg16( # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) + # Use custom backbones more appropriate for SSD + backbone = vgg.vgg16(pretrained=False, progress=progress) + if pretrained_backbone: + state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress) + backbone.load_state_dict(state_dict) + + backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator( [[2], [2, 3], [2, 3], [2, 3], [2], [2]], scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 503b69d7380..e32aebcf839 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -1,7 +1,7 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import nn, Tensor @@ -117,7 +117,6 @@ def __init__( norm_layer: Callable[..., nn.Module], width_mult: float = 1.0, min_depth: int = 16, - **kwargs: Any, ): super().__init__() @@ -156,20 +155,11 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def _mobilenet_extractor( - backbone_name: str, - progress: bool, - pretrained: bool, + backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], trainable_layers: int, norm_layer: Callable[..., nn.Module], - **kwargs: Any, ): - backbone = mobilenet.__dict__[backbone_name]( - pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs - ).features - if not pretrained: - # Change the default initialization scheme if not pretrained - _normal_init(backbone) - + backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] @@ -183,7 +173,7 @@ def _mobilenet_extractor( for parameter in b.parameters(): parameter.requires_grad_(False) - return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) + return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer) def ssdlite320_mobilenet_v3_large( @@ -235,14 +225,16 @@ def ssdlite320_mobilenet_v3_large( if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) + backbone = mobilenet.mobilenet_v3_large( + pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs + ) + if not pretrained_backbone: + # Change the default initialization scheme if not pretrained + _normal_init(backbone) backbone = _mobilenet_extractor( - "mobilenet_v3_large", - progress, - pretrained_backbone, + backbone, trainable_backbone_layers, norm_layer, - reduced_tail=reduce_tail, - **kwargs, ) size = (320, 320) diff --git a/torchvision/prototype/models/detection/backbone_utils.py b/torchvision/prototype/models/detection/backbone_utils.py deleted file mode 100644 index d95b1ab52f0..00000000000 --- a/torchvision/prototype/models/detection/backbone_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Callable, Optional, List - -from torch import nn - -from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock -from .. import resnet -from .._api import Weights - - -def resnet_fpn_backbone( - backbone_name: str, - weights: Optional[Weights], - norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, - trainable_layers: int = 3, - returned_layers: Optional[List[int]] = None, - extra_blocks: Optional[ExtraFPNBlock] = None, -) -> BackboneWithFPN: - - backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) - return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 0b27eb50a37..bb3817c6b45 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,12 +1,17 @@ import warnings from typing import Any, Optional -from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers +from ....models.detection.faster_rcnn import ( + _validate_trainable_layers, + _resnet_fpn_extractor, + FasterRCNN, + misc_nn_ops, + overwrite_eps, +) from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES -from ..resnet import ResNet50Weights -from .backbone_utils import resnet_fpn_backbone +from ..resnet import ResNet50Weights, resnet50 __all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] @@ -49,7 +54,8 @@ def fasterrcnn_resnet50_fpn( weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 ) - backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) if weights is not None: