Skip to content

Commit 6e10e3f

Browse files
authored
Adds Generalized IOU (#2642)
* tries adding genaralized_iou * fixes linting * Adds docs for giou, iou and box area * fixes lint * removes docs to fixup in other PR * linter fix * Cleans comments * Adds tests for box area, iou and giou * typo fix for testCase * fixes typo * fixes box area test * fixes implementation * updates tests to tolerance
1 parent 15848ed commit 6e10e3f

File tree

4 files changed

+93
-4
lines changed

4 files changed

+93
-4
lines changed

docs/source/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ torchvision.ops
1515
.. autofunction:: clip_boxes_to_image
1616
.. autofunction:: box_area
1717
.. autofunction:: box_iou
18+
.. autofunction:: generalized_box_iou
1819
.. autofunction:: roi_align
1920
.. autofunction:: ps_roi_align
2021
.. autofunction:: roi_pool

test/test_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,5 +647,51 @@ def test_convert_boxes_to_roi_format(self):
647647
self.assertTrue(torch.equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)))
648648

649649

650+
class BoxAreaTester(unittest.TestCase):
651+
def test_box_area(self):
652+
# A bounding box of area 10000 and a degenerate case
653+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
654+
expected = torch.tensor([10000, 0])
655+
calc_area = ops.box_area(box_tensor)
656+
assert calc_area.size() == torch.Size([2])
657+
assert calc_area.dtype == box_tensor.dtype
658+
assert torch.all(torch.eq(calc_area, expected)).item() is True
659+
660+
661+
class BoxIouTester(unittest.TestCase):
662+
def test_iou(self):
663+
# Boxes to test Iou
664+
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
665+
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
666+
667+
# Expected IoU matrix for these boxes
668+
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
669+
670+
out = ops.box_iou(boxes1, boxes2)
671+
672+
# Check if all elements of tensor are as expected.
673+
assert out.size() == torch.Size([3, 3])
674+
tolerance = 1e-4
675+
assert ((out - expected).abs().max() < tolerance).item() is True
676+
677+
678+
class GenBoxIouTester(unittest.TestCase):
679+
def test_gen_iou(self):
680+
# Test Generalized IoU
681+
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
682+
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
683+
684+
# Expected gIoU matrix for these boxes
685+
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611],
686+
[-0.7778, -0.8611, 1.0]])
687+
688+
out = ops.generalized_box_iou(boxes1, boxes2)
689+
690+
# Check if all elements of tensor are as expected.
691+
assert out.size() == torch.Size([3, 3])
692+
tolerance = 1e-4
693+
assert ((out - expected).abs().max() < tolerance).item() is True
694+
695+
650696
if __name__ == '__main__':
651697
unittest.main()

torchvision/ops/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou
1+
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou
22
from .new_empty_tensor import _new_empty_tensor
33
from .deform_conv import deform_conv2d, DeformConv2d
44
from .roi_align import roi_align, RoIAlign
@@ -15,7 +15,7 @@
1515

1616
__all__ = [
1717
'deform_conv2d', 'DeformConv2d', 'nms', 'batched_nms', 'remove_small_boxes',
18-
'clip_boxes_to_image', 'box_area', 'box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
18+
'clip_boxes_to_image', 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
1919
'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
2020
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
2121
]

torchvision/ops/boxes.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
161161
boxes2 (Tensor[M, 4])
162162
163163
Returns:
164-
iou (Tensor[N, M]): the NxM matrix containing the pairwise
165-
IoU values for every element in boxes1 and boxes2
164+
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
166165
"""
167166
area1 = box_area(boxes1)
168167
area2 = box_area(boxes2)
@@ -175,3 +174,46 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
175174

176175
iou = inter / (area1[:, None] + area2 - inter)
177176
return iou
177+
178+
179+
# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
180+
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
181+
"""
182+
Return generalized intersection-over-union (Jaccard index) of boxes.
183+
184+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
185+
186+
Arguments:
187+
boxes1 (Tensor[N, 4])
188+
boxes2 (Tensor[M, 4])
189+
190+
Returns:
191+
generalized_iou (Tensor[N, M]): the NxM matrix containing the pairwise generalized_IoU values
192+
for every element in boxes1 and boxes2
193+
"""
194+
195+
# degenerate boxes gives inf / nan results
196+
# so do an early check
197+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
198+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
199+
200+
area1 = box_area(boxes1)
201+
area2 = box_area(boxes2)
202+
203+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
204+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
205+
206+
wh = (rb - lt).clamp(min=0) # [N,M,2]
207+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
208+
209+
union = area1[:, None] + area2 - inter
210+
211+
iou = inter / union
212+
213+
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
214+
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
215+
216+
whi = (rbi - lti).clamp(min=0) # [N,M,2]
217+
areai = whi[:, :, 0] * whi[:, :, 1]
218+
219+
return iou - (areai - union) / areai

0 commit comments

Comments
 (0)