diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 8f923475664..ba74aa5858e 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -154,6 +154,29 @@ def test_rotate_interpolation_type(self): res2 = F.rotate(tensor, 45, interpolation=BILINEAR) assert_equal(res1, res2) + @pytest.mark.parametrize("fn", [F.rotate, scripted_rotate]) + @pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)]) + def test_differentiable_rotate(self, fn, center): + alpha = torch.tensor(1.0, requires_grad=True) + x = torch.zeros(1, 3, 10, 10) + x[0, :, 2:5, 2:5] = 1.0 + + y = fn(x, alpha, interpolation=BILINEAR, center=center) + assert y.requires_grad + y.mean().backward() + assert alpha.grad is not None + if center is not None: + assert center.grad is not None + + @pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)]) + def test_differentiable_rotate_nonfloat(self, center): + alpha = torch.tensor(1.0, requires_grad=True) + x = torch.zeros(1, 3, 10, 10, dtype=torch.long) + x[0, :, 2:5, 2:5] = 1 + + with pytest.raises(ValueError, match=r"input should be float tensor"): + F.rotate(x, alpha, interpolation=BILINEAR, center=center) + class TestAffine: @@ -379,6 +402,37 @@ def test_warnings(self, device): # we convert the PIL images to numpy as assert_equal doesn't work on PIL images. assert_equal(np.asarray(res1), np.asarray(res2)) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) + @pytest.mark.parametrize("translate", [[0, 0], torch.tensor([1.0, 2.0], requires_grad=True)]) + @pytest.mark.parametrize("scale", [1.0, torch.tensor(1.0, requires_grad=True)]) + @pytest.mark.parametrize("shear", [[1.0, 1.0], torch.tensor([1.0, 1.0], requires_grad=True)]) + def test_differentiable_affine(self, fn, translate, scale, shear): + alpha = torch.tensor(1.0, requires_grad=True) + x = torch.zeros(1, 3, 10, 10) + x[0, :, 2:5, 2:5] = 1.0 + + y = fn(x, alpha, translate, scale, shear, interpolation=BILINEAR) + assert y.requires_grad + y.mean().backward() + assert alpha.grad is not None + if isinstance(translate, torch.Tensor): + assert translate.grad is not None + if isinstance(scale, torch.Tensor): + assert scale.grad is not None + if isinstance(shear, torch.Tensor): + assert shear.grad is not None + + @pytest.mark.parametrize("translate", [[0, 0], torch.tensor([1.0, 2.0], requires_grad=True)]) + @pytest.mark.parametrize("scale", [1.0, torch.tensor(1.0, requires_grad=True)]) + @pytest.mark.parametrize("shear", [[1.0, 1.0], torch.tensor([1.0, 1.0], requires_grad=True)]) + def test_differentiable_affine_nonfloat(self, translate, scale, shear): + alpha = torch.tensor(1.0, requires_grad=True) + x = torch.zeros(1, 3, 10, 10, dtype=torch.long) + x[0, :, 2:5, 2:5] = 1 + + with pytest.raises(ValueError, match=r"input should be float tensor"): + F.affine(x, alpha, translate, scale, shear, interpolation=BILINEAR) + def _get_data_dims_and_points_for_perspective(): # Ideally we would parametrize independently over data dims and points, but diff --git a/test/test_transforms.py b/test/test_transforms.py index 512a343ee59..5e3d1cf45be 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2015,7 +2015,13 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_ true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv))) result_matrix = self._to_3x3_inv( - F._get_inverse_affine_matrix(center=cnt, angle=angle, translate=translate, scale=scale, shear=shear) + F._get_inverse_affine_matrix_tensor( + center=torch.tensor(cnt, dtype=torch.float64), # using double to match true_matrix precision + angle=torch.tensor(angle, dtype=torch.float64), + translate=torch.tensor(translate, dtype=torch.float64), + scale=torch.tensor(scale, dtype=torch.float64), + shear=torch.tensor(shear, dtype=torch.float64), + ) ) assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10 # 2) Perform inverse mapping: diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index fb07112bb83..09c7d46fa4c 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -464,6 +464,7 @@ def test_resized_crop_save(self, tmpdir): def _test_random_affine_helper(device, **kwargs): + torch.manual_seed(12) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) transform = T.RandomAffine(**kwargs) @@ -482,7 +483,7 @@ def test_random_affine(device, tmpdir): @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) -@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) +@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 11.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) def test_random_affine_shear(device, interpolation, shear): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7d7d5382291..98485a879e1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,8 +1,7 @@ -import math import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -944,63 +943,129 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: return F_t.adjust_gamma(img, gamma, gain) -def _get_inverse_affine_matrix( - center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] -) -> List[float]: - # Helper method to compute inverse matrix for affine transformation - - # As it is explained in PIL.Image.rotate - # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 - # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] - # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] - # RSS is rotation with scale and shear matrix - # RSS(a, s, (sx, sy)) = - # = R(a) * S(s) * SHy(sy) * SHx(sx) - # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ] - # [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ] - # [ 0 , 0 , 1 ] - # - # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: - # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] - # [0, 1 ] [-tan(s), 1] - # - # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 - - rot = math.radians(angle) - sx = math.radians(shear[0]) - sy = math.radians(shear[1]) - - cx, cy = center - tx, ty = translate +def _get_inverse_affine_matrix_tensor( + center: Tensor, angle: Tensor, translate: Tensor, scale: Tensor, shear: Tensor +) -> Tensor: + output = torch.zeros(3, 3, dtype=angle.dtype) + + rot = angle * torch.pi / 180.0 + shear_rad = shear * torch.pi / 180.0 + + m_center = torch.eye(3, 3, dtype=angle.dtype) + m_center[:2, 2] = center + + i_m_center = torch.eye(3, 3, dtype=angle.dtype) + i_m_center[:2, 2] = -center + + i_m_translate = torch.eye(3, 3, dtype=angle.dtype) + i_m_translate[:2, 2] = -translate # RSS without scaling - a = math.cos(rot - sy) / math.cos(sy) - b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot) - c = math.sin(rot - sy) / math.cos(sy) - d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) + sx, sy = shear_rad[0], shear_rad[1] + a = torch.cos(rot - sy) / torch.cos(sy) + b = torch.cos(rot - sy) * torch.tan(sx) / torch.cos(sy) + torch.sin(rot) + c = -torch.sin(rot - sy) / torch.cos(sy) + d = -torch.sin(rot - sy) * torch.tan(sx) / torch.cos(sy) + torch.cos(rot) + + output[0, 0] = d + output[0, 1] = b + output[1, 0] = c + output[1, 1] = a + output = output / scale + output[2, 2] = 1.0 + + output = torch.linalg.multi_dot([m_center, output, i_m_center, i_m_translate]) + output = output[:2, :] + return output + + +def _from_arg_scalar_or_tensor(arg: Union[float, int, Tensor], name: str) -> Tensor: + if not isinstance(arg, (int, float, Tensor)): + raise TypeError(f"Argument {name} should be int or float or a tensor") + + if isinstance(arg, Tensor): + return arg + + if isinstance(arg, int): + arg = float(arg) + + return torch.tensor(arg) + + +def _from_arg_seq_or_tensor(arg: Union[List[float], Tuple[float, float], Tensor], name: str) -> Tensor: + if not isinstance(arg, (list, tuple, Tensor)): + raise TypeError(f"Argument {name} should be a sequence of two values or a tensor") + + if isinstance(arg, Tensor): + if arg.numel() != 2: + raise ValueError(f"Tensor should have 2 values, got {arg.numel()}") + return arg + + # https://github.com/pytorch/pytorch/issues/70240 + # if len(arg) != 2: + # raise ValueError(f"Argument {name} should be a sequence of length 2, got {len(arg)}, {arg}") + + if isinstance(arg, tuple): + arg = list(arg) + + return torch.tensor([float(arg[0]), float(arg[1])]) + + +def _from_arg_intseq_or_tensor(arg: Union[List[int], Tuple[int, int], Tensor], name: str) -> Tensor: + if not isinstance(arg, (list, tuple, Tensor)): + raise TypeError(f"Argument {name} should be a sequence of two values or a tensor") + + if isinstance(arg, Tensor): + if arg.numel() != 2: + raise ValueError(f"Tensor should have 2 values, got {arg.numel()}") + return arg + + # https://github.com/pytorch/pytorch/issues/70240 + # if len(arg) != 2: + # raise ValueError(f"Argument {name} should be a sequence of length 2, got {len(arg)}, {arg}") + + if isinstance(arg, tuple): + arg = list(arg) + + return torch.tensor([float(arg[0]), float(arg[1])]) + + +def _from_arg_scalar_or_seq_or_tensor( + arg: Union[int, float, List[float], Tuple[float, float], Tensor], name: str +) -> Tensor: + # Assumption is that len(arg) == 2 + if not isinstance(arg, (int, float, list, tuple, Tensor)): + raise TypeError(f"Argument {name} should be a single value or a sequence of two values or a tensor") + + if isinstance(arg, Tensor): + if arg.numel() != 2: + raise ValueError(f"Tensor should have 2 values, got {arg.numel()}") + return arg + + if isinstance(arg, (int, float)): + return torch.tensor([float(arg), 0.0]) + + if isinstance(arg, tuple): + arg = list(arg) - # Inverted rotation matrix with scale and shear - # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - matrix = [d, -b, 0.0, -c, a, 0.0] - matrix = [x / scale for x in matrix] + if isinstance(arg, list) and len(arg) == 1: + return torch.tensor([float(arg[0]), float(arg[0])]) - # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 - matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) - matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + return _from_arg_seq_or_tensor(arg, name) - # Apply center translation: C * RSS^-1 * C^-1 * T^-1 - matrix[2] += cx - matrix[5] += cy - return matrix +def _ensure_float_input_for_learnable_params(inpt: Tensor, params: List[Tensor]) -> None: + if any([isinstance(p, torch.Tensor) and p.requires_grad for p in params]): + if not inpt.is_floating_point(): + raise ValueError("If any parameter is a tensor that requires grad, then input should be float tensor") def rotate( img: Tensor, - angle: float, + angle: Union[float, int, Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[List[int]] = None, + center: Optional[Union[List[int], Tuple[int, int], Tensor]] = None, fill: Optional[List[float]] = None, resample: Optional[int] = None, ) -> Tensor: @@ -1010,7 +1075,7 @@ def rotate( Args: img (PIL Image or Tensor): image to be rotated. - angle (number): rotation angle value in degrees, counter-clockwise. + angle (number or Tensor): rotation angle value in degrees, counter-clockwise. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. @@ -1019,7 +1084,7 @@ def rotate( If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. - center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + center (sequence or Tensor, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. @@ -1049,38 +1114,51 @@ def rotate( "Please, use InterpolationMode enum." ) interpolation = _interpolation_modes_from_int(interpolation) - - if not isinstance(angle, (int, float)): - raise TypeError("Argument angle should be int or float") - - if center is not None and not isinstance(center, (list, tuple)): - raise TypeError("Argument center should be a sequence") - if not isinstance(interpolation, InterpolationMode): raise TypeError("Argument interpolation should be a InterpolationMode") + angle = _from_arg_scalar_or_tensor(angle, "angle") + if not isinstance(img, torch.Tensor): + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + pil_interpolation = pil_modes_mapping[interpolation] - return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) + return F_pil.rotate( + img, angle=angle.item(), interpolation=pil_interpolation, expand=expand, center=center, fill=fill + ) - center_f = [0.0, 0.0] - if center is not None: - img_size = get_image_size(img) + if not torch.jit.is_scripting(): + # torch.jit.script crashes with Segmentation fault (core dumped) on the following + # without if not torch.jit.is_scripting() + _ensure_float_input_for_learnable_params(img, [angle, center]) + + do_recenter = True + if center is None: + center = torch.tensor([0.0, 0.0]) + do_recenter = False + else: + center = _from_arg_intseq_or_tensor(center, "center") + + if do_recenter: + img_size = torch.tensor(get_image_size(img), dtype=torch.float) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center = center - img_size * 0.5 # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. - matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + matrix = _get_inverse_affine_matrix_tensor( + center, -angle, torch.tensor([0.0, 0.0]), torch.tensor(1.0), torch.tensor([0.0, 0.0]) + ) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) def affine( img: Tensor, - angle: float, - translate: List[int], - scale: float, - shear: List[float], + angle: Union[float, int, Tensor], + translate: Union[List[int], Tuple[int, int], Tensor], + scale: Union[float, Tensor], + shear: Union[List[float], Tuple[float, float], Tensor], interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, resample: Optional[int] = None, @@ -1092,10 +1170,10 @@ def affine( Args: img (PIL Image or Tensor): image to transform. - angle (number): rotation angle in degrees between -180 and 180, clockwise direction. - translate (sequence of integers): horizontal and vertical translations (post-rotation translation) - scale (float): overall scale - shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction. + angle (number or Tensor): rotation angle in degrees between -180 and 180, clockwise direction. + translate (sequence of integers or Tensor): horizontal and vertical translations (post-rotation translation) + scale (float or Tensor): overall scale + shear (float or sequence or Tensor): shear angle value in degrees between -180 to 180, clockwise direction. If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. interpolation (InterpolationMode): Desired interpolation enum defined by @@ -1132,58 +1210,37 @@ def affine( ) interpolation = _interpolation_modes_from_int(interpolation) + if not isinstance(interpolation, InterpolationMode): + raise TypeError("Argument interpolation should be a InterpolationMode") + if fillcolor is not None: warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead") fill = fillcolor - if not isinstance(angle, (int, float)): - raise TypeError("Argument angle should be int or float") - - if not isinstance(translate, (list, tuple)): - raise TypeError("Argument translate should be a sequence") - - if len(translate) != 2: - raise ValueError("Argument translate should be a sequence of length 2") - + angle = _from_arg_scalar_or_tensor(angle, "angle") + translate = _from_arg_intseq_or_tensor(translate, "translate") + scale = _from_arg_scalar_or_tensor(scale, "scale") if scale <= 0.0: raise ValueError("Argument scale should be positive") + shear = _from_arg_scalar_or_seq_or_tensor(shear, "shear") - if not isinstance(shear, (numbers.Number, (list, tuple))): - raise TypeError("Shear should be either a single value or a sequence of two values") - - if not isinstance(interpolation, InterpolationMode): - raise TypeError("Argument interpolation should be a InterpolationMode") - - if isinstance(angle, int): - angle = float(angle) - - if isinstance(translate, tuple): - translate = list(translate) - - if isinstance(shear, numbers.Number): - shear = [shear, 0.0] - - if isinstance(shear, tuple): - shear = list(shear) - - if len(shear) == 1: - shear = [shear[0], shear[0]] - - if len(shear) != 2: - raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") - - img_size = get_image_size(img) if not isinstance(img, torch.Tensor): + img_size = get_image_size(img) # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine - center = [img_size[0] * 0.5, img_size[1] * 0.5] - matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + center = torch.tensor([img_size[0] * 0.5, img_size[1] * 0.5]) + matrix = _get_inverse_affine_matrix_tensor(center, angle, translate, scale, shear) pil_interpolation = pil_modes_mapping[interpolation] - return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) + return F_pil.affine(img, matrix=matrix.view(-1).tolist(), interpolation=pil_interpolation, fill=fill) + + if not torch.jit.is_scripting(): + # torch.jit.script crashes with Segmentation fault (core dumped) on the following + # without if not torch.jit.is_scripting() + _ensure_float_input_for_learnable_params(img, [angle, translate, scale, shear]) - translate_f = [1.0 * t for t in translate] - matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) + center = torch.tensor([0.0, 0.0]) + matrix = _get_inverse_affine_matrix_tensor(center, angle, translate, scale, shear) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4e20c19e45f..ee4a5b0c381 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -568,7 +568,7 @@ def resize( def _assert_grid_transform_inputs( img: Tensor, - matrix: Optional[List[float]], + matrix: Optional[Tensor], interpolation: str, fill: Optional[List[float]], supported_interpolation_modes: List[str], @@ -580,11 +580,11 @@ def _assert_grid_transform_inputs( _assert_image_tensor(img) - if matrix is not None and not isinstance(matrix, list): - raise TypeError("Argument matrix should be a list") + if matrix is not None and not isinstance(matrix, Tensor): + raise TypeError("Argument matrix should be a Tensor") - if matrix is not None and len(matrix) != 6: - raise ValueError("Argument matrix should have 6 float values") + if matrix is not None and list(matrix.shape) != [2, 3]: + raise ValueError("Argument matrix should have shape [2, 3]") if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") @@ -697,20 +697,18 @@ def _gen_affine_grid( return output_grid.view(1, oh, ow, 2) -def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None -) -> Tensor: +def affine(img: Tensor, matrix: Tensor, interpolation: str = "nearest", fill: Optional[List[float]] = None) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) - + matrix = matrix.unsqueeze(0) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) return _apply_grid_transform(img, grid, interpolation, fill=fill) -def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: +def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: # Inspired of PIL implementation: # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 @@ -724,7 +722,6 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] [0.5 * w, -0.5 * h, 1.0], ] ) - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) @@ -739,16 +736,17 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( img: Tensor, - matrix: List[float], + matrix: Tensor, interpolation: str = "nearest", expand: bool = False, fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) + matrix = matrix.unsqueeze(0) w, h = img.shape[-1], img.shape[-2] - ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) + ow, oh = _compute_output_size(matrix.detach(), w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + theta = matrix.to(dtype=dtype, device=img.device) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)