diff --git a/test/test_datasets.py b/test/test_datasets.py index ce99412f26f..a947df16c4b 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -120,7 +120,8 @@ def test_imagefolder_empty(self): ) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_mnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_mnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "MNIST") as root: dataset = torchvision.datasets.MNIST(root, download=True) @@ -129,7 +130,8 @@ def test_mnist(self, mock_download_extract): self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_kmnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_kmnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "KMNIST") as root: dataset = torchvision.datasets.KMNIST(root, download=True) @@ -138,7 +140,8 @@ def test_kmnist(self, mock_download_extract): self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_fashionmnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_fashionmnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "FashionMNIST") as root: dataset = torchvision.datasets.FashionMNIST(root, download=True) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 123ce40cb8e..e356f17dd1b 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -7,12 +7,10 @@ import torch import codecs import string -import gzip -import lzma -from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.error import URLError -from .utils import download_url, download_and_extract_archive, extract_archive, \ - verify_str_arg +from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity +import shutil class MNIST(VisionDataset): @@ -81,6 +79,10 @@ def __init__( target_transform=target_transform) self.train = train # training set or test set + if self._check_legacy_exist(): + self.data, self.targets = self._load_legacy_data() + return + if download: self.download() @@ -88,11 +90,31 @@ def __init__( raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') - if self.train: - data_file = self.training_file - else: - data_file = self.test_file - self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) + self.data, self.targets = self._load_data() + + def _check_legacy_exist(self): + processed_folder_exists = os.path.exists(self.processed_folder) + if not processed_folder_exists: + return False + + return all( + check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) + ) + + def _load_legacy_data(self): + # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data + # directly. + data_file = self.training_file if self.train else self.test_file + return torch.load(os.path.join(self.processed_folder, data_file)) + + def _load_data(self): + image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" + data = read_image_file(os.path.join(self.raw_folder, image_file)) + + label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" + targets = read_label_file(os.path.join(self.raw_folder, label_file)) + + return data, targets def __getitem__(self, index: int) -> Tuple[Any, Any]: """ @@ -132,19 +154,18 @@ def class_to_idx(self) -> Dict[str, int]: return {_class: i for i, _class in enumerate(self.classes)} def _check_exists(self) -> bool: - return (os.path.exists(os.path.join(self.processed_folder, - self.training_file)) and - os.path.exists(os.path.join(self.processed_folder, - self.test_file))) + return all( + check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) + for url, _ in self.resources + ) def download(self) -> None: - """Download the MNIST data if it doesn't exist in processed_folder already.""" + """Download the MNIST data if it doesn't exist already.""" if self._check_exists(): return os.makedirs(self.raw_folder, exist_ok=True) - os.makedirs(self.processed_folder, exist_ok=True) # download files for filename, md5 in self.resources: @@ -168,24 +189,6 @@ def download(self) -> None: else: raise RuntimeError("Error downloading {}".format(filename)) - # process and save as torch files - print('Processing...') - - training_set = ( - read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), - read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) - ) - test_set = ( - read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), - read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) - ) - with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: - torch.save(training_set, f) - with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: - torch.save(test_set, f) - - print('Done!') - def extra_repr(self) -> str: return "Split: {}".format("Train" if self.train is True else "Test") @@ -298,44 +301,39 @@ def _training_file(split) -> str: def _test_file(split) -> str: return 'test_{}.pt'.format(split) + @property + def _file_prefix(self) -> str: + return f"emnist-{self.split}-{'train' if self.train else 'test'}" + + @property + def images_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") + + @property + def labels_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") + + def _load_data(self): + return read_image_file(self.images_file), read_label_file(self.labels_file) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + def download(self) -> None: - """Download the EMNIST data if it doesn't exist in processed_folder already.""" - import shutil + """Download the EMNIST data if it doesn't exist already.""" if self._check_exists(): return os.makedirs(self.raw_folder, exist_ok=True) - os.makedirs(self.processed_folder, exist_ok=True) - # download files - print('Downloading and extracting zip archive') - download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip", - remove_finished=True, md5=self.md5) + download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): - extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder) - - # process and save as torch files - for split in self.splits: - print('Processing ' + split) - training_set = ( - read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), - read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) - ) - test_set = ( - read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), - read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) - ) - with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f: - torch.save(training_set, f) - with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f: - torch.save(test_set, f) + extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) shutil.rmtree(gzip_folder) - print('Done!') - class QMNIST(MNIST): """`QMNIST `_ Dataset. @@ -404,40 +402,51 @@ def __init__( self.test_file = self.data_file super(QMNIST, self).__init__(root, train, **kwargs) + @property + def images_file(self) -> str: + (url, _), _ = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + @property + def labels_file(self) -> str: + _, (url, _) = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def _load_data(self): + data = read_sn3_pascalvincent_tensor(self.images_file) + assert (data.dtype == torch.uint8) + assert (data.ndimension() == 3) + + targets = read_sn3_pascalvincent_tensor(self.labels_file).long() + assert (targets.ndimension() == 2) + + if self.what == 'test10k': + data = data[0:10000, :, :].clone() + targets = targets[0:10000, :].clone() + elif self.what == 'test50k': + data = data[10000:, :, :].clone() + targets = targets[10000:, :].clone() + + return data, targets + def download(self) -> None: - """Download the QMNIST data if it doesn't exist in processed_folder already. + """Download the QMNIST data if it doesn't exist already. Note that we only download what has been asked for (argument 'what'). """ if self._check_exists(): return + os.makedirs(self.raw_folder, exist_ok=True) - os.makedirs(self.processed_folder, exist_ok=True) split = self.resources[self.subsets[self.what]] - files = [] - # download data files if not already there for url, md5 in split: filename = url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) if not os.path.isfile(file_path): - download_url(url, root=self.raw_folder, filename=filename, md5=md5) - files.append(file_path) - - # process and save as torch files - print('Processing...') - data = read_sn3_pascalvincent_tensor(files[0]) - assert(data.dtype == torch.uint8) - assert(data.ndimension() == 3) - targets = read_sn3_pascalvincent_tensor(files[1]).long() - assert(targets.ndimension() == 2) - if self.what == 'test10k': - data = data[0:10000, :, :].clone() - targets = targets[0:10000, :].clone() - if self.what == 'test50k': - data = data[10000:, :, :].clone() - targets = targets[10000:, :].clone() - with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f: - torch.save((data, targets), f) + download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5) def __getitem__(self, index: int) -> Tuple[Any, Any]: # redefined to handle the compat flag @@ -459,19 +468,6 @@ def get_int(b: bytes) -> int: return int(codecs.encode(b, 'hex'), 16) -def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]: - """Return a file object that possibly decompresses 'path' on the fly. - Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. - """ - if not isinstance(path, torch._six.string_classes): - return path - if path.endswith('.gz'): - return gzip.open(path, 'rb') - if path.endswith('.xz'): - return lzma.open(path, 'rb') - return open(path, 'rb') - - SN3_PASCALVINCENT_TYPEMAP = { 8: (torch.uint8, np.uint8, np.uint8), 9: (torch.int8, np.int8, np.int8), @@ -482,12 +478,12 @@ def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile] } -def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor: +def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). Argument may be a filename, compressed filename, or file object. """ # read - with open_maybe_compressed_file(path) as f: + with open(path, "rb") as f: data = f.read() # parse magic = get_int(data[0:4]) @@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> def read_label_file(path: str) -> torch.Tensor: - with open(path, 'rb') as f: - x = read_sn3_pascalvincent_tensor(f, strict=False) + x = read_sn3_pascalvincent_tensor(path, strict=False) assert(x.dtype == torch.uint8) assert(x.ndimension() == 1) return x.long() def read_image_file(path: str) -> torch.Tensor: - with open(path, 'rb') as f: - x = read_sn3_pascalvincent_tensor(f, strict=False) + x = read_sn3_pascalvincent_tensor(path, strict=False) assert(x.dtype == torch.uint8) assert(x.ndimension() == 3) return x