diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py
index e2d2801216a..2cb623d84c7 100644
--- a/torchvision/datasets/__init__.py
+++ b/torchvision/datasets/__init__.py
@@ -9,10 +9,11 @@
from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot
+from .smallnorb import SmallNORB
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
- 'Omniglot')
+ 'Omniglot', 'SmallNORB')
diff --git a/torchvision/datasets/smallnorb.py b/torchvision/datasets/smallnorb.py
new file mode 100644
index 00000000000..2f33218c0de
--- /dev/null
+++ b/torchvision/datasets/smallnorb.py
@@ -0,0 +1,345 @@
+from __future__ import print_function
+import os
+import errno
+import struct
+
+import torch
+import torch.utils.data as data
+import numpy as np
+from PIL import Image
+from .utils import download_url, check_integrity
+
+
+class SmallNORB(data.Dataset):
+ """`MNIST `_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where processed folder and
+ and raw folder exist.
+ train (bool, optional): If True, creates dataset from the training files,
+ otherwise from the test files.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If the dataset is already processed, it is not processed
+ and downloaded again. If dataset is only already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ info_transform (callable, optional): A function/transform that takes in the
+ info and transforms it.
+ mode (string, optional): Denotes how the images in the data files are returned. Possible values:
+ - all (default): both left and right are included separately.
+ - stereo: left and right images are included as corresponding pairs.
+ - left: only the left images are included.
+ - right: only the right images are included.
+ """
+
+ dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/"
+ data_files = {
+ 'train': {
+ 'dat': {
+ "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat',
+ "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2",
+ "md5": "8138a0902307b32dfa0025a36dfa45ec"
+ },
+ 'info': {
+ "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat',
+ "md5_gz": "51dee1210a742582ff607dfd94e332e3",
+ "md5": "19faee774120001fc7e17980d6960451"
+ },
+ 'cat': {
+ "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat',
+ "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9",
+ "md5": "fd5120d3f770ad57ebe620eb61a0b633"
+ },
+ },
+ 'test': {
+ 'dat': {
+ "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat',
+ "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071",
+ "md5": "e9920b7f7b2869a8f1a12e945b2c166c"
+ },
+ 'info': {
+ "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat',
+ "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e",
+ "md5": "7c5b871cc69dcadec1bf6a18141f5edc"
+ },
+ 'cat': {
+ "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat',
+ "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603",
+ "md5": "fd5120d3f770ad57ebe620eb61a0b633"
+ },
+ },
+ }
+
+ raw_folder = 'raw'
+ processed_folder = 'processed'
+ train_image_file = 'train_img'
+ train_label_file = 'train_label'
+ train_info_file = 'train_info'
+ test_image_file = 'test_img'
+ test_label_file = 'test_label'
+ test_info_file = 'test_info'
+ extension = '.pt'
+
+ def __init__(self, root, train=True, transform=None, target_transform=None, info_transform=None, download=False,
+ mode="all"):
+
+ self.root = os.path.expanduser(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.info_transform = info_transform
+ self.train = train # training set or test set
+ self.mode = mode
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ # load test or train set
+ image_file = self.train_image_file if self.train else self.test_image_file
+ label_file = self.train_label_file if self.train else self.test_label_file
+ info_file = self.train_info_file if self.train else self.test_info_file
+
+ # load labels
+ self.labels = self._load(label_file)
+
+ # load info files
+ self.infos = self._load(info_file)
+
+ # load right set
+ if self.mode == "left":
+ self.data = self._load("{}_left".format(image_file))
+
+ # load left set
+ elif self.mode == "right":
+ self.data = self._load("{}_right".format(image_file))
+
+ elif self.mode == "all" or self.mode == "stereo":
+ left_data = self._load("{}_left".format(image_file))
+ right_data = self._load("{}_right".format(image_file))
+
+ # load stereo
+ if self.mode == "stereo":
+ self.data = torch.stack((left_data, right_data), dim=1)
+
+ # load all
+ else:
+ self.data = torch.cat((left_data, right_data), dim=0)
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ mode ``all'', ``left'', ``right'':
+ tuple: (image, target, info)
+ mode ``stereo'':
+ tuple: (image left, image right, target, info)
+ """
+ target = self.labels[index % 24300] if self.mode is "all" else self.labels[index]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ info = self.infos[index % 24300] if self.mode is "all" else self.infos[index]
+ if self.info_transform is not None:
+ info = self.info_transform(info)
+
+ if self.mode == "stereo":
+ img_left = self._transform(self.data[index, 0])
+ img_right = self._transform(self.data[index, 1])
+ return img_left, img_right, target, info
+
+ img = self._transform(self.data[index])
+ return img, target, info
+
+ def __len__(self):
+ return len(self.data)
+
+ def _transform(self, img):
+ # doing this so that it is consistent with all other data sets
+ # to return a PIL Image
+ img = Image.fromarray(img.numpy(), mode='L')
+
+ if self.transform is not None:
+ img = self.transform(img)
+ return img
+
+ def _load(self, file_name):
+ return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension))
+
+ def _save(self, file, file_name):
+ with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f:
+ torch.save(file, f)
+
+ def _check_exists(self):
+ """ Check if processed files exists."""
+ files = (
+ "{}_left".format(self.train_image_file),
+ "{}_right".format(self.train_image_file),
+ "{}_left".format(self.test_image_file),
+ "{}_right".format(self.test_image_file),
+ self.test_label_file,
+ self.train_label_file
+ )
+ fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files]
+ return False not in fpaths
+
+ def _flat_data_files(self):
+ return [j for i in self.data_files.values() for j in list(i.values())]
+
+ def _check_integrity(self):
+ """Check if unpacked files have correct md5 sum."""
+ root = self.root
+ for file_dict in self._flat_data_files():
+ filename = file_dict["name"]
+ md5 = file_dict["md5"]
+ fpath = os.path.join(root, self.raw_folder, filename)
+ if not check_integrity(fpath, md5):
+ return False
+ return True
+
+ def download(self):
+ """Download the SmallNORB data if it doesn't exist in processed_folder already."""
+ import gzip
+
+ if self._check_exists():
+ return
+
+ # check if already extracted and verified
+ if self._check_integrity():
+ print('Files already downloaded and verified')
+ else:
+ # download and extract
+ for file_dict in self._flat_data_files():
+ url = self.dataset_root + file_dict["name"] + '.gz'
+ filename = file_dict["name"]
+ gz_filename = filename + '.gz'
+ md5 = file_dict["md5_gz"]
+ fpath = os.path.join(self.root, self.raw_folder, filename)
+ gz_fpath = fpath + '.gz'
+
+ # download if compressed file not exists and verified
+ download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5)
+
+ print('# Extracting data {}\n'.format(filename))
+
+ with open(fpath, 'wb') as out_f, \
+ gzip.GzipFile(gz_fpath) as zip_f:
+ out_f.write(zip_f.read())
+
+ os.unlink(gz_fpath)
+
+ # process and save as torch files
+ print('Processing...')
+
+ # create processed folder
+ try:
+ os.makedirs(os.path.join(self.root, self.processed_folder))
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ pass
+ else:
+ raise
+
+ # read train files
+ left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"])
+ train_info = self._read_info_file(self.data_files["train"]["info"]["name"])
+ train_label = self._read_label_file(self.data_files["train"]["cat"]["name"])
+
+ # read test files
+ left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"])
+ test_info = self._read_info_file(self.data_files["test"]["info"]["name"])
+ test_label = self._read_label_file(self.data_files["test"]["cat"]["name"])
+
+ # save training files
+ self._save(left_train_img, "{}_left".format(self.train_image_file))
+ self._save(right_train_img, "{}_right".format(self.train_image_file))
+ self._save(train_label, self.train_label_file)
+ self._save(train_info, self.train_info_file)
+
+ # save test files
+ self._save(left_test_img, "{}_left".format(self.test_image_file))
+ self._save(right_test_img, "{}_right".format(self.test_image_file))
+ self._save(test_label, self.test_label_file)
+ self._save(test_info, self.test_info_file)
+
+ print('Done!')
+
+ @staticmethod
+ def _parse_header(file_pointer):
+ # Read magic number and ignore
+ struct.unpack('