diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e2d2801216a..2cb623d84c7 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -9,10 +9,11 @@ from .fakedata import FakeData from .semeion import SEMEION from .omniglot import Omniglot +from .smallnorb import SmallNORB __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', - 'Omniglot') + 'Omniglot', 'SmallNORB') diff --git a/torchvision/datasets/smallnorb.py b/torchvision/datasets/smallnorb.py new file mode 100644 index 00000000000..2f33218c0de --- /dev/null +++ b/torchvision/datasets/smallnorb.py @@ -0,0 +1,345 @@ +from __future__ import print_function +import os +import errno +import struct + +import torch +import torch.utils.data as data +import numpy as np +from PIL import Image +from .utils import download_url, check_integrity + + +class SmallNORB(data.Dataset): + """`MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where processed folder and + and raw folder exist. + train (bool, optional): If True, creates dataset from the training files, + otherwise from the test files. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If the dataset is already processed, it is not processed + and downloaded again. If dataset is only already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + info_transform (callable, optional): A function/transform that takes in the + info and transforms it. + mode (string, optional): Denotes how the images in the data files are returned. Possible values: + - all (default): both left and right are included separately. + - stereo: left and right images are included as corresponding pairs. + - left: only the left images are included. + - right: only the right images are included. + """ + + dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/" + data_files = { + 'train': { + 'dat': { + "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat', + "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2", + "md5": "8138a0902307b32dfa0025a36dfa45ec" + }, + 'info': { + "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat', + "md5_gz": "51dee1210a742582ff607dfd94e332e3", + "md5": "19faee774120001fc7e17980d6960451" + }, + 'cat': { + "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat', + "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9", + "md5": "fd5120d3f770ad57ebe620eb61a0b633" + }, + }, + 'test': { + 'dat': { + "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat', + "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071", + "md5": "e9920b7f7b2869a8f1a12e945b2c166c" + }, + 'info': { + "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat', + "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e", + "md5": "7c5b871cc69dcadec1bf6a18141f5edc" + }, + 'cat': { + "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat', + "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603", + "md5": "fd5120d3f770ad57ebe620eb61a0b633" + }, + }, + } + + raw_folder = 'raw' + processed_folder = 'processed' + train_image_file = 'train_img' + train_label_file = 'train_label' + train_info_file = 'train_info' + test_image_file = 'test_img' + test_label_file = 'test_label' + test_info_file = 'test_info' + extension = '.pt' + + def __init__(self, root, train=True, transform=None, target_transform=None, info_transform=None, download=False, + mode="all"): + + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + self.info_transform = info_transform + self.train = train # training set or test set + self.mode = mode + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + # load test or train set + image_file = self.train_image_file if self.train else self.test_image_file + label_file = self.train_label_file if self.train else self.test_label_file + info_file = self.train_info_file if self.train else self.test_info_file + + # load labels + self.labels = self._load(label_file) + + # load info files + self.infos = self._load(info_file) + + # load right set + if self.mode == "left": + self.data = self._load("{}_left".format(image_file)) + + # load left set + elif self.mode == "right": + self.data = self._load("{}_right".format(image_file)) + + elif self.mode == "all" or self.mode == "stereo": + left_data = self._load("{}_left".format(image_file)) + right_data = self._load("{}_right".format(image_file)) + + # load stereo + if self.mode == "stereo": + self.data = torch.stack((left_data, right_data), dim=1) + + # load all + else: + self.data = torch.cat((left_data, right_data), dim=0) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + mode ``all'', ``left'', ``right'': + tuple: (image, target, info) + mode ``stereo'': + tuple: (image left, image right, target, info) + """ + target = self.labels[index % 24300] if self.mode is "all" else self.labels[index] + if self.target_transform is not None: + target = self.target_transform(target) + + info = self.infos[index % 24300] if self.mode is "all" else self.infos[index] + if self.info_transform is not None: + info = self.info_transform(info) + + if self.mode == "stereo": + img_left = self._transform(self.data[index, 0]) + img_right = self._transform(self.data[index, 1]) + return img_left, img_right, target, info + + img = self._transform(self.data[index]) + return img, target, info + + def __len__(self): + return len(self.data) + + def _transform(self, img): + # doing this so that it is consistent with all other data sets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode='L') + + if self.transform is not None: + img = self.transform(img) + return img + + def _load(self, file_name): + return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension)) + + def _save(self, file, file_name): + with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f: + torch.save(file, f) + + def _check_exists(self): + """ Check if processed files exists.""" + files = ( + "{}_left".format(self.train_image_file), + "{}_right".format(self.train_image_file), + "{}_left".format(self.test_image_file), + "{}_right".format(self.test_image_file), + self.test_label_file, + self.train_label_file + ) + fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files] + return False not in fpaths + + def _flat_data_files(self): + return [j for i in self.data_files.values() for j in list(i.values())] + + def _check_integrity(self): + """Check if unpacked files have correct md5 sum.""" + root = self.root + for file_dict in self._flat_data_files(): + filename = file_dict["name"] + md5 = file_dict["md5"] + fpath = os.path.join(root, self.raw_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self): + """Download the SmallNORB data if it doesn't exist in processed_folder already.""" + import gzip + + if self._check_exists(): + return + + # check if already extracted and verified + if self._check_integrity(): + print('Files already downloaded and verified') + else: + # download and extract + for file_dict in self._flat_data_files(): + url = self.dataset_root + file_dict["name"] + '.gz' + filename = file_dict["name"] + gz_filename = filename + '.gz' + md5 = file_dict["md5_gz"] + fpath = os.path.join(self.root, self.raw_folder, filename) + gz_fpath = fpath + '.gz' + + # download if compressed file not exists and verified + download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5) + + print('# Extracting data {}\n'.format(filename)) + + with open(fpath, 'wb') as out_f, \ + gzip.GzipFile(gz_fpath) as zip_f: + out_f.write(zip_f.read()) + + os.unlink(gz_fpath) + + # process and save as torch files + print('Processing...') + + # create processed folder + try: + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + # read train files + left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"]) + train_info = self._read_info_file(self.data_files["train"]["info"]["name"]) + train_label = self._read_label_file(self.data_files["train"]["cat"]["name"]) + + # read test files + left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"]) + test_info = self._read_info_file(self.data_files["test"]["info"]["name"]) + test_label = self._read_label_file(self.data_files["test"]["cat"]["name"]) + + # save training files + self._save(left_train_img, "{}_left".format(self.train_image_file)) + self._save(right_train_img, "{}_right".format(self.train_image_file)) + self._save(train_label, self.train_label_file) + self._save(train_info, self.train_info_file) + + # save test files + self._save(left_test_img, "{}_left".format(self.test_image_file)) + self._save(right_test_img, "{}_right".format(self.test_image_file)) + self._save(test_label, self.test_label_file) + self._save(test_info, self.test_info_file) + + print('Done!') + + @staticmethod + def _parse_header(file_pointer): + # Read magic number and ignore + struct.unpack('