@@ -973,10 +973,8 @@ def collate_fn(examples):
973
973
loss = loss .mean ()
974
974
975
975
# Gather the losses across all processes for logging (if we use distributed training).
976
- # avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
977
- # train_loss += avg_loss.item() / args.gradient_accumulation_steps
978
- print ('accelerator.sync_gradients=' , accelerator .sync_gradients )
979
- accelerator .sync_gradients = False
976
+ avg_loss = accelerator .gather (loss .repeat (args .train_batch_size )).mean ()
977
+ train_loss += avg_loss .item () / args .gradient_accumulation_steps
980
978
981
979
# Backpropagate
982
980
accelerator .backward (loss )
@@ -1020,14 +1018,11 @@ def collate_fn(examples):
1020
1018
save_path = os .path .join (args .output_dir , f"checkpoint-{ global_step } " )
1021
1019
accelerator .save_state (save_path )
1022
1020
logger .info (f"Saved state to { save_path } " )
1023
- if step % 20 == 0 :
1024
- logs = {"step_loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ]}
1025
- progress_bar .set_postfix (** logs )
1021
+
1022
+ logs = {"step_loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ]}
1023
+ progress_bar .set_postfix (** logs )
1026
1024
1027
1025
if global_step >= args .max_train_steps :
1028
- import torch_xla .debug .metrics as met
1029
- # For short report that only contains a few key metrics.
1030
- print (met .short_metrics_report ())
1031
1026
break
1032
1027
1033
1028
if accelerator .is_main_process :
0 commit comments