Skip to content

Commit 61d7e23

Browse files
author
Philip Meier
committed
fix test
1 parent 4f13190 commit 61d7e23

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

test/test_datasets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ def test_fashionmnist(self, mock_download_extract):
108108
img, target = dataset[0]
109109
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
110110

111-
@mock.patch('torchvision.datasets.utils.download_url')
111+
@mock.patch('torchvision.datasets.imagenet.ImageNet._verify_archive')
112112
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
113-
def test_imagenet(self, mock_download):
113+
def test_imagenet(self, mock_check):
114114
with imagenet_root() as root:
115-
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
115+
dataset = torchvision.datasets.ImageNet(root, split='train')
116116
self.generic_classification_dataset_test(dataset)
117117

118-
dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
118+
dataset = torchvision.datasets.ImageNet(root, split='val')
119119
self.generic_classification_dataset_test(dataset)
120120

121121
@mock.patch('torchvision.datasets.cifar.check_integrity')

torchvision/datasets/imagenet.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,29 @@ def __init__(self, root, split='train', download=None, **kwargs):
7373
for cls in clss}
7474

7575
def extract_archives(self):
76-
def check_archive(archive_dict):
77-
file = archive_dict["file"]
78-
md5 = archive_dict["md5"]
79-
archive = os.path.join(self.root, file)
80-
if not check_integrity(archive, md5):
81-
msg = ("The file {} is not present in the root directory. You need to "
82-
"download it externally and place it in {}.")
83-
raise RuntimeError(msg.format(file, self.root))
84-
85-
return archive
86-
8776
if not check_integrity(self.meta_file):
88-
archive = check_archive(ARCHIVE_DICT['devkit'])
77+
archive_dict = ARCHIVE_DICT['devkit']
78+
archive = os.path.join(self.root, archive_dict["file"])
79+
self._verify_archive(archive, archive_dict["md5"])
80+
8981
parse_devkit_archive(archive)
9082

9183
if not os.path.isdir(self.split_folder):
92-
archive = check_archive(ARCHIVE_DICT[self.split])
84+
archive_dict = ARCHIVE_DICT[self.split]
85+
archive = os.path.join(self.root, archive_dict["file"])
86+
self._verify_archive(archive, archive_dict["md5"])
9387

9488
if self.split == 'train':
9589
parse_train_archive(archive)
9690
elif self.split == 'val':
9791
parse_val_archive(archive)
9892

93+
def _verify_archive(self, archive, md5):
94+
if not check_integrity(archive, md5):
95+
msg = ("The file {} is not present in the root directory or corrupted. "
96+
"You need to download it externally and place it in {}.")
97+
raise RuntimeError(msg.format(os.path.basename(archive), self.root))
98+
9999
@property
100100
def meta_file(self):
101101
return os.path.join(self.root, META_FILE_NAME)

0 commit comments

Comments
 (0)