Skip to content

Commit 693a671

Browse files
committed
Fix mypy type annotations
1 parent ae228fe commit 693a671

File tree

9 files changed

+94
-70
lines changed

9 files changed

+94
-70
lines changed

torchvision/models/detection/_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class BalancedPositiveNegativeSampler(object):
2222
"""
2323

2424
def __init__(self, batch_size_per_image, positive_fraction):
25-
# type: (int, float)
25+
# type: (int, float) -> None
2626
"""
2727
Arguments:
2828
batch_size_per_image (int): number of elements to be selected per image
@@ -32,7 +32,7 @@ def __init__(self, batch_size_per_image, positive_fraction):
3232
self.positive_fraction = positive_fraction
3333

3434
def __call__(self, matched_idxs):
35-
# type: (List[Tensor])
35+
# type: (List[Tensor]) -> (List[Tensor], List[Tensor])
3636
"""
3737
Arguments:
3838
matched idxs: list of tensors containing -1, 0 or positive values.
@@ -141,7 +141,7 @@ class BoxCoder(object):
141141
"""
142142

143143
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
144-
# type: (Tuple[float, float, float, float], float)
144+
# type: (Tuple[float, float, float, float], float) -> None
145145
"""
146146
Arguments:
147147
weights (4-element tuple)
@@ -151,7 +151,7 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
151151
self.bbox_xform_clip = bbox_xform_clip
152152

153153
def encode(self, reference_boxes, proposals):
154-
# type: (List[Tensor], List[Tensor])
154+
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
155155
boxes_per_image = [len(b) for b in reference_boxes]
156156
reference_boxes = torch.cat(reference_boxes, dim=0)
157157
proposals = torch.cat(proposals, dim=0)
@@ -175,7 +175,7 @@ def encode_single(self, reference_boxes, proposals):
175175
return targets
176176

177177
def decode(self, rel_codes, boxes):
178-
# type: (Tensor, List[Tensor])
178+
# type: (Tensor, List[Tensor]) -> Tensor
179179
assert isinstance(boxes, (list, tuple))
180180
assert isinstance(rel_codes, torch.Tensor)
181181
boxes_per_image = [b.size(0) for b in boxes]
@@ -253,7 +253,7 @@ class Matcher(object):
253253
}
254254

255255
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
256-
# type: (float, float, bool)
256+
# type: (float, float, bool) -> None
257257
"""
258258
Args:
259259
high_threshold (float): quality values greater than or equal to

torchvision/models/detection/generalized_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def eager_outputs(self, losses, detections):
4242
return detections
4343

4444
def forward(self, images, targets=None):
45-
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
45+
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
4646
"""
4747
Arguments:
4848
images (list[Tensor]): images to be processed

torchvision/models/detection/image_list.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ImageList(object):
1616
"""
1717

1818
def __init__(self, tensors, image_sizes):
19-
# type: (Tensor, List[Tuple[int, int]])
19+
# type: (Tensor, List[Tuple[int, int]]) -> None
2020
"""
2121
Arguments:
2222
tensors (tensor)
@@ -26,6 +26,6 @@ def __init__(self, tensors, image_sizes):
2626
self.image_sizes = image_sizes
2727

2828
def to(self, device):
29-
# type: (Device) # noqa
29+
# type: (Device) -> ImageList # noqa
3030
cast_tensor = self.tensors.to(device)
3131
return ImageList(cast_tensor, self.image_sizes)

torchvision/models/detection/roi_heads.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
19-
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
19+
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> (Tensor, Tensor)
2020
"""
2121
Computes the loss for Faster R-CNN.
2222
@@ -55,7 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
5555

5656

5757
def maskrcnn_inference(x, labels):
58-
# type: (Tensor, List[Tensor])
58+
# type: (Tensor, List[Tensor]) -> List[Tensor]
5959
"""
6060
From the results of the CNN, post process the masks
6161
by taking the mask corresponding to the class with max
@@ -91,7 +91,7 @@ def maskrcnn_inference(x, labels):
9191

9292

9393
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
94-
# type: (Tensor, Tensor, Tensor, int)
94+
# type: (Tensor, Tensor, Tensor, int) -> Tensor
9595
"""
9696
Given segmentation masks and the bounding boxes corresponding
9797
to the location of the masks in the image, this function
@@ -106,7 +106,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
106106

