Skip to content

Update ImageNet training code #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 3, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 232 additions & 67 deletions imagenet/main.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,262 @@
import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.trainer as trainer
import torch.utils.trainer.plugins
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import resnet


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='PATH', required=True,
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
help='model architecture: resnet18 | resnet34 | ...'
'(default: resnet18)')
parser.add_argument('--gen', default='gen', metavar='PATH',
help='path to save generated files (default: gen)')
parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N',
help='number of data loading threads (default: 2)')
parser.add_argument('--nEpochs', default=90, type=int, metavar='N',
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--epochNumber', default=1, type=int, metavar='N',
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
help='mini-batch size (1 = pure stochastic) Default: 256')
parser.add_argument('--lr', default=0.1, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weightDecay', default=1e-4, type=float, metavar='W',
help='weight decay')
args = parser.parse_args()

if args.arch.startswith('resnet'):
model = resnet.__dict__[args.arch]()
model.cuda()
else:
parser.error('invalid architecture: {}'.format(args.arch))

cudnn.benchmark = True

# Data loading code
transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
std = [ 0.229, 0.224, 0.225 ]),
])

traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
train = datasets.ImageFolder(traindir, transform)
val = datasets.ImageFolder(valdir, transform)
train_loader = torch.utils.data.DataLoader(
train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)


# create a small container to apply DataParallel to the ResNet
class DataParallel(nn.Container):
def __init__(self):
super(DataParallel, self).__init__(
model=model,
)

def forward(self, input):
if torch.cuda.device_count() > 1:
gpu_ids = range(torch.cuda.device_count())
return nn.parallel.data_parallel(self.model, input, gpu_ids)
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')

best_prec1 = 0


def main():
global args, best_prec1
args = parser.parse_args()

# create model
if args.arch.startswith('resnet'):
print("=> creating model '{}'".format(args.arch))
model = resnet.__dict__[args.arch]()
model.cuda()
else:
parser.error('invalid architecture: {}'.format(args.arch))

# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print(" | resuming from epoch {}".format(args.start_epoch))
else:
return self.model(input.cuda()).cpu()
print("=> no checkpoint found at '{}'".format(args.resume))

cudnn.benchmark = True

# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)

# parallelize model across all visible GPUs

This comment was marked as off-topic.

model = torch.nn.DataParallel(model)

# define loss function (criterion) and pptimizer
criterion = nn.CrossEntropyLoss().cuda()

This comment was marked as off-topic.


optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)

for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch)

# train for one epoch
model.train()
train(train_loader, model, criterion, optimizer, epoch)

# evaluate on validation set
model.eval()
prec1 = validate(val_loader, model, criterion)

# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best)


def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)

target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)

# compute output
output = model(input_var)
loss = criterion(output, target_var)

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0])
top1.update(prec1[0])
top5.update(prec5[0])

# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))


def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input, volatile=True)
target_var = torch.autograd.Variable(target, volatile=True)

# compute output
output = model(input_var)
loss = criterion(output, target_var)

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0])
top1.update(prec1[0])
top5.update(prec5[0])

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))

return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.n = 0

def update(self, val):
self.val = val
self.sum += val
self.n += 1
self.avg = self.sum / self.n


def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.state_dict()['param_groups']:

This comment was marked as off-topic.

param_group['lr'] = lr

model = DataParallel()

# define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr, args.momentum)
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

# pass model, loss, optimizer and dataset to the trainer
t = trainer.Trainer(model, criterion, optimizer, train_loader)
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res

# register some monitoring plugins
t.register_plugin(trainer.plugins.ProgressMonitor())
t.register_plugin(trainer.plugins.AccuracyMonitor())
t.register_plugin(trainer.plugins.LossMonitor())
t.register_plugin(trainer.plugins.TimeMonitor())
t.register_plugin(trainer.plugins.Logger(['progress', 'accuracy', 'loss', 'time']))

# train!
t.run(args.nEpochs)
if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion imagenet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _make_layer(self, block, planes, blocks, stride=1):

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
self.inplanes = planes * block.expansion

This comment was marked as off-topic.

for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

Expand Down