@@ -42,35 +42,35 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
42
42
+ ' You can use download=True to download it' )
43
43
44
44
# now load the picked numpy arrays
45
- self .train_data = []
46
- self .train_labels = []
47
- for fentry in self .train_list :
48
- f = fentry [0 ]
45
+ if self .train :
46
+ self .train_data = []
47
+ self .train_labels = []
48
+ for fentry in self .train_list :
49
+ f = fentry [0 ]
50
+ file = os .path .join (root , self .base_folder , f )
51
+ fo = open (file , 'rb' )
52
+ entry = pickle .load (fo )
53
+ self .train_data .append (entry ['data' ])
54
+ if 'labels' in entry :
55
+ self .train_labels += entry ['labels' ]
56
+ else :
57
+ self .train_labels += entry ['fine_labels' ]
58
+ fo .close ()
59
+
60
+ self .train_data = np .concatenate (self .train_data )
61
+ self .train_data = self .train_data .reshape ((50000 , 3 , 32 , 32 ))
62
+ else :
63
+ f = self .test_list [0 ][0 ]
49
64
file = os .path .join (root , self .base_folder , f )
50
65
fo = open (file , 'rb' )
51
66
entry = pickle .load (fo )
52
- self .train_data . append ( entry ['data' ])
67
+ self .test_data = entry ['data' ]
53
68
if 'labels' in entry :
54
- self .train_labels + = entry ['labels' ]
69
+ self .test_labels = entry ['labels' ]
55
70
else :
56
- self .train_labels + = entry ['fine_labels' ]
71
+ self .test_labels = entry ['fine_labels' ]
57
72
fo .close ()
58
-
59
- self .train_data = np .concatenate (self .train_data )
60
-
61
- f = self .test_list [0 ][0 ]
62
- file = os .path .join (root , self .base_folder , f )
63
- fo = open (file , 'rb' )
64
- entry = pickle .load (fo )
65
- self .test_data = entry ['data' ]
66
- if 'labels' in entry :
67
- self .test_labels = entry ['labels' ]
68
- else :
69
- self .test_labels = entry ['fine_labels' ]
70
- fo .close ()
71
-
72
- self .train_data = self .train_data .reshape ((50000 , 3 , 32 , 32 ))
73
- self .test_data = self .test_data .reshape ((10000 , 3 , 32 , 32 ))
73
+ self .test_data = self .test_data .reshape ((10000 , 3 , 32 , 32 ))
74
74
75
75
def __getitem__ (self , index ):
76
76
if self .train :
0 commit comments