Skip to content

Commit c59f047

Browse files
authored
[WIP] Add tests for datasets (#966)
* WIP * WIP: minor improvements * Add tests * Fix typo * Use download_and_extract on caltech, cifar and omniglot * Add a print message during extraction * Remove EMNIST from test
1 parent 2b3a1b6 commit c59f047

File tree

7 files changed

+154
-68
lines changed

7 files changed

+154
-68
lines changed

test/test_datasets.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import PIL
2+
import shutil
3+
import tempfile
4+
import unittest
5+
6+
import torchvision
7+
8+
9+
class Tester(unittest.TestCase):
10+
11+
def test_mnist(self):
12+
tmp_dir = tempfile.mkdtemp()
13+
dataset = torchvision.datasets.MNIST(tmp_dir, download=True)
14+
self.assertEqual(len(dataset), 60000)
15+
img, target = dataset[0]
16+
self.assertTrue(isinstance(img, PIL.Image.Image))
17+
self.assertTrue(isinstance(target, int))
18+
shutil.rmtree(tmp_dir)
19+
20+
def test_kmnist(self):
21+
tmp_dir = tempfile.mkdtemp()
22+
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True)
23+
img, target = dataset[0]
24+
self.assertTrue(isinstance(img, PIL.Image.Image))
25+
self.assertTrue(isinstance(target, int))
26+
shutil.rmtree(tmp_dir)
27+
28+
def test_fashionmnist(self):
29+
tmp_dir = tempfile.mkdtemp()
30+
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True)
31+
img, target = dataset[0]
32+
self.assertTrue(isinstance(img, PIL.Image.Image))
33+
self.assertTrue(isinstance(target, int))
34+
shutil.rmtree(tmp_dir)
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

test/test_datasets_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import tempfile
44
import torchvision.datasets.utils as utils
55
import unittest
6+
import zipfile
7+
import tarfile
8+
import gzip
69

710
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
811
'assets', 'grace_hopper_517x606.jpg')
@@ -41,6 +44,47 @@ def test_download_url_retry_http(self):
4144
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
4245
shutil.rmtree(temp_dir)
4346

47+
def test_extract_zip(self):
48+
temp_dir = tempfile.mkdtemp()
49+
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
50+
with zipfile.ZipFile(f, 'w') as zf:
51+
zf.writestr('file.tst', 'this is the content')
52+
utils.extract_file(f.name, temp_dir)
53+
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
54+
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
55+
data = nf.read()
56+
assert data == 'this is the content'
57+
shutil.rmtree(temp_dir)
58+
59+
def test_extract_tar(self):
60+
for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
61+
temp_dir = tempfile.mkdtemp()
62+
with tempfile.NamedTemporaryFile() as bf:
63+
bf.write("this is the content".encode())
64+
bf.seek(0)
65+
with tempfile.NamedTemporaryFile(suffix=ext) as f:
66+
with tarfile.open(f.name, mode=mode) as zf:
67+
zf.add(bf.name, arcname='file.tst')
68+
utils.extract_file(f.name, temp_dir)
69+
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
70+
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
71+
data = nf.read()
72+
assert data == 'this is the content', data
73+
shutil.rmtree(temp_dir)
74+
75+
def test_extract_gzip(self):
76+
temp_dir = tempfile.mkdtemp()
77+
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
78+
with gzip.GzipFile(f.name, 'wb') as zf:
79+
zf.write('this is the content'.encode())
80+
utils.extract_file(f.name, temp_dir)
81+
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
82+
assert os.path.exists(f_name)
83+
with open(os.path.join(f_name), 'r') as nf:
84+
data = nf.read()
85+
assert data == 'this is the content', data
86+
shutil.rmtree(temp_dir)
87+
4488

4589
if __name__ == '__main__':
4690
unittest.main()

torchvision/datasets/caltech.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os.path
55

66
from .vision import VisionDataset
7-
from .utils import download_url, makedir_exist_ok
7+
from .utils import download_and_extract, makedir_exist_ok
88

