@@ -187,6 +187,21 @@ def box_area(boxes: Tensor) -> Tensor:
187
187
188
188
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
189
189
# with slight modifications
190
+ def _box_inter_union (boxes1 : Tensor , boxes2 : Tensor ) -> Tuple [Tensor , Tensor ]:
191
+ area1 = box_area (boxes1 )
192
+ area2 = box_area (boxes2 )
193
+
194
+ lt = torch .max (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
195
+ rb = torch .min (boxes1 [:, None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
196
+
197
+ wh = (rb - lt ).clamp (min = 0 ) # [N,M,2]
198
+ inter = wh [:, :, 0 ] * wh [:, :, 1 ] # [N,M]
199
+
200
+ union = area1 [:, None ] + area2 - inter
201
+
202
+ return inter , union
203
+
204
+
190
205
def box_iou (boxes1 : Tensor , boxes2 : Tensor ) -> Tensor :
191
206
"""
192
207
Return intersection-over-union (Jaccard index) of boxes.
@@ -200,16 +215,8 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
200
215
Returns:
201
216
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
202
217
"""
203
- area1 = box_area (boxes1 )
204
- area2 = box_area (boxes2 )
205
-
206
- lt = torch .max (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
207
- rb = torch .min (boxes1 [:, None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
208
-
209
- wh = (rb - lt ).clamp (min = 0 ) # [N,M,2]
210
- inter = wh [:, :, 0 ] * wh [:, :, 1 ] # [N,M]
211
-
212
- iou = inter / (area1 [:, None ] + area2 - inter )
218
+ inter , union = _box_inter_union (boxes1 , boxes2 )
219
+ iou = inter / union
213
220
return iou
214
221
215
222
@@ -234,17 +241,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
234
241
assert (boxes1 [:, 2 :] >= boxes1 [:, :2 ]).all ()
235
242
assert (boxes2 [:, 2 :] >= boxes2 [:, :2 ]).all ()
236
243
237
- area1 = box_area (boxes1 )
238
- area2 = box_area (boxes2 )
239
-
240
- lt = torch .max (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
241
- rb = torch .min (boxes1 [:, None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
242
-
243
- wh = (rb - lt ).clamp (min = 0 ) # [N,M,2]
244
- inter = wh [:, :, 0 ] * wh [:, :, 1 ] # [N,M]
245
-
246
- union = area1 [:, None ] + area2 - inter
247
-
244
+ inter , union = _box_inter_union (boxes1 , boxes2 )
248
245
iou = inter / union
249
246
250
247
lti = torch .min (boxes1 [:, None , :2 ], boxes2 [:, :2 ])
0 commit comments