@@ -113,7 +113,9 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
113
113
self .policy .behavior_spec .observation_shapes ,
114
114
policy_network_settings ,
115
115
)
116
- self .soft_update (self .policy .actor_critic .critic , self .target_network , 1.0 )
116
+ ModelUtils .soft_update (
117
+ self .policy .actor_critic .critic , self .target_network , 1.0
118
+ )
117
119
118
120
self ._log_ent_coef = torch .nn .Parameter (
119
121
torch .log (torch .as_tensor ([self .init_entcoef ] * len (self .act_size ))),
@@ -201,12 +203,6 @@ def sac_q_loss(
201
203
q2_loss = torch .mean (torch .stack (q2_losses ))
202
204
return q1_loss , q2_loss
203
205
204
- def soft_update (self , source : nn .Module , target : nn .Module , tau : float ) -> None :
205
- for source_param , target_param in zip (source .parameters (), target .parameters ()):
206
- target_param .data .copy_ (
207
- target_param .data * (1.0 - tau ) + source_param .data * tau
208
- )
209
-
210
206
def sac_value_loss (
211
207
self ,
212
208
log_probs : torch .Tensor ,
@@ -441,20 +437,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
441
437
self .target_network .network_body .copy_normalization (
442
438
self .policy .actor_critic .network_body
443
439
)
444
- (
445
- sampled_actions ,
446
- log_probs ,
447
- entropies ,
448
- sampled_values ,
449
- _ ,
450
- ) = self .policy .sample_actions (
440
+ (sampled_actions , log_probs , _ , _ ) = self .policy .sample_actions (
451
441
vec_obs ,
452
442
vis_obs ,
453
443
masks = act_masks ,
454
444
memories = memories ,
455
445
seq_len = self .policy .sequence_length ,
456
446
all_log_probs = not self .policy .use_continuous_act ,
457
447
)
448
+ value_estimates , _ = self .policy .actor_critic .critic_pass (
449
+ vec_obs , vis_obs , memories , sequence_length = self .policy .sequence_length
450
+ )
458
451
if self .policy .use_continuous_act :
459
452
squeezed_actions = actions .squeeze (- 1 )
460
453
q1p_out , q2p_out = self .value_network (
@@ -504,7 +497,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
504
497
q1_stream , q2_stream , target_values , dones , rewards , masks
505
498
)
506
499
value_loss = self .sac_value_loss (
507
- log_probs , sampled_values , q1p_out , q2p_out , masks , use_discrete
500
+ log_probs , value_estimates , q1p_out , q2p_out , masks , use_discrete
508
501
)
509
502
policy_loss = self .sac_policy_loss (log_probs , q1p_out , masks , use_discrete )
510
503
entropy_loss = self .sac_entropy_loss (log_probs , masks , use_discrete )
@@ -528,7 +521,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
528
521
self .entropy_optimizer .step ()
529
522
530
523
# Update target network
531
- self .soft_update (self .policy .actor_critic .critic , self .target_network , self .tau )
524
+ ModelUtils .soft_update (
525
+ self .policy .actor_critic .critic , self .target_network , self .tau
526
+ )
532
527
update_stats = {
533
528
"Losses/Policy Loss" : policy_loss .item (),
534
529
"Losses/Value Loss" : value_loss .item (),
0 commit comments