Skip to content

Commit cb56d52

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Fix write and encode jpeg tests (#3908)
Reviewed By: vincentqb, cpuhrsch Differential Revision: D28679967 fbshipit-source-id: 000bea1e2bc5fe7db14fbc36d80528300c7f7650
1 parent f87b024 commit cb56d52

File tree

1 file changed

+156
-73
lines changed

1 file changed

+156
-73
lines changed

test/test_image.py

Lines changed: 156 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,36 @@
11
import glob
22
import io
33
import os
4+
import sys
45
import unittest
6+
from pathlib import Path
57

68
import pytest
79
import numpy as np
810
import torch
911
from PIL import Image
10-
from common_utils import get_tmp_dir, needs_cuda
12+
import torchvision.transforms.functional as F
13+
from common_utils import get_tmp_dir, needs_cuda, cpu_only
1114
from _assert_utils import assert_equal
1215

1316
from torchvision.io.image import (
1417
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
15-
encode_png, write_png, write_file, ImageReadMode)
18+
encode_png, write_png, write_file, ImageReadMode, read_image)
1619

1720
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
1821
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
1922
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
2023
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
2124
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
25+
IS_WINDOWS = sys.platform in ('win32', 'cygwin')
26+
27+
28+
def _get_safe_image_name(name):
29+
# Used when we need to change the pytest "id" for an "image path" parameter.
30+
# If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
31+
# and this creates issues when the test is running in a different machine than where it was collected
32+
# (typically, in fb internal infra)
33+
return name.split(os.path.sep)[-1]
2234

2335

2436
def get_images(directory, img_ext):
@@ -93,72 +105,6 @@ def test_damaged_images(self):
93105
with self.assertRaises(RuntimeError):
94106
decode_jpeg(data)
95107

