Skip to content

Commit a884cb7

Browse files
authored
Add support of mode and remove channels (#3024)
* Add support of mode and remove channels. * Replacing integer mode with define constants.
1 parent 1706921 commit a884cb7

File tree

9 files changed

+151
-124
lines changed

9 files changed

+151
-124
lines changed

test/test_image.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
import io
33
import glob
44
import unittest
5-
import sys
65

76
import torch
8-
import torchvision
97
from PIL import Image
108
from torchvision.io.image import (
119
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
12-
encode_png, write_png, write_file)
10+
encode_png, write_png, write_file, ImageReadMode)
1311
import numpy as np
1412

1513
from common_utils import get_tmp_dir
@@ -49,9 +47,9 @@ def normalize_dimensions(img_pil):
4947

5048
class ImageTester(unittest.TestCase):
5149
def test_decode_jpeg(self):
52-
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
50+
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
5351
for img_path in get_images(IMAGE_ROOT, ".jpg"):
54-
for pil_mode, channels in conversion:
52+
for pil_mode, mode in conversion:
5553
with Image.open(img_path) as img:
5654
is_cmyk = img.mode == "CMYK"
5755
if pil_mode is not None:
@@ -66,7 +64,7 @@ def test_decode_jpeg(self):
6664

6765
img_pil = normalize_dimensions(img_pil)
6866
data = read_file(img_path)
69-
img_ljpeg = decode_image(data, channels=channels)
67+
img_ljpeg = decode_image(data, mode=mode)
7068

7169
# Permit a small variation on pixel values to account for implementation
7270
# differences between Pillow and LibJPEG.
@@ -165,17 +163,18 @@ def test_write_jpeg(self):
165163
self.assertEqual(torch_bytes, pil_bytes)
166164

167165
def test_decode_png(self):
168-
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
166+
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
167+
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
169168
for img_path in get_images(FAKEDATA_DIR, ".png"):
170-
for pil_mode, channels in conversion:
169+
for pil_mode, mode in conversion:
171170
with Image.open(img_path) as img:
172171
if pil_mode is not None:
173172
img = img.convert(pil_mode)
174173
img_pil = torch.from_numpy(np.array(img))
175174

176175
img_pil = normalize_dimensions(img_pil)
177176
data = read_file(img_path)
178-
img_lpng = decode_image(data, channels=channels)
177+
img_lpng = decode_image(data, mode=mode)
179178

180179
tol = 0 if conversion is None else 1
181180
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
/* Should be kept in-sync with Python ImageReadMode enum */
4+
using ImageReadMode = int64_t;
5+
#define IMAGE_READ_MODE_UNCHANGED 0
6+
#define IMAGE_READ_MODE_GRAY 1
7+
#define IMAGE_READ_MODE_GRAY_ALPHA 2
8+
#define IMAGE_READ_MODE_RGB 3
9+
#define IMAGE_READ_MODE_RGB_ALPHA 4

torchvision/csrc/cpu/image/read_image_cpu.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
#include "read_image_cpu.h"
2-
#include <cstring>
2+
#include "readjpeg_cpu.h"
3+
#include "readpng_cpu.h"
34

4-
torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
5+
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
56
// Check that the input tensor dtype is uint8
67
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
78
// Check that the input tensor is 1-dimensional
89
TORCH_CHECK(
910
data.dim() == 1 && data.numel() > 0,
1011
"Expected a non empty 1-dimensional tensor");
11-
TORCH_CHECK(
12-
channels >= 0 && channels <= 4, "Number of channels not supported");
1312

1413
auto datap = data.data_ptr<uint8_t>();
1514

1615
const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
1716
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
1817

