We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c62ec9d commit 6a46040Copy full SHA for 6a46040
torchvision/models/detection/roi_heads.py
@@ -40,7 +40,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
40
sampled_pos_inds_subset = torch.where(labels > 0)[0]
41
labels_pos = labels[sampled_pos_inds_subset]
42
N, num_classes = class_logits.shape
43
- box_regression = box_regression.reshape(N, -1, 4)
+ box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
44
45
box_loss = det_utils.smooth_l1_loss(
46
box_regression[sampled_pos_inds_subset, labels_pos],
0 commit comments