Skip to content

Commit 15848ed

Browse files
authored
Fix deprecation warning in nonzero (#2705)
Replace nonzero by where, now that it works with just a condition
1 parent 6a43a1f commit 15848ed

File tree

6 files changed

+16
-16
lines changed

6 files changed

+16
-16
lines changed

torchvision/models/detection/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def __call__(self, matched_idxs):
4141
pos_idx = []
4242
neg_idx = []
4343
for matched_idxs_per_image in matched_idxs:
44-
positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
45-
negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
44+
positive = torch.where(matched_idxs_per_image >= 1)[0]
45+
negative = torch.where(matched_idxs_per_image == 0)[0]
4646

4747
num_pos = int(self.batch_size_per_image * self.positive_fraction)
4848
# protect against not enough positive examples
@@ -317,7 +317,7 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
317317
# For each gt, find the prediction with which it has highest quality
318318
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
319319
# Find highest quality match available, even if it is low, including ties
320-
gt_pred_pairs_of_highest_quality = torch.nonzero(
320+
gt_pred_pairs_of_highest_quality = torch.where(
321321
match_quality_matrix == highest_quality_foreach_gt[:, None]
322322
)
323323
# Example gt_pred_pairs_of_highest_quality:
@@ -334,7 +334,7 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
334334
# Each row is a (gt index, prediction index)
335335
# Note how gt items 1, 2, 3, and 5 each have two ties
336336

337-
pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
337+
pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
338338
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
339339

340340

torchvision/models/detection/generalized_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def forward(self, images, targets=None):
8787
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
8888
if degenerate_boxes.any():
8989
# print the first degenerate box
90-
bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
90+
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
9191
degen_bb: List[float] = boxes[bb_idx].tolist()
9292
raise ValueError("All bounding boxes should have positive height and width."
9393
" Found invalid box {} for target at index {}."

torchvision/models/detection/roi_heads.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
3737
# get indices that correspond to the regression targets for
3838
# the corresponding ground truth labels, to be used with
3939
# advanced indexing
40-
sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
40+
sampled_pos_inds_subset = torch.where(labels > 0)[0]
4141
labels_pos = labels[sampled_pos_inds_subset]
4242
N, num_classes = class_logits.shape
4343
box_regression = box_regression.reshape(N, -1, 4)
@@ -296,7 +296,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
296296

297297
keypoint_targets = torch.cat(heatmaps, dim=0)
298298
valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
299-
valid = torch.nonzero(valid).squeeze(1)
299+
valid = torch.where(valid)[0]
300300

301301
# torch.mean (in binary_cross_entropy_with_logits) does'nt
302302
# accept empty tensors, so handle it sepaartely
@@ -604,7 +604,7 @@ def subsample(self, labels):
604604
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
605605
zip(sampled_pos_inds, sampled_neg_inds)
606606
):
607-
img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
607+
img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
608608
sampled_inds.append(img_sampled_inds)
609609
return sampled_inds
610610

@@ -700,7 +700,7 @@ def postprocess_detections(self,
700700
labels = labels.reshape(-1)
701701

702702
# remove low scoring boxes
703-
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
703+
inds = torch.where(scores > self.score_thresh)[0]
704704
boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
705705

706706
# remove empty boxes
@@ -784,7 +784,7 @@ def forward(self,
784784
mask_proposals = []
785785
pos_matched_idxs = []
786786
for img_id in range(num_images):
787-
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
787+
pos = torch.where(labels[img_id] > 0)[0]
788788
mask_proposals.append(proposals[img_id][pos])
789789
pos_matched_idxs.append(matched_idxs[img_id][pos])
790790
else:
@@ -832,7 +832,7 @@ def forward(self,
832832
pos_matched_idxs = []
833833
assert matched_idxs is not None
834834
for img_id in range(num_images):
835-
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
835+
pos = torch.where(labels[img_id] > 0)[0]
836836
keypoint_proposals.append(proposals[img_id][pos])
837837
pos_matched_idxs.append(matched_idxs[img_id][pos])
838838
else:

torchvision/models/detection/rpn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
430430
"""
431431

432432
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
433-
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
434-
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
433+
sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
434+
sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
435435

436436
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
437437

torchvision/ops/boxes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
100100
"""
101101
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
102102
keep = (ws >= min_size) & (hs >= min_size)
103-
keep = keep.nonzero().squeeze(1)
103+
keep = torch.where(keep)[0]
104104
return keep
105105

106106

torchvision/ops/poolers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor
2424
first_result.size(2), first_result.size(3)),
2525
dtype=dtype, device=device)
2626
for level in range(len(unmerged_results)):
27-
index = (levels == level).nonzero().view(-1, 1, 1, 1)
27+
index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
2828
index = index.expand(index.size(0),
2929
unmerged_results[level].size(1),
3030
unmerged_results[level].size(2),
@@ -234,7 +234,7 @@ def forward(
234234

235235
tracing_results = []
236236
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
237-
idx_in_level = torch.nonzero(levels == level).squeeze(1)
237+
idx_in_level = torch.where(levels == level)[0]
238238
rois_per_level = rois[idx_in_level]
239239

240240
result_idx_in_level = roi_align(

0 commit comments

Comments
 (0)