Skip to content

Commit 63dabca

Browse files
authored
Merge pull request #3 from pytorch/cifar
cifar 10 and 100
2 parents e37323d + 754d526 commit 63dabca

File tree

5 files changed

+191
-1
lines changed

5 files changed

+191
-1
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
build/
2+
dist/
3+
torchvision.egg-info/
4+
*/**/__pycache__
5+
*/**/*.pyc
6+
*/**/*~
7+
*~

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The following dataset loaders are available:
2929
- [LSUN Classification](#lsun)
3030
- [ImageFolder](#imagefolder)
3131
- [Imagenet-12](#imagenet-12)
32+
- [CIFAR10 and CIFAR100](#cifar)
3233

3334
Datasets have the API:
3435
- `__getitem__`
@@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background']
9798
- ['bedroom_train', 'church_train', ...] : a list of categories to load
9899

99100

101+
### CIFAR
102+
103+
`dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)`
104+
`dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)`
105+
106+
- `root` : root directory of dataset where there is folder `cifar-10-batches-py`
107+
- `train` : `True` = Training set, `False` = Test set
108+
- `download` : `True` = downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything.
109+
100110
### ImageFolder
101111

102112
A generic data loader where the images are arranged in this way:

test/cifar.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
import torchvision.datasets as dset
3+
4+
print('\n\nCifar 10')
5+
a = dset.CIFAR10(root="abc/def/ghi", download=True)
6+
7+
print(a[3])
8+
9+
print('\n\nCifar 100')
10+
a = dset.CIFAR100(root="abc/def/ghi", download=True)
11+
12+
print(a[3])

torchvision/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .lsun import LSUN, LSUNClass
22
from .folder import ImageFolder
33
from .coco import CocoCaptions, CocoDetection
4+
from .cifar import CIFAR10, CIFAR100
45

56
__all__ = ('LSUN', 'LSUNClass',
67
'ImageFolder',
7-
'CocoCaptions', 'CocoDetection')
8+
'CocoCaptions', 'CocoDetection',
9+
'CIFAR10', 'CIFAR100')

torchvision/datasets/cifar.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import print_function
2+
import torch.utils.data as data
3+
from PIL import Image
4+
import os
5+
import os.path
6+
import errno
7+
import numpy as np
8+
import sys
9+
if sys.version_info[0] == 2:
10+
import cPickle as pickle
11+
else:
12+
import pickle
13+
14+
class CIFAR10(data.Dataset):
15+
base_folder = 'cifar-10-batches-py'
16+
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
17+
filename = "cifar-10-python.tar.gz"
18+
tgz_mdf = 'c58f30108f718f92721af3b95e74349a'
19+
train_list = [
20+
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
21+
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
22+
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
23+
['data_batch_4', '634d18415352ddfa80567beed471001a'],
24+
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
25+
]
26+
27+
test_list = [
28+
['test_batch', '40351d587109b95175f43aff81a1287e'],
29+
]
30+
31+
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
32+
self.root = root
33+
self.transform = transform
34+
self.target_transform = target_transform
35+
self.train = train # training set or test set
36+
37+
if download:
38+
self.download()
39+
40+
if not self._check_integrity():
41+
raise RuntimeError('Dataset not found or corrupted.'
42+
+ ' You can use download=True to download it')
43+
44+
# now load the picked numpy arrays
45+
self.train_data = []
46+
self.train_labels = []
47+
for fentry in self.train_list:
48+
f = fentry[0]
49+
file = os.path.join(root, self.base_folder, f)
50+
fo = open(file, 'rb')
51+
entry = pickle.load(fo)
52+
self.train_data.append(entry['data'])
53+
if 'labels' in entry:
54+
self.train_labels += entry['labels']
55+
else:
56+
self.train_labels += entry['fine_labels']
57+
fo.close()
58+
59+
self.train_data = np.concatenate(self.train_data)
60+
61+
f = self.test_list[0][0]
62+
file = os.path.join(root, self.base_folder, f)
63+
fo = open(file, 'rb')
64+
entry = pickle.load(fo)
65+
self.test_data = entry['data']
66+
if 'labels' in entry:
67+
self.test_labels = entry['labels']
68+
else:
69+
self.test_labels = entry['fine_labels']
70+
fo.close()
71+
72+
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
73+
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
74+
75+
def __getitem__(self, index):
76+
if self.train:
77+
img, target = self.train_data[index], self.train_labels[index]
78+
else:
79+
img, target = self.test_data[index], self.test_labels[index]
80+
81+
if self.transform is not None:
82+
img = self.transform(img)
83+
84+
if self.target_transform is not None:
85+
target = self.target_transform(target)
86+
87+
return img, target
88+
89+
def __len__(self):
90+
if self.train:
91+
return 50000
92+
else:
93+
return 10000
94+
95+
def _check_integrity(self):
96+
import hashlib
97+
root = self.root
98+
for fentry in (self.train_list + self.test_list):
99+
filename, md5 = fentry[0], fentry[1]
100+
fpath = os.path.join(root, self.base_folder, filename)
101+
if not os.path.isfile(fpath):
102+
return False
103+
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
104+
if md5c != md5:
105+
return False
106+
return True
107+
108+
def download(self):
109+
from six.moves import urllib
110+
import tarfile
111+
import hashlib
112+
113+
root = self.root
114+
fpath = os.path.join(root, self.filename)
115+
116+
try:
117+
os.makedirs(root)
118+
except OSError as e:
119+
if e.errno == errno.EEXIST:
120+
pass
121+
else:
122+
raise
123+
124+
if self._check_integrity():
125+
print('Files already downloaded and verified')
126+
return
127+
128+
# downloads file
129+
if os.path.isfile(fpath) and \
130+
hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5:
131+
print('Using downloaded file: ' + fpath)
132+
else:
133+
print('Downloading ' + self.url + ' to ' + fpath)
134+
urllib.request.urlretrieve(self.url, fpath)
135+
136+
# extract file
137+
cwd = os.getcwd()
138+
print('Extracting tar file')
139+
tar = tarfile.open(fpath, "r:gz")
140+
os.chdir(root)
141+
tar.extractall()
142+
tar.close()
143+
os.chdir(cwd)
144+
print('Done!')
145+
146+
147+
class CIFAR100(CIFAR10):
148+
base_folder = 'cifar-100-python'
149+
url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
150+
filename = "cifar-100-python.tar.gz"
151+
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
152+
train_list = [
153+
['train', '16019d7e3df5f24257cddd939b257f8d'],
154+
]
155+
156+
test_list = [
157+
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
158+
]
159+

0 commit comments

Comments
 (0)