Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ml-agents/mlagents/model_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from distutils.util import strtobool
import os
import shutil
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion

Expand Down Expand Up @@ -227,3 +228,20 @@ def _enforce_onnx_conversion() -> bool:
return strtobool(val)
except Exception:
return False


def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
shutil.copyfile(source_nn_path, destination_nn_path)
logger.info(f"Copied {source_nn_path} to {destination_nn_path}.")
# Copy the onnx file if it exists
source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx"
destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx"
try:
shutil.copyfile(source_onnx_path, destination_onnx_path)
logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.")
except OSError:
pass
8 changes: 5 additions & 3 deletions ml-agents/mlagents/trainers/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import abc
import time
import attr
from mlagents.model_serialization import SerializationSettings
from mlagents.model_serialization import SerializationSettings, copy_model_files
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,
Expand Down Expand Up @@ -131,12 +131,14 @@ def save_model(self) -> None:
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
settings = SerializationSettings(policy.model_path, self.brain_name)
model_checkpoint = self._checkpoint()

# Copy the checkpointed model files to the final output location
copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")

final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{policy.model_path}.nn"
)
policy.save(policy.model_path, settings)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)

@abc.abstractmethod
Expand Down