diff --git a/test/common_utils.py b/test/common_utils.py index 5936ae1f713..79fce27110d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -15,6 +15,7 @@ from numbers import Number from torch._six import string_classes from collections import OrderedDict +from torchvision import io import numpy as np from PIL import Image @@ -147,6 +148,25 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) +def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): + names = [] + for i in range(num_videos): + if sizes is None: + size = 5 * (i + 1) + else: + size = sizes[i] + if fps is None: + f = 5 + else: + f = fps[i] + data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) + name = os.path.join(tmpdir, "{}.mp4".format(i)) + names.append(name) + io.write_video(name, data, fps=f) + + return names + + def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): np_pil_image = np.array(pil_image) if np_pil_image.ndim == 2: diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 8c2d575e01d..0cf86918575 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -22,8 +22,6 @@ USER_AGENT, ) -from common_utils import get_tmp_dir - def limit_requests_per_time(min_secs_between_requests=2.0): last_requests = {} @@ -166,16 +164,15 @@ def assert_url_is_accessible(url, timeout=5.0): urlopen(request, timeout=timeout) -def assert_file_downloads_correctly(url, md5, timeout=5.0): - with get_tmp_dir() as root: - file = path.join(root, path.basename(url)) - with assert_server_response_ok(): - with open(file, "wb") as fh: - request = Request(url, headers={"User-Agent": USER_AGENT}) - response = urlopen(request, timeout=timeout) - fh.write(response.read()) +def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0): + file = path.join(tmpdir, path.basename(url)) + with assert_server_response_ok(): + with open(file, "wb") as fh: + request = Request(url, headers={"User-Agent": USER_AGENT}) + response = urlopen(request, timeout=timeout) + fh.write(response.read()) - assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" + assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" class DownloadConfig: diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index 7754c1a98e8..c76fd1849fc 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -13,104 +13,83 @@ from torchvision.datasets.video_utils import VideoClips, unfold from torchvision import get_video_backend -from common_utils import get_tmp_dir, assert_equal - - -@contextlib.contextmanager -def get_list_of_videos(num_videos=5, sizes=None, fps=None): - with get_tmp_dir() as tmp_dir: - names = [] - for i in range(num_videos): - if sizes is None: - size = 5 * (i + 1) - else: - size = sizes[i] - if fps is None: - f = 5 - else: - f = fps[i] - data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) - name = os.path.join(tmp_dir, "{}.mp4".format(i)) - names.append(name) - io.write_video(name, data, fps=f) - - yield names +from common_utils import get_list_of_videos, assert_equal @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") class TestDatasetsSamplers: - def test_random_clip_sampler(self): - with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) - sampler = RandomClipSampler(video_clips, 3) - assert len(sampler) == 3 * 3 - indices = torch.tensor(list(iter(sampler))) - videos = torch.div(indices, 5, rounding_mode='floor') - v_idxs, count = torch.unique(videos, return_counts=True) - assert_equal(v_idxs, torch.tensor([0, 1, 2])) - assert_equal(count, torch.tensor([3, 3, 3])) + def test_random_clip_sampler(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = RandomClipSampler(video_clips, 3) + assert len(sampler) == 3 * 3 + indices = torch.tensor(list(iter(sampler))) + videos = torch.div(indices, 5, rounding_mode='floor') + v_idxs, count = torch.unique(videos, return_counts=True) + assert_equal(v_idxs, torch.tensor([0, 1, 2])) + assert_equal(count, torch.tensor([3, 3, 3])) - def test_random_clip_sampler_unequal(self): - with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) - sampler = RandomClipSampler(video_clips, 3) - assert len(sampler) == 2 + 3 + 3 - indices = list(iter(sampler)) - assert 0 in indices - assert 1 in indices - # remove elements of the first video, to simplify testing - indices.remove(0) - indices.remove(1) - indices = torch.tensor(indices) - 2 - videos = torch.div(indices, 5, rounding_mode='floor') - v_idxs, count = torch.unique(videos, return_counts=True) - assert_equal(v_idxs, torch.tensor([0, 1])) - assert_equal(count, torch.tensor([3, 3])) + def test_random_clip_sampler_unequal(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = RandomClipSampler(video_clips, 3) + assert len(sampler) == 2 + 3 + 3 + indices = list(iter(sampler)) + assert 0 in indices + assert 1 in indices + # remove elements of the first video, to simplify testing + indices.remove(0) + indices.remove(1) + indices = torch.tensor(indices) - 2 + videos = torch.div(indices, 5, rounding_mode='floor') + v_idxs, count = torch.unique(videos, return_counts=True) + assert_equal(v_idxs, torch.tensor([0, 1])) + assert_equal(count, torch.tensor([3, 3])) - def test_uniform_clip_sampler(self): - with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) - sampler = UniformClipSampler(video_clips, 3) - assert len(sampler) == 3 * 3 - indices = torch.tensor(list(iter(sampler))) - videos = torch.div(indices, 5, rounding_mode='floor') - v_idxs, count = torch.unique(videos, return_counts=True) - assert_equal(v_idxs, torch.tensor([0, 1, 2])) - assert_equal(count, torch.tensor([3, 3, 3])) - assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14])) + def test_uniform_clip_sampler(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = UniformClipSampler(video_clips, 3) + assert len(sampler) == 3 * 3 + indices = torch.tensor(list(iter(sampler))) + videos = torch.div(indices, 5, rounding_mode='floor') + v_idxs, count = torch.unique(videos, return_counts=True) + assert_equal(v_idxs, torch.tensor([0, 1, 2])) + assert_equal(count, torch.tensor([3, 3, 3])) + assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14])) - def test_uniform_clip_sampler_insufficient_clips(self): - with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) - sampler = UniformClipSampler(video_clips, 3) - assert len(sampler) == 3 * 3 - indices = torch.tensor(list(iter(sampler))) - assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])) + def test_uniform_clip_sampler_insufficient_clips(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = UniformClipSampler(video_clips, 3) + assert len(sampler) == 3 * 3 + indices = torch.tensor(list(iter(sampler))) + assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])) - def test_distributed_sampler_and_uniform_clip_sampler(self): - with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) - clip_sampler = UniformClipSampler(video_clips, 3) + def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + clip_sampler = UniformClipSampler(video_clips, 3) - distributed_sampler_rank0 = DistributedSampler( - clip_sampler, - num_replicas=2, - rank=0, - group_size=3, - ) - indices = torch.tensor(list(iter(distributed_sampler_rank0))) - assert len(distributed_sampler_rank0) == 6 - assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14])) + distributed_sampler_rank0 = DistributedSampler( + clip_sampler, + num_replicas=2, + rank=0, + group_size=3, + ) + indices = torch.tensor(list(iter(distributed_sampler_rank0))) + assert len(distributed_sampler_rank0) == 6 + assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14])) - distributed_sampler_rank1 = DistributedSampler( - clip_sampler, - num_replicas=2, - rank=1, - group_size=3, - ) - indices = torch.tensor(list(iter(distributed_sampler_rank1))) - assert len(distributed_sampler_rank1) == 6 - assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4])) + distributed_sampler_rank1 = DistributedSampler( + clip_sampler, + num_replicas=2, + rank=1, + group_size=3, + ) + indices = torch.tensor(list(iter(distributed_sampler_rank1))) + assert len(distributed_sampler_rank1) == 6 + assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4])) if __name__ == '__main__': diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 0c2dc5260de..3d147608a59 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -12,7 +12,6 @@ import lzma import contextlib -from common_utils import get_tmp_dir from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS @@ -113,7 +112,7 @@ def test_detect_file_type_incompatible(self, file): utils._detect_file_type(file) @pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"]) - def test_decompress(self, extension): + def test_decompress(self, extension, tmpdir): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") compressed = f"{file}{extension}" @@ -124,21 +123,20 @@ def create_compressed(root, content="this is the content"): return compressed, file, content - with get_tmp_dir() as temp_dir: - compressed, file, content = create_compressed(temp_dir) + compressed, file, content = create_compressed(tmpdir) - utils._decompress(compressed) + utils._decompress(compressed) - assert os.path.exists(file) + assert os.path.exists(file) - with open(file, "r") as fh: - assert fh.read() == content + with open(file, "r") as fh: + assert fh.read() == content def test_decompress_no_compression(self): with pytest.raises(RuntimeError): utils._decompress("foo.tar") - def test_decompress_remove_finished(self): + def test_decompress_remove_finished(self, tmpdir): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") compressed = f"{file}.gz" @@ -148,12 +146,11 @@ def create_compressed(root, content="this is the content"): return compressed, file, content - with get_tmp_dir() as temp_dir: - compressed, file, content = create_compressed(temp_dir) + compressed, file, content = create_compressed(tmpdir) - utils.extract_archive(compressed, temp_dir, remove_finished=True) + utils.extract_archive(compressed, tmpdir, remove_finished=True) - assert not os.path.exists(compressed) + assert not os.path.exists(compressed) @pytest.mark.parametrize('extension', [".gz", ".xz"]) @pytest.mark.parametrize('remove_finished', [True, False]) @@ -166,7 +163,7 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m mocked.assert_called_once_with(file, filename, remove_finished=remove_finished) - def test_extract_zip(self): + def test_extract_zip(self, tmpdir): def create_archive(root, content="this is the content"): file = os.path.join(root, "dst.txt") archive = os.path.join(root, "archive.zip") @@ -176,19 +173,18 @@ def create_archive(root, content="this is the content"): return archive, file, content - with get_tmp_dir() as temp_dir: - archive, file, content = create_archive(temp_dir) + archive, file, content = create_archive(tmpdir) - utils.extract_archive(archive, temp_dir) + utils.extract_archive(archive, tmpdir) - assert os.path.exists(file) + assert os.path.exists(file) - with open(file, "r") as fh: - assert fh.read() == content + with open(file, "r") as fh: + assert fh.read() == content @pytest.mark.parametrize('extension, mode', [ ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')]) - def test_extract_tar(self, extension, mode): + def test_extract_tar(self, extension, mode, tmpdir): def create_archive(root, extension, mode, content="this is the content"): src = os.path.join(root, "src.txt") dst = os.path.join(root, "dst.txt") @@ -202,15 +198,14 @@ def create_archive(root, extension, mode, content="this is the content"): return archive, dst, content - with get_tmp_dir() as temp_dir: - archive, file, content = create_archive(temp_dir, extension, mode) + archive, file, content = create_archive(tmpdir, extension, mode) - utils.extract_archive(archive, temp_dir) + utils.extract_archive(archive, tmpdir) - assert os.path.exists(file) + assert os.path.exists(file) - with open(file, "r") as fh: - assert fh.read() == content + with open(file, "r") as fh: + assert fh.read() == content def test_verify_str_arg(self): assert "a" == utils.verify_str_arg("a", "arg", ("a",)) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index 00db0aad127..9671d1d8f4c 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -6,28 +6,7 @@ from torchvision import io from torchvision.datasets.video_utils import VideoClips, unfold -from common_utils import get_tmp_dir, assert_equal - - -@contextlib.contextmanager -def get_list_of_videos(num_videos=5, sizes=None, fps=None): - with get_tmp_dir() as tmp_dir: - names = [] - for i in range(num_videos): - if sizes is None: - size = 5 * (i + 1) - else: - size = sizes[i] - if fps is None: - f = 5 - else: - f = fps[i] - data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) - name = os.path.join(tmp_dir, "{}.mp4".format(i)) - names.append(name) - io.write_video(name, data, fps=f) - - yield names +from common_utils import get_list_of_videos, assert_equal class TestVideo: @@ -58,40 +37,40 @@ def test_unfold(self): assert_equal(r, expected) @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") - def test_video_clips(self): - with get_list_of_videos(num_videos=3) as video_list: - video_clips = VideoClips(video_list, 5, 5, num_workers=2) - assert video_clips.num_clips() == 1 + 2 + 3 - for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]): - video_idx, clip_idx = video_clips.get_clip_location(i) - assert video_idx == v_idx - assert clip_idx == c_idx - - video_clips = VideoClips(video_list, 6, 6) - assert video_clips.num_clips() == 0 + 1 + 2 - for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]): - video_idx, clip_idx = video_clips.get_clip_location(i) - assert video_idx == v_idx - assert clip_idx == c_idx - - video_clips = VideoClips(video_list, 6, 1) - assert video_clips.num_clips() == 0 + (10 - 6 + 1) + (15 - 6 + 1) - for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]: - video_idx, clip_idx = video_clips.get_clip_location(i) - assert video_idx == v_idx - assert clip_idx == c_idx + def test_video_clips(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3) + video_clips = VideoClips(video_list, 5, 5, num_workers=2) + assert video_clips.num_clips() == 1 + 2 + 3 + for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]): + video_idx, clip_idx = video_clips.get_clip_location(i) + assert video_idx == v_idx + assert clip_idx == c_idx + + video_clips = VideoClips(video_list, 6, 6) + assert video_clips.num_clips() == 0 + 1 + 2 + for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]): + video_idx, clip_idx = video_clips.get_clip_location(i) + assert video_idx == v_idx + assert clip_idx == c_idx + + video_clips = VideoClips(video_list, 6, 1) + assert video_clips.num_clips() == 0 + (10 - 6 + 1) + (15 - 6 + 1) + for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]: + video_idx, clip_idx = video_clips.get_clip_location(i) + assert video_idx == v_idx + assert clip_idx == c_idx @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") - def test_video_clips_custom_fps(self): - with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: - num_frames = 4 - for fps in [1, 3, 4, 10]: - video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2) - for i in range(video_clips.num_clips()): - video, audio, info, video_idx = video_clips.get_clip(i) - assert video.shape[0] == num_frames - assert info["video_fps"] == fps - # TODO add tests checking that the content is right + def test_video_clips_custom_fps(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) + num_frames = 4 + for fps in [1, 3, 4, 10]: + video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2) + for i in range(video_clips.num_clips()): + video, audio, info, video_idx = video_clips.get_clip(i) + assert video.shape[0] == num_frames + assert info["video_fps"] == fps + # TODO add tests checking that the content is right def test_compute_clips_for_video(self): video_pts = torch.arange(30) diff --git a/test/test_image.py b/test/test_image.py index 7c6764dce64..5630d5d8226 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -9,7 +9,7 @@ import torch from PIL import Image, __version__ as PILLOW_VERSION import torchvision.transforms.functional as F -from common_utils import get_tmp_dir, needs_cuda, assert_equal +from common_utils import needs_cuda, assert_equal from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, @@ -197,74 +197,69 @@ def test_encode_png_errors(): pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png") ]) -def test_write_png(img_path): - with get_tmp_dir() as d: - pil_image = Image.open(img_path) - img_pil = torch.from_numpy(np.array(pil_image)) - img_pil = img_pil.permute(2, 0, 1) +def test_write_png(img_path, tmpdir): + pil_image = Image.open(img_path) + img_pil = torch.from_numpy(np.array(pil_image)) + img_pil = img_pil.permute(2, 0, 1) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) - write_png(img_pil, torch_png, compression_level=6) - saved_image = torch.from_numpy(np.array(Image.open(torch_png))) - saved_image = saved_image.permute(2, 0, 1) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_png = os.path.join(tmpdir, '{0}_torch.png'.format(filename)) + write_png(img_pil, torch_png, compression_level=6) + saved_image = torch.from_numpy(np.array(Image.open(torch_png))) + saved_image = saved_image.permute(2, 0, 1) - assert_equal(img_pil, saved_image) + assert_equal(img_pil, saved_image) -def test_read_file(): - with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: - f.write(content) +def test_read_file(tmpdir): + fname, content = 'test1.bin', b'TorchVision\211\n' + fpath = os.path.join(tmpdir, fname) + with open(fpath, 'wb') as f: + f.write(content) - data = read_file(fpath) - expected = torch.tensor(list(content), dtype=torch.uint8) - os.unlink(fpath) - assert_equal(data, expected) + data = read_file(fpath) + expected = torch.tensor(list(content), dtype=torch.uint8) + os.unlink(fpath) + assert_equal(data, expected) with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): read_file('tst') -def test_read_file_non_ascii(): - with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: - f.write(content) +def test_read_file_non_ascii(tmpdir): + fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fpath = os.path.join(tmpdir, fname) + with open(fpath, 'wb') as f: + f.write(content) - data = read_file(fpath) - expected = torch.tensor(list(content), dtype=torch.uint8) - os.unlink(fpath) - assert_equal(data, expected) + data = read_file(fpath) + expected = torch.tensor(list(content), dtype=torch.uint8) + os.unlink(fpath) + assert_equal(data, expected) -def test_write_file(): - with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write_file(fpath, content_tensor) +def test_write_file(tmpdir): + fname, content = 'test1.bin', b'TorchVision\211\n' + fpath = os.path.join(tmpdir, fname) + content_tensor = torch.tensor(list(content), dtype=torch.uint8) + write_file(fpath, content_tensor) - with open(fpath, 'rb') as f: - saved_content = f.read() - os.unlink(fpath) - assert content == saved_content + with open(fpath, 'rb') as f: + saved_content = f.read() + os.unlink(fpath) + assert content == saved_content -def test_write_file_non_ascii(): - with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write_file(fpath, content_tensor) +def test_write_file_non_ascii(tmpdir): + fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fpath = os.path.join(tmpdir, fname) + content_tensor = torch.tensor(list(content), dtype=torch.uint8) + write_file(fpath, content_tensor) - with open(fpath, 'rb') as f: - saved_content = f.read() - os.unlink(fpath) - assert content == saved_content + with open(fpath, 'rb') as f: + saved_content = f.read() + os.unlink(fpath) + assert content == saved_content @pytest.mark.parametrize('shape', [ @@ -272,16 +267,15 @@ def test_write_file_non_ascii(): (60, 60), (105, 105), ]) -def test_read_1_bit_png(shape): +def test_read_1_bit_png(shape, tmpdir): np_rng = np.random.RandomState(0) - with get_tmp_dir() as root: - image_path = os.path.join(root, f'test_{shape}.png') - pixels = np_rng.rand(*shape) > 0.5 - img = Image.fromarray(pixels) - img.save(image_path) - img1 = read_image(image_path) - img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8)) - assert_equal(img1, img2) + image_path = os.path.join(tmpdir, f'test_{shape}.png') + pixels = np_rng.rand(*shape) > 0.5 + img = Image.fromarray(pixels) + img.save(image_path) + img1 = read_image(image_path) + img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8)) + assert_equal(img1, img2) @pytest.mark.parametrize('shape', [ @@ -293,16 +287,15 @@ def test_read_1_bit_png(shape): ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ]) -def test_read_1_bit_png_consistency(shape, mode): +def test_read_1_bit_png_consistency(shape, mode, tmpdir): np_rng = np.random.RandomState(0) - with get_tmp_dir() as root: - image_path = os.path.join(root, f'test_{shape}.png') - pixels = np_rng.rand(*shape) > 0.5 - img = Image.fromarray(pixels) - img.save(image_path) - img1 = read_image(image_path, mode) - img2 = read_image(image_path, mode) - assert_equal(img1, img2) + image_path = os.path.join(tmpdir, f'test_{shape}.png') + pixels = np_rng.rand(*shape) > 0.5 + img = Image.fromarray(pixels) + img.save(image_path) + img1 = read_image(image_path, mode) + img2 = read_image(image_path, mode) + assert_equal(img1, img2) def test_read_interlaced_png(): @@ -427,28 +420,27 @@ def test_encode_jpeg_reference(img_path): pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg") ]) -def test_write_jpeg_reference(img_path): +def test_write_jpeg_reference(img_path, tmpdir): # FIXME: Remove this eventually, see test_encode_jpeg_reference - with get_tmp_dir() as d: - data = read_file(img_path) - img = decode_jpeg(data) + data = read_file(img_path) + img = decode_jpeg(data) - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - d, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + basedir = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_jpeg = os.path.join( + tmpdir, '{0}_torch.jpg'.format(filename)) + pil_jpeg = os.path.join( + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) - write_jpeg(img, torch_jpeg, quality=75) + write_jpeg(img, torch_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: - torch_bytes = f.read() + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: - pil_bytes = f.read() + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() - assert_equal(torch_bytes, pil_bytes) + assert_equal(torch_bytes, pil_bytes) @pytest.mark.skipif(IS_WINDOWS, reason=( @@ -481,25 +473,24 @@ def test_encode_jpeg(img_path): pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg") ]) -def test_write_jpeg(img_path): - with get_tmp_dir() as d: - d = Path(d) - img = read_image(img_path) - pil_img = F.to_pil_image(img) +def test_write_jpeg(img_path, tmpdir): + tmpdir = Path(tmpdir) + img = read_image(img_path) + pil_img = F.to_pil_image(img) - torch_jpeg = str(d / 'torch.jpg') - pil_jpeg = str(d / 'pil.jpg') + torch_jpeg = str(tmpdir / 'torch.jpg') + pil_jpeg = str(tmpdir / 'pil.jpg') - write_jpeg(img, torch_jpeg, quality=75) - pil_img.save(pil_jpeg, quality=75) + write_jpeg(img, torch_jpeg, quality=75) + pil_img.save(pil_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: - torch_bytes = f.read() + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: - pil_bytes = f.read() + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() - assert_equal(torch_bytes, pil_bytes) + assert_equal(torch_bytes, pil_bytes) if __name__ == "__main__": diff --git a/test/test_internet.py b/test/test_internet.py index 8b1678f7b58..fd552961714 100644 --- a/test/test_internet.py +++ b/test/test_internet.py @@ -11,35 +11,31 @@ from urllib.error import URLError import torchvision.datasets.utils as utils -from common_utils import get_tmp_dir class TestDatasetUtils: - def test_download_url(self): - with get_tmp_dir() as temp_dir: - url = "http://github.com/pytorch/vision/archive/master.zip" - try: - utils.download_url(url, temp_dir) - assert len(os.listdir(temp_dir)) != 0 - except URLError: - pytest.skip(f"could not download test file '{url}'") - - def test_download_url_retry_http(self): - with get_tmp_dir() as temp_dir: - url = "https://github.com/pytorch/vision/archive/master.zip" - try: - utils.download_url(url, temp_dir) - assert len(os.listdir(temp_dir)) != 0 - except URLError: - pytest.skip(f"could not download test file '{url}'") - - def test_download_url_dont_exist(self): - with get_tmp_dir() as temp_dir: - url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip" - with pytest.raises(URLError): - utils.download_url(url, temp_dir) - - def test_download_url_dispatch_download_from_google_drive(self, mocker): + def test_download_url(self, tmpdir): + url = "http://github.com/pytorch/vision/archive/master.zip" + try: + utils.download_url(url, tmpdir) + assert len(os.listdir(tmpdir)) != 0 + except URLError: + pytest.skip(f"could not download test file '{url}'") + + def test_download_url_retry_http(self, tmpdir): + url = "https://github.com/pytorch/vision/archive/master.zip" + try: + utils.download_url(url, tmpdir) + assert len(os.listdir(tmpdir)) != 0 + except URLError: + pytest.skip(f"could not download test file '{url}'") + + def test_download_url_dont_exist(self, tmpdir): + url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip" + with pytest.raises(URLError): + utils.download_url(url, tmpdir) + + def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir): url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" @@ -47,10 +43,9 @@ def test_download_url_dispatch_download_from_google_drive(self, mocker): md5 = "md5" mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive') - with get_tmp_dir() as root: - utils.download_url(url, root, filename, md5) + utils.download_url(url, tmpdir, filename, md5) - mocked.assert_called_once_with(id, root, filename, md5) + mocked.assert_called_once_with(id, tmpdir, filename, md5) if __name__ == '__main__': diff --git a/test/test_io.py b/test/test_io.py index 56cd0af5fd8..150d66f0814 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -9,7 +9,7 @@ import warnings from urllib.error import URLError -from common_utils import get_tmp_dir, assert_equal +from common_utils import assert_equal try: @@ -255,37 +255,36 @@ def test_read_video_partially_corrupted_file(self): assert_equal(video, data) @pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows') - def test_write_video_with_audio(self): + def test_write_video_with_audio(self, tmpdir): f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") - with get_tmp_dir() as tmpdir: - out_f_name = os.path.join(tmpdir, "testing.mp4") - io.video.write_video( - out_f_name, - video_tensor, - round(info["video_fps"]), - video_codec="libx264rgb", - options={'crf': '0'}, - audio_array=audio_tensor, - audio_fps=info["audio_fps"], - audio_codec="aac", - ) - - out_video_tensor, out_audio_tensor, out_info = io.read_video( - out_f_name, pts_unit="sec" - ) - - assert info["video_fps"] == out_info["video_fps"] - assert_equal(video_tensor, out_video_tensor) - - audio_stream = av.open(f_name).streams.audio[0] - out_audio_stream = av.open(out_f_name).streams.audio[0] - - assert info["audio_fps"] == out_info["audio_fps"] - assert audio_stream.rate == out_audio_stream.rate - assert pytest.approx(out_audio_stream.frames, rel=0.0, abs=1) == audio_stream.frames - assert audio_stream.frame_size == out_audio_stream.frame_size + out_f_name = os.path.join(tmpdir, "testing.mp4") + io.video.write_video( + out_f_name, + video_tensor, + round(info["video_fps"]), + video_codec="libx264rgb", + options={'crf': '0'}, + audio_array=audio_tensor, + audio_fps=info["audio_fps"], + audio_codec="aac", + ) + + out_video_tensor, out_audio_tensor, out_info = io.read_video( + out_f_name, pts_unit="sec" + ) + + assert info["video_fps"] == out_info["video_fps"] + assert_equal(video_tensor, out_video_tensor) + + audio_stream = av.open(f_name).streams.audio[0] + out_audio_stream = av.open(out_f_name).streams.audio[0] + + assert info["audio_fps"] == out_info["audio_fps"] + assert audio_stream.rate == out_audio_stream.rate + assert pytest.approx(out_audio_stream.frames, rel=0.0, abs=1) == audio_stream.frames + assert audio_stream.frame_size == out_audio_stream.frame_size # TODO add tests for audio diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 0bf5d77716f..5081626fec4 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -230,7 +230,7 @@ def test_crop_pad(size, padding_config, device): @pytest.mark.parametrize('device', cpu_and_gpu()) -def test_center_crop(device): +def test_center_crop(device, tmpdir): fn_kwargs = {"output_size": (4, 5)} meth_kwargs = {"size": (4, 5), } _test_op( @@ -259,8 +259,7 @@ def test_center_crop(device): scripted_fn = torch.jit.script(f) scripted_fn(tensor) - with get_tmp_dir() as tmp_dir: - scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt")) + scripted_fn.save(os.path.join(tmpdir, "t_center_crop.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) @@ -309,11 +308,10 @@ def test_x_crop(fn, method, out_length, size, device): @pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"]) -def test_x_crop_save(method): +def test_x_crop_save(method, tmpdir): fn = getattr(T, method)(size=[5, ]) scripted_fn = torch.jit.script(fn) - with get_tmp_dir() as tmp_dir: - scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method))) + scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method))) class TestResize: @@ -349,11 +347,10 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): _test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) - def test_resize_save(self): + def test_resize_save(self, tmpdir): transform = T.Resize(size=[32, ]) s_transform = torch.jit.script(transform) - with get_tmp_dir() as tmp_dir: - s_transform.save(os.path.join(tmp_dir, "t_resize.pt")) + s_transform.save(os.path.join(tmpdir, "t_resize.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) @@ -368,11 +365,10 @@ def test_resized_crop(self, scale, ratio, size, interpolation, device): _test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) - def test_resized_crop_save(self): + def test_resized_crop_save(self, tmpdir): transform = T.RandomResizedCrop(size=[32, ]) s_transform = torch.jit.script(transform) - with get_tmp_dir() as tmp_dir: - s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt")) + s_transform.save(os.path.join(tmpdir, "t_resized_crop.pt")) def _test_random_affine_helper(device, **kwargs): @@ -386,11 +382,10 @@ def _test_random_affine_helper(device, **kwargs): @pytest.mark.parametrize('device', cpu_and_gpu()) -def test_random_affine(device): +def test_random_affine(device, tmpdir): transform = T.RandomAffine(degrees=45.0) s_transform = torch.jit.script(transform) - with get_tmp_dir() as tmp_dir: - s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt")) + s_transform.save(os.path.join(tmpdir, "t_random_affine.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) @@ -447,11 +442,10 @@ def test_random_rotate(device, center, expand, degrees, interpolation, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -def test_random_rotate_save(): +def test_random_rotate_save(tmpdir): transform = T.RandomRotation(degrees=45.0) s_transform = torch.jit.script(transform) - with get_tmp_dir() as tmp_dir: - s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt")) + s_transform.save(os.path.join(tmpdir, "t_random_rotate.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) @@ -473,11 +467,10 @@ def test_random_perspective(device, distortion_scale, interpolation, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -def test_random_perspective_save(): +def test_random_perspective_save(tmpdir): transform = T.RandomPerspective() s_transform = torch.jit.script(transform) - with get_tmp_dir() as tmp_dir: - s_transform.save(os.path.join(tmp_dir, "t_perspective.pt")) + s_transform.save(os.path.join(tmpdir, "t_perspective.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) @@ -519,11 +512,10 @@ def test_convert_image_dtype(device, in_dtype, out_dtype): _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors) -def test_convert_image_dtype_save(): +def test_convert_image_dtype_save(tmpdir): fn = T.ConvertImageDtype(dtype=torch.uint8) scripted_fn = torch.jit.script(fn) - with get_tmp_dir() as tmp_dir: - scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) + scripted_fn.save(os.path.join(tmpdir, "t_convert_dtype.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) @@ -541,11 +533,10 @@ def test_autoaugment(device, policy, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -def test_autoaugment_save(): +def test_autoaugment_save(tmpdir): transform = T.AutoAugment() s_transform = torch.jit.script(transform) - with get_tmp_dir() as tmp_dir: - s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) + s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) @@ -567,11 +558,10 @@ def test_random_erasing(device, config): _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) -def test_random_erasing_save(): +def test_random_erasing_save(tmpdir): fn = T.RandomErasing(value=0.2) scripted_fn = torch.jit.script(fn) - with get_tmp_dir() as tmp_dir: - scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt")) + scripted_fn.save(os.path.join(tmpdir, "t_random_erasing.pt")) def test_random_erasing_with_invalid_data(): @@ -583,7 +573,7 @@ def test_random_erasing_with_invalid_data(): @pytest.mark.parametrize('device', cpu_and_gpu()) -def test_normalize(device): +def test_normalize(device, tmpdir): fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) tensor, _ = _create_data(26, 34, device=device) @@ -598,12 +588,11 @@ def test_normalize(device): _test_transform_vs_scripted(fn, scripted_fn, tensor) _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) - with get_tmp_dir() as tmp_dir: - scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt")) + scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) -def test_linear_transformation(device): +def test_linear_transformation(device, tmpdir): c, h, w = 3, 24, 32 tensor, _ = _create_data(h, w, channels=c, device=device) @@ -625,8 +614,7 @@ def test_linear_transformation(device): s_transformed_batch = scripted_fn(batch_tensors) assert_equal(transformed_batch, s_transformed_batch) - with get_tmp_dir() as tmp_dir: - scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt")) + scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) @pytest.mark.parametrize('device', cpu_and_gpu())