Skip to content

Commit a39e60e

Browse files
authored
Porting Detection models (#5617)
* fix inits * fix docs * Port faster_rcnn * Port fcos * Port keypoint_rcnn * Port mask_rcnn * Port retinanet * Port ssd * Port ssdlite * Fix linter * Fixing tests * Fixing tests * Fixing vgg test
1 parent 66d7642 commit a39e60e

21 files changed

+550
-1051
lines changed

test/test_prototype_models.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_get_weight(name, weight):
7474
@pytest.mark.parametrize(
7575
"model_fn",
7676
TM.get_models_from_module(torchvision.models)
77-
+ TM.get_models_from_module(models.detection)
77+
+ TM.get_models_from_module(torchvision.models.detection)
7878
+ TM.get_models_from_module(torchvision.models.quantization)
7979
+ TM.get_models_from_module(models.segmentation)
8080
+ TM.get_models_from_module(models.video)
@@ -90,7 +90,7 @@ def test_naming_conventions(model_fn):
9090
@pytest.mark.parametrize(
9191
"model_fn",
9292
TM.get_models_from_module(torchvision.models)
93-
+ TM.get_models_from_module(models.detection)
93+
+ TM.get_models_from_module(torchvision.models.detection)
9494
+ TM.get_models_from_module(torchvision.models.quantization)
9595
+ TM.get_models_from_module(models.segmentation)
9696
+ TM.get_models_from_module(models.video)
@@ -143,13 +143,6 @@ def test_schema_meta_validation(model_fn):
143143
assert not bad_names
144144

145145

146-
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
147-
@pytest.mark.parametrize("dev", cpu_and_gpu())
148-
@run_if_test_with_prototype
149-
def test_detection_model(model_fn, dev):
150-
TM.test_detection_model(model_fn, dev)
151-
152-
153146
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
154147
@pytest.mark.parametrize("dev", cpu_and_gpu())
155148
@run_if_test_with_prototype
@@ -174,8 +167,7 @@ def test_raft(model_builder, scripted):
174167

175168
@pytest.mark.parametrize(
176169
"model_fn",
177-
TM.get_models_from_module(models.detection)
178-
+ TM.get_models_from_module(models.segmentation)
170+
TM.get_models_from_module(models.segmentation)
179171
+ TM.get_models_from_module(models.video)
180172
+ TM.get_models_from_module(models.optical_flow),
181173
)
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .faster_rcnn import *
2-
from .mask_rcnn import *
2+
from .fcos import *
33
from .keypoint_rcnn import *
4+
from .mask_rcnn import *
45
from .retinanet import *
56
from .ssd import *
67
from .ssdlite import *
7-
from .fcos import *

torchvision/models/detection/backbone_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def resnet_fpn_backbone(
8888
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
8989
norm_layer (callable): it is recommended to use the default value. For details visit:
9090
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
91-
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
91+
trainable_layers (int): number of trainable (not frozen) layers starting from final block.
9292
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
9393
returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
9494
By default all layers are returned.

torchvision/models/detection/faster_rcnn.py

+156-65
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
from typing import Any, Optional, Union
2+
13
import torch.nn.functional as F
24
from torch import nn
35
from torchvision.ops import MultiScaleRoIAlign
46

5-
from ..._internally_replaced_utils import load_state_dict_from_url
67
from ...ops import misc as misc_nn_ops
7-
from ..mobilenetv3 import mobilenet_v3_large
8-
from ..resnet import resnet50
8+
from ...transforms import ObjectDetectionEval, InterpolationMode
9+
from .._api import WeightsEnum, Weights
10+
from .._meta import _COCO_CATEGORIES
11+
from .._utils import handle_legacy_interface, _ovewrite_value_param
12+
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
13+
from ..resnet import ResNet50_Weights, resnet50
914
from ._utils import overwrite_eps
1015
from .anchor_utils import AnchorGenerator
1116
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor
@@ -17,9 +22,12 @@
1722

1823
__all__ = [
1924
"FasterRCNN",
25+
"FasterRCNN_ResNet50_FPN_Weights",
26+
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
27+
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
2028
"fasterrcnn_resnet50_fpn",
21-
"fasterrcnn_mobilenet_v3_large_320_fpn",
2229
"fasterrcnn_mobilenet_v3_large_fpn",
30+
"fasterrcnn_mobilenet_v3_large_320_fpn",
2331
]
2432

2533

@@ -307,16 +315,70 @@ def forward(self, x):
307315
return scores, bbox_deltas
308316

309317

310-
model_urls = {
311-
"fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
312-
"fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
313-
"fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
318+
_COMMON_META = {
319+
"task": "image_object_detection",
320+
"architecture": "FasterRCNN",
321+
"publication_year": 2015,
322+
"categories": _COCO_CATEGORIES,
323+
"interpolation": InterpolationMode.BILINEAR,
314324
}
315325

316326

327+
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
328+
COCO_V1 = Weights(
329+
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
330+
transforms=ObjectDetectionEval,
331+
meta={
332+
**_COMMON_META,
333+
"num_params": 41755286,
334+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
335+
"map": 37.0,
336+
},
337+
)
338+
DEFAULT = COCO_V1
339+
340+
341+
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
342+
COCO_V1 = Weights(
343+
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
344+
transforms=ObjectDetectionEval,
345+
meta={
346+
**_COMMON_META,
347+
"num_params": 19386354,
348+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
349+
"map": 32.8,
350+
},
351+
)
352+
DEFAULT = COCO_V1
353+
354+
355+
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
356+
COCO_V1 = Weights(
357+
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
358+
transforms=ObjectDetectionEval,
359+
meta={
360+
**_COMMON_META,
361+
"num_params": 19386354,
362+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
363+
"map": 22.8,
364+
},
365+
)
366+
DEFAULT = COCO_V1
367+
368+
369+
@handle_legacy_interface(
370+
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
371+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
372+
)
317373
def fasterrcnn_resnet50_fpn(
318-
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
319-
):
374+
*,
375+
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
376+
progress: bool = True,
377+
num_classes: Optional[int] = None,
378+
weights_backbone: Optional[ResNet50_Weights] = None,
379+
trainable_backbone_layers: Optional[int] = None,
380+
**kwargs: Any,
381+
) -> FasterRCNN:
320382
"""
321383
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
322384
@@ -375,51 +437,60 @@ def fasterrcnn_resnet50_fpn(
375437
>>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
376438
377439
Args:
378-
pretrained (bool): If True, returns a model pre-trained on COCO train2017
440+
weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model
379441
progress (bool): If True, displays a progress bar of the download to stderr
380-
num_classes (int): number of output classes of the model (including the background)
381-
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
382-
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
442+
num_classes (int, optional): number of output classes of the model (including the background)
443+
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
444+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
383445
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
384446
passed (the default) this value is set to 3.
385447
"""
386-
is_trained = pretrained or pretrained_backbone
448+
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
449+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
450+
451+
if weights is not None:
452+
weights_backbone = None
453+
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
454+
elif num_classes is None:
455+
num_classes = 91
456+
457+
is_trained = weights is not None or weights_backbone is not None
387458
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
388459
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
389460

390-
if pretrained:
391-
# no need to download the backbone if pretrained is set
392-
pretrained_backbone = False
393-
394-
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
461+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
395462
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
396-
model = FasterRCNN(backbone, num_classes, **kwargs)
397-
if pretrained:
398-
state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress)
399-
model.load_state_dict(state_dict)
400-
overwrite_eps(model, 0.0)
463+
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
464+
465+
if weights is not None:
466+
model.load_state_dict(weights.get_state_dict(progress=progress))
467+
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
468+
overwrite_eps(model, 0.0)
469+
401470
return model
402471

403472

404473
def _fasterrcnn_mobilenet_v3_large_fpn(
405-
weights_name,
406-
pretrained=False,
407-
progress=True,
408-
num_classes=91,
409-
pretrained_backbone=True,
410-
trainable_backbone_layers=None,
411-
**kwargs,
412-
):
413-
is_trained = pretrained or pretrained_backbone
474+
*,
475+
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
476+
progress: bool,
477+
num_classes: Optional[int],
478+
weights_backbone: Optional[MobileNet_V3_Large_Weights],
479+
trainable_backbone_layers: Optional[int],
480+
**kwargs: Any,
481+
) -> FasterRCNN:
482+
if weights is not None:
483+
weights_backbone = None
484+
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
485+
elif num_classes is None:
486+
num_classes = 91
487+
488+
is_trained = weights is not None or weights_backbone is not None
414489
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
415490
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
416491

417-
if pretrained:
418-
pretrained_backbone = False
419-
420-
backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
492+
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
421493
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
422-
423494
anchor_sizes = (
424495
(
425496
32,
@@ -430,21 +501,29 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
430501
),
431502
) * 3
432503
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
433-
434504
model = FasterRCNN(
435505
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
436506
)
437-
if pretrained:
438-
if model_urls.get(weights_name, None) is None:
439-
raise ValueError(f"No checkpoint is available for model {weights_name}")
440-
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
441-
model.load_state_dict(state_dict)
507+
508+
if weights is not None:
509+
model.load_state_dict(weights.get_state_dict(progress=progress))
510+
442511
return model
443512

444513

514+
@handle_legacy_interface(
515+
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
516+
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
517+
)
445518
def fasterrcnn_mobilenet_v3_large_320_fpn(
446-
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
447-
):
519+
*,
520+
weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
521+
progress: bool = True,
522+
num_classes: Optional[int] = None,
523+
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
524+
trainable_backbone_layers: Optional[int] = None,
525+
**kwargs: Any,
526+
) -> FasterRCNN:
448527
"""
449528
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
450529
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
@@ -459,15 +538,17 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
459538
>>> predictions = model(x)
460539
461540
Args:
462-
pretrained (bool): If True, returns a model pre-trained on COCO train2017
541+
weights (FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, optional): The pretrained weights for the model
463542
progress (bool): If True, displays a progress bar of the download to stderr
464-
num_classes (int): number of output classes of the model (including the background)
465-
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
466-
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
543+
num_classes (int, optional): number of output classes of the model (including the background)
544+
weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
545+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
467546
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
468547
passed (the default) this value is set to 3.
469548
"""
470-
weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco"
549+
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
550+
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
551+
471552
defaults = {
472553
"min_size": 320,
473554
"max_size": 640,
@@ -478,19 +559,28 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
478559

479560
kwargs = {**defaults, **kwargs}
480561
return _fasterrcnn_mobilenet_v3_large_fpn(
481-
weights_name,
482-
pretrained=pretrained,
562+
weights=weights,
483563
progress=progress,
484564
num_classes=num_classes,
485-
pretrained_backbone=pretrained_backbone,
565+
weights_backbone=weights_backbone,
486566
trainable_backbone_layers=trainable_backbone_layers,
487567
**kwargs,
488568
)
489569

490570

571+
@handle_legacy_interface(
572+
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
573+
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
574+
)
491575
def fasterrcnn_mobilenet_v3_large_fpn(
492-
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
493-
):
576+
*,
577+
weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
578+
progress: bool = True,
579+
num_classes: Optional[int] = None,
580+
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
581+
trainable_backbone_layers: Optional[int] = None,
582+
**kwargs: Any,
583+
) -> FasterRCNN:
494584
"""
495585
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
496586
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
@@ -505,26 +595,27 @@ def fasterrcnn_mobilenet_v3_large_fpn(
505595
>>> predictions = model(x)
506596
507597
Args:
508-
pretrained (bool): If True, returns a model pre-trained on COCO train2017
598+
weights (FasterRCNN_MobileNet_V3_Large_FPN_Weights, optional): The pretrained weights for the model
509599
progress (bool): If True, displays a progress bar of the download to stderr
510-
num_classes (int): number of output classes of the model (including the background)
511-
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
512-
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
600+
num_classes (int, optional): number of output classes of the model (including the background)
601+
weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
602+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
513603
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
514604
passed (the default) this value is set to 3.
515605
"""
516-
weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco"
606+
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
607+
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
608+
517609
defaults = {
518610
"rpn_score_thresh": 0.05,
519611
}
520612

521613
kwargs = {**defaults, **kwargs}
522614
return _fasterrcnn_mobilenet_v3_large_fpn(
523-
weights_name,
524-
pretrained=pretrained,
615+
weights=weights,
525616
progress=progress,
526617
num_classes=num_classes,
527-
pretrained_backbone=pretrained_backbone,
618+
weights_backbone=weights_backbone,
528619
trainable_backbone_layers=trainable_backbone_layers,
529620
**kwargs,
530621
)

0 commit comments

Comments
 (0)