Description
For float8 training with rowwise scales, if fast_accum=False, the cutlass rowwise GEMM was so slow with fast_accum=False that we use the cublas tensorwise GEMM and do apply the input tensor scales to the output tensor manually (context).
With torch.compile this unfortunately requires inductor to codegen at least one extra triton kernel per GEMM to be launched just for the rescaling, which results in inefficient data movement between HBM and SRAM, since the GEMM should be natively handling the output rescaling.
At the time this workaround was better but some perf improvements to the cutlass rowwise GEMM have landed this half: pytorch/pytorch#144809
We should investigate if perf is good enough to remove this workaround now. If not, we should optimize this rowwise GEMM further and eliminate the need for these extra rescaling kernels.