Skip to content

Commit b68beb4

Browse files
author
Chris Elion
authored
Don't save model twice, copy instead (#4302)
* Don't save model twice, copy instead * narrower exception
1 parent 9d92e8a commit b68beb4

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

ml-agents/mlagents/model_serialization.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from distutils.util import strtobool
22
import os
3+
import shutil
34
from typing import Any, List, Set, NamedTuple
45
from distutils.version import LooseVersion
56

@@ -227,3 +228,20 @@ def _enforce_onnx_conversion() -> bool:
227228
return strtobool(val)
228229
except Exception:
229230
return False
231+
232+
233+
def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None:
234+
"""
235+
Copy the .nn file at the given source to the destination.
236+
Also copies the corresponding .onnx file if it exists.
237+
"""
238+
shutil.copyfile(source_nn_path, destination_nn_path)
239+
logger.info(f"Copied {source_nn_path} to {destination_nn_path}.")
240+
# Copy the onnx file if it exists
241+
source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx"
242+
destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx"
243+
try:
244+
shutil.copyfile(source_onnx_path, destination_onnx_path)
245+
logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.")
246+
except OSError:
247+
pass

ml-agents/mlagents/trainers/trainer/rl_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import abc
66
import time
77
import attr
8-
from mlagents.model_serialization import SerializationSettings
8+
from mlagents.model_serialization import SerializationSettings, copy_model_files
99
from mlagents.trainers.policy.checkpoint_manager import (
1010
NNCheckpoint,
1111
NNCheckpointManager,
@@ -131,12 +131,14 @@ def save_model(self) -> None:
131131
"Trainer has multiple policies, but default behavior only saves the first."
132132
)
133133
policy = list(self.policies.values())[0]
134-
settings = SerializationSettings(policy.model_path, self.brain_name)
135134
model_checkpoint = self._checkpoint()
135+
136+
# Copy the checkpointed model files to the final output location
137+
copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")
138+
136139
final_checkpoint = attr.evolve(
137140
model_checkpoint, file_path=f"{policy.model_path}.nn"
138141
)
139-
policy.save(policy.model_path, settings)
140142
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
141143

142144
@abc.abstractmethod

0 commit comments

Comments
 (0)