Skip to content

Remove caching from MNIST and variants #3420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def test_imagefolder_empty(self):
)

@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_mnist(self, mock_download_extract):
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_mnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True)
Expand All @@ -129,7 +130,8 @@ def test_mnist(self, mock_download_extract):
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)

@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_kmnist(self, mock_download_extract):
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_kmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True)
Expand All @@ -138,7 +140,8 @@ def test_kmnist(self, mock_download_extract):
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)

@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_fashionmnist(self, mock_download_extract):
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_fashionmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True)
Expand Down
196 changes: 95 additions & 101 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
import torch
import codecs
import string
import gzip
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
from .utils import download_url, download_and_extract_archive, extract_archive, \
verify_str_arg
from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
import shutil


class MNIST(VisionDataset):
Expand Down Expand Up @@ -81,18 +79,42 @@ def __init__(
target_transform=target_transform)
self.train = train # training set or test set

if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return

if download:
self.download()

if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')

if self.train:
data_file = self.training_file
else:
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
self.data, self.targets = self._load_data()

def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False

return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)

def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))

def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))

label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))

return data, targets

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Expand Down Expand Up @@ -132,19 +154,18 @@ def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}

def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))
return all(
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
for url, _ in self.resources
)

def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already."""
"""Download the MNIST data if it doesn't exist already."""

if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)

# download files
for filename, md5 in self.resources:
Expand All @@ -168,24 +189,6 @@ def download(self) -> None:
else:
raise RuntimeError("Error downloading {}".format(filename))

# process and save as torch files
print('Processing...')

training_set = (
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)

print('Done!')

def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")

Expand Down Expand Up @@ -298,44 +301,39 @@ def _training_file(split) -> str:
def _test_file(split) -> str:
return 'test_{}.pt'.format(split)

@property
def _file_prefix(self) -> str:
return f"emnist-{self.split}-{'train' if self.train else 'test'}"

@property
def images_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")

@property
def labels_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")

def _load_data(self):
return read_image_file(self.images_file), read_label_file(self.labels_file)

def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))

def download(self) -> None:
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil
"""Download the EMNIST data if it doesn't exist already."""

if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)

# download files
print('Downloading and extracting zip archive')
download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
remove_finished=True, md5=self.md5)
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder)

# process and save as torch files
for split in self.splits:
print('Processing ' + split)
training_set = (
read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
)
test_set = (
read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
)
with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f)
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
shutil.rmtree(gzip_folder)

print('Done!')


class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
Expand Down Expand Up @@ -404,40 +402,51 @@ def __init__(
self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs)

@property
def images_file(self) -> str:
(url, _), _ = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

@property
def labels_file(self) -> str:
_, (url, _) = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))

def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
assert (data.dtype == torch.uint8)
assert (data.ndimension() == 3)

targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert (targets.ndimension() == 2)

if self.what == 'test10k':
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == 'test50k':
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()

return data, targets

def download(self) -> None:
"""Download the QMNIST data if it doesn't exist in processed_folder already.
"""Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
split = self.resources[self.subsets[self.what]]
files = []

# download data files if not already there
for url, md5 in split:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
if not os.path.isfile(file_path):
download_url(url, root=self.raw_folder, filename=filename, md5=md5)
files.append(file_path)

# process and save as torch files
print('Processing...')
data = read_sn3_pascalvincent_tensor(files[0])
assert(data.dtype == torch.uint8)
assert(data.ndimension() == 3)
targets = read_sn3_pascalvincent_tensor(files[1]).long()
assert(targets.ndimension() == 2)
if self.what == 'test10k':
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
if self.what == 'test50k':
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
torch.save((data, targets), f)
download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
Expand All @@ -459,19 +468,6 @@ def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16)


def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
"""
if not isinstance(path, torch._six.string_classes):
return path
if path.endswith('.gz'):
return gzip.open(path, 'rb')
if path.endswith('.xz'):
return lzma.open(path, 'rb')
return open(path, 'rb')


SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
Expand All @@ -482,12 +478,12 @@ def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]
}


def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with open_maybe_compressed_file(path) as f:
with open(path, "rb") as f:
data = f.read()
# parse
magic = get_int(data[0:4])
Expand All @@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) ->


def read_label_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 1)
return x.long()


def read_image_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 3)
return x