1
1
from __future__ import print_function
2
- import os
2
+ import os , argparse
3
3
import torch
4
4
import torch .nn as nn
5
5
import torch .optim as optim
6
6
from torch .autograd import Variable
7
7
8
8
cuda = torch .cuda .is_available ()
9
9
10
- def print_header (msg ):
11
- print ('===>' , msg )
10
+ # Training settings
11
+ parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
12
+ parser .add_argument ('--batchSize' , type = int , default = 64 , metavar = 'input batch size' )
13
+ parser .add_argument ('--testBatchSize' , type = int , default = 1000 , metavar = 'input batch size for testing' )
14
+ parser .add_argument ('--trainSize' , type = int , default = 1000 , metavar = 'Train dataset size (max=60000). Default: 1000' )
15
+ parser .add_argument ('--nEpochs' , type = int , default = 2 , metavar = 'number of epochs to train' )
16
+ parser .add_argument ('--lr' , type = float , default = 0.01 , metavar = 'Learning Rate. Default=0.01' )
17
+ parser .add_argument ('--momentum' , type = float , default = 0.5 , metavar = 'Default=0.5' )
18
+ parser .add_argument ('--seed' , type = int , default = 123 , metavar = 'Random Seed to use. Default=123' )
19
+ opt = parser .parse_args ()
20
+ print (opt )
21
+
22
+ torch .manual_seed (opt .seed )
23
+ if cuda == True :
24
+ torch .cuda .manual_seed (opt .seed )
12
25
13
26
if not os .path .exists ('data/processed/training.pt' ):
14
27
import data
15
28
16
29
# Data
17
- print_header ( ' Loading data' )
30
+ print ( '===> Loading data' )
18
31
with open ('data/processed/training.pt' , 'rb' ) as f :
19
32
training_set = torch .load (f )
20
33
with open ('data/processed/test.pt' , 'rb' ) as f :
21
34
test_set = torch .load (f )
22
35
23
36
training_data = training_set [0 ].view (- 1 , 1 , 28 , 28 ).div (255 )
37
+ training_data = training_data [:opt .trainSize ]
24
38
training_labels = training_set [1 ]
25
39
test_data = test_set [0 ].view (- 1 , 1 , 28 , 28 ).div (255 )
26
40
test_labels = test_set [1 ]
27
41
28
42
del training_set
29
43
del test_set
30
44
31
- # Model
32
- print_header ('Building model' )
45
+ print ('===> Building model' )
33
46
class Net (nn .Container ):
34
47
def __init__ (self ):
35
- super (Net , self ).__init__ (
36
- conv1 = nn .Conv2d (1 , 20 , 5 ),
37
- pool1 = nn .MaxPool2d (2 , 2 ),
38
- conv2 = nn .Conv2d (20 , 50 , 5 ),
39
- pool2 = nn .MaxPool2d (2 , 2 ),
40
- fc1 = nn .Linear (800 , 500 ),
41
- fc2 = nn .Linear (500 , 10 ),
42
- relu = nn .ReLU (),
43
- softmax = nn .LogSoftmax (),
44
- )
48
+ super (Net , self ).__init__ ()
49
+ self .conv1 = nn .Conv2d (1 , 10 , 5 )
50
+ self .pool1 = nn .MaxPool2d (2 ,2 )
51
+ self .conv2 = nn .Conv2d (10 , 20 , 5 )
52
+ self .pool2 = nn .MaxPool2d (2 , 2 )
53
+ self .fc1 = nn .Linear (320 , 50 )
54
+ self .fc2 = nn .Linear (50 , 10 )
55
+ self .relu = nn .ReLU ()
56
+ self .softmax = nn .LogSoftmax ()
45
57
46
58
def forward (self , x ):
47
59
x = self .relu (self .pool1 (self .conv1 (x )))
48
60
x = self .relu (self .pool2 (self .conv2 (x )))
49
- x = x .view (- 1 , 800 )
61
+ x = x .view (- 1 , 320 )
50
62
x = self .relu (self .fc1 (x ))
51
63
x = self .relu (self .fc2 (x ))
52
64
return self .softmax (x )
@@ -56,60 +68,59 @@ def forward(self, x):
56
68
model .cuda ()
57
69
58
70
criterion = nn .NLLLoss ()
59
-
60
- # Training settings
61
- BATCH_SIZE = 150
62
- TEST_BATCH_SIZE = 1000
63
- NUM_EPOCHS = 2
64
-
65
- optimizer = optim .SGD (model .parameters (), lr = 1e-2 , momentum = 0.9 )
71
+ optimizer = optim .SGD (model .parameters (), lr = opt .lr , momentum = opt .momentum )
66
72
67
73
def train (epoch ):
68
- batch_data_t = torch .FloatTensor (BATCH_SIZE , 1 , 28 , 28 )
69
- batch_targets_t = torch .LongTensor (BATCH_SIZE )
74
+ # create buffers for mini-batch
75
+ batch_data = torch .FloatTensor (opt .batchSize , 1 , 28 , 28 )
76
+ batch_targets = torch .LongTensor (opt .batchSize )
70
77
if cuda :
71
- batch_data_t = batch_data_t .cuda ()
72
- batch_targets_t = batch_targets_t .cuda ()
73
- batch_data = Variable (batch_data_t , requires_grad = False )
74
- batch_targets = Variable (batch_targets_t , requires_grad = False )
75
- for i in range (0 , training_data .size (0 ), BATCH_SIZE ):
78
+ batch_data , batch_targets = batch_data .cuda (), batch_targets .cuda ()
79
+
80
+ # create autograd Variables over these buffers
81
+ batch_data , batch_targets = Variable (batch_data ), Variable (batch_targets )
82
+
83
+ for i in range (0 , training_data .size (0 )- opt .batchSize + 1 , opt .batchSize ):
84
+ start , end = i , i + opt .batchSize
76
85
optimizer .zero_grad ()
77
- batch_data .data [:] = training_data [i :i + BATCH_SIZE ]
78
- batch_targets .data [:] = training_labels [i :i + BATCH_SIZE ]
79
- loss = criterion (model (batch_data ), batch_targets )
86
+ batch_data .data [:] = training_data [start :end ]
87
+ batch_targets .data [:] = training_labels [start :end ]
88
+ output = model (batch_data )
89
+ loss = criterion (output , batch_targets )
80
90
loss .backward ()
81
91
loss = loss .data [0 ]
82
92
optimizer .step ()
83
- print ('Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.4f}' .format (epoch ,
84
- i + BATCH_SIZE , training_data .size (0 ),
85
- float (i + BATCH_SIZE )/ training_data .size (0 )* 100 , loss ))
93
+ print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.4f}'
94
+ .format (epoch , end , opt .trainSize , float (end )/ opt .trainSize * 100 , loss ))
86
95
87
96
def test (epoch ):
88
- test_loss = 0
89
- batch_data_t = torch .FloatTensor (TEST_BATCH_SIZE , 1 , 28 , 28 )
90
- batch_targets_t = torch .LongTensor (TEST_BATCH_SIZE )
97
+ # create buffers for mini-batch
98
+ batch_data = torch .FloatTensor (opt . testBatchSize , 1 , 28 , 28 )
99
+ batch_targets = torch .LongTensor (opt . testBatchSize )
91
100
if cuda :
92
- batch_data_t = batch_data_t .cuda ()
93
- batch_targets_t = batch_targets_t .cuda ()
94
- batch_data = Variable (batch_data_t , volatile = True )
95
- batch_targets = Variable (batch_targets_t , volatile = True )
101
+ batch_data , batch_targets = batch_data .cuda (), batch_targets .cuda ()
102
+
103
+ # create autograd Variables over these buffers
104
+ batch_data = Variable (batch_data , volatile = True )
105
+ batch_targets = Variable (batch_targets , volatile = True )
106
+
107
+ test_loss = 0
96
108
correct = 0
97
- for i in range ( 0 , test_data . size ( 0 ), TEST_BATCH_SIZE ):
98
- print ( 'Testing model: {}/{}' . format ( i , test_data .size (0 )), end = ' \r ' )
99
- batch_data .data [:] = test_data [i :i + TEST_BATCH_SIZE ]
100
- batch_targets .data [:] = test_labels [i :i + TEST_BATCH_SIZE ]
109
+
110
+ for i in range ( 0 , test_data .size (0 ), opt . testBatchSize ):
111
+ batch_data .data [:] = test_data [i :i + opt . testBatchSize ]
112
+ batch_targets .data [:] = test_labels [i :i + opt . testBatchSize ]
101
113
output = model (batch_data )
102
114
test_loss += criterion (output , batch_targets )
103
- pred = output .data .max (1 )[1 ]
115
+ pred = output .data .max (1 )[1 ] # get the index of the max log-probability
104
116
correct += pred .long ().eq (batch_targets .data .long ()).cpu ().sum ()
105
117
106
118
test_loss = test_loss .data [0 ]
107
- test_loss /= (test_data .size (0 ) / TEST_BATCH_SIZE ) # criterion averages over batch size
108
- print ('TEST SET RESULTS:' + ' ' * 20 )
109
- print ('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)' .format (
119
+ test_loss /= (test_data .size (0 ) / opt .testBatchSize ) # criterion averages over batch size
120
+ print ('\n Test Set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
110
121
test_loss , correct , test_data .size (0 ),
111
122
float (correct )/ test_data .size (0 )* 100 ))
112
123
113
- for epoch in range (1 , NUM_EPOCHS + 1 ):
124
+ for epoch in range (1 , opt . nEpochs + 1 ):
114
125
train (epoch )
115
126
test (epoch )
0 commit comments