107107

108108
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
109-
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor])
109+
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
110110
"""
111111
Arguments:
112112
proposals (list[BoxList])
@@ -139,7 +139,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
139139

140140

141141
def keypoints_to_heatmap(keypoints, rois, heatmap_size):
142-
# type: (Tensor, Tensor, int)
142+
# type: (Tensor, Tensor, int) -> (Tensor, Tensor)
143143
offset_x = rois[:, 0]
144144
offset_y = rois[:, 1]
145145
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
@@ -283,7 +283,7 @@ def heatmaps_to_keypoints(maps, rois):
283283

284284

285285
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
286-
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor])
286+
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
287287
N, K, H, W = keypoint_logits.shape
288288
assert H == W
289289
discretization_size = H
@@ -313,7 +313,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
313313

314314

315315
def keypointrcnn_inference(x, boxes):
316-
# type: (Tensor, List[Tensor])
316+
# type: (Tensor, List[Tensor]) -> (List[Tensor], List[Tensor])
317317
kp_probs = []
318318
kp_scores = []
319319

@@ -335,7 +335,7 @@ def keypointrcnn_inference(x, boxes):
335335

336336

337337
def _onnx_expand_boxes(boxes, scale):
338-
# type: (Tensor, float)
338+
# type: (Tensor, float) -> Tensor
339339
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
340340
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
341341
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
@@ -356,7 +356,7 @@ def _onnx_expand_boxes(boxes, scale):
356356
# but are kept here for the moment while we need them
357357
# temporarily for paste_mask_in_image
358358
def expand_boxes(boxes, scale):
359-
# type: (Tensor, float)
359+
# type: (Tensor, float) -> Tensor
360360
if torchvision._is_tracing():
361361
return _onnx_expand_boxes(boxes, scale)
362362
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
@@ -382,7 +382,7 @@ def expand_masks_tracing_scale(M, padding):
382382

383383

384384
def expand_masks(mask, padding):
385-
# type: (Tensor, int)
385+
# type: (Tensor, int) -> (Tensor, float)
386386
M = mask.shape[-1]
387387
if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
388388
scale = expand_masks_tracing_scale(M, padding)
@@ -393,7 +393,7 @@ def expand_masks(mask, padding):
393393

394394

395395
def paste_mask_in_image(mask, box, im_h, im_w):
396-
# type: (Tensor, Tensor, int, int)
396+
# type: (Tensor, Tensor, int, int) -> Tensor
397397
TO_REMOVE = 1
398398
w = int(box[2] - box[0] + TO_REMOVE)
399399
h = int(box[3] - box[1] + TO_REMOVE)
@@ -471,7 +471,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
471471

472472

473473
def paste_masks_in_image(masks, boxes, img_shape, padding=1):
474-
# type: (Tensor, Tensor, Tuple[int, int], int)
474+
# type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
475475
masks, scale = expand_masks(masks, padding=padding)
476476
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
477477
im_h, im_w = img_shape
@@ -570,7 +570,7 @@ def has_keypoint(self):
570570
return True
571571

572572
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
573-
# type: (List[Tensor], List[Tensor], List[Tensor])
573+
# type: (List[Tensor], List[Tensor], List[Tensor]) -> (List[Tensor], List[Tensor])
574574
matched_idxs = []
575575
labels = []
576576
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
@@ -607,7 +607,7 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
607607
return matched_idxs, labels
608608

609609
def subsample(self, labels):
610-
# type: (List[Tensor])
610+
# type: (List[Tensor]) -> List[Tensor]
611611
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
612612
sampled_inds = []
613613
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
@@ -618,7 +618,7 @@ def subsample(self, labels):
618618
return sampled_inds
619619

620620
def add_gt_proposals(self, proposals, gt_boxes):
621-
# type: (List[Tensor], List[Tensor])
621+
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
622622
proposals = [
623623
torch.cat((proposal, gt_box))
624624
for proposal, gt_box in zip(proposals, gt_boxes)
@@ -627,22 +627,25 @@ def add_gt_proposals(self, proposals, gt_boxes):
627627
return proposals
628628

629629
def DELTEME_all(self, the_list):
630-
# type: (List[bool])
630+
# type: (List[bool]) -> bool
631631
for i in the_list:
632632
if not i:
633633
return False
634634
return True
635635

636636
def check_targets(self, targets):
637-
# type: (Optional[List[Dict[str, Tensor]]])
637+
# type: (Optional[List[Dict[str, Tensor]]]) -> None
638638
assert targets is not None
639639
assert self.DELTEME_all(["boxes" in t for t in targets])
640640
assert self.DELTEME_all(["labels" in t for t in targets])
641641
if self.has_mask():
642642
assert self.DELTEME_all(["masks" in t for t in targets])
643643

644-
def select_training_samples(self, proposals, targets):
645-
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
644+
def select_training_samples(self,
645+
proposals, # type: List[Tensor]
646+
targets # type: Optional[List[Dict[str, Tensor]]]
647+
):
648+
# type: (...) -> (List[Tensor], List[Tensor], List[Tensor], List[Tensor])
646649
self.check_targets(targets)
647650
assert targets is not None
648651
dtype = proposals[0].dtype
@@ -674,8 +677,13 @@ def select_training_samples(self, proposals, targets):
674677
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
675678
return proposals, matched_idxs, labels, regression_targets
676679

677-
def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
678-
# type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
680+
def postprocess_detections(self,
681+
class_logits, # type: Tensor
682+
box_regression, # type: Tensor
683+
proposals, # type: List[Tensor]
684+
image_shapes # type: List[Tuple[int, int]]
685+
):
686+
# type: (...) -> (List[Tensor], List[Tensor], List[Tensor])
679687
device = class_logits.device
680688
num_classes = class_logits.shape[-1]
681689

@@ -734,8 +742,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
734742

735743
return all_boxes, all_scores, all_labels
736744

737-
def forward(self, features, proposals, image_shapes, targets=None):
738-
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
745+
def forward(self,
746+
features, # type: Dict[str, Tensor]
747+
proposals, # type: List[Tensor]
748+
image_shapes, # type: List[Tuple[int, int]]
749+
targets=None # type: Optional[List[Dict[str, Tensor]]]
750+
):
751+
# type: (...) -> (List[Dict[str, Tensor]], Dict[str, Tensor])
739752
"""
740753
Arguments:
741754
features (List[Tensor])

