Skip to content

Commit a3a01ed

Browse files
authored
Merge pull request #20 from y0ast/master
implement VAE
2 parents c88ba0a + b87368e commit a3a01ed

File tree

4 files changed

+154
-0
lines changed

4 files changed

+154
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
mnist/data
2+
VAE/data
23
*.pyc

VAE/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Basic VAE Example
2+
3+
This is an improved implementation of the paper [Stochastic Gradient VB and the
4+
Variational Auto-Encoder](http://arxiv.org/abs/1312.6114) by Kingma and Welling.
5+
It uses ReLUs and the adam optimizer, instead of sigmoids and adagrad. These changes make the network converge much faster.
6+
7+
We reuse the data preparation script of the MNIST experiment
8+
9+
```bash
10+
pip install -r requirements.txt
11+
python ../mnist/data.py
12+
python main.py
13+
```

VAE/main.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from __future__ import print_function
2+
import os
3+
import torch
4+
import torch.utils.data
5+
import torch.nn as nn
6+
import torch.optim as optim
7+
from torch.autograd import Variable
8+
9+
# Training settings
10+
BATCH_SIZE = 150
11+
TEST_BATCH_SIZE = 1000
12+
NUM_EPOCHS = 2
13+
14+
15+
cuda = torch.cuda.is_available()
16+
17+
print('====> Running with CUDA: {0}'.format(cuda))
18+
19+
20+
assert os.path.exists('data/processed/training.pt'), \
21+
"Please run python ../mnist/data.py before starting the VAE."
22+
23+
# Data
24+
print('====> Loading data')
25+
with open('data/processed/training.pt', 'rb') as f:
26+
training_set = torch.load(f)
27+
with open('data/processed/test.pt', 'rb') as f:
28+
test_set = torch.load(f)
29+
30+
training_data = training_set[0].view(-1, 784).div(255)
31+
test_data = test_set[0].view(-1, 784).div(255)
32+
33+
del training_set
34+
del test_set
35+
36+
if cuda:
37+
training_data.cuda()
38+
test_data.cuda()
39+
40+
train_loader = torch.utils.data.DataLoader(training_data,
41+
batch_size=BATCH_SIZE,
42+
shuffle=True)
43+
44+
test_loader = torch.utils.data.DataLoader(test_data,
45+
batch_size=TEST_BATCH_SIZE)
46+
47+
# Model
48+
print('====> Building model')
49+
50+
51+
class VAE(nn.Container):
52+
def __init__(self):
53+
super(VAE, self).__init__()
54+
55+
self.fc1 = nn.Linear(784, 400)
56+
self.fc21 = nn.Linear(400, 20)
57+
self.fc22 = nn.Linear(400, 20)
58+
self.fc3 = nn.Linear(20, 400)
59+
self.fc4 = nn.Linear(400, 784)
60+
61+
self.relu = nn.ReLU()
62+
self.sigmoid = nn.Sigmoid()
63+
64+
def encode(self, x):
65+
h1 = self.relu(self.fc1(x))
66+
return self.fc21(h1), self.fc22(h1)
67+
68+
def reparametrize(self, mu, logvar):
69+
std = logvar.mul(0.5).exp_()
70+
eps = Variable(torch.randn(std.size()), requires_grad=False)
71+
return eps.mul(std).add_(mu)
72+
73+
def decode(self, z):
74+
h3 = self.relu(self.fc3(z))
75+
return self.sigmoid(self.fc4(h3))
76+
77+
def forward(self, x):
78+
mu, logvar = self.encode(x)
79+
z = self.reparametrize(mu, logvar)
80+
return self.decode(z), mu, logvar
81+
82+
83+
model = VAE()
84+
if cuda is True:
85+
model.cuda()
86+
87+
reconstruction_function = nn.BCELoss()
88+
reconstruction_function.size_average = False
89+
90+
91+
def loss_function(recon_x, x, mu, logvar):
92+
BCE = reconstruction_function(recon_x, x)
93+
94+
# Appendix B from VAE paper: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
95+
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
96+
KLD = torch.sum(KLD_element).mul_(-0.5)
97+
98+
return BCE + KLD
99+
100+
101+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
102+
103+
104+
def train(epoch):
105+
model.train()
106+
train_loss = 0
107+
for batch in train_loader:
108+
batch = Variable(batch)
109+
110+
optimizer.zero_grad()
111+
recon_batch, mu, logvar = model(batch)
112+
loss = loss_function(recon_batch, batch, mu, logvar)
113+
loss.backward()
114+
train_loss += loss
115+
optimizer.step()
116+
117+
print('====> Epoch: {} Loss: {:.4f}'.format(
118+
epoch,
119+
train_loss.data[0] / training_data.size(0)))
120+
121+
122+
def test(epoch):
123+
model.eval()
124+
test_loss = 0
125+
for batch in test_loader:
126+
batch = Variable(batch)
127+
128+
recon_batch, mu, logvar = model(batch)
129+
test_loss += loss_function(recon_batch, batch, mu, logvar)
130+
131+
test_loss = test_loss.data[0] / test_data.size(0)
132+
print('====> Test set results: {:.4f}'.format(test_loss))
133+
134+
135+
for epoch in range(1, NUM_EPOCHS + 1):
136+
train(epoch)
137+
test(epoch)

VAE/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch
2+
tqdm
3+
six

0 commit comments

Comments
 (0)