You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When doing large scale runs, found that compiled_rmsnorm was producing aberrant loss curves compared to tp or async tp with rmsnorm.
Verified this reproes with small scale and thus opening issue for tracking.
Compiled rmsnorm works fine under 1D FSDP.
Easy repro is to run debug model and 50 steps with compiled rms_norm with tp, vs tp with rms norm:
fsdp + compile + rmsnorm: