diff --git a/imagenet/main.py b/imagenet/main.py index 004af8b096..5006378d6f 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -16,10 +16,10 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -parser.add_argument('--data', metavar='PATH', required=True, +parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', - help='model architecture: resnet18 | resnet34 | ...' + help='model architecture: resnet18 | resnet34 | ... ' '(default: resnet18)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') @@ -39,6 +39,8 @@ metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', type=str, metavar='FILE', + help='evaluate model FILE on validation set') best_prec1 = 0 @@ -50,20 +52,28 @@ def main(): # create model if args.arch.startswith('resnet'): print("=> creating model '{}'".format(args.arch)) - model = resnet.__dict__[args.arch]() + model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) model.cuda() else: parser.error('invalid architecture: {}'.format(args.arch)) # optionally resume from a checkpoint - if args.resume: + if args.evaluate: + if not os.path.isfile(args.evaluate): + parser.error('invalid checkpoint: {}'.format(args.evaluate)) + checkpoint = torch.load(args.evaluate) + model.load_state_dict(checkpoint['state_dict']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.evaluate, checkpoint['epoch'])) + elif args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) - print(" | resuming from epoch {}".format(args.start_epoch)) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.evaluate, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) @@ -95,9 +105,6 @@ def main(): batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) - # parallelize model across all visible GPUs - model = torch.nn.DataParallel(model) - # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss().cuda() @@ -105,22 +112,24 @@ def main(): momentum=args.momentum, weight_decay=args.weight_decay) + if args.evaluate: + validate(val_loader, model, criterion) + return + for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch - model.train() train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set - model.eval() prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ - 'epoch': epoch, + 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, @@ -134,6 +143,9 @@ def train(train_loader, model, criterion, optimizer, epoch): top1 = AverageMeter() top5 = AverageMeter() + # switch to train mode + model.train() + end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time @@ -149,9 +161,9 @@ def train(train_loader, model, criterion, optimizer, epoch): # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - losses.update(loss.data[0]) - top1.update(prec1[0]) - top5.update(prec5[0]) + losses.update(loss.data[0], input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() @@ -179,6 +191,9 @@ def validate(val_loader, model, criterion): top1 = AverageMeter() top5 = AverageMeter() + # switch to evaluate mode + model.eval() + end = time.time() for i, (input, target) in enumerate(val_loader): target = target.cuda(async=True) @@ -191,9 +206,9 @@ def validate(val_loader, model, criterion): # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - losses.update(loss.data[0]) - top1.update(prec1[0]) - top5.update(prec5[0]) + losses.update(loss.data[0], input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) # measure elapsed time batch_time.update(time.time() - end) @@ -208,6 +223,9 @@ def validate(val_loader, model, criterion): i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) + print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + return top1.avg @@ -226,13 +244,13 @@ def reset(self): self.val = 0 self.avg = 0 self.sum = 0 - self.n = 0 + self.count = 0 - def update(self, val): + def update(self, val, n=1): self.val = val - self.sum += val - self.n += 1 - self.avg = self.sum / self.n + self.sum += val * n + self.count += n + self.avg = self.sum / self.count def adjust_learning_rate(optimizer, epoch): @@ -247,7 +265,7 @@ def accuracy(output, target, topk=(1,)): maxk = max(topk) batch_size = target.size(0) - _, pred = output.topk(maxk, True, True) + _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred))