diff --git a/docs/source/models.rst b/docs/source/models.rst index c70cd07979f..f9fb793ed36 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -427,6 +427,7 @@ Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - RetinaNet ResNet-50 FPN 36.4 - - SSD VGG16 25.1 - - +SSDlite MobileNetV3-Large 21.3 - - Mask R-CNN ResNet-50 FPN 37.9 34.6 - ====================================== ======= ======== =========== @@ -486,6 +487,7 @@ Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 SSD VGG16 0.2093 0.0744 1.5 +SSDlite MobileNetV3-Large 0.1773 0.0906 1.5 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 ====================================== =================== ================== =========== @@ -511,6 +513,12 @@ SSD .. autofunction:: torchvision.models.detection.ssd300_vgg16 +SSDlite +------------ + +.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large + + Mask R-CNN ---------- diff --git a/references/detection/README.md b/references/detection/README.md index e4d52869d35..2fb0b658aa7 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -56,6 +56,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --weight-decay 0.0005 --data-augmentation ssd ``` +### SSDlite MobileNetV3-Large +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ + --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\ + --aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\ + --weight-decay 0.00004 --data-augmentation ssdlite +``` + ### Mask R-CNN ``` diff --git a/references/detection/presets.py b/references/detection/presets.py index 22937cf9576..1fac69ae356 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -16,6 +16,12 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)): T.RandomHorizontalFlip(p=hflip_prob), T.ToTensor(), ]) + elif data_augmentation == 'ssdlite': + self.transforms = T.Compose([ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.ToTensor(), + ]) else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') diff --git a/references/detection/train.py b/references/detection/train.py index 4eb39bf17f5..cd4148e9bf7 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -73,9 +73,13 @@ def get_args_parser(add_help=True): parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') - parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs') - parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs') - parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') + parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)') + parser.add_argument('--lr-step-size', default=8, type=int, + help='decrease lr every step-size epochs (multisteplr scheduler only)') + parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, + help='decrease lr every step-size epochs (multisteplr scheduler only)') + parser.add_argument('--lr-gamma', default=0.1, type=float, + help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)') parser.add_argument('--print-freq', default=20, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') @@ -85,6 +89,12 @@ def get_args_parser(add_help=True): parser.add_argument('--trainable-backbone-layers', default=None, type=int, help='number of trainable layers of backbone') parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)') + parser.add_argument( + "--sync-bn", + dest="sync_bn", + help="Use sync batch norm", + action="store_true", + ) parser.add_argument( "--test-only", dest="test_only", @@ -156,6 +166,8 @@ def main(args): model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, **kwargs) model.to(device) + if args.distributed and args.sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model_without_ddp = model if args.distributed: @@ -166,8 +178,14 @@ def main(args): optimizer = torch.optim.SGD( params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) + args.lr_scheduler = args.lr_scheduler.lower() + if args.lr_scheduler == 'multisteplr': + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) + elif args.lr_scheduler == 'cosineannealinglr': + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + else: + raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR " + "are supported.".format(args.lr_scheduler)) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') diff --git a/test/expect/ModelTester.test_ssdlite320_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_ssdlite320_mobilenet_v3_large_expect.pkl new file mode 100644 index 00000000000..f314346af01 Binary files /dev/null and b/test/expect/ModelTester.test_ssdlite320_mobilenet_v3_large_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 157288a2c32..401c4175ccf 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -46,6 +46,7 @@ def get_available_video_models(): "keypointrcnn_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1], "ssd300_vgg16": lambda x: x[1], + "ssdlite320_mobilenet_v3_large": lambda x: x[1], } diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 34e2b6d8dcf..4772415b3b1 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -3,3 +3,4 @@ from .keypoint_rcnn import * from .retinanet import * from .ssd import * +from .ssdlite import * diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 3e0740036c8..06ecc551442 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -206,8 +206,8 @@ def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int] else: y_f_k, x_f_k = f_k - shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k - shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k + shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype) + shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py new file mode 100644 index 00000000000..412434dabd7 --- /dev/null +++ b/torchvision/models/detection/ssdlite.py @@ -0,0 +1,228 @@ +import torch + +from collections import OrderedDict +from functools import partial +from torch import nn, Tensor +from typing import Any, Callable, Dict, List, Optional, Tuple + +from . import _utils as det_utils +from .ssd import SSD, SSDScoringHead +from .anchor_utils import DefaultBoxGenerator +from .backbone_utils import _validate_trainable_layers +from .. import mobilenet +from ..mobilenetv3 import ConvBNActivation +from ..utils import load_state_dict_from_url + + +__all__ = ['ssdlite320_mobilenet_v3_large'] + +model_urls = { + 'ssdlite320_mobilenet_v3_large_coco': + 'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth' +} + + +def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, + norm_layer: Callable[..., nn.Module]) -> nn.Sequential: + return nn.Sequential( + # 3x3 depthwise with stride 1 and padding 1 + ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, + norm_layer=norm_layer, activation_layer=nn.ReLU6), + + # 1x1 projetion to output channels + nn.Conv2d(in_channels, out_channels, 1) + ) + + +def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential: + activation = nn.ReLU6 + intermediate_channels = out_channels // 2 + return nn.Sequential( + # 1x1 projection to half output channels + ConvBNActivation(in_channels, intermediate_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation), + + # 3x3 depthwise with stride 2 and padding 1 + ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, + groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), + + # 1x1 projetion to output channels + ConvBNActivation(intermediate_channels, out_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation), + ) + + +def _normal_init(conv: nn.Module): + for layer in conv.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03) + if layer.bias is not None: + torch.nn.init.constant_(layer.bias, 0.0) + + +class SSDLiteHead(nn.Module): + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, + norm_layer: Callable[..., nn.Module]): + super().__init__() + self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer) + self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer) + + def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: + return { + 'bbox_regression': self.regression_head(x), + 'cls_logits': self.classification_head(x), + } + + +class SSDLiteClassificationHead(SSDScoringHead): + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, + norm_layer: Callable[..., nn.Module]): + cls_logits = nn.ModuleList() + for channels, anchors in zip(in_channels, num_anchors): + cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer)) + _normal_init(cls_logits) + super().__init__(cls_logits, num_classes) + + +class SSDLiteRegressionHead(SSDScoringHead): + def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]): + bbox_reg = nn.ModuleList() + for channels, anchors in zip(in_channels, num_anchors): + bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer)) + _normal_init(bbox_reg) + super().__init__(bbox_reg, 4) + + +class SSDLiteFeatureExtractorMobileNet(nn.Module): + def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool, + **kwargs: Any): + super().__init__() + # non-public config parameters + min_depth = kwargs.pop('_min_depth', 16) + width_mult = kwargs.pop('_width_mult', 1.0) + + assert not backbone[c4_pos].use_res_connect + self.features = nn.Sequential( + nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer + nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end + ) + + get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731 + extra = nn.ModuleList([ + _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), + _extra_block(get_depth(512), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(128), norm_layer), + ]) + _normal_init(extra) + + self.extra = extra + self.rescaling = rescaling + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + # Rescale from [0, 1] to [-1, -1] + if self.rescaling: + x = 2.0 * x - 1.0 + + # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations. + output = [] + for block in self.features: + x = block(x) + output.append(x) + + for block in self.extra: + x = block(x) + output.append(x) + + return OrderedDict([(str(i), v) for i, v in enumerate(output)]) + + +def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, + norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any): + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, + norm_layer=norm_layer, **kwargs).features + if not pretrained: + # Change the default initialization scheme if not pretrained + _normal_init(backbone) + + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + num_stages = len(stage_indices) + + # find the index of the layer from which we wont freeze + assert 0 <= trainable_layers <= num_stages + freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs) + + +def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, + pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any): + """ + Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details. + + Example: + + >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + norm_layer: + **kwargs: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + norm_layer (callable, optional): Module specifying the normalization layer to use. + """ + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) + + if pretrained: + pretrained_backbone = False + + # Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected + rescaling = reduce_tail = not pretrained_backbone + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) + + backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, + norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0) + + size = (320, 320) + anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) + out_channels = det_utils.retrieve_out_channels(backbone, size) + num_anchors = anchor_generator.num_anchors_per_location() + assert len(out_channels) == len(anchor_generator.aspect_ratios) + + defaults = { + "score_thresh": 0.001, + "nms_thresh": 0.55, + "detections_per_img": 300, + "topk_candidates": 300, + "image_mean": [0., 0., 0.], + "image_std": [1., 1., 1.], + } + kwargs = {**defaults, **kwargs} + model = SSD(backbone, anchor_generator, size, num_classes, + head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs) + + if pretrained: + weights_name = 'ssdlite320_mobilenet_v3_large_coco' + if model_urls.get(weights_name, None) is None: + raise ValueError("No checkpoint is available for model {}".format(weights_name)) + state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) + model.load_state_dict(state_dict) + return model