Skip to content

[add-fire] Add learning rate and beta/epsilon decay to PyTorch #4318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
self.hyperparameters: PPOSettings = cast(
PPOSettings, trainer_settings.hyperparameters
)
self.decay_learning_rate = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.learning_rate,
1e-10,
self.trainer_settings.max_steps,
)
self.decay_epsilon = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.epsilon,
0.1,
self.trainer_settings.max_steps,
)
self.decay_beta = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.beta,
1e-5,
self.trainer_settings.max_steps,
)

self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate
Expand All @@ -37,22 +55,25 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):

self.stream_names = list(self.reward_signals.keys())

def ppo_value_loss(self, values, old_values, returns):
def ppo_value_loss(
self,
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
) -> torch.Tensor:
"""
Creates training-specific Tensorflow ops for PPO models.
:param returns:
:param old_values:
:param values:
"""

decay_epsilon = self.hyperparameters.epsilon

value_losses = []
for name, head in values.items():
old_val_tensor = old_values[name]
returns_tensor = returns[name]
clipped_value_estimate = old_val_tensor + torch.clamp(
head - old_val_tensor, -decay_epsilon, decay_epsilon
head - old_val_tensor, -1 * epsilon, epsilon
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
Expand Down Expand Up @@ -89,6 +110,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
:param num_sequences: Number of sequences to process.
:return: Results of update.
"""
# Get decayed parameters
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step())
decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
returns = {}
old_values = {}
for name in self.reward_signals:
Expand Down Expand Up @@ -128,25 +153,27 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
memories=memories,
seq_len=self.policy.sequence_length,
)
value_loss = self.ppo_value_loss(values, old_values, returns)
value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps)
policy_loss = self.ppo_policy_loss(
ModelUtils.list_to_tensor(batch["advantages"]),
log_probs,
ModelUtils.list_to_tensor(batch["action_probs"]),
ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32),
)
loss = (
policy_loss
+ 0.5 * value_loss
- self.hyperparameters.beta * torch.mean(entropy)
)
loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy)

# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()

self.optimizer.step()
update_stats = {
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()),
"Losses/Value Loss": value_loss.detach().cpu().numpy(),
"Policy/Learning Rate": decay_lr,
"Policy/Epsilon": decay_eps,
"Policy/Beta": decay_bet,
}

return update_stats
29 changes: 20 additions & 9 deletions ml-agents/mlagents/trainers/sac/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,12 @@ def forward(
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
super().__init__(policy, trainer_params)
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters)
lr = hyperparameters.learning_rate
# lr_schedule = hyperparameters.learning_rate_schedule
# max_step = trainer_params.max_steps
self.tau = hyperparameters.tau
self.init_entcoef = hyperparameters.init_entcoef

self.policy = policy
self.act_size = policy.act_size
policy_network_settings = policy.network_settings
# h_size = policy_network_settings.hidden_units
# num_layers = policy_network_settings.num_layers
# vis_encode_type = policy_network_settings.vis_encode_type

self.tau = hyperparameters.tau
self.burn_in_ratio = 0.0
Expand Down Expand Up @@ -137,9 +131,21 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
for param in policy_params:
logger.debug(param.shape)

self.policy_optimizer = torch.optim.Adam(policy_params, lr=lr)
self.value_optimizer = torch.optim.Adam(value_params, lr=lr)
self.entropy_optimizer = torch.optim.Adam([self._log_ent_coef], lr=lr)
self.decay_learning_rate = ModelUtils.DecayedValue(
hyperparameters.learning_rate_schedule,
hyperparameters.learning_rate,
1e-10,
self.trainer_settings.max_steps,
)
self.policy_optimizer = torch.optim.Adam(
policy_params, lr=hyperparameters.learning_rate
)
self.value_optimizer = torch.optim.Adam(
value_params, lr=hyperparameters.learning_rate
)
self.entropy_optimizer = torch.optim.Adam(
[self._log_ent_coef], lr=hyperparameters.learning_rate
)

def sac_q_loss(
self,
Expand Down Expand Up @@ -436,14 +442,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:

total_value_loss = q1_loss + q2_loss + value_loss

decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()

ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
self.value_optimizer.zero_grad()
total_value_loss.backward()
self.value_optimizer.step()

ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
self.entropy_optimizer.zero_grad()
entropy_loss.backward()
self.entropy_optimizer.step()
Expand All @@ -459,6 +469,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
.detach()
.cpu()
.numpy(),
"Policy/Learning Rate": decay_lr,
}

return update_stats
Expand Down
34 changes: 33 additions & 1 deletion ml-agents/mlagents/trainers/tests/torch/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import numpy as np

from mlagents.trainers.settings import EncoderType
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.encoders import (
Expand Down Expand Up @@ -79,6 +79,38 @@ def test_create_encoders(
assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type))


def test_decayed_value():
test_steps = [0, 4, 9]
# Test constant decay
param = ModelUtils.DecayedValue(ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1])
for _step in test_steps:
_param = param.get_value(_step)
assert _param == 1.0

test_results = [1.0, 0.6444, 0.2]
# Test linear decay
param = ModelUtils.DecayedValue(ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1])
for _step, _result in zip(test_steps, test_results):
_param = param.get_value(_step)
assert _param == pytest.approx(_result, abs=0.01)

# Test invalid
with pytest.raises(UnityTrainerException):
ModelUtils.DecayedValue(
"SomeOtherSchedule", 1.0, 0.2, test_steps[-1]
).get_value(0)


def test_polynomial_decay():
test_steps = [0, 4, 9]
test_results = [1.0, 0.7, 0.2]
for _step, _result in zip(test_steps, test_results):
decayed = ModelUtils.polynomial_decay(
1.0, 0.2, test_steps[-1], _step, power=0.8
)
assert decayed == pytest.approx(_result, abs=0.01)


def test_list_to_tensor():
# Test converting pure list
unconverted_list = [[1, 2], [1, 3], [1, 4]]
Expand Down
73 changes: 72 additions & 1 deletion ml-agents/mlagents/trainers/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
VectorEncoder,
VectorAndUnnormalizedInputEncoder,
)
from mlagents.trainers.settings import EncoderType
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance

Expand All @@ -29,6 +29,77 @@ def swish(input_activation: torch.Tensor) -> torch.Tensor:
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
return torch.mul(input_activation, torch.sigmoid(input_activation))

@staticmethod
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:
"""
Apply a learning rate to a torch optimizer.
:param optim: Optimizer
:param lr: Learning rate
"""
for param_group in optim.param_groups:
param_group["lr"] = lr

class DecayedValue:
def __init__(
self,
schedule: ScheduleType,
initial_value: float,
min_value: float,
max_step: int,
):
"""
Object that represnets value of a parameter that should be decayed, assuming it is a function of
global_step.
:param schedule: Type of learning rate schedule.
:param initial_value: Initial value before decay.
:param min_value: Decay value to this value by max_step.
:param max_step: The final step count where the return value should equal min_value.
:param global_step: The current step count.
:return: The value.
"""
self.schedule = schedule
self.initial_value = initial_value
self.min_value = min_value
self.max_step = max_step

def get_value(self, global_step: int) -> float:
"""
Get the value at a given global step.
:param global_step: Step count.
:returns: Decayed value at this global step.
"""
if self.schedule == ScheduleType.CONSTANT:
return self.initial_value
elif self.schedule == ScheduleType.LINEAR:
return ModelUtils.polynomial_decay(
self.initial_value, self.min_value, self.max_step, global_step
)
else:
raise UnityTrainerException(f"The schedule {self.schedule} is invalid.")

@staticmethod
def polynomial_decay(
initial_value: float,
min_value: float,
max_step: int,
global_step: int,
power: float = 1.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's really no way to configure this properly. Should it even be an arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If at some point we want a ScheduleType.EXPONENTIAL, we do

) -> float:
"""
Get a decayed value based on a polynomial schedule, with respect to the current global step.
:param initial_value: Initial value before decay.
:param min_value: Decay value to this value by max_step.
:param max_step: The final step count where the return value should equal min_value.
:param global_step: The current step count.
:param power: Power of polynomial decay. 1.0 (default) is a linear decay.
:return: The current decayed value.
"""
global_step = min(global_step, max_step)
decayed_value = (initial_value - min_value) * (
1 - float(global_step) / max_step
) ** (power) + min_value
return decayed_value

@staticmethod
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
ENCODER_FUNCTION_BY_TYPE = {
Expand Down