99

1010
class Caltech101(VisionDataset):
@@ -109,27 +109,20 @@ def __len__(self):
109109
return len(self.index)
110110

111111
def download(self):
112-
import tarfile
113-
114112
if self._check_integrity():
115113
print('Files already downloaded and verified')
116114
return
117115

118-
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
119-
self.root,
120-
"101_ObjectCategories.tar.gz",
121-
"b224c7392d521a49829488ab0f1120d9")
122-
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
123-
self.root,
124-
"101_Annotations.tar",
125-
"6f83eeb1f24d99cab4eb377263132c91")
126-
127-
# extract file
128-
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
129-
tar.extractall(path=self.root)
130-
131-
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
132-
tar.extractall(path=self.root)
116+
download_and_extract(
117+
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
118+
self.root,
119+
"101_ObjectCategories.tar.gz",
120+
"b224c7392d521a49829488ab0f1120d9")
121+
download_and_extract(
122+
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
123+
self.root,
124+
"101_Annotations.tar",
125+
"6f83eeb1f24d99cab4eb377263132c91")
133126

134127
def extra_repr(self):
135128
return "Target type: {target_type}".format(**self.__dict__)
@@ -204,17 +197,12 @@ def __len__(self):
204197
return len(self.index)
205198

206199
def download(self):
207-
import tarfile
208-
209200
if self._check_integrity():
210201
print('Files already downloaded and verified')
211202
return
212203

213-
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
214-
self.root,
215-
"256_ObjectCategories.tar",
216-
"67b4f42ca05d46448c6bb8ecd2220f6d")
217-
218-
# extract file
219-
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
220-
tar.extractall(path=self.root)
204+
download_and_extract(
205+
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
206+
self.root,
207+
"256_ObjectCategories.tar",
208+
"67b4f42ca05d46448c6bb8ecd2220f6d")

torchvision/datasets/cifar.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pickle
1212

1313
from .vision import VisionDataset
14-
from .utils import download_url, check_integrity
14+
from .utils import check_integrity, download_and_extract
1515

1616

1717
class CIFAR10(VisionDataset):
@@ -144,17 +144,10 @@ def _check_integrity(self):
144144
return True
145145

146146
def download(self):
147-
import tarfile
148-
149147
if self._check_integrity():
150148
print('Files already downloaded and verified')
151149
return
152-
153-
download_url(self.url, self.root, self.filename, self.tgz_md5)
154-
155-
# extract file
156-
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
157-
tar.extractall(path=self.root)
150+
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)
158151

159152
def extra_repr(self):
160153
return "Split: {}".format("Train" if self.train is True else "Test")

torchvision/datasets/mnist.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
from PIL import Image
55
import os
66
import os.path
7-
import gzip
87
import numpy as np
98
import torch
109
import codecs
11-
from .utils import download_url, makedir_exist_ok
10+
from .utils import download_and_extract, extract_file, makedir_exist_ok
1211

1312

1413
class MNIST(VisionDataset):
@@ -120,15 +119,6 @@ def _check_exists(self):
120119
os.path.exists(os.path.join(self.processed_folder,
121120
self.test_file)))
122121

123-
@staticmethod
124-
def extract_gzip(gzip_path, remove_finished=False):
125-
print('Extracting {}'.format(gzip_path))
126-
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
127-
gzip.GzipFile(gzip_path) as zip_f:
128-
out_f.write(zip_f.read())
129-
if remove_finished:
130-
os.unlink(gzip_path)
131-
132122
def download(self):
133123
"""Download the MNIST data if it doesn't exist in processed_folder already."""
134124

@@ -141,9 +131,7 @@ def download(self):
141131
# download files
142132
for url in self.urls:
143133
filename = url.rpartition('/')[2]
144-
file_path = os.path.join(self.raw_folder, filename)
145-
download_url(url, root=self.raw_folder, filename=filename, md5=None)
146-
self.extract_gzip(gzip_path=file_path, remove_finished=True)
134+
download_and_extract(url, root=self.raw_folder, filename=filename)
147135

