diff --git a/mypy.ini b/mypy.ini index dbcaab8770c..4f7271b4a1f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -33,10 +33,6 @@ ignore_errors = True ignore_errors = True -[mypy-torchvision.models.detection.roi_heads] - -ignore_errors = True - [mypy-torchvision.models.detection.generalized_rcnn] ignore_errors = True diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 35aee4b7d54..219774525f9 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -5,13 +5,14 @@ import torchvision from torch import nn, Tensor from torchvision.ops import boxes as box_ops -from torchvision.ops import roi_align +from torchvision.ops.roi_align import roi_align from . import _utils as det_utils -def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): - # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] +def fastrcnn_loss( + class_logits: Tensor, box_regression: Tensor, labels: List[Tensor], regression_targets: List[Tensor] +) -> Tuple[Tensor, Tensor]: """ Computes the loss for Faster R-CNN. @@ -50,8 +51,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): return classification_loss, box_loss -def maskrcnn_inference(x, labels): - # type: (Tensor, List[Tensor]) -> List[Tensor] +def maskrcnn_inference(x: Tensor, labels: List[Tensor]) -> List[Tensor]: """ From the results of the CNN, post process the masks by taking the mask corresponding to the class with max @@ -80,8 +80,7 @@ def maskrcnn_inference(x, labels): return mask_prob -def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): - # type: (Tensor, Tensor, Tensor, int) -> Tensor +def project_masks_on_boxes(gt_masks: Tensor, boxes: Tensor, matched_idxs: Tensor, M: int) -> Tensor: """ Given segmentation masks and the bounding boxes corresponding to the location of the masks in the image, this function @@ -95,15 +94,22 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0] -def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): - # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor +def maskrcnn_loss( + mask_logits: Tensor, + proposals: List[Tensor], + gt_masks: List[Tensor], + gt_labels: List[Tensor], + mask_matched_idxs: List[Tensor], +) -> Tensor: """ Args: - proposals (list[BoxList]) mask_logits (Tensor) - targets (list[BoxList]) + proposals (list[Tensor]) + gt_masks (list[Tensor]) + gt_labels (list[Tensor]) + mask_matched_idxs (list[Tensor]) - Return: + Returns: mask_loss (Tensor): scalar tensor containing the loss """ @@ -127,8 +133,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs return mask_loss -def keypoints_to_heatmap(keypoints, rois, heatmap_size): - # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor] +def keypoints_to_heatmap(keypoints: Tensor, rois: Tensor, heatmap_size: int) -> Tuple[Tensor, Tensor]: offset_x = rois[:, 0] offset_y = rois[:, 1] scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) @@ -164,7 +169,14 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size): def _onnx_heatmaps_to_keypoints( - maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i + maps: Tensor, + maps_i: Tensor, + roi_map_width: Tensor, + roi_map_height: Tensor, + widths_i: Tensor, + heights_i: Tensor, + offset_x_i: Tensor, + offset_y_i: Tensor, ): num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64) @@ -202,7 +214,7 @@ def _onnx_heatmaps_to_keypoints( # TODO: simplify when indexing without rank will be supported by ONNX base = num_keypoints * num_keypoints + num_keypoints + 1 - ind = torch.arange(num_keypoints) + ind = torch.arange(int(num_keypoints)) ind = ind.to(dtype=torch.int64) * base end_scores_i = ( roi_map.index_select(1, y_int.to(dtype=torch.int64)) @@ -216,8 +228,17 @@ def _onnx_heatmaps_to_keypoints( @torch.jit._script_if_tracing def _onnx_heatmaps_to_keypoints_loop( - maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints -): + maps: Tensor, + rois: Tensor, + widths_ceil: List[Tensor], + heights_ceil: List[Tensor], + widths: List[Tensor], + heights: List[Tensor], + offset_x: List[Tensor], + offset_y: List[Tensor], + num_keypoints: Tensor, +) -> Tuple[Tensor, Tensor]: + xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device) end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device) @@ -232,7 +253,7 @@ def _onnx_heatmaps_to_keypoints_loop( return xy_preds, end_scores -def heatmaps_to_keypoints(maps, rois): +def heatmaps_to_keypoints(maps: Tensor, rois: Tensor) -> Tuple[Tensor, Tensor]: """Extract predicted keypoint locations from heatmaps. Output has shape (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) for each keypoint. @@ -296,8 +317,10 @@ def heatmaps_to_keypoints(maps, rois): return xy_preds.permute(0, 2, 1), end_scores -def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs): - # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor +def keypointrcnn_loss( + keypoint_logits: Tensor, proposals: List[Tensor], gt_keypoints: List[Tensor], keypoint_matched_idxs: List[Tensor] +) -> Tensor: + N, K, H, W = keypoint_logits.shape assert H == W discretization_size = H @@ -324,8 +347,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched return keypoint_loss -def keypointrcnn_inference(x, boxes): - # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] +def keypointrcnn_inference(x: Tensor, boxes: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: kp_probs = [] kp_scores = [] @@ -340,8 +362,7 @@ def keypointrcnn_inference(x, boxes): return kp_probs, kp_scores -def _onnx_expand_boxes(boxes, scale): - # type: (Tensor, float) -> Tensor +def _onnx_expand_boxes(boxes: Tensor, scale: float) -> Tensor: w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 @@ -361,8 +382,7 @@ def _onnx_expand_boxes(boxes, scale): # the next two functions should be merged inside Masker # but are kept here for the moment while we need them # temporarily for paste_mask_in_image -def expand_boxes(boxes, scale): - # type: (Tensor, float) -> Tensor +def expand_boxes(boxes: Tensor, scale: float) -> Tensor: if torchvision._is_tracing(): return _onnx_expand_boxes(boxes, scale) w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 @@ -382,13 +402,11 @@ def expand_boxes(boxes, scale): @torch.jit.unused -def expand_masks_tracing_scale(M, padding): - # type: (int, int) -> float - return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32) +def expand_masks_tracing_scale(M: int, padding: int) -> float: + return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32) # type: ignore[return-value] -def expand_masks(mask, padding): - # type: (Tensor, int) -> Tuple[Tensor, float] +def expand_masks(mask: Tensor, padding: int) -> Tuple[Tensor, float]: M = mask.shape[-1] if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why scale = expand_masks_tracing_scale(M, padding) @@ -398,15 +416,15 @@ def expand_masks(mask, padding): return padded_mask, scale -def paste_mask_in_image(mask, box, im_h, im_w): - # type: (Tensor, Tensor, int, int) -> Tensor +def paste_mask_in_image(mask: Tensor, box: Tensor, im_h: int, im_w: int) -> Tensor: TO_REMOVE = 1 + # Here Box is a single tensor E.g. box = torch.tensor([10, 20, 40, 60]) w = int(box[2] - box[0] + TO_REMOVE) h = int(box[3] - box[1] + TO_REMOVE) w = max(w, 1) h = max(h, 1) - # Set shape to [batchxCxHxW] + # Set shape to [batch x C x H x W] mask = mask.expand((1, 1, -1, -1)) # Resize mask @@ -414,16 +432,18 @@ def paste_mask_in_image(mask, box, im_h, im_w): mask = mask[0][0] im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) - x_0 = max(box[0], 0) - x_1 = min(box[2] + 1, im_w) - y_0 = max(box[1], 0) - y_1 = min(box[3] + 1, im_h) - im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] + # Since we are slicing, we need int + x_0 = int(max(box[0].item(), 0)) + x_1 = int(min(box[2].item() + 1, im_w)) + y_0 = int(max(box[1].item(), 0)) + y_1 = int(min(box[3].item() + 1, im_h)) + + im_mask[y_0:y_1, x_0:x_1] = mask[int(y_0 - box[1]) : int(y_1 - box[1]), int(x_0 - box[0]) : int(x_1 - box[0])] return im_mask -def _onnx_paste_mask_in_image(mask, box, im_h, im_w): +def _onnx_paste_mask_in_image(mask: Tensor, box: Tensor, im_h: Tensor, im_w: Tensor) -> Tensor: one = torch.ones(1, dtype=torch.int64) zero = torch.zeros(1, dtype=torch.int64) @@ -444,24 +464,24 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero))) y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0)))) - unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] + unpaded_im_mask = mask[int(y_0 - box[1]) : int(y_1 - box[1]), int(x_0 - box[0]) : int(x_1 - box[0])] # TODO : replace below with a dynamic padding when support is added in ONNX # pad y - zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1)) - zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1)) - concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :] + zeros_y0 = torch.zeros(int(y_0), unpaded_im_mask.size(1)) + zeros_y1 = torch.zeros(int(im_h - y_1), unpaded_im_mask.size(1)) + concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0 : int(im_h), :] # pad x - zeros_x0 = torch.zeros(concat_0.size(0), x_0) - zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1) - im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w] + zeros_x0 = torch.zeros(concat_0.size(0), int(x_0)) + zeros_x1 = torch.zeros(concat_0.size(0), int(im_w - x_1)) + im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, : int(im_w)] return im_mask @torch.jit._script_if_tracing -def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): - res_append = torch.zeros(0, im_h, im_w) +def _onnx_paste_masks_in_image_loop(masks: Tensor, boxes: Tensor, im_h: Tensor, im_w: Tensor) -> Tensor: + res_append = torch.zeros(0, int(im_h), int(im_w)) for i in range(masks.size(0)): mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w) mask_res = mask_res.unsqueeze(0) @@ -469,8 +489,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): return res_append -def paste_masks_in_image(masks, boxes, img_shape, padding=1): - # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor +def paste_masks_in_image(masks: Tensor, boxes: Tensor, img_shape: Tuple[int, int], padding: int = 1) -> Tensor: masks, scale = expand_masks(masks, padding=padding) boxes = expand_boxes(boxes, scale).to(dtype=torch.int64) im_h, im_w = img_shape @@ -496,27 +515,27 @@ class RoIHeads(nn.Module): def __init__( self, - box_roi_pool, - box_head, - box_predictor, + box_roi_pool: nn.Module, + box_head: nn.Module, + box_predictor: nn.Module, # Faster R-CNN training - fg_iou_thresh, - bg_iou_thresh, - batch_size_per_image, - positive_fraction, - bbox_reg_weights, + fg_iou_thresh: float, + bg_iou_thresh: float, + batch_size_per_image: int, + positive_fraction: float, + bbox_reg_weights: Tuple[float, float, float, float], # Faster R-CNN inference - score_thresh, - nms_thresh, - detections_per_img, + score_thresh: float, + nms_thresh: float, + detections_per_img: int, # Mask - mask_roi_pool=None, - mask_head=None, - mask_predictor=None, - keypoint_roi_pool=None, - keypoint_head=None, - keypoint_predictor=None, - ): + mask_roi_pool: Optional[nn.Module] = None, + mask_head: Optional[nn.Module] = None, + mask_predictor: Optional[nn.Module] = None, + keypoint_roi_pool: Optional[nn.Module] = None, + keypoint_head: Optional[nn.Module] = None, + keypoint_predictor: Optional[nn.Module] = None, + ) -> None: super(RoIHeads, self).__init__() self.box_similarity = box_ops.box_iou @@ -545,7 +564,7 @@ def __init__( self.keypoint_head = keypoint_head self.keypoint_predictor = keypoint_predictor - def has_mask(self): + def has_mask(self) -> bool: if self.mask_roi_pool is None: return False if self.mask_head is None: @@ -554,7 +573,7 @@ def has_mask(self): return False return True - def has_keypoint(self): + def has_keypoint(self) -> bool: if self.keypoint_roi_pool is None: return False if self.keypoint_head is None: @@ -563,8 +582,10 @@ def has_keypoint(self): return False return True - def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): - # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] + def assign_targets_to_proposals( + self, proposals: List[Tensor], gt_boxes: List[Tensor], gt_labels: List[Tensor] + ) -> Tuple[List[Tensor], List[Tensor]]: + matched_idxs = [] labels = [] for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): @@ -598,8 +619,7 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): labels.append(labels_in_image) return matched_idxs, labels - def subsample(self, labels): - # type: (List[Tensor]) -> List[Tensor] + def subsample(self, labels: List[Tensor]) -> List[Tensor]: sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_inds = [] for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)): @@ -607,14 +627,12 @@ def subsample(self, labels): sampled_inds.append(img_sampled_inds) return sampled_inds - def add_gt_proposals(self, proposals, gt_boxes): - # type: (List[Tensor], List[Tensor]) -> List[Tensor] + def add_gt_proposals(self, proposals: List[Tensor], gt_boxes: List[Tensor]) -> List[Tensor]: proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)] return proposals - def check_targets(self, targets): - # type: (Optional[List[Dict[str, Tensor]]]) -> None + def check_targets(self, targets: Optional[List[Dict[str, Tensor]]] = None) -> None: assert targets is not None assert all(["boxes" in t for t in targets]) assert all(["labels" in t for t in targets]) @@ -623,10 +641,10 @@ def check_targets(self, targets): def select_training_samples( self, - proposals, # type: List[Tensor] - targets, # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]] + proposals: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: + self.check_targets(targets) assert targets is not None dtype = proposals[0].dtype @@ -660,12 +678,12 @@ def select_training_samples( def postprocess_detections( self, - class_logits, # type: Tensor - box_regression, # type: Tensor - proposals, # type: List[Tensor] - image_shapes, # type: List[Tuple[int, int]] - ): - # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]] + class_logits: Tensor, + box_regression: Tensor, + proposals: List[Tensor], + image_shapes: List[Tuple[int, int]], + ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: + device = class_logits.device num_classes = class_logits.shape[-1] @@ -719,12 +737,11 @@ def postprocess_detections( def forward( self, - features, # type: Dict[str, Tensor] - proposals, # type: List[Tensor] - image_shapes, # type: List[Tuple[int, int]] - targets=None, # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] + features: Dict[str, Tensor], + proposals: List[Tensor], + image_shapes: List[Tuple[int, int]], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]: """ Args: features (List[Tensor]) @@ -744,9 +761,9 @@ def forward( if self.training: proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) else: - labels = None - regression_targets = None - matched_idxs = None + labels = None # type: ignore[assignment] + regression_targets = None # type: ignore[assignment] + matched_idxs = None # type: ignore[assignment] box_features = self.box_roi_pool(features, proposals, image_shapes) box_features = self.box_head(box_features) @@ -783,12 +800,12 @@ def forward( mask_proposals.append(proposals[img_id][pos]) pos_matched_idxs.append(matched_idxs[img_id][pos]) else: - pos_matched_idxs = None + pos_matched_idxs = None # type: ignore[assignment] if self.mask_roi_pool is not None: mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes) - mask_features = self.mask_head(mask_features) - mask_logits = self.mask_predictor(mask_features) + mask_features = self.mask_head(mask_features) # type: ignore[misc] + mask_logits = self.mask_predictor(mask_features) # type: ignore[misc] else: raise Exception("Expected mask_roi_pool to be not None") @@ -829,7 +846,7 @@ def forward( keypoint_proposals.append(proposals[img_id][pos]) pos_matched_idxs.append(matched_idxs[img_id][pos]) else: - pos_matched_idxs = None + pos_matched_idxs = None # type: ignore[assignment] keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes) keypoint_features = self.keypoint_head(keypoint_features)