Skip to content

torch.compile cast to mxfp8 should only require one kernel #1769

@vkuzo

Description

@vkuzo
Contributor

What this cast is doing

  • reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
  • for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
  • return the casted elements and the scale

We really should do this all in one kernel, but today we see two kernels

How to reproduce (requires latest main branch)

TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only

Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels

Activity

eellison

eellison commented on Feb 26, 2025

@eellison
Contributor

If you change the view to occur at the end, you do get a single fused kernel, and a 2.5x speedup. I am going to look into making the fusion occur automatically in inductor. Still need to scope out what changes are involved. But maybe this is possible as a manual change workaround for now?

    view_1: "f32[2048, 4096][4096, 1]cuda:0" = torch.ops.aten.view.default(clamp_max_1, [2048, 4096]);  clamp_max_1 = None
    convert_element_type_4: "f8e4m3fn[2048, 4096][4096, 1]cuda:0" = torch.ops.prims.convert_element_type.default(view_1, torch.float8_e4m3fn);  view_1 = None
    return (where, convert_element_type_4)

to

    convert_element_type_4: "f8e4m3fn[8388608][1]cuda:0" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.float8_e4m3fn);  clamp_max_1 = None
    view_1: "f8e4m3fn[2048, 4096][4096, 1]cuda:0" = torch.ops.aten.view.default(convert_element_type_4, [2048, 4096]);  convert_element_type_4 = None
    return (where, view_1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    Participants

    @vkuzo@eellison

    Issue actions

      torch.compile cast to mxfp8 should only require one kernel · Issue #1769 · pytorch/ao