diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 0c92f35d8b..ec09c74af5 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -26,6 +26,24 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.hyperparameters: PPOSettings = cast( PPOSettings, trainer_settings.hyperparameters ) + self.decay_learning_rate = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.learning_rate, + 1e-10, + self.trainer_settings.max_steps, + ) + self.decay_epsilon = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.epsilon, + 0.1, + self.trainer_settings.max_steps, + ) + self.decay_beta = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.beta, + 1e-5, + self.trainer_settings.max_steps, + ) self.optimizer = torch.optim.Adam( params, lr=self.trainer_settings.hyperparameters.learning_rate @@ -37,22 +55,25 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.stream_names = list(self.reward_signals.keys()) - def ppo_value_loss(self, values, old_values, returns): + def ppo_value_loss( + self, + values: Dict[str, torch.Tensor], + old_values: Dict[str, torch.Tensor], + returns: Dict[str, torch.Tensor], + epsilon: float, + ) -> torch.Tensor: """ Creates training-specific Tensorflow ops for PPO models. :param returns: :param old_values: :param values: """ - - decay_epsilon = self.hyperparameters.epsilon - value_losses = [] for name, head in values.items(): old_val_tensor = old_values[name] returns_tensor = returns[name] clipped_value_estimate = old_val_tensor + torch.clamp( - head - old_val_tensor, -decay_epsilon, decay_epsilon + head - old_val_tensor, -1 * epsilon, epsilon ) v_opt_a = (returns_tensor - head) ** 2 v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 @@ -89,6 +110,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: :param num_sequences: Number of sequences to process. :return: Results of update. """ + # Get decayed parameters + decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) + decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) + decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: @@ -128,18 +153,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) - value_loss = self.ppo_value_loss(values, old_values, returns) + value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, ModelUtils.list_to_tensor(batch["action_probs"]), ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32), ) - loss = ( - policy_loss - + 0.5 * value_loss - - self.hyperparameters.beta * torch.mean(entropy) - ) + loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy) + + # Set optimizer learning rate + ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() @@ -147,6 +171,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: update_stats = { "Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), "Losses/Value Loss": value_loss.detach().cpu().numpy(), + "Policy/Learning Rate": decay_lr, + "Policy/Epsilon": decay_eps, + "Policy/Beta": decay_bet, } return update_stats diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index b5653a9f65..85dc2e218d 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -65,18 +65,12 @@ def forward( def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): super().__init__(policy, trainer_params) hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) - lr = hyperparameters.learning_rate - # lr_schedule = hyperparameters.learning_rate_schedule - # max_step = trainer_params.max_steps self.tau = hyperparameters.tau self.init_entcoef = hyperparameters.init_entcoef self.policy = policy self.act_size = policy.act_size policy_network_settings = policy.network_settings - # h_size = policy_network_settings.hidden_units - # num_layers = policy_network_settings.num_layers - # vis_encode_type = policy_network_settings.vis_encode_type self.tau = hyperparameters.tau self.burn_in_ratio = 0.0 @@ -137,9 +131,21 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): for param in policy_params: logger.debug(param.shape) - self.policy_optimizer = torch.optim.Adam(policy_params, lr=lr) - self.value_optimizer = torch.optim.Adam(value_params, lr=lr) - self.entropy_optimizer = torch.optim.Adam([self._log_ent_coef], lr=lr) + self.decay_learning_rate = ModelUtils.DecayedValue( + hyperparameters.learning_rate_schedule, + hyperparameters.learning_rate, + 1e-10, + self.trainer_settings.max_steps, + ) + self.policy_optimizer = torch.optim.Adam( + policy_params, lr=hyperparameters.learning_rate + ) + self.value_optimizer = torch.optim.Adam( + value_params, lr=hyperparameters.learning_rate + ) + self.entropy_optimizer = torch.optim.Adam( + [self._log_ent_coef], lr=hyperparameters.learning_rate + ) def sac_q_loss( self, @@ -436,14 +442,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: total_value_loss = q1_loss + q2_loss + value_loss + decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) + ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() + ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() + ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() @@ -459,6 +469,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: .detach() .cpu() .numpy(), + "Policy/Learning Rate": decay_lr, } return update_stats diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index b9a58c4617..f6f286c84a 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -2,7 +2,7 @@ import torch import numpy as np -from mlagents.trainers.settings import EncoderType +from mlagents.trainers.settings import EncoderType, ScheduleType from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.torch.encoders import ( @@ -79,6 +79,38 @@ def test_create_encoders( assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type)) +def test_decayed_value(): + test_steps = [0, 4, 9] + # Test constant decay + param = ModelUtils.DecayedValue(ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1]) + for _step in test_steps: + _param = param.get_value(_step) + assert _param == 1.0 + + test_results = [1.0, 0.6444, 0.2] + # Test linear decay + param = ModelUtils.DecayedValue(ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1]) + for _step, _result in zip(test_steps, test_results): + _param = param.get_value(_step) + assert _param == pytest.approx(_result, abs=0.01) + + # Test invalid + with pytest.raises(UnityTrainerException): + ModelUtils.DecayedValue( + "SomeOtherSchedule", 1.0, 0.2, test_steps[-1] + ).get_value(0) + + +def test_polynomial_decay(): + test_steps = [0, 4, 9] + test_results = [1.0, 0.7, 0.2] + for _step, _result in zip(test_steps, test_results): + decayed = ModelUtils.polynomial_decay( + 1.0, 0.2, test_steps[-1], _step, power=0.8 + ) + assert decayed == pytest.approx(_result, abs=0.01) + + def test_list_to_tensor(): # Test converting pure list unconverted_list = [[1, 2], [1, 3], [1, 4]] diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 5d815cea0b..0644defec9 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -10,7 +10,7 @@ VectorEncoder, VectorAndUnnormalizedInputEncoder, ) -from mlagents.trainers.settings import EncoderType +from mlagents.trainers.settings import EncoderType, ScheduleType from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance @@ -29,6 +29,77 @@ def swish(input_activation: torch.Tensor) -> torch.Tensor: """Swish activation function. For more info: https://arxiv.org/abs/1710.05941""" return torch.mul(input_activation, torch.sigmoid(input_activation)) + @staticmethod + def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: + """ + Apply a learning rate to a torch optimizer. + :param optim: Optimizer + :param lr: Learning rate + """ + for param_group in optim.param_groups: + param_group["lr"] = lr + + class DecayedValue: + def __init__( + self, + schedule: ScheduleType, + initial_value: float, + min_value: float, + max_step: int, + ): + """ + Object that represnets value of a parameter that should be decayed, assuming it is a function of + global_step. + :param schedule: Type of learning rate schedule. + :param initial_value: Initial value before decay. + :param min_value: Decay value to this value by max_step. + :param max_step: The final step count where the return value should equal min_value. + :param global_step: The current step count. + :return: The value. + """ + self.schedule = schedule + self.initial_value = initial_value + self.min_value = min_value + self.max_step = max_step + + def get_value(self, global_step: int) -> float: + """ + Get the value at a given global step. + :param global_step: Step count. + :returns: Decayed value at this global step. + """ + if self.schedule == ScheduleType.CONSTANT: + return self.initial_value + elif self.schedule == ScheduleType.LINEAR: + return ModelUtils.polynomial_decay( + self.initial_value, self.min_value, self.max_step, global_step + ) + else: + raise UnityTrainerException(f"The schedule {self.schedule} is invalid.") + + @staticmethod + def polynomial_decay( + initial_value: float, + min_value: float, + max_step: int, + global_step: int, + power: float = 1.0, + ) -> float: + """ + Get a decayed value based on a polynomial schedule, with respect to the current global step. + :param initial_value: Initial value before decay. + :param min_value: Decay value to this value by max_step. + :param max_step: The final step count where the return value should equal min_value. + :param global_step: The current step count. + :param power: Power of polynomial decay. 1.0 (default) is a linear decay. + :return: The current decayed value. + """ + global_step = min(global_step, max_step) + decayed_value = (initial_value - min_value) * ( + 1 - float(global_step) / max_step + ) ** (power) + min_value + return decayed_value + @staticmethod def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module: ENCODER_FUNCTION_BY_TYPE = {