Skip to content

Fix the incorrect step log for profiler after resuming from a checkpoint #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 3, 2024

Conversation

fegin
Copy link
Contributor

@fegin fegin commented May 2, 2024

Summary:
The profiler currently maintains a counter locally and that counter is not synchronized with the checkpointed train step. This PR fixes the issue.

Summary:
The profiler currently maintains a counter locally and that counter is not synchronized with the checkpointed train step. This PR fixes the issue.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 2, 2024
@fegin fegin requested a review from tianyu-l May 2, 2024 06:37
@@ -27,8 +27,7 @@ def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
trace_dir = os.path.join(dump_dir, save_trace_dir)
profile_freq = config.profiling.profile_freq

_global_iter_count = 0

_global_iter_count = global_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not sync with the internal state step_num of profiler. Let's remove it and replace all appearance of _global_iter_count by prof.step_num in trace_handler.

Also need to set torch_profiler.step_num = global_step on line 71.

@fegin fegin requested a review from tianyu-l May 2, 2024 22:17
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Note: if resuming from a global_step such that global_step % profile_freq == profile_freq - 1, there won't be profile trace at step global_step + 1, which would have been done if not resuming from a checkpoint.

Essentially this is caused by profiler doesn't maintain a state dict. This miss may be avoided if we correctly implement save & load functions more carefully, which doesn't seem to be worth it.

@fegin
Copy link
Contributor Author

fegin commented May 3, 2024

In some trainers, the profiling is designed to be done only once and the global step is used to prevent profiling from happening after checkpoint resume.

@fegin fegin merged commit 695bd01 into main May 3, 2024
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
…int (pytorch#293)

Summary:
The profiler currently maintains a counter locally and that counter is
not synchronized with the checkpointed train step. This PR fixes the
issue.
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
…int (pytorch#293)

Summary:
The profiler currently maintains a counter locally and that counter is
not synchronized with the checkpointed train step. This PR fixes the
issue.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants