diff --git a/imagenet/main.py b/imagenet/main.py index 31d0f3a576..004af8b096 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -1,97 +1,262 @@ import argparse import os +import shutil +import time + import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim -import torch.utils.trainer as trainer -import torch.utils.trainer.plugins import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets import resnet + parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--data', metavar='PATH', required=True, help='path to dataset') parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', help='model architecture: resnet18 | resnet34 | ...' '(default: resnet18)') -parser.add_argument('--gen', default='gen', metavar='PATH', - help='path to save generated files (default: gen)') -parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N', - help='number of data loading threads (default: 2)') -parser.add_argument('--nEpochs', default=90, type=int, metavar='N', +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') -parser.add_argument('--epochNumber', default=1, type=int, metavar='N', +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') -parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', - help='mini-batch size (1 = pure stochastic) Default: 256') -parser.add_argument('--lr', default=0.1, type=float, metavar='LR', - help='initial learning rate') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') -parser.add_argument('--weightDecay', default=1e-4, type=float, metavar='W', - help='weight decay') -args = parser.parse_args() - -if args.arch.startswith('resnet'): - model = resnet.__dict__[args.arch]() - model.cuda() -else: - parser.error('invalid architecture: {}'.format(args.arch)) - -cudnn.benchmark = True - -# Data loading code -transform = transforms.Compose([ - transforms.RandomSizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], - std = [ 0.229, 0.224, 0.225 ]), -]) - -traindir = os.path.join(args.data, 'train') -valdir = os.path.join(args.data, 'val') -train = datasets.ImageFolder(traindir, transform) -val = datasets.ImageFolder(valdir, transform) -train_loader = torch.utils.data.DataLoader( - train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads) - - -# create a small container to apply DataParallel to the ResNet -class DataParallel(nn.Container): - def __init__(self): - super(DataParallel, self).__init__( - model=model, - ) - - def forward(self, input): - if torch.cuda.device_count() > 1: - gpu_ids = range(torch.cuda.device_count()) - return nn.parallel.data_parallel(self.model, input, gpu_ids) +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + +best_prec1 = 0 + + +def main(): + global args, best_prec1 + args = parser.parse_args() + + # create model + if args.arch.startswith('resnet'): + print("=> creating model '{}'".format(args.arch)) + model = resnet.__dict__[args.arch]() + model.cuda() + else: + parser.error('invalid architecture: {}'.format(args.arch)) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + best_prec1 = checkpoint['best_prec1'] + model.load_state_dict(checkpoint['state_dict']) + print(" | resuming from epoch {}".format(args.start_epoch)) else: - return self.model(input.cuda()).cpu() + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(traindir, transforms.Compose([ + transforms.RandomSizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Scale(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + # parallelize model across all visible GPUs + model = torch.nn.DataParallel(model) + + # define loss function (criterion) and pptimizer + criterion = nn.CrossEntropyLoss().cuda() + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + for epoch in range(args.start_epoch, args.epochs): + adjust_learning_rate(optimizer, epoch) + + # train for one epoch + model.train() + train(train_loader, model, criterion, optimizer, epoch) + + # evaluate on validation set + model.eval() + prec1 = validate(val_loader, model, criterion) + + # remember best prec@1 and save checkpoint + is_best = prec1 > best_prec1 + best_prec1 = max(prec1, best_prec1) + save_checkpoint({ + 'epoch': epoch, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_prec1': best_prec1, + }, is_best) + + +def train(train_loader, model, criterion, optimizer, epoch): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + target = target.cuda(async=True) + input_var = torch.autograd.Variable(input) + target_var = torch.autograd.Variable(target) + + # compute output + output = model(input_var) + loss = criterion(output, target_var) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.data[0]) + top1.update(prec1[0]) + top5.update(prec5[0]) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + + +def validate(val_loader, model, criterion): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + end = time.time() + for i, (input, target) in enumerate(val_loader): + target = target.cuda(async=True) + input_var = torch.autograd.Variable(input, volatile=True) + target_var = torch.autograd.Variable(target, volatile=True) + + # compute output + output = model(input_var) + loss = criterion(output, target_var) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.data[0]) + top1.update(prec1[0]) + top5.update(prec5[0]) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.n = 0 + + def update(self, val): + self.val = val + self.sum += val + self.n += 1 + self.avg = self.sum / self.n + + +def adjust_learning_rate(optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.state_dict()['param_groups']: + param_group['lr'] = lr -model = DataParallel() -# define Loss Function and Optimizer -criterion = nn.CrossEntropyLoss().cuda() -optimizer = torch.optim.SGD(model.parameters(), args.lr, args.momentum) +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) -# pass model, loss, optimizer and dataset to the trainer -t = trainer.Trainer(model, criterion, optimizer, train_loader) + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res -# register some monitoring plugins -t.register_plugin(trainer.plugins.ProgressMonitor()) -t.register_plugin(trainer.plugins.AccuracyMonitor()) -t.register_plugin(trainer.plugins.LossMonitor()) -t.register_plugin(trainer.plugins.TimeMonitor()) -t.register_plugin(trainer.plugins.Logger(['progress', 'accuracy', 'loss', 'time'])) -# train! -t.run(args.nEpochs) +if __name__ == '__main__': + main() diff --git a/imagenet/resnet.py b/imagenet/resnet.py index ea9fc111cb..c4100f65dd 100644 --- a/imagenet/resnet.py +++ b/imagenet/resnet.py @@ -114,7 +114,7 @@ def _make_layer(self, block, planes, blocks, stride=1): layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion + self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes))