1
1
from __future__ import print_function
2
2
import os
3
3
import torch
4
+ import torch .utils .data
4
5
import torch .nn as nn
5
6
import torch .optim as optim
6
7
from torch .autograd import Variable
7
8
8
- cuda = torch .cuda .is_available ()
9
+ # Training settings
10
+ BATCH_SIZE = 150
11
+ TEST_BATCH_SIZE = 1000
12
+ NUM_EPOCHS = 2
9
13
10
- print ('Running with CUDA: {0}' .format (cuda ))
11
14
15
+ cuda = torch .cuda .is_available ()
12
16
13
- def print_header (msg ):
14
- print ('===>' , msg )
17
+ print ('====> Running with CUDA: {0}' .format (cuda ))
15
18
16
19
17
20
assert os .path .exists ('data/processed/training.pt' ), \
18
21
"Please run python ../mnist/data.py before starting the VAE."
19
22
20
23
# Data
21
- print_header ( ' Loading data' )
24
+ print ( '====> Loading data' )
22
25
with open ('data/processed/training.pt' , 'rb' ) as f :
23
26
training_set = torch .load (f )
24
27
with open ('data/processed/test.pt' , 'rb' ) as f :
@@ -30,20 +33,32 @@ def print_header(msg):
30
33
del training_set
31
34
del test_set
32
35
36
+ if cuda :
37
+ training_data .cuda ()
38
+ test_data .cuda ()
39
+
40
+ train_loader = torch .utils .data .DataLoader (training_data ,
41
+ batch_size = BATCH_SIZE ,
42
+ shuffle = True )
43
+
44
+ test_loader = torch .utils .data .DataLoader (test_data ,
45
+ batch_size = TEST_BATCH_SIZE )
46
+
33
47
# Model
34
- print_header ( ' Building model' )
48
+ print ( '====> Building model' )
35
49
36
50
37
51
class VAE (nn .Container ):
38
52
def __init__ (self ):
39
- super ().__init__ ()
53
+ super (VAE , self ).__init__ ()
40
54
41
55
self .fc1 = nn .Linear (784 , 400 )
42
- self .relu = nn .ReLU ()
43
56
self .fc21 = nn .Linear (400 , 20 )
44
57
self .fc22 = nn .Linear (400 , 20 )
45
58
self .fc3 = nn .Linear (20 , 400 )
46
59
self .fc4 = nn .Linear (400 , 784 )
60
+
61
+ self .relu = nn .ReLU ()
47
62
self .sigmoid = nn .Sigmoid ()
48
63
49
64
def encode (self , x ):
@@ -83,50 +98,38 @@ def loss_function(recon_x, x, mu, logvar):
83
98
return BCE + KLD
84
99
85
100
86
- # Training settings
87
- BATCH_SIZE = 150
88
- TEST_BATCH_SIZE = 1000
89
- NUM_EPOCHS = 2
90
-
91
101
optimizer = optim .Adam (model .parameters (), lr = 1e-3 )
92
102
93
103
94
104
def train (epoch ):
95
- batch_data_t = torch . FloatTensor ( BATCH_SIZE , 784 )
96
- if cuda :
97
- batch_data_t = batch_data_t . cuda ()
98
- batch_data = Variable (batch_data_t , requires_grad = False )
99
- for i in range ( 0 , training_data . size ( 0 ), BATCH_SIZE ):
105
+ model . train ( )
106
+ train_loss = 0
107
+ for batch in train_loader :
108
+ batch = Variable (batch )
109
+
100
110
optimizer .zero_grad ()
101
- batch_data .data [:] = training_data [i :i + BATCH_SIZE ]
102
- recon_batch_data , mu , logvar = model (batch_data )
103
- loss = loss_function (recon_batch_data , batch_data , mu , logvar )
111
+ recon_batch , mu , logvar = model (batch )
112
+ loss = loss_function (recon_batch , batch , mu , logvar )
104
113
loss .backward ()
105
- loss = loss . data [ 0 ]
114
+ train_loss + = loss
106
115
optimizer .step ()
107
- if i % 10 == 0 :
108
- print ('Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.4f}' .format (
109
- epoch ,
110
- i + BATCH_SIZE , training_data .size (0 ),
111
- float (i + BATCH_SIZE ) / training_data .size (0 ) * 100 ,
112
- loss / BATCH_SIZE ))
116
+
117
+ print ('====> Epoch: {} Loss: {:.4f}' .format (
118
+ epoch ,
119
+ train_loss .data [0 ] / training_data .size (0 )))
113
120
114
121
115
122
def test (epoch ):
123
+ model .eval ()
116
124
test_loss = 0
117
- batch_data_t = torch .FloatTensor (TEST_BATCH_SIZE , 784 )
118
- if cuda :
119
- batch_data_t = batch_data_t .cuda ()
120
- batch_data = Variable (batch_data_t , volatile = True )
121
- for i in range (0 , test_data .size (0 ), TEST_BATCH_SIZE ):
122
- print ('Testing model: {}/{}' .format (i , test_data .size (0 )), end = '\r ' )
123
- batch_data .data [:] = test_data [i :i + TEST_BATCH_SIZE ]
124
- recon_batch_data , mu , logvar = model (batch_data )
125
- test_loss += loss_function (recon_batch_data , batch_data , mu , logvar )
125
+ for batch in test_loader :
126
+ batch = Variable (batch )
127
+
128
+ recon_batch , mu , logvar = model (batch )
129
+ test_loss += loss_function (recon_batch , batch , mu , logvar )
126
130
127
131
test_loss = test_loss .data [0 ] / test_data .size (0 )
128
- print ('TEST SET RESULTS:' + ' ' * 20 )
129
- print ('Average loss: {:.4f}' .format (test_loss ))
132
+ print ('====> Test set results: {:.4f}' .format (test_loss ))
130
133
131
134
132
135
for epoch in range (1 , NUM_EPOCHS + 1 ):
0 commit comments