Skip to content

Commit c7748fc

Browse files
prabhat00155facebook-github-bot
authored andcommitted
Added validation of mode parameter in AveragedModel (pytorch#65921)
Summary: Discussion: pytorch#65495 (comment) Pull Request resolved: pytorch#65921 Reviewed By: albanD Differential Revision: D31310105 Pulled By: prabhat00155 fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
1 parent 0fc6bd2 commit c7748fc

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torch/optim/swa_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class AveragedModel(Module):
2626
:class:`AveragedModel` parameter, the current value of :attr:`model`
2727
parameter and the number of models already averaged; if None,
2828
equally weighted average is used (default: None)
29-
mode (str, optional): whether to use parameters or state_dict for update
30-
(default: parameters)
29+
mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update
30+
(default: ``'parameters'``)
3131
3232
Example:
3333
>>> loader, optimizer, model, loss_fn = ...
@@ -98,6 +98,9 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
9898
return averaged_model_parameter + \
9999
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
100100
self.avg_fn = avg_fn
101+
modes = ['parameters', 'state_dict']
102+
if mode not in modes:
103+
raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.')
101104
self.use_state_dict = mode == 'state_dict'
102105

103106
def forward(self, *args, **kwargs):

0 commit comments

Comments
 (0)