Skip to content

Commit e35793a

Browse files
authored
Cerrypicking cleanups for SSD and SSDlite. (#3818)
1 parent 6374cff commit e35793a

File tree

4 files changed

+41
-38
lines changed

4 files changed

+41
-38
lines changed

docs/source/models.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,8 @@ Faster R-CNN ResNet-50 FPN 37.0 - -
426426
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
427427
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
428428
RetinaNet ResNet-50 FPN 36.4 - -
429-
SSD VGG16 25.1 - -
430-
SSDlite MobileNetV3-Large 21.3 - -
429+
SSD300 VGG16 25.1 - -
430+
SSDlite320 MobileNetV3-Large 21.3 - -
431431
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
432432
====================================== ======= ======== ===========
433433

@@ -486,8 +486,8 @@ Faster R-CNN ResNet-50 FPN 0.2288 0.0590
486486
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
487487
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
488488
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
489-
SSD VGG16 0.2093 0.0744 1.5
490-
SSDlite MobileNetV3-Large 0.1773 0.0906 1.5
489+
SSD300 VGG16 0.2093 0.0744 1.5
490+
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
491491
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
492492
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
493493
====================================== =================== ================== ===========
@@ -502,19 +502,19 @@ Faster R-CNN
502502

503503

504504
RetinaNet
505-
------------
505+
---------
506506

507507
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn
508508

509509

510510
SSD
511-
------------
511+
---
512512

513513
.. autofunction:: torchvision.models.detection.ssd300_vgg16
514514

515515

516516
SSDlite
517-
------------
517+
-------
518518

519519
.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large
520520

references/detection/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
4848
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
4949
```
5050

51-
### SSD VGG16
51+
### SSD300 VGG16
5252
```
5353
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
5454
--dataset coco --model ssd300_vgg16 --epochs 120\
5555
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
5656
--weight-decay 0.0005 --data-augmentation ssd
5757
```
5858

59-
### SSDlite MobileNetV3-Large
59+
### SSDlite320 MobileNetV3-Large
6060
```
6161
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
6262
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\

torchvision/models/detection/ssd.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors:
410410

411411

412412
class SSDFeatureExtractorVGG(nn.Module):
413-
def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool):
413+
def __init__(self, backbone: nn.Module, highres: bool):
414414
super().__init__()
415415

416416
_, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d))
@@ -476,13 +476,8 @@ def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool):
476476
fc,
477477
))
478478
self.extra = extra
479-
self.rescaling = rescaling
480479

481480
def forward(self, x: Tensor) -> Dict[str, Tensor]:
482-
# Undo the 0-1 scaling of toTensor. Necessary for some backbones.
483-
if self.rescaling:
484-
x *= 255
485-
486481
# L2 regularization + Rescaling of 1st block's feature map
487482
x = self.features(x)
488483
rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
@@ -496,8 +491,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
496491
return OrderedDict([(str(i), v) for i, v in enumerate(output)])
497492

498493

499-
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int,
500-
rescaling: bool):
494+
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int):
501495
if backbone_name in backbone_urls:
502496
# Use custom backbones more appropriate for SSD
503497
arch = backbone_name.split('_')[0]
@@ -521,19 +515,19 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained
521515
for parameter in b.parameters():
522516
parameter.requires_grad_(False)
523517

524-
return SSDFeatureExtractorVGG(backbone, highres, rescaling)
518+
return SSDFeatureExtractorVGG(backbone, highres)
525519

526520

527521
def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
528522
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any):
529523
"""
530-
Constructs an SSD model with a VGG16 backbone. See `SSD` for more details.
524+
Constructs an SSD model with input size 300x300 and a VGG16 backbone. See `SSD` for more details.
531525
532526
Example:
533527
534528
>>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True)
535529
>>> model.eval()
536-
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
530+
>>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)]
537531
>>> predictions = model(x)
538532
539533
Args:
@@ -544,19 +538,28 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
544538
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
545539
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
546540
"""
541+
if "size" in kwargs:
542+
warnings.warn("The size of the model is already fixed; ignoring the argument.")
543+
547544
trainable_backbone_layers = _validate_trainable_layers(
548545
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5)
549546

