diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 3bdf0cfe34e..f05112ee498 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -67,7 +67,7 @@ class TestRotate: IMG_W = 26 @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("height, width", [(26, IMG_W), (32, IMG_W)]) + @pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)]) @pytest.mark.parametrize( "center", [ @@ -77,7 +77,7 @@ class TestRotate: ], ) @pytest.mark.parametrize("dt", ALL_DTYPES) - @pytest.mark.parametrize("angle", range(-180, 180, 17)) + @pytest.mark.parametrize("angle", range(-180, 180, 34)) @pytest.mark.parametrize("expand", [True, False]) @pytest.mark.parametrize( "fill", diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 6bcd1ea85da..da7acef3e7b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -650,6 +650,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # Points are shifted due to affine matrix torch convention about + # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) pts = torch.tensor( [ [-0.5 * w, -0.5 * h, 1.0], @@ -658,11 +660,15 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] [0.5 * w, -0.5 * h, 1.0], ] ) - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) - new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) + theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) + new_pts = torch.matmul(pts, theta.T) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) + # shift points to [0, w] and [0, h] interval to match PIL results + min_vals += torch.tensor((w * 0.5, h * 0.5)) + max_vals += torch.tensor((w * 0.5, h * 0.5)) + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 tol = 1e-4 cmax = torch.ceil((max_vals / tol).trunc_() * tol)