Skip to content

Commit b671001

Browse files
vincentqbfacebook-github-bot
authored andcommitted
[fbsync] Remove caching from MNIST and variants (#3420)
Summary: * remove caching from (Fashion|K)?MNIST * remove unnecessary lazy import * remove false check of binaries against the md5 of archives * remove caching from EMNIST * remove caching from QMNIST * lint * fix EMNIST * streamline QMNIST download Reviewed By: fmassa Differential Revision: D27127995 fbshipit-source-id: 3f53be72b5e7c8abe191edb1e4467e3ef33741dd
1 parent cb77aa3 commit b671001

File tree

2 files changed

+101
-104
lines changed

2 files changed

+101
-104
lines changed

test/test_datasets.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def test_imagefolder_empty(self):
120120
)
121121

122122
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
123-
def test_mnist(self, mock_download_extract):
123+
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
124+
def test_mnist(self, mock_download_extract, mock_check_integrity):
124125
num_examples = 30
125126
with mnist_root(num_examples, "MNIST") as root:
126127
dataset = torchvision.datasets.MNIST(root, download=True)
@@ -129,7 +130,8 @@ def test_mnist(self, mock_download_extract):
129130
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
130131

131132
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
132-
def test_kmnist(self, mock_download_extract):
133+
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
134+
def test_kmnist(self, mock_download_extract, mock_check_integrity):
133135
num_examples = 30
134136
with mnist_root(num_examples, "KMNIST") as root:
135137
dataset = torchvision.datasets.KMNIST(root, download=True)
@@ -138,7 +140,8 @@ def test_kmnist(self, mock_download_extract):
138140
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
139141

140142
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
141-
def test_fashionmnist(self, mock_download_extract):
143+
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
144+
def test_fashionmnist(self, mock_download_extract, mock_check_integrity):
142145
num_examples = 30
143146
with mnist_root(num_examples, "FashionMNIST") as root:
144147
dataset = torchvision.datasets.FashionMNIST(root, download=True)

torchvision/datasets/mnist.py

Lines changed: 95 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77
import torch
88
import codecs
99
import string
10-
import gzip
11-
import lzma
12-
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
10+
from typing import Any, Callable, Dict, List, Optional, Tuple
1311
from urllib.error import URLError
14-
from .utils import download_url, download_and_extract_archive, extract_archive, \
15-
verify_str_arg
12+
from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
13+
import shutil
1614

1715

1816
class MNIST(VisionDataset):
@@ -81,18 +79,42 @@ def __init__(
8179
target_transform=target_transform)
8280
self.train = train # training set or test set
8381

82+
if self._check_legacy_exist():
83+
self.data, self.targets = self._load_legacy_data()
84+
return
85+
8486
if download:
8587
self.download()
8688

8789
if not self._check_exists():
8890
raise RuntimeError('Dataset not found.' +
8991
' You can use download=True to download it')
9092

91-
if self.train:
92-
data_file = self.training_file
93-
else:
94-
data_file = self.test_file
95-
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
93+
self.data, self.targets = self._load_data()
94+
95+
def _check_legacy_exist(self):
96+
processed_folder_exists = os.path.exists(self.processed_folder)
97+
if not processed_folder_exists:
98+
return False
99+
100+
return all(
101+
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
102+
)
103+
104+
def _load_legacy_data(self):
105+
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
106+
# directly.
107+
data_file = self.training_file if self.train else self.test_file
108+
return torch.load(os.path.join(self.processed_folder, data_file))
109+
110+
def _load_data(self):
111+
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
112+
data = read_image_file(os.path.join(self.raw_folder, image_file))
113+
114+
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
115+
targets = read_label_file(os.path.join(self.raw_folder, label_file))
116+
117+
return data, targets
96118

