Skip to content

Support specifying output channels in io.image.read_image #2988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/assets/fakedata/logos/cmyk_pytorch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/gray_pytorch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/rgb_pytorch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed test/assets/grace_hopper_517x606.pth
Binary file not shown.
6 changes: 4 additions & 2 deletions test/test_cpp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@ def process_model(model, tensor, func, name):


def read_image1():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((224, 224))
x = F.to_tensor(image)
return x.view(1, 3, 224, 224)


def read_image2():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((299, 299))
x = F.to_tensor(image)
Expand Down
2 changes: 1 addition & 1 deletion test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')


class Tester(unittest.TestCase):
Expand Down
87 changes: 55 additions & 32 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")


def get_images(directory, img_ext):
Expand All @@ -33,14 +34,44 @@ def get_images(directory, img_ext):
yield os.path.join(root, fl)


def pil_read_image(img_path):
with Image.open(img_path) as img:
return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
if len(img_pil.shape) == 3:
img_pil = img_pil.permute(2, 0, 1)
else:
img_pil = img_pil.unsqueeze(0)
return img_pil


class ImageTester(unittest.TestCase):
def test_decode_jpeg(self):
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
data = read_file(img_path)
img_ljpeg = decode_jpeg(data)
self.assertTrue(img_ljpeg.equal(img_pil))
for pil_mode, channels in conversion:
with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
if pil_mode is not None:
if is_cmyk:
# libjpeg does not support the conversion
continue
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))
if is_cmyk:
# flip the colors to match libjpeg
img_pil = 255 - img_pil

img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_ljpeg = decode_image(data, channels=channels)

# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
self.assertTrue(abs_mean_diff < 2)

with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
Expand Down Expand Up @@ -68,7 +99,7 @@ def test_damaged_images(self):
decode_jpeg(data)

def test_encode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
Expand Down Expand Up @@ -111,7 +142,7 @@ def test_encode_jpeg(self):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

def test_write_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)

Expand All @@ -134,20 +165,25 @@ def test_write_jpeg(self):
self.assertEqual(torch_bytes, pil_bytes)

def test_decode_png(self):
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
for img_path in get_images(FAKEDATA_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
if len(img_pil.shape) == 3:
img_pil = img_pil.permute(2, 0, 1)
else:
img_pil = img_pil.unsqueeze(0)
data = read_file(img_path)
img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil))
for pil_mode, channels in conversion:
with Image.open(img_path) as img:
if pil_mode is not None:
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))

with self.assertRaises(RuntimeError):
decode_png(torch.empty((), dtype=torch.uint8))
with self.assertRaises(RuntimeError):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_lpng = decode_image(data, channels=channels)

tol = 0 if conversion is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))

with self.assertRaises(RuntimeError):
decode_png(torch.empty((), dtype=torch.uint8))
with self.assertRaises(RuntimeError):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))

def test_encode_png(self):
for img_path in get_images(IMAGE_DIR, '.png'):
Expand Down Expand Up @@ -196,19 +232,6 @@ def test_write_png(self):

self.assertTrue(img_pil.equal(saved_image))

def test_decode_image(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = decode_image(read_file(img_path))
self.assertTrue(img_ljpeg.equal(img_pil))

for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = decode_image(read_file(img_path))
self.assertTrue(img_lpng.equal(img_pil))

def test_read_file(self):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')


class Tester(unittest.TestCase):
Expand Down
8 changes: 5 additions & 3 deletions torchvision/csrc/cpu/image/read_image_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
#include "read_image_cpu.h"
#include <cstring>

torch::Tensor decode_image(const torch::Tensor& data) {
torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");

auto datap = data.data_ptr<uint8_t>();

const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data);
return decodeJPEG(data, channels);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data);
return decodePNG(data, channels);
} else {
TORCH_CHECK(
false,
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/cpu/image/read_image_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"

C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data);
C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
int64_t channels = 0);
45 changes: 36 additions & 9 deletions torchvision/csrc/cpu/image/readjpeg_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
#include "readjpeg_cpu.h"

#include <ATen/ATen.h>
#include <setjmp.h>
#include <string>

#if !JPEG_FOUND

torch::Tensor decodeJPEG(const torch::Tensor& data) {
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}

#else
#include <jpeglib.h>
#include <setjmp.h>
#include "jpegcommon.h"

struct torch_jpeg_mgr {
Expand Down Expand Up @@ -71,13 +69,16 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}

torch::Tensor decodeJPEG(const torch::Tensor& data) {
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels == 0 || channels == 1 || channels == 3,
"Number of channels not supported");

struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
Expand All @@ -100,15 +101,41 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) {

// read info from header.
jpeg_read_header(&cinfo, TRUE);

int current_channels = cinfo.num_components;

if (channels > 0 && channels != current_channels) {
switch (channels) {
case 1: // Gray
cinfo.out_color_space = JCS_GRAYSCALE;
break;
case 3: // RGB
cinfo.out_color_space = JCS_RGB;
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
* is a way to do this but it involves converting it manually to RGB:
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
*
*/
default:
jpeg_destroy_decompress(&cinfo);
TORCH_CHECK(false, "Invalid number of output channels.");
}

jpeg_calc_output_dimensions(&cinfo);
} else {
channels = current_channels;
}

jpeg_start_decompress(&cinfo);

int height = cinfo.output_height;
int width = cinfo.output_width;
int components = cinfo.output_components;

auto stride = width * components;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), int64_t(components)}, torch::kU8);
int stride = width * channels;
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/cpu/image/readjpeg_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@

#include <torch/torch.h>

C10_EXPORT torch::Tensor decodeJPEG(const torch::Tensor& data);
C10_EXPORT torch::Tensor decodeJPEG(
const torch::Tensor& data,
int64_t channels = 0);
Loading