Skip to content

Commit 9874a35

Browse files
Fixed the reporting of the discriminator loss (#4348)
* Fixed the reporting of the discriminator loss * Update ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py * fixing pre-commit test
1 parent ff667e7 commit 9874a35

File tree

1 file changed

+37
-35
lines changed

1 file changed

+37
-35
lines changed

ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,12 @@ def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
4646
expert_batch = self._demo_buffer.sample_mini_batch(
4747
mini_batch.num_experiences, 1
4848
)
49-
loss, policy_mean_estimate, expert_mean_estimate, kl_loss = self._discriminator_network.compute_loss(
49+
loss, stats_dict = self._discriminator_network.compute_loss(
5050
mini_batch, expert_batch
5151
)
5252
self.optimizer.zero_grad()
5353
loss.backward()
5454
self.optimizer.step()
55-
stats_dict = {
56-
"Losses/GAIL Discriminator Loss": loss.detach().cpu().numpy(),
57-
"Policy/GAIL Policy Estimate": policy_mean_estimate.detach().cpu().numpy(),
58-
"Policy/GAIL Expert Estimate": expert_mean_estimate.detach().cpu().numpy(),
59-
}
60-
if self._discriminator_network.use_vail:
61-
stats_dict["Policy/GAIL Beta"] = (
62-
self._discriminator_network.beta.detach().cpu().numpy()
63-
)
64-
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy()
6555
return stats_dict
6656

6757

@@ -76,7 +66,7 @@ class DiscriminatorNetwork(torch.nn.Module):
7666
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
7767
super().__init__()
7868
self._policy_specs = specs
79-
self.use_vail = settings.use_vail
69+
self._use_vail = settings.use_vail
8070
self._settings = settings
8171

8272
state_encoder_settings = NetworkSettings(
@@ -108,20 +98,20 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
10898
estimator_input_size = settings.encoding_size
10999
if settings.use_vail:
110100
estimator_input_size = self.z_size
111-
self.z_sigma = torch.nn.Parameter(
101+
self._z_sigma = torch.nn.Parameter(
112102
torch.ones((self.z_size), dtype=torch.float), requires_grad=True
113103
)
114-
self.z_mu_layer = linear_layer(
104+
self._z_mu_layer = linear_layer(
115105
settings.encoding_size,
116106
self.z_size,
117107
kernel_init=Initialization.KaimingHeNormal,
118108
kernel_gain=0.1,
119109
)
120-
self.beta = torch.nn.Parameter(
110+
self._beta = torch.nn.Parameter(
121111
torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False
122112
)
123113

124-
self.estimator = torch.nn.Sequential(
114+
self._estimator = torch.nn.Sequential(
125115
linear_layer(estimator_input_size, 1), torch.nn.Sigmoid()
126116
)
127117

@@ -166,9 +156,9 @@ def compute_estimate(
166156
hidden = self.encoder(encoder_input)
167157
z_mu: Optional[torch.Tensor] = None
168158
if self._settings.use_vail:
169-
z_mu = self.z_mu_layer(hidden)
170-
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise)
171-
estimate = self.estimator(hidden)
159+
z_mu = self._z_mu_layer(hidden)
160+
hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
161+
estimate = self._estimator(hidden)
172162
return estimate, z_mu
173163

174164
def compute_loss(
@@ -177,41 +167,53 @@ def compute_loss(
177167
"""
178168
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
179169
"""
170+
total_loss = torch.zeros(1)
171+
stats_dict: Dict[str, np.ndarray] = {}
180172
policy_estimate, policy_mu = self.compute_estimate(
181173
policy_batch, use_vail_noise=True
182174
)
183175
expert_estimate, expert_mu = self.compute_estimate(
184176
expert_batch, use_vail_noise=True
185177
)
186-
loss = -(
187-
torch.log(expert_estimate * (1 - self.EPSILON))
188-
+ torch.log(1.0 - policy_estimate * (1 - self.EPSILON))
178+
stats_dict["Policy/GAIL Policy Estimate"] = (
179+
policy_estimate.mean().detach().cpu().numpy()
180+
)
181+
stats_dict["Policy/GAIL Expert Estimate"] = (
182+
expert_estimate.mean().detach().cpu().numpy()
183+
)
184+
discriminator_loss = -(
185+
torch.log(expert_estimate + self.EPSILON)
186+
+ torch.log(1.0 - policy_estimate + self.EPSILON)
189187
).mean()
190-
kl_loss: Optional[torch.Tensor] = None
188+
stats_dict["Losses/GAIL Loss"] = discriminator_loss.detach().cpu().numpy()
189+
total_loss += discriminator_loss
191190
if self._settings.use_vail:
192191
# KL divergence loss (encourage latent representation to be normal)
193192
kl_loss = torch.mean(
194193
-torch.sum(
195194
1
196-
+ (self.z_sigma ** 2).log()
195+
+ (self._z_sigma ** 2).log()
197196
- 0.5 * expert_mu ** 2
198197
- 0.5 * policy_mu ** 2
199-
- (self.z_sigma ** 2),
198+
- (self._z_sigma ** 2),
200199
dim=1,
201200
)
202201
)
203-
vail_loss = self.beta * (kl_loss - self.mutual_information)
202+
vail_loss = self._beta * (kl_loss - self.mutual_information)
204203
with torch.no_grad():
205-
self.beta.data = torch.max(
206-
self.beta + self.alpha * (kl_loss - self.mutual_information),
204+
self._beta.data = torch.max(
205+
self._beta + self.alpha * (kl_loss - self.mutual_information),
207206
torch.tensor(0.0),
208207
)
209-
loss += vail_loss
208+
total_loss += vail_loss
209+
stats_dict["Policy/GAIL Beta"] = self._beta.detach().cpu().numpy()
210+
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy()
210211
if self.gradient_penalty_weight > 0.0:
211-
loss += self.gradient_penalty_weight * self.compute_gradient_magnitude(
212-
policy_batch, expert_batch
212+
total_loss += (
213+
self.gradient_penalty_weight
214+
* self.compute_gradient_magnitude(policy_batch, expert_batch)
213215
)
214-
return loss, torch.mean(policy_estimate), torch.mean(expert_estimate), kl_loss
216+
return total_loss, stats_dict
215217

216218
def compute_gradient_magnitude(
217219
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
@@ -243,9 +245,9 @@ def compute_gradient_magnitude(
243245
hidden = self.encoder(encoder_input)
244246
if self._settings.use_vail:
245247
use_vail_noise = True
246-
z_mu = self.z_mu_layer(hidden)
247-
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise)
248-
hidden = self.estimator(hidden)
248+
z_mu = self._z_mu_layer(hidden)
249+
hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
250+
hidden = self._estimator(hidden)
249251
estimate = torch.mean(torch.sum(hidden, dim=1))
250252
gradient = torch.autograd.grad(estimate, encoder_input)[0]
251253
# Norm's gradient could be NaN at 0. Use our own safe_norm

0 commit comments

Comments
 (0)