148136
# process and save as torch files
149137
print('Processing...')
@@ -262,7 +250,6 @@ def _test_file(split):
262250
def download(self):
263251
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
264252
import shutil
265-
import zipfile
266253

267254
if self._check_exists():
268255
return
@@ -271,18 +258,12 @@ def download(self):
271258
makedir_exist_ok(self.processed_folder)
272259

273260
# download files
274-
filename = self.url.rpartition('/')[2]
275-
file_path = os.path.join(self.raw_folder, filename)
276-
download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
277-
278-
print('Extracting zip archive')
279-
with zipfile.ZipFile(file_path) as zip_f:
280-
zip_f.extractall(self.raw_folder)
281-
os.unlink(file_path)
261+
print('Downloading and extracting zip archive')
262+
download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True)
282263
gzip_folder = os.path.join(self.raw_folder, 'gzip')
283264
for gzip_file in os.listdir(gzip_folder):
284265
if gzip_file.endswith('.gz'):
285-
self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
266+
extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder)
286267

287268
# process and save as torch files
288269
for split in self.splits:

torchvision/datasets/omniglot.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from os.path import join
44
import os
55
from .vision import VisionDataset
6-
from .utils import download_url, check_integrity, list_dir, list_files
6+
from .utils import download_and_extract, check_integrity, list_dir, list_files
77

88

99
class Omniglot(VisionDataset):
@@ -81,19 +81,14 @@ def _check_integrity(self):
8181
return True
8282

8383
def download(self):
84-
import zipfile
85-
8684
if self._check_integrity():
8785
print('Files already downloaded and verified')
8886
return
8987

9088
filename = self._get_target_folder()
9189
zip_filename = filename + '.zip'
9290
url = self.download_url_prefix + '/' + zip_filename
93-
download_url(url, self.root, zip_filename, self.zips_md5[filename])
94-
print('Extracting downloaded file: ' + join(self.root, zip_filename))
95-
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
96-
zip_file.extractall(self.root)
91+
download_and_extract(url, self.root, zip_filename, self.zips_md5[filename])
9792

9893
def _get_target_folder(self):
9994
return 'images_background' if self.background else 'images_evaluation'

torchvision/datasets/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import os
22
import os.path
33
import hashlib
4+
import gzip
45
import errno
6+
import tarfile
7+
import zipfile
8+
59
from torch.utils.model_zoo import tqdm
610

711

@@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768):
189193
progress += len(chunk)
190194
pbar.update(progress - pbar.n)
191195
pbar.close()
196+
197+
198+
def _is_tar(filename):
199+
return filename.endswith(".tar")
200+
201+
202+
def _is_targz(filename):
203+
return filename.endswith(".tar.gz")
204+
205+
206+
def _is_gzip(filename):
207+
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
208+
209+
210+
def _is_zip(filename):
211+
return filename.endswith(".zip")
212+
213+
214+
def extract_file(from_path, to_path, remove_finished=False):
215+
if _is_tar(from_path):
216+
with tarfile.open(from_path, 'r:') as tar:
217+
tar.extractall(path=to_path)
218+
elif _is_targz(from_path):
219+
with tarfile.open(from_path, 'r:gz') as tar:
220+
tar.extractall(path=to_path)
221+
elif _is_gzip(from_path):
222+
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
223+
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
224+
out_f.write(zip_f.read())
225+
elif _is_zip(from_path):
226+
with zipfile.ZipFile(from_path, 'r') as z:
227+
z.extractall(to_path)
228+
else:
229+
raise ValueError("Extraction of {} not supported".format(from_path))
230+
231+
if remove_finished:
232+
os.unlink(from_path)
233+
234+
235+
def download_and_extract(url, root, filename, md5=None, remove_finished=False):
236+
download_url(url, root, filename, md5)
237+
print("Extracting {} to {}".format(os.path.join(root, filename), root))
238+
extract_file(os.path.join(root, filename), root, remove_finished)

0 commit comments

Comments
 (0)