96-
def test_encode_jpeg(self):
97-
for img_path in get_images(ENCODE_JPEG, ".jpg"):
98-
dirname = os.path.dirname(img_path)
99-
filename, _ = os.path.splitext(os.path.basename(img_path))
100-
write_folder = os.path.join(dirname, 'jpeg_write')
101-
expected_file = os.path.join(
102-
write_folder, '{0}_pil.jpg'.format(filename))
103-
img = decode_jpeg(read_file(img_path))
104-
105-
with open(expected_file, 'rb') as f:
106-
pil_bytes = f.read()
107-
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
108-
for src_img in [img, img.contiguous()]:
109-
# PIL sets jpeg quality to 75 by default
110-
jpeg_bytes = encode_jpeg(src_img, quality=75)
111-
assert_equal(jpeg_bytes, pil_bytes)
112-
113-
with self.assertRaisesRegex(
114-
RuntimeError, "Input tensor dtype should be uint8"):
115-
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
116-
117-
with self.assertRaisesRegex(
118-
ValueError, "Image quality should be a positive number "
119-
"between 1 and 100"):
120-
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
121-
122-
with self.assertRaisesRegex(
123-
ValueError, "Image quality should be a positive number "
124-
"between 1 and 100"):
125-
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
126-
127-
with self.assertRaisesRegex(
128-
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
129-
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
130-
131-
with self.assertRaisesRegex(
132-
RuntimeError, "Input data should be a 3-dimensional tensor"):
133-
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
134-
135-
with self.assertRaisesRegex(
136-
RuntimeError, "Input data should be a 3-dimensional tensor"):
137-
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
138-
139-
def test_write_jpeg(self):
140-
with get_tmp_dir() as d:
141-
for img_path in get_images(ENCODE_JPEG, ".jpg"):
142-
data = read_file(img_path)
143-
img = decode_jpeg(data)
144-
145-
basedir = os.path.dirname(img_path)
146-
filename, _ = os.path.splitext(os.path.basename(img_path))
147-
torch_jpeg = os.path.join(
148-
d, '{0}_torch.jpg'.format(filename))
149-
pil_jpeg = os.path.join(
150-
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
151-
152-
write_jpeg(img, torch_jpeg, quality=75)
153-
154-
with open(torch_jpeg, 'rb') as f:
155-
torch_bytes = f.read()
156-
157-
with open(pil_jpeg, 'rb') as f:
158-
pil_bytes = f.read()
159-
160-
self.assertEqual(torch_bytes, pil_bytes)
161-
162108
def test_decode_png(self):
163109
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
164110
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
@@ -282,11 +228,7 @@ def test_write_file_non_ascii(self):
282228

283229
@needs_cuda
284230
@pytest.mark.parametrize('img_path', [
285-
# We need to change the "id" for that parameter.
286-
# If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific,
287-
# and this creates issues when the test is running in a different machine than where it was collected
288-
# (typically, in fb internal infra)
289-
pytest.param(jpeg_path, id=jpeg_path.split('/')[-1])
231+
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
290232
for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
291233
])
292234
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors():
325267
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
326268

327269

270+
@cpu_only
271+
def test_encode_jpeg_errors():
272+
273+
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
274+
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
275+
276+
with pytest.raises(ValueError, match="Image quality should be a positive number "
277+
"between 1 and 100"):
278+
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
279+
280+
with pytest.raises(ValueError, match="Image quality should be a positive number "
281+
"between 1 and 100"):
282+
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
283+
284+
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
285+
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
286+
287+
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
288+
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
289+
290+
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
291+
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
292+
293+
294+
def _collect_if(cond):
295+
# TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows
296+
# are removed
297+
def _inner(test_func):
298+
if cond:
299+
return test_func
300+
else:
301+
return pytest.mark.dont_collect(test_func)
302+
return _inner
303+
304+
305+
@cpu_only
306+
@_collect_if(cond=IS_WINDOWS)
307+
def test_encode_jpeg_windows():
308+
# This test is *wrong*.
309+
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
310+
# starts encoding the torchvision version from an image that comes from
311+
# decode_jpeg, which can yield different results from pil.decode (see
312+
# test_decode... which uses a high tolerance).
313+
# Instead, we should start encoding from the exact same decoded image, for a
314+
# valid comparison. This is done in test_encode_jpeg, but unfortunately
315+
# these more correct tests fail on windows (probably because of a difference
316+
# in libjpeg) between torchvision and PIL.
317+
# FIXME: make the correct tests pass on windows and remove this.
318+
for img_path in get_images(ENCODE_JPEG, ".jpg"):
319+
dirname = os.path.dirname(img_path)
320+
filename, _ = os.path.splitext(os.path.basename(img_path))
321+
write_folder = os.path.join(dirname, 'jpeg_write')
322+
expected_file = os.path.join(
323+
write_folder, '{0}_pil.jpg'.format(filename))
324+
img = decode_jpeg(read_file(img_path))
325+
326+
with open(expected_file, 'rb') as f:
327+
pil_bytes = f.read()
328+
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
329+
for src_img in [img, img.contiguous()]:
330+
# PIL sets jpeg quality to 75 by default
331+
jpeg_bytes = encode_jpeg(src_img, quality=75)
332+
assert_equal(jpeg_bytes, pil_bytes)
333+
334+
335+
@cpu_only
336+
@_collect_if(cond=IS_WINDOWS)
337+
def test_write_jpeg_windows():
338+
# FIXME: Remove this eventually, see test_encode_jpeg_windows
339+
with get_tmp_dir() as d:
340+
for img_path in get_images(ENCODE_JPEG, ".jpg"):
341+
data = read_file(img_path)
342+
img = decode_jpeg(data)
343+
344+
basedir = os.path.dirname(img_path)
345+
filename, _ = os.path.splitext(os.path.basename(img_path))
346+
torch_jpeg = os.path.join(
347+
d, '{0}_torch.jpg'.format(filename))
348+
pil_jpeg = os.path.join(
349+
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
350+
351+
write_jpeg(img, torch_jpeg, quality=75)
352+
353+
with open(torch_jpeg, 'rb') as f:
354+
torch_bytes = f.read()
355+
356+
with open(pil_jpeg, 'rb') as f:
357+
pil_bytes = f.read()
358+
359+
assert_equal(torch_bytes, pil_bytes)
360+
361+
362+
@cpu_only
363+
@_collect_if(cond=not IS_WINDOWS)
364+
@pytest.mark.parametrize('img_path', [
365+
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
366+
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
367+
])
368+
def test_encode_jpeg(img_path):
369+
img = read_image(img_path)
370+
371+
pil_img = F.to_pil_image(img)
372+
buf = io.BytesIO()
373+
pil_img.save(buf, format='JPEG', quality=75)
374+
375+
# pytorch can't read from raw bytes so we go through numpy
376+
pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8)
377+
encoded_jpeg_pil = torch.as_tensor(pil_bytes)
378+
379+
for src_img in [img, img.contiguous()]:
380+
encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
381+
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
382+
383+
384+
@cpu_only
385+
@_collect_if(cond=not IS_WINDOWS)
386+
@pytest.mark.parametrize('img_path', [
387+
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
388+
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
389+
])
390+
def test_write_jpeg(img_path):
391+
with get_tmp_dir() as d:
392+
d = Path(d)
393+
img = read_image(img_path)
394+
pil_img = F.to_pil_image(img)
395+
396+
torch_jpeg = str(d / 'torch.jpg')
397+
pil_jpeg = str(d / 'pil.jpg')
398+
399+
write_jpeg(img, torch_jpeg, quality=75)
400+
pil_img.save(pil_jpeg, quality=75)
401+
402+
with open(torch_jpeg, 'rb') as f:
403+
torch_bytes = f.read()
404+
405+
with open(pil_jpeg, 'rb') as f:
406+
pil_bytes = f.read()
407+
408+
assert_equal(torch_bytes, pil_bytes)
409+
410+
328411
if __name__ == '__main__':
329412
unittest.main()

0 commit comments

Comments
 (0)