torchvision/models/detection/rpn.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
7878
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
7979
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
80-
# type: (List[int], List[float], int, Device) # noqa: F821
80+
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
8181
scales = torch.as_tensor(scales, dtype=dtype, device=device)
8282
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
8383
h_ratios = torch.sqrt(aspect_ratios)
@@ -90,7 +90,7 @@ def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="c
9090
return base_anchors.round()
9191

9292
def set_cell_anchors(self, dtype, device):
93-
# type: (int, Device) -> None # noqa: F821
93+
# type: (int, Device) -> None # noqa: F821
9494
if self.cell_anchors is not None:
9595
cell_anchors = self.cell_anchors
9696
assert cell_anchors is not None
@@ -116,7 +116,7 @@ def num_anchors_per_location(self):
116116
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
117117
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
118118
def grid_anchors(self, grid_sizes, strides):
119-
# type: (List[List[int]], List[List[int]])
119+
# type: (List[List[int]], List[List[int]]) -> List[Tensor]
120120
anchors = []
121121
cell_anchors = self.cell_anchors
122122
assert cell_anchors is not None
@@ -153,7 +153,7 @@ def grid_anchors(self, grid_sizes, strides):
153153
return anchors
154154

155155
def cached_grid_anchors(self, grid_sizes, strides):
156-
# type: (List[List[int]], List[List[int]])
156+
# type: (List[List[int]], List[List[int]]) -> List[Tensor]
157157
key = str(grid_sizes + strides)
158158
if key in self._cache:
159159
return self._cache[key]
@@ -162,7 +162,7 @@ def cached_grid_anchors(self, grid_sizes, strides):
162162
return anchors
163163

164164
def forward(self, image_list, feature_maps):
165-
# type: (ImageList, List[Tensor])
165+
# type: (ImageList, List[Tensor]) -> List[List[Tensor]]
166166
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
167167
image_size = image_list.tensors.shape[-2:]
168168
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
@@ -206,7 +206,7 @@ def __init__(self, in_channels, num_anchors):
206206
torch.nn.init.constant_(l.bias, 0)
207207

208208
def forward(self, x):
209-
# type: (List[Tensor])
209+
# type: (List[Tensor]) -> (List[Tensor], List[Tensor])
210210
logits = []
211211
bbox_reg = []
212212
for feature in x:
@@ -217,15 +217,15 @@ def forward(self, x):
217217

218218

219219
def permute_and_flatten(layer, N, A, C, H, W):
220-
# type: (Tensor, int, int, int, int, int)
220+
# type: (Tensor, int, int, int, int, int) -> Tensor
221221
layer = layer.view(N, -1, C, H, W)
222222
layer = layer.permute(0, 3, 4, 1, 2)
223223
layer = layer.reshape(N, -1, C)
224224
return layer
225225

226226

227227
def concat_box_prediction_layers(box_cls, box_regression):
228-
# type: (List[Tensor], List[Tensor])
228+
# type: (List[Tensor], List[Tensor]) -> (Tensor, Tensor)
229229
box_cls_flattened = []
230230
box_regression_flattened = []
231231
# for each feature level, permute the outputs to make them be in the
@@ -331,7 +331,7 @@ def post_nms_top_n(self):
331331
return self._post_nms_top_n['testing']
332332

333333
def assign_targets_to_anchors(self, anchors, targets):
334-
# type: (List[Tensor], List[Dict[str, Tensor]])
334+
# type: (List[Tensor], List[Dict[str, Tensor]]) -> (List[Tensor], List[Tensor])
335335
labels = []
336336
matched_gt_boxes = []
337337
for anchors_per_image, targets_per_image in zip(anchors, targets):
@@ -367,7 +367,7 @@ def assign_targets_to_anchors(self, anchors, targets):
367367
return labels, matched_gt_boxes
368368

369369
def _get_top_n_idx(self, objectness, num_anchors_per_level):
370-
# type: (Tensor, List[int])
370+
# type: (Tensor, List[int]) -> Tensor
371371
r = []
372372
offset = 0
373373
for ob in objectness.split(num_anchors_per_level, 1):
@@ -382,7 +382,7 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
382382
return torch.cat(r, dim=1)
383383

384384
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
385-
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int])
385+
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> (List[Tensor], List[Tensor])
386386
num_images = proposals.shape[0]
387387
device = proposals.device
388388
# do not backprop throught objectness
@@ -422,7 +422,7 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
422422
return final_boxes, final_scores
423423

424424
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
425-
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
425+
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> (Tensor, Tensor)
426426
"""
427427
Arguments:
428428
objectness (Tensor)
@@ -458,8 +458,12 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
458458

459459
return objectness_loss, box_loss
460460

461-
def forward(self, images, features, targets=None):
462-
# type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]])
461+
def forward(self,
462+
images, # type: ImageList
463+
features, # type: Dict[str, Tensor]
464+
targets=None # type: Optional[List[Dict[str, Tensor]]]
465+
):
466+
# type: (...) -> (List[Tensor], Dict[str, Tensor])
463467
"""
464468
Arguments:
465469
images (ImageList): images for which we want to compute the predictions

0 commit comments

Comments
 (0)