-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added update_parameters to EMA to fix calculation #4406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Prabhat.
As @kazhang this is an issue on PyTorch core. Ideally the base class should include all state not just the params. We use a similar approach in other averaging schemes (see this) so this is aligned to what we've seen in the past. Though this workaround will do for our use-case, I think it's still worth raising it on PyTorch core and either correcting it or introducing better control over which params we should consider.
Makes sense. Will follow-up on this. |
…65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: #65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
Reviewed By: datumbox Differential Revision: D31268055 fbshipit-source-id: 2bedf7cd5db0a345dffa42a9ff94ce7d425e1008
…ytorch#65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: pytorch#65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2 (cherry picked from commit 2ea724b)
…65495) (#65755) * Added option to update parameters using state_dict in AveragedModel (#65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: #65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2 (cherry picked from commit 2ea724b) * Added validation of mode parameter in AveragedModel (#65921) Summary: Discussion: #65495 (comment) Pull Request resolved: #65921 Reviewed By: albanD Differential Revision: D31310105 Pulled By: prabhat00155 fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3 (cherry picked from commit c7748fc)
Resolves #4391.
Output logs:
resnext101_32x8d_training_log_ema_fix.txt
resnet18_training_log_ema_fix.txt
cc @datumbox