Skip to content

Commit 8a19c60

Browse files
committed
Update ImageNet training code
- add learning rate schedule and weight decay - checkpointing - use pinned memory in data loader - don't use Trainer API for now
1 parent 12e213c commit 8a19c60

File tree

2 files changed

+233
-68
lines changed

2 files changed

+233
-68
lines changed

imagenet/main.py

Lines changed: 232 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,262 @@
11
import argparse
22
import os
3+
import shutil
4+
import time
5+
36
import torch
47
import torch.nn as nn
58
import torch.nn.parallel
69
import torch.backends.cudnn as cudnn
710
import torch.optim
8-
import torch.utils.trainer as trainer
9-
import torch.utils.trainer.plugins
1011
import torch.utils.data
1112
import torchvision.transforms as transforms
1213
import torchvision.datasets as datasets
1314

1415
import resnet
1516

17+
1618
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
1719
parser.add_argument('--data', metavar='PATH', required=True,
1820
help='path to dataset')
1921
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
2022
help='model architecture: resnet18 | resnet34 | ...'
2123
'(default: resnet18)')
22-
parser.add_argument('--gen', default='gen', metavar='PATH',
23-
help='path to save generated files (default: gen)')
24-
parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N',
25-
help='number of data loading threads (default: 2)')
26-
parser.add_argument('--nEpochs', default=90, type=int, metavar='N',
24+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
25+
help='number of data loading workers (default: 4)')
26+
parser.add_argument('--epochs', default=90, type=int, metavar='N',
2727
help='number of total epochs to run')
28-
parser.add_argument('--epochNumber', default=1, type=int, metavar='N',
28+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
2929
help='manual epoch number (useful on restarts)')
30-
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
31-
help='mini-batch size (1 = pure stochastic) Default: 256')
32-
parser.add_argument('--lr', default=0.1, type=float, metavar='LR',
33-
help='initial learning rate')
30+
parser.add_argument('-b', '--batch-size', default=256, type=int,
31+
metavar='N', help='mini-batch size (default: 256)')
32+
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
33+
metavar='LR', help='initial learning rate')
3434
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
3535
help='momentum')
36-
parser.add_argument('--weightDecay', default=1e-4, type=float, metavar='W',
37-
help='weight decay')
38-
args = parser.parse_args()
39-
40-
if args.arch.startswith('resnet'):
41-
model = resnet.__dict__[args.arch]()
42-
model.cuda()
43-
else:
44-
parser.error('invalid architecture: {}'.format(args.arch))
45-
46-
cudnn.benchmark = True
47-
48-
# Data loading code
49-
transform = transforms.Compose([
50-
transforms.RandomSizedCrop(224),
51-
transforms.RandomHorizontalFlip(),
52-
transforms.ToTensor(),
53-
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
54-
std = [ 0.229, 0.224, 0.225 ]),
55-
])
56-
57-
traindir = os.path.join(args.data, 'train')
58-
valdir = os.path.join(args.data, 'val')
59-
train = datasets.ImageFolder(traindir, transform)
60-
val = datasets.ImageFolder(valdir, transform)
61-
train_loader = torch.utils.data.DataLoader(
62-
train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
63-
64-
65-
# create a small container to apply DataParallel to the ResNet
66-
class DataParallel(nn.Container):
67-
def __init__(self):
68-
super(DataParallel, self).__init__(
69-
model=model,
70-
)
71-
72-
def forward(self, input):
73-
if torch.cuda.device_count() > 1:
74-
gpu_ids = range(torch.cuda.device_count())
75-
return nn.parallel.data_parallel(self.model, input, gpu_ids)
36+
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
37+
metavar='W', help='weight decay (default: 1e-4)')
38+
parser.add_argument('--print-freq', '-p', default=10, type=int,
39+
metavar='N', help='print frequency (default: 10)')
40+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
41+
help='path to latest checkpoint (default: none)')
42+
43+
best_prec1 = 0
44+
45+
46+
def main():
47+
global args, best_prec1
48+
args = parser.parse_args()
49+
50+
# create model
51+
if args.arch.startswith('resnet'):
52+
print("=> creating model '{}'".format(args.arch))
53+
model = resnet.__dict__[args.arch]()
54+
model.cuda()
55+
else:
56+
parser.error('invalid architecture: {}'.format(args.arch))
57+
58+
# optionally resume from a checkpoint
59+
if args.resume:
60+
if os.path.isfile(args.resume):
61+
print("=> loading checkpoint '{}'".format(args.resume))
62+
checkpoint = torch.load(args.resume)
63+
args.start_epoch = checkpoint['epoch']
64+
best_prec1 = checkpoint['best_prec1']
65+
model.load_state_dict(checkpoint['state_dict'])
66+
print(" | resuming from epoch {}".format(args.start_epoch))
7667
else:
77-
return self.model(input.cuda()).cpu()
68+
print("=> no checkpoint found at '{}'".format(args.resume))
69+
70+
cudnn.benchmark = True
71+
72+
# Data loading code
73+
traindir = os.path.join(args.data, 'train')
74+
valdir = os.path.join(args.data, 'val')
75+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
76+
std=[0.229, 0.224, 0.225])
77+
78+
train_loader = torch.utils.data.DataLoader(
79+
datasets.ImageFolder(traindir, transforms.Compose([
80+
transforms.RandomSizedCrop(224),
81+
transforms.RandomHorizontalFlip(),
82+
transforms.ToTensor(),
83+
normalize,
84+
])),
85+
batch_size=args.batch_size, shuffle=True,
86+
num_workers=args.workers, pin_memory=True)
87+
88+
val_loader = torch.utils.data.DataLoader(
89+
datasets.ImageFolder(valdir, transforms.Compose([
90+
transforms.Scale(256),
91+
transforms.CenterCrop(224),
92+
transforms.ToTensor(),
93+
normalize,
94+
])),
95+
batch_size=args.batch_size, shuffle=False,
96+
num_workers=args.workers, pin_memory=True)
97+
98+
# parallelize model across all visible GPUs
99+
model = torch.nn.DataParallel(model)
100+
101+
# define loss function (criterion) and pptimizer
102+
criterion = nn.CrossEntropyLoss().cuda()
103+
104+
optimizer = torch.optim.SGD(model.parameters(), args.lr,
105+
momentum=args.momentum,
106+
weight_decay=args.weight_decay)
107+
108+
for epoch in range(args.start_epoch, args.epochs):
109+
adjust_learning_rate(optimizer, epoch)
110+
111+
# train for one epoch
112+
model.train()
113+
train(train_loader, model, criterion, optimizer, epoch)
114+
115+
# evaluate on validation set
116+
model.eval()
117+
prec1 = validate(val_loader, model, criterion)
118+
119+
# remember best prec@1 and save checkpoint
120+
is_best = prec1 > best_prec1
121+
best_prec1 = max(prec1, best_prec1)
122+
save_checkpoint({
123+
'epoch': epoch,
124+
'arch': args.arch,
125+
'state_dict': model.state_dict(),
126+
'best_prec1': best_prec1,
127+
}, is_best)
128+
129+
130+
def train(train_loader, model, criterion, optimizer, epoch):
131+
batch_time = AverageMeter()
132+
data_time = AverageMeter()
133+
losses = AverageMeter()
134+
top1 = AverageMeter()
135+
top5 = AverageMeter()
136+
137+
end = time.time()
138+
for i, (input, target) in enumerate(train_loader):
139+
# measure data loading time
140+
data_time.update(time.time() - end)
141+
142+
target = target.cuda(async=True)
143+
input_var = torch.autograd.Variable(input)
144+
target_var = torch.autograd.Variable(target)
145+
146+
# compute output
147+
output = model(input_var)
148+
loss = criterion(output, target_var)
149+
150+
# measure accuracy and record loss
151+
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
152+
losses.update(loss.data[0])
153+
top1.update(prec1[0])
154+
top5.update(prec5[0])
155+
156+
# compute gradient and do SGD step
157+
optimizer.zero_grad()
158+
loss.backward()
159+
optimizer.step()
160+
161+
# measure elapsed time
162+
batch_time.update(time.time() - end)
163+
end = time.time()
164+
165+
if i % args.print_freq == 0:
166+
print('Epoch: [{0}][{1}/{2}]\t'
167+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
168+
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
169+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
170+
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
171+
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
172+
epoch, i, len(train_loader), batch_time=batch_time,
173+
data_time=data_time, loss=losses, top1=top1, top5=top5))
174+
175+
176+
def validate(val_loader, model, criterion):
177+
batch_time = AverageMeter()
178+
losses = AverageMeter()
179+
top1 = AverageMeter()
180+
top5 = AverageMeter()
181+
182+
end = time.time()
183+
for i, (input, target) in enumerate(val_loader):
184+
target = target.cuda(async=True)
185+
input_var = torch.autograd.Variable(input, volatile=True)
186+
target_var = torch.autograd.Variable(target, volatile=True)
187+
188+
# compute output
189+
output = model(input_var)
190+
loss = criterion(output, target_var)
191+
192+
# measure accuracy and record loss
193+
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
194+
losses.update(loss.data[0])
195+
top1.update(prec1[0])
196+
top5.update(prec5[0])
197+
198+
# measure elapsed time
199+
batch_time.update(time.time() - end)
200+
end = time.time()
201+
202+
if i % args.print_freq == 0:
203+
print('Test: [{0}/{1}]\t'
204+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
205+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
206+
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
207+
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
208+
i, len(val_loader), batch_time=batch_time, loss=losses,
209+
top1=top1, top5=top5))
210+
211+
return top1.avg
212+
213+
214+
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
215+
torch.save(state, filename)
216+
if is_best:
217+
shutil.copyfile(filename, 'model_best.pth.tar')
218+
219+
220+
class AverageMeter(object):
221+
"""Computes and stores the average and current value"""
222+
def __init__(self):
223+
self.reset()
224+
225+
def reset(self):
226+
self.val = 0
227+
self.avg = 0
228+
self.sum = 0
229+
self.n = 0
230+
231+
def update(self, val):
232+
self.val = val
233+
self.sum += val
234+
self.n += 1
235+
self.avg = self.sum / self.n
236+
237+
238+
def adjust_learning_rate(optimizer, epoch):
239+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
240+
lr = args.lr * (0.1 ** (epoch // 30))
241+
for param_group in optimizer.state_dict()['param_groups']:
242+
param_group['lr'] = lr
78243

79-
model = DataParallel()
80244

81-
# define Loss Function and Optimizer
82-
criterion = nn.CrossEntropyLoss().cuda()
83-
optimizer = torch.optim.SGD(model.parameters(), args.lr, args.momentum)
245+
def accuracy(output, target, topk=(1,)):
246+
"""Computes the precision@k for the specified values of k"""
247+
maxk = max(topk)
248+
batch_size = target.size(0)
84249

250+
_, pred = output.topk(maxk, True, True)
251+
pred = pred.t()
252+
correct = pred.eq(target.view(1, -1).expand_as(pred))
85253

86-
# pass model, loss, optimizer and dataset to the trainer
87-
t = trainer.Trainer(model, criterion, optimizer, train_loader)
254+
res = []
255+
for k in topk:
256+
correct_k = correct[:k].view(-1).float().sum(0)
257+
res.append(correct_k.mul_(100.0 / batch_size))
258+
return res
88259

89-
# register some monitoring plugins
90-
t.register_plugin(trainer.plugins.ProgressMonitor())
91-
t.register_plugin(trainer.plugins.AccuracyMonitor())
92-
t.register_plugin(trainer.plugins.LossMonitor())
93-
t.register_plugin(trainer.plugins.TimeMonitor())
94-
t.register_plugin(trainer.plugins.Logger(['progress', 'accuracy', 'loss', 'time']))
95260

96-
# train!
97-
t.run(args.nEpochs)
261+
if __name__ == '__main__':
262+
main()

imagenet/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _make_layer(self, block, planes, blocks, stride=1):
114114

115115
layers = []
116116
layers.append(block(self.inplanes, planes, stride, downsample))
117-
self.inplanes = planes * block.expansion
117+
self.inplanes = planes * block.expansion
118118
for i in range(1, blocks):
119119
layers.append(block(self.inplanes, planes))
120120

0 commit comments

Comments
 (0)