1918
if (memcmp(jpeg_signature, datap, 3) == 0) {
20-
return decodeJPEG(data, channels);
19+
return decodeJPEG(data, mode);
2120
} else if (memcmp(png_signature, datap, 4) == 0) {
22-
return decodePNG(data, channels);
21+
return decodePNG(data, mode);
2322
} else {
2423
TORCH_CHECK(
2524
false,
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3-
#include "readjpeg_cpu.h"
4-
#include "readpng_cpu.h"
3+
#include <torch/torch.h>
4+
#include "image_read_mode.h"
55

66
C10_EXPORT torch::Tensor decode_image(
77
const torch::Tensor& data,
8-
int64_t channels = 0);
8+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

torchvision/csrc/cpu/image/readjpeg_cpu.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#include "readjpeg_cpu.h"
22

33
#include <ATen/ATen.h>
4-
#include <string>
54

65
#if !JPEG_FOUND
7-
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
6+
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
87
TORCH_CHECK(
98
false, "decodeJPEG: torchvision not compiled with libjpeg support");
109
}
@@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr(
6968
src->pub.next_input_byte = src->data;
7069
}
7170

72-
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
71+
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
7372
// Check that the input tensor dtype is uint8
7473
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
7574
// Check that the input tensor is 1-dimensional
7675
TORCH_CHECK(
7776
data.dim() == 1 && data.numel() > 0,
7877
"Expected a non empty 1-dimensional tensor");
79-
TORCH_CHECK(
80-
channels == 0 || channels == 1 || channels == 3,
81-
"Number of channels not supported");
8278

8379
struct jpeg_decompress_struct cinfo;
8480
struct torch_jpeg_error_mgr jerr;
@@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
10298
// read info from header.
10399
jpeg_read_header(&cinfo, TRUE);
104100

105-
int current_channels = cinfo.num_components;
101+
int channels = cinfo.num_components;
106102

107-
if (channels > 0 && channels != current_channels) {
108-
switch (channels) {
109-
case 1: // Gray
110-
cinfo.out_color_space = JCS_GRAYSCALE;
103+
if (mode != IMAGE_READ_MODE_UNCHANGED) {
104+
switch (mode) {
105+
case IMAGE_READ_MODE_GRAY:
106+
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
107+
cinfo.out_color_space = JCS_GRAYSCALE;
108+
channels = 1;
109+
}
111110
break;
112-
case 3: // RGB
113-
cinfo.out_color_space = JCS_RGB;
111+
case IMAGE_READ_MODE_RGB:
112+
if (cinfo.jpeg_color_space != JCS_RGB) {
113+
cinfo.out_color_space = JCS_RGB;
114+
channels = 3;
115+
}
114116
break;
115117
/*
116118
* Libjpeg does not support converting from CMYK to grayscale etc. There
117119
* is a way to do this but it involves converting it manually to RGB:
118120
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
119-
*
120121
*/
121122
default:
122123
jpeg_destroy_decompress(&cinfo);
123-
TORCH_CHECK(false, "Invalid number of output channels.");
124+
TORCH_CHECK(false, "Provided mode not supported");
124125
}
125126

126127
jpeg_calc_output_dimensions(&cinfo);
127-
} else {
128-
channels = current_channels;
129128
}
130129

131130
jpeg_start_decompress(&cinfo);
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#pragma once
22

33
#include <torch/torch.h>
4+
#include "image_read_mode.h"
45

56
C10_EXPORT torch::Tensor decodeJPEG(
67
const torch::Tensor& data,
7-
int64_t channels = 0);
8+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

torchvision/csrc/cpu/image/readpng_cpu.cpp

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
#include "readpng_cpu.h"
22

3-
// Comment
43
#include <ATen/ATen.h>
5-
#include <string>
64

75
#if !PNG_FOUND
8-
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
6+
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
97
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
108
}
119
#else
1210
#include <png.h>
1311
#include <setjmp.h>
1412

15-
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
13+
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
1614
// Check that the input tensor dtype is uint8
1715
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
1816
// Check that the input tensor is 1-dimensional
1917
TORCH_CHECK(
2018
data.dim() == 1 && data.numel() > 0,
2119
"Expected a non empty 1-dimensional tensor");
22-
TORCH_CHECK(
23-
channels >= 0 && channels <= 4, "Number of channels not supported");
2420

2521
auto png_ptr =
2622
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
@@ -74,75 +70,85 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
7470
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
7571
}
7672

77-
int current_channels = png_get_channels(png_ptr, info_ptr);
73+
int channels = png_get_channels(png_ptr, info_ptr);
7874

