Skip to content

[float8] Investigate if workaround for slow cutlass rowwise GEMM when fast_accum=False is still needed after perf improvments and potentially optimize GEMM further #2184

Open
@danielvegamyhre

Description

@danielvegamyhre

For float8 training with rowwise scales, if fast_accum=False, the cutlass rowwise GEMM was so slow with fast_accum=False that we use the cublas tensorwise GEMM and do apply the input tensor scales to the output tensor manually (context).

With torch.compile this unfortunately requires inductor to codegen at least one extra triton kernel per GEMM to be launched just for the rescaling, which results in inefficient data movement between HBM and SRAM, since the GEMM should be natively handling the output rescaling.

At the time this workaround was better but some perf improvements to the cutlass rowwise GEMM have landed this half: pytorch/pytorch#144809

We should investigate if perf is good enough to remove this workaround now. If not, we should optimize this rowwise GEMM further and eliminate the need for these extra rescaling kernels.

cc @drisspg @vkuzo

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions