Skip to content

Commit 1174146

Browse files
y0astJoost van Amersfoort
authored andcommitted
implement VAE
1 parent 4ba9ae6 commit 1174146

File tree

4 files changed

+151
-0
lines changed

4 files changed

+151
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
mnist/data
2+
VAE/data
23
*.pyc

VAE/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Basic VAE Example
2+
3+
This is an improved implementation of the paper [Stochastic Gradient VB and the
4+
Variational Auto-Encoder](http://arxiv.org/abs/1312.6114) by Kingma and Welling.
5+
It uses ReLUs and the adam optimizer, instead of sigmoids and adagrad. These changes make the network converge much faster.
6+
7+
We reuse the data preparation script of the MNIST experiment
8+
9+
```bash
10+
pip install -r requirements.txt
11+
python ../mnist/data.py
12+
python main.py
13+
```

VAE/main.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from __future__ import print_function
2+
import os
3+
import torch
4+
import torch.nn as nn
5+
import torch.optim as optim
6+
from torch.autograd import Variable
7+
8+
cuda = torch.cuda.is_available()
9+
10+
print('Running with CUDA: {0}'.format(cuda))
11+
12+
13+
def print_header(msg):
14+
print('===>', msg)
15+
16+
17+
assert os.path.exists('data/processed/training.pt'), \
18+
"Please run python ../mnist/data.py before starting the VAE."
19+
20+
# Data
21+
print_header('Loading data')
22+
with open('data/processed/training.pt', 'rb') as f:
23+
training_set = torch.load(f)
24+
with open('data/processed/test.pt', 'rb') as f:
25+
test_set = torch.load(f)
26+
27+
training_data = training_set[0].view(-1, 784).div(255)
28+
test_data = test_set[0].view(-1, 784).div(255)
29+
30+
del training_set
31+
del test_set
32+
33+
# Model
34+
print_header('Building model')
35+
36+
37+
class VAE(nn.Container):
38+
def __init__(self):
39+
super().__init__()
40+
41+
self.fc1 = nn.Linear(784, 400)
42+
self.relu = nn.ReLU()
43+
self.fc21 = nn.Linear(400, 20)
44+
self.fc22 = nn.Linear(400, 20)
45+
self.fc3 = nn.Linear(20, 400)
46+
self.fc4 = nn.Linear(400, 784)
47+
self.sigmoid = nn.Sigmoid()
48+
49+
def encode(self, x):
50+
h1 = self.relu(self.fc1(x))
51+
return self.fc21(h1), self.fc22(h1)
52+
53+
def reparametrize(self, mu, logvar):
54+
std = logvar.mul(0.5).exp_()
55+
eps = Variable(torch.randn(std.size()), requires_grad=False)
56+
return eps.mul(std).add_(mu)
57+
58+
def decode(self, z):
59+
h3 = self.relu(self.fc3(z))
60+
return self.sigmoid(self.fc4(h3))
61+
62+
def forward(self, x):
63+
mu, logvar = self.encode(x)
64+
z = self.reparametrize(mu, logvar)
65+
return self.decode(z), mu, logvar
66+
67+
68+
model = VAE()
69+
if cuda is True:
70+
model.cuda()
71+
72+
reconstruction_function = nn.BCELoss()
73+
reconstruction_function.size_average = False
74+
75+
76+
def loss_function(recon_x, x, mu, logvar):
77+
BCE = reconstruction_function(recon_x, x)
78+
79+
# Appendix B from VAE paper: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
80+
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
81+
KLD = torch.sum(KLD_element).mul_(-0.5)
82+
83+
return BCE + KLD
84+
85+
86+
# Training settings
87+
BATCH_SIZE = 150
88+
TEST_BATCH_SIZE = 1000
89+
NUM_EPOCHS = 2
90+
91+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
92+
93+
94+
def train(epoch):
95+
batch_data_t = torch.FloatTensor(BATCH_SIZE, 784)
96+
if cuda:
97+
batch_data_t = batch_data_t.cuda()
98+
batch_data = Variable(batch_data_t, requires_grad=False)
99+
for i in range(0, training_data.size(0), BATCH_SIZE):
100+
optimizer.zero_grad()
101+
batch_data.data[:] = training_data[i:i + BATCH_SIZE]
102+
recon_batch_data, mu, logvar = model(batch_data)
103+
loss = loss_function(recon_batch_data, batch_data, mu, logvar)
104+
loss.backward()
105+
loss = loss.data[0]
106+
optimizer.step()
107+
if i % 10 == 0:
108+
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
109+
epoch,
110+
i + BATCH_SIZE, training_data.size(0),
111+
float(i + BATCH_SIZE) / training_data.size(0) * 100,
112+
loss / BATCH_SIZE))
113+
114+
115+
def test(epoch):
116+
test_loss = 0
117+
batch_data_t = torch.FloatTensor(TEST_BATCH_SIZE, 784)
118+
if cuda:
119+
batch_data_t = batch_data_t.cuda()
120+
batch_data = Variable(batch_data_t, volatile=True)
121+
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
122+
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
123+
batch_data.data[:] = test_data[i:i + TEST_BATCH_SIZE]
124+
recon_batch_data, mu, logvar = model(batch_data)
125+
test_loss += loss_function(recon_batch_data, batch_data, mu, logvar)
126+
127+
test_loss = test_loss.data[0] / test_data.size(0)
128+
print('TEST SET RESULTS:' + ' ' * 20)
129+
print('Average loss: {:.4f}'.format(test_loss))
130+
131+
132+
for epoch in range(1, NUM_EPOCHS + 1):
133+
train(epoch)
134+
test(epoch)

VAE/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch
2+
tqdm
3+
six

0 commit comments

Comments
 (0)