Skip to content

Commit 240792d

Browse files
authored
New tests for ImageNet dataset (#3543)
1 parent 814c4f0 commit 240792d

File tree

3 files changed

+34
-82
lines changed

3 files changed

+34
-82
lines changed

test/datasets_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,8 @@ def create_dataset(
312312
patch_checks = inject_fake_data
313313

314314
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
315-
if "download" in self._HAS_SPECIAL_KWARG:
315+
if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
316+
# override download param to False param if its default is truthy
316317
special_kwargs["download"] = False
317318
config.update(other_kwargs)
318319

test/fakedata_generation.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -143,76 +143,6 @@ def _make_meta_file(file, classes_key):
143143
yield root
144144

145145

146-
@contextlib.contextmanager
147-
def imagenet_root():
148-
import scipy.io as sio
149-
150-
WNID = 'n01234567'
151-
CLS = 'fakedata'
152-
153-
def _make_image(file):
154-
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)
155-
156-
def _make_tar(archive, content, arcname=None, compress=False):
157-
mode = 'w:gz' if compress else 'w'
158-
if arcname is None:
159-
arcname = os.path.basename(content)
160-
with tarfile.open(archive, mode) as fh:
161-
fh.add(content, arcname=arcname)
162-
163-
def _make_train_archive(root):
164-
with get_tmp_dir() as tmp:
165-
wnid_dir = os.path.join(tmp, WNID)
166-
os.mkdir(wnid_dir)
167-
168-
_make_image(os.path.join(wnid_dir, WNID + '_1.JPEG'))
169-
170-
wnid_archive = wnid_dir + '.tar'
171-
_make_tar(wnid_archive, wnid_dir)
172-
173-
train_archive = os.path.join(root, 'ILSVRC2012_img_train.tar')
174-
_make_tar(train_archive, wnid_archive)
175-
176-
def _make_val_archive(root):
177-
with get_tmp_dir() as tmp:
178-
val_image = os.path.join(tmp, 'ILSVRC2012_val_00000001.JPEG')
179-
_make_image(val_image)
180-
181-
val_archive = os.path.join(root, 'ILSVRC2012_img_val.tar')
182-
_make_tar(val_archive, val_image)
183-
184-
def _make_devkit_archive(root):
185-
with get_tmp_dir() as tmp:
186-
data_dir = os.path.join(tmp, 'data')
187-
os.mkdir(data_dir)
188-
189-
meta_file = os.path.join(data_dir, 'meta.mat')
190-
synsets = np.core.records.fromarrays([
191-
(0.0, 1.0),
192-
(WNID, ''),
193-
(CLS, ''),
194-
('fakedata for the torchvision testsuite', ''),
195-
(0.0, 1.0),
196-
], names=['ILSVRC2012_ID', 'WNID', 'words', 'gloss', 'num_children'])
197-
sio.savemat(meta_file, {'synsets': synsets})
198-
199-
groundtruth_file = os.path.join(data_dir,
200-
'ILSVRC2012_validation_ground_truth.txt')
201-
with open(groundtruth_file, 'w') as fh:
202-
fh.write('0\n')
203-
204-
devkit_name = 'ILSVRC2012_devkit_t12'
205-
devkit_archive = os.path.join(root, devkit_name + '.tar.gz')
206-
_make_tar(devkit_archive, tmp, arcname=devkit_name, compress=True)
207-
208-
with get_tmp_dir() as root:
209-
_make_train_archive(root)
210-
_make_val_archive(root)
211-
_make_devkit_archive(root)
212-
213-
yield root
214-
215-
216146
@contextlib.contextmanager
217147
def widerface_root():
218148
"""

test/test_datasets.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torchvision
1111
from torchvision.datasets import utils
1212
from common_utils import get_tmp_dir
13-
from fakedata_generation import mnist_root, imagenet_root, \
13+
from fakedata_generation import mnist_root, \
1414
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
1515
import xml.etree.ElementTree as ET
1616
from urllib.request import Request, urlopen
@@ -146,16 +146,6 @@ def test_fashionmnist(self, mock_download_extract):
146146
img, target = dataset[0]
147147
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
148148

149-
@mock.patch('torchvision.datasets.imagenet._verify_archive')
150-
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
151-
def test_imagenet(self, mock_verify):
152-
with imagenet_root() as root:
153-
dataset = torchvision.datasets.ImageNet(root, split='train')
154-
self.generic_classification_dataset_test(dataset)
155-
156-
dataset = torchvision.datasets.ImageNet(root, split='val')
157-
self.generic_classification_dataset_test(dataset)
158-
159149
@mock.patch('torchvision.datasets.WIDERFace._check_integrity')
160150
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
161151
def test_widerface(self, mock_check_integrity):
@@ -490,6 +480,37 @@ def inject_fake_data(self, tmpdir, config):
490480
return num_images_per_category * len(categories)
491481

492482

483+
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
484+
DATASET_CLASS = datasets.ImageNet
485+
REQUIRED_PACKAGES = ('scipy',)
486+
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val'))
487+
488+
def inject_fake_data(self, tmpdir, config):
489+
tmpdir = pathlib.Path(tmpdir)
490+
491+
wnid = 'n01234567'
492+
if config['split'] == 'train':
493+
num_examples = 3
494+
datasets_utils.create_image_folder(
495+
root=tmpdir,
496+
name=tmpdir / 'train' / wnid / wnid,
497+
file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG",
498+
num_examples=num_examples,
499+
)
500+
else:
501+
num_examples = 1
502+
datasets_utils.create_image_folder(
503+
root=tmpdir,
504+
name=tmpdir / 'val' / wnid,
505+
file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG",
506+
num_examples=num_examples,
507+
)
508+
509+
wnid_to_classes = {wnid: [1]}
510+
torch.save((wnid_to_classes, None), tmpdir / 'meta.bin')
511+
return num_examples
512+
513+
493514
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
494515
DATASET_CLASS = datasets.CIFAR10
495516
CONFIGS = datasets_utils.combinations_grid(train=(True, False))

0 commit comments

Comments
 (0)