Skip to content

Fix overshoot issue in F.to_pil_image #3610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_pil_to_tensor(self, channels):
input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte()
output = trans(img) # HWC -> CHW
expected_output = (input_data * 255).byte()
expected_output = (input_data * 255).round().byte()
torch.testing.assert_close(output, expected_output, check_stride=False)

# separate test for mode '1' PIL images
Expand Down Expand Up @@ -502,6 +502,7 @@ def test_pad_with_mode_F_images(self):

img = Image.new("F", (10, 10))
padded_img = transform(img)

assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size], check_stride=False)


Expand Down Expand Up @@ -539,7 +540,7 @@ class TestToPil:

def _get_1_channel_tensor_various_types():
img_data_float = torch.Tensor(1, 4, 4).uniform_()
expected_output = img_data_float.mul(255).int().float().div(255).numpy()
expected_output = img_data_float.mul(255).round().div(255).numpy()
yield img_data_float, expected_output, 'L'

img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
Expand All @@ -556,7 +557,7 @@ def _get_1_channel_tensor_various_types():

def _get_2d_tensor_various_types():
img_data_float = torch.Tensor(4, 4).uniform_()
expected_output = img_data_float.mul(255).int().float().div(255).numpy()
expected_output = img_data_float.mul(255).round().div(255).numpy()
yield img_data_float, expected_output, 'L'

img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
Expand Down Expand Up @@ -634,7 +635,7 @@ def test_2_channel_ndarray_to_pil_image_error(self):
@pytest.mark.parametrize('expected_mode', [None, 'LA'])
def test_2_channel_tensor_to_pil_image(self, expected_mode):
img_data = torch.Tensor(2, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255)
expected_output = img_data.mul(255).round().div(255)
if expected_mode is None:
img = transforms.ToPILImage()(img_data)
assert img.mode == 'LA' # default should assume LA
Expand Down Expand Up @@ -683,7 +684,7 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
@pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr'])
def test_3_channel_tensor_to_pil_image(self, expected_mode):
img_data = torch.Tensor(3, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255)
expected_output = img_data.mul(255).round().div(255)

if expected_mode is None:
img = transforms.ToPILImage()(img_data)
Expand Down Expand Up @@ -741,7 +742,7 @@ def test_3_channel_ndarray_to_pil_image_error(self):
@pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX'])
def test_4_channel_tensor_to_pil_image(self, expected_mode):
img_data = torch.Tensor(4, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255)
expected_output = img_data.mul(255).round().div(255)

if expected_mode is None:
img = transforms.ToPILImage()(img_data)
Expand Down Expand Up @@ -815,6 +816,25 @@ def test_tensor_bad_types_to_pil_image(self):
with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'):
transforms.ToPILImage()(torch.ones(6, 4, 4))

@pytest.mark.parametrize('input_img, expected_output', [
(torch.full((4, 4), 1.01), np.full((4, 4), 255)),
(torch.full((4, 4), -0.01), np.full((4, 4), 0))
])
def test_tensor_to_pil_no_overshoot(self, input_img, expected_output):
transform = transforms.ToPILImage()
img = transform(input_img)
torch.testing.assert_close(expected_output, np.array(img), check_dtype=False)

def test_tensor_to_pil_robust(self, input_img, expected_output):
to_pil = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
input_img = torch.ByteTensor(20, 20).random_(0, 255).numpy()
eps = 1.0e-5
img_sub = to_tensor(to_pil(input_img - eps))
torch.testing.assert_close(input_img, np.array(img_sub), check_dtype=False)
img_add = to_tensor(to_pil(input_img + eps))
torch.testing.assert_close(input_img, np.array(img_add), check_dtype=False)


def test_adjust_brightness():
x_shape = [2, 2, 3]
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def to_pil_image(pic, mode=None):
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != 'F':
pic = pic.mul(255).byte()
pic = pic.clamp(0, 1).mul(255).round().byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))

if not isinstance(npimg, np.ndarray):
Expand Down