diff --git a/docs/source/models.rst b/docs/source/models.rst index 007fd5ea229..66ebf0e211d 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -350,6 +350,7 @@ the instances set of COCO train2017 and evaluated on COCO val2017. Network box AP mask AP keypoint AP ================================ ======= ======== =========== Faster R-CNN ResNet-50 FPN 37.0 - - +RetinaNet ResNet-50 FPN 36.4 - - Mask R-CNN ResNet-50 FPN 37.9 34.6 - ================================ ======= ======== =========== @@ -405,6 +406,7 @@ precision-recall. Network train time (s / it) test time (s / it) memory (GB) ============================== =================== ================== =========== Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 +RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 ============================== =================== ================== =========== @@ -416,6 +418,12 @@ Faster R-CNN .. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn +RetinaNet +------------ + +.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn + + Mask R-CNN ---------- diff --git a/references/detection/README.md b/references/detection/README.md index 280d2b26f43..f89e8149a71 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -27,6 +27,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --lr-steps 16 22 --aspect-ratio-group-factor 3 ``` +### RetinaNet +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ + --dataset coco --model retinanet_resnet50_fpn --epochs 26\ + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 +``` + ### Mask R-CNN ``` diff --git a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl new file mode 100644 index 00000000000..aed2976cb6c Binary files /dev/null and b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl differ diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 3b2683ae328..78ddc61c144 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,3 +1,4 @@ from .faster_rcnn import * from .mask_rcnn import * from .keypoint_rcnn import * +from .retinanet import * diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py new file mode 100644 index 00000000000..94bb006c5dd --- /dev/null +++ b/torchvision/models/detection/anchor_utils.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn + +from torch.jit.annotations import List, Optional, Dict +from .image_list import ImageList + + +class AnchorGenerator(nn.Module): + """ + Module that generates anchors for a set of feature maps and + image sizes. + + The module support computing anchors at multiple sizes and aspect ratios + per feature map. This module assumes aspect ratio = height / width for + each anchor. + + sizes and aspect_ratios should have the same number of elements, and it should + correspond to the number of feature maps. + + sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, + and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors + per spatial location for feature map i. + + Arguments: + sizes (Tuple[Tuple[int]]): + aspect_ratios (Tuple[Tuple[float]]): + """ + + __annotations__ = { + "cell_anchors": Optional[List[torch.Tensor]], + "_cache": Dict[str, List[torch.Tensor]] + } + + def __init__( + self, + sizes=((128, 256, 512),), + aspect_ratios=((0.5, 1.0, 2.0),), + ): + super(AnchorGenerator, self).__init__() + + if not isinstance(sizes[0], (list, tuple)): + # TODO change this + sizes = tuple((s,) for s in sizes) + if not isinstance(aspect_ratios[0], (list, tuple)): + aspect_ratios = (aspect_ratios,) * len(sizes) + + assert len(sizes) == len(aspect_ratios) + + self.sizes = sizes + self.aspect_ratios = aspect_ratios + self.cell_anchors = None + self._cache = {} + + # TODO: https://github.com/pytorch/pytorch/issues/26792 + # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. + # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) + # This method assumes aspect ratio = height / width for an anchor. + def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): + # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 + scales = torch.as_tensor(scales, dtype=dtype, device=device) + aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) + h_ratios = torch.sqrt(aspect_ratios) + w_ratios = 1 / h_ratios + + ws = (w_ratios[:, None] * scales[None, :]).view(-1) + hs = (h_ratios[:, None] * scales[None, :]).view(-1) + + base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 + return base_anchors.round() + + def set_cell_anchors(self, dtype, device): + # type: (int, Device) -> None # noqa: F821 + if self.cell_anchors is not None: + cell_anchors = self.cell_anchors + assert cell_anchors is not None + # suppose that all anchors have the same device + # which is a valid assumption in the current state of the codebase + if cell_anchors[0].device == device: + return + + cell_anchors = [ + self.generate_anchors( + sizes, + aspect_ratios, + dtype, + device + ) + for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) + ] + self.cell_anchors = cell_anchors + + def num_anchors_per_location(self): + return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] + + # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), + # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. + def grid_anchors(self, grid_sizes, strides): + # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] + anchors = [] + cell_anchors = self.cell_anchors + assert cell_anchors is not None + assert len(grid_sizes) == len(strides) == len(cell_anchors) + + for size, stride, base_anchors in zip( + grid_sizes, strides, cell_anchors + ): + grid_height, grid_width = size + stride_height, stride_width = stride + device = base_anchors.device + + # For output anchor, compute [x_center, y_center, x_center, y_center] + shifts_x = torch.arange( + 0, grid_width, dtype=torch.float32, device=device + ) * stride_width + shifts_y = torch.arange( + 0, grid_height, dtype=torch.float32, device=device + ) * stride_height + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + # For every (base anchor, output anchor) pair, + # offset each zero-centered base anchor by the center of the output anchor. + anchors.append( + (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) + ) + + return anchors + + def cached_grid_anchors(self, grid_sizes, strides): + # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] + key = str(grid_sizes) + str(strides) + if key in self._cache: + return self._cache[key] + anchors = self.grid_anchors(grid_sizes, strides) + self._cache[key] = anchors + return anchors + + def forward(self, image_list, feature_maps): + # type: (ImageList, List[Tensor]) -> List[Tensor] + grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) + image_size = image_list.tensors.shape[-2:] + dtype, device = feature_maps[0].dtype, feature_maps[0].device + strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), + torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] + self.set_cell_anchors(dtype, device) + anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) + anchors = torch.jit.annotate(List[List[torch.Tensor]], []) + for i, (image_height, image_width) in enumerate(image_list.image_sizes): + anchors_in_image = [] + for anchors_per_feature_map in anchors_over_all_feature_maps: + anchors_in_image.append(anchors_per_feature_map) + anchors.append(anchors_in_image) + anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] + # Clear the cache in case that memory leaks. + self._cache.clear() + return anchors diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 8d5d56404e8..c0527e544b3 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module): Attributes: out_channels (int): the number of channels in the FPN """ - def __init__(self, backbone, return_layers, in_channels_list, out_channels): + def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): super(BackboneWithFPN, self).__init__() + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.fpn = FeaturePyramidNetwork( in_channels_list=in_channels_list, out_channels=out_channels, - extra_blocks=LastLevelMaxPool(), + extra_blocks=extra_blocks, ) self.out_channels = out_channels @@ -41,7 +45,14 @@ def forward(self, x): return x -def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3): +def resnet_fpn_backbone( + backbone_name, + pretrained, + norm_layer=misc_nn_ops.FrozenBatchNorm2d, + trainable_layers=3, + returned_layers=None, + extra_blocks=None +): """ Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. @@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) - return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'} + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + if returned_layers is None: + returned_layers = [1, 2, 3, 4] + assert min(returned_layers) > 0 and max(returned_layers) < 5 + return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)} in_channels_stage2 = backbone.inplanes // 8 - in_channels_list = [ - in_channels_stage2, - in_channels_stage2 * 2, - in_channels_stage2 * 4, - in_channels_stage2 * 8, - ] + in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] out_channels = 256 - return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels) + return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index c7e6c6d12db..117d985eca6 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -9,8 +9,9 @@ from ..utils import load_state_dict_from_url +from .anchor_utils import AnchorGenerator from .generalized_rcnn import GeneralizedRCNN -from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads from .transform import GeneralizedRCNNTransform from .backbone_utils import resnet_fpn_backbone diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 438ef225c91..f1f4ad26680 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN): >>> import torch >>> import torchvision >>> from torchvision.models.detection import KeypointRCNN - >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> >>> # load a pre-trained model for classification and return >>> # only the features diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index af6ac8fb17c..668d8ab8122 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN): >>> import torch >>> import torchvision >>> from torchvision.models.detection import MaskRCNN - >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> >>> # load a pre-trained model for classification and return >>> # only the features diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py new file mode 100644 index 00000000000..c124bed79c8 --- /dev/null +++ b/torchvision/models/detection/retinanet.py @@ -0,0 +1,627 @@ +import math +from collections import OrderedDict +import warnings + +import torch +import torch.nn as nn +from torch import Tensor +from torch.jit.annotations import Dict, List, Tuple + +from ..utils import load_state_dict_from_url + +from . import _utils as det_utils +from .anchor_utils import AnchorGenerator +from .transform import GeneralizedRCNNTransform +from .backbone_utils import resnet_fpn_backbone +from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...ops import sigmoid_focal_loss +from ...ops import boxes as box_ops + + +__all__ = [ + "RetinaNet", "retinanet_resnet50_fpn", +] + + +def _sum(x: List[Tensor]) -> Tensor: + res = x[0] + for i in x[1:]: + res = res + i + return res + + +class RetinaNetHead(nn.Module): + """ + A regression and classification head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + """ + + def __init__(self, in_channels, num_anchors, num_classes): + super().__init__() + self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) + self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] + return { + 'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs), + 'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + } + + def forward(self, x): + # type: (List[Tensor]) -> Dict[str, Tensor] + return { + 'cls_logits': self.classification_head(x), + 'bbox_regression': self.regression_head(x) + } + + +class RetinaNetClassificationHead(nn.Module): + """ + A classification head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + """ + + def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01): + super().__init__() + + conv = [] + for _ in range(4): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.constant_(layer.bias, 0) + + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) + + self.num_classes = num_classes + self.num_anchors = num_anchors + + # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript. + # TorchScript doesn't support class attributes. + # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584 + self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS + + def compute_loss(self, targets, head_outputs, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor + losses = [] + + cls_logits = head_outputs['cls_logits'] + + for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): + # determine only the foreground + foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() + # no matched_idxs means there were no annotations in this image + # TODO: enable support for images without annotations that works on distributed + if False: # matched_idxs_per_image.numel() == 0: + gt_classes_target = torch.zeros_like(cls_logits_per_image) + valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0]) + else: + # create the target classification + gt_classes_target = torch.zeros_like(cls_logits_per_image) + gt_classes_target[ + foreground_idxs_per_image, + targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] + ] = 1.0 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS + + # compute the classification loss + losses.append(sigmoid_focal_loss( + cls_logits_per_image[valid_idxs_per_image], + gt_classes_target[valid_idxs_per_image], + reduction='sum', + ) / max(1, num_foreground)) + + return _sum(losses) / len(targets) + + def forward(self, x): + # type: (List[Tensor]) -> Tensor + all_cls_logits = [] + + for features in x: + cls_logits = self.conv(features) + cls_logits = self.cls_logits(cls_logits) + + # Permute classification output from (N, A * K, H, W) to (N, HWA, K). + N, _, H, W = cls_logits.shape + cls_logits = cls_logits.view(N, -1, self.num_classes, H, W) + cls_logits = cls_logits.permute(0, 3, 4, 1, 2) + cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) + + all_cls_logits.append(cls_logits) + + return torch.cat(all_cls_logits, dim=1) + + +class RetinaNetRegressionHead(nn.Module): + """ + A regression head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + __annotations__ = { + 'box_coder': det_utils.BoxCoder, + } + + def __init__(self, in_channels, num_anchors): + super().__init__() + + conv = [] + for _ in range(4): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.bbox_reg.weight, std=0.01) + torch.nn.init.zeros_(self.bbox_reg.bias) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) + + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor + losses = [] + + bbox_regression = head_outputs['bbox_regression'] + + for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ + zip(targets, bbox_regression, anchors, matched_idxs): + # no matched_idxs means there were no annotations in this image + # TODO enable support for images without annotations with distributed support + # if matched_idxs_per_image.numel() == 0: + # continue + + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] + + # determine only the foreground indices, ignore the rest + foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() + + # select only the foreground boxes + matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] + bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] + anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] + + # compute the regression targets + target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) + + # compute the loss + losses.append(torch.nn.functional.l1_loss( + bbox_regression_per_image, + target_regression, + size_average=False + ) / max(1, num_foreground)) + + return _sum(losses) / max(1, len(targets)) + + def forward(self, x): + # type: (List[Tensor]) -> Tensor + all_bbox_regression = [] + + for features in x: + bbox_regression = self.conv(features) + bbox_regression = self.bbox_reg(bbox_regression) + + # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). + N, _, H, W = bbox_regression.shape + bbox_regression = bbox_regression.view(N, -1, 4, H, W) + bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2) + bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) + + all_bbox_regression.append(bbox_regression) + + return torch.cat(all_bbox_regression, dim=1) + + +class RetinaNet(nn.Module): + """ + Implements RetinaNet. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values + between 0 and H and 0 and W + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between + 0 and H and 0 and W + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores for each prediction + + Arguments: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or an OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (excluding the background). + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + head (nn.Module): Module run on top of the feature pyramid. + Defaults to a module containing a classification and regression module. + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + detections_per_img (int): Number of best detections to keep after NMS. + fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training. + bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training. + + Example: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import RetinaNet + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> # RetinaNet needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280 + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the network generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator( + >>> sizes=tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]), + >>> aspect_ratios=((0.5, 1.0, 2.0),) * 5 + >>> ) + >>> + >>> # put the pieces together inside a RetinaNet model + >>> model = RetinaNet(backbone, + >>> num_classes=2, + >>> anchor_generator=anchor_generator) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + __annotations__ = { + 'box_coder': det_utils.BoxCoder, + 'proposal_matcher': det_utils.Matcher, + } + + def __init__(self, backbone, num_classes, + # transform parameters + min_size=800, max_size=1333, + image_mean=None, image_std=None, + # Anchor parameters + anchor_generator=None, head=None, + proposal_matcher=None, + score_thresh=0.05, + nms_thresh=0.5, + detections_per_img=300, + fg_iou_thresh=0.5, bg_iou_thresh=0.4): + super().__init__() + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)") + self.backbone = backbone + + assert isinstance(anchor_generator, (AnchorGenerator, type(None))) + + if anchor_generator is None: + anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + anchor_generator = AnchorGenerator( + anchor_sizes, aspect_ratios + ) + self.anchor_generator = anchor_generator + + if head is None: + head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) + self.head = head + + if proposal_matcher is None: + proposal_matcher = det_utils.Matcher( + fg_iou_thresh, + bg_iou_thresh, + allow_low_quality_matches=True, + ) + self.proposal_matcher = proposal_matcher + + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + + # used only on torchscript mode + self._has_warned = False + + @torch.jit.unused + def eager_outputs(self, losses, detections): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + if self.training: + return losses + + return detections + + def compute_loss(self, targets, head_outputs, anchors): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor] + matched_idxs = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + if targets_per_image['boxes'].numel() == 0: + matched_idxs.append(torch.empty((0,), dtype=torch.int32)) + continue + + match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) + matched_idxs.append(self.proposal_matcher(match_quality_matrix)) + + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + + def postprocess_detections(self, head_outputs, anchors, image_shapes): + # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] + # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ? + + class_logits = head_outputs.pop('cls_logits') + box_regression = head_outputs.pop('bbox_regression') + other_outputs = head_outputs + + device = class_logits.device + num_classes = class_logits.shape[-1] + + scores = torch.sigmoid(class_logits) + + # create labels for each score + labels = torch.arange(num_classes, device=device) + labels = labels.view(1, -1).expand_as(scores) + + detections = torch.jit.annotate(List[Dict[str, Tensor]], []) + + for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \ + enumerate(zip(box_regression, scores, labels, anchors, image_shapes)): + + boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image) + boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape) + + other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()] + + image_boxes = [] + image_scores = [] + image_labels = [] + image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {}) + + for class_index in range(num_classes): + # remove low scoring boxes + inds = torch.gt(scores_per_image[:, class_index], self.score_thresh) + boxes_per_class, scores_per_class, labels_per_class = \ + boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index] + other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image] + + # remove empty boxes + keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2) + boxes_per_class, scores_per_class, labels_per_class = \ + boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] + + # non-maximum suppression, independently done per class + keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh) + + # keep only topk scoring predictions + keep = keep[:self.detections_per_img] + boxes_per_class, scores_per_class, labels_per_class = \ + boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] + + image_boxes.append(boxes_per_class) + image_scores.append(scores_per_class) + image_labels.append(labels_per_class) + + for k, v in other_outputs_per_class: + if k not in image_other_outputs: + image_other_outputs[k] = [] + image_other_outputs[k].append(v) + + detections.append({ + 'boxes': torch.cat(image_boxes, dim=0), + 'scores': torch.cat(image_scores, dim=0), + 'labels': torch.cat(image_labels, dim=0), + }) + + for k, v in image_other_outputs.items(): + detections[-1].update({k: torch.cat(v, dim=0)}) + + return detections + + def forward(self, images, targets=None): + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + """ + Arguments: + images (list[Tensor]): images to be processed + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training and targets is None: + raise ValueError("In training mode, targets should be passed") + + if self.training: + assert targets is not None + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError("Expected target boxes to be a tensor" + "of shape [N, 4], got {:}.".format( + boxes.shape)) + else: + raise ValueError("Expected target boxes to be of type " + "Tensor, got {:}.".format(type(boxes))) + + # get the original image sizes + original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + # TODO: Move this to a function + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError("All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}." + .format(degen_bb, target_idx)) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([('0', features)]) + + # TODO: Do we want a list or a dict? + features = list(features.values()) + + # compute the retinanet heads outputs using the features + head_outputs = self.head(features) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + + losses = {} + detections = torch.jit.annotate(List[Dict[str, Tensor]], []) + if self.training: + assert targets is not None + + # compute the losses + losses = self.compute_loss(targets, head_outputs, anchors) + else: + # compute the detections + detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return (losses, detections) + return self.eager_outputs(losses, detections) + + +model_urls = { + 'retinanet_resnet50_fpn_coco': + 'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth', +} + + +def retinanet_resnet50_fpn(pretrained=False, progress=True, + num_classes=91, pretrained_backbone=True, **kwargs): + """ + Constructs a RetinaNet model with a ResNet-50-FPN backbone. + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values + between ``0`` and ``H`` and ``0`` and ``W`` + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between + ``0`` and ``H`` and ``0`` and ``W`` + - labels (``Int64Tensor[N]``): the predicted labels for each image + - scores (``Tensor[N]``): the scores or each prediction + + Example:: + + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Arguments: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + # no need to download the backbone if pretrained is set + pretrained_backbone = False + # skip P2 because it generates too many anchors (according to their paper) + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, + returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)) + model = RetinaNet(backbone, num_classes, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], + progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 1d814ef232c..e4c4478eb54 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -11,6 +11,9 @@ from torch.jit.annotations import List, Optional, Dict, Tuple +# Import AnchorGenerator to keep compatibility. +from .anchor_utils import AnchorGenerator + @torch.jit.unused def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): @@ -24,159 +27,6 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): return num_anchors, pre_nms_top_n -class AnchorGenerator(nn.Module): - __annotations__ = { - "cell_anchors": Optional[List[torch.Tensor]], - "_cache": Dict[str, List[torch.Tensor]] - } - - """ - Module that generates anchors for a set of feature maps and - image sizes. - - The module support computing anchors at multiple sizes and aspect ratios - per feature map. This module assumes aspect ratio = height / width for - each anchor. - - sizes and aspect_ratios should have the same number of elements, and it should - correspond to the number of feature maps. - - sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, - and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors - per spatial location for feature map i. - - Arguments: - sizes (Tuple[Tuple[int]]): - aspect_ratios (Tuple[Tuple[float]]): - """ - - def __init__( - self, - sizes=((128, 256, 512),), - aspect_ratios=((0.5, 1.0, 2.0),), - ): - super(AnchorGenerator, self).__init__() - - if not isinstance(sizes[0], (list, tuple)): - # TODO change this - sizes = tuple((s,) for s in sizes) - if not isinstance(aspect_ratios[0], (list, tuple)): - aspect_ratios = (aspect_ratios,) * len(sizes) - - assert len(sizes) == len(aspect_ratios) - - self.sizes = sizes - self.aspect_ratios = aspect_ratios - self.cell_anchors = None - self._cache = {} - - # TODO: https://github.com/pytorch/pytorch/issues/26792 - # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. - # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) - # This method assumes aspect ratio = height / width for an anchor. - def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): - # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 - scales = torch.as_tensor(scales, dtype=dtype, device=device) - aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) - h_ratios = torch.sqrt(aspect_ratios) - w_ratios = 1 / h_ratios - - ws = (w_ratios[:, None] * scales[None, :]).view(-1) - hs = (h_ratios[:, None] * scales[None, :]).view(-1) - - base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 - return base_anchors.round() - - def set_cell_anchors(self, dtype, device): - # type: (int, Device) -> None # noqa: F821 - if self.cell_anchors is not None: - cell_anchors = self.cell_anchors - assert cell_anchors is not None - # suppose that all anchors have the same device - # which is a valid assumption in the current state of the codebase - if cell_anchors[0].device == device: - return - - cell_anchors = [ - self.generate_anchors( - sizes, - aspect_ratios, - dtype, - device - ) - for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) - ] - self.cell_anchors = cell_anchors - - def num_anchors_per_location(self): - return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] - - # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), - # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. - def grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] - anchors = [] - cell_anchors = self.cell_anchors - assert cell_anchors is not None - assert len(grid_sizes) == len(strides) == len(cell_anchors) - - for size, stride, base_anchors in zip( - grid_sizes, strides, cell_anchors - ): - grid_height, grid_width = size - stride_height, stride_width = stride - device = base_anchors.device - - # For output anchor, compute [x_center, y_center, x_center, y_center] - shifts_x = torch.arange( - 0, grid_width, dtype=torch.float32, device=device - ) * stride_width - shifts_y = torch.arange( - 0, grid_height, dtype=torch.float32, device=device - ) * stride_height - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) - shift_x = shift_x.reshape(-1) - shift_y = shift_y.reshape(-1) - shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) - - # For every (base anchor, output anchor) pair, - # offset each zero-centered base anchor by the center of the output anchor. - anchors.append( - (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) - ) - - return anchors - - def cached_grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] - key = str(grid_sizes) + str(strides) - if key in self._cache: - return self._cache[key] - anchors = self.grid_anchors(grid_sizes, strides) - self._cache[key] = anchors - return anchors - - def forward(self, image_list, feature_maps): - # type: (ImageList, List[Tensor]) -> List[Tensor] - grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) - image_size = image_list.tensors.shape[-2:] - dtype, device = feature_maps[0].dtype, feature_maps[0].device - strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), - torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] - self.set_cell_anchors(dtype, device) - anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) - anchors = torch.jit.annotate(List[List[torch.Tensor]], []) - for i, (image_height, image_width) in enumerate(image_list.image_sizes): - anchors_in_image = [] - for anchors_per_feature_map in anchors_over_all_feature_maps: - anchors_in_image.append(anchors_per_feature_map) - anchors.append(anchors_in_image) - anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] - # Clear the cache in case that memory leaks. - self._cache.clear() - return anchors - - class RPNHead(nn.Module): """ Adds a simple RPN Head with classification and regression heads @@ -338,7 +188,7 @@ def assign_targets_to_anchors(self, anchors, targets): matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device) labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device) else: - match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) + match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image) matched_idxs = self.proposal_matcher(match_quality_matrix) # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 4f94ac447c8..bd82a0d5ed9 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -8,6 +8,7 @@ from .ps_roi_pool import ps_roi_pool, PSRoIPool from .poolers import MultiScaleRoIAlign from .feature_pyramid_network import FeaturePyramidNetwork +from .focal_loss import sigmoid_focal_loss from ._register_onnx_ops import _register_custom_op @@ -19,5 +20,6 @@ 'clip_boxes_to_image', 'box_convert', 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', - 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' + 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', + 'sigmoid_focal_loss' ] diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py new file mode 100644 index 00000000000..1114d80ed47 --- /dev/null +++ b/torchvision/ops/focal_loss.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F + + +def sigmoid_focal_loss( + inputs, + targets, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "none", +): + """ + Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py . + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples or -1 for ignore. Default = 0.25 + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + Returns: + Loss tensor with the reduction option applied. + """ + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits( + inputs, targets, reduction="none" + ) + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss