Skip to content

[Performance]: Custom fused kernel tracking #25179

@ProExpertProg

Description

@ProExpertProg

We need various additional fused kernels so we can allow fusion to happen orthogonally of other features. E.g. we want to support more kinds of quantization with existing passes, as well as allow for quantized kv cache to be orthogonal to fusion. Some of these kernels will require new passes, and most will require new patterns in existing passes.

Flashinfer (B200)

Currently available (and integrated):

  • silu_mul + nvfp4
  • fp8 attention (decode) + static fp8 quant
  • fp8 attention (prefill) + static fp8 quant
  • fp8 attention (decode) + nvfp4 quant
  • fp8 attention (prefill) + nvfp4 quant
  • all_reduce + rms_norm
  • all_reduce + rms_norm + static fp8 quant
  • all_reduce + rms_norm + nvfp4 quant

Currently available (and not integrated):

  • ? bf16 attention (decode) + static fp8 quant
  • ? bf16 attention (decode) + nvfp4 quant
  • rope + static fp8 quant
  • rope + kvcache
  • rope + static fp8 quant + kvcache
  • ? mla attention (prefill) + static fp8 quant
  • ? mla attention (prefill) + nvfp4 quant

I don't remember if bf16-attn with fused output quant was supported in prefill or decode but I think it was just one of the two. @pavanimajety is already working on integrating RoPE fusion. I also recall that mla prefill uses the mha kernel so this should be supported already but somebody also mentioned some cubins might need to be regenerated?

Needed (easy integration):

  • rms_norm + nvfp4
  • ? bf16 attention (prefill) + static fp8 quant
  • ? bf16 attention (prefill) + nvfp4 quant
  • all_reduce + rms_norm + dynamic per-token fp8 quant
  • all_reduce + rms_norm + dynamic group fp8 quant

Needed (lower urgency):

  • mla attention (decode) + static fp8 quant & nvfp4 quant
    • integration requires splitting decode/prefill out of MLA custom op and quantizing BMMs
  • mla attention (prefill) + dynamic fp8 quant
    • This is likely harder to fuse into attention but should be possible if attn output has group_size elements available at a time (up to 128)
    • @bringlein proposed calculating the scale in a similar way to softmax and we might be able to reuse that to do full row dynamic per-token quant
  • mla attention (prefill) + dynamic group fp8 quant
  • all_reduce + rms_norm + dynamic per-tensor fp8 quant (low priority)

Other hardware (Hopper, AMD, etc.)

I think the goal for expanded support should be custom fused Triton (or even Helion) kernels for better portability and to avoid combinatorial explosion, although if there are platform-specific kernels to be contributed they are certainly extremely welcome. Platform-specific might be required for collectives (or can we leverage Triton-Distributed?)

Current:

  • attention + static fp8 quant fusion (Triton/ROCM backend)
    • we should make sure we maintain support for these in both the Triton and ROCm kernels (cc @tdoublep @bringlein)

Wanted:

  • all_reduce + rms_norm (+ quant: all flavors)
  • GEMM + static fp8 quant
    • Dynamic quant might be harder, again could be easier with group over per-token
    • Can we do this with our existing cutlass kernel epilogues?
  • mla attention + quant
  • aiter attention + quant

Wanted (only if faster than torch Inductor):

  • rms_norm + quant
  • silu_mul + quant

cc @pavanimajety @gshtras @kushanam @nvpohanh @mgoin @zou3519 @BoyuanFeng @tlrmchlsmth @LucasWilkinson

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Ready

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions