Skip to content

torch.compile cast to mxfp8 should only require one kernel #1769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vkuzo opened this issue Feb 24, 2025 · 1 comment · Fixed by #1786
Closed

torch.compile cast to mxfp8 should only require one kernel #1769

vkuzo opened this issue Feb 24, 2025 · 1 comment · Fixed by #1786
Assignees

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Feb 24, 2025

What this cast is doing

  • reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
  • for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
  • return the casted elements and the scale

We really should do this all in one kernel, but today we see two kernels

How to reproduce (requires latest main branch)

TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only

Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels

@eellison
Copy link
Contributor

If you change the view to occur at the end, you do get a single fused kernel, and a 2.5x speedup. I am going to look into making the fusion occur automatically in inductor. Still need to scope out what changes are involved. But maybe this is possible as a manual change workaround for now?

    view_1: "f32[2048, 4096][4096, 1]cuda:0" = torch.ops.aten.view.default(clamp_max_1, [2048, 4096]);  clamp_max_1 = None
    convert_element_type_4: "f8e4m3fn[2048, 4096][4096, 1]cuda:0" = torch.ops.prims.convert_element_type.default(view_1, torch.float8_e4m3fn);  view_1 = None
    return (where, convert_element_type_4)

to

    convert_element_type_4: "f8e4m3fn[8388608][1]cuda:0" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.float8_e4m3fn);  clamp_max_1 = None
    view_1: "f8e4m3fn[2048, 4096][4096, 1]cuda:0" = torch.ops.aten.view.default(convert_element_type_4, [2048, 4096]);  convert_element_type_4 = None
    return (where, view_1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants