From 0e1f994e3b59553e7dbd4ec83add6c0ef44e3abf Mon Sep 17 00:00:00 2001 From: zpcore Date: Thu, 12 Sep 2024 11:27:45 -0700 Subject: [PATCH 1/2] add mark_step --- examples/text_to_image/train_text_to_image_xla.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_xla.py b/examples/text_to_image/train_text_to_image_xla.py index 9b8bce3ef..bffea7c21 100644 --- a/examples/text_to_image/train_text_to_image_xla.py +++ b/examples/text_to_image/train_text_to_image_xla.py @@ -71,23 +71,26 @@ def start_training(self): assert measure_start_step < self.args.max_train_steps total_time = 0 for step in range(0, self.args.max_train_steps): - if step == measure_start_step and PROFILE_DIR is not None: - xm.wait_device_ops() - xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) - last_time = time.time() try: batch = next(self.dataloader) except Exception as e: dataloader_exception = True print(e) break + if step == measure_start_step and PROFILE_DIR is not None: + xm.wait_device_ops() + xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) + last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) - self.global_step += 1 - + self.global_step += 1 + xm.mark_step() if not dataloader_exception: xm.wait_device_ops() total_time = time.time() - last_time print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") + else: + print("dataloader exception happen, skip result") + return def step_fn( self, From d8e12d960b176e5f323477ae97d1d7842fef595b Mon Sep 17 00:00:00 2001 From: zpcore Date: Thu, 12 Sep 2024 13:00:50 -0700 Subject: [PATCH 2/2] nit --- examples/text_to_image/train_text_to_image_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_xla.py b/examples/text_to_image/train_text_to_image_xla.py index bffea7c21..9c6e7c86a 100644 --- a/examples/text_to_image/train_text_to_image_xla.py +++ b/examples/text_to_image/train_text_to_image_xla.py @@ -82,7 +82,7 @@ def start_training(self): xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) - self.global_step += 1 + self.global_step += 1 xm.mark_step() if not dataloader_exception: xm.wait_device_ops()