Skip to content

Commit 83d6f0e

Browse files
datumboxfmassa
authored andcommitted
[fbsync] Add support for 16 bits png images (#4657)
Summary: * WIP * cleaner code * Add tests * Add docs * Assert dtype * put back check * Address comments Reviewed By: NicolasHug Differential Revision: D31916334 fbshipit-source-id: 8877266f6e533e8c45c5f202e535944a9a939376 Co-authored-by: Francisco Massa <[email protected]>
1 parent be3ef03 commit 83d6f0e

File tree

5 files changed

+61
-13
lines changed

5 files changed

+61
-13
lines changed
68.2 KB
Loading
Loading

test/test_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ def test_decode_png(img_path, pil_mode, mode):
168168
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
169169
img_lpng, img_pil = img_lpng[0], img_pil[0]
170170

171+
if "16" in img_path:
172+
# PIL converts 16 bits pngs in uint8
173+
assert img_lpng.dtype == torch.int32
174+
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
175+
171176
torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
172177

173178

torchvision/csrc/io/image/cpu/decode_png.cpp

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
1111
}
1212
#else
1313

14+
bool is_little_endian() {
15+
uint32_t x = 1;
16+
return *(uint8_t*)&x;
17+
}
18+
1419
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
1520
// Check that the input tensor dtype is uint8
1621
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
@@ -72,9 +77,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
7277
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
7378
}
7479

75-
if (bit_depth > 8) {
80+
if (bit_depth > 16) {
7681
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
77-
TORCH_CHECK(false, "At most 8-bit PNG images are supported currently.")
82+
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.")
7883
}
7984

8085
int channels = png_get_channels(png_ptr, info_ptr);
@@ -168,15 +173,46 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
168173
png_read_update_info(png_ptr, info_ptr);
169174
}
170175

171-
auto tensor =
172-
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
173-
auto ptr = tensor.accessor<uint8_t, 3>().data();
174-
for (int pass = 0; pass < number_of_passes; pass++) {
175-
for (png_uint_32 i = 0; i < height; ++i) {
176-
png_read_row(png_ptr, ptr, nullptr);
177-
ptr += width * channels;
176+
auto num_pixels_per_row = width * channels;
177+
auto tensor = torch::empty(
178+
{int64_t(height), int64_t(width), channels},
179+
bit_depth <= 8 ? torch::kU8 : torch::kI32);
180+
181+
if (bit_depth <= 8) {
182+
auto t_ptr = tensor.accessor<uint8_t, 3>().data();
183+
for (int pass = 0; pass < number_of_passes; pass++) {
184+
for (png_uint_32 i = 0; i < height; ++i) {
185+
png_read_row(png_ptr, t_ptr, nullptr);
186+
t_ptr += num_pixels_per_row;
187+
}
188+
t_ptr = tensor.accessor<uint8_t, 3>().data();
189+
}
190+
} else {
191+
// We're reading a 16bits png, but pytorch doesn't support uint16.
192+
// So we read each row in a 16bits tmp_buffer which we then cast into
193+
// a int32 tensor instead.
194+
if (is_little_endian()) {
195+
png_set_swap(png_ptr);
196+
}
197+
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();
198+
199+
// We create a tensor instead of malloc-ing for automatic memory management
200+
auto tmp_buffer_tensor = torch::empty(
201+
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
202+
uint16_t* tmp_buffer =
203+
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();
204+
205+
for (int pass = 0; pass < number_of_passes; pass++) {
206+
for (png_uint_32 i = 0; i < height; ++i) {
207+
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
208+
// Now we copy the uint16 values into the int32 tensor.
209+
for (size_t j = 0; j < num_pixels_per_row; ++j) {
210+
t_ptr[j] = (int32_t)tmp_buffer[j];
211+
}
212+
t_ptr += num_pixels_per_row;
213+
}
214+
t_ptr = tensor.accessor<int32_t, 3>().data();
178215
}
179-
ptr = tensor.accessor<uint8_t, 3>().data();
180216
}
181217
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
182218
return tensor.permute({2, 0, 1});

torchvision/io/image.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
6161
"""
6262
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
6363
Optionally converts the image to the desired format.
64-
The values of the output tensor are uint8 between 0 and 255.
64+
The values of the output tensor are uint8 in [0, 255], except for
65+
16-bits pngs which are int32 tensors in [0, 65535].
66+
67+
.. warning::
68+
Should pytorch ever support the uint16 dtype natively, the dtype of the
69+
output for 16-bits pngs will be updated from int32 to uint16.
6570
6671
Args:
6772
input (Tensor[1]): a one dimensional uint8 tensor containing
@@ -188,7 +193,8 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
188193
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
189194
190195
Optionally converts the image to the desired format.
191-
The values of the output tensor are uint8 between 0 and 255.
196+
The values of the output tensor are uint8 in [0, 255], except for
197+
16-bits pngs which are int32 tensors in [0, 65535].
192198
193199
Args:
194200
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
@@ -209,7 +215,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
209215
"""
210216
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
211217
Optionally converts the image to the desired format.
212-
The values of the output tensor are uint8 between 0 and 255.
218+
The values of the output tensor are uint8 in [0, 255], except for
219+
16-bits pngs which are int32 tensors in [0, 65535].
213220
214221
Args:
215222
path (str): path of the JPEG or PNG image.

0 commit comments

Comments
 (0)