diff --git a/examples/text_to_image/train_text_to_image_xla.py b/examples/text_to_image/train_text_to_image_xla.py index 9b8bce3ef..9c6e7c86a 100644 --- a/examples/text_to_image/train_text_to_image_xla.py +++ b/examples/text_to_image/train_text_to_image_xla.py @@ -71,23 +71,26 @@ def start_training(self): assert measure_start_step < self.args.max_train_steps total_time = 0 for step in range(0, self.args.max_train_steps): - if step == measure_start_step and PROFILE_DIR is not None: - xm.wait_device_ops() - xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) - last_time = time.time() try: batch = next(self.dataloader) except Exception as e: dataloader_exception = True print(e) break + if step == measure_start_step and PROFILE_DIR is not None: + xm.wait_device_ops() + xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) + last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) self.global_step += 1 - + xm.mark_step() if not dataloader_exception: xm.wait_device_ops() total_time = time.time() - last_time print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") + else: + print("dataloader exception happen, skip result") + return def step_fn( self,