|
1 | 1 | import glob
|
2 | 2 | import io
|
3 | 3 | import os
|
| 4 | +import sys |
4 | 5 | import unittest
|
| 6 | +from pathlib import Path |
5 | 7 |
|
6 | 8 | import pytest
|
7 | 9 | import numpy as np
|
8 | 10 | import torch
|
9 | 11 | from PIL import Image
|
10 |
| -from common_utils import get_tmp_dir, needs_cuda |
| 12 | +import torchvision.transforms.functional as F |
| 13 | +from common_utils import get_tmp_dir, needs_cuda, cpu_only |
11 | 14 | from _assert_utils import assert_equal
|
12 | 15 |
|
13 | 16 | from torchvision.io.image import (
|
14 | 17 | decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
|
15 |
| - encode_png, write_png, write_file, ImageReadMode) |
| 18 | + encode_png, write_png, write_file, ImageReadMode, read_image) |
16 | 19 |
|
17 | 20 | IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
|
18 | 21 | FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
|
19 | 22 | IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
|
20 | 23 | DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
|
21 | 24 | ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
|
| 25 | +IS_WINDOWS = sys.platform in ('win32', 'cygwin') |
| 26 | + |
| 27 | + |
| 28 | +def _get_safe_image_name(name): |
| 29 | + # Used when we need to change the pytest "id" for an "image path" parameter. |
| 30 | + # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific, |
| 31 | + # and this creates issues when the test is running in a different machine than where it was collected |
| 32 | + # (typically, in fb internal infra) |
| 33 | + return name.split(os.path.sep)[-1] |
22 | 34 |
|
23 | 35 |
|
24 | 36 | def get_images(directory, img_ext):
|
@@ -93,72 +105,6 @@ def test_damaged_images(self):
|
93 | 105 | with self.assertRaises(RuntimeError):
|
94 | 106 | decode_jpeg(data)
|
95 | 107 |
|
96 |
| - def test_encode_jpeg(self): |
97 |
| - for img_path in get_images(ENCODE_JPEG, ".jpg"): |
98 |
| - dirname = os.path.dirname(img_path) |
99 |
| - filename, _ = os.path.splitext(os.path.basename(img_path)) |
100 |
| - write_folder = os.path.join(dirname, 'jpeg_write') |
101 |
| - expected_file = os.path.join( |
102 |
| - write_folder, '{0}_pil.jpg'.format(filename)) |
103 |
| - img = decode_jpeg(read_file(img_path)) |
104 |
| - |
105 |
| - with open(expected_file, 'rb') as f: |
106 |
| - pil_bytes = f.read() |
107 |
| - pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) |
108 |
| - for src_img in [img, img.contiguous()]: |
109 |
| - # PIL sets jpeg quality to 75 by default |
110 |
| - jpeg_bytes = encode_jpeg(src_img, quality=75) |
111 |
| - assert_equal(jpeg_bytes, pil_bytes) |
112 |
| - |
113 |
| - with self.assertRaisesRegex( |
114 |
| - RuntimeError, "Input tensor dtype should be uint8"): |
115 |
| - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) |
116 |
| - |
117 |
| - with self.assertRaisesRegex( |
118 |
| - ValueError, "Image quality should be a positive number " |
119 |
| - "between 1 and 100"): |
120 |
| - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) |
121 |
| - |
122 |
| - with self.assertRaisesRegex( |
123 |
| - ValueError, "Image quality should be a positive number " |
124 |
| - "between 1 and 100"): |
125 |
| - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) |
126 |
| - |
127 |
| - with self.assertRaisesRegex( |
128 |
| - RuntimeError, "The number of channels should be 1 or 3, got: 5"): |
129 |
| - encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) |
130 |
| - |
131 |
| - with self.assertRaisesRegex( |
132 |
| - RuntimeError, "Input data should be a 3-dimensional tensor"): |
133 |
| - encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) |
134 |
| - |
135 |
| - with self.assertRaisesRegex( |
136 |
| - RuntimeError, "Input data should be a 3-dimensional tensor"): |
137 |
| - encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) |
138 |
| - |
139 |
| - def test_write_jpeg(self): |
140 |
| - with get_tmp_dir() as d: |
141 |
| - for img_path in get_images(ENCODE_JPEG, ".jpg"): |
142 |
| - data = read_file(img_path) |
143 |
| - img = decode_jpeg(data) |
144 |
| - |
145 |
| - basedir = os.path.dirname(img_path) |
146 |
| - filename, _ = os.path.splitext(os.path.basename(img_path)) |
147 |
| - torch_jpeg = os.path.join( |
148 |
| - d, '{0}_torch.jpg'.format(filename)) |
149 |
| - pil_jpeg = os.path.join( |
150 |
| - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) |
151 |
| - |
152 |
| - write_jpeg(img, torch_jpeg, quality=75) |
153 |
| - |
154 |
| - with open(torch_jpeg, 'rb') as f: |
155 |
| - torch_bytes = f.read() |
156 |
| - |
157 |
| - with open(pil_jpeg, 'rb') as f: |
158 |
| - pil_bytes = f.read() |
159 |
| - |
160 |
| - self.assertEqual(torch_bytes, pil_bytes) |
161 |
| - |
162 | 108 | def test_decode_png(self):
|
163 | 109 | conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
|
164 | 110 | ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
|
@@ -282,11 +228,7 @@ def test_write_file_non_ascii(self):
|
282 | 228 |
|
283 | 229 | @needs_cuda
|
284 | 230 | @pytest.mark.parametrize('img_path', [
|
285 |
| - # We need to change the "id" for that parameter. |
286 |
| - # If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific, |
287 |
| - # and this creates issues when the test is running in a different machine than where it was collected |
288 |
| - # (typically, in fb internal infra) |
289 |
| - pytest.param(jpeg_path, id=jpeg_path.split('/')[-1]) |
| 231 | + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) |
290 | 232 | for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
|
291 | 233 | ])
|
292 | 234 | @pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
|
@@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors():
|
325 | 267 | torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
|
326 | 268 |
|
327 | 269 |
|
| 270 | +@cpu_only |
| 271 | +def test_encode_jpeg_errors(): |
| 272 | + |
| 273 | + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): |
| 274 | + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) |
| 275 | + |
| 276 | + with pytest.raises(ValueError, match="Image quality should be a positive number " |
| 277 | + "between 1 and 100"): |
| 278 | + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) |
| 279 | + |
| 280 | + with pytest.raises(ValueError, match="Image quality should be a positive number " |
| 281 | + "between 1 and 100"): |
| 282 | + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) |
| 283 | + |
| 284 | + with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): |
| 285 | + encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) |
| 286 | + |
| 287 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 288 | + encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) |
| 289 | + |
| 290 | + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): |
| 291 | + encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) |
| 292 | + |
| 293 | + |
| 294 | +def _collect_if(cond): |
| 295 | + # TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows |
| 296 | + # are removed |
| 297 | + def _inner(test_func): |
| 298 | + if cond: |
| 299 | + return test_func |
| 300 | + else: |
| 301 | + return pytest.mark.dont_collect(test_func) |
| 302 | + return _inner |
| 303 | + |
| 304 | + |
| 305 | +@cpu_only |
| 306 | +@_collect_if(cond=IS_WINDOWS) |
| 307 | +def test_encode_jpeg_windows(): |
| 308 | + # This test is *wrong*. |
| 309 | + # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it |
| 310 | + # starts encoding the torchvision version from an image that comes from |
| 311 | + # decode_jpeg, which can yield different results from pil.decode (see |
| 312 | + # test_decode... which uses a high tolerance). |
| 313 | + # Instead, we should start encoding from the exact same decoded image, for a |
| 314 | + # valid comparison. This is done in test_encode_jpeg, but unfortunately |
| 315 | + # these more correct tests fail on windows (probably because of a difference |
| 316 | + # in libjpeg) between torchvision and PIL. |
| 317 | + # FIXME: make the correct tests pass on windows and remove this. |
| 318 | + for img_path in get_images(ENCODE_JPEG, ".jpg"): |
| 319 | + dirname = os.path.dirname(img_path) |
| 320 | + filename, _ = os.path.splitext(os.path.basename(img_path)) |
| 321 | + write_folder = os.path.join(dirname, 'jpeg_write') |
| 322 | + expected_file = os.path.join( |
| 323 | + write_folder, '{0}_pil.jpg'.format(filename)) |
| 324 | + img = decode_jpeg(read_file(img_path)) |
| 325 | + |
| 326 | + with open(expected_file, 'rb') as f: |
| 327 | + pil_bytes = f.read() |
| 328 | + pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) |
| 329 | + for src_img in [img, img.contiguous()]: |
| 330 | + # PIL sets jpeg quality to 75 by default |
| 331 | + jpeg_bytes = encode_jpeg(src_img, quality=75) |
| 332 | + assert_equal(jpeg_bytes, pil_bytes) |
| 333 | + |
| 334 | + |
| 335 | +@cpu_only |
| 336 | +@_collect_if(cond=IS_WINDOWS) |
| 337 | +def test_write_jpeg_windows(): |
| 338 | + # FIXME: Remove this eventually, see test_encode_jpeg_windows |
| 339 | + with get_tmp_dir() as d: |
| 340 | + for img_path in get_images(ENCODE_JPEG, ".jpg"): |
| 341 | + data = read_file(img_path) |
| 342 | + img = decode_jpeg(data) |
| 343 | + |
| 344 | + basedir = os.path.dirname(img_path) |
| 345 | + filename, _ = os.path.splitext(os.path.basename(img_path)) |
| 346 | + torch_jpeg = os.path.join( |
| 347 | + d, '{0}_torch.jpg'.format(filename)) |
| 348 | + pil_jpeg = os.path.join( |
| 349 | + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) |
| 350 | + |
| 351 | + write_jpeg(img, torch_jpeg, quality=75) |
| 352 | + |
| 353 | + with open(torch_jpeg, 'rb') as f: |
| 354 | + torch_bytes = f.read() |
| 355 | + |
| 356 | + with open(pil_jpeg, 'rb') as f: |
| 357 | + pil_bytes = f.read() |
| 358 | + |
| 359 | + assert_equal(torch_bytes, pil_bytes) |
| 360 | + |
| 361 | + |
| 362 | +@cpu_only |
| 363 | +@_collect_if(cond=not IS_WINDOWS) |
| 364 | +@pytest.mark.parametrize('img_path', [ |
| 365 | + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) |
| 366 | + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") |
| 367 | +]) |
| 368 | +def test_encode_jpeg(img_path): |
| 369 | + img = read_image(img_path) |
| 370 | + |
| 371 | + pil_img = F.to_pil_image(img) |
| 372 | + buf = io.BytesIO() |
| 373 | + pil_img.save(buf, format='JPEG', quality=75) |
| 374 | + |
| 375 | + # pytorch can't read from raw bytes so we go through numpy |
| 376 | + pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) |
| 377 | + encoded_jpeg_pil = torch.as_tensor(pil_bytes) |
| 378 | + |
| 379 | + for src_img in [img, img.contiguous()]: |
| 380 | + encoded_jpeg_torch = encode_jpeg(src_img, quality=75) |
| 381 | + assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) |
| 382 | + |
| 383 | + |
| 384 | +@cpu_only |
| 385 | +@_collect_if(cond=not IS_WINDOWS) |
| 386 | +@pytest.mark.parametrize('img_path', [ |
| 387 | + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) |
| 388 | + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") |
| 389 | +]) |
| 390 | +def test_write_jpeg(img_path): |
| 391 | + with get_tmp_dir() as d: |
| 392 | + d = Path(d) |
| 393 | + img = read_image(img_path) |
| 394 | + pil_img = F.to_pil_image(img) |
| 395 | + |
| 396 | + torch_jpeg = str(d / 'torch.jpg') |
| 397 | + pil_jpeg = str(d / 'pil.jpg') |
| 398 | + |
| 399 | + write_jpeg(img, torch_jpeg, quality=75) |
| 400 | + pil_img.save(pil_jpeg, quality=75) |
| 401 | + |
| 402 | + with open(torch_jpeg, 'rb') as f: |
| 403 | + torch_bytes = f.read() |
| 404 | + |
| 405 | + with open(pil_jpeg, 'rb') as f: |
| 406 | + pil_bytes = f.read() |
| 407 | + |
| 408 | + assert_equal(torch_bytes, pil_bytes) |
| 409 | + |
| 410 | + |
328 | 411 | if __name__ == '__main__':
|
329 | 412 | unittest.main()
|
0 commit comments