@@ -46,6 +46,10 @@ class CIFAR10(data.Dataset):
46
46
['test_batch' , '40351d587109b95175f43aff81a1287e' ],
47
47
]
48
48
49
+ meta_list = [
50
+ ['batches.meta' , '5ff9c542aee3614f3951f8cda6e48888' ],
51
+ ]
52
+
49
53
def __init__ (self , root , train = True ,
50
54
transform = None , target_transform = None ,
51
55
download = False ):
@@ -100,6 +104,16 @@ def __init__(self, root, train=True,
100
104
self .test_data = self .test_data .reshape ((10000 , 3 , 32 , 32 ))
101
105
self .test_data = self .test_data .transpose ((0 , 2 , 3 , 1 )) # convert to HWC
102
106
107
+ f = self .meta_list [0 ][0 ]
108
+ file = os .path .join (self .root , self .base_folder , f )
109
+ fo = open (file , 'rb' )
110
+ if sys .version_info [0 ] == 2 :
111
+ entry = pickle .load (fo )
112
+ else :
113
+ entry = pickle .load (fo , encoding = 'latin1' )
114
+ fo .close ()
115
+ self .meta = entry
116
+
103
117
def __getitem__ (self , index ):
104
118
"""
105
119
Args:
@@ -133,7 +147,7 @@ def __len__(self):
133
147
134
148
def _check_integrity (self ):
135
149
root = self .root
136
- for fentry in (self .train_list + self .test_list ):
150
+ for fentry in (self .train_list + self .test_list + self . meta_list ):
137
151
filename , md5 = fentry [0 ], fentry [1 ]
138
152
fpath = os .path .join (root , self .base_folder , filename )
139
153
if not check_integrity (fpath , md5 ):
@@ -187,3 +201,7 @@ class CIFAR100(CIFAR10):
187
201
test_list = [
188
202
['test' , 'f0ef6b0ae62326f3e7ffdfab6717acfc' ],
189
203
]
204
+
205
+ meta_list = [
206
+ ['meta' , '7973b15100ade9c7d40fb424638fde48' ],
207
+ ]
0 commit comments