550547
if pretrained:
551548
# no need to download the backbone if pretrained is set
552549
pretrained_backbone = False
553550

554-
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers, True)
551+
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers)
555552
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]],
556553
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
557554
steps=[8, 16, 32, 64, 100, 300])
558-
model = SSD(backbone, anchor_generator, (300, 300), num_classes,
559-
image_mean=[0.48235, 0.45882, 0.40784], image_std=[1., 1., 1.], **kwargs)
555+
556+
defaults = {
557+
# Rescale the input in a way compatible to the backbone
558+
"image_mean": [0.48235, 0.45882, 0.40784],
559+
"image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor
560+
}
561+
kwargs = {**defaults, **kwargs}
562+
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
560563
if pretrained:
561564
weights_name = 'ssd300_vgg16_coco'
562565
if model_urls.get(weights_name, None) is None:

torchvision/models/detection/ssdlite.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import warnings
23

34
from collections import OrderedDict
45
from functools import partial
@@ -94,8 +95,7 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C
9495

9596

9697
class SSDLiteFeatureExtractorMobileNet(nn.Module):
97-
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool,
98-
**kwargs: Any):
98+
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], **kwargs: Any):
9999
super().__init__()
100100
# non-public config parameters
101101
min_depth = kwargs.pop('_min_depth', 16)
@@ -117,13 +117,8 @@ def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., n
117117
_normal_init(extra)
118118

119119
self.extra = extra
120-
self.rescaling = rescaling
121120

122121
def forward(self, x: Tensor) -> Dict[str, Tensor]:
123-
# Rescale from [0, 1] to [-1, -1]
124-
if self.rescaling:
125-
x = 2.0 * x - 1.0
126-
127122
# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
128123
output = []
129124
for block in self.features:
@@ -138,7 +133,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
138133

139134

140135
def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int,
141-
norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any):
136+
norm_layer: Callable[..., nn.Module], **kwargs: Any):
142137
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress,
143138
norm_layer=norm_layer, **kwargs).features
144139
if not pretrained:
@@ -158,15 +153,15 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t
158153
for parameter in b.parameters():
159154
parameter.requires_grad_(False)
160155

161-
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs)
156+
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)
162157

163158

164159
def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
165160
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None,
166161
norm_layer: Optional[Callable[..., nn.Module]] = None,
167162
**kwargs: Any):
168163
"""
169-
Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details.
164+
Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone. See `SSD` for more details.
170165
171166
Example:
172167
@@ -186,20 +181,23 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
186181
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
187182
norm_layer (callable, optional): Module specifying the normalization layer to use.
188183
"""
184+
if "size" in kwargs:
185+
warnings.warn("The size of the model is already fixed; ignoring the argument.")
186+
189187
trainable_backbone_layers = _validate_trainable_layers(
190188
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6)
191189

192190
if pretrained:
193191
pretrained_backbone = False
194192

195-
# Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected
196-
rescaling = reduce_tail = not pretrained_backbone
193+
# Enable reduced tail if no pretrained backbone is selected
194+
reduce_tail = not pretrained_backbone
197195

198196
if norm_layer is None:
199197
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
200198

201199
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
202-
norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0)
200+
norm_layer, _reduced_tail=reduce_tail, _width_mult=1.0)
203201

204202
size = (320, 320)
205203
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
@@ -212,8 +210,10 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
212210
"nms_thresh": 0.55,
213211
"detections_per_img": 300,
214212
"topk_candidates": 300,
215-
"image_mean": [0., 0., 0.],
216-
"image_std": [1., 1., 1.],
213+
# Rescale the input in a way compatible to the backbone:
214+
# The following mean/std rescale the data from [0, 1] to [-1, -1]
215+
"image_mean": [0.5, 0.5, 0.5],
216+
"image_std": [0.5, 0.5, 0.5],
217217
}
218218
kwargs = {**defaults, **kwargs}
219219
model = SSD(backbone, anchor_generator, size, num_classes,

0 commit comments

Comments
 (0)