Skip to content

Conversation

zhumakhan
Copy link

@zhumakhan zhumakhan commented Jul 29, 2025

attention plugin

Summary by CodeRabbit

  • New Features
    • Introduced a custom WanAttention multi-head attention plugin optimized for transformer models, now integrated with TensorRT.
    • Added a specialized CrossAttention mechanism supporting image and text context separation for advanced cross-attention tasks.
    • Provided Python API for invoking WanAttention via a new functional interface.
    • Enabled building and serialization of networks using the new attention mechanism with TensorRT and tensorrt_llm.
  • Tests
    • Added scripts to test and benchmark the new WanAttention plugin and quantized attention kernels, including integration with CUDA and PyTorch for validation and performance measurement.

Signed-off-by: Sultan <[email protected]>
Copy link
Contributor

coderabbitai bot commented Jul 29, 2025

Walkthrough

This change introduces a new TensorRT plugin called WanAttentionPlugin for optimized multi-head attention, integrating it into the build system and plugin registry. It provides full implementation, headers, and build scripts for the plugin, adds Python bindings for its usage, and includes new model code and test scripts for validation and benchmarking.

Changes

Cohort / File(s) Change Summary
CMake Integration
cpp/tensorrt_llm/plugins/CMakeLists.txt,
cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt
Added wanAttentionPlugin to the plugin build list and created a dedicated CMake build script for its sources, integrating it into the TensorRT LLM plugin build process.
Plugin Registration
cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
Registered the new WanAttentionPlugin in the plugin creator system, ensuring it is available for use within TensorRT.
WanAttention Plugin Implementation
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp,
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
Added the full implementation and header for WanAttentionPlugin, including plugin lifecycle, serialization, deserialization, execution, and creator classes. Supports FP16, FP32, and BF16, with optimized fused multi-head attention kernels and TensorRT plugin interfaces.
Python Functional API
tensorrt_llm/functional.py
Introduced a new wan_attention function to wrap the plugin as a TensorRT layer, enabling multi-head attention via the plugin in Python workflows.
Model Integration
model.py,
tensorrt_llm/models/wan/model.py
Added a CrossAttention class utilizing the new plugin, with custom forward logic for image/text context separation, normalization, and TensorRT engine building. Integrated with tensorrt_llm and TensorRT builder APIs.
Testing and Benchmarking
test.py,
test_sage.py
Added test and benchmarking scripts: test.py builds, serializes, deserializes, and runs the plugin in TensorRT, comparing outputs to PyTorch; test_sage.py benchmarks a quantized attention CUDA kernel, measuring performance for various configurations.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Python API
    participant TensorRT Network
    participant WanAttentionPlugin
    participant CUDA Kernel

    User->>Python API: Call wan_attention(...)
    Python API->>TensorRT Network: Add WanAttention plugin layer
    TensorRT Network->>WanAttentionPlugin: Instantiate plugin with params
    WanAttentionPlugin->>CUDA Kernel: Launch fused multi-head attention
    CUDA Kernel-->>WanAttentionPlugin: Return computed output
    WanAttentionPlugin-->>TensorRT Network: Output tensor
    TensorRT Network-->>User: Engine with WanAttention
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested labels

Community want to contribute

Suggested reviewers

  • nv-guomingz
  • shaharmor98
  • juney-nvidia
  • litaotju
  • Superjomn

Poem

A plugin hops into the fray,
WanAttention leads the way!
Heads and tensors, fused with might,
CUDA kernels take their flight.
Models, tests, all in sync—
This rabbit’s code is quick as a wink!
🐇✨

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@zhumakhan zhumakhan closed this Jul 29, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 25

🔭 Outside diff range comments (1)
tensorrt_llm/models/wan/model.py (1)

209-250: Move module-level code to a function.

Like in model.py, this execution code should be in a function, not at module level.

🧹 Nitpick comments (16)
cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt (1)

2-3: Update the copyright year to include current year.

The copyright header shows "1993-2022" but according to coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header that includes the current year (2025).

Apply this diff to update the copyright year:

-# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp (1)

2-2: Update the copyright year to include current year.

The copyright header shows "1993-2022" but according to coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header that includes the current year (2025).

Apply this diff to update the copyright year:

- * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
test_sage.py (1)

95-98: Consider removing unused warmup loop.

The warmup loop runs 0 iterations, making it effectively unused. If warmup is not needed, consider removing it for code clarity.

Apply this diff to remove the unused loop:

