diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index ac7a7269af3..5cc651f1e41 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -46,6 +46,10 @@ class CIFAR10(data.Dataset): ['test_batch', '40351d587109b95175f43aff81a1287e'], ] + meta_list = [ + ['batches.meta', '5ff9c542aee3614f3951f8cda6e48888'], + ] + def __init__(self, root, train=True, transform=None, target_transform=None, download=False): @@ -100,6 +104,16 @@ def __init__(self, root, train=True, self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC + f = self.meta_list[0][0] + file = os.path.join(self.root, self.base_folder, f) + fo = open(file, 'rb') + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') + fo.close() + self.meta = entry + def __getitem__(self, index): """ Args: @@ -133,7 +147,7 @@ def __len__(self): def _check_integrity(self): root = self.root - for fentry in (self.train_list + self.test_list): + for fentry in (self.train_list + self.test_list + self.meta_list): filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): @@ -187,3 +201,7 @@ class CIFAR100(CIFAR10): test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ] + + meta_list = [ + ['meta', '7973b15100ade9c7d40fb424638fde48'], + ]