@@ -46,22 +46,12 @@ def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
46
46
expert_batch = self ._demo_buffer .sample_mini_batch (
47
47
mini_batch .num_experiences , 1
48
48
)
49
- loss , policy_mean_estimate , expert_mean_estimate , kl_loss = self ._discriminator_network .compute_loss (
49
+ loss , stats_dict = self ._discriminator_network .compute_loss (
50
50
mini_batch , expert_batch
51
51
)
52
52
self .optimizer .zero_grad ()
53
53
loss .backward ()
54
54
self .optimizer .step ()
55
- stats_dict = {
56
- "Losses/GAIL Discriminator Loss" : loss .detach ().cpu ().numpy (),
57
- "Policy/GAIL Policy Estimate" : policy_mean_estimate .detach ().cpu ().numpy (),
58
- "Policy/GAIL Expert Estimate" : expert_mean_estimate .detach ().cpu ().numpy (),
59
- }
60
- if self ._discriminator_network .use_vail :
61
- stats_dict ["Policy/GAIL Beta" ] = (
62
- self ._discriminator_network .beta .detach ().cpu ().numpy ()
63
- )
64
- stats_dict ["Losses/GAIL KL Loss" ] = kl_loss .detach ().cpu ().numpy ()
65
55
return stats_dict
66
56
67
57
@@ -76,7 +66,7 @@ class DiscriminatorNetwork(torch.nn.Module):
76
66
def __init__ (self , specs : BehaviorSpec , settings : GAILSettings ) -> None :
77
67
super ().__init__ ()
78
68
self ._policy_specs = specs
79
- self .use_vail = settings .use_vail
69
+ self ._use_vail = settings .use_vail
80
70
self ._settings = settings
81
71
82
72
state_encoder_settings = NetworkSettings (
@@ -108,20 +98,20 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
108
98
estimator_input_size = settings .encoding_size
109
99
if settings .use_vail :
110
100
estimator_input_size = self .z_size
111
- self .z_sigma = torch .nn .Parameter (
101
+ self ._z_sigma = torch .nn .Parameter (
112
102
torch .ones ((self .z_size ), dtype = torch .float ), requires_grad = True
113
103
)
114
- self .z_mu_layer = linear_layer (
104
+ self ._z_mu_layer = linear_layer (
115
105
settings .encoding_size ,
116
106
self .z_size ,
117
107
kernel_init = Initialization .KaimingHeNormal ,
118
108
kernel_gain = 0.1 ,
119
109
)
120
- self .beta = torch .nn .Parameter (
110
+ self ._beta = torch .nn .Parameter (
121
111
torch .tensor (self .initial_beta , dtype = torch .float ), requires_grad = False
122
112
)
123
113
124
- self .estimator = torch .nn .Sequential (
114
+ self ._estimator = torch .nn .Sequential (
125
115
linear_layer (estimator_input_size , 1 ), torch .nn .Sigmoid ()
126
116
)
127
117
@@ -166,9 +156,9 @@ def compute_estimate(
166
156
hidden = self .encoder (encoder_input )
167
157
z_mu : Optional [torch .Tensor ] = None
168
158
if self ._settings .use_vail :
169
- z_mu = self .z_mu_layer (hidden )
170
- hidden = torch .normal (z_mu , self .z_sigma * use_vail_noise )
171
- estimate = self .estimator (hidden )
159
+ z_mu = self ._z_mu_layer (hidden )
160
+ hidden = torch .normal (z_mu , self ._z_sigma * use_vail_noise )
161
+ estimate = self ._estimator (hidden )
172
162
return estimate , z_mu
173
163
174
164
def compute_loss (
@@ -177,41 +167,53 @@ def compute_loss(
177
167
"""
178
168
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
179
169
"""
170
+ total_loss = torch .zeros (1 )
171
+ stats_dict : Dict [str , np .ndarray ] = {}
180
172
policy_estimate , policy_mu = self .compute_estimate (
181
173
policy_batch , use_vail_noise = True
182
174
)
183
175
expert_estimate , expert_mu = self .compute_estimate (
184
176
expert_batch , use_vail_noise = True
185
177
)
186
- loss = - (
187
- torch .log (expert_estimate * (1 - self .EPSILON ))
188
- + torch .log (1.0 - policy_estimate * (1 - self .EPSILON ))
178
+ stats_dict ["Policy/GAIL Policy Estimate" ] = (
179
+ policy_estimate .mean ().detach ().cpu ().numpy ()
180
+ )
181
+ stats_dict ["Policy/GAIL Expert Estimate" ] = (
182
+ expert_estimate .mean ().detach ().cpu ().numpy ()
183
+ )
184
+ discriminator_loss = - (
185
+ torch .log (expert_estimate + self .EPSILON )
186
+ + torch .log (1.0 - policy_estimate + self .EPSILON )
189
187
).mean ()
190
- kl_loss : Optional [torch .Tensor ] = None
188
+ stats_dict ["Losses/GAIL Loss" ] = discriminator_loss .detach ().cpu ().numpy ()
189
+ total_loss += discriminator_loss
191
190
if self ._settings .use_vail :
192
191
# KL divergence loss (encourage latent representation to be normal)
193
192
kl_loss = torch .mean (
194
193
- torch .sum (
195
194
1
196
- + (self .z_sigma ** 2 ).log ()
195
+ + (self ._z_sigma ** 2 ).log ()
197
196
- 0.5 * expert_mu ** 2
198
197
- 0.5 * policy_mu ** 2
199
- - (self .z_sigma ** 2 ),
198
+ - (self ._z_sigma ** 2 ),
200
199
dim = 1 ,
201
200
)
202
201
)
203
- vail_loss = self .beta * (kl_loss - self .mutual_information )
202
+ vail_loss = self ._beta * (kl_loss - self .mutual_information )
204
203
with torch .no_grad ():
205
- self .beta .data = torch .max (
206
- self .beta + self .alpha * (kl_loss - self .mutual_information ),
204
+ self ._beta .data = torch .max (
205
+ self ._beta + self .alpha * (kl_loss - self .mutual_information ),
207
206
torch .tensor (0.0 ),
208
207
)
209
- loss += vail_loss
208
+ total_loss += vail_loss
209
+ stats_dict ["Policy/GAIL Beta" ] = self ._beta .detach ().cpu ().numpy ()
210
+ stats_dict ["Losses/GAIL KL Loss" ] = kl_loss .detach ().cpu ().numpy ()
210
211
if self .gradient_penalty_weight > 0.0 :
211
- loss += self .gradient_penalty_weight * self .compute_gradient_magnitude (
212
- policy_batch , expert_batch
212
+ total_loss += (
213
+ self .gradient_penalty_weight
214
+ * self .compute_gradient_magnitude (policy_batch , expert_batch )
213
215
)
214
- return loss , torch . mean ( policy_estimate ), torch . mean ( expert_estimate ), kl_loss
216
+ return total_loss , stats_dict
215
217
216
218
def compute_gradient_magnitude (
217
219
self , policy_batch : AgentBuffer , expert_batch : AgentBuffer
@@ -243,9 +245,9 @@ def compute_gradient_magnitude(
243
245
hidden = self .encoder (encoder_input )
244
246
if self ._settings .use_vail :
245
247
use_vail_noise = True
246
- z_mu = self .z_mu_layer (hidden )
247
- hidden = torch .normal (z_mu , self .z_sigma * use_vail_noise )
248
- hidden = self .estimator (hidden )
248
+ z_mu = self ._z_mu_layer (hidden )
249
+ hidden = torch .normal (z_mu , self ._z_sigma * use_vail_noise )
250
+ hidden = self ._estimator (hidden )
249
251
estimate = torch .mean (torch .sum (hidden , dim = 1 ))
250
252
gradient = torch .autograd .grad (estimate , encoder_input )[0 ]
251
253
# Norm's gradient could be NaN at 0. Use our own safe_norm
0 commit comments