Skip to content

Commit 0539915

Browse files
author
Ervin T
authored
[refactor] Make classes except Optimizer framework agnostic (#4268)
1 parent 1ad1e05 commit 0539915

File tree

11 files changed

+46
-84
lines changed

11 files changed

+46
-84
lines changed

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
EnvironmentStats,
1515
)
1616
from mlagents.trainers.trajectory import Trajectory, AgentExperience
17-
from mlagents.trainers.policy.tf_policy import TFPolicy
1817
from mlagents.trainers.policy import Policy
1918
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
2019
from mlagents.trainers.stats import StatsReporter
@@ -32,7 +31,7 @@ class AgentProcessor:
3231

3332
def __init__(
3433
self,
35-
policy: TFPolicy,
34+
policy: Policy,
3635
behavior_id: str,
3736
stats_reporter: StatsReporter,
3837
max_trajectory_length: int = sys.maxsize,
@@ -290,7 +289,7 @@ class AgentManager(AgentProcessor):
290289

291290
def __init__(
292291
self,
293-
policy: TFPolicy,
292+
policy: Policy,
294293
behavior_id: str,
295294
stats_reporter: StatsReporter,
296295
max_trajectory_length: int = sys.maxsize,

ml-agents/mlagents/trainers/env_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from mlagents_envs.side_channel.stats_side_channel import EnvironmentStats
1010

11-
from mlagents.trainers.policy.tf_policy import TFPolicy
11+
from mlagents.trainers.policy import Policy
1212
from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
1313
from mlagents.trainers.action_info import ActionInfo
1414
from mlagents_envs.logging_util import get_logger
@@ -36,11 +36,11 @@ def empty(worker_id: int) -> "EnvironmentStep":
3636

3737
class EnvManager(ABC):
3838
def __init__(self):
39-
self.policies: Dict[BehaviorName, TFPolicy] = {}
39+
self.policies: Dict[BehaviorName, Policy] = {}
4040
self.agent_managers: Dict[BehaviorName, AgentManager] = {}
4141
self.first_step_infos: List[EnvironmentStep] = []
4242

43-
def set_policy(self, brain_name: BehaviorName, policy: TFPolicy) -> None:
43+
def set_policy(self, brain_name: BehaviorName, policy: Policy) -> None:
4444
self.policies[brain_name] = policy
4545
if brain_name in self.agent_managers:
4646
self.agent_managers[brain_name].policy = policy

ml-agents/mlagents/trainers/ghost/trainer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
# ## ML-Agent Learning (Ghost Trainer)
33

44
from collections import defaultdict
5-
from typing import Deque, Dict, DefaultDict, List, cast
5+
from typing import Deque, Dict, DefaultDict, List
66

77
import numpy as np
88

99
from mlagents_envs.logging_util import get_logger
1010
from mlagents_envs.base_env import BehaviorSpec
1111
from mlagents.trainers.policy import Policy
12-
from mlagents.trainers.policy.tf_policy import TFPolicy
1312

1413
from mlagents.trainers.trainer import Trainer
1514
from mlagents.trainers.trajectory import Trajectory
@@ -262,7 +261,7 @@ def advance(self) -> None:
262261
for brain_name in self._internal_policy_queues:
263262
internal_policy_queue = self._internal_policy_queues[brain_name]
264263
try:
265-
policy = cast(TFPolicy, internal_policy_queue.get_nowait())
264+
policy = internal_policy_queue.get_nowait()
266265
self.current_policy_snapshot[brain_name] = policy.get_weights()
267266
except AgentManagerQueue.Empty:
268267
pass
@@ -306,7 +305,7 @@ def save_model(self) -> None:
306305

307306
def create_policy(
308307
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
309-
) -> TFPolicy:
308+
) -> Policy:
310309
"""
311310
Creates policy with the wrapped trainer's create_policy function
312311
The first policy encountered sets the wrapped
@@ -339,7 +338,7 @@ def create_policy(
339338
return policy
340339

341340
def add_policy(
342-
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
341+
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
343342
) -> None:
344343
"""
345344
Adds policy to GhostTrainer.
@@ -350,7 +349,7 @@ def add_policy(
350349
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id
351350
self.policies[name_behavior_id] = policy
352351

353-
def get_policy(self, name_behavior_id: str) -> TFPolicy:
352+
def get_policy(self, name_behavior_id: str) -> Policy:
354353
"""
355354
Gets policy associated with name_behavior_id
356355
:param name_behavior_id: Fully qualified behavior name

ml-agents/mlagents/trainers/optimizer/optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ class Optimizer(abc.ABC):
1010
Provides methods to update the Policy.
1111
"""
1212

13+
def __init__(self):
14+
self.reward_signals = {}
15+
1316
@abc.abstractmethod
1417
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
1518
"""

ml-agents/mlagents/trainers/optimizer/tf_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
class TFOptimizer(Optimizer): # pylint: disable=W0223
1717
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings):
18+
super().__init__()
1819
self.sess = policy.sess
1920
self.policy = policy
2021
self.update_dict: Dict[str, tf.Tensor] = {}
@@ -129,7 +130,6 @@ def create_reward_signals(
129130
Create reward signals
130131
:param reward_signal_configs: Reward signal config.
131132
"""
132-
self.reward_signals = {}
133133
# Create reward signals
134134
for reward_signal, settings in reward_signal_configs.items():
135135
# Name reward signals by string in case we have duplicates later

ml-agents/mlagents/trainers/policy/policy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,15 @@ def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> N
158158
@abstractmethod
159159
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
160160
pass
161+
162+
@abstractmethod
163+
def load_weights(self, values: List[np.ndarray]) -> None:
164+
pass
165+
166+
@abstractmethod
167+
def get_weights(self) -> List[np.ndarray]:
168+
return []
169+
170+
@abstractmethod
171+
def init_load_weights(self) -> None:
172+
pass

ml-agents/mlagents/trainers/policy/tf_policy.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -376,62 +376,6 @@ def fill_eval_dict(self, feed_dict, batched_step_result):
376376
feed_dict[self.action_masks] = mask
377377
return feed_dict
378378

379-
def make_empty_memory(self, num_agents):
380-
"""
381-
Creates empty memory for use with RNNs
382-
:param num_agents: Number of agents.
383-
:return: Numpy array of zeros.
384-
"""
385-
return np.zeros((num_agents, self.m_size), dtype=np.float32)
386-
387-
def save_memories(
388-
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
389-
) -> None:
390-
if memory_matrix is None:
391-
return
392-
for index, agent_id in enumerate(agent_ids):
393-
self.memory_dict[agent_id] = memory_matrix[index, :]
394-
395-
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
396-
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
397-
for index, agent_id in enumerate(agent_ids):
398-
if agent_id in self.memory_dict:
399-
memory_matrix[index, :] = self.memory_dict[agent_id]
400-
return memory_matrix
401-
402-
def remove_memories(self, agent_ids):
403-
for agent_id in agent_ids:
404-
if agent_id in self.memory_dict:
405-
self.memory_dict.pop(agent_id)
406-
407-
def make_empty_previous_action(self, num_agents):
408-
"""
409-
Creates empty previous action for use with RNNs and discrete control
410-
:param num_agents: Number of agents.
411-
:return: Numpy array of zeros.
412-
"""
413-
return np.zeros((num_agents, self.num_branches), dtype=np.int)
414-
415-
def save_previous_action(
416-
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
417-
) -> None:
418-
if action_matrix is None:
419-
return
420-
for index, agent_id in enumerate(agent_ids):
421-
self.previous_action_dict[agent_id] = action_matrix[index, :]
422-
423-
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
424-
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
425-
for index, agent_id in enumerate(agent_ids):
426-
if agent_id in self.previous_action_dict:
427-
action_matrix[index, :] = self.previous_action_dict[agent_id]
428-
return action_matrix
429-
430-
def remove_previous_action(self, agent_ids):
431-
for agent_id in agent_ids:
432-
if agent_id in self.previous_action_dict:
433-
self.previous_action_dict.pop(agent_id)
434-
435379
def get_current_step(self):
436380
"""
437381
Gets current model step.

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mlagents_envs.logging_util import get_logger
1111
from mlagents_envs.base_env import BehaviorSpec
1212
from mlagents.trainers.trainer.rl_trainer import RLTrainer
13+
from mlagents.trainers.policy import Policy
1314
from mlagents.trainers.policy.tf_policy import TFPolicy
1415
from mlagents.trainers.ppo.optimizer import PPOOptimizer
1516
from mlagents.trainers.trajectory import Trajectory
@@ -51,7 +52,7 @@ def __init__(
5152
)
5253
self.load = load
5354
self.seed = seed
54-
self.policy: TFPolicy = None # type: ignore
55+
self.policy: Policy = None # type: ignore
5556

5657
def _process_trajectory(self, trajectory: Trajectory) -> None:
5758
"""
@@ -208,7 +209,7 @@ def create_policy(
208209
return policy
209210

210211
def add_policy(
211-
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
212+
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
212213
) -> None:
213214
"""
214215
Adds policy to trainer.
@@ -224,13 +225,15 @@ def add_policy(
224225
)
225226
self.policy = policy
226227
self.policies[parsed_behavior_id.behavior_id] = policy
227-
self.optimizer = PPOOptimizer(self.policy, self.trainer_settings)
228+
self.optimizer = PPOOptimizer(
229+
cast(TFPolicy, self.policy), self.trainer_settings
230+
)
228231
for _reward_signal in self.optimizer.reward_signals.keys():
229232
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
230233
# Needed to resume loads properly
231234
self.step = policy.get_current_step()
232235

233-
def get_policy(self, name_behavior_id: str) -> TFPolicy:
236+
def get_policy(self, name_behavior_id: str) -> Policy:
234237
"""
235238
Gets policy from trainer associated with name_behavior_id
236239
:param name_behavior_id: full identifier of policy

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mlagents_envs.timers import timed
1414
from mlagents_envs.base_env import BehaviorSpec
1515
from mlagents.trainers.policy.tf_policy import TFPolicy
16+
from mlagents.trainers.policy import Policy
1617
from mlagents.trainers.sac.optimizer import SACOptimizer
1718
from mlagents.trainers.trainer.rl_trainer import RLTrainer
1819
from mlagents.trainers.trajectory import Trajectory, SplitObservations
@@ -57,7 +58,7 @@ def __init__(
5758

5859
self.load = load
5960
self.seed = seed
60-
self.policy: TFPolicy = None # type: ignore
61+
self.policy: Policy = None # type: ignore
6162
self.optimizer: SACOptimizer = None # type: ignore
6263
self.hyperparameters: SACSettings = cast(
6364
SACSettings, trainer_settings.hyperparameters
@@ -312,7 +313,7 @@ def _update_reward_signals(self) -> None:
312313
self._stats_reporter.add_stat(stat, np.mean(stat_list))
313314

314315
def add_policy(
315-
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
316+
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
316317
) -> None:
317318
"""
318319
Adds policy to trainer.
@@ -326,7 +327,9 @@ def add_policy(
326327
)
327328
self.policy = policy
328329
self.policies[parsed_behavior_id.behavior_id] = policy
329-
self.optimizer = SACOptimizer(self.policy, self.trainer_settings)
330+
self.optimizer = SACOptimizer(
331+
cast(TFPolicy, self.policy), self.trainer_settings
332+
)
330333
for _reward_signal in self.optimizer.reward_signals.keys():
331334
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
332335
# Needed to resume loads properly
@@ -337,7 +340,7 @@ def add_policy(
337340
max(1, self.step / self.reward_signal_steps_per_update)
338341
)
339342

340-
def get_policy(self, name_behavior_id: str) -> TFPolicy:
343+
def get_policy(self, name_behavior_id: str) -> Policy:
341344
"""
342345
Gets policy from trainer associated with name_behavior_id
343346
:param name_behavior_id: full identifier of policy

ml-agents/mlagents/trainers/trainer/rl_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from mlagents_envs.logging_util import get_logger
1414
from mlagents_envs.timers import timed
15-
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
15+
from mlagents.trainers.optimizer import Optimizer
1616
from mlagents.trainers.buffer import AgentBuffer
1717
from mlagents.trainers.trainer import Trainer
1818
from mlagents.trainers.components.reward_signals import RewardSignalResult
@@ -56,7 +56,7 @@ def end_episode(self) -> None:
5656
for agent_id in rewards:
5757
rewards[agent_id] = 0
5858

59-
def _update_end_episode_stats(self, agent_id: str, optimizer: TFOptimizer) -> None:
59+
def _update_end_episode_stats(self, agent_id: str, optimizer: Optimizer) -> None:
6060
for name, rewards in self.collected_rewards.items():
6161
if name == "environment":
6262
self.stats_reporter.add_stat(

0 commit comments

Comments
 (0)