|
| 1 | +from typing import Dict, Optional, Tuple, List |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from mlagents.trainers.buffer import AgentBuffer |
| 6 | +from mlagents.trainers.trajectory import SplitObservations |
| 7 | +from mlagents.trainers.torch.components.bc.module import BCModule |
| 8 | +from mlagents.trainers.torch.components.reward_providers import create_reward_provider |
| 9 | + |
| 10 | +from mlagents.trainers.policy.torch_policy import TorchPolicy |
| 11 | +from mlagents.trainers.optimizer import Optimizer |
| 12 | +from mlagents.trainers.settings import TrainerSettings |
| 13 | +from mlagents.trainers.torch.utils import ModelUtils |
| 14 | + |
| 15 | + |
| 16 | +class TorchOptimizer(Optimizer): # pylint: disable=W0223 |
| 17 | + def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
| 18 | + super().__init__() |
| 19 | + self.policy = policy |
| 20 | + self.trainer_settings = trainer_settings |
| 21 | + self.update_dict: Dict[str, torch.Tensor] = {} |
| 22 | + self.value_heads: Dict[str, torch.Tensor] = {} |
| 23 | + self.memory_in: torch.Tensor = None |
| 24 | + self.memory_out: torch.Tensor = None |
| 25 | + self.m_size: int = 0 |
| 26 | + self.global_step = torch.tensor(0) |
| 27 | + self.bc_module: Optional[BCModule] = None |
| 28 | + self.create_reward_signals(trainer_settings.reward_signals) |
| 29 | + if trainer_settings.behavioral_cloning is not None: |
| 30 | + self.bc_module = BCModule( |
| 31 | + self.policy, |
| 32 | + trainer_settings.behavioral_cloning, |
| 33 | + policy_learning_rate=trainer_settings.hyperparameters.learning_rate, |
| 34 | + default_batch_size=trainer_settings.hyperparameters.batch_size, |
| 35 | + default_num_epoch=3, |
| 36 | + ) |
| 37 | + |
| 38 | + def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
| 39 | + pass |
| 40 | + |
| 41 | + def create_reward_signals(self, reward_signal_configs): |
| 42 | + """ |
| 43 | + Create reward signals |
| 44 | + :param reward_signal_configs: Reward signal config. |
| 45 | + """ |
| 46 | + for reward_signal, settings in reward_signal_configs.items(): |
| 47 | + # Name reward signals by string in case we have duplicates later |
| 48 | + self.reward_signals[reward_signal.value] = create_reward_provider( |
| 49 | + reward_signal, self.policy.behavior_spec, settings |
| 50 | + ) |
| 51 | + |
| 52 | + def get_trajectory_value_estimates( |
| 53 | + self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
| 54 | + ) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
| 55 | + vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
| 56 | + if self.policy.use_vis_obs: |
| 57 | + visual_obs = [] |
| 58 | + for idx, _ in enumerate( |
| 59 | + self.policy.actor_critic.network_body.visual_encoders |
| 60 | + ): |
| 61 | + visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
| 62 | + visual_obs.append(visual_ob) |
| 63 | + else: |
| 64 | + visual_obs = [] |
| 65 | + |
| 66 | + memory = torch.zeros([1, 1, self.policy.m_size]) |
| 67 | + |
| 68 | + vec_vis_obs = SplitObservations.from_observations(next_obs) |
| 69 | + next_vec_obs = [ |
| 70 | + ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0) |
| 71 | + ] |
| 72 | + next_vis_obs = [ |
| 73 | + ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0) |
| 74 | + for _vis_ob in vec_vis_obs.visual_observations |
| 75 | + ] |
| 76 | + |
| 77 | + value_estimates, next_memory = self.policy.actor_critic.critic_pass( |
| 78 | + vector_obs, visual_obs, memory, sequence_length=batch.num_experiences |
| 79 | + ) |
| 80 | + |
| 81 | + next_value_estimate, _ = self.policy.actor_critic.critic_pass( |
| 82 | + next_vec_obs, next_vis_obs, next_memory, sequence_length=1 |
| 83 | + ) |
| 84 | + |
| 85 | + for name, estimate in value_estimates.items(): |
| 86 | + value_estimates[name] = estimate.detach().cpu().numpy() |
| 87 | + next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy() |
| 88 | + |
| 89 | + if done: |
| 90 | + for k in next_value_estimate: |
| 91 | + if not self.reward_signals[k].ignore_done: |
| 92 | + next_value_estimate[k] = 0.0 |
| 93 | + |
| 94 | + return value_estimates, next_value_estimate |
0 commit comments