diff --git a/references/classification/train.py b/references/classification/train.py index 79b99156a05..89eae31c2cd 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -17,7 +17,8 @@ amp = None -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False): +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, + print_freq, apex=False, model_ema=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) @@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) + if model_ema: + model_ema.update_parameters(model) -def evaluate(model, criterion, data_loader, device, print_freq=100): + +def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = f'Test: {log_suffix}' with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) @@ -199,12 +203,18 @@ def main(args): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module + model_ema = None + if args.model_ema: + model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) + if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 + if model_ema: + model_ema.load_state_dict(checkpoint['model_ema']) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -215,9 +225,11 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) + if model_ema: + evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA') if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), @@ -225,6 +237,8 @@ def main(args): 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args} + if model_ema: + checkpoint['model_ema'] = model_ema.state_dict() utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) @@ -306,6 +320,12 @@ def get_args_parser(add_help=True): parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument( + '--model-ema', action='store_true', + help='enable tracking Exponential Moving Average of model parameters') + parser.add_argument( + '--model-ema-decay', type=float, default=0.99, + help='decay factor for Exponential Moving Average of model parameters(default: 0.99)') return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 4e53ed1d3d7..644f1c4708a 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -161,6 +161,18 @@ def log_every(self, iterable, print_freq, header=None): print('{} Total time: {}'.format(header, total_time_str)) +class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): + """Maintains moving averages of model parameters using an exponential decay. + ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` + `torch.optim.swa_utils.AveragedModel `_ + is used to compute the EMA. + """ + def __init__(self, model, decay, device='cpu'): + ema_avg = (lambda avg_model_param, model_param, num_averaged: + decay * avg_model_param + (1 - decay) * model_param) + super().__init__(model, device, ema_avg) + + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad():