From 50f822cc57515ba3944607fb02016ae9d71b373e Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 20 Dec 2019 10:43:19 +0100 Subject: [PATCH 01/31] Add rough implementation of RetinaNet. --- torchvision/models/detection/retinanet.py | 389 ++++++++++++++++++++++ 1 file changed, 389 insertions(+) create mode 100644 torchvision/models/detection/retinanet.py diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py new file mode 100644 index 00000000000..ac320177501 --- /dev/null +++ b/torchvision/models/detection/retinanet.py @@ -0,0 +1,389 @@ +from collections import OrderedDict + +import torch +from torch import nn +import torch.nn.functional as F + +from ..utils import load_state_dict_from_url + +from .rpn import AnchorGenerator +from .transform import GeneralizedRCNNTransform +from .backbone_utils import resnet_fpn_backbone + + +__all__ = [ + "RetinaNet", "retinanet_resnet50_fpn", +] + + +class RetinaNetHead(nn.Module): + """ + A regression and classification head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + """ + + def __init__(self, in_channels, num_anchors, num_classes): + super(RPNHead, self).__init__() + self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) + self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) + + def compute_loss(self, outputs, labels, matched_gt_boxes): + return { + 'classification': self.classification_head.compute_loss(outputs, targets, anchor_state), + 'regression': self.regression_head.compute_loss(outputs, targets, anchor_state), + } + + def forward(self, x): + logits = [self.classification_head(feature, targets) for feature in x] + bbox_reg = [self.regression_head(feature, targets) for feature in x] + return dict(logits=logits, bbox_reg=bbox_reg) + + +class RetinaNetClassificationHead(nn.Module): + """ + A classification head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + """ + + def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01): + super(RetinaNetClassificationHead, self).__init__() + + conv = [] + for _ in range(4): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + for l in self.children(): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + def compute_loss(self, outputs, labels, matched_gt_boxes): + # TODO Implement focal loss, is there an existing function for this? + return 0 + + def forward(self, x): + all_cls_logits = [] + + for features in x: + cls_logits = self.conv(features) + cls_logits = self.cls_logits(cls_logits) + + # Permute classification output from (N, A * K, H, W) to (N, HWA, K). + N, _, H, W = cls_logits.shape + cls_logits = cls_logits.view(N, -1, self.num_classes, H, W) + cls_logits = cls_logits.permute(0, 3, 4, 1, 2) + cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) + + all_cls_logits.append(cls_logits) + + return torch.cat(all_cls_logits, dim=1) + + +class RetinaNetRegressionHead(nn.Module): + """ + A regression head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + + def __init__(self, in_channels, num_anchors): + super(RetinaNetRegressionHead, self).__init__() + + conv = [] + for _ in range(4): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) + + for l in self.children(): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + def compute_loss(self, outputs, labels, matched_gt_boxes): + # TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ? + return 0 + + def forward(self, x): + all_bbox_regression = [] + + for features in x: + bbox_regression = self.conv(features) + bbox_regression = self.bbox_reg(bbox_regression) + + # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). + N, _, H, W = bbox_regression.shape + bbox_regression = bbox_regression.view(N, -1, 4, H, W) + bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2) + bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) + + all_bbox_regression.append(bbox_regression) + + return torch.cat(all_bbox_regression, dim=1) + + +class RetinaNet(nn.Module): + """ + Implements RetinaNet. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values + between 0 and H and 0 and W + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between + 0 and H and 0 and W + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores for each prediction + + Arguments: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or an OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (excluding the background). + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + head (nn.Module): Module run on top of the feature pyramid. + Defaults to a module containing a classification and regression module. + pre_nms_top_n (int): number of proposals to keep before applying NMS during testing. + post_nms_top_n (int): number of proposals to keep after applying NMS during testing. + nms_thresh (float): NMS threshold used for postprocessing the detections. + fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training. + bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training. + + Example: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import RetinaNet + >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> # RetinaNet needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280 + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the network generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), + >>> aspect_ratios=((0.5, 1.0, 2.0),)) + >>> + >>> # put the pieces together inside a RetinaNet model + >>> model = RetinaNet(backbone, + >>> num_classes=2, + >>> anchor_generator=anchor_generator) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + def __init__(self, backbone, num_classes, + # transform parameters + min_size=800, max_size=1333, + image_mean=None, image_std=None, + # Anchor parameters + anchor_generator=None, head=None, + pre_nms_top_n=1000, post_nms_top_n=1000, + nms_thresh=0.5, + fg_iou_thresh=0.5, bg_iou_thresh=0.4): + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)") + + assert isinstance(anchor_generator, (AnchorGenerator, type(None))) + + if anchor_generator is None: + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + self.anchor_generator = AnchorGenerator( + anchor_sizes, aspect_ratios + ) + + if head is None: + head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()) + self.head = head + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + + @torch.jit.unused + def eager_outputs(self, losses, detections): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + if self.training: + return losses + + return detections + + def forward(self, images, targets=None): + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) + """ + Arguments: + images (list[Tensor]): images to be processed + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training and targets is None: + raise ValueError("In training mode, targets should be passed") + + # get the original image sizes + original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([('0', features)]) + + # compute the retinanet heads outputs using the features + head_outputs = self.head(images, features, targets) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + + losses = {} + detections = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) + if self.training: + assert targets is not None + + # compute the losses + # TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function + # so that we can use it here and in rpn.RegionProposalNetwork + labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) + regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) + losses = self.head.compute_loss(head_outputs, labels, matched_gt_boxes) + else: + # compute the detections + # TODO: Implement postprocess_detections + boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, anchors) + num_images = len(images) + for i in range(num_images): + detections.append( + { + "boxes": boxes[i], + "labels": labels[i], + "scores": scores[i], + } + ) + + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return (losses, detections) + else: + return self.eager_outputs(losses, detections) + + +model_urls = { + 'retinanet_resnet50_fpn_coco': + '#TODO', +} + + +def retinanet_resnet50_fpn(pretrained=False, progress=True, + num_classes=91, pretrained_backbone=True, **kwargs): + """ + Constructs a RetinaNet model with a ResNet-50-FPN backbone. + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values + between ``0`` and ``H`` and ``0`` and ``W`` + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between + ``0`` and ``H`` and ``0`` and ``W`` + - labels (``Int64Tensor[N]``): the predicted labels for each image + - scores (``Tensor[N]``): the scores or each prediction + + Example:: + + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Arguments: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + # no need to download the backbone if pretrained is set + pretrained_backbone = False + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) + model = RetinaNet(backbone, num_classes, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], + progress=progress) + model.load_state_dict(state_dict) + return model From 022f8e14c82d85edb583848208d5d6243484f656 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 22 May 2020 09:34:54 +0200 Subject: [PATCH 02/31] Move AnchorGenerator to a seperate file. --- torchvision/models/detection/anchor_utils.py | 156 ++++++++++++++++++ torchvision/models/detection/faster_rcnn.py | 3 +- torchvision/models/detection/keypoint_rcnn.py | 2 +- torchvision/models/detection/mask_rcnn.py | 2 +- torchvision/models/detection/rpn.py | 156 +----------------- 5 files changed, 163 insertions(+), 156 deletions(-) create mode 100644 torchvision/models/detection/anchor_utils.py diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py new file mode 100644 index 00000000000..80171ef20d4 --- /dev/null +++ b/torchvision/models/detection/anchor_utils.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn + + +class AnchorGenerator(nn.Module): + __annotations__ = { + "cell_anchors": Optional[List[torch.Tensor]], + "_cache": Dict[str, List[torch.Tensor]] + } + + """ + Module that generates anchors for a set of feature maps and + image sizes. + + The module support computing anchors at multiple sizes and aspect ratios + per feature map. This module assumes aspect ratio = height / width for + each anchor. + + sizes and aspect_ratios should have the same number of elements, and it should + correspond to the number of feature maps. + + sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, + and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors + per spatial location for feature map i. + + Arguments: + sizes (Tuple[Tuple[int]]): + aspect_ratios (Tuple[Tuple[float]]): + """ + + def __init__( + self, + sizes=((128, 256, 512),), + aspect_ratios=((0.5, 1.0, 2.0),), + ): + super(AnchorGenerator, self).__init__() + + if not isinstance(sizes[0], (list, tuple)): + # TODO change this + sizes = tuple((s,) for s in sizes) + if not isinstance(aspect_ratios[0], (list, tuple)): + aspect_ratios = (aspect_ratios,) * len(sizes) + + assert len(sizes) == len(aspect_ratios) + + self.sizes = sizes + self.aspect_ratios = aspect_ratios + self.cell_anchors = None + self._cache = {} + + # TODO: https://github.com/pytorch/pytorch/issues/26792 + # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. + # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) + # This method assumes aspect ratio = height / width for an anchor. + def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): + # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 + scales = torch.as_tensor(scales, dtype=dtype, device=device) + aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) + h_ratios = torch.sqrt(aspect_ratios) + w_ratios = 1 / h_ratios + + ws = (w_ratios[:, None] * scales[None, :]).view(-1) + hs = (h_ratios[:, None] * scales[None, :]).view(-1) + + base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 + return base_anchors.round() + + def set_cell_anchors(self, dtype, device): + # type: (int, Device) -> None # noqa: F821 + if self.cell_anchors is not None: + cell_anchors = self.cell_anchors + assert cell_anchors is not None + # suppose that all anchors have the same device + # which is a valid assumption in the current state of the codebase + if cell_anchors[0].device == device: + return + + cell_anchors = [ + self.generate_anchors( + sizes, + aspect_ratios, + dtype, + device + ) + for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) + ] + self.cell_anchors = cell_anchors + + def num_anchors_per_location(self): + return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] + + # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), + # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. + def grid_anchors(self, grid_sizes, strides): + # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] + anchors = [] + cell_anchors = self.cell_anchors + assert cell_anchors is not None + assert len(grid_sizes) == len(strides) == len(cell_anchors) + + for size, stride, base_anchors in zip( + grid_sizes, strides, cell_anchors + ): + grid_height, grid_width = size + stride_height, stride_width = stride + device = base_anchors.device + + # For output anchor, compute [x_center, y_center, x_center, y_center] + shifts_x = torch.arange( + 0, grid_width, dtype=torch.float32, device=device + ) * stride_width + shifts_y = torch.arange( + 0, grid_height, dtype=torch.float32, device=device + ) * stride_height + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + # For every (base anchor, output anchor) pair, + # offset each zero-centered base anchor by the center of the output anchor. + anchors.append( + (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) + ) + + return anchors + + def cached_grid_anchors(self, grid_sizes, strides): + # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] + key = str(grid_sizes) + str(strides) + if key in self._cache: + return self._cache[key] + anchors = self.grid_anchors(grid_sizes, strides) + self._cache[key] = anchors + return anchors + + def forward(self, image_list, feature_maps): + # type: (ImageList, List[Tensor]) -> List[Tensor] + grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) + image_size = image_list.tensors.shape[-2:] + dtype, device = feature_maps[0].dtype, feature_maps[0].device + strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), + torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] + self.set_cell_anchors(dtype, device) + anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) + anchors = torch.jit.annotate(List[List[torch.Tensor]], []) + for i, (image_height, image_width) in enumerate(image_list.image_sizes): + anchors_in_image = [] + for anchors_per_feature_map in anchors_over_all_feature_maps: + anchors_in_image.append(anchors_per_feature_map) + anchors.append(anchors_in_image) + anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] + # Clear the cache in case that memory leaks. + self._cache.clear() + return anchors diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index c7e6c6d12db..117d985eca6 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -9,8 +9,9 @@ from ..utils import load_state_dict_from_url +from .anchor_utils import AnchorGenerator from .generalized_rcnn import GeneralizedRCNN -from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads from .transform import GeneralizedRCNNTransform from .backbone_utils import resnet_fpn_backbone diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 438ef225c91..f1f4ad26680 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN): >>> import torch >>> import torchvision >>> from torchvision.models.detection import KeypointRCNN - >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> >>> # load a pre-trained model for classification and return >>> # only the features diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index af6ac8fb17c..668d8ab8122 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN): >>> import torch >>> import torchvision >>> from torchvision.models.detection import MaskRCNN - >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> >>> # load a pre-trained model for classification and return >>> # only the features diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 1d814ef232c..52cb28d48a8 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -11,6 +11,9 @@ from torch.jit.annotations import List, Optional, Dict, Tuple +# Import AnchorGenerator to keep compatibility. +from .anchor_utils import AnchorGenerator + @torch.jit.unused def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): @@ -24,159 +27,6 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): return num_anchors, pre_nms_top_n -class AnchorGenerator(nn.Module): - __annotations__ = { - "cell_anchors": Optional[List[torch.Tensor]], - "_cache": Dict[str, List[torch.Tensor]] - } - - """ - Module that generates anchors for a set of feature maps and - image sizes. - - The module support computing anchors at multiple sizes and aspect ratios - per feature map. This module assumes aspect ratio = height / width for - each anchor. - - sizes and aspect_ratios should have the same number of elements, and it should - correspond to the number of feature maps. - - sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, - and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors - per spatial location for feature map i. - - Arguments: - sizes (Tuple[Tuple[int]]): - aspect_ratios (Tuple[Tuple[float]]): - """ - - def __init__( - self, - sizes=((128, 256, 512),), - aspect_ratios=((0.5, 1.0, 2.0),), - ): - super(AnchorGenerator, self).__init__() - - if not isinstance(sizes[0], (list, tuple)): - # TODO change this - sizes = tuple((s,) for s in sizes) - if not isinstance(aspect_ratios[0], (list, tuple)): - aspect_ratios = (aspect_ratios,) * len(sizes) - - assert len(sizes) == len(aspect_ratios) - - self.sizes = sizes - self.aspect_ratios = aspect_ratios - self.cell_anchors = None - self._cache = {} - - # TODO: https://github.com/pytorch/pytorch/issues/26792 - # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. - # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) - # This method assumes aspect ratio = height / width for an anchor. - def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): - # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 - scales = torch.as_tensor(scales, dtype=dtype, device=device) - aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) - h_ratios = torch.sqrt(aspect_ratios) - w_ratios = 1 / h_ratios - - ws = (w_ratios[:, None] * scales[None, :]).view(-1) - hs = (h_ratios[:, None] * scales[None, :]).view(-1) - - base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 - return base_anchors.round() - - def set_cell_anchors(self, dtype, device): - # type: (int, Device) -> None # noqa: F821 - if self.cell_anchors is not None: - cell_anchors = self.cell_anchors - assert cell_anchors is not None - # suppose that all anchors have the same device - # which is a valid assumption in the current state of the codebase - if cell_anchors[0].device == device: - return - - cell_anchors = [ - self.generate_anchors( - sizes, - aspect_ratios, - dtype, - device - ) - for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) - ] - self.cell_anchors = cell_anchors - - def num_anchors_per_location(self): - return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] - - # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), - # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. - def grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] - anchors = [] - cell_anchors = self.cell_anchors - assert cell_anchors is not None - assert len(grid_sizes) == len(strides) == len(cell_anchors) - - for size, stride, base_anchors in zip( - grid_sizes, strides, cell_anchors - ): - grid_height, grid_width = size - stride_height, stride_width = stride - device = base_anchors.device - - # For output anchor, compute [x_center, y_center, x_center, y_center] - shifts_x = torch.arange( - 0, grid_width, dtype=torch.float32, device=device - ) * stride_width - shifts_y = torch.arange( - 0, grid_height, dtype=torch.float32, device=device - ) * stride_height - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) - shift_x = shift_x.reshape(-1) - shift_y = shift_y.reshape(-1) - shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) - - # For every (base anchor, output anchor) pair, - # offset each zero-centered base anchor by the center of the output anchor. - anchors.append( - (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) - ) - - return anchors - - def cached_grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] - key = str(grid_sizes) + str(strides) - if key in self._cache: - return self._cache[key] - anchors = self.grid_anchors(grid_sizes, strides) - self._cache[key] = anchors - return anchors - - def forward(self, image_list, feature_maps): - # type: (ImageList, List[Tensor]) -> List[Tensor] - grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) - image_size = image_list.tensors.shape[-2:] - dtype, device = feature_maps[0].dtype, feature_maps[0].device - strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), - torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] - self.set_cell_anchors(dtype, device) - anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) - anchors = torch.jit.annotate(List[List[torch.Tensor]], []) - for i, (image_height, image_width) in enumerate(image_list.image_sizes): - anchors_in_image = [] - for anchors_per_feature_map in anchors_over_all_feature_maps: - anchors_in_image.append(anchors_per_feature_map) - anchors.append(anchors_in_image) - anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] - # Clear the cache in case that memory leaks. - self._cache.clear() - return anchors - - class RPNHead(nn.Module): """ Adds a simple RPN Head with classification and regression heads From 8e0804d940978154d3087d8642d3a8b726fc593c Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 24 Jan 2020 09:44:52 +0100 Subject: [PATCH 03/31] Move box similarity to Matcher. --- torchvision/models/detection/_utils.py | 16 ++++++++++++++-- torchvision/models/detection/rpn.py | 5 +---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index cf4567daa8d..ca9ad28982b 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -5,6 +5,8 @@ from torch import Tensor import torchvision +from torchvision.ops import boxes as box_ops + class BalancedPositiveNegativeSampler(object): """ @@ -240,7 +242,11 @@ class Matcher(object): 'BETWEEN_THRESHOLDS': int, } - def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): + def __init__(self, + high_threshold, + low_threshold, + allow_low_quality_matches=False, + box_similarity=None): # type: (float, float, bool) -> None """ Args: @@ -262,7 +268,11 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches - def __call__(self, match_quality_matrix): + if box_similarity is None: + box_similarity = box_ops.box_iou + self.box_similarity = box_similarity + + def __call__(self, gt_boxes, anchors_per_image): """ Args: match_quality_matrix (Tensor[float]): an MxN tensor, containing the @@ -273,6 +283,8 @@ def __call__(self, match_quality_matrix): [0, M - 1] or a negative value indicating that prediction i could not be matched. """ + match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image) + if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training if match_quality_matrix.shape[0] == 0: diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 52cb28d48a8..a08b778d076 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -147,9 +147,6 @@ def __init__(self, self.head = head self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) - # used during training - self.box_similarity = box_ops.box_iou - self.proposal_matcher = det_utils.Matcher( fg_iou_thresh, bg_iou_thresh, @@ -189,7 +186,7 @@ def assign_targets_to_anchors(self, anchors, targets): labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device) else: match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) - matched_idxs = self.proposal_matcher(match_quality_matrix) + matched_idxs = self.proposal_matcher(gt_boxes, anchors_per_image) # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single # GT in the image, and matched_idxs can be -2, which goes From ad531941dbeef642fc634b9077397db5871e17d1 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 24 Jan 2020 09:45:27 +0100 Subject: [PATCH 04/31] Expose extra blocks in FPN. --- .../models/detection/backbone_utils.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 8d5d56404e8..b64b1eb030f 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module): Attributes: out_channels (int): the number of channels in the FPN """ - def __init__(self, backbone, return_layers, in_channels_list, out_channels): + def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): super(BackboneWithFPN, self).__init__() + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.fpn = FeaturePyramidNetwork( in_channels_list=in_channels_list, out_channels=out_channels, - extra_blocks=LastLevelMaxPool(), + extra_blocks=extra_blocks, ) self.out_channels = out_channels @@ -41,7 +45,16 @@ def forward(self, x): return x -def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3): +def resnet_fpn_backbone( + backbone_name, + pretrained, + norm_layer=misc_nn_ops.FrozenBatchNorm2d, + trainable_layers=3, + extra_blocks=None +): + backbone = resnet.__dict__[backbone_name]( + pretrained=pretrained, + norm_layer=norm_layer) """ Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. @@ -82,6 +95,9 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'} in_channels_stage2 = backbone.inplanes // 8 @@ -92,4 +108,4 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen in_channels_stage2 * 8, ] out_channels = 256 - return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels) + return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) From 2a5a5be11258ef27ebf3442bb7d51039953cd301 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 24 Jan 2020 09:47:38 +0100 Subject: [PATCH 05/31] Expose retinanet in __init__.py. --- torchvision/models/detection/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 3b2683ae328..78ddc61c144 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,3 +1,4 @@ from .faster_rcnn import * from .mask_rcnn import * from .keypoint_rcnn import * +from .retinanet import * From 49e990cfea39ffc2983dc1acb16c7c95257e373c Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 24 Jan 2020 09:48:19 +0100 Subject: [PATCH 06/31] Use P6 and P7 in FPN for retinanet. --- torchvision/models/detection/retinanet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index ac320177501..24080d2ed08 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -9,6 +9,7 @@ from .rpn import AnchorGenerator from .transform import GeneralizedRCNNTransform from .backbone_utils import resnet_fpn_backbone +from ...ops.feature_pyramid_network import LastLevelP6P7 __all__ = [ @@ -380,7 +381,7 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, extra_blocks=LastLevelP6P7(256, 256)) model = RetinaNet(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], From b5966eb8054365756d52d9c8ebb3569d6f354ca4 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 24 Jan 2020 09:50:32 +0100 Subject: [PATCH 07/31] Use parameters from retinanet for anchor generation. --- torchvision/models/detection/retinanet.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 24080d2ed08..2e40aa88b34 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -6,7 +6,8 @@ from ..utils import load_state_dict_from_url -from .rpn import AnchorGenerator +from . import _utils as det_utils +from .anchor_utils import AnchorGenerator from .transform import GeneralizedRCNNTransform from .backbone_utils import resnet_fpn_backbone from ...ops.feature_pyramid_network import LastLevelP6P7 @@ -191,7 +192,7 @@ class RetinaNet(nn.Module): >>> import torch >>> import torchvision >>> from torchvision.models.detection import RetinaNet - >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features @@ -205,8 +206,8 @@ class RetinaNet(nn.Module): >>> # ratios. We have a Tuple[Tuple[int]] because each feature >>> # map could potentially have different sizes and >>> # aspect ratios - >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), - >>> aspect_ratios=((0.5, 1.0, 2.0),)) + >>> anchor_generator = AnchorGenerator(sizes=[[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]], + >>> aspect_ratios=[[0.5, 1.0, 2.0]] * 5) >>> >>> # put the pieces together inside a RetinaNet model >>> model = RetinaNet(backbone, @@ -236,14 +237,16 @@ def __init__(self, backbone, num_classes, assert isinstance(anchor_generator, (AnchorGenerator, type(None))) if anchor_generator is None: - anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) + # TODO: Set correct default values + anchor_sizes = [[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]] aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - self.anchor_generator = AnchorGenerator( + anchor_generator = AnchorGenerator( anchor_sizes, aspect_ratios ) + self.anchor_generator = anchor_generator if head is None: - head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()) + head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()[0]) self.head = head if image_mean is None: From aab1b2834726c2ca3c5d2ddfb5b4bd467b1f6c05 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 24 Jan 2020 09:51:34 +0100 Subject: [PATCH 08/31] General fixes for retinanet model. --- torchvision/models/detection/retinanet.py | 80 ++++++++++++++++++----- 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 2e40aa88b34..d0f011d697a 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -29,19 +29,19 @@ class RetinaNetHead(nn.Module): """ def __init__(self, in_channels, num_anchors, num_classes): - super(RPNHead, self).__init__() + super(RetinaNetHead, self).__init__() self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) - def compute_loss(self, outputs, labels, matched_gt_boxes): + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): return { - 'classification': self.classification_head.compute_loss(outputs, targets, anchor_state), - 'regression': self.regression_head.compute_loss(outputs, targets, anchor_state), + 'classification': self.classification_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + 'bbox_reg': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), } def forward(self, x): - logits = [self.classification_head(feature, targets) for feature in x] - bbox_reg = [self.regression_head(feature, targets) for feature in x] + logits = [self.classification_head(feature) for feature in x] + bbox_reg = [self.regression_head(feature) for feature in x] return dict(logits=logits, bbox_reg=bbox_reg) @@ -68,7 +68,7 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) - def compute_loss(self, outputs, labels, matched_gt_boxes): + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # TODO Implement focal loss, is there an existing function for this? return 0 @@ -114,9 +114,35 @@ def __init__(self, in_channels, num_anchors): torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) - def compute_loss(self, outputs, labels, matched_gt_boxes): - # TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ? - return 0 + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + loss = [] + + predicted_regression = head_outputs['bbox_reg'][0] + for targets_per_image, predicted_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_regression, anchors, matched_idxs): + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] + + # determine only the foreground indices, ignore the rest + foreground_idxs_per_image = matched_idxs_per_image >= 0 + + # select only the foreground boxes + matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] + print(predicted_regression_per_image.shape) + predicted_regression_per_image = predicted_regression_per_image['bbox_reg'][foreground_idxs_per_image, :] + anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] + + # compute the regression targets + target_regression = self.box_coder.encode(matched_gt_boxes_per_image, anchors_per_image) + + # compute the loss + loss.append(torch.nn.SmoothL1Loss()(predicted_regression_per_image, target_regression)) + + return sum(loss) / len(loss) def forward(self, x): all_bbox_regression = [] @@ -224,15 +250,18 @@ def __init__(self, backbone, num_classes, image_mean=None, image_std=None, # Anchor parameters anchor_generator=None, head=None, + proposal_matcher=None, pre_nms_top_n=1000, post_nms_top_n=1000, nms_thresh=0.5, fg_iou_thresh=0.5, bg_iou_thresh=0.4): + super(RetinaNet, self).__init__() if not hasattr(backbone, "out_channels"): raise ValueError( "backbone should contain an attribute out_channels " "specifying the number of output channels (assumed to be the " "same for all the levels)") + self.backbone = backbone assert isinstance(anchor_generator, (AnchorGenerator, type(None))) @@ -249,6 +278,14 @@ def __init__(self, backbone, num_classes, head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()[0]) self.head = head + if proposal_matcher is None: + proposal_matcher = det_utils.Matcher( + fg_iou_thresh, + bg_iou_thresh, + allow_low_quality_matches=True, + ) + self.proposal_matcher = proposal_matcher + if image_mean is None: image_mean = [0.485, 0.456, 0.406] if image_std is None: @@ -263,6 +300,13 @@ def eager_outputs(self, losses, detections): return detections + def compute_loss(self, targets, head_outputs, anchors): + matched_idxs = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + matched_idxs.append(self.proposal_matcher(targets_per_image["boxes"], anchors_per_image)) + + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + def forward(self, images, targets=None): # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) """ @@ -295,8 +339,16 @@ def forward(self, images, targets=None): if isinstance(features, torch.Tensor): features = OrderedDict([('0', features)]) + # TODO: Do we want a list or a dict? + features = list(features.values()) + + # TODO: Is there a better way to check for [P3, P4, P5, P6, P7]? + if len(features) == 6: + # skip P2 because it generates too many anchors + features = features[1:] + # compute the retinanet heads outputs using the features - head_outputs = self.head(images, features, targets) + head_outputs = self.head(features) # create the set of anchors anchors = self.anchor_generator(images, features) @@ -307,11 +359,7 @@ def forward(self, images, targets=None): assert targets is not None # compute the losses - # TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function - # so that we can use it here and in rpn.RegionProposalNetwork - labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) - regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) - losses = self.head.compute_loss(head_outputs, labels, matched_gt_boxes) + losses = self.compute_loss(targets, head_outputs, anchors) else: # compute the detections # TODO: Implement postprocess_detections From c07811434093dba966530f61bf3957cb2384d366 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 31 Jan 2020 15:52:51 +0100 Subject: [PATCH 09/31] Implement loss for retinanet heads. --- torchvision/models/detection/_utils.py | 19 ++--- torchvision/models/detection/anchor_utils.py | 3 + torchvision/models/detection/retinanet.py | 79 ++++++++++++++++---- torchvision/ops/__init__.py | 4 +- torchvision/ops/focal_loss.py | 47 ++++++++++++ 5 files changed, 129 insertions(+), 23 deletions(-) create mode 100644 torchvision/ops/focal_loss.py diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index ca9ad28982b..36ba47e2545 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -1,7 +1,7 @@ import math import torch -from torch.jit.annotations import List, Tuple +from torch.jit.annotations import List, Tuple, Optional from torch import Tensor import torchvision @@ -245,8 +245,7 @@ class Matcher(object): def __init__(self, high_threshold, low_threshold, - allow_low_quality_matches=False, - box_similarity=None): + allow_low_quality_matches=False): # type: (float, float, bool) -> None """ Args: @@ -268,22 +267,24 @@ def __init__(self, self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches - if box_similarity is None: - box_similarity = box_ops.box_iou - self.box_similarity = box_similarity + # if box_similarity is None: + # box_similarity = box_ops.box_iou + # self.box_similarity = box_similarity def __call__(self, gt_boxes, anchors_per_image): """ Args: - match_quality_matrix (Tensor[float]): an MxN tensor, containing the - pairwise quality between M ground-truth elements and N predicted elements. + gt_boxes (Tensor[float]): an Mx4 tensor, containing M detections. + + anchors_per_image (Tensor[float]): an Mx4 tensor, containing + the anchors for a specific image. Returns: matches (Tensor[int64]): an N tensor where N[i] is a matched gt in [0, M - 1] or a negative value indicating that prediction i could not be matched. """ - match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image) + match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) # self.box_similarity(gt_boxes, anchors_per_image) if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 80171ef20d4..95418fca34a 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -1,7 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch +import torchvision from torch import nn +from torch.jit.annotations import List, Optional, Dict + class AnchorGenerator(nn.Module): __annotations__ = { diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index d0f011d697a..d94648de6f9 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,8 +1,10 @@ +import math from collections import OrderedDict import torch from torch import nn import torch.nn.functional as F +from torch.jit.annotations import Dict, List, Tuple from ..utils import load_state_dict_from_url @@ -11,6 +13,7 @@ from .transform import GeneralizedRCNNTransform from .backbone_utils import resnet_fpn_backbone from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...ops import sigmoid_focal_loss __all__ = [ @@ -36,13 +39,13 @@ def __init__(self, in_channels, num_anchors, num_classes): def compute_loss(self, targets, head_outputs, anchors, matched_idxs): return { 'classification': self.classification_head.compute_loss(targets, head_outputs, anchors, matched_idxs), - 'bbox_reg': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + 'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), } def forward(self, x): - logits = [self.classification_head(feature) for feature in x] + cls_logits = [self.classification_head(feature) for feature in x] bbox_reg = [self.regression_head(feature) for feature in x] - return dict(logits=logits, bbox_reg=bbox_reg) + return dict(cls_logits=cls_logits, bbox_reg=bbox_reg) class RetinaNetClassificationHead(nn.Module): @@ -68,9 +71,48 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) + + self.num_classes = num_classes + self.num_anchors = num_anchors + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): - # TODO Implement focal loss, is there an existing function for this? - return 0 + loss = [] + + def permute_classification(tensor): + """ Permute classification output from (N, A * K, H, W) to (N, HWA, K). """ + N, _, H, W = tensor.shape + tensor = tensor.view(N, -1, self.num_classes, H, W) + tensor = tensor.permute(0, 3, 4, 1, 2) + tensor = tensor.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) + return tensor + + predicted_classification = head_outputs['cls_logits'] + predicted_classification = [permute_classification(cls) for cls in predicted_classification] + predicted_classification = torch.cat(predicted_classification, dim=1) + + for targets_per_image, predicted_classification_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_classification, anchors, matched_idxs): + # determine only the foreground + foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() + + # create the target classification + gt_classes_target = torch.zeros_like(predicted_classification_per_image) + gt_classes_target[foreground_idxs_per_image, targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]] = 1 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS + + # compute the classification loss + loss.append(sigmoid_focal_loss_jit( + predicted_classification_per_image[valid_idxs_per_image], + gt_classes_target[valid_idxs_per_image], + reduction='sum', + ) / max(1, num_foreground)) + + return sum(loss) / len(loss) def forward(self, x): all_cls_logits = [] @@ -112,14 +154,25 @@ def __init__(self, in_channels, num_anchors): for l in self.children(): torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.constant_(l.bias, 0) + torch.nn.init.zeros_(l.bias) self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) def compute_loss(self, targets, head_outputs, anchors, matched_idxs): loss = [] - predicted_regression = head_outputs['bbox_reg'][0] + def permute_bbox_reg(tensor): + """ Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """ + N, _, H, W = tensor.shape + tensor = tensor.view(N, -1, 4, H, W) + tensor = tensor.permute(0, 3, 4, 1, 2) + tensor = tensor.reshape(N, -1, 4) # Size=(N, HWA, 4) + return tensor + + predicted_regression = head_outputs['bbox_reg'] + predicted_regression = [permute_bbox_reg(reg) for reg in predicted_regression] + predicted_regression = torch.cat(predicted_regression, dim=1) + for targets_per_image, predicted_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_regression, anchors, matched_idxs): # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single @@ -129,20 +182,20 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # determine only the foreground indices, ignore the rest foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() # select only the foreground boxes matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] - print(predicted_regression_per_image.shape) - predicted_regression_per_image = predicted_regression_per_image['bbox_reg'][foreground_idxs_per_image, :] + predicted_regression_per_image = predicted_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] # compute the regression targets - target_regression = self.box_coder.encode(matched_gt_boxes_per_image, anchors_per_image) + target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss - loss.append(torch.nn.SmoothL1Loss()(predicted_regression_per_image, target_regression)) + loss.append(F.smooth_l1_loss((bbox_regression_per_image, target_regression) / max(1, num_foreground), reduction='sum') - return sum(loss) / len(loss) + return sum(loss) / max(1, len(loss)) def forward(self, x): all_bbox_regression = [] @@ -275,7 +328,7 @@ def __init__(self, backbone, num_classes, self.anchor_generator = anchor_generator if head is None: - head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()[0]) + head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) self.head = head if proposal_matcher is None: diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 4f94ac447c8..bd82a0d5ed9 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -8,6 +8,7 @@ from .ps_roi_pool import ps_roi_pool, PSRoIPool from .poolers import MultiScaleRoIAlign from .feature_pyramid_network import FeaturePyramidNetwork +from .focal_loss import sigmoid_focal_loss from ._register_onnx_ops import _register_custom_op @@ -19,5 +20,6 @@ 'clip_boxes_to_image', 'box_convert', 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', - 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' + 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', + 'sigmoid_focal_loss' ] diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py new file mode 100644 index 00000000000..f19bec3c2f1 --- /dev/null +++ b/torchvision/ops/focal_loss.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.functional as F + +def sigmoid_focal_loss( + inputs, + targets, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "none", +): + """ + Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py . + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples or -1 for ignore. Default = 0.25 + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + Returns: + Loss tensor with the reduction option applied. + """ + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits( + inputs, targets, reduction="none" + ) + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss From eae4ee5ac67117663e718bd078e3abfbb5340aa1 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Sun, 2 Feb 2020 11:55:17 +0100 Subject: [PATCH 10/31] Output reshaped outputs from retinanet heads. --- torchvision/models/detection/retinanet.py | 43 +++++++---------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index d94648de6f9..6cb59feb26a 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -43,9 +43,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): } def forward(self, x): - cls_logits = [self.classification_head(feature) for feature in x] - bbox_reg = [self.regression_head(feature) for feature in x] - return dict(cls_logits=cls_logits, bbox_reg=bbox_reg) + return { + 'cls_logits': self.classification_head(x), + 'bbox_regression': self.regression_head(x) + } class RetinaNetClassificationHead(nn.Module): @@ -81,33 +82,23 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 def compute_loss(self, targets, head_outputs, anchors, matched_idxs): loss = [] - def permute_classification(tensor): - """ Permute classification output from (N, A * K, H, W) to (N, HWA, K). """ - N, _, H, W = tensor.shape - tensor = tensor.view(N, -1, self.num_classes, H, W) - tensor = tensor.permute(0, 3, 4, 1, 2) - tensor = tensor.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) - return tensor + cls_logits = head_outputs['cls_logits'] - predicted_classification = head_outputs['cls_logits'] - predicted_classification = [permute_classification(cls) for cls in predicted_classification] - predicted_classification = torch.cat(predicted_classification, dim=1) - - for targets_per_image, predicted_classification_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_classification, anchors, matched_idxs): + for targets_per_image, cls_logits_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, cls_logits, anchors, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() # create the target classification - gt_classes_target = torch.zeros_like(predicted_classification_per_image) + gt_classes_target = torch.zeros_like(cls_logits_per_image) gt_classes_target[foreground_idxs_per_image, targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]] = 1 # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS # compute the classification loss - loss.append(sigmoid_focal_loss_jit( - predicted_classification_per_image[valid_idxs_per_image], + loss.append(sigmoid_focal_loss( + cls_logits_per_image[valid_idxs_per_image], gt_classes_target[valid_idxs_per_image], reduction='sum', ) / max(1, num_foreground)) @@ -161,19 +152,9 @@ def __init__(self, in_channels, num_anchors): def compute_loss(self, targets, head_outputs, anchors, matched_idxs): loss = [] - def permute_bbox_reg(tensor): - """ Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """ - N, _, H, W = tensor.shape - tensor = tensor.view(N, -1, 4, H, W) - tensor = tensor.permute(0, 3, 4, 1, 2) - tensor = tensor.reshape(N, -1, 4) # Size=(N, HWA, 4) - return tensor - - predicted_regression = head_outputs['bbox_reg'] - predicted_regression = [permute_bbox_reg(reg) for reg in predicted_regression] - predicted_regression = torch.cat(predicted_regression, dim=1) + bbox_regression = head_outputs['bbox_regression'] - for targets_per_image, predicted_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_regression, anchors, matched_idxs): + for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, bbox_regression, anchors, matched_idxs): # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single # GT in the image, and matched_idxs can be -2, which goes @@ -186,7 +167,7 @@ def permute_bbox_reg(tensor): # select only the foreground boxes matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] - predicted_regression_per_image = predicted_regression_per_image[foreground_idxs_per_image, :] + bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] # compute the regression targets From 3dac47705a2cb05912bded27831349d37045652a Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 7 Feb 2020 13:55:53 +0100 Subject: [PATCH 11/31] Add postprocessing of detections. --- torchvision/models/detection/retinanet.py | 80 ++++++++++++++++++----- 1 file changed, 63 insertions(+), 17 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 6cb59feb26a..0c5da1fb6bf 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -239,9 +239,9 @@ class RetinaNet(nn.Module): maps. head (nn.Module): Module run on top of the feature pyramid. Defaults to a module containing a classification and regression module. - pre_nms_top_n (int): number of proposals to keep before applying NMS during testing. - post_nms_top_n (int): number of proposals to keep after applying NMS during testing. + score_thresh (float): Score threshold used for postprocessing the detections. nms_thresh (float): NMS threshold used for postprocessing the detections. + detections_per_img (int): Number of best detections to keep after NMS. fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be considered as positive during training. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be @@ -285,8 +285,9 @@ def __init__(self, backbone, num_classes, # Anchor parameters anchor_generator=None, head=None, proposal_matcher=None, - pre_nms_top_n=1000, post_nms_top_n=1000, + score_thresh=0.5, nms_thresh=0.5, + detections_per_img=300, fg_iou_thresh=0.5, bg_iou_thresh=0.4): super(RetinaNet, self).__init__() @@ -300,7 +301,6 @@ def __init__(self, backbone, num_classes, assert isinstance(anchor_generator, (AnchorGenerator, type(None))) if anchor_generator is None: - # TODO: Set correct default values anchor_sizes = [[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]] aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) anchor_generator = AnchorGenerator( @@ -320,12 +320,18 @@ def __init__(self, backbone, num_classes, ) self.proposal_matcher = proposal_matcher + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + if image_mean is None: image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + @torch.jit.unused def eager_outputs(self, losses, detections): # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] @@ -341,6 +347,57 @@ def compute_loss(self, targets, head_outputs, anchors): return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + def postprocess_detections(self, class_logits, box_regression, anchors, image_shapes): + # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]]) + # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ? + device = class_logits.device + num_classes = class_logits.shape[-1] + + scores = torch.sigmoid(class_logits) + + # create labels for each score + # the +1 is to make the labels identical to other object detection algorithms that treat background as label 0 + labels = torch.arange(num_classes, device=device) + 1 + labels = labels.view(1, -1).expand_as(scores) + + detections = [] + + for box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape in zip(box_regression, scores, labels, anchors, image_shapes): + boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image) + boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape) + + image_boxes = [] + image_scores = [] + image_labels = [] + + for class_index in range(num_classes): + # remove low scoring boxes + inds = torch.nonzero(scores_per_image[:, class_index] > self.score_thresh).squeeze(1) + boxes_per_class, scores_per_class, labels_per_class = boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index] + + # remove empty boxes + keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2) + boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + + # non-maximum suppression, independently done per class + keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh) + + # keep only topk scoring predictions + keep = keep[:self.detections_per_img] + boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + + image_boxes.append(boxes_per_class) + image_scores.append(scores_per_class) + image_labels.append(labels_per_class) + + detections.append({ + 'boxes': torch.cat(image_boxes, dim=0), + 'scores': torch.cat(image_scores, dim=0), + 'labels': torch.cat(image_labels, dim=0), + }) + + return detections + def forward(self, images, targets=None): # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) """ @@ -396,19 +453,8 @@ def forward(self, images, targets=None): losses = self.compute_loss(targets, head_outputs, anchors) else: # compute the detections - # TODO: Implement postprocess_detections - boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, anchors) - num_images = len(images) - for i in range(num_images): - detections.append( - { - "boxes": boxes[i], - "labels": labels[i], - "scores": scores[i], - } - ) - - detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + detections = self.postprocess_detections(head_outputs['cls_logits'], head_outputs['bbox_regression'], anchors, images.image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) if torch.jit.is_scripting(): if not self._has_warned: From 9981a3c5f4019fd2738a9ec1b005ff53b4f959d1 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 20 Mar 2020 09:18:34 +0100 Subject: [PATCH 12/31] Small fixes. --- torchvision/models/detection/_utils.py | 23 +++++------------------ torchvision/models/detection/retinanet.py | 22 +++++++++++++--------- torchvision/models/detection/rpn.py | 7 +++++-- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 36ba47e2545..cf4567daa8d 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -1,12 +1,10 @@ import math import torch -from torch.jit.annotations import List, Tuple, Optional +from torch.jit.annotations import List, Tuple from torch import Tensor import torchvision -from torchvision.ops import boxes as box_ops - class BalancedPositiveNegativeSampler(object): """ @@ -242,10 +240,7 @@ class Matcher(object): 'BETWEEN_THRESHOLDS': int, } - def __init__(self, - high_threshold, - low_threshold, - allow_low_quality_matches=False): + def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): # type: (float, float, bool) -> None """ Args: @@ -267,25 +262,17 @@ def __init__(self, self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches - # if box_similarity is None: - # box_similarity = box_ops.box_iou - # self.box_similarity = box_similarity - - def __call__(self, gt_boxes, anchors_per_image): + def __call__(self, match_quality_matrix): """ Args: - gt_boxes (Tensor[float]): an Mx4 tensor, containing M detections. - - anchors_per_image (Tensor[float]): an Mx4 tensor, containing - the anchors for a specific image. + match_quality_matrix (Tensor[float]): an MxN tensor, containing the + pairwise quality between M ground-truth elements and N predicted elements. Returns: matches (Tensor[int64]): an N tensor where N[i] is a matched gt in [0, M - 1] or a negative value indicating that prediction i could not be matched. """ - match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) # self.box_similarity(gt_boxes, anchors_per_image) - if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training if match_quality_matrix.shape[0] == 0: diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 0c5da1fb6bf..01839f79949 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -14,6 +14,7 @@ from .backbone_utils import resnet_fpn_backbone from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops import sigmoid_focal_loss +from ...ops import boxes as box_ops __all__ = [ @@ -68,9 +69,10 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 conv.append(nn.ReLU()) self.conv = nn.Sequential(*conv) - for l in self.children(): - torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.constant_(l.bias, 0) + for l in self.conv.children(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) torch.nn.init.normal_(self.cls_logits.weight, std=0.01) @@ -143,9 +145,10 @@ def __init__(self, in_channels, num_anchors): self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) - for l in self.children(): - torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.zeros_(l.bias) + for l in self.conv.children(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.zeros_(l.bias) self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) @@ -174,7 +177,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss - loss.append(F.smooth_l1_loss((bbox_regression_per_image, target_regression) / max(1, num_foreground), reduction='sum') + loss.append(F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction='sum') / max(1, num_foreground)) return sum(loss) / max(1, len(loss)) @@ -343,7 +346,8 @@ def eager_outputs(self, losses, detections): def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): - matched_idxs.append(self.proposal_matcher(targets_per_image["boxes"], anchors_per_image)) + match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) + matched_idxs.append(self.proposal_matcher(match_quality_matrix)) return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) @@ -435,7 +439,7 @@ def forward(self, images, targets=None): # TODO: Is there a better way to check for [P3, P4, P5, P6, P7]? if len(features) == 6: - # skip P2 because it generates too many anchors + # skip P2 because it generates too many anchors (according to their paper) features = features[1:] # compute the retinanet heads outputs using the features diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index a08b778d076..e4c4478eb54 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -147,6 +147,9 @@ def __init__(self, self.head = head self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + # used during training + self.box_similarity = box_ops.box_iou + self.proposal_matcher = det_utils.Matcher( fg_iou_thresh, bg_iou_thresh, @@ -185,8 +188,8 @@ def assign_targets_to_anchors(self, anchors, targets): matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device) labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device) else: - match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) - matched_idxs = self.proposal_matcher(gt_boxes, anchors_per_image) + match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image) + matched_idxs = self.proposal_matcher(match_quality_matrix) # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single # GT in the image, and matched_idxs can be -2, which goes From 5571dfef359f8f1c1f7f0106825ebbd84c42b378 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 3 Apr 2020 22:29:17 +0200 Subject: [PATCH 13/31] Remove unused argument. --- torchvision/models/detection/retinanet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 01839f79949..f5f37a918c7 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -39,7 +39,7 @@ def __init__(self, in_channels, num_anchors, num_classes): def compute_loss(self, targets, head_outputs, anchors, matched_idxs): return { - 'classification': self.classification_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + 'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs), 'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), } @@ -81,12 +81,12 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 self.num_classes = num_classes self.num_anchors = num_anchors - def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + def compute_loss(self, targets, head_outputs, matched_idxs): loss = [] cls_logits = head_outputs['cls_logits'] - for targets_per_image, cls_logits_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, cls_logits, anchors, matched_idxs): + for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() From fc7751b1c3f5ef10cc9e797da220df814f429607 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Sat, 4 Apr 2020 15:17:56 +0200 Subject: [PATCH 14/31] Remove python2 invocation of super. --- torchvision/models/detection/retinanet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index f5f37a918c7..85dc7a59916 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -33,7 +33,7 @@ class RetinaNetHead(nn.Module): """ def __init__(self, in_channels, num_anchors, num_classes): - super(RetinaNetHead, self).__init__() + super().__init__() self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) @@ -61,7 +61,7 @@ class RetinaNetClassificationHead(nn.Module): """ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01): - super(RetinaNetClassificationHead, self).__init__() + super().__init__() conv = [] for _ in range(4): @@ -135,7 +135,7 @@ class RetinaNetRegressionHead(nn.Module): """ def __init__(self, in_channels, num_anchors): - super(RetinaNetRegressionHead, self).__init__() + super().__init__() conv = [] for _ in range(4): @@ -292,7 +292,7 @@ def __init__(self, backbone, num_classes, nms_thresh=0.5, detections_per_img=300, fg_iou_thresh=0.5, bg_iou_thresh=0.4): - super(RetinaNet, self).__init__() + super().__init__() if not hasattr(backbone, "out_channels"): raise ValueError( From b942648133fc095bd391e6a9be74c3aa5ccf732f Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Sun, 5 Apr 2020 00:32:29 +0200 Subject: [PATCH 15/31] Add postprocessing for additional outputs. --- torchvision/models/detection/retinanet.py | 25 +++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 85dc7a59916..d01738f3745 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -351,9 +351,14 @@ def compute_loss(self, targets, head_outputs, anchors): return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) - def postprocess_detections(self, class_logits, box_regression, anchors, image_shapes): + def postprocess_detections(self, head_outputs, anchors, image_shapes): # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]]) # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ? + + class_logits = head_outputs.pop('cls_logits') + box_regression = head_outputs.pop('bbox_regression') + other_outputs = head_outputs + device = class_logits.device num_classes = class_logits.shape[-1] @@ -366,22 +371,27 @@ def postprocess_detections(self, class_logits, box_regression, anchors, image_sh detections = [] - for box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape in zip(box_regression, scores, labels, anchors, image_shapes): + for image_index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in enumerate(zip(box_regression, scores, labels, anchors, image_shapes)): boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image) boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape) + other_outputs_per_image = [(k, v[image_index]) for k, v in other_outputs.items()] + image_boxes = [] image_scores = [] image_labels = [] + image_other_outputs = {k: [] for k in other_outputs.keys()} for class_index in range(num_classes): # remove low scoring boxes - inds = torch.nonzero(scores_per_image[:, class_index] > self.score_thresh).squeeze(1) + inds = torch.gt(scores_per_image[:, class_index], self.score_thresh) boxes_per_class, scores_per_class, labels_per_class = boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index] + other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image] # remove empty boxes keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2) boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] # non-maximum suppression, independently done per class keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh) @@ -389,17 +399,24 @@ def postprocess_detections(self, class_logits, box_regression, anchors, image_sh # keep only topk scoring predictions keep = keep[:self.detections_per_img] boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] image_boxes.append(boxes_per_class) image_scores.append(scores_per_class) image_labels.append(labels_per_class) + for k, v in other_outputs_per_class: + image_other_outputs[k].append(v) + detections.append({ 'boxes': torch.cat(image_boxes, dim=0), 'scores': torch.cat(image_scores, dim=0), 'labels': torch.cat(image_labels, dim=0), }) + for k, v in image_other_outputs.items(): + detections[-1].update({k: torch.cat(v, dim=0)}) + return detections def forward(self, images, targets=None): @@ -457,7 +474,7 @@ def forward(self, images, targets=None): losses = self.compute_loss(targets, head_outputs, anchors) else: # compute the detections - detections = self.postprocess_detections(head_outputs['cls_logits'], head_outputs['bbox_regression'], anchors, images.image_sizes) + detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) if torch.jit.is_scripting(): From b619936354e2ef9e9f2b5b31ee3bb072352d72b4 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 17 Apr 2020 10:00:36 +0200 Subject: [PATCH 16/31] Add missing import of ImageList. --- torchvision/models/detection/anchor_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 95418fca34a..d8f90c5ce88 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -4,6 +4,7 @@ from torch import nn from torch.jit.annotations import List, Optional, Dict +from .image_list import ImageList class AnchorGenerator(nn.Module): From 8c86588485f013176415aefb9b771b92b6bab678 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 17 Apr 2020 10:13:34 +0200 Subject: [PATCH 17/31] Remove redundant import. --- torchvision/models/detection/anchor_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index d8f90c5ce88..94bb006c5dd 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -1,6 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -import torchvision from torch import nn from torch.jit.annotations import List, Optional, Dict @@ -8,11 +7,6 @@ class AnchorGenerator(nn.Module): - __annotations__ = { - "cell_anchors": Optional[List[torch.Tensor]], - "_cache": Dict[str, List[torch.Tensor]] - } - """ Module that generates anchors for a set of feature maps and image sizes. @@ -33,6 +27,11 @@ class AnchorGenerator(nn.Module): aspect_ratios (Tuple[Tuple[float]]): """ + __annotations__ = { + "cell_anchors": Optional[List[torch.Tensor]], + "_cache": Dict[str, List[torch.Tensor]] + } + def __init__( self, sizes=((128, 256, 512),), From 2934f0db334ef3e701d7768f78b21cd9f248280d Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 17 Apr 2020 11:11:45 +0200 Subject: [PATCH 18/31] Simplify class correction. --- torchvision/models/detection/retinanet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index d01738f3745..4f8529a8e1d 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -366,7 +366,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): # create labels for each score # the +1 is to make the labels identical to other object detection algorithms that treat background as label 0 - labels = torch.arange(num_classes, device=device) + 1 + labels = torch.arange(1, num_classes + 1, device=device) labels = labels.view(1, -1).expand_as(scores) detections = [] From 32b8e772e8ff8dfdead1c602aaada6f83bbbb202 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 17 Apr 2020 12:23:19 +0200 Subject: [PATCH 19/31] Fix pylint warnings. --- torchvision/models/detection/retinanet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 4f8529a8e1d..b42483f8130 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,5 +1,6 @@ import math from collections import OrderedDict +import warnings import torch from torch import nn @@ -335,6 +336,9 @@ def __init__(self, backbone, num_classes, self.nms_thresh = nms_thresh self.detections_per_img = detections_per_img + # used only on torchscript mode + self._has_warned = False + @torch.jit.unused def eager_outputs(self, losses, detections): # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] @@ -482,8 +486,7 @@ def forward(self, images, targets=None): warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") self._has_warned = True return (losses, detections) - else: - return self.eager_outputs(losses, detections) + return self.eager_outputs(losses, detections) model_urls = { From 437bfe9005c05f6e75c81c1ab4317721cf455be0 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 17 Apr 2020 23:58:47 +0200 Subject: [PATCH 20/31] Remove the label adjustment for background class. --- torchvision/models/detection/retinanet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index b42483f8130..5775878709a 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -369,8 +369,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): scores = torch.sigmoid(class_logits) # create labels for each score - # the +1 is to make the labels identical to other object detection algorithms that treat background as label 0 - labels = torch.arange(1, num_classes + 1, device=device) + labels = torch.arange(num_classes, device=device) labels = labels.view(1, -1).expand_as(scores) detections = [] From 9e810d6ec54ac21c76390cc1a18ba9eaa25a0a55 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Sat, 18 Apr 2020 01:05:38 +0200 Subject: [PATCH 21/31] Set default score threshold to 0.05. --- torchvision/models/detection/retinanet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 5775878709a..363cfc115be 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -289,7 +289,7 @@ def __init__(self, backbone, num_classes, # Anchor parameters anchor_generator=None, head=None, proposal_matcher=None, - score_thresh=0.5, + score_thresh=0.05, nms_thresh=0.5, detections_per_img=300, fg_iou_thresh=0.5, bg_iou_thresh=0.4): From f7d8c2e82ccc4817939497cac488f7bfc42dab41 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 24 Apr 2020 15:20:46 +0200 Subject: [PATCH 22/31] Add weight initialization for regression layer. --- torchvision/models/detection/retinanet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 363cfc115be..8943cf5c403 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -145,6 +145,8 @@ def __init__(self, in_channels, num_anchors): self.conv = nn.Sequential(*conv) self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.bbox_reg.weight, std=0.01) + torch.nn.init.zeros_(self.bbox_reg.bias) for l in self.conv.children(): if isinstance(l, nn.Conv2d): From d86c437094fea3d71ea32a67aeec84b96bd4ceb4 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Mon, 27 Apr 2020 15:22:56 +0200 Subject: [PATCH 23/31] Allow training on images with no annotations. --- torchvision/models/detection/retinanet.py | 37 ++++++++++++++++------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 8943cf5c403..c4c2b3deffd 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -88,16 +88,22 @@ def compute_loss(self, targets, head_outputs, matched_idxs): cls_logits = head_outputs['cls_logits'] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): - # determine only the foreground - foreground_idxs_per_image = matched_idxs_per_image >= 0 - num_foreground = foreground_idxs_per_image.sum() - - # create the target classification - gt_classes_target = torch.zeros_like(cls_logits_per_image) - gt_classes_target[foreground_idxs_per_image, targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]] = 1 - - # find indices for which anchors should be ignored - valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS + # no matched_idxs means there were no annotations in this image + if matched_idxs_per_image is None: + gt_classes_target = torch.zeros_like(cls_logits_per_image) + valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0]) + num_foreground = 0 + else: + # determine only the foreground + foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() + + # create the target classification + gt_classes_target = torch.zeros_like(cls_logits_per_image) + gt_classes_target[foreground_idxs_per_image, targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]] = 1 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS # compute the classification loss loss.append(sigmoid_focal_loss( @@ -161,6 +167,11 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): bbox_regression = head_outputs['bbox_regression'] for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, bbox_regression, anchors, matched_idxs): + # no matched_idxs means there were no annotations in this image + if matched_idxs_per_image is None: + loss.append(0) + continue + # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single # GT in the image, and matched_idxs can be -2, which goes @@ -352,7 +363,11 @@ def eager_outputs(self, losses, detections): def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): - match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) + if targets_per_image['boxes'].numel() == 0: + matched_idxs.append(None) + continue + + match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) matched_idxs.append(self.proposal_matcher(match_quality_matrix)) return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) From 72e46f26ae26e9c7e55a6cc5077323b2e7b90645 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Mon, 27 Apr 2020 17:13:46 +0200 Subject: [PATCH 24/31] Use smooth_l1_loss with beta value. --- torchvision/models/detection/retinanet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index c4c2b3deffd..95aba4fb5a8 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -191,7 +191,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss - loss.append(F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction='sum') / max(1, num_foreground)) + loss.append(det_utils.smooth_l1_loss(bbox_regression_per_image, target_regression, size_average=False) / max(1, num_foreground)) return sum(loss) / max(1, len(loss)) From 41c90fa015cc94d291deec21be2bc45c87179f74 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 15 May 2020 20:32:10 +0200 Subject: [PATCH 25/31] Add more typehints for TorchScript conversions. --- torchvision/models/detection/retinanet.py | 24 ++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 95aba4fb5a8..dadab01086b 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -39,12 +39,14 @@ def __init__(self, in_channels, num_anchors, num_classes): self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor return { 'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs), 'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), } def forward(self, x): + # type: (List[Tensor]) -> Dict[str, Tensor] return { 'cls_logits': self.classification_head(x), 'bbox_regression': self.regression_head(x) @@ -83,13 +85,14 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 self.num_anchors = num_anchors def compute_loss(self, targets, head_outputs, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor loss = [] cls_logits = head_outputs['cls_logits'] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): # no matched_idxs means there were no annotations in this image - if matched_idxs_per_image is None: + if matched_idxs_per_image.numel() == 0: gt_classes_target = torch.zeros_like(cls_logits_per_image) valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0]) num_foreground = 0 @@ -100,7 +103,10 @@ def compute_loss(self, targets, head_outputs, matched_idxs): # create the target classification gt_classes_target = torch.zeros_like(cls_logits_per_image) - gt_classes_target[foreground_idxs_per_image, targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]] = 1 + gt_classes_target[ + foreground_idxs_per_image, + targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] + ] = torch.tensor(1.0) # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS @@ -115,6 +121,7 @@ def compute_loss(self, targets, head_outputs, matched_idxs): return sum(loss) / len(loss) def forward(self, x): + # type: (List[Tensor]) -> Tensor all_cls_logits = [] for features in x: @@ -162,6 +169,7 @@ def __init__(self, in_channels, num_anchors): self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor loss = [] bbox_regression = head_outputs['bbox_regression'] @@ -196,6 +204,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): return sum(loss) / max(1, len(loss)) def forward(self, x): + # type: (List[Tensor]) -> Tensor all_bbox_regression = [] for features in x: @@ -283,8 +292,8 @@ class RetinaNet(nn.Module): >>> # ratios. We have a Tuple[Tuple[int]] because each feature >>> # map could potentially have different sizes and >>> # aspect ratios - >>> anchor_generator = AnchorGenerator(sizes=[[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]], - >>> aspect_ratios=[[0.5, 1.0, 2.0]] * 5) + >>> anchor_generator = AnchorGenerator(sizes=tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]), + >>> aspect_ratios=((0.5, 1.0, 2.0),) * 5) >>> >>> # put the pieces together inside a RetinaNet model >>> model = RetinaNet(backbone, @@ -318,7 +327,7 @@ def __init__(self, backbone, num_classes, assert isinstance(anchor_generator, (AnchorGenerator, type(None))) if anchor_generator is None: - anchor_sizes = [[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]] + anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) anchor_generator = AnchorGenerator( anchor_sizes, aspect_ratios @@ -361,10 +370,11 @@ def eager_outputs(self, losses, detections): return detections def compute_loss(self, targets, head_outputs, anchors): + # type: (List[Dict[str, Tensor]], List[Tensor], List[Tensor]) -> Dict[str, Tensor] matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(None) + matched_idxs.append(torch.empty((0,))) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) @@ -440,7 +450,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): return detections def forward(self, images, targets=None): - # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] """ Arguments: images (list[Tensor]): images to be processed From b9daa864e899706aa8f906703249cc69c5b82c70 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 15 May 2020 20:47:12 +0200 Subject: [PATCH 26/31] Fix linting issues. --- torchvision/models/detection/retinanet.py | 48 +++++++++++++++-------- torchvision/ops/focal_loss.py | 1 + 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index dadab01086b..dde4ee4363d 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -72,10 +72,10 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 conv.append(nn.ReLU()) self.conv = nn.Sequential(*conv) - for l in self.conv.children(): - if isinstance(l, nn.Conv2d): - torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.constant_(l.bias, 0) + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.constant_(layer.bias, 0) self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) torch.nn.init.normal_(self.cls_logits.weight, std=0.01) @@ -161,10 +161,10 @@ def __init__(self, in_channels, num_anchors): torch.nn.init.normal_(self.bbox_reg.weight, std=0.01) torch.nn.init.zeros_(self.bbox_reg.bias) - for l in self.conv.children(): - if isinstance(l, nn.Conv2d): - torch.nn.init.normal_(l.weight, std=0.01) - torch.nn.init.zeros_(l.bias) + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) @@ -174,7 +174,8 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): bbox_regression = head_outputs['bbox_regression'] - for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, bbox_regression, anchors, matched_idxs): + for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ + zip(targets, bbox_regression, anchors, matched_idxs): # no matched_idxs means there were no annotations in this image if matched_idxs_per_image is None: loss.append(0) @@ -199,7 +200,13 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss - loss.append(det_utils.smooth_l1_loss(bbox_regression_per_image, target_regression, size_average=False) / max(1, num_foreground)) + loss.append( + det_utils.smooth_l1_loss( + bbox_regression_per_image, + target_regression, + size_average=False + ) / max(1, num_foreground) + ) return sum(loss) / max(1, len(loss)) @@ -292,8 +299,10 @@ class RetinaNet(nn.Module): >>> # ratios. We have a Tuple[Tuple[int]] because each feature >>> # map could potentially have different sizes and >>> # aspect ratios - >>> anchor_generator = AnchorGenerator(sizes=tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]), - >>> aspect_ratios=((0.5, 1.0, 2.0),) * 5) + >>> anchor_generator = AnchorGenerator( + >>> sizes=tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]), + >>> aspect_ratios=((0.5, 1.0, 2.0),) * 5 + >>> ) >>> >>> # put the pieces together inside a RetinaNet model >>> model = RetinaNet(backbone, @@ -401,11 +410,13 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): detections = [] - for image_index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in enumerate(zip(box_regression, scores, labels, anchors, image_shapes)): + for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \ + enumerate(zip(box_regression, scores, labels, anchors, image_shapes)): + boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image) boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape) - other_outputs_per_image = [(k, v[image_index]) for k, v in other_outputs.items()] + other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()] image_boxes = [] image_scores = [] @@ -415,12 +426,14 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): for class_index in range(num_classes): # remove low scoring boxes inds = torch.gt(scores_per_image[:, class_index], self.score_thresh) - boxes_per_class, scores_per_class, labels_per_class = boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index] + boxes_per_class, scores_per_class, labels_per_class = \ + boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index] other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image] # remove empty boxes keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2) - boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + boxes_per_class, scores_per_class, labels_per_class = \ + boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] # non-maximum suppression, independently done per class @@ -428,7 +441,8 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): # keep only topk scoring predictions keep = keep[:self.detections_per_img] - boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] + boxes_per_class, scores_per_class, labels_per_class = \ + boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] image_boxes.append(boxes_per_class) diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py index f19bec3c2f1..1114d80ed47 100644 --- a/torchvision/ops/focal_loss.py +++ b/torchvision/ops/focal_loss.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F + def sigmoid_focal_loss( inputs, targets, From 97d63b656d41ae3ec8d57355c1707cc2ffd28968 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 15 May 2020 21:06:09 +0200 Subject: [PATCH 27/31] Fix type hints in postprocess_detections. --- torchvision/models/detection/retinanet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index dde4ee4363d..820199df046 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -392,7 +392,7 @@ def compute_loss(self, targets, head_outputs, anchors): return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) def postprocess_detections(self, head_outputs, anchors, image_shapes): - # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]]) + # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Dict[str, Tensor] # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ? class_logits = head_outputs.pop('cls_logits') From eba7e16b676e5fc65fddced32888833a36177168 Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Mon, 18 May 2020 18:43:34 +0200 Subject: [PATCH 28/31] Fix type annotations for TorchScript. --- torchvision/models/detection/retinanet.py | 51 +++++++++++++---------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 820199df046..2821c4913b9 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -3,8 +3,8 @@ import warnings import torch -from torch import nn -import torch.nn.functional as F +import torch.nn as nn +from torch import Tensor from torch.jit.annotations import Dict, List, Tuple from ..utils import load_state_dict_from_url @@ -39,7 +39,7 @@ def __init__(self, in_channels, num_anchors, num_classes): self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) def compute_loss(self, targets, head_outputs, anchors, matched_idxs): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] return { 'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs), 'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), @@ -84,9 +84,14 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 self.num_classes = num_classes self.num_anchors = num_anchors + # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript. + # TorchScript doesn't support class attributes. + # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584 + self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS + def compute_loss(self, targets, head_outputs, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor - loss = [] + loss = torch.tensor(0.0) cls_logits = head_outputs['cls_logits'] @@ -95,7 +100,7 @@ def compute_loss(self, targets, head_outputs, matched_idxs): if matched_idxs_per_image.numel() == 0: gt_classes_target = torch.zeros_like(cls_logits_per_image) valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0]) - num_foreground = 0 + num_foreground = torch.tensor(0.0) else: # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 @@ -109,16 +114,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs): ] = torch.tensor(1.0) # find indices for which anchors should be ignored - valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS + valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss - loss.append(sigmoid_focal_loss( + loss += sigmoid_focal_loss( cls_logits_per_image[valid_idxs_per_image], gt_classes_target[valid_idxs_per_image], reduction='sum', - ) / max(1, num_foreground)) + ) / max(1, num_foreground) - return sum(loss) / len(loss) + return loss / len(targets) def forward(self, x): # type: (List[Tensor]) -> Tensor @@ -170,7 +175,7 @@ def __init__(self, in_channels, num_anchors): def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor - loss = [] + loss = torch.tensor(0.0) bbox_regression = head_outputs['bbox_regression'] @@ -200,15 +205,13 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss - loss.append( - det_utils.smooth_l1_loss( - bbox_regression_per_image, - target_regression, - size_average=False - ) / max(1, num_foreground) - ) + loss += det_utils.smooth_l1_loss( + bbox_regression_per_image, + target_regression, + size_average=False + ) / max(1, num_foreground) - return sum(loss) / max(1, len(loss)) + return loss / max(1, len(targets)) def forward(self, x): # type: (List[Tensor]) -> Tensor @@ -379,7 +382,7 @@ def eager_outputs(self, losses, detections): return detections def compute_loss(self, targets, head_outputs, anchors): - # type: (List[Dict[str, Tensor]], List[Tensor], List[Tensor]) -> Dict[str, Tensor] + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor] matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: @@ -392,7 +395,7 @@ def compute_loss(self, targets, head_outputs, anchors): return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) def postprocess_detections(self, head_outputs, anchors, image_shapes): - # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Dict[str, Tensor] + # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ? class_logits = head_outputs.pop('cls_logits') @@ -408,7 +411,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): labels = torch.arange(num_classes, device=device) labels = labels.view(1, -1).expand_as(scores) - detections = [] + detections = torch.jit.annotate(List[Dict[str, Tensor]], []) for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \ enumerate(zip(box_regression, scores, labels, anchors, image_shapes)): @@ -421,7 +424,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): image_boxes = [] image_scores = [] image_labels = [] - image_other_outputs = {k: [] for k in other_outputs.keys()} + image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {}) for class_index in range(num_classes): # remove low scoring boxes @@ -450,6 +453,8 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): image_labels.append(labels_per_class) for k, v in other_outputs_per_class: + if k not in image_other_outputs: + image_other_outputs[k] = [] image_other_outputs[k].append(v) detections.append({ @@ -510,7 +515,7 @@ def forward(self, images, targets=None): anchors = self.anchor_generator(images, features) losses = {} - detections = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) + detections = torch.jit.annotate(List[Dict[str, Tensor]], []) if self.training: assert targets is not None From 954505990df0298214e9aee8279a807f1ec9312c Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 22 May 2020 10:15:57 +0200 Subject: [PATCH 29/31] Fix inconsistency with matched_idxs. --- torchvision/models/detection/retinanet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 2821c4913b9..9d247d84f1f 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -182,8 +182,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ zip(targets, bbox_regression, anchors, matched_idxs): # no matched_idxs means there were no annotations in this image - if matched_idxs_per_image is None: - loss.append(0) + if matched_idxs_per_image.numel() == 0: continue # get the targets corresponding GT for each proposal @@ -386,7 +385,7 @@ def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.empty((0,))) + matched_idxs.append(torch.empty((0,), dtype=torch.int32)) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) From 48659523c85eea16d8be9c461de43dde7b44a69d Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Tue, 26 May 2020 18:11:45 +0200 Subject: [PATCH 30/31] Add retinanet model test. --- ...lTester.test_retinanet_resnet50_fpn_expect.pkl | Bin 0 -> 592 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl diff --git a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..aed2976cb6ce0b9198d495ceb8c3a4daa7a75903 GIT binary patch literal 592 zcmZvY%}(1u6ovhlP^Jyg0&Qt&Vd-w75KJ8J5gWJsxSQD^%dtnXtC1&nJVtG$NNhl9 zpP>)XSK|?oSa5BJC<{E2r8#r&cka2@#$Uo8{+-h7tb~lyZdj^}7bKKyRAxL070U}D!>c|4 z)3)hK<$;yan1R)_kP|K<^}P`ZwCEbvc#h-soP#5`ci8s^(5}acCnZjxoho>7;cIxT zTUhliU%v%1i!nK8vBKa5Ig<&KEIH0%Md0O^0A4BR1h7Wim47TGPQ0GtK&h2342Rxe z=z4D7IT*l3+8DjAjc(wLuiwDtEUj2u3bq4yTeEmHV!VKN4Qn3gE3Rsh=lxS2#2>ou xaL{*purtShnBzZI_}xD>{*!{w0qm7gvuFFcetk(?%xpY~%bPP<7b8FWJp$d3nMD8q literal 0 HcmV?d00001 From 6e065be6f1857cdf91d2eb674091ca83ae7e473f Mon Sep 17 00:00:00 2001 From: Hans Gaiser Date: Fri, 25 Sep 2020 12:09:37 +0200 Subject: [PATCH 31/31] Add missing JIT annotations. --- torchvision/models/detection/retinanet.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 9d247d84f1f..3e20dd04e9e 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -152,6 +152,9 @@ class RetinaNetRegressionHead(nn.Module): in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted """ + __annotations__ = { + 'box_coder': det_utils.BoxCoder, + } def __init__(self, in_channels, num_anchors): super().__init__() @@ -314,6 +317,10 @@ class RetinaNet(nn.Module): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ + __annotations__ = { + 'box_coder': det_utils.BoxCoder, + 'proposal_matcher': det_utils.Matcher, + } def __init__(self, backbone, num_classes, # transform parameters