-for i in range(0):
-    kernel(q, k, v, o, q_scale, k_scale, 1, _is_causal, _qk_quant_gran,
-           sm_scale, 0)
-torch.cuda.synchronize()
+torch.cuda.synchronize()
tensorrt_llm/functional.py (1)

4732-4736: Consider making hardcoded context_fmha_type configurable.

The context_fmha_type is hardcoded to value 2. Consider making this configurable, especially since there's a commented line suggesting it should come from default_net().plugin_config.context_fmha_type.

     context_fmha_type = trt.PluginField(
         "context_fmha_type",
-        # default_net().plugin_config.context_fmha_type
-        np.array(np.int8(2), dtype=np.int8),
+        np.array(np.int8(default_net().plugin_config.context_fmha_type), dtype=np.int8),
         trt.PluginFieldType.INT8)

If the plugin config doesn't have this attribute, consider adding a function parameter with a default value.

model.py (4)

12-12: Remove debug flag from production code.

The ADD_DEBUG_TENSOR flag appears to be for debugging purposes but is not used anywhere in the code. Consider removing it or properly implementing its usage.

-ADD_DEBUG_TENSOR = True

30-30: Address the incomplete RMS norm implementation.

The comment indicates that qk_norm='rms_norm' is not used because rms_norm_cross_attention is not released yet. This suggests incomplete functionality that should be tracked or implemented.

Would you like me to create an issue to track the implementation of rms_norm_cross_attention functionality?


121-121: Remove no-op statements.

These lines don't have any effect and should be removed.

-        query.dtype
-        head_dim * self.heads

Also applies to: 126-126


263-263: Fix line length violations.

Lines exceed the 120 character limit specified in the coding guidelines.

Break these long lines or remove them if they're just debug information.

Also applies to: 265-265

test.py (2)

32-32: Make logger verbosity configurable.

The logger is set to VERBOSE which can produce excessive output. Consider making this configurable.

-logger = trt.Logger(trt.Logger.VERBOSE)
+log_level = os.environ.get("TRT_LOG_LEVEL", "INFO")
+log_levels = {
+    "VERBOSE": trt.Logger.VERBOSE,
+    "INFO": trt.Logger.INFO,
+    "WARNING": trt.Logger.WARNING,
+    "ERROR": trt.Logger.ERROR
+}
+logger = trt.Logger(log_levels.get(log_level, trt.Logger.INFO))

81-82: Use constants for file names.

Hardcoded file names should be defined as constants or made configurable.

+ENGINE_FILE = "wan_attention.engine"
+
-with open("wan_attention.engine", "wb") as f:
+with open(ENGINE_FILE, "wb") as f:
     f.write(serialized_engine)

-with open("wan_attention.engine", "rb") as f, trt.Runtime(logger) as runtime:
+with open(ENGINE_FILE, "rb") as f, trt.Runtime(logger) as runtime:
tensorrt_llm/models/wan/model.py (2)

49-49: Make data type configurable.

The dtype is hardcoded to trt.bfloat16. Consider making it configurable to support different precision requirements.

-            dtype=trt.bfloat16,
+            dtype=trt.float16,  # or make it a parameter with default

115-115: Remove unused variable assignments.

These variables are assigned but never used effectively.

-        q_dtype = query.dtype
+        # Store dtype for later casting if needed
+        q_dtype = query.dtype
-        inner_dim = head_dim * self.heads
+        # inner_dim = head_dim * self.heads  # Unused

Also applies to: 120-120

cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h (1)

121-121: Add namespace closing comment.

According to the coding guidelines, closing braces of namespaces should have a comment.

-} // namespace tensorrt_llm::plugins
+} // namespace tensorrt_llm::plugins
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp (3)

1-16: Fix copyright header formatting.

