-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Pytorch ghost trainer #4370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch ghost trainer #4370
Changes from all commits
88b9e40
c180bed
16690d9
d04fe05
d964085
9588362
b2da109
ef4a5a0
453e1aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from typing import Any, Dict, List | ||
import numpy as np | ||
import torch | ||
import copy | ||
|
||
from mlagents.trainers.action_info import ActionInfo | ||
from mlagents.trainers.behavior_id_utils import get_global_agent_id | ||
|
@@ -249,13 +250,13 @@ def increment_step(self, n_steps): | |
return self.get_current_step() | ||
|
||
def load_weights(self, values: List[np.ndarray]) -> None: | ||
pass | ||
self.actor_critic.load_state_dict(values) | ||
|
||
def init_load_weights(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We dont need this function with torch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a comment and mark for removal when TF is deprecated |
||
pass | ||
|
||
def get_weights(self) -> List[np.ndarray]: | ||
return [] | ||
return copy.deepcopy(self.actor_critic.state_dict()) | ||
|
||
def get_modules(self): | ||
return {"Policy": self.actor_critic, "global_step": self.global_step} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -217,18 +217,23 @@ def _update_policy(self): | |
return True | ||
|
||
def create_tf_policy( | ||
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec | ||
self, | ||
parsed_behavior_id: BehaviorIdentifiers, | ||
behavior_spec: BehaviorSpec, | ||
create_graph: bool = False, | ||
) -> TFPolicy: | ||
""" | ||
Creates a PPO policy to trainers list of policies. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what create_graph does. Can you add a comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added. It sets the |
||
:param behavior_spec: specifications for policy construction | ||
:param create_graph: whether to create the graph when policy is constructed | ||
:return policy | ||
""" | ||
policy = TFPolicy( | ||
self.seed, | ||
behavior_spec, | ||
self.trainer_settings, | ||
condition_sigma_on_obs=False, # Faster training for PPO | ||
create_tf_graph=create_graph, | ||
) | ||
return policy | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
|
||
from mlagents.trainers.ghost.trainer import GhostTrainer | ||
from mlagents.trainers.ghost.controller import GhostController | ||
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers | ||
from mlagents.trainers.ppo.trainer import PPOTrainer | ||
from mlagents.trainers.agent_processor import AgentManagerQueue | ||
from mlagents.trainers.tests import mock_brain as mb | ||
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory | ||
from mlagents.trainers.settings import TrainerSettings, SelfPlaySettings, FrameworkType | ||
|
||
|
||
@pytest.fixture | ||
def dummy_config(): | ||
return TrainerSettings( | ||
self_play=SelfPlaySettings(), framework=FrameworkType.PYTORCH | ||
) | ||
|
||
|
||
VECTOR_ACTION_SPACE = 1 | ||
VECTOR_OBS_SPACE = 8 | ||
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] | ||
BUFFER_INIT_SAMPLES = 513 | ||
NUM_AGENTS = 12 | ||
|
||
|
||
@pytest.mark.parametrize("use_discrete", [True, False]) | ||
def test_load_and_set(dummy_config, use_discrete): | ||
mock_specs = mb.setup_test_behavior_specs( | ||
use_discrete, | ||
False, | ||
vector_action_space=DISCRETE_ACTION_SPACE | ||
if use_discrete | ||
else VECTOR_ACTION_SPACE, | ||
vector_obs_space=VECTOR_OBS_SPACE, | ||
) | ||
|
||
trainer_params = dummy_config | ||
trainer = PPOTrainer("test", 0, trainer_params, True, False, 0, "0") | ||
trainer.seed = 1 | ||
policy = trainer.create_policy("test", mock_specs) | ||
trainer.seed = 20 # otherwise graphs are the same | ||
to_load_policy = trainer.create_policy("test", mock_specs) | ||
|
||
weights = policy.get_weights() | ||
load_weights = to_load_policy.get_weights() | ||
try: | ||
for w, lw in zip(weights, load_weights): | ||
np.testing.assert_array_equal(w, lw) | ||
except AssertionError: | ||
pass | ||
|
||
to_load_policy.load_weights(weights) | ||
load_weights = to_load_policy.get_weights() | ||
|
||
for w, lw in zip(weights, load_weights): | ||
np.testing.assert_array_equal(w, lw) | ||
|
||
|
||
def test_process_trajectory(dummy_config): | ||
mock_specs = mb.setup_test_behavior_specs( | ||
True, False, vector_action_space=[2], vector_obs_space=1 | ||
) | ||
behavior_id_team0 = "test_brain?team=0" | ||
behavior_id_team1 = "test_brain?team=1" | ||
brain_name = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0).brain_name | ||
|
||
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0") | ||
controller = GhostController(100) | ||
trainer = GhostTrainer( | ||
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0" | ||
) | ||
|
||
# first policy encountered becomes policy trained by wrapped PPO | ||
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0) | ||
policy = trainer.create_policy(parsed_behavior_id0, mock_specs) | ||
trainer.add_policy(parsed_behavior_id0, policy) | ||
trajectory_queue0 = AgentManagerQueue(behavior_id_team0) | ||
trainer.subscribe_trajectory_queue(trajectory_queue0) | ||
|
||
# Ghost trainer should ignore this queue because off policy | ||
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team1) | ||
policy = trainer.create_policy(parsed_behavior_id1, mock_specs) | ||
trainer.add_policy(parsed_behavior_id1, policy) | ||
trajectory_queue1 = AgentManagerQueue(behavior_id_team1) | ||
trainer.subscribe_trajectory_queue(trajectory_queue1) | ||
|
||
time_horizon = 15 | ||
trajectory = make_fake_trajectory( | ||
length=time_horizon, | ||
max_step_complete=True, | ||
observation_shapes=[(1,)], | ||
action_space=[2], | ||
) | ||
trajectory_queue0.put(trajectory) | ||
trainer.advance() | ||
|
||
# Check that trainer put trajectory in update buffer | ||
assert trainer.trainer.update_buffer.num_experiences == 15 | ||
|
||
trajectory_queue1.put(trajectory) | ||
trainer.advance() | ||
|
||
# Check that ghost trainer ignored off policy queue | ||
assert trainer.trainer.update_buffer.num_experiences == 15 | ||
# Check that it emptied the queue | ||
assert trajectory_queue1.empty() | ||
|
||
|
||
def test_publish_queue(dummy_config): | ||
mock_specs = mb.setup_test_behavior_specs( | ||
True, False, vector_action_space=[1], vector_obs_space=8 | ||
) | ||
|
||
behavior_id_team0 = "test_brain?team=0" | ||
behavior_id_team1 = "test_brain?team=1" | ||
|
||
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0) | ||
|
||
brain_name = parsed_behavior_id0.brain_name | ||
|
||
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0") | ||
controller = GhostController(100) | ||
trainer = GhostTrainer( | ||
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0" | ||
) | ||
|
||
# First policy encountered becomes policy trained by wrapped PPO | ||
# This queue should remain empty after swap snapshot | ||
policy = trainer.create_policy(parsed_behavior_id0, mock_specs) | ||
trainer.add_policy(parsed_behavior_id0, policy) | ||
policy_queue0 = AgentManagerQueue(behavior_id_team0) | ||
trainer.publish_policy_queue(policy_queue0) | ||
|
||
# Ghost trainer should use this queue for ghost policy swap | ||
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team1) | ||
policy = trainer.create_policy(parsed_behavior_id1, mock_specs) | ||
trainer.add_policy(parsed_behavior_id1, policy) | ||
policy_queue1 = AgentManagerQueue(behavior_id_team1) | ||
trainer.publish_policy_queue(policy_queue1) | ||
|
||
# check ghost trainer swap pushes to ghost queue and not trainer | ||
assert policy_queue0.empty() and policy_queue1.empty() | ||
trainer._swap_snapshots() | ||
assert policy_queue0.empty() and not policy_queue1.empty() | ||
# clear | ||
policy_queue1.get_nowait() | ||
|
||
mock_specs = mb.setup_test_behavior_specs( | ||
False, | ||
False, | ||
vector_action_space=VECTOR_ACTION_SPACE, | ||
vector_obs_space=VECTOR_OBS_SPACE, | ||
) | ||
|
||
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_specs) | ||
# Mock out reward signal eval | ||
buffer["extrinsic_rewards"] = buffer["environment_rewards"] | ||
buffer["extrinsic_returns"] = buffer["environment_rewards"] | ||
buffer["extrinsic_value_estimates"] = buffer["environment_rewards"] | ||
buffer["curiosity_rewards"] = buffer["environment_rewards"] | ||
buffer["curiosity_returns"] = buffer["environment_rewards"] | ||
buffer["curiosity_value_estimates"] = buffer["environment_rewards"] | ||
buffer["advantages"] = buffer["environment_rewards"] | ||
trainer.trainer.update_buffer = buffer | ||
|
||
# when ghost trainer advance and wrapped trainer buffers full | ||
# the wrapped trainer pushes updated policy to correct queue | ||
assert policy_queue0.empty() and policy_queue1.empty() | ||
trainer.advance() | ||
assert not policy_queue0.empty() and policy_queue1.empty() | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit strange to have Ghost specific code in the TF policy. Isn't there another reason to call init_load_weights here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, this was tf specific code in the ghost....
I don't need to call it ghost specific but the ghost trainer is the only thing that uses the load/get methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think we can do without that comment but I don't mind leaving it either.