@@ -66,20 +66,20 @@ def __init__(self, root, split='train',
66
66
import scipy .io as sio
67
67
68
68
# reading(loading) mat file as array
69
- loaded_mat = sio .loadmat (os .path .join (self .root , self .filename ))
69
+ loaded_mat = sio .loadmat (os .path .join (self .root , self .filename ),
70
+ squeeze_me = True )
71
+ data , targets = loaded_mat ['X' ], loaded_mat ['y' ]
70
72
71
- self .data = loaded_mat ['X' ]
72
- # loading from the .mat file gives an np array of type np.uint8
73
- # converting to np.int64, so that we have a LongTensor after
74
- # the conversion from the numpy array
75
- # the squeeze is needed to obtain a 1D tensor
76
- self .labels = loaded_mat ['y' ].astype (np .int64 ).squeeze ()
73
+ # doing this so that it is consistent with all other datasets
74
+ # to return a PIL Image
75
+ self .data = [Image .fromarray (image .squeeze (3 ))
76
+ for image in np .split (data , len (targets ), axis = 3 )]
77
77
78
78
# the svhn dataset assigns the class label "10" to the digit 0
79
79
# this makes it inconsistent with several loss functions
80
80
# which expect the class labels to be in the range [0, C-1]
81
- np .place (self . labels , self . labels == 10 , 0 )
82
- self .data = np . transpose ( self . data , ( 3 , 2 , 0 , 1 ))
81
+ np .place (targets , targets == 10 , 0 )
82
+ self .targets = [ int ( target ) for target in targets ]
83
83
84
84
def __getitem__ (self , index ):
85
85
"""
@@ -89,11 +89,7 @@ def __getitem__(self, index):
89
89
Returns:
90
90
tuple: (image, target) where target is index of the target class.
91
91
"""
92
- img , target = self .data [index ], int (self .labels [index ])
93
-
94
- # doing this so that it is consistent with all other datasets
95
- # to return a PIL Image
96
- img = Image .fromarray (np .transpose (img , (1 , 2 , 0 )))
92
+ img , target = self .data [index ], self .targets [index ]
97
93
98
94
if self .transform is not None :
99
95
img = self .transform (img )
0 commit comments