-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[WIP] Add tests for datasets #966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6a73c83
bce0176
c3c708b
d8c0d9f
ce05448
05e64b0
b51e0b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import PIL | ||
import shutil | ||
import tempfile | ||
import unittest | ||
|
||
import torchvision | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
|
||
def test_mnist(self): | ||
tmp_dir = tempfile.mkdtemp() | ||
dataset = torchvision.datasets.MNIST(tmp_dir, download=True) | ||
self.assertEqual(len(dataset), 60000) | ||
img, target = dataset[0] | ||
self.assertTrue(isinstance(img, PIL.Image.Image)) | ||
self.assertTrue(isinstance(target, int)) | ||
shutil.rmtree(tmp_dir) | ||
|
||
def test_kmnist(self): | ||
tmp_dir = tempfile.mkdtemp() | ||
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True) | ||
img, target = dataset[0] | ||
self.assertTrue(isinstance(img, PIL.Image.Image)) | ||
self.assertTrue(isinstance(target, int)) | ||
shutil.rmtree(tmp_dir) | ||
|
||
def test_fashionmnist(self): | ||
tmp_dir = tempfile.mkdtemp() | ||
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True) | ||
img, target = dataset[0] | ||
self.assertTrue(isinstance(img, PIL.Image.Image)) | ||
self.assertTrue(isinstance(target, int)) | ||
shutil.rmtree(tmp_dir) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
import os | ||
import os.path | ||
import hashlib | ||
import gzip | ||
import errno | ||
import tarfile | ||
import zipfile | ||
|
||
from torch.utils.model_zoo import tqdm | ||
|
||
|
||
|
@@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768): | |
progress += len(chunk) | ||
pbar.update(progress - pbar.n) | ||
pbar.close() | ||
|
||
|
||
def _is_tar(filename): | ||
return filename.endswith(".tar") | ||
|
||
|
||
def _is_targz(filename): | ||
return filename.endswith(".tar.gz") | ||
|
||
|
||
def _is_gzip(filename): | ||
return filename.endswith(".gz") and not filename.endswith(".tar.gz") | ||
|
||
|
||
def _is_zip(filename): | ||
return filename.endswith(".zip") | ||
|
||
|
||
def extract_file(from_path, to_path, remove_finished=False): | ||
if _is_tar(from_path): | ||
with tarfile.open(from_path, 'r:') as tar: | ||
tar.extractall(path=to_path) | ||
elif _is_targz(from_path): | ||
with tarfile.open(from_path, 'r:gz') as tar: | ||
tar.extractall(path=to_path) | ||
elif _is_gzip(from_path): | ||
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) | ||
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: | ||
out_f.write(zip_f.read()) | ||
elif _is_zip(from_path): | ||
with zipfile.ZipFile(from_path, 'r') as z: | ||
z.extractall(to_path) | ||
else: | ||
raise ValueError("Extraction of {} not supported".format(from_path)) | ||
|
||
if remove_finished: | ||
os.unlink(from_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pure curiosity: Why did you use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No other reason than just because it was what was used in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. |
||
|
||
|
||
def download_and_extract(url, root, filename, md5=None, remove_finished=False): | ||
download_url(url, root, filename, md5) | ||
print("Extracting {} to {}".format(os.path.join(root, filename), root)) | ||
extract_file(os.path.join(root, filename), root, remove_finished) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we do a lazy import in
extract_file()
? AFAIK we do this now to prevent becoming dependent on these packages atimport
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe
tarfile
,zipfile
andgzip
are part of the python standard library, so I think this should be ok