|
1 | 1 | import argparse
|
2 | 2 | import os
|
| 3 | +import shutil |
| 4 | +import time |
| 5 | + |
3 | 6 | import torch
|
4 | 7 | import torch.nn as nn
|
5 | 8 | import torch.nn.parallel
|
6 | 9 | import torch.backends.cudnn as cudnn
|
7 | 10 | import torch.optim
|
8 |
| -import torch.utils.trainer as trainer |
9 |
| -import torch.utils.trainer.plugins |
10 | 11 | import torch.utils.data
|
11 | 12 | import torchvision.transforms as transforms
|
12 | 13 | import torchvision.datasets as datasets
|
13 | 14 |
|
14 | 15 | import resnet
|
15 | 16 |
|
| 17 | + |
16 | 18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
17 | 19 | parser.add_argument('--data', metavar='PATH', required=True,
|
18 | 20 | help='path to dataset')
|
19 | 21 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
|
20 | 22 | help='model architecture: resnet18 | resnet34 | ...'
|
21 | 23 | '(default: resnet18)')
|
22 |
| -parser.add_argument('--gen', default='gen', metavar='PATH', |
23 |
| - help='path to save generated files (default: gen)') |
24 |
| -parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N', |
25 |
| - help='number of data loading threads (default: 2)') |
26 |
| -parser.add_argument('--nEpochs', default=90, type=int, metavar='N', |
| 24 | +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', |
| 25 | + help='number of data loading workers (default: 4)') |
| 26 | +parser.add_argument('--epochs', default=90, type=int, metavar='N', |
27 | 27 | help='number of total epochs to run')
|
28 |
| -parser.add_argument('--epochNumber', default=1, type=int, metavar='N', |
| 28 | +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', |
29 | 29 | help='manual epoch number (useful on restarts)')
|
30 |
| -parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', |
31 |
| - help='mini-batch size (1 = pure stochastic) Default: 256') |
32 |
| -parser.add_argument('--lr', default=0.1, type=float, metavar='LR', |
33 |
| - help='initial learning rate') |
| 30 | +parser.add_argument('-b', '--batch-size', default=256, type=int, |
| 31 | + metavar='N', help='mini-batch size (default: 256)') |
| 32 | +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, |
| 33 | + metavar='LR', help='initial learning rate') |
34 | 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
35 | 35 | help='momentum')
|
36 |
| -parser.add_argument('--weightDecay', default=1e-4, type=float, metavar='W', |
37 |
| - help='weight decay') |
38 |
| -args = parser.parse_args() |
39 |
| - |
40 |
| -if args.arch.startswith('resnet'): |
41 |
| - model = resnet.__dict__[args.arch]() |
42 |
| - model.cuda() |
43 |
| -else: |
44 |
| - parser.error('invalid architecture: {}'.format(args.arch)) |
45 |
| - |
46 |
| -cudnn.benchmark = True |
47 |
| - |
48 |
| -# Data loading code |
49 |
| -transform = transforms.Compose([ |
50 |
| - transforms.RandomSizedCrop(224), |
51 |
| - transforms.RandomHorizontalFlip(), |
52 |
| - transforms.ToTensor(), |
53 |
| - transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], |
54 |
| - std = [ 0.229, 0.224, 0.225 ]), |
55 |
| -]) |
56 |
| - |
57 |
| -traindir = os.path.join(args.data, 'train') |
58 |
| -valdir = os.path.join(args.data, 'val') |
59 |
| -train = datasets.ImageFolder(traindir, transform) |
60 |
| -val = datasets.ImageFolder(valdir, transform) |
61 |
| -train_loader = torch.utils.data.DataLoader( |
62 |
| - train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads) |
63 |
| - |
64 |
| - |
65 |
| -# create a small container to apply DataParallel to the ResNet |
66 |
| -class DataParallel(nn.Container): |
67 |
| - def __init__(self): |
68 |
| - super(DataParallel, self).__init__( |
69 |
| - model=model, |
70 |
| - ) |
71 |
| - |
72 |
| - def forward(self, input): |
73 |
| - if torch.cuda.device_count() > 1: |
74 |
| - gpu_ids = range(torch.cuda.device_count()) |
75 |
| - return nn.parallel.data_parallel(self.model, input, gpu_ids) |
| 36 | +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, |
| 37 | + metavar='W', help='weight decay (default: 1e-4)') |
| 38 | +parser.add_argument('--print-freq', '-p', default=10, type=int, |
| 39 | + metavar='N', help='print frequency (default: 10)') |
| 40 | +parser.add_argument('--resume', default='', type=str, metavar='PATH', |
| 41 | + help='path to latest checkpoint (default: none)') |
| 42 | + |
| 43 | +best_prec1 = 0 |
| 44 | + |
| 45 | + |
| 46 | +def main(): |
| 47 | + global args, best_prec1 |
| 48 | + args = parser.parse_args() |
| 49 | + |
| 50 | + # create model |
| 51 | + if args.arch.startswith('resnet'): |
| 52 | + print("=> creating model '{}'".format(args.arch)) |
| 53 | + model = resnet.__dict__[args.arch]() |
| 54 | + model.cuda() |
| 55 | + else: |
| 56 | + parser.error('invalid architecture: {}'.format(args.arch)) |
| 57 | + |
| 58 | + # optionally resume from a checkpoint |
| 59 | + if args.resume: |
| 60 | + if os.path.isfile(args.resume): |
| 61 | + print("=> loading checkpoint '{}'".format(args.resume)) |
| 62 | + checkpoint = torch.load(args.resume) |
| 63 | + args.start_epoch = checkpoint['epoch'] |
| 64 | + best_prec1 = checkpoint['best_prec1'] |
| 65 | + model.load_state_dict(checkpoint['state_dict']) |
| 66 | + print(" | resuming from epoch {}".format(args.start_epoch)) |
76 | 67 | else:
|
77 |
| - return self.model(input.cuda()).cpu() |
| 68 | + print("=> no checkpoint found at '{}'".format(args.resume)) |
| 69 | + |
| 70 | + cudnn.benchmark = True |
| 71 | + |
| 72 | + # Data loading code |
| 73 | + traindir = os.path.join(args.data, 'train') |
| 74 | + valdir = os.path.join(args.data, 'val') |
| 75 | + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 76 | + std=[0.229, 0.224, 0.225]) |
| 77 | + |
| 78 | + train_loader = torch.utils.data.DataLoader( |
| 79 | + datasets.ImageFolder(traindir, transforms.Compose([ |
| 80 | + transforms.RandomSizedCrop(224), |
| 81 | + transforms.RandomHorizontalFlip(), |
| 82 | + transforms.ToTensor(), |
| 83 | + normalize, |
| 84 | + ])), |
| 85 | + batch_size=args.batch_size, shuffle=True, |
| 86 | + num_workers=args.workers, pin_memory=True) |
| 87 | + |
| 88 | + val_loader = torch.utils.data.DataLoader( |
| 89 | + datasets.ImageFolder(valdir, transforms.Compose([ |
| 90 | + transforms.Scale(256), |
| 91 | + transforms.CenterCrop(224), |
| 92 | + transforms.ToTensor(), |
| 93 | + normalize, |
| 94 | + ])), |
| 95 | + batch_size=args.batch_size, shuffle=False, |
| 96 | + num_workers=args.workers, pin_memory=True) |
| 97 | + |
| 98 | + # parallelize model across all visible GPUs |
| 99 | + model = torch.nn.DataParallel(model) |
| 100 | + |
| 101 | + # define loss function (criterion) and pptimizer |
| 102 | + criterion = nn.CrossEntropyLoss().cuda() |
| 103 | + |
| 104 | + optimizer = torch.optim.SGD(model.parameters(), args.lr, |
| 105 | + momentum=args.momentum, |
| 106 | + weight_decay=args.weight_decay) |
| 107 | + |
| 108 | + for epoch in range(args.start_epoch, args.epochs): |
| 109 | + adjust_learning_rate(optimizer, epoch) |
| 110 | + |
| 111 | + # train for one epoch |
| 112 | + model.train() |
| 113 | + train(train_loader, model, criterion, optimizer, epoch) |
| 114 | + |
| 115 | + # evaluate on validation set |
| 116 | + model.eval() |
| 117 | + prec1 = validate(val_loader, model, criterion) |
| 118 | + |
| 119 | + # remember best prec@1 and save checkpoint |
| 120 | + is_best = prec1 > best_prec1 |
| 121 | + best_prec1 = max(prec1, best_prec1) |
| 122 | + save_checkpoint({ |
| 123 | + 'epoch': epoch, |
| 124 | + 'arch': args.arch, |
| 125 | + 'state_dict': model.state_dict(), |
| 126 | + 'best_prec1': best_prec1, |
| 127 | + }, is_best) |
| 128 | + |
| 129 | + |
| 130 | +def train(train_loader, model, criterion, optimizer, epoch): |
| 131 | + batch_time = AverageMeter() |
| 132 | + data_time = AverageMeter() |
| 133 | + losses = AverageMeter() |
| 134 | + top1 = AverageMeter() |
| 135 | + top5 = AverageMeter() |
| 136 | + |
| 137 | + end = time.time() |
| 138 | + for i, (input, target) in enumerate(train_loader): |
| 139 | + # measure data loading time |
| 140 | + data_time.update(time.time() - end) |
| 141 | + |
| 142 | + target = target.cuda(async=True) |
| 143 | + input_var = torch.autograd.Variable(input) |
| 144 | + target_var = torch.autograd.Variable(target) |
| 145 | + |
| 146 | + # compute output |
| 147 | + output = model(input_var) |
| 148 | + loss = criterion(output, target_var) |
| 149 | + |
| 150 | + # measure accuracy and record loss |
| 151 | + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) |
| 152 | + losses.update(loss.data[0]) |
| 153 | + top1.update(prec1[0]) |
| 154 | + top5.update(prec5[0]) |
| 155 | + |
| 156 | + # compute gradient and do SGD step |
| 157 | + optimizer.zero_grad() |
| 158 | + loss.backward() |
| 159 | + optimizer.step() |
| 160 | + |
| 161 | + # measure elapsed time |
| 162 | + batch_time.update(time.time() - end) |
| 163 | + end = time.time() |
| 164 | + |
| 165 | + if i % args.print_freq == 0: |
| 166 | + print('Epoch: [{0}][{1}/{2}]\t' |
| 167 | + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 168 | + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' |
| 169 | + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
| 170 | + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' |
| 171 | + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( |
| 172 | + epoch, i, len(train_loader), batch_time=batch_time, |
| 173 | + data_time=data_time, loss=losses, top1=top1, top5=top5)) |
| 174 | + |
| 175 | + |
| 176 | +def validate(val_loader, model, criterion): |
| 177 | + batch_time = AverageMeter() |
| 178 | + losses = AverageMeter() |
| 179 | + top1 = AverageMeter() |
| 180 | + top5 = AverageMeter() |
| 181 | + |
| 182 | + end = time.time() |
| 183 | + for i, (input, target) in enumerate(val_loader): |
| 184 | + target = target.cuda(async=True) |
| 185 | + input_var = torch.autograd.Variable(input, volatile=True) |
| 186 | + target_var = torch.autograd.Variable(target, volatile=True) |
| 187 | + |
| 188 | + # compute output |
| 189 | + output = model(input_var) |
| 190 | + loss = criterion(output, target_var) |
| 191 | + |
| 192 | + # measure accuracy and record loss |
| 193 | + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) |
| 194 | + losses.update(loss.data[0]) |
| 195 | + top1.update(prec1[0]) |
| 196 | + top5.update(prec5[0]) |
| 197 | + |
| 198 | + # measure elapsed time |
| 199 | + batch_time.update(time.time() - end) |
| 200 | + end = time.time() |
| 201 | + |
| 202 | + if i % args.print_freq == 0: |
| 203 | + print('Test: [{0}/{1}]\t' |
| 204 | + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 205 | + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
| 206 | + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' |
| 207 | + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( |
| 208 | + i, len(val_loader), batch_time=batch_time, loss=losses, |
| 209 | + top1=top1, top5=top5)) |
| 210 | + |
| 211 | + return top1.avg |
| 212 | + |
| 213 | + |
| 214 | +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): |
| 215 | + torch.save(state, filename) |
| 216 | + if is_best: |
| 217 | + shutil.copyfile(filename, 'model_best.pth.tar') |
| 218 | + |
| 219 | + |
| 220 | +class AverageMeter(object): |
| 221 | + """Computes and stores the average and current value""" |
| 222 | + def __init__(self): |
| 223 | + self.reset() |
| 224 | + |
| 225 | + def reset(self): |
| 226 | + self.val = 0 |
| 227 | + self.avg = 0 |
| 228 | + self.sum = 0 |
| 229 | + self.n = 0 |
| 230 | + |
| 231 | + def update(self, val): |
| 232 | + self.val = val |
| 233 | + self.sum += val |
| 234 | + self.n += 1 |
| 235 | + self.avg = self.sum / self.n |
| 236 | + |
| 237 | + |
| 238 | +def adjust_learning_rate(optimizer, epoch): |
| 239 | + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" |
| 240 | + lr = args.lr * (0.1 ** (epoch // 30)) |
| 241 | + for param_group in optimizer.state_dict()['param_groups']: |
| 242 | + param_group['lr'] = lr |
78 | 243 |
|
79 |
| -model = DataParallel() |
80 | 244 |
|
81 |
| -# define Loss Function and Optimizer |
82 |
| -criterion = nn.CrossEntropyLoss().cuda() |
83 |
| -optimizer = torch.optim.SGD(model.parameters(), args.lr, args.momentum) |
| 245 | +def accuracy(output, target, topk=(1,)): |
| 246 | + """Computes the precision@k for the specified values of k""" |
| 247 | + maxk = max(topk) |
| 248 | + batch_size = target.size(0) |
84 | 249 |
|
| 250 | + _, pred = output.topk(maxk, True, True) |
| 251 | + pred = pred.t() |
| 252 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) |
85 | 253 |
|
86 |
| -# pass model, loss, optimizer and dataset to the trainer |
87 |
| -t = trainer.Trainer(model, criterion, optimizer, train_loader) |
| 254 | + res = [] |
| 255 | + for k in topk: |
| 256 | + correct_k = correct[:k].view(-1).float().sum(0) |
| 257 | + res.append(correct_k.mul_(100.0 / batch_size)) |
| 258 | + return res |
88 | 259 |
|
89 |
| -# register some monitoring plugins |
90 |
| -t.register_plugin(trainer.plugins.ProgressMonitor()) |
91 |
| -t.register_plugin(trainer.plugins.AccuracyMonitor()) |
92 |
| -t.register_plugin(trainer.plugins.LossMonitor()) |
93 |
| -t.register_plugin(trainer.plugins.TimeMonitor()) |
94 |
| -t.register_plugin(trainer.plugins.Logger(['progress', 'accuracy', 'loss', 'time'])) |
95 | 260 |
|
96 |
| -# train! |
97 |
| -t.run(args.nEpochs) |
| 261 | +if __name__ == '__main__': |
| 262 | + main() |
0 commit comments