diff --git a/test/assets/grace_hopper_517x606.jpg b/test/assets/encode_jpeg/grace_hopper_517x606.jpg similarity index 100% rename from test/assets/grace_hopper_517x606.jpg rename to test/assets/encode_jpeg/grace_hopper_517x606.jpg diff --git a/test/assets/jpeg_write/grace_hopper_517x606_pil.jpg b/test/assets/encode_jpeg/jpeg_write/grace_hopper_517x606_pil.jpg similarity index 100% rename from test/assets/jpeg_write/grace_hopper_517x606_pil.jpg rename to test/assets/encode_jpeg/jpeg_write/grace_hopper_517x606_pil.jpg diff --git a/test/assets/fakedata/logos/cmyk_pytorch.jpg b/test/assets/fakedata/logos/cmyk_pytorch.jpg new file mode 100644 index 00000000000..16ee8b2b4bc Binary files /dev/null and b/test/assets/fakedata/logos/cmyk_pytorch.jpg differ diff --git a/test/assets/fakedata/logos/gray_pytorch.jpg b/test/assets/fakedata/logos/gray_pytorch.jpg new file mode 100644 index 00000000000..60c9c7cf705 Binary files /dev/null and b/test/assets/fakedata/logos/gray_pytorch.jpg differ diff --git a/test/assets/fakedata/logos/rgb_pytorch.jpg b/test/assets/fakedata/logos/rgb_pytorch.jpg new file mode 100644 index 00000000000..d49e658b94f Binary files /dev/null and b/test/assets/fakedata/logos/rgb_pytorch.jpg differ diff --git a/test/assets/grace_hopper_517x606.pth b/test/assets/grace_hopper_517x606.pth deleted file mode 100644 index 54b39dc0cd7..00000000000 Binary files a/test/assets/grace_hopper_517x606.pth and /dev/null differ diff --git a/test/test_cpp_models.py b/test/test_cpp_models.py index b6654a0278d..6deb5d79739 100644 --- a/test/test_cpp_models.py +++ b/test/test_cpp_models.py @@ -25,7 +25,8 @@ def process_model(model, tensor, func, name): def read_image1(): - image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', + 'grace_hopper_517x606.jpg') image = Image.open(image_path) image = image.resize((224, 224)) x = F.to_tensor(image) @@ -33,7 +34,8 @@ def read_image1(): def read_image2(): - image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', + 'grace_hopper_517x606.jpg') image = Image.open(image_path) image = image.resize((299, 299)) x = F.to_tensor(image) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index cab0f90c51b..2c6599ce497 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -14,7 +14,7 @@ TEST_FILE = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') + os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') class Tester(unittest.TestCase): diff --git a/test/test_image.py b/test/test_image.py index 45a4258816e..b3ab0b2364a 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -19,6 +19,7 @@ FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') +ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") def get_images(directory, img_ext): @@ -33,14 +34,44 @@ def get_images(directory, img_ext): yield os.path.join(root, fl) +def pil_read_image(img_path): + with Image.open(img_path) as img: + return torch.from_numpy(np.array(img)) + + +def normalize_dimensions(img_pil): + if len(img_pil.shape) == 3: + img_pil = img_pil.permute(2, 0, 1) + else: + img_pil = img_pil.unsqueeze(0) + return img_pil + + class ImageTester(unittest.TestCase): def test_decode_jpeg(self): + conversion = [(None, 0), ("L", 1), ("RGB", 3)] for img_path in get_images(IMAGE_ROOT, ".jpg"): - img_pil = torch.load(img_path.replace('jpg', 'pth')) - img_pil = img_pil.permute(2, 0, 1) - data = read_file(img_path) - img_ljpeg = decode_jpeg(data) - self.assertTrue(img_ljpeg.equal(img_pil)) + for pil_mode, channels in conversion: + with Image.open(img_path) as img: + is_cmyk = img.mode == "CMYK" + if pil_mode is not None: + if is_cmyk: + # libjpeg does not support the conversion + continue + img = img.convert(pil_mode) + img_pil = torch.from_numpy(np.array(img)) + if is_cmyk: + # flip the colors to match libjpeg + img_pil = 255 - img_pil + + img_pil = normalize_dimensions(img_pil) + data = read_file(img_path) + img_ljpeg = decode_image(data, channels=channels) + + # Permit a small variation on pixel values to account for implementation + # differences between Pillow and LibJPEG. + abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() + self.assertTrue(abs_mean_diff < 2) with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) @@ -68,7 +99,7 @@ def test_damaged_images(self): decode_jpeg(data) def test_encode_jpeg(self): - for img_path in get_images(IMAGE_ROOT, ".jpg"): + for img_path in get_images(ENCODE_JPEG, ".jpg"): dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) write_folder = os.path.join(dirname, 'jpeg_write') @@ -111,7 +142,7 @@ def test_encode_jpeg(self): encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) def test_write_jpeg(self): - for img_path in get_images(IMAGE_ROOT, ".jpg"): + for img_path in get_images(ENCODE_JPEG, ".jpg"): data = read_file(img_path) img = decode_jpeg(data) @@ -134,20 +165,25 @@ def test_write_jpeg(self): self.assertEqual(torch_bytes, pil_bytes) def test_decode_png(self): + conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)] for img_path in get_images(FAKEDATA_DIR, ".png"): - img_pil = torch.from_numpy(np.array(Image.open(img_path))) - if len(img_pil.shape) == 3: - img_pil = img_pil.permute(2, 0, 1) - else: - img_pil = img_pil.unsqueeze(0) - data = read_file(img_path) - img_lpng = decode_png(data) - self.assertTrue(img_lpng.equal(img_pil)) + for pil_mode, channels in conversion: + with Image.open(img_path) as img: + if pil_mode is not None: + img = img.convert(pil_mode) + img_pil = torch.from_numpy(np.array(img)) - with self.assertRaises(RuntimeError): - decode_png(torch.empty((), dtype=torch.uint8)) - with self.assertRaises(RuntimeError): - decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) + img_pil = normalize_dimensions(img_pil) + data = read_file(img_path) + img_lpng = decode_image(data, channels=channels) + + tol = 0 if conversion is None else 1 + self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) + + with self.assertRaises(RuntimeError): + decode_png(torch.empty((), dtype=torch.uint8)) + with self.assertRaises(RuntimeError): + decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) def test_encode_png(self): for img_path in get_images(IMAGE_DIR, '.png'): @@ -196,19 +232,6 @@ def test_write_png(self): self.assertTrue(img_pil.equal(saved_image)) - def test_decode_image(self): - for img_path in get_images(IMAGE_ROOT, ".jpg"): - img_pil = torch.load(img_path.replace('jpg', 'pth')) - img_pil = img_pil.permute(2, 0, 1) - img_ljpeg = decode_image(read_file(img_path)) - self.assertTrue(img_ljpeg.equal(img_pil)) - - for img_path in get_images(IMAGE_DIR, ".png"): - img_pil = torch.from_numpy(np.array(Image.open(img_path))) - img_pil = img_pil.permute(2, 0, 1) - img_lpng = decode_image(read_file(img_path)) - self.assertTrue(img_lpng.equal(img_pil)) - def test_read_file(self): with get_tmp_dir() as d: fname, content = 'test1.bin', b'TorchVision\211\n' diff --git a/test/test_transforms.py b/test/test_transforms.py index f9add6d1b57..72100d0feac 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -24,7 +24,7 @@ GRACE_HOPPER = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') + os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') class Tester(unittest.TestCase): diff --git a/torchvision/csrc/cpu/image/read_image_cpu.cpp b/torchvision/csrc/cpu/image/read_image_cpu.cpp index ad11ee666b9..5839017d3d7 100644 --- a/torchvision/csrc/cpu/image/read_image_cpu.cpp +++ b/torchvision/csrc/cpu/image/read_image_cpu.cpp @@ -1,13 +1,15 @@ #include "read_image_cpu.h" #include -torch::Tensor decode_image(const torch::Tensor& data) { +torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional TORCH_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); + TORCH_CHECK( + channels >= 0 && channels <= 4, "Number of channels not supported"); auto datap = data.data_ptr(); @@ -15,9 +17,9 @@ torch::Tensor decode_image(const torch::Tensor& data) { const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" if (memcmp(jpeg_signature, datap, 3) == 0) { - return decodeJPEG(data); + return decodeJPEG(data, channels); } else if (memcmp(png_signature, datap, 4) == 0) { - return decodePNG(data); + return decodePNG(data, channels); } else { TORCH_CHECK( false, diff --git a/torchvision/csrc/cpu/image/read_image_cpu.h b/torchvision/csrc/cpu/image/read_image_cpu.h index c8538cc88c6..e926a8474da 100644 --- a/torchvision/csrc/cpu/image/read_image_cpu.h +++ b/torchvision/csrc/cpu/image/read_image_cpu.h @@ -3,4 +3,6 @@ #include "readjpeg_cpu.h" #include "readpng_cpu.h" -C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data); +C10_EXPORT torch::Tensor decode_image( + const torch::Tensor& data, + int64_t channels = 0); diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp index dd2354e4467..d093dca0963 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp @@ -1,18 +1,16 @@ #include "readjpeg_cpu.h" #include -#include #include #if !JPEG_FOUND - -torch::Tensor decodeJPEG(const torch::Tensor& data) { +torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { TORCH_CHECK( false, "decodeJPEG: torchvision not compiled with libjpeg support"); } - #else #include +#include #include "jpegcommon.h" struct torch_jpeg_mgr { @@ -71,13 +69,16 @@ static void torch_jpeg_set_source_mgr( src->pub.next_input_byte = src->data; } -torch::Tensor decodeJPEG(const torch::Tensor& data) { +torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional TORCH_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); + TORCH_CHECK( + channels == 0 || channels == 1 || channels == 3, + "Number of channels not supported"); struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; @@ -100,15 +101,41 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) { // read info from header. jpeg_read_header(&cinfo, TRUE); + + int current_channels = cinfo.num_components; + + if (channels > 0 && channels != current_channels) { + switch (channels) { + case 1: // Gray + cinfo.out_color_space = JCS_GRAYSCALE; + break; + case 3: // RGB + cinfo.out_color_space = JCS_RGB; + break; + /* + * Libjpeg does not support converting from CMYK to grayscale etc. There + * is a way to do this but it involves converting it manually to RGB: + * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 + * + */ + default: + jpeg_destroy_decompress(&cinfo); + TORCH_CHECK(false, "Invalid number of output channels."); + } + + jpeg_calc_output_dimensions(&cinfo); + } else { + channels = current_channels; + } + jpeg_start_decompress(&cinfo); int height = cinfo.output_height; int width = cinfo.output_width; - int components = cinfo.output_components; - auto stride = width * components; - auto tensor = torch::empty( - {int64_t(height), int64_t(width), int64_t(components)}, torch::kU8); + int stride = width * channels; + auto tensor = + torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); auto ptr = tensor.data_ptr(); while (cinfo.output_scanline < cinfo.output_height) { /* jpeg_read_scanlines expects an array of pointers to scanlines. diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.h b/torchvision/csrc/cpu/image/readjpeg_cpu.h index 70482caa81f..0e7bb137d12 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cpu.h +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.h @@ -2,4 +2,6 @@ #include -C10_EXPORT torch::Tensor decodeJPEG(const torch::Tensor& data); +C10_EXPORT torch::Tensor decodeJPEG( + const torch::Tensor& data, + int64_t channels = 0); diff --git a/torchvision/csrc/cpu/image/readpng_cpu.cpp b/torchvision/csrc/cpu/image/readpng_cpu.cpp index 6fbe04ac033..fbca228b436 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.cpp +++ b/torchvision/csrc/cpu/image/readpng_cpu.cpp @@ -2,23 +2,26 @@ // Comment #include -#include #include +#define PNG_FOUND 1 #if !PNG_FOUND -torch::Tensor decodePNG(const torch::Tensor& data) { +torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support"); } #else #include +#include -torch::Tensor decodePNG(const torch::Tensor& data) { +torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional TORCH_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); + TORCH_CHECK( + channels >= 0 && channels <= 4, "Number of channels not supported"); auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); @@ -72,30 +75,79 @@ torch::Tensor decodePNG(const torch::Tensor& data) { TORCH_CHECK(retval == 1, "Could read image metadata from content.") } - int channels; - switch (color_type) { - case PNG_COLOR_TYPE_RGB: - channels = 3; - break; - case PNG_COLOR_TYPE_RGB_ALPHA: - channels = 4; - break; - case PNG_COLOR_TYPE_GRAY: - channels = 1; - break; - case PNG_COLOR_TYPE_GRAY_ALPHA: - channels = 2; - break; - case PNG_COLOR_TYPE_PALETTE: - channels = 1; - break; - default: - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "Image color type is not supported."); + int current_channels = png_get_channels(png_ptr, info_ptr); + + if (channels > 0) { + // TODO: consider supporting PNG_INFO_tRNS + bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; + bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; + bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; + + switch (channels) { + case 1: // Gray + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + break; + case 2: // Gray + Alpha + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + break; + case 3: + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + break; + case 4: + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + break; + default: + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "Invalid number of output channels."); + } + + png_read_update_info(png_ptr, info_ptr); + } else { + channels = current_channels; } - auto tensor = torch::empty( - {int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8); + auto tensor = + torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); auto ptr = tensor.accessor().data(); auto bytes = png_get_rowbytes(png_ptr, info_ptr); for (png_uint_32 i = 0; i < height; ++i) { diff --git a/torchvision/csrc/cpu/image/readpng_cpu.h b/torchvision/csrc/cpu/image/readpng_cpu.h index f84fd99fa92..a36032ddb25 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.h +++ b/torchvision/csrc/cpu/image/readpng_cpu.h @@ -4,4 +4,6 @@ #include #include -C10_EXPORT torch::Tensor decodePNG(const torch::Tensor& data); +C10_EXPORT torch::Tensor decodePNG( + const torch::Tensor& data, + int64_t channels = 0); diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 2279be3ad10..01e1e1e5ca0 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -74,19 +74,24 @@ def write_file(filename: str, data: torch.Tensor) -> None: torch.ops.image.write_file(filename, data) -def decode_png(input: torch.Tensor) -> torch.Tensor: +def decode_png(input: torch.Tensor, channels: int = 0) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB Tensor. + Optionally converts the image to the desired number of color channels. The values of the output tensor are uint8 between 0 and 255. Arguments: input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the PNG image. + channels (int): the number of output channels for the decoded + image. 0 keeps the original number of channels, 1 converts to Grayscale + 2 converts to Grayscale with Alpha, 3 converts to RGB and 4 coverts to + RGB with Alpha. Default: 0 Returns: - output (Tensor[3, image_height, image_width]) + output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_png(input) + output = torch.ops.image.decode_png(input, channels) return output @@ -132,17 +137,23 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor) -> torch.Tensor: +def decode_jpeg(input: torch.Tensor, channels: int = 0) -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. + Optionally converts the image to the desired number of color channels. The values of the output tensor are uint8 between 0 and 255. + Arguments: input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the JPEG image. + channels (int): the number of output channels for the decoded + image. 0 keeps the original number of channels, 1 converts to Grayscale + and 3 converts to RGB. Default: 0 + Returns: - output (Tensor[3, image_height, image_width]) + output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_jpeg(input) + output = torch.ops.image.decode_jpeg(input, channels) return output @@ -191,11 +202,12 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): write_file(filename, output) -def decode_image(input: torch.Tensor) -> torch.Tensor: +def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor: """ Detects whether an image is a JPEG or PNG and performs the appropriate operation to decode the image into a 3 dimensional RGB Tensor. + Optionally converts the image to the desired number of color channels. The values of the output tensor are uint8 between 0 and 255. Parameters @@ -203,28 +215,39 @@ def decode_image(input: torch.Tensor) -> torch.Tensor: input: Tensor a one dimensional uint8 tensor containing the raw bytes of the PNG or JPEG image. + channels: int + the number of output channels of the decoded image. JPEG and PNG images + have different permitted values. The default value is 0 and it keeps + the original number of channels. See `decode_jpeg()` and `decode_png()` + for more information. Default: 0 Returns ------- - output: Tensor[3, image_height, image_width] + output: Tensor[image_channels, image_height, image_width] """ - output = torch.ops.image.decode_image(input) + output = torch.ops.image.decode_image(input, channels) return output -def read_image(path: str) -> torch.Tensor: +def read_image(path: str, channels: int = 0) -> torch.Tensor: """ Reads a JPEG or PNG image into a 3 dimensional RGB Tensor. + Optionally converts the image to the desired number of color channels. The values of the output tensor are uint8 between 0 and 255. Parameters ---------- path: str path of the JPEG or PNG image. + channels: int + the number of output channels of the decoded image. JPEG and PNG images + have different permitted values. The default value is 0 and it keeps + the original number of channels. See `decode_jpeg()` and `decode_png()` + for more information. Default: 0 Returns ------- - output: Tensor[3, image_height, image_width] + output: Tensor[image_channels, image_height, image_width] """ data = read_file(path) - return decode_image(data) + return decode_image(data, channels)