Skip to content

Only calls destroy_process_group if the trainer exist successfully #1342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions torchtitan/experiments/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,11 @@ def train_step(
logger.info("Created seed checkpoint")
else:
trainer.train()
finally:
except Exception:
if trainer:
trainer.close()

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
Comment on lines -228 to +235
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason for wrapping almost the entire program in try-except?
It seems trainer.close() just closes a file?

    def close(self) -> None:
        if self.checkpointer:
            self.checkpointer.close()

May I say that Python will automatically close a file even if the program ends due to an unhanded exception ?

Here is what will happen upon program exit (exception or not):

  1. CPython dereferences all objects, and all objects have their destructors called, even if the program ends due to an unhanded exception.
  2. When the reference count hits zero, no Python code can reach the object anymore, so the object gets deallocated. And when it gets deallocated, Python calls the __del__() destructor.
  3. Python’s __del__() method for files flushes the buffers and closes the file from the operating system’s point of view.

Copy link
Contributor Author

@fegin fegin Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a close() to ensure child processes exiting and finishing correctly, which is currently used by Checkpointer. Due to the GIL issue, there may be background processes running, not necessarily just Checkpointer. A proper close() is required to ensure no data lost.

11 changes: 6 additions & 5 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,11 @@ def close(self) -> None:
logger.info("Created seed checkpoint")
else:
trainer.train()
finally:
except Exception:
if trainer:
trainer.close()

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

destroy_process_group() is causing the hang? or trainer.close()?

Becuase ideally calling destroy_process_group() itself would not hang, if it does that seems like another bug we should look into

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

destroy_process_group(), I attached the py-spy result in the summary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Cc @kwen2501 is that supposed to happen?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no guarantee for "destroy_process_group would not hang in whatever situation".
From the doc of NCCL re ncclCommDestroy:

This function is an intra-node collective call, which all ranks on the same node should call to avoid a hang.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchrun should crash other ranks in case one rank crashed. It seems it failed to do so here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eventually timeout

This is the right behavior upon collective mismatch.

Copy link
Contributor

@kwen2501 kwen2501 Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My q is, why does the user program mute the exception and not re-throw? Does it believe it is recoverable? In this case it does not seem so?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, then it sounds like we are actually using destroy_process_group wrong in the torchtitan scripts. We should only call it if we are on the clean exit path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docs here look correct to me, but i think we could add an example of how to do this kind of exception handling on exit the recommended way
https://docs.pytorch.org/docs/stable/distributed.html#shutdown

ill make a PR tmrw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My q is, why does the user program mute the exception and not re-throw? Does it believe it is recoverable? In this case it does not seem so?

Good question, @kwen2501 , I had the same impression as @wconstab. I thought destroy_process_group is a purely local call. That's why I wrapped it in finally: and it didn't go wrong until this week when we were debugging CP issues.

This PR should be the right way to call destroy_process_group().

logger.info("Process group destroyed.")
Loading