Skip to content

Commit c88ba0a

Browse files
authored
Fixes to ImageNet training script (#19)
- Fix call to top-k which had produced incorrect prec@1 calculation - Fix checkpoint loading. Create the DataParallel container before resuming checkpoints so that the structure matches. - Weight average precision by batch size - Add '--evaluate FILE' option to only run on validation set
1 parent d2acdb9 commit c88ba0a

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

imagenet/main.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717

1818
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
19-
parser.add_argument('--data', metavar='PATH', required=True,
19+
parser.add_argument('data', metavar='DIR',
2020
help='path to dataset')
2121
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
22-
help='model architecture: resnet18 | resnet34 | ...'
22+
help='model architecture: resnet18 | resnet34 | ... '
2323
'(default: resnet18)')
2424
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
2525
help='number of data loading workers (default: 4)')
@@ -39,6 +39,8 @@
3939
metavar='N', help='print frequency (default: 10)')
4040
parser.add_argument('--resume', default='', type=str, metavar='PATH',
4141
help='path to latest checkpoint (default: none)')
42+
parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
43+
help='evaluate model FILE on validation set')
4244

4345
best_prec1 = 0
4446

@@ -50,20 +52,28 @@ def main():
5052
# create model
5153
if args.arch.startswith('resnet'):
5254
print("=> creating model '{}'".format(args.arch))
53-
model = resnet.__dict__[args.arch]()
55+
model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
5456
model.cuda()
5557
else:
5658
parser.error('invalid architecture: {}'.format(args.arch))
5759

5860
# optionally resume from a checkpoint
59-
if args.resume:
61+
if args.evaluate:
62+
if not os.path.isfile(args.evaluate):
63+
parser.error('invalid checkpoint: {}'.format(args.evaluate))
64+
checkpoint = torch.load(args.evaluate)
65+
model.load_state_dict(checkpoint['state_dict'])
66+
print("=> loaded checkpoint '{}' (epoch {})"
67+
.format(args.evaluate, checkpoint['epoch']))
68+
elif args.resume:
6069
if os.path.isfile(args.resume):
6170
print("=> loading checkpoint '{}'".format(args.resume))
6271
checkpoint = torch.load(args.resume)
6372
args.start_epoch = checkpoint['epoch']
6473
best_prec1 = checkpoint['best_prec1']
6574
model.load_state_dict(checkpoint['state_dict'])
66-
print(" | resuming from epoch {}".format(args.start_epoch))
75+
print("=> loaded checkpoint '{}' (epoch {})"
76+
.format(args.evaluate, checkpoint['epoch']))
6777
else:
6878
print("=> no checkpoint found at '{}'".format(args.resume))
6979

@@ -95,32 +105,31 @@ def main():
95105
batch_size=args.batch_size, shuffle=False,
96106
num_workers=args.workers, pin_memory=True)
97107

98-
# parallelize model across all visible GPUs
99-
model = torch.nn.DataParallel(model)
100-
101108
# define loss function (criterion) and pptimizer
102109
criterion = nn.CrossEntropyLoss().cuda()
103110

104111
optimizer = torch.optim.SGD(model.parameters(), args.lr,
105112
momentum=args.momentum,
106113
weight_decay=args.weight_decay)
107114

115+
if args.evaluate:
116+
validate(val_loader, model, criterion)
117+
return
118+
108119
for epoch in range(args.start_epoch, args.epochs):
109120
adjust_learning_rate(optimizer, epoch)
110121

111122
# train for one epoch
112-
model.train()
113123
train(train_loader, model, criterion, optimizer, epoch)
114124

115125
# evaluate on validation set
116-
model.eval()
117126
prec1 = validate(val_loader, model, criterion)
118127

119128
# remember best prec@1 and save checkpoint
120129
is_best = prec1 > best_prec1
121130
best_prec1 = max(prec1, best_prec1)
122131
save_checkpoint({
123-
'epoch': epoch,
132+
'epoch': epoch + 1,
124133
'arch': args.arch,
125134
'state_dict': model.state_dict(),
126135
'best_prec1': best_prec1,
@@ -134,6 +143,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
134143
top1 = AverageMeter()
135144
top5 = AverageMeter()
136145

146+
# switch to train mode
147+
model.train()
148+
137149
end = time.time()
138150
for i, (input, target) in enumerate(train_loader):
139151
# measure data loading time
@@ -149,9 +161,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
149161

150162
# measure accuracy and record loss
151163
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
152-
losses.update(loss.data[0])
153-
top1.update(prec1[0])
154-
top5.update(prec5[0])
164+
losses.update(loss.data[0], input.size(0))
165+
top1.update(prec1[0], input.size(0))
166+
top5.update(prec5[0], input.size(0))
155167

156168
# compute gradient and do SGD step
157169
optimizer.zero_grad()
@@ -179,6 +191,9 @@ def validate(val_loader, model, criterion):
179191
top1 = AverageMeter()
180192
top5 = AverageMeter()
181193

194+
# switch to evaluate mode
195+
model.eval()
196+
182197
end = time.time()
183198
for i, (input, target) in enumerate(val_loader):
184199
target = target.cuda(async=True)
@@ -191,9 +206,9 @@ def validate(val_loader, model, criterion):
191206

192207
# measure accuracy and record loss
193208
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
194-
losses.update(loss.data[0])
195-
top1.update(prec1[0])
196-
top5.update(prec5[0])
209+
losses.update(loss.data[0], input.size(0))
210+
top1.update(prec1[0], input.size(0))
211+
top5.update(prec5[0], input.size(0))
197212

198213
# measure elapsed time
199214
batch_time.update(time.time() - end)
@@ -208,6 +223,9 @@ def validate(val_loader, model, criterion):
208223
i, len(val_loader), batch_time=batch_time, loss=losses,
209224
top1=top1, top5=top5))
210225

226+
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
227+
.format(top1=top1, top5=top5))
228+
211229
return top1.avg
212230

213231

@@ -226,13 +244,13 @@ def reset(self):
226244
self.val = 0
227245
self.avg = 0
228246
self.sum = 0
229-
self.n = 0
247+
self.count = 0
230248

231-
def update(self, val):
249+
def update(self, val, n=1):
232250
self.val = val
233-
self.sum += val
234-
self.n += 1
235-
self.avg = self.sum / self.n
251+
self.sum += val * n
252+
self.count += n
253+
self.avg = self.sum / self.count
236254

237255

238256
def adjust_learning_rate(optimizer, epoch):
@@ -247,7 +265,7 @@ def accuracy(output, target, topk=(1,)):
247265
maxk = max(topk)
248266
batch_size = target.size(0)
249267

250-
_, pred = output.topk(maxk, True, True)
268+
_, pred = output.topk(maxk, 1, True, True)
251269
pred = pred.t()
252270
correct = pred.eq(target.view(1, -1).expand_as(pred))
253271

0 commit comments

Comments
 (0)