97119
def __getitem__(self, index: int) -> Tuple[Any, Any]:
98120
"""
@@ -132,19 +154,18 @@ def class_to_idx(self) -> Dict[str, int]:
132154
return {_class: i for i, _class in enumerate(self.classes)}
133155

134156
def _check_exists(self) -> bool:
135-
return (os.path.exists(os.path.join(self.processed_folder,
136-
self.training_file)) and
137-
os.path.exists(os.path.join(self.processed_folder,
138-
self.test_file)))
157+
return all(
158+
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
159+
for url, _ in self.resources
160+
)
139161

140162
def download(self) -> None:
141-
"""Download the MNIST data if it doesn't exist in processed_folder already."""
163+
"""Download the MNIST data if it doesn't exist already."""
142164

143165
if self._check_exists():
144166
return
145167

146168
os.makedirs(self.raw_folder, exist_ok=True)
147-
os.makedirs(self.processed_folder, exist_ok=True)
148169

149170
# download files
150171
for filename, md5 in self.resources:
@@ -168,24 +189,6 @@ def download(self) -> None:
168189
else:
169190
raise RuntimeError("Error downloading {}".format(filename))
170191

171-
# process and save as torch files
172-
print('Processing...')
173-
174-
training_set = (
175-
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
176-
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
177-
)
178-
test_set = (
179-
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
180-
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
181-
)
182-
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
183-
torch.save(training_set, f)
184-
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
185-
torch.save(test_set, f)
186-
187-
print('Done!')
188-
189192
def extra_repr(self) -> str:
190193
return "Split: {}".format("Train" if self.train is True else "Test")
191194

@@ -298,44 +301,39 @@ def _training_file(split) -> str:
298301
def _test_file(split) -> str:
299302
return 'test_{}.pt'.format(split)
300303

304+
@property
305+
def _file_prefix(self) -> str:
306+
return f"emnist-{self.split}-{'train' if self.train else 'test'}"
307+
308+
@property
309+
def images_file(self) -> str:
310+
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
311+
312+
@property
313+
def labels_file(self) -> str:
314+
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
315+
316+
def _load_data(self):
317+
return read_image_file(self.images_file), read_label_file(self.labels_file)
318+
319+
def _check_exists(self) -> bool:
320+
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
321+
301322
def download(self) -> None:
302-
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
303-
import shutil
323+
"""Download the EMNIST data if it doesn't exist already."""
304324

305325
if self._check_exists():
306326
return
307327

308328
os.makedirs(self.raw_folder, exist_ok=True)
309-
os.makedirs(self.processed_folder, exist_ok=True)
310329

311-
# download files
312-
print('Downloading and extracting zip archive')
313-
download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
314-
remove_finished=True, md5=self.md5)
330+
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
315331
gzip_folder = os.path.join(self.raw_folder, 'gzip')
316332
for gzip_file in os.listdir(gzip_folder):
317333
if gzip_file.endswith('.gz'):
318-
extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder)
319-
320-
# process and save as torch files
321-
for split in self.splits:
322-
print('Processing ' + split)
323-
training_set = (
324-
read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
325-
read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
326-
)
327-
test_set = (
328-
read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
329-
read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
330-
)
331-
with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
332-
torch.save(training_set, f)
333-
with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
334-
torch.save(test_set, f)
334+
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
335335
shutil.rmtree(gzip_folder)
336336

337-
print('Done!')
338-
339337

