@@ -650,6 +650,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
650
650
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
651
651
652
652
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
653
+ # Points are shifted due to affine matrix torch convention about
654
+ # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
653
655
pts = torch .tensor (
654
656
[
655
657
[- 0.5 * w , - 0.5 * h , 1.0 ],
@@ -658,11 +660,15 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
658
660
[0.5 * w , - 0.5 * h , 1.0 ],
659
661
]
660
662
)
661
- theta = torch .tensor (matrix , dtype = torch .float ).reshape ( 1 , 2 , 3 )
662
- new_pts = pts . view ( 1 , 4 , 3 ). bmm ( theta .transpose ( 1 , 2 )). view ( 4 , 2 )
663
+ theta = torch .tensor (matrix , dtype = torch .float ).view ( 2 , 3 )
664
+ new_pts = torch . matmul ( pts , theta .T )
663
665
min_vals , _ = new_pts .min (dim = 0 )
664
666
max_vals , _ = new_pts .max (dim = 0 )
665
667
668
+ # shift points to [0, w] and [0, h] interval to match PIL results
669
+ min_vals += torch .tensor ((w * 0.5 , h * 0.5 ))
670
+ max_vals += torch .tensor ((w * 0.5 , h * 0.5 ))
671
+
666
672
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
667
673
tol = 1e-4
668
674
cmax = torch .ceil ((max_vals / tol ).trunc_ () * tol )
0 commit comments