@@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
11
11
}
12
12
#else
13
13
14
+ bool is_little_endian () {
15
+ uint32_t x = 1 ;
16
+ return *(uint8_t *)&x;
17
+ }
18
+
14
19
torch::Tensor decode_png (const torch::Tensor& data, ImageReadMode mode) {
15
20
// Check that the input tensor dtype is uint8
16
21
TORCH_CHECK (data.dtype () == torch::kU8 , " Expected a torch.uint8 tensor" );
@@ -72,9 +77,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
72
77
TORCH_CHECK (retval == 1 , " Could read image metadata from content." )
73
78
}
74
79
75
- if (bit_depth > 8 ) {
80
+ if (bit_depth > 16 ) {
76
81
png_destroy_read_struct (&png_ptr, &info_ptr, nullptr );
77
- TORCH_CHECK (false , " At most 8 -bit PNG images are supported currently." )
82
+ TORCH_CHECK (false , " At most 16 -bit PNG images are supported currently." )
78
83
}
79
84
80
85
int channels = png_get_channels (png_ptr, info_ptr);
@@ -168,15 +173,46 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
168
173
png_read_update_info (png_ptr, info_ptr);
169
174
}
170
175
171
- auto tensor =
172
- torch::empty ({int64_t (height), int64_t (width), channels}, torch::kU8 );
173
- auto ptr = tensor.accessor <uint8_t , 3 >().data ();
174
- for (int pass = 0 ; pass < number_of_passes; pass++) {
175
- for (png_uint_32 i = 0 ; i < height; ++i) {
176
- png_read_row (png_ptr, ptr, nullptr );
177
- ptr += width * channels;
176
+ auto num_pixels_per_row = width * channels;
177
+ auto tensor = torch::empty (
178
+ {int64_t (height), int64_t (width), channels},
179
+ bit_depth <= 8 ? torch::kU8 : torch::kI32 );
180
+
181
+ if (bit_depth <= 8 ) {
182
+ auto t_ptr = tensor.accessor <uint8_t , 3 >().data ();
183
+ for (int pass = 0 ; pass < number_of_passes; pass++) {
184
+ for (png_uint_32 i = 0 ; i < height; ++i) {
185
+ png_read_row (png_ptr, t_ptr, nullptr );
186
+ t_ptr += num_pixels_per_row;
187
+ }
188
+ t_ptr = tensor.accessor <uint8_t , 3 >().data ();
189
+ }
190
+ } else {
191
+ // We're reading a 16bits png, but pytorch doesn't support uint16.
192
+ // So we read each row in a 16bits tmp_buffer which we then cast into
193
+ // a int32 tensor instead.
194
+ if (is_little_endian ()) {
195
+ png_set_swap (png_ptr);
196
+ }
197
+ int32_t * t_ptr = tensor.accessor <int32_t , 3 >().data ();
198
+
199
+ // We create a tensor instead of malloc-ing for automatic memory management
200
+ auto tmp_buffer_tensor = torch::empty (
201
+ {int64_t (num_pixels_per_row * sizeof (uint16_t ))}, torch::kU8 );
202
+ uint16_t * tmp_buffer =
203
+ (uint16_t *)tmp_buffer_tensor.accessor <uint8_t , 1 >().data ();
204
+
205
+ for (int pass = 0 ; pass < number_of_passes; pass++) {
206
+ for (png_uint_32 i = 0 ; i < height; ++i) {
207
+ png_read_row (png_ptr, (uint8_t *)tmp_buffer, nullptr );
208
+ // Now we copy the uint16 values into the int32 tensor.
209
+ for (size_t j = 0 ; j < num_pixels_per_row; ++j) {
210
+ t_ptr[j] = (int32_t )tmp_buffer[j];
211
+ }
212
+ t_ptr += num_pixels_per_row;
213
+ }
214
+ t_ptr = tensor.accessor <int32_t , 3 >().data ();
178
215
}
179
- ptr = tensor.accessor <uint8_t , 3 >().data ();
180
216
}
181
217
png_destroy_read_struct (&png_ptr, &info_ptr, nullptr );
182
218
return tensor.permute ({2 , 0 , 1 });
0 commit comments