340338
class QMNIST(MNIST):
341339
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
@@ -404,40 +402,51 @@ def __init__(
404402
self.test_file = self.data_file
405403
super(QMNIST, self).__init__(root, train, **kwargs)
406404

405+
@property
406+
def images_file(self) -> str:
407+
(url, _), _ = self.resources[self.subsets[self.what]]
408+
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
409+
410+
@property
411+
def labels_file(self) -> str:
412+
_, (url, _) = self.resources[self.subsets[self.what]]
413+
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
414+
415+
def _check_exists(self) -> bool:
416+
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
417+
418+
def _load_data(self):
419+
data = read_sn3_pascalvincent_tensor(self.images_file)
420+
assert (data.dtype == torch.uint8)
421+
assert (data.ndimension() == 3)
422+
423+
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
424+
assert (targets.ndimension() == 2)
425+
426+
if self.what == 'test10k':
427+
data = data[0:10000, :, :].clone()
428+
targets = targets[0:10000, :].clone()
429+
elif self.what == 'test50k':
430+
data = data[10000:, :, :].clone()
431+
targets = targets[10000:, :].clone()
432+
433+
return data, targets
434+
407435
def download(self) -> None:
408-
"""Download the QMNIST data if it doesn't exist in processed_folder already.
436+
"""Download the QMNIST data if it doesn't exist already.
409437
Note that we only download what has been asked for (argument 'what').
410438
"""
411439
if self._check_exists():
412440
return
441+
413442
os.makedirs(self.raw_folder, exist_ok=True)
414-
os.makedirs(self.processed_folder, exist_ok=True)
415443
split = self.resources[self.subsets[self.what]]
416-
files = []
417444

418-
# download data files if not already there
419445
for url, md5 in split:
420446
filename = url.rpartition('/')[2]
421447
file_path = os.path.join(self.raw_folder, filename)
422448
if not os.path.isfile(file_path):
423-
download_url(url, root=self.raw_folder, filename=filename, md5=md5)
424-
files.append(file_path)
425-
426-
# process and save as torch files
427-
print('Processing...')
428-
data = read_sn3_pascalvincent_tensor(files[0])
429-
assert(data.dtype == torch.uint8)
430-
assert(data.ndimension() == 3)
431-
targets = read_sn3_pascalvincent_tensor(files[1]).long()
432-
assert(targets.ndimension() == 2)
433-
if self.what == 'test10k':
434-
data = data[0:10000, :, :].clone()
435-
targets = targets[0:10000, :].clone()
436-
if self.what == 'test50k':
437-
data = data[10000:, :, :].clone()
438-
targets = targets[10000:, :].clone()
439-
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
440-
torch.save((data, targets), f)
449+
download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5)
441450

442451
def __getitem__(self, index: int) -> Tuple[Any, Any]:
443452
# redefined to handle the compat flag
@@ -459,19 +468,6 @@ def get_int(b: bytes) -> int:
459468
return int(codecs.encode(b, 'hex'), 16)
460469

461470

462-
def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:
463-
"""Return a file object that possibly decompresses 'path' on the fly.
464-
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
465-
"""
466-
if not isinstance(path, torch._six.string_classes):
467-
return path
468-
if path.endswith('.gz'):
469-
return gzip.open(path, 'rb')
470-
if path.endswith('.xz'):
471-
return lzma.open(path, 'rb')
472-
return open(path, 'rb')
473-
474-
475471
SN3_PASCALVINCENT_TYPEMAP = {
476472
8: (torch.uint8, np.uint8, np.uint8),
477473
9: (torch.int8, np.int8, np.int8),
@@ -482,12 +478,12 @@ def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]
482478
}
483479

484480

485-
def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:
481+
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
486482
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
487483
Argument may be a filename, compressed filename, or file object.
488484
"""
489485
# read
490-
with open_maybe_compressed_file(path) as f:
486+
with open(path, "rb") as f:
491487
data = f.read()
492488
# parse
493489
magic = get_int(data[0:4])
@@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) ->
503499

504500

505501
def read_label_file(path: str) -> torch.Tensor:
506-
with open(path, 'rb') as f:
507-
x = read_sn3_pascalvincent_tensor(f, strict=False)
502+
x = read_sn3_pascalvincent_tensor(path, strict=False)
508503
assert(x.dtype == torch.uint8)
509504
assert(x.ndimension() == 1)
510505
return x.long()
511506

512507

513508
def read_image_file(path: str) -> torch.Tensor:
514-
with open(path, 'rb') as f:
515-
x = read_sn3_pascalvincent_tensor(f, strict=False)
509+
x = read_sn3_pascalvincent_tensor(path, strict=False)
516510
assert(x.dtype == torch.uint8)
517511
assert(x.ndimension() == 3)
518512
return x

0 commit comments

Comments
 (0)