Skip to content

Commit 6f46b30

Browse files
author
Ervin T
authored
[add-fire] Revert unneeded changes back to master (#4389)
1 parent 9406624 commit 6f46b30

File tree

9 files changed

+21
-22
lines changed

9 files changed

+21
-22
lines changed

ml-agents/mlagents/trainers/ppo/optimizer_tf.py renamed to ml-agents/mlagents/trainers/ppo/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _create_dc_critic(
177177
name="old_probabilities",
178178
)
179179

180-
# Break old log log_probs into separate branches
180+
# Break old log probs into separate branches
181181
old_log_prob_branches = ModelUtils.break_into_branches(
182182
self.all_old_log_probs, self.policy.act_size
183183
)

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mlagents.trainers.trainer.rl_trainer import RLTrainer
1313
from mlagents.trainers.policy import Policy
1414
from mlagents.trainers.policy.tf_policy import TFPolicy
15-
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
15+
from mlagents.trainers.ppo.optimizer import PPOOptimizer
1616
from mlagents.trainers.trajectory import Trajectory
1717
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
1818
from mlagents.trainers.settings import TrainerSettings, PPOSettings, FrameworkType

ml-agents/mlagents/trainers/saver/saver.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,23 @@ def _register_optimizer(self, optimizer):
3434
pass
3535

3636
@abc.abstractmethod
37-
def save_checkpoint(self, behavior_name: str, step: int) -> str:
37+
def save_checkpoint(self, brain_name: str, step: int) -> str:
3838
"""
3939
Checkpoints the policy on disk.
4040
:param checkpoint_path: filepath to write the checkpoint
41-
:param behavior_name: Behavior name of behavior to be trained
41+
:param brain_name: Brain name of brain to be trained
4242
"""
4343
pass
4444

4545
@abc.abstractmethod
46-
def export(self, output_filepath: str, behavior_name: str) -> None:
46+
def export(self, output_filepath: str, brain_name: str) -> None:
4747
"""
48-
Saves the serialized model, given a path and behavior name.
48+
Saves the serialized model, given a path and brain name.
4949
This method will save the policy graph to the given filepath. The path
5050
should be provided without an extension as multiple serialized model formats
5151
may be generated as a result.
5252
:param output_filepath: path (without suffix) for the model file(s)
53-
:param behavior_name: Behavior name of behavior to be trained.
53+
:param brain_name: Brain name of brain to be trained.
5454
"""
5555
pass
5656

ml-agents/mlagents/trainers/saver/tf_saver.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def _register_policy(self, policy: TFPolicy) -> None:
5555
with self.policy.graph.as_default():
5656
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
5757

58-
def save_checkpoint(self, behavior_name: str, step: int) -> str:
59-
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
58+
def save_checkpoint(self, brain_name: str, step: int) -> str:
59+
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
6060
# Save the TF checkpoint and graph definition
6161
if self.graph:
6262
with self.graph.as_default():
@@ -66,16 +66,16 @@ def save_checkpoint(self, behavior_name: str, step: int) -> str:
6666
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
6767
)
6868
# also save the policy so we have optimized model files for each checkpoint
69-
self.export(checkpoint_path, behavior_name)
69+
self.export(checkpoint_path, brain_name)
7070
return checkpoint_path
7171

72-
def export(self, output_filepath: str, behavior_name: str) -> None:
72+
def export(self, output_filepath: str, brain_name: str) -> None:
7373
# save model if there is only one worker or
7474
# only on worker-0 if there are multiple workers
7575
if self.policy and self.policy.rank is not None and self.policy.rank != 0:
7676
return
7777
export_policy_model(
78-
self.model_path, output_filepath, behavior_name, self.graph, self.sess
78+
self.model_path, output_filepath, brain_name, self.graph, self.sess
7979
)
8080

8181
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
@@ -94,7 +94,6 @@ def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
9494
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps)
9595
else:
9696
policy.initialize()
97-
9897
TFPolicy.broadcast_global_variables(0)
9998

10099
def _load_graph(

ml-agents/mlagents/trainers/saver/torch_saver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,19 @@ def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None:
4545
self.policy = module
4646
self.exporter = ModelSerializer(self.policy)
4747

48-
def save_checkpoint(self, behavior_name: str, step: int) -> str:
48+
def save_checkpoint(self, brain_name: str, step: int) -> str:
4949
if not os.path.exists(self.model_path):
5050
os.makedirs(self.model_path)
51-
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
51+
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
5252
state_dict = {
5353
name: module.state_dict() for name, module in self.modules.items()
5454
}
5555
torch.save(state_dict, f"{checkpoint_path}.pt")
5656
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
57-
self.export(checkpoint_path, behavior_name)
57+
self.export(checkpoint_path, brain_name)
5858
return checkpoint_path
5959

60-
def export(self, output_filepath: str, behavior_name: str) -> None:
60+
def export(self, output_filepath: str, brain_name: str) -> None:
6161
if self.exporter is not None:
6262
self.exporter.export_policy_model(output_filepath)
6363

ml-agents/mlagents/trainers/tests/test_ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from mlagents.trainers.trainer.rl_trainer import RLTrainer
1111
from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
12-
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
12+
from mlagents.trainers.ppo.optimizer import PPOOptimizer
1313
from mlagents.trainers.policy.tf_policy import TFPolicy
1414
from mlagents.trainers.agent_processor import AgentManagerQueue
1515
from mlagents.trainers.tests import mock_brain as mb

ml-agents/mlagents/trainers/tests/test_reward_signals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import mlagents.trainers.tests.mock_brain as mb
55
from mlagents.trainers.policy.tf_policy import TFPolicy
66
from mlagents.trainers.sac.optimizer import SACOptimizer
7-
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
7+
from mlagents.trainers.ppo.optimizer import PPOOptimizer
88
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG
99
from mlagents.trainers.settings import (
1010
GAILSettings,

ml-agents/mlagents/trainers/tests/test_saver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mlagents.trainers.policy.tf_policy import TFPolicy
1313
from mlagents.trainers.tests import mock_brain as mb
1414
from mlagents.trainers.tests.test_nn_policy import create_policy_mock
15-
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
15+
from mlagents.trainers.ppo.optimizer import PPOOptimizer
1616

1717

1818
def test_register(tmp_path):

ml-agents/mlagents/trainers/tf/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ def create_discrete_action_masking_layer(
510510
:param action_masks: The mask for the logits. Must be of dimension [None x total_number_of_action]
511511
:param action_size: A list containing the number of possible actions for each branch
512512
:return: The action output dimension [batch_size, num_branches], the concatenated
513-
normalized log_probs (after softmax)
514-
and the concatenated normalized log log_probs
513+
normalized probs (after softmax)
514+
and the concatenated normalized log probs
515515
"""
516516
branch_masks = ModelUtils.break_into_branches(action_masks, action_size)
517517
raw_probs = [

0 commit comments

Comments
 (0)