Skip to content

compiled rms_norm not numerically accurate (fails to produce good loss curves) when run under tp #497

@lessw2020

Description

@lessw2020

When doing large scale runs, found that compiled_rmsnorm was producing aberrant loss curves compared to tp or async tp with rmsnorm.
Verified this reproes with small scale and thus opening issue for tracking.
Compiled rmsnorm works fine under 1D FSDP.

Easy repro is to run debug model and 50 steps with compiled rms_norm with tp, vs tp with rms norm:
fsdp + compile + rmsnorm:
Screenshot 2024-07-31 at 11 51 47 AM

fspd + tp (2) + compile + rmsnorm:
Screenshot 2024-07-31 at 11 52 26 AM

Issue shown here:
fsdp + tp(2) + compiled rms norm:
Screenshot 2024-07-31 at 11 52 09 AM

Finally, compiled rmsnorm + fsdp + compile (1D) works fine:
Screenshot 2024-07-31 at 12 51 24 PM

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions