16
16
17
17
18
18
def fastrcnn_loss (class_logits , box_regression , labels , regression_targets ):
19
- # type: (Tensor, Tensor, List[Tensor], List[Tensor])
19
+ # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> (Tensor, Tensor)
20
20
"""
21
21
Computes the loss for Faster R-CNN.
22
22
@@ -55,7 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
55
55
56
56
57
57
def maskrcnn_inference (x , labels ):
58
- # type: (Tensor, List[Tensor])
58
+ # type: (Tensor, List[Tensor]) -> List[Tensor]
59
59
"""
60
60
From the results of the CNN, post process the masks
61
61
by taking the mask corresponding to the class with max
@@ -91,7 +91,7 @@ def maskrcnn_inference(x, labels):
91
91
92
92
93
93
def project_masks_on_boxes (gt_masks , boxes , matched_idxs , M ):
94
- # type: (Tensor, Tensor, Tensor, int)
94
+ # type: (Tensor, Tensor, Tensor, int) -> Tensor
95
95
"""
96
96
Given segmentation masks and the bounding boxes corresponding
97
97
to the location of the masks in the image, this function
@@ -106,7 +106,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
106
106
107
107
108
108
def maskrcnn_loss (mask_logits , proposals , gt_masks , gt_labels , mask_matched_idxs ):
109
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor])
109
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
110
110
"""
111
111
Arguments:
112
112
proposals (list[BoxList])
@@ -139,7 +139,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
139
139
140
140
141
141
def keypoints_to_heatmap (keypoints , rois , heatmap_size ):
142
- # type: (Tensor, Tensor, int)
142
+ # type: (Tensor, Tensor, int) -> (Tensor, Tensor)
143
143
offset_x = rois [:, 0 ]
144
144
offset_y = rois [:, 1 ]
145
145
scale_x = heatmap_size / (rois [:, 2 ] - rois [:, 0 ])
@@ -283,7 +283,7 @@ def heatmaps_to_keypoints(maps, rois):
283
283
284
284
285
285
def keypointrcnn_loss (keypoint_logits , proposals , gt_keypoints , keypoint_matched_idxs ):
286
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor])
286
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
287
287
N , K , H , W = keypoint_logits .shape
288
288
assert H == W
289
289
discretization_size = H
@@ -313,7 +313,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
313
313
314
314
315
315
def keypointrcnn_inference (x , boxes ):
316
- # type: (Tensor, List[Tensor])
316
+ # type: (Tensor, List[Tensor]) -> (List[Tensor], List[Tensor])
317
317
kp_probs = []
318
318
kp_scores = []
319
319
@@ -335,7 +335,7 @@ def keypointrcnn_inference(x, boxes):
335
335
336
336
337
337
def _onnx_expand_boxes (boxes , scale ):
338
- # type: (Tensor, float)
338
+ # type: (Tensor, float) -> Tensor
339
339
w_half = (boxes [:, 2 ] - boxes [:, 0 ]) * .5
340
340
h_half = (boxes [:, 3 ] - boxes [:, 1 ]) * .5
341
341
x_c = (boxes [:, 2 ] + boxes [:, 0 ]) * .5
@@ -356,7 +356,7 @@ def _onnx_expand_boxes(boxes, scale):
356
356
# but are kept here for the moment while we need them
357
357
# temporarily for paste_mask_in_image
358
358
def expand_boxes (boxes , scale ):
359
- # type: (Tensor, float)
359
+ # type: (Tensor, float) -> Tensor
360
360
if torchvision ._is_tracing ():
361
361
return _onnx_expand_boxes (boxes , scale )
362
362
w_half = (boxes [:, 2 ] - boxes [:, 0 ]) * .5
@@ -382,7 +382,7 @@ def expand_masks_tracing_scale(M, padding):
382
382
383
383
384
384
def expand_masks (mask , padding ):
385
- # type: (Tensor, int)
385
+ # type: (Tensor, int) -> (Tensor, float)
386
386
M = mask .shape [- 1 ]
387
387
if torch ._C ._get_tracing_state (): # could not import is_tracing(), not sure why
388
388
scale = expand_masks_tracing_scale (M , padding )
@@ -393,7 +393,7 @@ def expand_masks(mask, padding):
393
393
394
394
395
395
def paste_mask_in_image (mask , box , im_h , im_w ):
396
- # type: (Tensor, Tensor, int, int)
396
+ # type: (Tensor, Tensor, int, int) -> Tensor
397
397
TO_REMOVE = 1
398
398
w = int (box [2 ] - box [0 ] + TO_REMOVE )
399
399
h = int (box [3 ] - box [1 ] + TO_REMOVE )
@@ -471,7 +471,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
471
471
472
472
473
473
def paste_masks_in_image (masks , boxes , img_shape , padding = 1 ):
474
- # type: (Tensor, Tensor, Tuple[int, int], int)
474
+ # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
475
475
masks , scale = expand_masks (masks , padding = padding )
476
476
boxes = expand_boxes (boxes , scale ).to (dtype = torch .int64 )
477
477
im_h , im_w = img_shape
@@ -570,7 +570,7 @@ def has_keypoint(self):
570
570
return True
571
571
572
572
def assign_targets_to_proposals (self , proposals , gt_boxes , gt_labels ):
573
- # type: (List[Tensor], List[Tensor], List[Tensor])
573
+ # type: (List[Tensor], List[Tensor], List[Tensor]) -> (List[Tensor], List[Tensor])
574
574
matched_idxs = []
575
575
labels = []
576
576
for proposals_in_image , gt_boxes_in_image , gt_labels_in_image in zip (proposals , gt_boxes , gt_labels ):
@@ -607,7 +607,7 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
607
607
return matched_idxs , labels
608
608
609
609
def subsample (self , labels ):
610
- # type: (List[Tensor])
610
+ # type: (List[Tensor]) -> List[Tensor]
611
611
sampled_pos_inds , sampled_neg_inds = self .fg_bg_sampler (labels )
612
612
sampled_inds = []
613
613
for img_idx , (pos_inds_img , neg_inds_img ) in enumerate (
@@ -618,7 +618,7 @@ def subsample(self, labels):
618
618
return sampled_inds
619
619
620
620
def add_gt_proposals (self , proposals , gt_boxes ):
621
- # type: (List[Tensor], List[Tensor])
621
+ # type: (List[Tensor], List[Tensor]) -> List[Tensor]
622
622
proposals = [
623
623
torch .cat ((proposal , gt_box ))
624
624
for proposal , gt_box in zip (proposals , gt_boxes )
@@ -627,22 +627,25 @@ def add_gt_proposals(self, proposals, gt_boxes):
627
627
return proposals
628
628
629
629
def DELTEME_all (self , the_list ):
630
- # type: (List[bool])
630
+ # type: (List[bool]) -> bool
631
631
for i in the_list :
632
632
if not i :
633
633
return False
634
634
return True
635
635
636
636
def check_targets (self , targets ):
637
- # type: (Optional[List[Dict[str, Tensor]]])
637
+ # type: (Optional[List[Dict[str, Tensor]]]) -> None
638
638
assert targets is not None
639
639
assert self .DELTEME_all (["boxes" in t for t in targets ])
640
640
assert self .DELTEME_all (["labels" in t for t in targets ])
641
641
if self .has_mask ():
642
642
assert self .DELTEME_all (["masks" in t for t in targets ])
643
643
644
- def select_training_samples (self , proposals , targets ):
645
- # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
644
+ def select_training_samples (self ,
645
+ proposals , # type: List[Tensor]
646
+ targets # type: Optional[List[Dict[str, Tensor]]]
647
+ ):
648
+ # type: (...) -> (List[Tensor], List[Tensor], List[Tensor], List[Tensor])
646
649
self .check_targets (targets )
647
650
assert targets is not None
648
651
dtype = proposals [0 ].dtype
@@ -674,8 +677,13 @@ def select_training_samples(self, proposals, targets):
674
677
regression_targets = self .box_coder .encode (matched_gt_boxes , proposals )
675
678
return proposals , matched_idxs , labels , regression_targets
676
679
677
- def postprocess_detections (self , class_logits , box_regression , proposals , image_shapes ):
678
- # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
680
+ def postprocess_detections (self ,
681
+ class_logits , # type: Tensor
682
+ box_regression , # type: Tensor
683
+ proposals , # type: List[Tensor]
684
+ image_shapes # type: List[Tuple[int, int]]
685
+ ):
686
+ # type: (...) -> (List[Tensor], List[Tensor], List[Tensor])
679
687
device = class_logits .device
680
688
num_classes = class_logits .shape [- 1 ]
681
689
@@ -734,8 +742,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
734
742
735
743
return all_boxes , all_scores , all_labels
736
744
737
- def forward (self , features , proposals , image_shapes , targets = None ):
738
- # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
745
+ def forward (self ,
746
+ features , # type: Dict[str, Tensor]
747
+ proposals , # type: List[Tensor]
748
+ image_shapes , # type: List[Tuple[int, int]]
749
+ targets = None # type: Optional[List[Dict[str, Tensor]]]
750
+ ):
751
+ # type: (...) -> (List[Dict[str, Tensor]], Dict[str, Tensor])
739
752
"""
740
753
Arguments:
741
754
features (List[Tensor])
0 commit comments