Skip to content

Porting Detection models #5617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_get_weight(name, weight):
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(torchvision.models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(torchvision.models.detection)
+ TM.get_models_from_module(torchvision.models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
Expand All @@ -90,7 +90,7 @@ def test_naming_conventions(model_fn):
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(torchvision.models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(torchvision.models.detection)
+ TM.get_models_from_module(torchvision.models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
Expand Down Expand Up @@ -143,13 +143,6 @@ def test_schema_meta_validation(model_fn):
assert not bad_names


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_detection_model(model_fn, dev):
TM.test_detection_model(model_fn, dev)


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
Expand All @@ -174,8 +167,7 @@ def test_raft(model_builder, scripted):

@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.segmentation)
TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .faster_rcnn import *
from .mask_rcnn import *
from .fcos import *
from .keypoint_rcnn import *
from .mask_rcnn import *
from .retinanet import *
from .ssd import *
from .ssdlite import *
from .fcos import *
2 changes: 1 addition & 1 deletion torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def resnet_fpn_backbone(
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
norm_layer (callable): it is recommended to use the default value. For details visit:
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
trainable_layers (int): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
By default all layers are returned.
Expand Down
221 changes: 156 additions & 65 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Any, Optional, Union

import torch.nn.functional as F
from torch import nn
from torchvision.ops import MultiScaleRoIAlign

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import misc as misc_nn_ops
from ..mobilenetv3 import mobilenet_v3_large
from ..resnet import resnet50
from ...transforms import ObjectDetectionEval, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet50_Weights, resnet50
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor
Expand All @@ -17,9 +22,12 @@

__all__ = [
"FasterRCNN",
"FasterRCNN_ResNet50_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
]


Expand Down Expand Up @@ -307,16 +315,70 @@ def forward(self, x):
return scores, bbox_deltas


model_urls = {
"fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
"fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
"fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
_COMMON_META = {
"task": "image_object_detection",
"architecture": "FasterRCNN",
"publication_year": 2015,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}


class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 41755286,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
)
DEFAULT = COCO_V1


class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
)
DEFAULT = COCO_V1


class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
)
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fasterrcnn_resnet50_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
*,
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.

Expand Down Expand Up @@ -375,51 +437,60 @@ def fasterrcnn_resnet50_fpn(
>>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
is_trained = pretrained or pretrained_backbone
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)

if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d

if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False

backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

return model


def _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=91,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
is_trained = pretrained or pretrained_backbone
*,
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
progress: bool,
num_classes: Optional[int],
weights_backbone: Optional[MobileNet_V3_Large_Weights],
trainable_backbone_layers: Optional[int],
**kwargs: Any,
) -> FasterRCNN:
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d

if pretrained:
pretrained_backbone = False

backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)

anchor_sizes = (
(
32,
Expand All @@ -430,21 +501,29 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
),
) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

model = FasterRCNN(
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)
if pretrained:
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
Expand All @@ -459,15 +538,17 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
>>> predictions = model(x)

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
weights (FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco"
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)

defaults = {
"min_size": 320,
"max_size": 640,
Expand All @@ -478,19 +559,28 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(

kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=pretrained,
weights=weights,
progress=progress,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
weights_backbone=weights_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
Expand All @@ -505,26 +595,27 @@ def fasterrcnn_mobilenet_v3_large_fpn(
>>> predictions = model(x)

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
weights (FasterRCNN_MobileNet_V3_Large_FPN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco"
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)

defaults = {
"rpn_score_thresh": 0.05,
}

kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=pretrained,
weights=weights,
progress=progress,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
weights_backbone=weights_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
Loading