diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4f8d0027bd6..23f75accb52 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1393,6 +1393,28 @@ def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + def test_bilinear_no_blend_artifacts(self, dtype): + # Regression test for https://github.com/pytorch/vision/issues/8083 + value = 0.5 + if not dtype.is_floating_point: + value = int(value * get_max_value(dtype)) + + input = torch.full((1, 200, 200), value, dtype=dtype) + + output = F.affine( + input, + angle=30.0, + translate=(0, 0), + scale=1.0, + shear=(0, 0), + interpolation=F.InterpolationMode.BILINEAR, + fill=value, + ) + + # Since the fill color is the same as the input, affine should be a no-op + assert_equal(output, input) + class TestVerticalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index d6d42344fcb..dd71049289a 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -585,13 +585,8 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill mask = mask.expand_as(float_img) fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type] fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) - if mode == "nearest": - bool_mask = mask < 0.5 - float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask] - else: # 'bilinear' - # The following is mathematically equivalent to: - # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill - float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) + bool_mask = mask < 1 + float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask] img = float_img.round_().to(img.dtype) if not fp else float_img