Skip to content

Commit 65fb7df

Browse files
committed
Revert "debugging"
This reverts commit 7e45bdd.
1 parent 7e45bdd commit 65fb7df

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -973,10 +973,8 @@ def collate_fn(examples):
973973
loss = loss.mean()
974974

975975
# Gather the losses across all processes for logging (if we use distributed training).
976-
# avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
977-
# train_loss += avg_loss.item() / args.gradient_accumulation_steps
978-
print('accelerator.sync_gradients=', accelerator.sync_gradients)
979-
accelerator.sync_gradients = False
976+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
977+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
980978

981979
# Backpropagate
982980
accelerator.backward(loss)
@@ -1020,14 +1018,11 @@ def collate_fn(examples):
10201018
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
10211019
accelerator.save_state(save_path)
10221020
logger.info(f"Saved state to {save_path}")
1023-
if step % 20 == 0:
1024-
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1025-
progress_bar.set_postfix(**logs)
1021+
1022+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1023+
progress_bar.set_postfix(**logs)
10261024

10271025
if global_step >= args.max_train_steps:
1028-
import torch_xla.debug.metrics as met
1029-
# For short report that only contains a few key metrics.
1030-
print(met.short_metrics_report())
10311026
break
10321027

10331028
if accelerator.is_main_process:

0 commit comments

Comments
 (0)