The copyright header has incorrect line breaking. The second line should not be split.

 /*
- * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
- * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0

154-155: Define NUM_BUFFERS as a constant.

The NUM_BUFFERS should be defined as a class constant rather than a local variable.

+    static constexpr int NUM_BUFFERS = 1;
     int const NUM_BUFFERS = 1;

447-448: Fix misleading comment.

The comment mentions "BertAttentionPlugin" but this is "WanAttentionPlugin".

     // This object will be deleted when the network is destroyed, which will
-    // call BertAttentionPlugin::destroy()
+    // call WanAttentionPlugin::destroy()
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f1086e7 and f5d87c0.

📒 Files selected for processing (10)
  • cpp/tensorrt_llm/plugins/CMakeLists.txt (1 hunks)
  • cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp (3 hunks)
  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt (1 hunks)
  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp (1 hunks)
  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h (1 hunks)
  • model.py (1 hunks)
  • tensorrt_llm/functional.py (1 hunks)
  • tensorrt_llm/models/wan/model.py (1 hunks)
  • test.py (1 hunks)
  • test_sage.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,h,hpp,cc,cxx}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) in C++ should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces in C++ should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable should use camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global c...

Files:

  • cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp
  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
  • test_sage.py
  • tensorrt_llm/models/wan/model.py
  • model.py
  • tensorrt_llm/functional.py
  • test.py
  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp
  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • test_sage.py
  • tensorrt_llm/models/wan/model.py
  • model.py
  • tensorrt_llm/functional.py
  • test.py
**/*.{h,hpp}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.

Files:

  • cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
🧠 Learnings (3)
cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp (1)

Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-29T08:25:52.877Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt (1)

Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-29T08:25:52.877Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

