-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] add Hopper FP8 context MLA #7107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Caution Review failedThe pull request is closed. 📝 WalkthroughWalkthroughExpands FMHA kernel enumeration and metadata to include SEPARATE_Q_K_V and new 192x128 context MLA variants (with output_dtype options), adjusts TMA store gating and kernel traits propagation, introduces a CUBIN-aware per-CUBIN QKV descaling factor and unified packed store path, fixes V-tile transpose indexing, and broadens FP8 MLA test gating to include SM90. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Build as Kernel Enumeration
participant Enum as enumerate_qgmma_flash_warpspec_kernels
participant Traits as Kernel_traits
participant Cubin as CUBIN Metadata
Build->>Enum: call(sm, dtype, output_dtype, layouts...)
Note over Enum: include InputLayout.SEPARATE_Q_K_V and 192x128 context MLA
Enum->>Traits: compute/instantiate traits (propagate RETURN_SOFTMAX_STATS_, OutputType, SAGE sizes)
Note over Enum: enable_tma_store checks output_dtype and head_size%16==0
Enum->>Cubin: emit kernel declarations/metadata for 192x128 (e4m3 + bf16-output)
sequenceDiagram
autonumber
participant Kernel as Gmem_tile_o_qgmma_fp32_16bits
participant Params as Params (runtime)
participant Packer as Acc_packer
participant GM as GlobalMemory
Kernel->>Params: initialize params_scale_bmm2_ (CUBIN-aware)
Kernel->>Kernel: store(accumulators...)
Kernel->>Packer: Acc_packer<Src,Out,Scale>::run(_src, params_scale_bmm2_)
Packer-->>Kernel: packed uint2 (_dst)
Kernel->>GM: stg(_dst) @ computed offset
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 💡 Knowledge Base configuration:
You can enable these sources in your CodeRabbit configuration. 📒 Files selected for processing (59)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/kernels/fmha_v2/fmha_test.py (1)
1-1
: Missing NVIDIA copyright header (2025)Per coding guidelines, prepend the current-year NVIDIA copyright header to all source files.
Apply at the very top of the file:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # SPDX-License-Identifier: Apache-2.0
🧹 Nitpick comments (5)
cpp/kernels/fmha_v2/fmha_test.py (1)
168-169
: De-duplicate supported-arch lists for FP8 MLA gatingSupported SMs for FP8 context MLA are hardcoded here, while generation-phase FP8 MLA uses a different guard (Lines 213-215). To reduce churn when enabling more archs, centralize these into a single constant.
Apply this diff in-place:
- if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in [90, 120]: + if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in SUPPORTED_FP8_CONTEXT_MLA: pytest.skip("FP8 MLAs are only supported on sm90 and sm120 currently.")Then, add this near the top of the file (e.g., below imports) to define the shared constant:
SUPPORTED_FP8_CONTEXT_MLA = (90, 120)cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (1)
1267-1281
: Improved packing implementation with configurable scaling.The refactored store path using
Acc_packer
with template parameterScale
is cleaner and more maintainable than the previous per-element approach. The conditional scaling based onUNIFIED_EPILOGUE_SCALE
provides flexibility for different build configurations.Consider extracting the packing logic into a helper function to reduce macro usage:
-#define STORE_COLUMNS() \ - { \ - /* we assume M = 1. some shortcuts. */ \ - static_assert(M == 1); \ - uint4 _src = { \ - .x = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2), \ - .y = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2), \ - .z = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1), \ - .w = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1), \ - }; \ - uint2 _dst = Acc_packer<float, Output_type, Scale>::run(this, _src); \ - int64_t _offset = \ - (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD) * STEP_N; \ - fmha::stg(o_ptr_ + _offset, _dst); \ - } + template<bool Scale> + inline __device__ void store_columns(auto const& acc, int mma_ni, int ci, int ri, int64_t step_m) { + static_assert(M == 1); + uint4 src = { + .x = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2), + .y = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2), + .z = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1), + .w = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1), + }; + uint2 dst = Acc_packer<float, Output_type, Scale>::run(this, src); + int64_t offset = (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD) * STEP_N; + fmha::stg(o_ptr_ + offset, dst); + }Then replace
STORE_COLUMNS()
withstore_columns<Scale>(acc, mma_ni, ci, ri, step_m);
in the loops.cpp/kernels/fmha_v2/setup.py (3)
1917-1921
: TMA-store gating: consider row-byte alignment (16B) rather than head_size only; optionally broaden to bf16 when validatedUsing the actual output dtype here is correct. However, aligning on element count (head_size % 16) assumes 1B elements. TMA operates on 16B granularity, so the robust check is row_bytes % 16 == 0. That keeps behavior identical for FP8 today and future-proofs the condition. If/when bf16 output store via TMA is validated, you can safely include it by keeping the same row alignment check.
Apply this minimal generalization now (no behavior change for FP8), keeping bf16 disabled until you validate:
def enable_tma_store(kspec): - output_dtype = kspec.output_dtype if kspec.output_dtype is not None else kspec.dtype - # TMA copies data in the 16B granularity. - return 'true' if (output_dtype in ['e4m3', 'e4m3_fp32'] - and kspec.head_size % 16 == 0) else 'false' + output_dtype = kspec.output_dtype if kspec.output_dtype is not None else kspec.dtype + # TMA copies data in 16B granularity: require row-size (in bytes) to be a multiple of 16. + row_bytes = kspec.head_size * dtype2bytes[output_dtype] + return 'true' if (output_dtype in ['e4m3', 'e4m3_fp32'] and (row_bytes % 16 == 0)) else 'false'Optionally, once store path readiness for 16-bit outputs is confirmed, enable bf16 with the same alignment:
- return 'true' if (output_dtype in ['e4m3', 'e4m3_fp32'] and (row_bytes % 16 == 0)) else 'false' + return 'true' if (output_dtype in ['e4m3', 'e4m3_fp32', 'bf16'] and (row_bytes % 16 == 0)) else 'false'Would you like me to add a small guard (env flag) to toggle bf16 TMA-store at runtime for A/B perf validation without rebuilds?
3816-3818
: Broadened input-layout combinations for FP8 WS kernels: OK; consider trimming to avoid generating unneeded SEPARATE_Q_K_V variants outside MLAIncluding SEPARATE_Q_K_V in the general cartesian product is functionally fine, but most of those specs will be filtered out later by specs_names, adding enumeration noise. Optional: restrict the general pass to PACKED/CONTIGUOUS/Q_PAGED and handle SEPARATE_Q_K_V only in the MLA block below (lines 3932-3971).
Light refactor (paired with the MLA block tweak below) to avoid spec explosion:
- combinations = product([False, True], \ - [InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, - InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V], - [False, True]) + combinations = product( + [False, True], # alibi + [InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, InputLayout.Q_PAGED_KV], + [False, True], # enable_attn_logit_softcapping + )This change should be combined with forcing SEPARATE_Q_K_V in the 192x128 MLA block (see comment at lines 3932-3971).
3932-3971
: 192x128 context MLA variants: force SEPARATE_Q_K_V here to avoid generating unused variantsGreat to see explicit 192x128 WS variants with output bf16 and default output. Since this block is exclusively for Deepseek context MLA (separate Q/K/V), pin the input layout here instead of inheriting it from the outer combinations. This removes many discarded specs and keeps intent explicit.
Apply this small tweak:
- for output_type in [None, 'bf16']: + for output_type in [None, 'bf16']: specs.append( kernel_spec( sm=sm, sm_mma=90, dtype=dtype, seq_len=0, # support any sequence length head_size=192, head_size_v=128, warps_m=4, #4x1 warpgroups warps_n=1, version=2, interleaved=False, ldgsts_q= False, # for Hopper kernels, ldgsts = False signals TMA usage. ldgsts_k=False, ldgsts_v=False, share_smem_k_v=False, loop_step=64, q_tile_buffers=1, # only used by warp specialized kernels has_noloop=0, noloop_step=64, kv_loop_step=128, kv_tile_buffers=2, # only used by warp specialized kernels unroll_threshold=1, has_scale_max=False, flash_attention=True, warp_specialization=True, alibi=alibi, enable_attn_logit_softcapping=enable_attn_logit_softcapping, return_softmax_stats= False, # return softmax stats is not supported for fp8 now scheduling_mode=scheduling_mode, - input_layout=input_layout, + input_layout=InputLayout.SEPARATE_Q_K_V, sage_block_sizes=sage_block_sizes, output_dtype=output_type))This pairs with limiting the general combinations to PACKED/CONTIGUOUS/Q_PAGED (see lines 3816-3818).
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
cpp/kernels/fmha_v2/fmha_test.py
(1 hunks)cpp/kernels/fmha_v2/setup.py
(5 hunks)cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
(2 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
(1 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (5)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
cpp/kernels/fmha_v2/fmha_test.py
cpp/kernels/fmha_v2/setup.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
cpp/kernels/fmha_v2/fmha_test.py
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
cpp/kernels/fmha_v2/setup.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Parameter names must be consistent between declarations and definitions
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
**/*.{h,hpp,hxx,hh,cuh}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (10)
cpp/kernels/fmha_v2/fmha_test.py (2)
168-169
: Enable FP8 context MLA on Hopper (sm90): LGTMAllowing FP8 MLAs to run on sm90 (in addition to sm120) matches the new context kernels added in this PR and the stated goal “Hopper FP8 context MLA.” No functional concerns with the guard.
213-215
: Confirm intent: gen-phase FP8 MLA still sm120-onlyContext-phase FP8 MLAs now allow sm90, but generation-phase remains restricted to sm120. If this is intentional (e.g., no tested/available sm90 gen kernels yet), great—please confirm. If sm90 gen is expected soon, consider reusing the same constant approach so the test matrix can be updated in one place when ready.
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (1)
1974-1977
: The struct definition forFusedMultiHeadAttentionKernelMetaInfoV2
wasn’t found with the above searches. This absence makes it impossible to confirm the exact bitfield ordering and flag positions automatically. I recommend manually verifying the following:
Locate the
FusedMultiHeadAttentionKernelMetaInfoV2
struct (or its alias) in your codebase.Confirm the ordering of its boolean fields to ensure the “separate Q/K/V” and “paged KV” flags align with these kernel entries.
Once confirmed, add the inline comments shown in the diff hunk to document:
- dq = 192, dv = 128
- causal = 0/1
- output = BF16 (where applicable)
Optionally, introduce named
constexpr
values for164096
and384
near the top of this file or in a shared header:constexpr int kSM90_MLA_192X128_BLOCK_SIZE = 164096; constexpr int kSM90_MLA_SHARED_MEM_WORDS = 384;Then replace the magic numbers in these initializers with those constants.
Please confirm these field mappings and flag meanings to prevent future dispatch regressions.
cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h (1)
592-593
: Changes to Base alias look correct.The updated template parameter list now properly passes through
RETURN_SOFTMAX_STATS_
,OutputType
, and the three SAGE block size parameters to match the baseKernel_traits
template. This ensures proper trait inheritance and type propagation.cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h (1)
758-767
: Correct loop bounds update for V tile dimension.The change from
Kernel_traits::D_GROUPS
toKernel_traits::DV_GROUPS
properly aligns with the V tensor's dimension groups, and the offset calculation usingKernel_traits::DV
instead ofKernel_traits::D
is consistent with the V tile layout. This is essential for proper handling of different V dimensions in context MLA variants.cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (3)
658-663
: Conditional QKV descale factor initialization.The conditional initialization of
params_scale_bmm2_
appropriately handles both CUBIN and non-CUBIN builds. The CUBIN path checks for dynamic scaling viaparams.scale_bmm2_d
, falling back to the static value if not present.
1225-1232
: Consistent scale parameter initialization across constructors.Good to see the same conditional initialization pattern applied to the
Gmem_tile_o_qgmma_fp32_16bits
constructor, maintaining consistency with the 8-bit variant.
1318-1319
: Clear documentation added for new member.The comment clearly explains that
params_scale_bmm2_
represents the QKV descale factor, which helps with code maintainability.cpp/kernels/fmha_v2/setup.py (2)
3917-3918
: KV step 128 for 128 < D <= 256: good, but please confirm shared memory headroom on sm90 across tile buffersBumping kv_loop_step to 128 to mitigate register pressure is reasonable. Verify that for worst-case batch/head configurations the computed Ktraits::BYTES_PER_SMEM remains under Hopper’s per-CTA SMEM budget given kv_tile_buffers=2 and q_tile_buffers=1. The doc block above (lines 3759-3768) explains the math; worth re-checking 192x128 too.
Would you like a quick script to print the generated SMEM sizes from generated/print_kernel_traits.cu output for these configs to prevent accidental regressions?
6422-6434
: Allowing e4m3 for 192x128 SEPARATE_Q_K_V path: LGTMBroadening the allowed dtype set to include e4m3 (alongside bf16 and e4m3_fp32) for the context MLA path matches the new kernel variants and the PR test scenario. This should unblock FP8-context selection on SM90.
Make sure the runtime side (where run_fmha_v2 is called) passes output_data_type as bf16 when selecting the bf16-output variant; otherwise the default FP8-output kernel will be chosen. I can provide a small rg-based script to sanity-check call sites if helpful.
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); | ||
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
New externs for 192x128 S_q_k_v (and BF16 output) look consistent; please verify out-of-line definitions exist.
Signatures follow existing conventions and appear correct. To avoid ODR/link issues, confirm these have exactly one non-extern definition and that parameter names match between declaration and definition (per guidelines).
Run this script from repo root to verify definitions and parameter-name consistency:
Optionally, if these are SM90-only implementations, consider guarding the declarations/uses with the same compile-time condition used for other SM90-only symbols for consistency.
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Declarations (should match these two)
decls=(
'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90'
'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90'
)
for sym in "${decls[@]}"; do
echo "== Checking: $sym =="
# Find declaration sites
rg -nP --type=cpp --type=cu $'extern\\s+void\\s+'"$sym"$'\\s*\\(\\s*Fused_multihead_attention_params_v2\\s*&\\s*params\\s*,\\s*const\\s*Launch_params\\s*&\\s*launch_params\\s*,\\s*cudaStream_t\\s*stream\\s*\\)' -C1
# Find definitions (non-extern)
rg -nP --type=cpp --type=cu $'(?<!extern)\\s+void\\s+'"$sym"$'\\s*\\(\\s*Fused_multihead_attention_params_v2\\s*&\\s*params\\s*,\\s*const\\s*Launch_params\\s*&\\s*launch_params\\s*,\\s*cudaStream_t\\s*stream\\s*\\)' -C3
done
Length of output: 669
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Ensure we search all relevant file extensions instead of relying on rg file types.
decls=(
'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90'
'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90'
)
for sym in "${decls[@]}"; do
echo "== Checking: $sym =="
echo "-- extern declarations in headers --"
rg -nP --glob '*.h' "extern\s+void\s+$sym\s*\([^)]*\)" -C1 || true
echo "-- definitions in .cpp/.cu sources (non-extern) --"
rg -nP --glob '*.cpp' --glob '*.cu' "(?<!extern)\s+void\s+$sym\s*\([^)]*\)" -C2 || true
echo
done
Length of output: 3201
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Search for definitions/occurrences of new symbols and a nearby existing symbol for reference
symbols=(
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90
run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90
)
for sym in "${symbols[@]}"; do
echo "=== Occurrences of $sym ==="
rg -n "$sym" -C2 || echo " (no matches found)"
echo
done
Length of output: 12014
Missing out-of-line definitions for new SM90 externs
Our scan of all .cpp
and .cu
files did not locate non-extern definitions for the two new functions declared in
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
at lines 263–264:
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90
Without matching definitions, the build will fail at link time. Please add implementations—e.g. in the existing fmha_cubin.cu
(or a new .cpp/.cu
)—ensuring:
- Exactly one non-extern definition per symbol.
- The parameter names in the definitions match the declarations (
params
,launch_params
,stream
). - Optionally, guard both declaration and definition with the same SM90 compile-time macro used elsewhere for consistency.
– File to update:
• cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
(lines 263–264)
• Corresponding .cu
or .cpp
where definitions belong.
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
around lines 263–264, two functions were declared as extern for SM90 but no
non-extern (definition) implementations exist:
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90 and
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90;
add exactly one non-extern definition for each (preferably in
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cu),
ensure the function signatures use the same parameter names (params,
launch_params, stream), and wrap both declaration and definition with the same
SM90 compile-time macro guard used elsewhere so the linker finds the
implementations only when SM90 is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (2)
1299-1305
: Handle odd‐column tail in last MMA column‐packing loopIn
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
(around lines 1299–1305), the store loop unconditionally writes two columns (ci
andci+1
) per iteration:#pragma unroll for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) { STORE_COLUMNS(); }If
VALID_COLS_PER_THREAD_FOR_LAST_MMA
is odd, the final iteration will read an out‐of‐range accumulator (ci+1
) and write a stray 8 bytes beyond the valid columns.Mandatory fix: restrict the paired‐store loop to full pairs and handle a leftover single column explicitly. For example:
@@ -1299,7 +1299,11 @@ int mma_ni = VALID_MMAS_N - 1; #pragma unroll - for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) + // Store full 2-column packs. + for (int ci = 0; ci + 1 < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) { STORE_COLUMNS(); } + // Handle odd tail, if any. + if (VALID_COLS_PER_THREAD_FOR_LAST_MMA & 1) { + int ci = VALID_COLS_PER_THREAD_FOR_LAST_MMA - 1; + STORE_SINGLE_COLUMN(ci); + }– Adds no‐op cost when even
– Ensures safety for odd‐column cases
– Mirrors the per‐element tail‐store path in other code paths
1249-1254
: Fix two-column packing logic to eliminate gaps in output addressesThe off-device verification shows that the new “two-column” write pattern only covers even-indexed columns, causing half the elements to be skipped (holes) for any VALID_COLS > 1:
• VALID_COLS=1: legacy_writes=2, new_writes=2 (diff=0)
• VALID_COLS=2: legacy_writes=4, new_writes=2 (diff=2)
…
• VALID_COLS=16: legacy_writes=32, new_writes=16 (diff=16)These diffs indicate missing writes on all odd-indexed columns, so the output layout no longer matches the per-element path.
Affected location
• File: cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
• Lines ~1249–1254 (wherecol = lane % 4 * ELEMENTS_PER_STG * 2
and thewhile (ci < VALID_COLS) { … ci += 2; }
loop live)Action items
- Update the
col
calculation or loop stepping so that every column index (even + odd) is written, preserving the original dense layout.- Re-run the offset comparison script for your target N_PER_MMA and VALID_N sets to confirm
legacy_writes == new_writes
anddiff == 0
.
🧹 Nitpick comments (3)
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (3)
1-11
: Update copyright year to include 2025.Repo guideline says to prepend the current year. Header still shows 2011-2024.
1225-1232
: Per‑CUBIN descale wiring looks right; consider constness and layout docs.Good: selecting scale_bmm2_d (if provided) at runtime enables CUBIN-specific scaling without branching in hot paths.
Minor:
- params_scale_bmm2_ could be made const (if Acc_packer does not require non-const access).
- Add a brief note about expected units/encoding of scale_bmm2 (e.g., E4M3 inverse scale in QKV descale), to prevent misuse by future traits.
1267-1280
: Avoid function-local macros; scope leak risk. Also generalize source type.
- STORE_COLUMNS is defined as a macro inside a function and never undefined. This leaks into the rest of the translation unit and risks collisions. Prefer a small inline helper or, minimally, add an #undef immediately after the last use.
- Use Traits::Accumulator_type instead of hard-coded float to keep this generic across traits.
Apply this change to generalize the source type within the macro:
- uint2 _dst = Acc_packer<float, Output_type, Scale>::run(this, _src); + uint2 _dst = Acc_packer<typename Traits::Accumulator_type, Output_type, Scale>::run(this, _src);And immediately after the last invocation (just before the end of the function), undefine the macro to avoid leaking it:
@@ for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) { STORE_COLUMNS() } } } + +#undef STORE_COLUMNS
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Parameter names must be consistent between declarations and definitions
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
**/*.{h,hpp,hxx,hh,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
🔇 Additional comments (2)
cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h (2)
1262-1266
: Compile-time Scale toggle is clear.Nice reduction of duplication vs. repeating #ifdef around each call site.
1316-1319
: Member added for scaling context—ensure Acc_packer reads this.Assuming Acc_packer<T, Output_type, Scale>::run expects “this” to expose params_scale_bmm2_, this addition is correct. If not, consider adding an accessor method to avoid relying on member naming.
Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Kanghwan Jang <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Robin Kobus <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: qqiao <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: yuhyao <[email protected]> Signed-off-by: Yuxin <[email protected]>
…isper (NVIDIA#7035) Signed-off-by: Dom Brown <[email protected]> Signed-off-by: Yuxin <[email protected]>
) Signed-off-by: Jin Li <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: junq <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: bhsueh <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: yechank <[email protected]> Signed-off-by: Yuxin <[email protected]>
…DIA#6957) Signed-off-by: Chang Liu (Enterprise Products) <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Batsheva Black <[email protected]> Signed-off-by: Bo Deng <[email protected]> Co-authored-by: Bo Deng <[email protected]> Signed-off-by: Yuxin <[email protected]>
…IA#7077) Signed-off-by: Frida Hou <[email protected]> Signed-off-by: Yuxin <[email protected]>
…ion/implementation (NVIDIA#6679) Signed-off-by: fanyunfan <[email protected]> Co-authored-by: fanyunfan <[email protected]> Co-authored-by: Yunfan Fan <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: Xin He (SW-GPU) <[email protected]> Signed-off-by: Yuxin <[email protected]>
…#7087) Signed-off-by: Yao Yao <[email protected]> Signed-off-by: Yuxin <[email protected]>
…h 2.8 (NVIDIA#7076) Signed-off-by: Frida Hou <[email protected]> Signed-off-by: Yuxin <[email protected]>
Signed-off-by: bhsueh <[email protected]> Signed-off-by: Yuxin <[email protected]>
…VIDIA#7101) Signed-off-by: Farshad Ghodsian <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Yuxin <[email protected]>
255f062
to
810beb2
Compare
Tested with DeepSeekV3Lite with FP8 KVCache enabled.
Summary by CodeRabbit
New Features
Performance/Compatibility
Tests
Documentation