Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 9f405b2

Browse files
vreisfacebook-github-bot
authored andcommitted
Save/load amp state in checkpoints (#435)
Summary: Pull Request resolved: #435 This is needed to restore the loss scale found by AMP. Reviewed By: mannatsingh Differential Revision: D20430214 fbshipit-source-id: b6257884f94a7e92bdf9abf2d24f8428b3926df7
1 parent da14757 commit 9f405b2

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

classy_vision/tasks/classification_task.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,15 @@ def prepare(
539539
# the appropriate device
540540
self.optimizer.init_pytorch_optimizer(self.base_model, loss=self.loss)
541541

542+
if self.amp_args is not None:
543+
# Initialize apex.amp. This updates the model and the PyTorch optimizer (
544+
# which is wrapped by the ClassyOptimizer in self.optimizer).
545+
# Please note this must happen before loading the checkpoint, cause
546+
# there's amp state to be restored.
547+
self.base_model, self.optimizer.optimizer = apex.amp.initialize(
548+
self.base_model, self.optimizer.optimizer, **self.amp_args
549+
)
550+
542551
classy_state_dict = (
543552
None
544553
if self.checkpoint is None
@@ -551,12 +560,6 @@ def prepare(
551560
state_load_success
552561
), "Update classy state from checkpoint was unsuccessful."
553562

554-
if self.amp_args is not None:
555-
# Initialize apex.amp. This updates the model and the PyTorch optimizer (
556-
# which is wrapped by the ClassyOptimizer in self.optimizer)
557-
self.base_model, self.optimizer.optimizer = apex.amp.initialize(
558-
self.base_model, self.optimizer.optimizer, **self.amp_args
559-
)
560563
self.init_distributed_data_parallel_model()
561564

562565
def init_distributed_data_parallel_model(self):
@@ -629,6 +632,8 @@ def get_classy_state(self, deep_copy: bool = False):
629632
}
630633
if isinstance(self.loss, ClassyLoss):
631634
classy_state_dict["loss"] = self.loss.get_classy_state()
635+
if self.amp_args is not None:
636+
classy_state_dict["amp"] = apex.amp.state_dict()
632637
if deep_copy:
633638
classy_state_dict = copy.deepcopy(classy_state_dict)
634639
return classy_state_dict
@@ -654,6 +659,9 @@ def set_classy_state(self, state):
654659
if state.get("loss") and isinstance(self.loss, ClassyLoss):
655660
self.loss.set_classy_state(state["loss"])
656661

662+
if "amp" in state:
663+
apex.amp.load_state_dict(state["amp"])
664+
657665
for hook in self.hooks:
658666
# we still want to be able to run when new hooks are added or old
659667
# hooks are removed

0 commit comments

Comments
 (0)