Skip to content

Commit c6e81a4

Browse files
committed
remove fill blending for bilinear affine
1 parent f69eee6 commit c6e81a4

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

test/test_transforms_v2.py

+22
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,28 @@ def test_transform_unknown_fill_error(self):
13931393
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
13941394
transforms.RandomAffine(degrees=0, fill="fill")
13951395

1396+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
1397+
def test_bilinear_no_blend_artifacts(self, dtype):
1398+
# Regression test for https://github.com/pytorch/vision/issues/8083
1399+
value = 0.5
1400+
if not dtype.is_floating_point:
1401+
value = int(value * get_max_value(dtype))
1402+
1403+
input = torch.full((1, 200, 200), value, dtype=dtype)
1404+
1405+
output = F.affine(
1406+
input,
1407+
angle=30.0,
1408+
translate=(0, 0),
1409+
scale=1.0,
1410+
shear=(0, 0),
1411+
interpolation=F.InterpolationMode.BILINEAR,
1412+
fill=value,
1413+
)
1414+
1415+
# Since the fill color is the same as the input, affine should be a no-op
1416+
assert_equal(output, input)
1417+
13961418

13971419
class TestVerticalFlip:
13981420
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])

torchvision/transforms/v2/functional/_geometry.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -585,13 +585,8 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
585585
mask = mask.expand_as(float_img)
586586
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
587587
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
588-
if mode == "nearest":
589-
bool_mask = mask < 0.5
590-
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
591-
else: # 'bilinear'
592-
# The following is mathematically equivalent to:
593-
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
594-
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
588+
bool_mask = mask < 1
589+
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
595590

596591
img = float_img.round_().to(img.dtype) if not fp else float_img
597592

0 commit comments

Comments
 (0)