diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 376021f0974..d067559f26a 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -388,7 +388,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: maximum = float_image.amax(dim=(-2, -1), keepdim=True) eq_idxs = maximum == minimum - inv_scale = maximum.sub_(minimum).div_(bound) + inv_scale = maximum.sub_(minimum).mul_(1.0 / bound) minimum[eq_idxs] = 0.0 inv_scale[eq_idxs] = 1.0 diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 909b43739ee..ce97ce0575d 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -390,7 +390,7 @@ def _affine_bounding_box_xyxy( device=device, ) new_points = torch.matmul(points, transposed_affine_matrix) - tr, _ = torch.min(new_points, dim=0, keepdim=True) + tr = torch.amin(new_points, dim=0, keepdim=True) # Translate bounding boxes out_bboxes.sub_(tr.repeat((1, 2))) # Estimate meta-data for image with inverted=True and with center=[0,0] @@ -701,7 +701,7 @@ def pad_image_tensor( # internally. torch_padding = _parse_pad_padding(padding) - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + if padding_mode not in ("constant", "edge", "reflect", "symmetric"): raise ValueError( f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, " f"but got `'{padding_mode}'`." @@ -917,7 +917,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # - + # TODO: should we define them transposed? theta1 = torch.tensor( [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device ) @@ -925,9 +925,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, d = 0.5 base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) - x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device) + x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device) base_grid[..., 0].copy_(x_grid) - y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1) + y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1) base_grid[..., 1].copy_(y_grid) base_grid[..., 2].fill_(1) @@ -1059,6 +1059,7 @@ def perspective_bounding_box( (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, ] + # TODO: should we define them transposed? theta1 = torch.tensor( [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], dtype=dtype, @@ -1165,6 +1166,7 @@ def elastic_image_tensor( return image shape = image.shape + device = image.device if image.ndim > 4: image = image.reshape((-1,) + shape[-3:]) @@ -1172,7 +1174,9 @@ def elastic_image_tensor( else: needs_unsquash = False - output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill) + image_height, image_width = shape[-2:] + grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device)) + output = _FT._apply_grid_transform(image, grid, interpolation.value, fill) if needs_unsquash: output = output.reshape(shape) @@ -1505,8 +1509,7 @@ def five_crop_image_tensor( image_height, image_width = image.shape[-2:] if crop_width > image_width or crop_height > image_height: - msg = "Requested crop size {} is bigger than input size {}" - raise ValueError(msg.format(size, (image_height, image_width))) + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") tl = crop_image_tensor(image, 0, 0, crop_height, crop_width) tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width) @@ -1525,8 +1528,7 @@ def five_crop_image_pil( image_height, image_width = get_spatial_size_image_pil(image) if crop_width > image_width or crop_height > image_height: - msg = "Requested crop size {} is bigger than input size {}" - raise ValueError(msg.format(size, (image_height, image_width))) + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") tl = crop_image_pil(image, 0, 0, crop_height, crop_width) tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)