@@ -55,8 +55,8 @@ def _register_policy(self, policy: TFPolicy) -> None:
55
55
with self .policy .graph .as_default ():
56
56
self .tf_saver = tf .train .Saver (max_to_keep = self ._keep_checkpoints )
57
57
58
- def save_checkpoint (self , behavior_name : str , step : int ) -> str :
59
- checkpoint_path = os .path .join (self .model_path , f"{ behavior_name } -{ step } " )
58
+ def save_checkpoint (self , brain_name : str , step : int ) -> str :
59
+ checkpoint_path = os .path .join (self .model_path , f"{ brain_name } -{ step } " )
60
60
# Save the TF checkpoint and graph definition
61
61
if self .graph :
62
62
with self .graph .as_default ():
@@ -66,16 +66,16 @@ def save_checkpoint(self, behavior_name: str, step: int) -> str:
66
66
self .graph , self .model_path , "raw_graph_def.pb" , as_text = False
67
67
)
68
68
# also save the policy so we have optimized model files for each checkpoint
69
- self .export (checkpoint_path , behavior_name )
69
+ self .export (checkpoint_path , brain_name )
70
70
return checkpoint_path
71
71
72
- def export (self , output_filepath : str , behavior_name : str ) -> None :
72
+ def export (self , output_filepath : str , brain_name : str ) -> None :
73
73
# save model if there is only one worker or
74
74
# only on worker-0 if there are multiple workers
75
75
if self .policy and self .policy .rank is not None and self .policy .rank != 0 :
76
76
return
77
77
export_policy_model (
78
- self .model_path , output_filepath , behavior_name , self .graph , self .sess
78
+ self .model_path , output_filepath , brain_name , self .graph , self .sess
79
79
)
80
80
81
81
def initialize_or_load (self , policy : Optional [TFPolicy ] = None ) -> None :
@@ -94,7 +94,6 @@ def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
94
94
self ._load_graph (policy , self .model_path , reset_global_steps = reset_steps )
95
95
else :
96
96
policy .initialize ()
97
-
98
97
TFPolicy .broadcast_global_variables (0 )
99
98
100
99
def _load_graph (
0 commit comments