diff --git a/test/test_transforms.py b/test/test_transforms.py index e86b6959517..f12678e1d5f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -269,7 +269,7 @@ def test_pil_to_tensor(self, channels): input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() output = trans(img) # HWC -> CHW - expected_output = (input_data * 255).byte() + expected_output = (input_data * 255).round().byte() torch.testing.assert_close(output, expected_output, check_stride=False) # separate test for mode '1' PIL images @@ -502,6 +502,7 @@ def test_pad_with_mode_F_images(self): img = Image.new("F", (10, 10)) padded_img = transform(img) + assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size], check_stride=False) @@ -539,7 +540,7 @@ class TestToPil: def _get_1_channel_tensor_various_types(): img_data_float = torch.Tensor(1, 4, 4).uniform_() - expected_output = img_data_float.mul(255).int().float().div(255).numpy() + expected_output = img_data_float.mul(255).round().div(255).numpy() yield img_data_float, expected_output, 'L' img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255) @@ -556,7 +557,7 @@ def _get_1_channel_tensor_various_types(): def _get_2d_tensor_various_types(): img_data_float = torch.Tensor(4, 4).uniform_() - expected_output = img_data_float.mul(255).int().float().div(255).numpy() + expected_output = img_data_float.mul(255).round().div(255).numpy() yield img_data_float, expected_output, 'L' img_data_byte = torch.ByteTensor(4, 4).random_(0, 255) @@ -634,7 +635,7 @@ def test_2_channel_ndarray_to_pil_image_error(self): @pytest.mark.parametrize('expected_mode', [None, 'LA']) def test_2_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(2, 4, 4).uniform_() - expected_output = img_data.mul(255).int().float().div(255) + expected_output = img_data.mul(255).round().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) assert img.mode == 'LA' # default should assume LA @@ -683,7 +684,7 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): @pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr']) def test_3_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(3, 4, 4).uniform_() - expected_output = img_data.mul(255).int().float().div(255) + expected_output = img_data.mul(255).round().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) @@ -741,7 +742,7 @@ def test_3_channel_ndarray_to_pil_image_error(self): @pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX']) def test_4_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(4, 4, 4).uniform_() - expected_output = img_data.mul(255).int().float().div(255) + expected_output = img_data.mul(255).round().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) @@ -815,6 +816,25 @@ def test_tensor_bad_types_to_pil_image(self): with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'): transforms.ToPILImage()(torch.ones(6, 4, 4)) + @pytest.mark.parametrize('input_img, expected_output', [ + (torch.full((4, 4), 1.01), np.full((4, 4), 255)), + (torch.full((4, 4), -0.01), np.full((4, 4), 0)) + ]) + def test_tensor_to_pil_no_overshoot(self, input_img, expected_output): + transform = transforms.ToPILImage() + img = transform(input_img) + torch.testing.assert_close(expected_output, np.array(img), check_dtype=False) + + def test_tensor_to_pil_robust(self, input_img, expected_output): + to_pil = transforms.ToPILImage() + to_tensor = transforms.ToTensor() + input_img = torch.ByteTensor(20, 20).random_(0, 255).numpy() + eps = 1.0e-5 + img_sub = to_tensor(to_pil(input_img - eps)) + torch.testing.assert_close(input_img, np.array(img_sub), check_dtype=False) + img_add = to_tensor(to_pil(input_img + eps)) + torch.testing.assert_close(input_img, np.array(img_add), check_dtype=False) + def test_adjust_brightness(): x_shape = [2, 2, 3] diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 21f5c654f99..d9e68f51c72 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -241,7 +241,7 @@ def to_pil_image(pic, mode=None): npimg = pic if isinstance(pic, torch.Tensor): if pic.is_floating_point() and mode != 'F': - pic = pic.mul(255).byte() + pic = pic.clamp(0, 1).mul(255).round().byte() npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) if not isinstance(npimg, np.ndarray):