From d3d2f792466d52d66df5b9eed090449902f95ecd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Apr 2024 12:25:57 +0100 Subject: [PATCH 1/2] Let to_tensor return torch.uint16 tensors --- docs/source/transforms.rst | 3 ++- test/test_transforms_v2.py | 12 ++++++++++++ torchvision/transforms/functional.py | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 54ed18394cd..26d52f22e93 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -91,7 +91,8 @@ the tensor dtype. Tensor images with a float dtype are expected to have values in ``[0, 1]``. Tensor images with an integer dtype are expected to have values in ``[0, MAX_DTYPE]`` where ``MAX_DTYPE`` is the largest value that can be represented in that dtype. Typically, images of dtype -``torch.uint8`` are expected to have values in ``[0, 255]``. +``torch.uint8`` are expected to have values in ``[0, 255]``. Note that dtypes +like ``torch.uint16`` or ``torch.uint32`` aren't fully supported. Use :class:`~torchvision.transforms.v2.ToDtype` to convert both the dtype and range of the inputs. diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3724a23af74..591dd38e942 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5327,6 +5327,18 @@ def test_functional_error(self): F.pil_to_tensor(object()) +@pytest.mark.parametrize("f", [F.to_tensor, F.pil_to_tensor]) +def test_I16_to_tensor(f): + # See https://github.com/pytorch/vision/issues/8359 + I16_pil_img = PIL.Image.fromarray(np.random.randint(0, 2 ** 16, (10, 10), dtype=np.uint16)) + assert I16_pil_img.mode == "I;16" + + cm = pytest.warns(UserWarning, match="deprecated") if f is F.to_tensor else contextlib.nullcontext() + with cm: + out = f(I16_pil_img) + assert out.dtype == torch.uint16 + + class TestLambda: @pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0]) @pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)]) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 124d1da5f4f..a364bdc9d2a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -164,7 +164,7 @@ def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor: return torch.from_numpy(nppic).to(dtype=default_float_dtype) # handle PIL Image - mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32} + mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.uint16, "F": np.float32} img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)) if pic.mode == "1": From ccdc0700e32f4db7c1aafca833ca1563147f54f6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Apr 2024 12:51:19 +0100 Subject: [PATCH 2/2] lint --- test/test_transforms_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 591dd38e942..83c4969dfc0 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5330,7 +5330,7 @@ def test_functional_error(self): @pytest.mark.parametrize("f", [F.to_tensor, F.pil_to_tensor]) def test_I16_to_tensor(f): # See https://github.com/pytorch/vision/issues/8359 - I16_pil_img = PIL.Image.fromarray(np.random.randint(0, 2 ** 16, (10, 10), dtype=np.uint16)) + I16_pil_img = PIL.Image.fromarray(np.random.randint(0, 2**16, (10, 10), dtype=np.uint16)) assert I16_pil_img.mode == "I;16" cm = pytest.warns(UserWarning, match="deprecated") if f is F.to_tensor else contextlib.nullcontext()