Skip to content

Commit b67b3d5

Browse files
datumboxvfdev-5
authored andcommitted
Support specifying output channels in io.image.read_image (pytorch#2988)
* Adding output channels implementation for pngs. * Adding tests for png. * Adding channels in the API and documentation. * Fixing formatting. * Refactoring test_image.py to remove huge grace_hopper_517x606.pth file from assets and reduce duplicate code. Moving jpeg assets used by encode and write unit-tests on their separate folders. * Adding output channels implementation for jpegs. Fix asset locations. * Add tests for JPEG, adding the channels in the API and documentation and adding checks for inputs. * Changing folder for unit-test. * Fixing windows flakiness, removing duplicate test. * Replacing components to channels. * Adding reference for supporting CMYK. * Minor changes: num_components to output_components, adding comments, fixing variable name etc. * Reverting output_components to num_components. * Replacing decoding with generic method on tests. * Palette converted to Gray.
1 parent cfd15fe commit b67b3d5

17 files changed

+223
-88
lines changed
3.45 KB
Loading
1.4 KB
Loading
2.08 KB
Loading

test/assets/grace_hopper_517x606.pth

-919 KB
Binary file not shown.

test/test_cpp_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@ def process_model(model, tensor, func, name):
2525

2626

2727
def read_image1():
28-
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
28+
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
29+
'grace_hopper_517x606.jpg')
2930
image = Image.open(image_path)
3031
image = image.resize((224, 224))
3132
x = F.to_tensor(image)
3233
return x.view(1, 3, 224, 224)
3334

3435

3536
def read_image2():
36-
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
37+
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
38+
'grace_hopper_517x606.jpg')
3739
image = Image.open(image_path)
3840
image = image.resize((299, 299))
3941
x = F.to_tensor(image)

