Skip to content

Commit 3f84497

Browse files
author
Philip Meier
committed
cast images to PIL at instantiation instead of runtime
1 parent 6c56029 commit 3f84497

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

torchvision/datasets/cifar.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def __init__(self, root, train=True,
9494
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
9595
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
9696

97+
# doing this so that it is consistent with all other datasets
98+
# to return a PIL Image
99+
self.data = [Image.fromarray(image) for image in self.data]
100+
97101
self._load_meta()
98102

99103
def _load_meta(self):
@@ -119,10 +123,6 @@ def __getitem__(self, index):
119123
"""
120124
img, target = self.data[index], self.targets[index]
121125

122-
# doing this so that it is consistent with all other datasets
123-
# to return a PIL Image
124-
img = Image.fromarray(img)
125-
126126
if self.transform is not None:
127127
img = self.transform(img)
128128

torchvision/datasets/svhn.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,20 @@ def __init__(self, root, split='train',
6666
import scipy.io as sio
6767

6868
# reading(loading) mat file as array
69-
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
69+
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename),
70+
squeeze_me=True)
71+
data, targets = loaded_mat['X'], loaded_mat['y']
7072

71-
self.data = loaded_mat['X']
72-
# loading from the .mat file gives an np array of type np.uint8
73-
# converting to np.int64, so that we have a LongTensor after
74-
# the conversion from the numpy array
75-
# the squeeze is needed to obtain a 1D tensor
76-
self.labels = loaded_mat['y'].astype(np.int64).squeeze()
73+
# doing this so that it is consistent with all other datasets
74+
# to return a PIL Image
75+
self.data = [Image.fromarray(image.squeeze(3))
76+
for image in np.split(data, len(targets), axis=3)]
7777

7878
# the svhn dataset assigns the class label "10" to the digit 0
7979
# this makes it inconsistent with several loss functions
8080
# which expect the class labels to be in the range [0, C-1]
81-
np.place(self.labels, self.labels == 10, 0)
82-
self.data = np.transpose(self.data, (3, 2, 0, 1))
81+
np.place(targets, targets == 10, 0)
82+
self.targets = [int(target) for target in targets]
8383

8484
def __getitem__(self, index):
8585
"""
@@ -89,11 +89,7 @@ def __getitem__(self, index):
8989
Returns:
9090
tuple: (image, target) where target is index of the target class.
9191
"""
92-
img, target = self.data[index], int(self.labels[index])
93-
94-
# doing this so that it is consistent with all other datasets
95-
# to return a PIL Image
96-
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
92+
img, target = self.data[index], self.targets[index]
9793

9894
if self.transform is not None:
9995
img = self.transform(img)

0 commit comments

Comments
 (0)