Skip to content

Commit 038a01f

Browse files
authored
Merge pull request #12 from pytorch/cifar10fix
cifar10/100 only load train/test and not both
2 parents 2d55b9d + 71755da commit 038a01f

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

torchvision/datasets/cifar.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,35 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
4242
+ ' You can use download=True to download it')
4343

4444
# now load the picked numpy arrays
45-
self.train_data = []
46-
self.train_labels = []
47-
for fentry in self.train_list:
48-
f = fentry[0]
45+
if self.train:
46+
self.train_data = []
47+
self.train_labels = []
48+
for fentry in self.train_list:
49+
f = fentry[0]
50+
file = os.path.join(root, self.base_folder, f)
51+
fo = open(file, 'rb')
52+
entry = pickle.load(fo)
53+
self.train_data.append(entry['data'])
54+
if 'labels' in entry:
55+
self.train_labels += entry['labels']
56+
else:
57+
self.train_labels += entry['fine_labels']
58+
fo.close()
59+
60+
self.train_data = np.concatenate(self.train_data)
61+
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
62+
else:
63+
f = self.test_list[0][0]
4964
file = os.path.join(root, self.base_folder, f)
5065
fo = open(file, 'rb')
5166
entry = pickle.load(fo)
52-
self.train_data.append(entry['data'])
67+
self.test_data = entry['data']
5368
if 'labels' in entry:
54-
self.train_labels += entry['labels']
69+
self.test_labels = entry['labels']
5570
else:
56-
self.train_labels += entry['fine_labels']
71+
self.test_labels = entry['fine_labels']
5772
fo.close()
58-
59-
self.train_data = np.concatenate(self.train_data)
60-
61-
f = self.test_list[0][0]
62-
file = os.path.join(root, self.base_folder, f)
63-
fo = open(file, 'rb')
64-
entry = pickle.load(fo)
65-
self.test_data = entry['data']
66-
if 'labels' in entry:
67-
self.test_labels = entry['labels']
68-
else:
69-
self.test_labels = entry['fine_labels']
70-
fo.close()
71-
72-
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
73-
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
73+
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
7474

7575
def __getitem__(self, index):
7676
if self.train:

0 commit comments

Comments
 (0)