|
19 | 19 | amp = None
|
20 | 20 |
|
21 | 21 |
|
22 |
| -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, |
23 |
| - print_freq, args, apex=False, model_ema=None): |
| 22 | +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None): |
24 | 23 | model.train()
|
25 | 24 | metric_logger = utils.MetricLogger(delimiter=" ")
|
26 | 25 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
|
27 | 26 | metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
|
28 | 27 |
|
29 | 28 | header = 'Epoch: [{}]'.format(epoch)
|
30 |
| - for i, (image, target) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
| 29 | + for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): |
31 | 30 | start_time = time.time()
|
32 | 31 | image, target = image.to(device), target.to(device)
|
33 | 32 | output = model(image)
|
34 | 33 | loss = criterion(output, target)
|
35 | 34 |
|
36 | 35 | optimizer.zero_grad()
|
37 |
| - if apex: |
| 36 | + if args.apex: |
38 | 37 | with amp.scale_loss(loss, optimizer) as scaled_loss:
|
39 | 38 | scaled_loss.backward()
|
40 | 39 | else:
|
@@ -197,16 +196,22 @@ def main(args):
|
197 | 196 |
|
198 | 197 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
199 | 198 |
|
| 199 | + if args.norm_weight_decay is None: |
| 200 | + parameters = model.parameters() |
| 201 | + else: |
| 202 | + param_groups = torchvision.ops._utils.split_normalization_params(model) |
| 203 | + wd_groups = [args.norm_weight_decay, args.weight_decay] |
| 204 | + parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] |
| 205 | + |
200 | 206 | opt_name = args.opt.lower()
|
201 | 207 | if opt_name.startswith("sgd"):
|
202 |
| - optimizer = torch.optim.SGD( |
203 |
| - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, |
204 |
| - nesterov="nesterov" in opt_name) |
| 208 | + optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, |
| 209 | + nesterov="nesterov" in opt_name) |
205 | 210 | elif opt_name == 'rmsprop':
|
206 |
| - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, |
207 |
| - weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) |
| 211 | + optimizer = torch.optim.RMSprop(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, |
| 212 | + eps=0.0316, alpha=0.9) |
208 | 213 | elif opt_name == 'adamw':
|
209 |
| - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| 214 | + optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) |
210 | 215 | else:
|
211 | 216 | raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
|
212 | 217 |
|
@@ -281,7 +286,7 @@ def main(args):
|
281 | 286 | for epoch in range(args.start_epoch, args.epochs):
|
282 | 287 | if args.distributed:
|
283 | 288 | train_sampler.set_epoch(epoch)
|
284 |
| - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) |
| 289 | + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema) |
285 | 290 | lr_scheduler.step()
|
286 | 291 | evaluate(model, criterion, data_loader_test, device=device)
|
287 | 292 | if model_ema:
|
@@ -326,6 +331,8 @@ def get_args_parser(add_help=True):
|
326 | 331 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
327 | 332 | metavar='W', help='weight decay (default: 1e-4)',
|
328 | 333 | dest='weight_decay')
|
| 334 | + parser.add_argument('--norm-weight-decay', default=None, type=float, |
| 335 | + help='weight decay for Normalization layers (default: None, same value as --wd)') |
329 | 336 | parser.add_argument('--label-smoothing', default=0.0, type=float,
|
330 | 337 | help='label smoothing (default: 0.0)',
|
331 | 338 | dest='label_smoothing')
|
|
0 commit comments