File tree Expand file tree Collapse file tree 2 files changed +23
-3
lines changed Expand file tree Collapse file tree 2 files changed +23
-3
lines changed Original file line number Diff line number Diff line change 1
1
from distutils .util import strtobool
2
2
import os
3
+ import shutil
3
4
from typing import Any , List , Set , NamedTuple
4
5
from distutils .version import LooseVersion
5
6
@@ -227,3 +228,20 @@ def _enforce_onnx_conversion() -> bool:
227
228
return strtobool (val )
228
229
except Exception :
229
230
return False
231
+
232
+
233
+ def copy_model_files (source_nn_path : str , destination_nn_path : str ) -> None :
234
+ """
235
+ Copy the .nn file at the given source to the destination.
236
+ Also copies the corresponding .onnx file if it exists.
237
+ """
238
+ shutil .copyfile (source_nn_path , destination_nn_path )
239
+ logger .info (f"Copied { source_nn_path } to { destination_nn_path } ." )
240
+ # Copy the onnx file if it exists
241
+ source_onnx_path = os .path .splitext (source_nn_path )[0 ] + ".onnx"
242
+ destination_onnx_path = os .path .splitext (destination_nn_path )[0 ] + ".onnx"
243
+ try :
244
+ shutil .copyfile (source_onnx_path , destination_onnx_path )
245
+ logger .info (f"Copied { source_onnx_path } to { destination_onnx_path } ." )
246
+ except OSError :
247
+ pass
Original file line number Diff line number Diff line change 5
5
import abc
6
6
import time
7
7
import attr
8
- from mlagents .model_serialization import SerializationSettings
8
+ from mlagents .model_serialization import SerializationSettings , copy_model_files
9
9
from mlagents .trainers .policy .checkpoint_manager import (
10
10
NNCheckpoint ,
11
11
NNCheckpointManager ,
@@ -131,12 +131,14 @@ def save_model(self) -> None:
131
131
"Trainer has multiple policies, but default behavior only saves the first."
132
132
)
133
133
policy = list (self .policies .values ())[0 ]
134
- settings = SerializationSettings (policy .model_path , self .brain_name )
135
134
model_checkpoint = self ._checkpoint ()
135
+
136
+ # Copy the checkpointed model files to the final output location
137
+ copy_model_files (model_checkpoint .file_path , f"{ policy .model_path } .nn" )
138
+
136
139
final_checkpoint = attr .evolve (
137
140
model_checkpoint , file_path = f"{ policy .model_path } .nn"
138
141
)
139
- policy .save (policy .model_path , settings )
140
142
NNCheckpointManager .track_final_checkpoint (self .brain_name , final_checkpoint )
141
143
142
144
@abc .abstractmethod
You can’t perform that action at this time.
0 commit comments