1
+ from typing import Any , Optional , Union
2
+
1
3
import torch .nn .functional as F
2
4
from torch import nn
3
5
from torchvision .ops import MultiScaleRoIAlign
4
6
5
- from ..._internally_replaced_utils import load_state_dict_from_url
6
7
from ...ops import misc as misc_nn_ops
7
- from ..mobilenetv3 import mobilenet_v3_large
8
- from ..resnet import resnet50
8
+ from ...transforms import ObjectDetectionEval , InterpolationMode
9
+ from .._api import WeightsEnum , Weights
10
+ from .._meta import _COCO_CATEGORIES
11
+ from .._utils import handle_legacy_interface , _ovewrite_value_param
12
+ from ..mobilenetv3 import MobileNet_V3_Large_Weights , mobilenet_v3_large
13
+ from ..resnet import ResNet50_Weights , resnet50
9
14
from ._utils import overwrite_eps
10
15
from .anchor_utils import AnchorGenerator
11
16
from .backbone_utils import _resnet_fpn_extractor , _validate_trainable_layers , _mobilenet_extractor
17
22
18
23
__all__ = [
19
24
"FasterRCNN" ,
25
+ "FasterRCNN_ResNet50_FPN_Weights" ,
26
+ "FasterRCNN_MobileNet_V3_Large_FPN_Weights" ,
27
+ "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights" ,
20
28
"fasterrcnn_resnet50_fpn" ,
21
- "fasterrcnn_mobilenet_v3_large_320_fpn" ,
22
29
"fasterrcnn_mobilenet_v3_large_fpn" ,
30
+ "fasterrcnn_mobilenet_v3_large_320_fpn" ,
23
31
]
24
32
25
33
@@ -307,16 +315,70 @@ def forward(self, x):
307
315
return scores , bbox_deltas
308
316
309
317
310
- model_urls = {
311
- "fasterrcnn_resnet50_fpn_coco" : "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" ,
312
- "fasterrcnn_mobilenet_v3_large_320_fpn_coco" : "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth" ,
313
- "fasterrcnn_mobilenet_v3_large_fpn_coco" : "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth" ,
318
+ _COMMON_META = {
319
+ "task" : "image_object_detection" ,
320
+ "architecture" : "FasterRCNN" ,
321
+ "publication_year" : 2015 ,
322
+ "categories" : _COCO_CATEGORIES ,
323
+ "interpolation" : InterpolationMode .BILINEAR ,
314
324
}
315
325
316
326
327
+ class FasterRCNN_ResNet50_FPN_Weights (WeightsEnum ):
328
+ COCO_V1 = Weights (
329
+ url = "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" ,
330
+ transforms = ObjectDetectionEval ,
331
+ meta = {
332
+ ** _COMMON_META ,
333
+ "num_params" : 41755286 ,
334
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn" ,
335
+ "map" : 37.0 ,
336
+ },
337
+ )
338
+ DEFAULT = COCO_V1
339
+
340
+
341
+ class FasterRCNN_MobileNet_V3_Large_FPN_Weights (WeightsEnum ):
342
+ COCO_V1 = Weights (
343
+ url = "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth" ,
344
+ transforms = ObjectDetectionEval ,
345
+ meta = {
346
+ ** _COMMON_META ,
347
+ "num_params" : 19386354 ,
348
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn" ,
349
+ "map" : 32.8 ,
350
+ },
351
+ )
352
+ DEFAULT = COCO_V1
353
+
354
+
355
+ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights (WeightsEnum ):
356
+ COCO_V1 = Weights (
357
+ url = "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth" ,
358
+ transforms = ObjectDetectionEval ,
359
+ meta = {
360
+ ** _COMMON_META ,
361
+ "num_params" : 19386354 ,
362
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn" ,
363
+ "map" : 22.8 ,
364
+ },
365
+ )
366
+ DEFAULT = COCO_V1
367
+
368
+
369
+ @handle_legacy_interface (
370
+ weights = ("pretrained" , FasterRCNN_ResNet50_FPN_Weights .COCO_V1 ),
371
+ weights_backbone = ("pretrained_backbone" , ResNet50_Weights .IMAGENET1K_V1 ),
372
+ )
317
373
def fasterrcnn_resnet50_fpn (
318
- pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True , trainable_backbone_layers = None , ** kwargs
319
- ):
374
+ * ,
375
+ weights : Optional [FasterRCNN_ResNet50_FPN_Weights ] = None ,
376
+ progress : bool = True ,
377
+ num_classes : Optional [int ] = None ,
378
+ weights_backbone : Optional [ResNet50_Weights ] = None ,
379
+ trainable_backbone_layers : Optional [int ] = None ,
380
+ ** kwargs : Any ,
381
+ ) -> FasterRCNN :
320
382
"""
321
383
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
322
384
@@ -375,51 +437,60 @@ def fasterrcnn_resnet50_fpn(
375
437
>>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
376
438
377
439
Args:
378
- pretrained (bool ): If True, returns a model pre-trained on COCO train2017
440
+ weights (FasterRCNN_ResNet50_FPN_Weights, optional ): The pretrained weights for the model
379
441
progress (bool): If True, displays a progress bar of the download to stderr
380
- num_classes (int): number of output classes of the model (including the background)
381
- pretrained_backbone (bool ): If True, returns a model with backbone pre-trained on Imagenet
382
- trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
442
+ num_classes (int, optional ): number of output classes of the model (including the background)
443
+ weights_backbone (ResNet50_Weights, optional ): The pretrained weights for the backbone
444
+ trainable_backbone_layers (int, optional ): number of trainable (not frozen) layers starting from final block.
383
445
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
384
446
passed (the default) this value is set to 3.
385
447
"""
386
- is_trained = pretrained or pretrained_backbone
448
+ weights = FasterRCNN_ResNet50_FPN_Weights .verify (weights )
449
+ weights_backbone = ResNet50_Weights .verify (weights_backbone )
450
+
451
+ if weights is not None :
452
+ weights_backbone = None
453
+ num_classes = _ovewrite_value_param (num_classes , len (weights .meta ["categories" ]))
454
+ elif num_classes is None :
455
+ num_classes = 91
456
+
457
+ is_trained = weights is not None or weights_backbone is not None
387
458
trainable_backbone_layers = _validate_trainable_layers (is_trained , trainable_backbone_layers , 5 , 3 )
388
459
norm_layer = misc_nn_ops .FrozenBatchNorm2d if is_trained else nn .BatchNorm2d
389
460
390
- if pretrained :
391
- # no need to download the backbone if pretrained is set
392
- pretrained_backbone = False
393
-
394
- backbone = resnet50 (pretrained = pretrained_backbone , progress = progress , norm_layer = norm_layer )
461
+ backbone = resnet50 (weights = weights_backbone , progress = progress , norm_layer = norm_layer )
395
462
backbone = _resnet_fpn_extractor (backbone , trainable_backbone_layers )
396
- model = FasterRCNN (backbone , num_classes , ** kwargs )
397
- if pretrained :
398
- state_dict = load_state_dict_from_url (model_urls ["fasterrcnn_resnet50_fpn_coco" ], progress = progress )
399
- model .load_state_dict (state_dict )
400
- overwrite_eps (model , 0.0 )
463
+ model = FasterRCNN (backbone , num_classes = num_classes , ** kwargs )
464
+
465
+ if weights is not None :
466
+ model .load_state_dict (weights .get_state_dict (progress = progress ))
467
+ if weights == FasterRCNN_ResNet50_FPN_Weights .COCO_V1 :
468
+ overwrite_eps (model , 0.0 )
469
+
401
470
return model
402
471
403
472
404
473
def _fasterrcnn_mobilenet_v3_large_fpn (
405
- weights_name ,
406
- pretrained = False ,
407
- progress = True ,
408
- num_classes = 91 ,
409
- pretrained_backbone = True ,
410
- trainable_backbone_layers = None ,
411
- ** kwargs ,
412
- ):
413
- is_trained = pretrained or pretrained_backbone
474
+ * ,
475
+ weights : Optional [Union [FasterRCNN_MobileNet_V3_Large_FPN_Weights , FasterRCNN_MobileNet_V3_Large_320_FPN_Weights ]],
476
+ progress : bool ,
477
+ num_classes : Optional [int ],
478
+ weights_backbone : Optional [MobileNet_V3_Large_Weights ],
479
+ trainable_backbone_layers : Optional [int ],
480
+ ** kwargs : Any ,
481
+ ) -> FasterRCNN :
482
+ if weights is not None :
483
+ weights_backbone = None
484
+ num_classes = _ovewrite_value_param (num_classes , len (weights .meta ["categories" ]))
485
+ elif num_classes is None :
486
+ num_classes = 91
487
+
488
+ is_trained = weights is not None or weights_backbone is not None
414
489
trainable_backbone_layers = _validate_trainable_layers (is_trained , trainable_backbone_layers , 6 , 3 )
415
490
norm_layer = misc_nn_ops .FrozenBatchNorm2d if is_trained else nn .BatchNorm2d
416
491
417
- if pretrained :
418
- pretrained_backbone = False
419
-
420
- backbone = mobilenet_v3_large (pretrained = pretrained_backbone , progress = progress , norm_layer = norm_layer )
492
+ backbone = mobilenet_v3_large (weights = weights_backbone , progress = progress , norm_layer = norm_layer )
421
493
backbone = _mobilenet_extractor (backbone , True , trainable_backbone_layers )
422
-
423
494
anchor_sizes = (
424
495
(
425
496
32 ,
@@ -430,21 +501,29 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
430
501
),
431
502
) * 3
432
503
aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
433
-
434
504
model = FasterRCNN (
435
505
backbone , num_classes , rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios ), ** kwargs
436
506
)
437
- if pretrained :
438
- if model_urls .get (weights_name , None ) is None :
439
- raise ValueError (f"No checkpoint is available for model { weights_name } " )
440
- state_dict = load_state_dict_from_url (model_urls [weights_name ], progress = progress )
441
- model .load_state_dict (state_dict )
507
+
508
+ if weights is not None :
509
+ model .load_state_dict (weights .get_state_dict (progress = progress ))
510
+
442
511
return model
443
512
444
513
514
+ @handle_legacy_interface (
515
+ weights = ("pretrained" , FasterRCNN_MobileNet_V3_Large_320_FPN_Weights .COCO_V1 ),
516
+ weights_backbone = ("pretrained_backbone" , MobileNet_V3_Large_Weights .IMAGENET1K_V1 ),
517
+ )
445
518
def fasterrcnn_mobilenet_v3_large_320_fpn (
446
- pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True , trainable_backbone_layers = None , ** kwargs
447
- ):
519
+ * ,
520
+ weights : Optional [FasterRCNN_MobileNet_V3_Large_320_FPN_Weights ] = None ,
521
+ progress : bool = True ,
522
+ num_classes : Optional [int ] = None ,
523
+ weights_backbone : Optional [MobileNet_V3_Large_Weights ] = None ,
524
+ trainable_backbone_layers : Optional [int ] = None ,
525
+ ** kwargs : Any ,
526
+ ) -> FasterRCNN :
448
527
"""
449
528
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
450
529
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
@@ -459,15 +538,17 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
459
538
>>> predictions = model(x)
460
539
461
540
Args:
462
- pretrained (bool ): If True, returns a model pre-trained on COCO train2017
541
+ weights (FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, optional ): The pretrained weights for the model
463
542
progress (bool): If True, displays a progress bar of the download to stderr
464
- num_classes (int): number of output classes of the model (including the background)
465
- pretrained_backbone (bool ): If True, returns a model with backbone pre-trained on Imagenet
466
- trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
543
+ num_classes (int, optional ): number of output classes of the model (including the background)
544
+ weights_backbone (MobileNet_V3_Large_Weights, optional ): The pretrained weights for the backbone
545
+ trainable_backbone_layers (int, optional ): number of trainable (not frozen) layers starting from final block.
467
546
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
468
547
passed (the default) this value is set to 3.
469
548
"""
470
- weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco"
549
+ weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights .verify (weights )
550
+ weights_backbone = MobileNet_V3_Large_Weights .verify (weights_backbone )
551
+
471
552
defaults = {
472
553
"min_size" : 320 ,
473
554
"max_size" : 640 ,
@@ -478,19 +559,28 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
478
559
479
560
kwargs = {** defaults , ** kwargs }
480
561
return _fasterrcnn_mobilenet_v3_large_fpn (
481
- weights_name ,
482
- pretrained = pretrained ,
562
+ weights = weights ,
483
563
progress = progress ,
484
564
num_classes = num_classes ,
485
- pretrained_backbone = pretrained_backbone ,
565
+ weights_backbone = weights_backbone ,
486
566
trainable_backbone_layers = trainable_backbone_layers ,
487
567
** kwargs ,
488
568
)
489
569
490
570
571
+ @handle_legacy_interface (
572
+ weights = ("pretrained" , FasterRCNN_MobileNet_V3_Large_FPN_Weights .COCO_V1 ),
573
+ weights_backbone = ("pretrained_backbone" , MobileNet_V3_Large_Weights .IMAGENET1K_V1 ),
574
+ )
491
575
def fasterrcnn_mobilenet_v3_large_fpn (
492
- pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True , trainable_backbone_layers = None , ** kwargs
493
- ):
576
+ * ,
577
+ weights : Optional [FasterRCNN_MobileNet_V3_Large_FPN_Weights ] = None ,
578
+ progress : bool = True ,
579
+ num_classes : Optional [int ] = None ,
580
+ weights_backbone : Optional [MobileNet_V3_Large_Weights ] = None ,
581
+ trainable_backbone_layers : Optional [int ] = None ,
582
+ ** kwargs : Any ,
583
+ ) -> FasterRCNN :
494
584
"""
495
585
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
496
586
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
@@ -505,26 +595,27 @@ def fasterrcnn_mobilenet_v3_large_fpn(
505
595
>>> predictions = model(x)
506
596
507
597
Args:
508
- pretrained (bool ): If True, returns a model pre-trained on COCO train2017
598
+ weights (FasterRCNN_MobileNet_V3_Large_FPN_Weights, optional ): The pretrained weights for the model
509
599
progress (bool): If True, displays a progress bar of the download to stderr
510
- num_classes (int): number of output classes of the model (including the background)
511
- pretrained_backbone (bool ): If True, returns a model with backbone pre-trained on Imagenet
512
- trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
600
+ num_classes (int, optional ): number of output classes of the model (including the background)
601
+ weights_backbone (MobileNet_V3_Large_Weights, optional ): The pretrained weights for the backbone
602
+ trainable_backbone_layers (int, optional ): number of trainable (not frozen) layers starting from final block.
513
603
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
514
604
passed (the default) this value is set to 3.
515
605
"""
516
- weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco"
606
+ weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights .verify (weights )
607
+ weights_backbone = MobileNet_V3_Large_Weights .verify (weights_backbone )
608
+
517
609
defaults = {
518
610
"rpn_score_thresh" : 0.05 ,
519
611
}
520
612
521
613
kwargs = {** defaults , ** kwargs }
522
614
return _fasterrcnn_mobilenet_v3_large_fpn (
523
- weights_name ,
524
- pretrained = pretrained ,
615
+ weights = weights ,
525
616
progress = progress ,
526
617
num_classes = num_classes ,
527
- pretrained_backbone = pretrained_backbone ,
618
+ weights_backbone = weights_backbone ,
528
619
trainable_backbone_layers = trainable_backbone_layers ,
529
620
** kwargs ,
530
621
)
0 commit comments