Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions rsl_rl/modules/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
activation="elu",
init_noise_std=1.0,
noise_std_type: str = "scalar",
state_dependent_std=False,
**kwargs,
):
if kwargs:
Expand All @@ -47,8 +48,12 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
num_critic_obs += obs[obs_group].shape[-1]

self.state_dependent_std = state_dependent_std
# actor
self.actor = MLP(num_actor_obs, num_actions, actor_hidden_dims, activation)
if self.state_dependent_std:
self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation)
else:
self.actor = MLP(num_actor_obs, num_actions, actor_hidden_dims, activation)
# actor observation normalization
self.actor_obs_normalization = actor_obs_normalization
if actor_obs_normalization:
Expand All @@ -69,12 +74,21 @@ def __init__(

# Action noise
self.noise_std_type = noise_std_type
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
elif self.noise_std_type == "log":
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
if self.state_dependent_std:
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
if self.noise_std_type == "scalar":
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
elif self.noise_std_type == "log":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the dangling else for a reason? Or it should also throw a value error?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch. Fixed.

self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
elif self.noise_std_type == "log":
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

# Action distribution (populated in update_distribution)
self.distribution = None
Expand All @@ -100,15 +114,26 @@ def entropy(self):
return self.distribution.entropy().sum(dim=-1)

def update_distribution(self, obs):
# compute mean
mean = self.actor(obs)
# compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
if self.state_dependent_std:
# compute mean and standard deviation
mean_and_std = self.actor(obs)
if self.noise_std_type == "scalar":
mean, std = torch.unbind(mean_and_std, dim=-2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: the common unbind step can be moved outside the if-else.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I debated this, but went with this option to make it explicit that unbind return std in one case and log_std in the other case.
Happy to move the unbind step out, if you feel strongly about it.

elif self.noise_std_type == "log":
mean, self.log_std = torch.unbind(mean_and_std, dim=-2)
std = torch.exp(self.log_std)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
# compute mean
mean = self.actor(obs)
# compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
# create distribution
self.distribution = Normal(mean, std)

Expand Down
54 changes: 40 additions & 14 deletions rsl_rl/modules/actor_critic_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
activation="elu",
init_noise_std=1.0,
noise_std_type: str = "scalar",
state_dependent_std=False,
rnn_type="lstm",
rnn_hidden_dim=256,
rnn_num_layers=1,
Expand Down Expand Up @@ -58,9 +59,14 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
num_critic_obs += obs[obs_group].shape[-1]

self.state_dependent_std = state_dependent_std
# actor
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
self.actor = MLP(rnn_hidden_dim, num_actions, actor_hidden_dims, activation)
if self.state_dependent_std:
self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation)
else:
self.actor = MLP(rnn_hidden_dim, num_actions, actor_hidden_dims, activation)

# actor observation normalization
self.actor_obs_normalization = actor_obs_normalization
if actor_obs_normalization:
Expand All @@ -84,12 +90,21 @@ def __init__(

# Action noise
self.noise_std_type = noise_std_type
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
elif self.noise_std_type == "log":
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
if self.state_dependent_std:
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
if self.noise_std_type == "scalar":
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
elif self.noise_std_type == "log":
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)))
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
elif self.noise_std_type == "log":
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

# Action distribution (populated in update_distribution)
self.distribution = None
Expand All @@ -116,15 +131,26 @@ def forward(self):
raise NotImplementedError

def update_distribution(self, obs):
# compute mean
mean = self.actor(obs)
# compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
if self.state_dependent_std:
# compute mean and standard deviation
mean_and_std = self.actor(obs)
if self.noise_std_type == "scalar":
mean, std = torch.unbind(mean_and_std, dim=-2)
elif self.noise_std_type == "log":
mean, self.log_std = torch.unbind(mean_and_std, dim=-2)
std = torch.exp(self.log_std)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
# compute mean
mean = self.actor(obs)
# compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
# create distribution
self.distribution = Normal(mean, std)

Expand Down
2 changes: 1 addition & 1 deletion rsl_rl/networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
total_out_dim = reduce(lambda x, y: x * y, output_dim)
# add a layer to reshape the output to the desired shape
layers.append(nn.Linear(hidden_dims_processed[-1], total_out_dim))
layers.append(nn.Unflatten(output_dim))
layers.append(nn.Unflatten(-1, output_dim))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the output dim should be changed according to the right shape. Otherwise this is always adding an extra dimension.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first argument specifies which dimension to apply the Unflatten operation. Based on the prior nn.Linear layer, this would be the final dimension, hence the argument -1.

https://docs.pytorch.org/docs/stable/generated/torch.nn.Unflatten.html#torch.nn.Unflatten


# add last activation function if specified
if last_activation_mod is not None:
Expand Down