79-
if (channels > 0) {
75+
if (mode != IMAGE_READ_MODE_UNCHANGED) {
8076
// TODO: consider supporting PNG_INFO_tRNS
8177
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
8278
bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
8379
bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;
8480

85-
switch (channels) {
86-
case 1: // Gray
87-
if (is_palette) {
88-
png_set_palette_to_rgb(png_ptr);
89-
has_alpha = true;
90-
}
91-
92-
if (has_alpha) {
93-
png_set_strip_alpha(png_ptr);
94-
}
95-
96-
if (has_color) {
97-
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
81+
switch (mode) {
82+
case IMAGE_READ_MODE_GRAY:
83+
if (color_type != PNG_COLOR_TYPE_GRAY) {
84+
if (is_palette) {
85+
png_set_palette_to_rgb(png_ptr);
86+
has_alpha = true;
87+
}
88+
89+
if (has_alpha) {
90+
png_set_strip_alpha(png_ptr);
91+
}
92+
93+
if (has_color) {
94+
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
95+
}
96+
channels = 1;
9897
}
9998
break;
100-
case 2: // Gray + Alpha
101-
if (is_palette) {
102-
png_set_palette_to_rgb(png_ptr);
103-
has_alpha = true;
104-
}
105-
106-
if (!has_alpha) {
107-
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
108-
}
109-
110-
if (has_color) {
111-
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
99+
case IMAGE_READ_MODE_GRAY_ALPHA:
100+
if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
101+
if (is_palette) {
102+
png_set_palette_to_rgb(png_ptr);
103+
has_alpha = true;
104+
}
105+
106+
if (!has_alpha) {
107+
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
108+
}
109+
110+
if (has_color) {
111+
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
112+
}
113+
channels = 2;
112114
}
113115
break;
114-
case 3:
115-
if (is_palette) {
116-
png_set_palette_to_rgb(png_ptr);
117-
has_alpha = true;
118-
} else if (!has_color) {
119-
png_set_gray_to_rgb(png_ptr);
120-
}
121-
122-
if (has_alpha) {
123-
png_set_strip_alpha(png_ptr);
116+
case IMAGE_READ_MODE_RGB:
117+
if (color_type != PNG_COLOR_TYPE_RGB) {
118+
if (is_palette) {
119+
png_set_palette_to_rgb(png_ptr);
120+
has_alpha = true;
121+
} else if (!has_color) {
122+
png_set_gray_to_rgb(png_ptr);
123+
}
124+
125+
if (has_alpha) {
126+
png_set_strip_alpha(png_ptr);
127+
}
128+
channels = 3;
124129
}
125130
break;
126-
case 4:
127-
if (is_palette) {
128-
png_set_palette_to_rgb(png_ptr);
129-
has_alpha = true;
130-
} else if (!has_color) {
131-
png_set_gray_to_rgb(png_ptr);
132-
}
133-
134-
if (!has_alpha) {
135-
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
131+
case IMAGE_READ_MODE_RGB_ALPHA:
132+
if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
133+
if (is_palette) {
134+
png_set_palette_to_rgb(png_ptr);
135+
has_alpha = true;
136+
} else if (!has_color) {
137+
png_set_gray_to_rgb(png_ptr);
138+
}
139+
140+
if (!has_alpha) {
141+
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
142+
}
143+
channels = 4;
136144
}
137145
break;
138146
default:
139147
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
140-
TORCH_CHECK(false, "Invalid number of output channels.");
148+
TORCH_CHECK(false, "Provided mode not supported");
141149
}
142150

143151
png_read_update_info(png_ptr, info_ptr);
144-
} else {
145-
channels = current_channels;
146152
}
147153

148154
auto tensor =
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#pragma once
22

3-
// Comment
43
#include <torch/torch.h>
5-
#include <string>
4+
#include "image_read_mode.h"
65

76
C10_EXPORT torch::Tensor decodePNG(
87
const torch::Tensor& data,
9-
int64_t channels = 0);
8+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

0 commit comments

Comments
 (0)