What this cast is doing * reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16 * for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn * return the casted elements and the scale We really should do this all in one kernel, but today we see two kernels How to reproduce (requires latest main branch) ``` TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only ``` Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels