Skip to content

Commit fb169a6

Browse files
committed
Make generalized_box_iou and box_iou share common code.
1 parent 059b19b commit fb169a6

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

torchvision/ops/boxes.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ def box_area(boxes: Tensor) -> Tensor:
187187

188188
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
189189
# with slight modifications
190+
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
191+
area1 = box_area(boxes1)
192+
area2 = box_area(boxes2)
193+
194+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
195+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
196+
197+
wh = (rb - lt).clamp(min=0) # [N,M,2]
198+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
199+
200+
union = area1[:, None] + area2 - inter
201+
202+
return inter, union
203+
204+
190205
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
191206
"""
192207
Return intersection-over-union (Jaccard index) of boxes.
@@ -200,16 +215,8 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
200215
Returns:
201216
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
202217
"""
203-
area1 = box_area(boxes1)
204-
area2 = box_area(boxes2)
205-
206-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
207-
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
208-
209-
wh = (rb - lt).clamp(min=0) # [N,M,2]
210-
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
211-
212-
iou = inter / (area1[:, None] + area2 - inter)
218+
inter, union = _box_inter_union(boxes1, boxes2)
219+
iou = inter / union
213220
return iou
214221

215222

@@ -234,17 +241,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
234241
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
235242
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
236243

237-
area1 = box_area(boxes1)
238-
area2 = box_area(boxes2)
239-
240-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
241-
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
242-
243-
wh = (rb - lt).clamp(min=0) # [N,M,2]
244-
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
245-
246-
union = area1[:, None] + area2 - inter
247-
244+
inter, union = _box_inter_union(boxes1, boxes2)
248245
iou = inter / union
249246

250247
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])

0 commit comments

Comments
 (0)