Skip to content

Commit 53d9ef2

Browse files
committed
Fix d/c IoU for different batch sizes (pytorch#6338)
1 parent 053feed commit 53d9ef2

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

torchvision/ops/boxes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
325325

326326
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
327327

328-
w_pred = boxes1[:, 2] - boxes1[:, 0]
329-
h_pred = boxes1[:, 3] - boxes1[:, 1]
328+
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
329+
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
330330

331331
w_gt = boxes2[:, 2] - boxes2[:, 0]
332332
h_gt = boxes2[:, 3] - boxes2[:, 1]
333333

334-
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
334+
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
335335
with torch.no_grad():
336336
alpha = v / (1 - iou + v + eps)
337337
return diou - alpha * v
@@ -358,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
358358

359359
boxes1 = _upcast(boxes1)
360360
boxes2 = _upcast(boxes2)
361-
diou, _ = _box_diou_iou(boxes1, boxes2)
361+
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
362362
return diou
363363

364364

@@ -375,7 +375,9 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
375375
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
376376
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
377377
# The distance between boxes' centers squared.
378-
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
378+
centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + (
379+
_upcast((y_p[:, None] - y_g[None, :])) ** 2
380+
)
379381
# The distance IoU is the IoU penalized by a normalized
380382
# distance between boxes' centers squared.
381383
return iou - (centers_distance_squared / diagonal_distance_squared), iou

torchvision/ops/ciou_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def complete_box_iou_loss(
1414

1515
"""
1616
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
17-
boxes do not overlap overlap area, This loss function considers important geometrical
18-
factors such as overlap area, normalized central point distance and aspect ratio.
17+
boxes do not overlap. This loss function considers important geometrical
18+
factors such as overlap area, normalized central point distance and aspect ratio.
1919
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
2020
2121
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
@@ -35,7 +35,7 @@ def complete_box_iou_loss(
3535
Tensor: Loss tensor with the reduction option applied.
3636
3737
Reference:
38-
Zhaohui Zheng et. al: Complete Intersection over Union Loss:
38+
Zhaohui Zheng et al.: Complete Intersection over Union Loss:
3939
https://arxiv.org/abs/1911.08287
4040
4141
"""

0 commit comments

Comments
 (0)