diff --git a/dcgan/README.md b/dcgan/README.md new file mode 100644 index 0000000000..4fd1d8f720 --- /dev/null +++ b/dcgan/README.md @@ -0,0 +1,36 @@ +# Deep Convolution Generative Adversarial Networks + +This example implements the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434) + +The implementation is very close to the Torch implementation [dcgan.torch](https://github.com/soumith/dcgan.torch) + +After every 100 training iterations, the files `real_samples.png` and `fake_samples.png` are written to disk +with the samples from the generative model. + +After every epoch, models are saved to: `netG_epoch_%d.pth` and `netD_epoch_%d.pth` + +``` +usage: main.py [-h] --dataset DATASET --dataroot DATAROOT [--workers WORKERS] + [--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ] + [--ngf NGF] [--ndf NDF] [--niter NITER] [--lr LR] + [--beta1 BETA1] [--cuda] [--netG NETG] [--netD NETD] + +optional arguments: + -h, --help show this help message and exit + --dataset DATASET cifar10 | lsun | imagenet | folder | lfw + --dataroot DATAROOT path to dataset + --workers WORKERS number of data loading workers + --batchSize BATCHSIZE + input batch size + --imageSize IMAGESIZE + the height / width of the input image to network + --nz NZ size of the latent z vector + --ngf NGF + --ndf NDF + --niter NITER number of epochs to train for + --lr LR learning rate, default=0.0002 + --beta1 BETA1 beta1 for adam. default=0.5 + --cuda enables cuda + --netG NETG path to netG (to continue training) + --netD NETD path to netD (to continue training) +``` diff --git a/dcgan/main.py b/dcgan/main.py new file mode 100644 index 0000000000..8814c81f54 --- /dev/null +++ b/dcgan/main.py @@ -0,0 +1,242 @@ +from __future__ import print_function +import argparse +import random +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim as optim +import torch.utils.data +import torchvision.datasets as dset +import torchvision.transforms as transforms +import torchvision.utils as vutils +from torch.autograd import Variable + + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ') +parser.add_argument('--dataroot', required=True, help='path to dataset') +parser.add_argument('--workers', help='number of data loading workers', default=2) +parser.add_argument('--batchSize', default=64, help='input batch size') +parser.add_argument('--imageSize', default=64, help='the height / width of the input image to network') +parser.add_argument('--nz', default=100, help='size of the latent z vector') +parser.add_argument('--ngf', default=64) +parser.add_argument('--ndf', default=64) +parser.add_argument('--niter', default=25, help='number of epochs to train for') +parser.add_argument('--lr', default=0.0002, help='learning rate, default=0.0002') +parser.add_argument('--beta1', default=0.5, help='beta1 for adam. default=0.5') +parser.add_argument('--cuda' , action='store_true', help='enables cuda') +parser.add_argument('--netG', default='', help="path to netG (to continue training)") +parser.add_argument('--netD', default='', help="path to netD (to continue training)") +opt = parser.parse_args() +print(opt) + +opt.manualSeed = random.randint(1, 10000) # fix seed +print("Random Seed: ", opt.manualSeed) +random.seed(opt.manualSeed) +torch.manual_seed(opt.manualSeed) + +cudnn.benchmark = True + +if torch.cuda.is_available() and not opt.cuda: + print("WARNING: You have a CUDA device, so you should probably run with --cuda") + +if opt.dataset in ['imagenet', 'folder', 'lfw']: + # folder dataset + dataset = dset.ImageFolder(root=opt.dataroot, + transform=transforms.Compose([ + transforms.Scale(opt.imageSize), + transforms.CenterCrop(opt.imageSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])) +elif opt.dataset == 'lsun': + opt.workers = 1 # need to do this because of an lmdb / python bug + dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'], + transform=transforms.Compose([ + transforms.Scale(opt.imageSize), + transforms.CenterCrop(opt.imageSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])) +elif opt.dataset == 'cifar10': + dataset = dset.CIFAR10(root=opt.dataroot, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.ToPILImage(), + transforms.Scale(opt.imageSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + ) +assert dataset +dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, + shuffle=True, num_workers=int(opt.workers)) + +nz = opt.nz +ngf = opt.ngf +ndf = opt.ndf +nc = 3 + +# custom weights initialization called on netG and netD +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + +class _netG(nn.Container): + def __init__(self): + super(_netG, self).__init__() + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), + nn.BatchNorm2d(ngf * 8), + nn.ReLU(True), + # state size. (ngf*8) x 4 x 4 + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + # state size. (ngf*4) x 8 x 8 + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + # state size. (ngf*2) x 16 x 16 + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + # state size. (ngf) x 32 x 32 + nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), + nn.Tanh() + # state size. (nc) x 64 x 64 + ) + def forward(self, input): + gpu_ids = None + if isinstance(input.data, torch.cuda.FloatTensor) \ + and torch.cuda.device_count() > 1: + gpu_ids = range(torch.cuda.device_count()) + return nn.parallel.data_parallel(self.main, input, gpu_ids) + +if opt.netG != '': + netG = torch.load(opt.netG) +else: + netG = _netG() + netG.apply(weights_init) + +print(netG) + +class _netD(nn.Container): + def __init__(self): + super(_netD, self).__init__() + self.main = nn.Sequential( + # input is (nc) x 64 x 64 + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf) x 32 x 32 + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 2), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*2) x 16 x 16 + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 4), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*4) x 8 x 8 + nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 8), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*8) x 4 x 4 + nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), + nn.Sigmoid() + ) + def forward(self, input): + gpu_ids = None + # TODO: enable this + #if isinstance(input.data, torch.cuda.FloatTensor) \ + # and torch.cuda.device_count() > 1: + # gpu_ids = range(torch.cuda.device_count()) + output = nn.parallel.data_parallel(self.main, input, gpu_ids) + return output.view(-1, 1) + +if opt.netD != '': + netD = torch.load(opt.netD) +else: + netD = _netD() + netD.apply(weights_init) +print(netD) + +criterion = nn.BCELoss() + +input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) +noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) +fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) +label = torch.FloatTensor(opt.batchSize) +real_label = 1 +fake_label = 0 + +if opt.cuda: + netD.cuda() + netG.cuda() + criterion.cuda() + input, label = input.cuda(), label.cuda() + noise, fixed_noise = noise.cuda(), fixed_noise.cuda() + +input = Variable(input) +label = Variable(label) +noise = Variable(noise) +fixed_noise = Variable(fixed_noise) + +# setup optimizer +optimizerD = optim.Adam(netD.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999)) +optimizerG = optim.Adam(netG.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999)) + +for epoch in range(opt.niter): + for i, data in enumerate(dataloader, 0): + ############################ + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) + ########################### + # train with real + netD.zero_grad() + real_cpu, _ = data + input.data.copy_(real_cpu) + label.data.fill_(real_label) + + output = netD(input) + errD_real = criterion(output, label) + errD_real.backward() + + # train with fake + noise.data.normal_(0, 1) + fake = netG(noise) + input.data.copy_(fake.data) + label.data.fill_(fake_label) + output = netD(input) + errD_fake = criterion(output, label) + errD_fake.backward() + errD = errD_real + errD_fake + optimizerD.step() + + ############################ + # (2) Update G network: maximize log(D(G(z))) + ########################### + netG.zero_grad() + label.data.fill_(real_label) # fake labels are real for generator cost + noise.data.normal_(0, 1) + fake = netG(noise) + output = netD(fake) + errG = criterion(output, label) + errG.backward() + optimizerG.step() + + print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f' + % (epoch, opt.niter, i, len(dataloader)/opt.batchSize, + errD.data[0], errG.data[0])) + if i % 100 == 0: + vutils.save_image(real_cpu, 'real_samples.png') + fake = netG(fixed_noise) + vutils.save_image(fake.data, 'fake_samples.png') + + # do checkpointing + torch.save(netG, 'netG_epoch_%d.pth' % epoch) + torch.save(netD, 'netD_epoch_%d.pth' % epoch) diff --git a/dcgan/requirements.txt b/dcgan/requirements.txt new file mode 100644 index 0000000000..ac988bdf84 --- /dev/null +++ b/dcgan/requirements.txt @@ -0,0 +1,2 @@ +torch +torchvision