Skip to content

Commit e1f22ed

Browse files
vmoensVincent Moens
and
Vincent Moens
authored
deinterlacing PNG images with read_image (#4268)
* interlaced png images Co-authored-by: Vincent Moens <[email protected]>
1 parent 7de6265 commit e1f22ed

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed
Loading
168 KB
Loading

test/test_image.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
2121
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
2222
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
23+
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
2324
IS_WINDOWS = sys.platform in ('win32', 'cygwin')
2425
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
2526

@@ -304,6 +305,15 @@ def test_read_1_bit_png_consistency(shape, mode):
304305
assert_equal(img1, img2)
305306

306307

308+
def test_read_interlaced_png():
309+
imgs = list(get_images(INTERLACED_PNG, ".png"))
310+
with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
311+
assert not (im1.info.get("interlace") is im2.info.get("interlace"))
312+
img1 = read_image(imgs[0])
313+
img2 = read_image(imgs[1])
314+
assert_equal(img1, img2)
315+
316+
307317
@needs_cuda
308318
@pytest.mark.parametrize('img_path', [
309319
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))

torchvision/csrc/io/image/cpu/decode_png.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,15 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
5555

5656
png_uint_32 width, height;
5757
int bit_depth, color_type;
58+
int interlace_type;
5859
auto retval = png_get_IHDR(
5960
png_ptr,
6061
info_ptr,
6162
&width,
6263
&height,
6364
&bit_depth,
6465
&color_type,
65-
nullptr,
66+
&interlace_type,
6667
nullptr,
6768
nullptr);
6869

@@ -81,6 +82,13 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
8182
if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
8283
png_set_expand_gray_1_2_4_to_8(png_ptr);
8384

85+
int number_of_passes;
86+
if (interlace_type == PNG_INTERLACE_ADAM7) {
87+
number_of_passes = png_set_interlace_handling(png_ptr);
88+
} else {
89+
number_of_passes = 1;
90+
}
91+
8492
if (mode != IMAGE_READ_MODE_UNCHANGED) {
8593
// TODO: consider supporting PNG_INFO_tRNS
8694
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
@@ -163,9 +171,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
163171
auto tensor =
164172
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
165173
auto ptr = tensor.accessor<uint8_t, 3>().data();
166-
for (png_uint_32 i = 0; i < height; ++i) {
167-
png_read_row(png_ptr, ptr, nullptr);
168-
ptr += width * channels;
174+
for (int pass = 0; pass < number_of_passes; pass++) {
175+
for (png_uint_32 i = 0; i < height; ++i) {
176+
png_read_row(png_ptr, ptr, nullptr);
177+
ptr += width * channels;
178+
}
179+
ptr = tensor.accessor<uint8_t, 3>().data();
169180
}
170181
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
171182
return tensor.permute({2, 0, 1});

0 commit comments

Comments
 (0)