Skip to content

Commit ac7ad5f

Browse files
Fix rotated box format conversion from XYXYXYXY to XYWHR (#9019)
1 parent d84aa89 commit ac7ad5f

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

test/test_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,10 @@ def test_bbox_cxcywhr_to_xyxyxyxy(self):
13841384
box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="cxcywhr", out_fmt="xyxyxyxy")
13851385
torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy)
13861386

1387+
# Reverse conversion
1388+
box_cxcywhr = ops.box_convert(box_xyxyxyxy, in_fmt="xyxyxyxy", out_fmt="cxcywhr")
1389+
torch.testing.assert_close(box_cxcywhr, box_tensor)
1390+
13871391
def test_bbox_xywhr_to_xyxyxyxy(self):
13881392
box_tensor = torch.tensor([[4, 5, 4, 2, 90]], dtype=torch.float)
13891393
exp_xyxyxyxy = torch.tensor([[4, 5, 4, 1, 6, 1, 6, 5]], dtype=torch.float)
@@ -1392,6 +1396,10 @@ def test_bbox_xywhr_to_xyxyxyxy(self):
13921396
box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="xywhr", out_fmt="xyxyxyxy")
13931397
torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy)
13941398

1399+
# Reverse conversion
1400+
box_xywhr = ops.box_convert(box_xyxyxyxy, in_fmt="xyxyxyxy", out_fmt="xywhr")
1401+
torch.testing.assert_close(box_xywhr, box_tensor)
1402+
13951403
@pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh", "xwyhr", "cxwyhr", "xxxxyyyy"])
13961404
@pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy", "xwcxr", "xhwcyr", "xyxyxxyy"])
13971405
def test_bbox_invalid(self, inv_infmt, inv_outfmt):

torchvision/ops/_box_convert.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,9 @@ def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor:
178178
x1, y1, x3, y3, x2, y2, x4, y4 = boxes.unbind(-1)
179179
r_rad = torch.atan2(y1 - y3, x3 - x1)
180180
r = r_rad * 180 / torch.pi
181-
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
182181

183-
w = (x2 - x1) * cos + (y1 - y2) * sin
184-
h = (x2 - x1) * sin + (y2 - y1) * cos
182+
w = ((x3 - x1) ** 2 + (y1 - y3) ** 2).sqrt()
183+
h = ((x3 - x2) ** 2 + (y3 - y2) ** 2).sqrt()
185184

186185
boxes = torch.stack((x1, y1, w, h, r), dim=-1)
187186

torchvision/transforms/v2/functional/_meta.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,13 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
252252
xyxyxyxy = xyxyxyxy.float()
253253

254254
r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0]))
255-
cos, sin = r_rad.cos(), r_rad.sin()
256-
# x1, y1, x3, y3, (x2 - x1), (y2 - y1) x4, y4
257-
xyxyxyxy[..., 4:6].sub_(xyxyxyxy[..., :2])
258-
# (x2 - x1) * cos + (y1 - y2) * sin = w
259-
xyxyxyxy[..., 2] = xyxyxyxy[..., 4].mul(cos).sub(xyxyxyxy[..., 5].mul(sin))
260-
# (x2 - x1) * sin + (y2 - y1) * cos = h
261-
xyxyxyxy[..., 3] = xyxyxyxy[..., 5].mul(cos).add(xyxyxyxy[..., 4].mul(sin))
255+
# x1, y1, (x3 - x1), (y3 - y1), (x2 - x3), (y2 - y3) x4, y4
256+
xyxyxyxy[..., 4:6].sub_(xyxyxyxy[..., 2:4])
257+
xyxyxyxy[..., 2:4].sub_(xyxyxyxy[..., :2])
258+
# sqrt((x3 - x1) ** 2 + (y1 - y3) ** 2) = w
259+
xyxyxyxy[..., 2] = xyxyxyxy[..., 2].pow(2).add(xyxyxyxy[..., 3].pow(2)).sqrt()
260+
# sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) = h
261+
xyxyxyxy[..., 3] = xyxyxyxy[..., 4].pow(2).add(xyxyxyxy[..., 5].pow(2)).sqrt()
262262
xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0)
263263
return xyxyxyxy[..., :5].to(dtype)
264264

0 commit comments

Comments
 (0)