Skip to content

Commit 4ca472e

Browse files
committed
Adding the best config along with its weights and documentation.
1 parent d4024cb commit 4ca472e

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

docs/source/models.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ Faster R-CNN MobileNetV3-Large FPN 32.8 - -
427427
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
428428
RetinaNet ResNet-50 FPN 36.4 - -
429429
SSD VGG16 25.1 - -
430+
SSDlite MobileNetV3-Large 21.3 - -
430431
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
431432
====================================== ======= ======== ===========
432433

@@ -486,6 +487,7 @@ Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415
486487
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
487488
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
488489
SSD VGG16 0.2093 0.0744 1.5
490+
SSDlite MobileNetV3-Large 0.1773 0.0906 1.5
489491
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
490492
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
491493
====================================== =================== ================== ===========
@@ -511,6 +513,12 @@ SSD
511513
.. autofunction:: torchvision.models.detection.ssd300_vgg16
512514

513515

516+
SSDlite
517+
------------
518+
519+
.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large
520+
521+
514522
Mask R-CNN
515523
----------
516524

references/detection/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
5656
--weight-decay 0.0005 --data-augmentation ssd
5757
```
5858

59+
### SSDlite MobileNetV3-Large
60+
```
61+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
62+
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
63+
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
64+
--weight-decay 0.00004 --data-augmentation ssdlite
65+
```
66+
5967

6068
### Mask R-CNN
6169
```
Binary file not shown.

torchvision/models/detection/ssdlite.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
__all__ = ['ssdlite320_mobilenet_v3_large']
1818

1919
model_urls = {
20-
'ssd320_mobilenet_v3_large_coco': None # TODO: add weights
20+
'ssdlite320_mobilenet_v3_large_coco':
21+
'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth'
2122
}
2223

2324

@@ -164,6 +165,27 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
164165
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None,
165166
norm_layer: Optional[Callable[..., nn.Module]] = None,
166167
**kwargs: Any):
168+
"""
169+
Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details.
170+
171+
Example:
172+
173+
>>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)
174+
>>> model.eval()
175+
>>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)]
176+
>>> predictions = model(x)
177+
178+
Args:
179+
norm_layer:
180+
**kwargs:
181+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
182+
progress (bool): If True, displays a progress bar of the download to stderr
183+
num_classes (int): number of output classes of the model (including the background)
184+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
185+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
186+
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
187+
norm_layer (callable, optional): Module specifying the normalization layer to use.
188+
"""
167189
trainable_backbone_layers = _validate_trainable_layers(
168190
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6)
169191

@@ -186,10 +208,10 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
186208
assert len(out_channels) == len(anchor_generator.aspect_ratios)
187209

188210
defaults = {
189-
"score_thresh": 1e-8,
190-
"nms_thresh": 0.6,
191-
"detections_per_img": 100,
192-
"topk_candidates": 100,
211+
"score_thresh": 0.001,
212+
"nms_thresh": 0.55,
213+
"detections_per_img": 300,
214+
"topk_candidates": 300,
193215
}
194216
kwargs = {**defaults, **kwargs}
195217
model = SSD(backbone, anchor_generator, size, num_classes,

0 commit comments

Comments
 (0)