-
Notifications
You must be signed in to change notification settings - Fork 327
Closed
Description
What this cast is doing
- reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
- for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
- return the casted elements and the scale
We really should do this all in one kernel, but today we see two kernels
How to reproduce (requires latest main branch)
TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only
Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels
Metadata
Metadata
Assignees
Labels
No labels
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
eellison commentedon Feb 26, 2025
If you change the view to occur at the end, you do get a single fused kernel, and a 2.5x speedup. I am going to look into making the fusion occur automatically in inductor. Still need to scope out what changes are involved. But maybe this is possible as a manual change workaround for now?
to