Skip to content

Commit a26534c

Browse files
authored
Fixed rotate with expand inconsistency (#5677)
* Fixed rotate with expand inconsistency between torch vs PIL on odd-sized images * Update functional_tensor.py
1 parent 71907be commit a26534c

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

test/test_functional_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class TestRotate:
6767
IMG_W = 26
6868

6969
@pytest.mark.parametrize("device", cpu_and_gpu())
70-
@pytest.mark.parametrize("height, width", [(26, IMG_W), (32, IMG_W)])
70+
@pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)])
7171
@pytest.mark.parametrize(
7272
"center",
7373
[
@@ -77,7 +77,7 @@ class TestRotate:
7777
],
7878
)
7979
@pytest.mark.parametrize("dt", ALL_DTYPES)
80-
@pytest.mark.parametrize("angle", range(-180, 180, 17))
80+
@pytest.mark.parametrize("angle", range(-180, 180, 34))
8181
@pytest.mark.parametrize("expand", [True, False])
8282
@pytest.mark.parametrize(
8383
"fill",

torchvision/transforms/functional_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
650650
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
651651

652652
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
653+
# Points are shifted due to affine matrix torch convention about
654+
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
653655
pts = torch.tensor(
654656
[
655657
[-0.5 * w, -0.5 * h, 1.0],
@@ -658,11 +660,15 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
658660
[0.5 * w, -0.5 * h, 1.0],
659661
]
660662
)
661-
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
662-
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
663+
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
664+
new_pts = torch.matmul(pts, theta.T)
663665
min_vals, _ = new_pts.min(dim=0)
664666
max_vals, _ = new_pts.max(dim=0)
665667

668+
# shift points to [0, w] and [0, h] interval to match PIL results
669+
min_vals += torch.tensor((w * 0.5, h * 0.5))
670+
max_vals += torch.tensor((w * 0.5, h * 0.5))
671+
666672
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
667673
tol = 1e-4
668674
cmax = torch.ceil((max_vals / tol).trunc_() * tol)

0 commit comments

Comments
 (0)