From 3943ca9e6474fa1136c76d837b46bf4738ac98df Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 7 Sep 2021 23:31:40 +0100 Subject: [PATCH 1/3] Added Exponential Moving Average support to classification reference script --- references/classification/train.py | 28 ++++++++++++++++++++++++---- references/classification/utils.py | 19 +++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 79b99156a05..8cd43cd2f90 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -17,7 +17,8 @@ amp = None -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False): +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, + print_freq, apex=False, model_ema=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) @@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) + if model_ema: + model_ema.update_parameters(model) -def evaluate(model, criterion, data_loader, device, print_freq=100): + +def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = f'Test: {log_suffix}' with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) @@ -199,12 +203,18 @@ def main(args): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module + model_ema = None + if args.model_ema: + model_ema = utils.ExponentialMovingAverage(model, device=device, decay=args.model_ema_decay) + if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 + if model_ema: + model_ema.load_state_dict(checkpoint['model_ema']) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -215,9 +225,11 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) + if model_ema: + evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA') if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), @@ -225,6 +237,8 @@ def main(args): 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args} + if model_ema: + checkpoint['model_ema'] = model_ema.state_dict() utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) @@ -306,6 +320,12 @@ def get_args_parser(add_help=True): parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument( + '--model-ema', action='store_true', + help='enable tracking Exponential Moving Average of model parameters') + parser.add_argument( + '--model-ema-decay', type=float, default=0.99, + help='decay factor for Exponential Moving Average of model parameters(default: 0.99)') return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 4e53ed1d3d7..245bf181df1 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -161,6 +161,25 @@ def log_every(self, iterable, print_freq, header=None): print('{} Total time: {}'.format(header, total_time_str)) +class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): + """Maintains moving averages of model parameters using an + exponential decay. + `ema_avg = decay * avg_model_param + (1 - decay) * model_param` + `torch.optim.swa_utils.AveragedModel ` + is used to compute the EMA. + """ + def __init__(self, model, decay, device='cpu', name='ExponentialMovingAverage'): + ema_avg = (lambda avg_model_param, model_param, num_averaged: + decay * avg_model_param + (1 - decay) * model_param) + super().__init__(model, device, ema_avg) + self._name = name + + @property + def name(self): + """ExponentialMovingAverage object's name.""" + return self._name + + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): From 8ca6212eb364e5aad33703705b1a6e50b0316826 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 8 Sep 2021 17:10:09 +0100 Subject: [PATCH 2/3] Addressed review comments --- references/classification/utils.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index 245bf181df1..644f1c4708a 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -162,22 +162,15 @@ def log_every(self, iterable, print_freq, header=None): class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): - """Maintains moving averages of model parameters using an - exponential decay. - `ema_avg = decay * avg_model_param + (1 - decay) * model_param` - `torch.optim.swa_utils.AveragedModel ` + """Maintains moving averages of model parameters using an exponential decay. + ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` + `torch.optim.swa_utils.AveragedModel `_ is used to compute the EMA. """ - def __init__(self, model, decay, device='cpu', name='ExponentialMovingAverage'): + def __init__(self, model, decay, device='cpu'): ema_avg = (lambda avg_model_param, model_param, num_averaged: decay * avg_model_param + (1 - decay) * model_param) super().__init__(model, device, ema_avg) - self._name = name - - @property - def name(self): - """ExponentialMovingAverage object's name.""" - return self._name def accuracy(output, target, topk=(1,)): From 79856190d80eadeae706fc32807f2f2cc4cb4678 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 9 Sep 2021 16:55:25 +0100 Subject: [PATCH 3/3] Updated model argument --- references/classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 8cd43cd2f90..89eae31c2cd 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -205,7 +205,7 @@ def main(args): model_ema = None if args.model_ema: - model_ema = utils.ExponentialMovingAverage(model, device=device, decay=args.model_ema_decay) + model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu')