Skip to content

Commit e8d5822

Browse files
committed
Ensuring gradient propagates to RegressionHead.
1 parent 2bed114 commit e8d5822

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -192,28 +192,25 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
192192
zip(targets, bbox_regression, anchors, matched_idxs):
193193
# no matched_idxs means there were no annotations in this image
194194
if matched_idxs_per_image.numel() == 0:
195-
device = targets_per_image['boxes'].device
196-
bbox_regression_per_image = torch.zeros_like(targets_per_image['boxes'], device=device)
197-
target_regression = torch.zeros_like(targets_per_image['boxes'], device=device)
198-
num_foreground = torch.tensor(0, dtype=torch.int64, device=device)
195+
matched_gt_boxes_per_image = torch.zeros_like(bbox_regression_per_image)
199196
else:
200197
# get the targets corresponding GT for each proposal
201198
# NB: need to clamp the indices because we can have a single
202199
# GT in the image, and matched_idxs can be -2, which goes
203200
# out of bounds
204201
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]
205202

206-
# determine only the foreground indices, ignore the rest
207-
foreground_idxs_per_image = matched_idxs_per_image >= 0
208-
num_foreground = foreground_idxs_per_image.sum()
203+
# determine only the foreground indices, ignore the rest
204+
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
205+
num_foreground = foreground_idxs_per_image.numel()
209206

210-
# select only the foreground boxes
211-
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
212-
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
213-
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
207+
# select only the foreground boxes
208+
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
209+
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
210+
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
214211

215-
# compute the regression targets
216-
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
212+
# compute the regression targets
213+
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
217214

218215
# compute the loss
219216
losses.append(torch.nn.functional.l1_loss(
@@ -404,7 +401,7 @@ def compute_loss(self, targets, head_outputs, anchors):
404401
matched_idxs = []
405402
for anchors_per_image, targets_per_image in zip(anchors, targets):
406403
if targets_per_image['boxes'].numel() == 0:
407-
matched_idxs.append(torch.empty((0,), dtype=torch.int32))
404+
matched_idxs.append(torch.empty((0,), dtype=torch.int64))
408405
continue
409406

410407
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)

0 commit comments

Comments
 (0)