test.py (3)

Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-29T08:25:52.877Z
Learning: Applies to **/*.py : The code developed for TensorRT-LLM should conform to Python 3.8+.

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
model.py

263-263: Line too long (428 > 120)

(E501)


265-265: Line too long (437 > 120)

(E501)

🔇 Additional comments (12)
cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt (1)

17-21: LGTM!

The CMakeLists.txt structure correctly integrates the plugin into the build system by gathering source files and exporting them to the parent scope.

cpp/tensorrt_llm/plugins/CMakeLists.txt (1)

51-51: LGTM!

The wanAttentionPlugin is correctly added to the PLUGIN_LISTS in alphabetical order, following the established pattern for plugin integration.

cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp (3)

37-37: LGTM!

The WanAttentionPlugin header is correctly included following the established pattern.


212-212: LGTM!

The WanAttentionPluginCreator static instance is correctly declared following the established naming convention and pattern.


255-255: LGTM!

The wanAttentionPluginCreator is correctly added to the pluginCreators array, maintaining the established order and pattern.

test_sage.py (4)

7-20: LGTM!

The argument parser configuration follows Python best practices with appropriate defaults, type specifications, and help text.


26-26: LGTM!

The assertion appropriately guards against unsupported accumulation modes with a clear error message.


45-47: LGTM!

The FLOPS calculation correctly accounts for attention operations: 4 * batch * heads * head_dim * seq_len * seq_len_kv, with appropriate adjustment for causal attention.


99-114: LGTM!

The benchmarking setup using flash_attn's benchmark_forward is appropriate for performance measurement with proper CUDA synchronization and single iteration for accurate timing.

tensorrt_llm/functional.py (1)

4738-4741: No change needed—bert_attention_plugin is the correct config for WanAttention

The PluginConfig.bert_attention_plugin field is intentionally used to configure the WanAttention plugin. There is no separate wan_attention_plugin property—models that rely on WanAttention (e.g. in tensorrt_llm/models/wan/model.py) assert and consume plugin_config.bert_attention_plugin, and the factory in functional.py correctly reads that same field.

• In tensorrt_llm/models/wan/model.py (≈line 154), the model asserts plugin_config.bert_attention_plugin is not None to enable WanAttention.
• In tensorrt_llm/functional.py, the plugin builder for "WanAttention" pulls its dtype from the very same plugin_config.bert_attention_plugin.

Ignore the suggestion to use a non-existent wan_attention_plugin.

Likely an incorrect or invalid review comment.

test.py (1)

183-183: Fix incorrect assertion error messages.

The assertion uses incorrect syntax for dimension checking.

cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h (1)

78-98: Follow member variable naming convention.

According to the coding guidelines, class member variables should use camelCase prefixed with 'm'. Some members follow this (like mLayerName) but others don't match the exact convention.

The member variables follow the correct naming convention with 'm' prefix.

Comment on lines +132 to +135
printf("inner dim of q,k and v should match!");
ret = true;
}
if (batch_size + 1 != inputs[2].dims.d[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace printf with proper error handling.

Use TLLM_CHECK or proper logging instead of printf for error messages.

-    if (inputs[0].dims.d[3] != inputs[1].dims.d[4])
-    {
-        printf("inner dim of q,k and v should match!");
-        ret = true;
-    }
+    TLLM_CHECK_WITH_INFO(inputs[0].dims.d[3] == inputs[1].dims.d[4],
+        "Inner dimensions of query and key/value must match");

-    if (batch_size + 1 != inputs[2].dims.d[0])
-    {
-        printf("cu seqlen_q should have dim of (batch_size + 1)");
-        ret = true;
-    }
+    TLLM_CHECK_WITH_INFO(batch_size + 1 == inputs[2].dims.d[0],
+        "cu_seqlen_q should have dimension of (batch_size + 1)");

-    if (batch_size + 1 != inputs[3].dims.d[0])
-    {
-        printf("cu seqlen_kv should have dim of (batch_size + 1)");
-        ret = true;
-    }
+    TLLM_CHECK_WITH_INFO(batch_size + 1 == inputs[3].dims.d[0],
+        "cu_seqlen_kv should have dimension of (batch_size + 1)");

Also applies to: 138-141

🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp around
lines 132 to 135 and also 138 to 141, replace the printf statements used for
error messages with proper error handling by using TLLM_CHECK or the project's
logging mechanism. This involves removing printf calls and instead invoking
TLLM_CHECK with the appropriate condition and error message or using the logging
utility to report errors consistently.

Comment on lines +145 to +148
if (ret)
{
return -1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove error-prone return value.

Returning -1 for size_t (unsigned) will result in a very large number, not a negative value.

-    if (ret)
-    {
-        return -1;
-    }
+    // Errors are now handled by TLLM_CHECK above
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (ret)
{
return -1;
}
// Errors are now handled by TLLM_CHECK above
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp around
lines 145 to 148, the code returns -1 for a size_t return type, which is
unsigned and causes an unintended large value. Replace the return value with a
suitable error indicator for size_t, such as returning 0 or a defined constant
representing an error, to avoid incorrect large values.

Comment on lines +183 to +185
assert(inputs[0].dims.d[3] == inputs[1].dims.d[4] && "inner dim of q,k and v should match!");
assert(batch_size + 1 == inputs[2].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)");
assert(batch_size + 1 == inputs[3].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace assert with TLLM_CHECK.

Use TLLM_CHECK instead of assert for consistency with the codebase.

-    assert(inputs[0].dims.d[3] == inputs[1].dims.d[4] && "inner dim of q,k and v should match!");
-    assert(batch_size + 1 == inputs[2].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)");
-    assert(batch_size + 1 == inputs[3].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)");
+    TLLM_CHECK(inputDesc[0].dims.d[3] == inputDesc[1].dims.d[4]);
+    TLLM_CHECK(batch_size + 1 == inputDesc[2].dims.d[0]);
+    TLLM_CHECK(batch_size + 1 == inputDesc[3].dims.d[0]);

-    assert(mEnableContextFMHA && "mEnableContextFMHA is false!");
+    TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "Context FMHA must be enabled for this plugin");

Also applies to: 203-203

🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp around
lines 183 to 185 and line 203, replace the assert statements with TLLM_CHECK
macros to maintain consistency with the codebase's error handling conventions.
Change each assert condition to a TLLM_CHECK call with the same condition and
error message string.

Comment on lines +330 to +331
printf("\n%s\n", fmhaParams.convertToStrOutput().c_str());
printf("\nmEnableContextFMHA: %s\n", mEnableContextFMHA ? "true" : "false");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove debug printf statements.

Debug output should use proper logging mechanisms, not printf.

-        printf("\n%s\n", fmhaParams.convertToStrOutput().c_str());
-        printf("\nmEnableContextFMHA: %s\n", mEnableContextFMHA ? "true" : "false");
+        TLLM_LOG_DEBUG("FMHA params: %s", fmhaParams.convertToStrOutput().c_str());
+        TLLM_LOG_DEBUG("mEnableContextFMHA: %s", mEnableContextFMHA ? "true" : "false");
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp at lines
330 to 331, remove the debug printf statements as they are not appropriate for
production code. Replace these with proper logging calls using the project's
logging framework or remove them entirely if logging is not needed here.

* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use proper include guard format.

According to the coding guidelines, header files should use a preprocessor guard with prefix TRTLLM_ followed by the filename in caps. Replace #pragma once with the proper guard.

-#pragma once
+#ifndef TRTLLM_WANATTENTIONPLUGIN_H
+#define TRTLLM_WANATTENTIONPLUGIN_H

And add at the end of the file:

 } // namespace tensorrt_llm::plugins
+
+#endif // TRTLLM_WANATTENTIONPLUGIN_H
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h at line 17,
replace the #pragma once directive with a proper include guard using the prefix
TRTLLM_ followed by the filename in uppercase (e.g.,
TRTLLM_WANATTENTIONPLUGIN_H). Add the corresponding #ifndef, #define at the top
and #endif at the end of the file to properly guard against multiple inclusions.

attn(hidden_states, encoder_hidden_states)

config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 1 MiB
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect workspace size comment.

The comment says "1 MiB" but 1 << 25 is 32 MiB.

-config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25)  # 1 MiB
+config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25)  # 32 MiB
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 1 MiB
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 32 MiB
🤖 Prompt for AI Agents
In tensorrt_llm/models/wan/model.py at line 245, the comment incorrectly states
the workspace size as "1 MiB" while the code sets it to 1 << 25, which equals 32
MiB. Update the comment to accurately reflect the workspace size as "32 MiB" to
match the code.

@@ -0,0 +1,114 @@
import argparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add required NVIDIA copyright header.

According to coding guidelines, all TensorRT-LLM source files including Python files should contain an NVIDIA copyright header that includes the current year.

Add this copyright header at the beginning of the file:

+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
 import argparse
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import argparse
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
🤖 Prompt for AI Agents
In test_sage.py at line 1, the file is missing the required NVIDIA copyright
header. Add the standard NVIDIA copyright header including the current year at
the very top of the file before any imports to comply with coding guidelines.

import tensorrt_llm as tllm
from tensorrt_llm.functional import wan_attention

os.environ["CUDA_VISIBLE_DEVICES"] = "0" # To use GPU 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid setting environment variables in code.

Setting CUDA_VISIBLE_DEVICES in the code can cause issues in multi-GPU environments. This should be handled externally or made configurable.

-os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # To use GPU 0
+# Set CUDA_VISIBLE_DEVICES externally or make it configurable
+# gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # To use GPU 0
# Set CUDA_VISIBLE_DEVICES externally or make it configurable
# gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
🤖 Prompt for AI Agents
In test.py at line 10, avoid setting the environment variable
CUDA_VISIBLE_DEVICES directly in the code. Instead, remove this line and
configure CUDA device visibility externally via environment settings or pass it
as a configurable parameter to the program to ensure flexibility in multi-GPU
environments.

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 1 MiB
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix workspace size comment.

The comment indicates "1 MiB" but 1 << 35 is 32 GiB.

-config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35)  # 1 MiB
+config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35)  # 32 GiB
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 1 MiB
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 32 GiB
🤖 Prompt for AI Agents
In test.py at line 37, the comment for the memory pool limit is incorrect; it
states "1 MiB" but the value `1 << 35` actually represents 32 GiB. Update the
comment to correctly reflect the size as "32 GiB" to match the value set.

Comment on lines +149 to +155
# print(hidden_states_true)
# print(hidden_states)

# print(torch.max(hidden_states_true), torch.min(hidden_states_true))
# print(torch.max(hidden_states), torch.min(hidden_states))

# torch.testing.assert_close(hidden_states.float().transpose(1,2), hidden_states_true.float(), rtol=0.001, atol=0.001)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Uncomment and fix test assertions.

The test assertions are commented out, which means the test doesn't actually validate the output. These should be enabled for proper testing.

-cuda_call(cudart.cudaStreamSynchronize(stream))
-# print(hidden_states_true)
-# print(hidden_states)
-
-# print(torch.max(hidden_states_true), torch.min(hidden_states_true))
-# print(torch.max(hidden_states), torch.min(hidden_states))
-
-# torch.testing.assert_close(hidden_states.float().transpose(1,2), hidden_states_true.float(), rtol=0.001, atol=0.001)
+cuda_call(cudart.cudaStreamSynchronize(stream))
+
+# Transpose hidden_states to match the expected output format
+hidden_states_transposed = hidden_states.transpose(1, 2)
+
+# Validate output
+torch.testing.assert_close(
+    hidden_states_transposed.float(), 
+    hidden_states_true.float(), 
+    rtol=0.01,  # Relaxed tolerance for numerical stability
+    atol=0.01
+)
+print("Test passed: WanAttention output matches PyTorch reference")

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In test.py around lines 149 to 155, the test assertions are commented out, so
the test does not validate the output. Uncomment the torch.testing.assert_close
line and ensure the tensors compared have compatible shapes and types, applying
any necessary transformations like transpose or type casting to match them
correctly. This will enable proper validation of the output in the test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants