Skip to content

Remove download for ImageNet #1457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ def test_fashionmnist(self, mock_download_extract):
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)

@mock.patch('torchvision.datasets.utils.download_url')
@mock.patch('torchvision.datasets.imagenet._verify_archive')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_imagenet(self, mock_download):
def test_imagenet(self, mock_verify):
with imagenet_root() as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
dataset = torchvision.datasets.ImageNet(root, split='train')
self.generic_classification_dataset_test(dataset)

dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset)

@mock.patch('torchvision.datasets.cifar.check_integrity')
Expand Down
247 changes: 147 additions & 100 deletions torchvision/datasets/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,27 @@
from __future__ import print_function
import warnings
from contextlib import contextmanager
import os
import shutil
import tempfile
import torch
from .folder import ImageFolder
from .utils import check_integrity, download_and_extract_archive, extract_archive, \
verify_str_arg

ARCHIVE_DICT = {
'train': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
'md5': '1d675b47d978889d74fa0da5fadfb00e',
},
'val': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
'md5': '29b22e2961454d5413ddabcf34fc5622',
},
'devkit': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
'md5': 'fa75699e90414af021442c21a62c3abf',
}
from .utils import check_integrity, extract_archive, verify_str_arg

ARCHIVE_META = {
'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
}

META_FILE = "meta.bin"


class ImageNet(ImageFolder):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.

Args:
root (string): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
Expand All @@ -47,13 +37,22 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset
"""

def __init__(self, root, split='train', download=False, **kwargs):
def __init__(self, root, split='train', download=None, **kwargs):
if download is True:
msg = ("The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root "
"directory.")
raise RuntimeError(msg)
elif download is False:
msg = ("The use of the download flag is deprecated, since the dataset "
"is no longer publicly accessible.")
warnings.warn(msg, RuntimeWarning)

root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val"))

if download:
self.download()
wnid_to_classes = self._load_meta_file()[0]
self.parse_archives()
wnid_to_classes = load_meta_file(self.root)[0]

super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root
Expand All @@ -65,50 +64,15 @@ def __init__(self, root, split='train', download=False, **kwargs):
for idx, clss in enumerate(self.classes)
for cls in clss}

def download(self):
if not check_integrity(self.meta_file):
tmp_dir = tempfile.mkdtemp()

archive_dict = ARCHIVE_DICT['devkit']
download_and_extract_archive(archive_dict['url'], self.root,
extract_root=tmp_dir,
md5=archive_dict['md5'])
devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
self._save_meta_file(*meta)

shutil.rmtree(tmp_dir)
def parse_archives(self):
if not check_integrity(os.path.join(self.root, META_FILE)):
parse_devkit_archive(self.root)

if not os.path.isdir(self.split_folder):
archive_dict = ARCHIVE_DICT[self.split]
download_and_extract_archive(archive_dict['url'], self.root,
extract_root=self.split_folder,
md5=archive_dict['md5'])

if self.split == 'train':
prepare_train_folder(self.split_folder)
parse_train_archive(self.root)
elif self.split == 'val':
val_wnids = self._load_meta_file()[1]
prepare_val_folder(self.split_folder, val_wnids)
else:
msg = ("You set download=True, but a folder '{}' already exist in "
"the root directory. If you want to re-download or re-extract the "
"archive, delete the folder.")
print(msg.format(self.split))

@property
def meta_file(self):
return os.path.join(self.root, 'meta.bin')

def _load_meta_file(self):
if check_integrity(self.meta_file):
return torch.load(self.meta_file)
else:
raise RuntimeError("Meta file not found or corrupted.",
"You can use download=True to create it.")

def _save_meta_file(self, wnid_to_class, val_wnids):
torch.save((wnid_to_class, val_wnids), self.meta_file)
parse_val_archive(self.root)

@property
def split_folder(self):
Expand All @@ -118,54 +82,137 @@ def extra_repr(self):
return "Split: {split}".format(**self.__dict__)


def parse_devkit(root):
idx_to_wnid, wnid_to_classes = parse_meta(root)
val_idcs = parse_val_groundtruth(root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
return wnid_to_classes, val_wnids
def load_meta_file(root, file=None):
if file is None:
file = META_FILE
file = os.path.join(root, file)

if check_integrity(file):
return torch.load(file)
else:
msg = ("The meta file {} is not present in the root directory or is corrupted. "
"This file is automatically created by the ImageNet dataset.")
raise RuntimeError(msg.format(file, root))


def _verify_archive(root, file, md5):
if not check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}.")
raise RuntimeError(msg.format(file, root))

def parse_meta(devkit_root, path='data', filename='meta.mat'):

def parse_devkit_archive(root, file=None):
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.

Args:
root (str): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz'
"""
import scipy.io as sio

metafile = os.path.join(devkit_root, path, filename)
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children)
if num_children == 0]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(', ')) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_meta_mat(devkit_root):
metafile = os.path.join(devkit_root, "data", "meta.mat")
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children)
if num_children == 0]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(', ')) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes

def parse_val_groundtruth_txt(devkit_root):
file = os.path.join(devkit_root, "data",
"ILSVRC2012_validation_ground_truth.txt")
with open(file, 'r') as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]

@contextmanager
def get_tmp_dir():
tmp_dir = tempfile.mkdtemp()
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)

archive_meta = ARCHIVE_META["devkit"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]

_verify_archive(root, file, md5)

with get_tmp_dir() as tmp_dir:
extract_archive(os.path.join(root, file), tmp_dir)

devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
val_idcs = parse_val_groundtruth_txt(devkit_root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]

torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))


def parse_train_archive(root, file=None, folder="train"):
"""Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset.

Args:
root (str): Root directory containing the train images archive
file (str, optional): Name of train images archive. Defaults to
'ILSVRC2012_img_train.tar'
folder (str, optional): Optional name for train images folder. Defaults to
'train'
"""
archive_meta = ARCHIVE_META["train"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]

def parse_val_groundtruth(devkit_root, path='data',
filename='ILSVRC2012_validation_ground_truth.txt'):
with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
_verify_archive(root, file, md5)

train_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), train_root)

def prepare_train_folder(folder):
for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
for archive in archives:
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)


def prepare_val_folder(folder, wnids):
img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
def parse_val_archive(root, file=None, wnids=None, folder="val"):
"""Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset.

for wnid in set(wnids):
os.mkdir(os.path.join(folder, wnid))
Args:
root (str): Root directory containing the validation images archive
file (str, optional): Name of validation images archive. Defaults to
'ILSVRC2012_img_val.tar'
wnids (list, optional): List of WordNet IDs of the validation images. If None
is given, the IDs are loaded from the meta file in the root directory
folder (str, optional): Optional name for validation images folder. Defaults to
'val'
"""
archive_meta = ARCHIVE_META["val"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
if wnids is None:
wnids = load_meta_file(root)[1]

_verify_archive(root, file, md5)

for wnid, img_file in zip(wnids, img_files):
shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
val_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), val_root)

images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])

for wnid in set(wnids):
os.mkdir(os.path.join(val_root, wnid))

def _splitexts(root):
exts = []
ext = '.'
while ext:
root, ext = os.path.splitext(root)
exts.append(ext)
return root, ''.join(reversed(exts))
for wnid, img_file in zip(wnids, images):
shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))