From 015481c49580cf2c15b40578a77ef05d1472b277 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 14 Sep 2021 10:37:29 +0100 Subject: [PATCH] Added update_parameters to EMA to fix calculation --- references/classification/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/references/classification/utils.py b/references/classification/utils.py index 644f1c4708a..bf7662ad023 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -172,6 +172,17 @@ def __init__(self, model, decay, device='cpu'): decay * avg_model_param + (1 - decay) * model_param) super().__init__(model, device, ema_avg) + def update_parameters(self, model): + for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()): + device = p_swa.device + p_model_ = p_model.detach().to(device) + if self.n_averaged == 0: + p_swa.detach().copy_(p_model_) + else: + p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, + self.n_averaged.to(device))) + self.n_averaged += 1 + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k"""