Skip to content

dcgan #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions dcgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Deep Convolution Generative Adversarial Networks

This example implements the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434)

The implementation is very close to the Torch implementation [dcgan.torch](https://github.com/soumith/dcgan.torch)

After every 100 training iterations, the files `real_samples.png` and `fake_samples.png` are written to disk
with the samples from the generative model.

After every epoch, models are saved to: `netG_epoch_%d.pth` and `netD_epoch_%d.pth`

```
usage: main.py [-h] --dataset DATASET --dataroot DATAROOT [--workers WORKERS]
[--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ]
[--ngf NGF] [--ndf NDF] [--niter NITER] [--lr LR]
[--beta1 BETA1] [--cuda] [--netG NETG] [--netD NETD]

optional arguments:
-h, --help show this help message and exit
--dataset DATASET cifar10 | lsun | imagenet | folder | lfw
--dataroot DATAROOT path to dataset
--workers WORKERS number of data loading workers
--batchSize BATCHSIZE
input batch size
--imageSize IMAGESIZE
the height / width of the input image to network
--nz NZ size of the latent z vector
--ngf NGF
--ndf NDF
--niter NITER number of epochs to train for
--lr LR learning rate, default=0.0002
--beta1 BETA1 beta1 for adam. default=0.5
--cuda enables cuda
--netG NETG path to netG (to continue training)
--netD NETD path to netD (to continue training)
```
242 changes: 242 additions & 0 deletions dcgan/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', help='number of data loading workers', default=2)
parser.add_argument('--batchSize', default=64, help='input batch size')
parser.add_argument('--imageSize', default=64, help='the height / width of the input image to network')
parser.add_argument('--nz', default=100, help='size of the latent z vector')
parser.add_argument('--ngf', default=64)
parser.add_argument('--ndf', default=64)
parser.add_argument('--niter', default=25, help='number of epochs to train for')
parser.add_argument('--lr', default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda' , action='store_true', help='enables cuda')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
opt = parser.parse_args()
print(opt)

opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")

if opt.dataset in ['imagenet', 'folder', 'lfw']:
# folder dataset
dataset = dset.ImageFolder(root=opt.dataroot,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'lsun':
opt.workers = 1 # need to do this because of an lmdb / python bug
dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'cifar10':
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Scale(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))

nz = opt.nz
ngf = opt.ngf
ndf = opt.ndf
nc = 3

# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)

class _netG(nn.Container):
def __init__(self):
super(_netG, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) \
and torch.cuda.device_count() > 1:
gpu_ids = range(torch.cuda.device_count())
return nn.parallel.data_parallel(self.main, input, gpu_ids)

if opt.netG != '':
netG = torch.load(opt.netG)
else:
netG = _netG()
netG.apply(weights_init)

print(netG)

class _netD(nn.Container):
def __init__(self):
super(_netD, self).__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
gpu_ids = None
# TODO: enable this
#if isinstance(input.data, torch.cuda.FloatTensor) \
# and torch.cuda.device_count() > 1:
# gpu_ids = range(torch.cuda.device_count())
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
return output.view(-1, 1)

if opt.netD != '':
netD = torch.load(opt.netD)
else:
netD = _netD()
netD.apply(weights_init)
print(netD)

criterion = nn.BCELoss()

input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0

if opt.cuda:
netD.cuda()
netG.cuda()
criterion.cuda()
input, label = input.cuda(), label.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))

for epoch in range(opt.niter):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_cpu, _ = data
input.data.copy_(real_cpu)
label.data.fill_(real_label)

output = netD(input)
errD_real = criterion(output, label)
errD_real.backward()

# train with fake
noise.data.normal_(0, 1)
fake = netG(noise)
input.data.copy_(fake.data)
label.data.fill_(fake_label)
output = netD(input)
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
optimizerD.step()

############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.data.fill_(real_label) # fake labels are real for generator cost
noise.data.normal_(0, 1)
fake = netG(noise)
output = netD(fake)
errG = criterion(output, label)
errG.backward()
optimizerG.step()

print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f'
% (epoch, opt.niter, i, len(dataloader)/opt.batchSize,
errD.data[0], errG.data[0]))
if i % 100 == 0:
vutils.save_image(real_cpu, 'real_samples.png')
fake = netG(fixed_noise)
vutils.save_image(fake.data, 'fake_samples.png')

# do checkpointing
torch.save(netG, 'netG_epoch_%d.pth' % epoch)
torch.save(netD, 'netD_epoch_%d.pth' % epoch)
2 changes: 2 additions & 0 deletions dcgan/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
torchvision