@@ -55,14 +55,15 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
55
55
56
56
png_uint_32 width, height;
57
57
int bit_depth, color_type;
58
+ int interlace_type;
58
59
auto retval = png_get_IHDR (
59
60
png_ptr,
60
61
info_ptr,
61
62
&width,
62
63
&height,
63
64
&bit_depth,
64
65
&color_type,
65
- nullptr ,
66
+ &interlace_type ,
66
67
nullptr ,
67
68
nullptr );
68
69
@@ -81,6 +82,13 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
81
82
if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8 )
82
83
png_set_expand_gray_1_2_4_to_8 (png_ptr);
83
84
85
+ int number_of_passes;
86
+ if (interlace_type == PNG_INTERLACE_ADAM7) {
87
+ number_of_passes = png_set_interlace_handling (png_ptr);
88
+ } else {
89
+ number_of_passes = 1 ;
90
+ }
91
+
84
92
if (mode != IMAGE_READ_MODE_UNCHANGED) {
85
93
// TODO: consider supporting PNG_INFO_tRNS
86
94
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0 ;
@@ -163,9 +171,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
163
171
auto tensor =
164
172
torch::empty ({int64_t (height), int64_t (width), channels}, torch::kU8 );
165
173
auto ptr = tensor.accessor <uint8_t , 3 >().data ();
166
- for (png_uint_32 i = 0 ; i < height; ++i) {
167
- png_read_row (png_ptr, ptr, nullptr );
168
- ptr += width * channels;
174
+ for (int pass = 0 ; pass < number_of_passes; pass++) {
175
+ for (png_uint_32 i = 0 ; i < height; ++i) {
176
+ png_read_row (png_ptr, ptr, nullptr );
177
+ ptr += width * channels;
178
+ }
179
+ ptr = tensor.accessor <uint8_t , 3 >().data ();
169
180
}
170
181
png_destroy_read_struct (&png_ptr, &info_ptr, nullptr );
171
182
return tensor.permute ({2 , 0 , 1 });
0 commit comments