diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py index dd3a9854c4..b59ef0c494 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py @@ -46,22 +46,12 @@ def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: expert_batch = self._demo_buffer.sample_mini_batch( mini_batch.num_experiences, 1 ) - loss, policy_mean_estimate, expert_mean_estimate, kl_loss = self._discriminator_network.compute_loss( + loss, stats_dict = self._discriminator_network.compute_loss( mini_batch, expert_batch ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() - stats_dict = { - "Losses/GAIL Discriminator Loss": loss.detach().cpu().numpy(), - "Policy/GAIL Policy Estimate": policy_mean_estimate.detach().cpu().numpy(), - "Policy/GAIL Expert Estimate": expert_mean_estimate.detach().cpu().numpy(), - } - if self._discriminator_network.use_vail: - stats_dict["Policy/GAIL Beta"] = ( - self._discriminator_network.beta.detach().cpu().numpy() - ) - stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy() return stats_dict @@ -76,7 +66,7 @@ class DiscriminatorNetwork(torch.nn.Module): def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: super().__init__() self._policy_specs = specs - self.use_vail = settings.use_vail + self._use_vail = settings.use_vail self._settings = settings state_encoder_settings = NetworkSettings( @@ -108,20 +98,20 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: estimator_input_size = settings.encoding_size if settings.use_vail: estimator_input_size = self.z_size - self.z_sigma = torch.nn.Parameter( + self._z_sigma = torch.nn.Parameter( torch.ones((self.z_size), dtype=torch.float), requires_grad=True ) - self.z_mu_layer = linear_layer( + self._z_mu_layer = linear_layer( settings.encoding_size, self.z_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=0.1, ) - self.beta = torch.nn.Parameter( + self._beta = torch.nn.Parameter( torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False ) - self.estimator = torch.nn.Sequential( + self._estimator = torch.nn.Sequential( linear_layer(estimator_input_size, 1), torch.nn.Sigmoid() ) @@ -166,9 +156,9 @@ def compute_estimate( hidden = self.encoder(encoder_input) z_mu: Optional[torch.Tensor] = None if self._settings.use_vail: - z_mu = self.z_mu_layer(hidden) - hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise) - estimate = self.estimator(hidden) + z_mu = self._z_mu_layer(hidden) + hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise) + estimate = self._estimator(hidden) return estimate, z_mu def compute_loss( @@ -177,41 +167,53 @@ def compute_loss( """ Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. """ + total_loss = torch.zeros(1) + stats_dict: Dict[str, np.ndarray] = {} policy_estimate, policy_mu = self.compute_estimate( policy_batch, use_vail_noise=True ) expert_estimate, expert_mu = self.compute_estimate( expert_batch, use_vail_noise=True ) - loss = -( - torch.log(expert_estimate * (1 - self.EPSILON)) - + torch.log(1.0 - policy_estimate * (1 - self.EPSILON)) + stats_dict["Policy/GAIL Policy Estimate"] = ( + policy_estimate.mean().detach().cpu().numpy() + ) + stats_dict["Policy/GAIL Expert Estimate"] = ( + expert_estimate.mean().detach().cpu().numpy() + ) + discriminator_loss = -( + torch.log(expert_estimate + self.EPSILON) + + torch.log(1.0 - policy_estimate + self.EPSILON) ).mean() - kl_loss: Optional[torch.Tensor] = None + stats_dict["Losses/GAIL Loss"] = discriminator_loss.detach().cpu().numpy() + total_loss += discriminator_loss if self._settings.use_vail: # KL divergence loss (encourage latent representation to be normal) kl_loss = torch.mean( -torch.sum( 1 - + (self.z_sigma ** 2).log() + + (self._z_sigma ** 2).log() - 0.5 * expert_mu ** 2 - 0.5 * policy_mu ** 2 - - (self.z_sigma ** 2), + - (self._z_sigma ** 2), dim=1, ) ) - vail_loss = self.beta * (kl_loss - self.mutual_information) + vail_loss = self._beta * (kl_loss - self.mutual_information) with torch.no_grad(): - self.beta.data = torch.max( - self.beta + self.alpha * (kl_loss - self.mutual_information), + self._beta.data = torch.max( + self._beta + self.alpha * (kl_loss - self.mutual_information), torch.tensor(0.0), ) - loss += vail_loss + total_loss += vail_loss + stats_dict["Policy/GAIL Beta"] = self._beta.detach().cpu().numpy() + stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy() if self.gradient_penalty_weight > 0.0: - loss += self.gradient_penalty_weight * self.compute_gradient_magnitude( - policy_batch, expert_batch + total_loss += ( + self.gradient_penalty_weight + * self.compute_gradient_magnitude(policy_batch, expert_batch) ) - return loss, torch.mean(policy_estimate), torch.mean(expert_estimate), kl_loss + return total_loss, stats_dict def compute_gradient_magnitude( self, policy_batch: AgentBuffer, expert_batch: AgentBuffer @@ -243,9 +245,9 @@ def compute_gradient_magnitude( hidden = self.encoder(encoder_input) if self._settings.use_vail: use_vail_noise = True - z_mu = self.z_mu_layer(hidden) - hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise) - hidden = self.estimator(hidden) + z_mu = self._z_mu_layer(hidden) + hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise) + hidden = self._estimator(hidden) estimate = torch.mean(torch.sum(hidden, dim=1)) gradient = torch.autograd.grad(estimate, encoder_input)[0] # Norm's gradient could be NaN at 0. Use our own safe_norm