Skip to content

Commit b87368e

Browse files
author
Joost van Amersfoort
committed
add python2 compatibility, use dataloader, explicit train and eval call on model, general clean up
1 parent 1174146 commit b87368e

File tree

1 file changed

+42
-39
lines changed

1 file changed

+42
-39
lines changed

VAE/main.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
11
from __future__ import print_function
22
import os
33
import torch
4+
import torch.utils.data
45
import torch.nn as nn
56
import torch.optim as optim
67
from torch.autograd import Variable
78

8-
cuda = torch.cuda.is_available()
9+
# Training settings
10+
BATCH_SIZE = 150
11+
TEST_BATCH_SIZE = 1000
12+
NUM_EPOCHS = 2
913

10-
print('Running with CUDA: {0}'.format(cuda))
1114

15+
cuda = torch.cuda.is_available()
1216

13-
def print_header(msg):
14-
print('===>', msg)
17+
print('====> Running with CUDA: {0}'.format(cuda))
1518

1619

1720
assert os.path.exists('data/processed/training.pt'), \
1821
"Please run python ../mnist/data.py before starting the VAE."
1922

2023
# Data
21-
print_header('Loading data')
24+
print('====> Loading data')
2225
with open('data/processed/training.pt', 'rb') as f:
2326
training_set = torch.load(f)
2427
with open('data/processed/test.pt', 'rb') as f:
@@ -30,20 +33,32 @@ def print_header(msg):
3033
del training_set
3134
del test_set
3235

36+
if cuda:
37+
training_data.cuda()
38+
test_data.cuda()
39+
40+
train_loader = torch.utils.data.DataLoader(training_data,
41+
batch_size=BATCH_SIZE,
42+
shuffle=True)
43+
44+
test_loader = torch.utils.data.DataLoader(test_data,
45+
batch_size=TEST_BATCH_SIZE)
46+
3347
# Model
34-
print_header('Building model')
48+
print('====> Building model')
3549

3650

3751
class VAE(nn.Container):
3852
def __init__(self):
39-
super().__init__()
53+
super(VAE, self).__init__()
4054

4155
self.fc1 = nn.Linear(784, 400)
42-
self.relu = nn.ReLU()
4356
self.fc21 = nn.Linear(400, 20)
4457
self.fc22 = nn.Linear(400, 20)
4558
self.fc3 = nn.Linear(20, 400)
4659
self.fc4 = nn.Linear(400, 784)
60+
61+
self.relu = nn.ReLU()
4762
self.sigmoid = nn.Sigmoid()
4863

4964
def encode(self, x):
@@ -83,50 +98,38 @@ def loss_function(recon_x, x, mu, logvar):
8398
return BCE + KLD
8499

85100

86-
# Training settings
87-
BATCH_SIZE = 150
88-
TEST_BATCH_SIZE = 1000
89-
NUM_EPOCHS = 2
90-
91101
optimizer = optim.Adam(model.parameters(), lr=1e-3)
92102

93103

94104
def train(epoch):
95-
batch_data_t = torch.FloatTensor(BATCH_SIZE, 784)
96-
if cuda:
97-
batch_data_t = batch_data_t.cuda()
98-
batch_data = Variable(batch_data_t, requires_grad=False)
99-
for i in range(0, training_data.size(0), BATCH_SIZE):
105+
model.train()
106+
train_loss = 0
107+
for batch in train_loader:
108+
batch = Variable(batch)
109+
100110
optimizer.zero_grad()
101-
batch_data.data[:] = training_data[i:i + BATCH_SIZE]
102-
recon_batch_data, mu, logvar = model(batch_data)
103-
loss = loss_function(recon_batch_data, batch_data, mu, logvar)
111+
recon_batch, mu, logvar = model(batch)
112+
loss = loss_function(recon_batch, batch, mu, logvar)
104113
loss.backward()
105-
loss = loss.data[0]
114+
train_loss += loss
106115
optimizer.step()
107-
if i % 10 == 0:
108-
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
109-
epoch,
110-
i + BATCH_SIZE, training_data.size(0),
111-
float(i + BATCH_SIZE) / training_data.size(0) * 100,
112-
loss / BATCH_SIZE))
116+
117+
print('====> Epoch: {} Loss: {:.4f}'.format(
118+
epoch,
119+
train_loss.data[0] / training_data.size(0)))
113120

114121

115122
def test(epoch):
123+
model.eval()
116124
test_loss = 0
117-
batch_data_t = torch.FloatTensor(TEST_BATCH_SIZE, 784)
118-
if cuda:
119-
batch_data_t = batch_data_t.cuda()
120-
batch_data = Variable(batch_data_t, volatile=True)
121-
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
122-
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
123-
batch_data.data[:] = test_data[i:i + TEST_BATCH_SIZE]
124-
recon_batch_data, mu, logvar = model(batch_data)
125-
test_loss += loss_function(recon_batch_data, batch_data, mu, logvar)
125+
for batch in test_loader:
126+
batch = Variable(batch)
127+
128+
recon_batch, mu, logvar = model(batch)
129+
test_loss += loss_function(recon_batch, batch, mu, logvar)
126130

127131
test_loss = test_loss.data[0] / test_data.size(0)
128-
print('TEST SET RESULTS:' + ' ' * 20)
129-
print('Average loss: {:.4f}'.format(test_loss))
132+
print('====> Test set results: {:.4f}'.format(test_loss))
130133

131134

132135
for epoch in range(1, NUM_EPOCHS + 1):

0 commit comments

Comments
 (0)