test/test_datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
TEST_FILE = get_file_path_2(
17-
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
17+
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
1818

1919

2020
class Tester(unittest.TestCase):

test/test_image.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
2020
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
2121
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
22+
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
2223

2324

2425
def get_images(directory, img_ext):
@@ -33,14 +34,44 @@ def get_images(directory, img_ext):
3334
yield os.path.join(root, fl)
3435

3536

37+
def pil_read_image(img_path):
38+
with Image.open(img_path) as img:
39+
return torch.from_numpy(np.array(img))
40+
41+
42+
def normalize_dimensions(img_pil):
43+
if len(img_pil.shape) == 3:
44+
img_pil = img_pil.permute(2, 0, 1)
45+
else:
46+
img_pil = img_pil.unsqueeze(0)
47+
return img_pil
48+
49+
3650
class ImageTester(unittest.TestCase):
3751
def test_decode_jpeg(self):
52+
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
3853
for img_path in get_images(IMAGE_ROOT, ".jpg"):
39-
img_pil = torch.load(img_path.replace('jpg', 'pth'))
40-
img_pil = img_pil.permute(2, 0, 1)
41-
data = read_file(img_path)
42-
img_ljpeg = decode_jpeg(data)
43-
self.assertTrue(img_ljpeg.equal(img_pil))
54+
for pil_mode, channels in conversion:
55+
with Image.open(img_path) as img:
56+
is_cmyk = img.mode == "CMYK"
57+
if pil_mode is not None:
58+
if is_cmyk:
59+
# libjpeg does not support the conversion
60+
continue
61+
img = img.convert(pil_mode)
62+
img_pil = torch.from_numpy(np.array(img))
63+
if is_cmyk:
64+
# flip the colors to match libjpeg
65+
img_pil = 255 - img_pil
66+
67+
img_pil = normalize_dimensions(img_pil)
68+
data = read_file(img_path)
69+
img_ljpeg = decode_image(data, channels=channels)
70+
71+
# Permit a small variation on pixel values to account for implementation
72+
# differences between Pillow and LibJPEG.
73+
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
74+
self.assertTrue(abs_mean_diff < 2)
4475

4576
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
4677
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
@@ -68,7 +99,7 @@ def test_damaged_images(self):
6899
decode_jpeg(data)
69100

70101
def test_encode_jpeg(self):
71-
for img_path in get_images(IMAGE_ROOT, ".jpg"):
102+
for img_path in get_images(ENCODE_JPEG, ".jpg"):
72103
dirname = os.path.dirname(img_path)
73104
filename, _ = os.path.splitext(os.path.basename(img_path))
74105
write_folder = os.path.join(dirname, 'jpeg_write')
@@ -111,7 +142,7 @@ def test_encode_jpeg(self):
111142
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
112143

113144
def test_write_jpeg(self):
114-
for img_path in get_images(IMAGE_ROOT, ".jpg"):
145+
for img_path in get_images(ENCODE_JPEG, ".jpg"):
115146
data = read_file(img_path)
116147
img = decode_jpeg(data)
117148

@@ -134,20 +165,25 @@ def test_write_jpeg(self):
134165
self.assertEqual(torch_bytes, pil_bytes)
135166

136167
def test_decode_png(self):
168+
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
137169
for img_path in get_images(FAKEDATA_DIR, ".png"):
138-
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
139-
if len(img_pil.shape) == 3:
140-
img_pil = img_pil.permute(2, 0, 1)
141-
else:
142-
img_pil = img_pil.unsqueeze(0)
143-
data = read_file(img_path)
144-
img_lpng = decode_png(data)
145-
self.assertTrue(img_lpng.equal(img_pil))
170+
for pil_mode, channels in conversion:
171+
with Image.open(img_path) as img:
172+
if pil_mode is not None:
173+
img = img.convert(pil_mode)
174+
img_pil = torch.from_numpy(np.array(img))
146175

147-
with self.assertRaises(RuntimeError):
148-
decode_png(torch.empty((), dtype=torch.uint8))
149-
with self.assertRaises(RuntimeError):
150-
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
176+
img_pil = normalize_dimensions(img_pil)
177+
data = read_file(img_path)
178+
img_lpng = decode_image(data, channels=channels)
179+
180+
tol = 0 if conversion is None else 1
181+
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
182+
183+
with self.assertRaises(RuntimeError):
184+
decode_png(torch.empty((), dtype=torch.uint8))
185+
with self.assertRaises(RuntimeError):
186+
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
151187

152188
def test_encode_png(self):
153189
for img_path in get_images(IMAGE_DIR, '.png'):
@@ -196,19 +232,6 @@ def test_write_png(self):
196232

197233
self.assertTrue(img_pil.equal(saved_image))
198234

199-
def test_decode_image(self):
200-
for img_path in get_images(IMAGE_ROOT, ".jpg"):
201-
img_pil = torch.load(img_path.replace('jpg', 'pth'))
202-
img_pil = img_pil.permute(2, 0, 1)
203-
img_ljpeg = decode_image(read_file(img_path))
204-
self.assertTrue(img_ljpeg.equal(img_pil))
205-
206-
for img_path in get_images(IMAGE_DIR, ".png"):
207-
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
208-
img_pil = img_pil.permute(2, 0, 1)
209-
img_lpng = decode_image(read_file(img_path))
210-
self.assertTrue(img_lpng.equal(img_pil))
211-
212235
def test_read_file(self):
213236
with get_tmp_dir() as d:
214237
fname, content = 'test1.bin', b'TorchVision\211\n'

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
GRACE_HOPPER = get_file_path_2(
27-
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
27+
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
2828

2929

3030
class Tester(unittest.TestCase):

torchvision/csrc/cpu/image/read_image_cpu.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
#include "read_image_cpu.h"
22
#include <cstring>
33

4-
torch::Tensor decode_image(const torch::Tensor& data) {
4+
torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
55
// Check that the input tensor dtype is uint8
66
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
77
// Check that the input tensor is 1-dimensional
88
TORCH_CHECK(
99
data.dim() == 1 && data.numel() > 0,
1010
"Expected a non empty 1-dimensional tensor");
11+
TORCH_CHECK(
12+
channels >= 0 && channels <= 4, "Number of channels not supported");
1113

1214
auto datap = data.data_ptr<uint8_t>();
1315

1416
const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
1517
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
1618

1719
if (memcmp(jpeg_signature, datap, 3) == 0) {
18-
return decodeJPEG(data);
20+
return decodeJPEG(data, channels);
1921
} else if (memcmp(png_signature, datap, 4) == 0) {
20-
return decodePNG(data);
22+
return decodePNG(data, channels);
2123
} else {
2224
TORCH_CHECK(
2325
false,

torchvision/csrc/cpu/image/read_image_cpu.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
#include "readjpeg_cpu.h"
44
#include "readpng_cpu.h"
55

6-
C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data);
6+
C10_EXPORT torch::Tensor decode_image(
7+
const torch::Tensor& data,
8+
int64_t channels = 0);

torchvision/csrc/cpu/image/readjpeg_cpu.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
#include "readjpeg_cpu.h"
22

33
#include <ATen/ATen.h>
4-
#include <setjmp.h>
54
#include <string>
65

76
#if !JPEG_FOUND
8-
9-
torch::Tensor decodeJPEG(const torch::Tensor& data) {
7+
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
108
TORCH_CHECK(
119
false, "decodeJPEG: torchvision not compiled with libjpeg support");
1210
}
13-
1411
#else
1512
#include <jpeglib.h>
13+
#include <setjmp.h>
1614
#include "jpegcommon.h"
1715

1816
struct torch_jpeg_mgr {
@@ -71,13 +69,16 @@ static void torch_jpeg_set_source_mgr(
7169
src->pub.next_input_byte = src->data;
7270
}
7371

74-
torch::Tensor decodeJPEG(const torch::Tensor& data) {
72+
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
7573
// Check that the input tensor dtype is uint8
7674
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
7775
// Check that the input tensor is 1-dimensional
7876
TORCH_CHECK(
7977
data.dim() == 1 && data.numel() > 0,
8078
"Expected a non empty 1-dimensional tensor");
79+
TORCH_CHECK(
80+
channels == 0 || channels == 1 || channels == 3,
81+
"Number of channels not supported");
8182

8283
struct jpeg_decompress_struct cinfo;
8384
struct torch_jpeg_error_mgr jerr;
@@ -100,15 +101,41 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) {
100101

101102
// read info from header.
102103
jpeg_read_header(&cinfo, TRUE);
104+
105+
int current_channels = cinfo.num_components;
106+
107+
if (channels > 0 && channels != current_channels) {
108+
switch (channels) {
109+
case 1: // Gray
110+
cinfo.out_color_space = JCS_GRAYSCALE;
111+
break;
112+
case 3: // RGB
113+
cinfo.out_color_space = JCS_RGB;
114+
break;
115+
/*
116+
* Libjpeg does not support converting from CMYK to grayscale etc. There
117+
* is a way to do this but it involves converting it manually to RGB:
118+
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
119+
*
120+
*/
121+
default:
122+
jpeg_destroy_decompress(&cinfo);
123+
TORCH_CHECK(false, "Invalid number of output channels.");
124+
}
125+
126+
jpeg_calc_output_dimensions(&cinfo);
127+
} else {
128+
channels = current_channels;
129+
}
130+
103131
jpeg_start_decompress(&cinfo);
104132

105133
int height = cinfo.output_height;
106134
int width = cinfo.output_width;
107-
int components = cinfo.output_components;
108135

109-
auto stride = width * components;
110-
auto tensor = torch::empty(
111-
{int64_t(height), int64_t(width), int64_t(components)}, torch::kU8);
136+
int stride = width * channels;
137+
auto tensor =
138+
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
112139
auto ptr = tensor.data_ptr<uint8_t>();
113140
while (cinfo.output_scanline < cinfo.output_height) {
114141
/* jpeg_read_scanlines expects an array of pointers to scanlines.

torchvision/csrc/cpu/image/readjpeg_cpu.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22

33
#include <torch/torch.h>
44

5-
C10_EXPORT torch::Tensor decodeJPEG(const torch::Tensor& data);
5+
C10_EXPORT torch::Tensor decodeJPEG(
6+
const torch::Tensor& data,
7+
int64_t channels = 0);

0 commit comments

Comments
 (0)