Skip to content

Commit 6937fce

Browse files
committed
dcgan
1 parent 5c32625 commit 6937fce

File tree

3 files changed

+280
-0
lines changed

3 files changed

+280
-0
lines changed

dcgan/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Deep Convolution Generative Adversarial Networks
2+
3+
This example implements the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434)
4+
5+
The implementation is very close to the Torch implementation [dcgan.torch](https://github.com/soumith/dcgan.torch)
6+
7+
After every 100 training iterations, the files `real_samples.png` and `fake_samples.png` are written to disk
8+
with the samples from the generative model.
9+
10+
After every epoch, models are saved to: `netG_epoch_%d.pth` and `netD_epoch_%d.pth`
11+
12+
```
13+
usage: main.py [-h] --dataset DATASET --dataroot DATAROOT [--workers WORKERS]
14+
[--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ]
15+
[--ngf NGF] [--ndf NDF] [--niter NITER] [--lr LR]
16+
[--beta1 BETA1] [--cuda] [--netG NETG] [--netD NETD]
17+
18+
optional arguments:
19+
-h, --help show this help message and exit
20+
--dataset DATASET cifar10 | lsun | imagenet | folder | lfw
21+
--dataroot DATAROOT path to dataset
22+
--workers WORKERS number of data loading workers
23+
--batchSize BATCHSIZE
24+
input batch size
25+
--imageSize IMAGESIZE
26+
the height / width of the input image to network
27+
--nz NZ size of the latent z vector
28+
--ngf NGF
29+
--ndf NDF
30+
--niter NITER number of epochs to train for
31+
--lr LR learning rate, default=0.0002
32+
--beta1 BETA1 beta1 for adam. default=0.5
33+
--cuda enables cuda
34+
--netG NETG path to netG (to continue training)
35+
--netD NETD path to netD (to continue training)
36+
```

dcgan/main.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
from __future__ import print_function
2+
import argparse
3+
import random
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.parallel
7+
import torch.backends.cudnn as cudnn
8+
import torch.optim as optim
9+
import torch.utils.data
10+
import torchvision.datasets as dset
11+
import torchvision.transforms as transforms
12+
import torchvision.utils as vutils
13+
from torch.autograd import Variable
14+
15+
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
18+
parser.add_argument('--dataroot', required=True, help='path to dataset')
19+
parser.add_argument('--workers', help='number of data loading workers', default=2)
20+
parser.add_argument('--batchSize', default=64, help='input batch size')
21+
parser.add_argument('--imageSize', default=64, help='the height / width of the input image to network')
22+
parser.add_argument('--nz', default=100, help='size of the latent z vector')
23+
parser.add_argument('--ngf', default=64)
24+
parser.add_argument('--ndf', default=64)
25+
parser.add_argument('--niter', default=25, help='number of epochs to train for')
26+
parser.add_argument('--lr', default=0.0002, help='learning rate, default=0.0002')
27+
parser.add_argument('--beta1', default=0.5, help='beta1 for adam. default=0.5')
28+
parser.add_argument('--cuda' , action='store_true', help='enables cuda')
29+
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
30+
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
31+
opt = parser.parse_args()
32+
print(opt)
33+
34+
opt.manualSeed = random.randint(1, 10000) # fix seed
35+
print("Random Seed: ", opt.manualSeed)
36+
random.seed(opt.manualSeed)
37+
torch.manual_seed(opt.manualSeed)
38+
39+
cudnn.benchmark = True
40+
41+
if torch.cuda.is_available() and not opt.cuda:
42+
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
43+
44+
if opt.dataset in ['imagenet', 'folder', 'lfw']:
45+
# folder dataset
46+
dataset = dset.ImageFolder(root=opt.dataroot,
47+
transform=transforms.Compose([
48+
transforms.Scale(opt.imageSize),
49+
transforms.CenterCrop(opt.imageSize),
50+
transforms.ToTensor(),
51+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
52+
]))
53+
elif opt.dataset == 'lsun':
54+
opt.workers = 1 # need to do this because of an lmdb / python bug
55+
dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
56+
transform=transforms.Compose([
57+
transforms.Scale(opt.imageSize),
58+
transforms.CenterCrop(opt.imageSize),
59+
transforms.ToTensor(),
60+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
61+
]))
62+
elif opt.dataset == 'cifar10':
63+
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
64+
transform=transforms.Compose([
65+
transforms.ToTensor(),
66+
transforms.ToPILImage(),
67+
transforms.Scale(opt.imageSize),
68+
transforms.ToTensor(),
69+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
70+
])
71+
)
72+
assert dataset
73+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
74+
shuffle=True, num_workers=int(opt.workers))
75+
76+
nz = opt.nz
77+
ngf = opt.ngf
78+
ndf = opt.ndf
79+
nc = 3
80+
81+
# custom weights initialization called on netG and netD
82+
def weights_init(m):
83+
classname = m.__class__.__name__
84+
if classname.find('Conv') != -1:
85+
m.weight.data.normal_(0.0, 0.02)
86+
elif classname.find('BatchNorm') != -1:
87+
m.weight.data.normal_(1.0, 0.02)
88+
m.bias.data.fill_(0)
89+
90+
class _netG(nn.Container):
91+
def __init__(self):
92+
super(_netG, self).__init__()
93+
self.main = nn.Sequential(
94+
# input is Z, going into a convolution
95+
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
96+
nn.BatchNorm2d(ngf * 8),
97+
nn.ReLU(True),
98+
# state size. (ngf*8) x 4 x 4
99+
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
100+
nn.BatchNorm2d(ngf * 4),
101+
nn.ReLU(True),
102+
# state size. (ngf*4) x 8 x 8
103+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
104+
nn.BatchNorm2d(ngf * 2),
105+
nn.ReLU(True),
106+
# state size. (ngf*2) x 16 x 16
107+
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
108+
nn.BatchNorm2d(ngf),
109+
nn.ReLU(True),
110+
# state size. (ngf) x 32 x 32
111+
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
112+
nn.Tanh()
113+
# state size. (nc) x 64 x 64
114+
)
115+
def forward(self, input):
116+
gpu_ids = None
117+
if isinstance(input.data, torch.cuda.FloatTensor) \
118+
and torch.cuda.device_count() > 1:
119+
gpu_ids = range(torch.cuda.device_count())
120+
return nn.parallel.data_parallel(self.main, input, gpu_ids)
121+
122+
if opt.netG != '':
123+
netG = torch.load(opt.netG)
124+
else:
125+
netG = _netG()
126+
netG.apply(weights_init)
127+
128+
print(netG)
129+
130+
class _netD(nn.Container):
131+
def __init__(self):
132+
super(_netD, self).__init__()
133+
self.main = nn.Sequential(
134+
# input is (nc) x 64 x 64
135+
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
136+
nn.LeakyReLU(0.2, inplace=True),
137+
# state size. (ndf) x 32 x 32
138+
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
139+
nn.BatchNorm2d(ndf * 2),
140+
nn.LeakyReLU(0.2, inplace=True),
141+
# state size. (ndf*2) x 16 x 16
142+
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
143+
nn.BatchNorm2d(ndf * 4),
144+
nn.LeakyReLU(0.2, inplace=True),
145+
# state size. (ndf*4) x 8 x 8
146+
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
147+
nn.BatchNorm2d(ndf * 8),
148+
nn.LeakyReLU(0.2, inplace=True),
149+
# state size. (ndf*8) x 4 x 4
150+
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
151+
nn.Sigmoid()
152+
)
153+
def forward(self, input):
154+
gpu_ids = None
155+
# TODO: enable this
156+
#if isinstance(input.data, torch.cuda.FloatTensor) \
157+
# and torch.cuda.device_count() > 1:
158+
# gpu_ids = range(torch.cuda.device_count())
159+
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
160+
return output.view(-1, 1)
161+
162+
if opt.netD != '':
163+
netD = torch.load(opt.netD)
164+
else:
165+
netD = _netD()
166+
netD.apply(weights_init)
167+
print(netD)
168+
169+
criterion = nn.BCELoss()
170+
171+
input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
172+
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
173+
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
174+
label = torch.FloatTensor(opt.batchSize)
175+
real_label = 1
176+
fake_label = 0
177+
178+
if opt.cuda:
179+
netD.cuda()
180+
netG.cuda()
181+
criterion.cuda()
182+
input, label = input.cuda(), label.cuda()
183+
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
184+
185+
input = Variable(input)
186+
label = Variable(label)
187+
noise = Variable(noise)
188+
fixed_noise = Variable(fixed_noise)
189+
190+
# setup optimizer
191+
optimizerD = optim.Adam(netD.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
192+
optimizerG = optim.Adam(netG.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
193+
194+
for epoch in range(opt.niter):
195+
for i, data in enumerate(dataloader, 0):
196+
############################
197+
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
198+
###########################
199+
# train with real
200+
netD.zero_grad()
201+
real_cpu, _ = data
202+
input.data.copy_(real_cpu)
203+
label.data.fill_(real_label)
204+
205+
output = netD(input)
206+
errD_real = criterion(output, label)
207+
errD_real.backward()
208+
209+
# train with fake
210+
noise.data.normal_(0, 1)
211+
fake = netG(noise)
212+
input.data.copy_(fake.data)
213+
label.data.fill_(fake_label)
214+
output = netD(input)
215+
errD_fake = criterion(output, label)
216+
errD_fake.backward()
217+
errD = errD_real + errD_fake
218+
optimizerD.step()
219+
220+
############################
221+
# (2) Update G network: maximize log(D(G(z)))
222+
###########################
223+
netG.zero_grad()
224+
label.data.fill_(real_label) # fake labels are real for generator cost
225+
noise.data.normal_(0, 1)
226+
fake = netG(noise)
227+
output = netD(fake)
228+
errG = criterion(output, label)
229+
errG.backward()
230+
optimizerG.step()
231+
232+
print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f'
233+
% (epoch, opt.niter, i, len(dataloader)/opt.batchSize,
234+
errD.data[0], errG.data[0]))
235+
if i % 100 == 0:
236+
vutils.save_image(real_cpu, 'real_samples.png')
237+
fake = netG(fixed_noise)
238+
vutils.save_image(fake.data, 'fake_samples.png')
239+
240+
# do checkpointing
241+
torch.save(netG, 'netG_epoch_%d.pth' % epoch)
242+
torch.save(netD, 'netD_epoch_%d.pth' % epoch)

dcgan/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch
2+
torchvision

0 commit comments

Comments
 (0)