Skip to content

Fixes to ImageNet training script #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 19, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='PATH', required=True,
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
help='model architecture: resnet18 | resnet34 | ...'
help='model architecture: resnet18 | resnet34 | ... '
'(default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
Expand All @@ -39,6 +39,8 @@
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
help='evaluate model FILE on validation set')

best_prec1 = 0

Expand All @@ -50,20 +52,28 @@ def main():
# create model
if args.arch.startswith('resnet'):
print("=> creating model '{}'".format(args.arch))
model = resnet.__dict__[args.arch]()
model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
model.cuda()
else:
parser.error('invalid architecture: {}'.format(args.arch))

# optionally resume from a checkpoint
if args.resume:
if args.evaluate:
if not os.path.isfile(args.evaluate):
parser.error('invalid checkpoint: {}'.format(args.evaluate))
checkpoint = torch.load(args.evaluate)
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.evaluate, checkpoint['epoch']))
elif args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print(" | resuming from epoch {}".format(args.start_epoch))
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.evaluate, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))

Expand Down Expand Up @@ -95,32 +105,31 @@ def main():
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)

# parallelize model across all visible GPUs
model = torch.nn.DataParallel(model)

# define loss function (criterion) and pptimizer
criterion = nn.CrossEntropyLoss().cuda()

optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)

if args.evaluate:
validate(val_loader, model, criterion)
return

for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch)

# train for one epoch
model.train()
train(train_loader, model, criterion, optimizer, epoch)

# evaluate on validation set
model.eval()
prec1 = validate(val_loader, model, criterion)

# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch,
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
Expand All @@ -134,6 +143,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
top1 = AverageMeter()
top5 = AverageMeter()

# switch to train mode
model.train()

end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
Expand All @@ -149,9 +161,9 @@ def train(train_loader, model, criterion, optimizer, epoch):

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0])
top1.update(prec1[0])
top5.update(prec5[0])
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))

# compute gradient and do SGD step
optimizer.zero_grad()
Expand Down Expand Up @@ -179,6 +191,9 @@ def validate(val_loader, model, criterion):
top1 = AverageMeter()
top5 = AverageMeter()

# switch to evaluate mode
model.eval()

end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
Expand All @@ -191,9 +206,9 @@ def validate(val_loader, model, criterion):

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0])
top1.update(prec1[0])
top5.update(prec5[0])
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
Expand All @@ -208,6 +223,9 @@ def validate(val_loader, model, criterion):
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))

print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

return top1.avg


Expand All @@ -226,13 +244,13 @@ def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.n = 0
self.count = 0

def update(self, val):
def update(self, val, n=1):
self.val = val
self.sum += val
self.n += 1
self.avg = self.sum / self.n
self.sum += val * n
self.count += n
self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch):
Expand All @@ -247,7 +265,7 @@ def accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, True, True)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

Expand Down