-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Attention attention plugin #6448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Sultan <[email protected]>
WalkthroughThis change introduces a new TensorRT plugin called Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Python API
participant TensorRT Network
participant WanAttentionPlugin
participant CUDA Kernel
User->>Python API: Call wan_attention(...)
Python API->>TensorRT Network: Add WanAttention plugin layer
TensorRT Network->>WanAttentionPlugin: Instantiate plugin with params
WanAttentionPlugin->>CUDA Kernel: Launch fused multi-head attention
CUDA Kernel-->>WanAttentionPlugin: Return computed output
WanAttentionPlugin-->>TensorRT Network: Output tensor
TensorRT Network-->>User: Engine with WanAttention
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested labels
Suggested reviewers
Poem
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. ✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 25
🔭 Outside diff range comments (1)
tensorrt_llm/models/wan/model.py (1)
209-250
: Move module-level code to a function.Like in model.py, this execution code should be in a function, not at module level.
🧹 Nitpick comments (16)
cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt (1)
2-3
: Update the copyright year to include current year.The copyright header shows "1993-2022" but according to coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header that includes the current year (2025).
Apply this diff to update the copyright year:
-# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp (1)
2-2
: Update the copyright year to include current year.The copyright header shows "1993-2022" but according to coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header that includes the current year (2025).
Apply this diff to update the copyright year:
- * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.test_sage.py (1)
95-98
: Consider removing unused warmup loop.The warmup loop runs 0 iterations, making it effectively unused. If warmup is not needed, consider removing it for code clarity.
Apply this diff to remove the unused loop:
-for i in range(0): - kernel(q, k, v, o, q_scale, k_scale, 1, _is_causal, _qk_quant_gran, - sm_scale, 0) -torch.cuda.synchronize() +torch.cuda.synchronize()tensorrt_llm/functional.py (1)
4732-4736
: Consider making hardcoded context_fmha_type configurable.The
context_fmha_type
is hardcoded to value 2. Consider making this configurable, especially since there's a commented line suggesting it should come fromdefault_net().plugin_config.context_fmha_type
.context_fmha_type = trt.PluginField( "context_fmha_type", - # default_net().plugin_config.context_fmha_type - np.array(np.int8(2), dtype=np.int8), + np.array(np.int8(default_net().plugin_config.context_fmha_type), dtype=np.int8), trt.PluginFieldType.INT8)If the plugin config doesn't have this attribute, consider adding a function parameter with a default value.
model.py (4)
12-12
: Remove debug flag from production code.The
ADD_DEBUG_TENSOR
flag appears to be for debugging purposes but is not used anywhere in the code. Consider removing it or properly implementing its usage.-ADD_DEBUG_TENSOR = True
30-30
: Address the incomplete RMS norm implementation.The comment indicates that
qk_norm='rms_norm'
is not used becauserms_norm_cross_attention
is not released yet. This suggests incomplete functionality that should be tracked or implemented.Would you like me to create an issue to track the implementation of
rms_norm_cross_attention
functionality?
121-121
: Remove no-op statements.These lines don't have any effect and should be removed.
- query.dtype - head_dim * self.headsAlso applies to: 126-126
263-263
: Fix line length violations.Lines exceed the 120 character limit specified in the coding guidelines.
Break these long lines or remove them if they're just debug information.
Also applies to: 265-265
test.py (2)
32-32
: Make logger verbosity configurable.The logger is set to VERBOSE which can produce excessive output. Consider making this configurable.
-logger = trt.Logger(trt.Logger.VERBOSE) +log_level = os.environ.get("TRT_LOG_LEVEL", "INFO") +log_levels = { + "VERBOSE": trt.Logger.VERBOSE, + "INFO": trt.Logger.INFO, + "WARNING": trt.Logger.WARNING, + "ERROR": trt.Logger.ERROR +} +logger = trt.Logger(log_levels.get(log_level, trt.Logger.INFO))
81-82
: Use constants for file names.Hardcoded file names should be defined as constants or made configurable.
+ENGINE_FILE = "wan_attention.engine" + -with open("wan_attention.engine", "wb") as f: +with open(ENGINE_FILE, "wb") as f: f.write(serialized_engine) -with open("wan_attention.engine", "rb") as f, trt.Runtime(logger) as runtime: +with open(ENGINE_FILE, "rb") as f, trt.Runtime(logger) as runtime:tensorrt_llm/models/wan/model.py (2)
49-49
: Make data type configurable.The dtype is hardcoded to
trt.bfloat16
. Consider making it configurable to support different precision requirements.- dtype=trt.bfloat16, + dtype=trt.float16, # or make it a parameter with default
115-115
: Remove unused variable assignments.These variables are assigned but never used effectively.
- q_dtype = query.dtype + # Store dtype for later casting if needed + q_dtype = query.dtype - inner_dim = head_dim * self.heads + # inner_dim = head_dim * self.heads # UnusedAlso applies to: 120-120
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h (1)
121-121
: Add namespace closing comment.According to the coding guidelines, closing braces of namespaces should have a comment.
-} // namespace tensorrt_llm::plugins +} // namespace tensorrt_llm::pluginscpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp (3)
1-16
: Fix copyright header formatting.The copyright header has incorrect line breaking. The second line should not be split.
/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0
154-155
: Define NUM_BUFFERS as a constant.The NUM_BUFFERS should be defined as a class constant rather than a local variable.
+ static constexpr int NUM_BUFFERS = 1; int const NUM_BUFFERS = 1;
447-448
: Fix misleading comment.The comment mentions "BertAttentionPlugin" but this is "WanAttentionPlugin".
// This object will be deleted when the network is destroyed, which will - // call BertAttentionPlugin::destroy() + // call WanAttentionPlugin::destroy()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/tensorrt_llm/plugins/CMakeLists.txt
(1 hunks)cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
(3 hunks)cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt
(1 hunks)cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp
(1 hunks)cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
(1 hunks)model.py
(1 hunks)tensorrt_llm/functional.py
(1 hunks)tensorrt_llm/models/wan/model.py
(1 hunks)test.py
(1 hunks)test_sage.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) in C++ should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces in C++ should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable should use camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global c...
Files:
cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
test_sage.py
tensorrt_llm/models/wan/model.py
model.py
tensorrt_llm/functional.py
test.py
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
test_sage.py
tensorrt_llm/models/wan/model.py
model.py
tensorrt_llm/functional.py
test.py
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Files:
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h
🧠 Learnings (3)
cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp (1)
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-29T08:25:52.877Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt (1)
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-29T08:25:52.877Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
test.py (3)
Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-29T08:25:52.877Z
Learning: Applies to **/*.py : The code developed for TensorRT-LLM should conform to Python 3.8+.
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
model.py
263-263: Line too long (428 > 120)
(E501)
265-265: Line too long (437 > 120)
(E501)
🔇 Additional comments (12)
cpp/tensorrt_llm/plugins/wanAttentionPlugin/CMakeLists.txt (1)
17-21
: LGTM!The CMakeLists.txt structure correctly integrates the plugin into the build system by gathering source files and exporting them to the parent scope.
cpp/tensorrt_llm/plugins/CMakeLists.txt (1)
51-51
: LGTM!The wanAttentionPlugin is correctly added to the PLUGIN_LISTS in alphabetical order, following the established pattern for plugin integration.
cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp (3)
37-37
: LGTM!The WanAttentionPlugin header is correctly included following the established pattern.
212-212
: LGTM!The WanAttentionPluginCreator static instance is correctly declared following the established naming convention and pattern.
255-255
: LGTM!The wanAttentionPluginCreator is correctly added to the pluginCreators array, maintaining the established order and pattern.
test_sage.py (4)
7-20
: LGTM!The argument parser configuration follows Python best practices with appropriate defaults, type specifications, and help text.
26-26
: LGTM!The assertion appropriately guards against unsupported accumulation modes with a clear error message.
45-47
: LGTM!The FLOPS calculation correctly accounts for attention operations: 4 * batch * heads * head_dim * seq_len * seq_len_kv, with appropriate adjustment for causal attention.
99-114
: LGTM!The benchmarking setup using flash_attn's benchmark_forward is appropriate for performance measurement with proper CUDA synchronization and single iteration for accurate timing.
tensorrt_llm/functional.py (1)
4738-4741
: No change needed—bert_attention_plugin
is the correct config for WanAttentionThe
PluginConfig.bert_attention_plugin
field is intentionally used to configure the WanAttention plugin. There is no separatewan_attention_plugin
property—models that rely on WanAttention (e.g. intensorrt_llm/models/wan/model.py
) assert and consumeplugin_config.bert_attention_plugin
, and the factory infunctional.py
correctly reads that same field.• In
tensorrt_llm/models/wan/model.py
(≈line 154), the model assertsplugin_config.bert_attention_plugin is not None
to enable WanAttention.
• Intensorrt_llm/functional.py
, the plugin builder for"WanAttention"
pulls its dtype from the very sameplugin_config.bert_attention_plugin
.Ignore the suggestion to use a non-existent
wan_attention_plugin
.Likely an incorrect or invalid review comment.
test.py (1)
183-183
: Fix incorrect assertion error messages.The assertion uses incorrect syntax for dimension checking.
cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h (1)
78-98
: Follow member variable naming convention.According to the coding guidelines, class member variables should use camelCase prefixed with 'm'. Some members follow this (like
mLayerName
) but others don't match the exact convention.The member variables follow the correct naming convention with 'm' prefix.
printf("inner dim of q,k and v should match!"); | ||
ret = true; | ||
} | ||
if (batch_size + 1 != inputs[2].dims.d[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Replace printf with proper error handling.
Use TLLM_CHECK or proper logging instead of printf for error messages.
- if (inputs[0].dims.d[3] != inputs[1].dims.d[4])
- {
- printf("inner dim of q,k and v should match!");
- ret = true;
- }
+ TLLM_CHECK_WITH_INFO(inputs[0].dims.d[3] == inputs[1].dims.d[4],
+ "Inner dimensions of query and key/value must match");
- if (batch_size + 1 != inputs[2].dims.d[0])
- {
- printf("cu seqlen_q should have dim of (batch_size + 1)");
- ret = true;
- }
+ TLLM_CHECK_WITH_INFO(batch_size + 1 == inputs[2].dims.d[0],
+ "cu_seqlen_q should have dimension of (batch_size + 1)");
- if (batch_size + 1 != inputs[3].dims.d[0])
- {
- printf("cu seqlen_kv should have dim of (batch_size + 1)");
- ret = true;
- }
+ TLLM_CHECK_WITH_INFO(batch_size + 1 == inputs[3].dims.d[0],
+ "cu_seqlen_kv should have dimension of (batch_size + 1)");
Also applies to: 138-141
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp around
lines 132 to 135 and also 138 to 141, replace the printf statements used for
error messages with proper error handling by using TLLM_CHECK or the project's
logging mechanism. This involves removing printf calls and instead invoking
TLLM_CHECK with the appropriate condition and error message or using the logging
utility to report errors consistently.
if (ret) | ||
{ | ||
return -1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove error-prone return value.
Returning -1 for size_t (unsigned) will result in a very large number, not a negative value.
- if (ret)
- {
- return -1;
- }
+ // Errors are now handled by TLLM_CHECK above
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if (ret) | |
{ | |
return -1; | |
} | |
// Errors are now handled by TLLM_CHECK above |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp around
lines 145 to 148, the code returns -1 for a size_t return type, which is
unsigned and causes an unintended large value. Replace the return value with a
suitable error indicator for size_t, such as returning 0 or a defined constant
representing an error, to avoid incorrect large values.
assert(inputs[0].dims.d[3] == inputs[1].dims.d[4] && "inner dim of q,k and v should match!"); | ||
assert(batch_size + 1 == inputs[2].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)"); | ||
assert(batch_size + 1 == inputs[3].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Replace assert with TLLM_CHECK.
Use TLLM_CHECK instead of assert for consistency with the codebase.
- assert(inputs[0].dims.d[3] == inputs[1].dims.d[4] && "inner dim of q,k and v should match!");
- assert(batch_size + 1 == inputs[2].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)");
- assert(batch_size + 1 == inputs[3].dims.d[0] && "cu seqlen should have dim of (batch_size + 1)");
+ TLLM_CHECK(inputDesc[0].dims.d[3] == inputDesc[1].dims.d[4]);
+ TLLM_CHECK(batch_size + 1 == inputDesc[2].dims.d[0]);
+ TLLM_CHECK(batch_size + 1 == inputDesc[3].dims.d[0]);
- assert(mEnableContextFMHA && "mEnableContextFMHA is false!");
+ TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "Context FMHA must be enabled for this plugin");
Also applies to: 203-203
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp around
lines 183 to 185 and line 203, replace the assert statements with TLLM_CHECK
macros to maintain consistency with the codebase's error handling conventions.
Change each assert condition to a TLLM_CHECK call with the same condition and
error message string.
printf("\n%s\n", fmhaParams.convertToStrOutput().c_str()); | ||
printf("\nmEnableContextFMHA: %s\n", mEnableContextFMHA ? "true" : "false"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove debug printf statements.
Debug output should use proper logging mechanisms, not printf.
- printf("\n%s\n", fmhaParams.convertToStrOutput().c_str());
- printf("\nmEnableContextFMHA: %s\n", mEnableContextFMHA ? "true" : "false");
+ TLLM_LOG_DEBUG("FMHA params: %s", fmhaParams.convertToStrOutput().c_str());
+ TLLM_LOG_DEBUG("mEnableContextFMHA: %s", mEnableContextFMHA ? "true" : "false");
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.cpp at lines
330 to 331, remove the debug printf statements as they are not appropriate for
production code. Replace these with proper logging calls using the project's
logging framework or remove them entirely if logging is not needed here.
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use proper include guard format.
According to the coding guidelines, header files should use a preprocessor guard with prefix TRTLLM_
followed by the filename in caps. Replace #pragma once
with the proper guard.
-#pragma once
+#ifndef TRTLLM_WANATTENTIONPLUGIN_H
+#define TRTLLM_WANATTENTIONPLUGIN_H
And add at the end of the file:
} // namespace tensorrt_llm::plugins
+
+#endif // TRTLLM_WANATTENTIONPLUGIN_H
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/plugins/wanAttentionPlugin/wanAttentionPlugin.h at line 17,
replace the #pragma once directive with a proper include guard using the prefix
TRTLLM_ followed by the filename in uppercase (e.g.,
TRTLLM_WANATTENTIONPLUGIN_H). Add the corresponding #ifndef, #define at the top
and #endif at the end of the file to properly guard against multiple inclusions.
attn(hidden_states, encoder_hidden_states) | ||
|
||
config = builder.create_builder_config() | ||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 1 MiB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect workspace size comment.
The comment says "1 MiB" but 1 << 25
is 32 MiB.
-config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 1 MiB
+config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 32 MiB
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 1 MiB | |
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 32 MiB |
🤖 Prompt for AI Agents
In tensorrt_llm/models/wan/model.py at line 245, the comment incorrectly states
the workspace size as "1 MiB" while the code sets it to 1 << 25, which equals 32
MiB. Update the comment to accurately reflect the workspace size as "32 MiB" to
match the code.
@@ -0,0 +1,114 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add required NVIDIA copyright header.
According to coding guidelines, all TensorRT-LLM source files including Python files should contain an NVIDIA copyright header that includes the current year.
Add this copyright header at the beginning of the file:
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import argparse
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import argparse | |
# | |
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import argparse |
🤖 Prompt for AI Agents
In test_sage.py at line 1, the file is missing the required NVIDIA copyright
header. Add the standard NVIDIA copyright header including the current year at
the very top of the file before any imports to comply with coding guidelines.
import tensorrt_llm as tllm | ||
from tensorrt_llm.functional import wan_attention | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # To use GPU 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid setting environment variables in code.
Setting CUDA_VISIBLE_DEVICES
in the code can cause issues in multi-GPU environments. This should be handled externally or made configurable.
-os.environ["CUDA_VISIBLE_DEVICES"] = "0" # To use GPU 0
+# Set CUDA_VISIBLE_DEVICES externally or make it configurable
+# gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # To use GPU 0 | |
# Set CUDA_VISIBLE_DEVICES externally or make it configurable | |
# gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0") |
🤖 Prompt for AI Agents
In test.py at line 10, avoid setting the environment variable
CUDA_VISIBLE_DEVICES directly in the code. Instead, remove this line and
configure CUDA device visibility externally via environment settings or pass it
as a configurable parameter to the program to ensure flexibility in multi-GPU
environments.
builder = trt.Builder(logger) | ||
config = builder.create_builder_config() | ||
config.set_flag(trt.BuilderFlag.FP16) | ||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 1 MiB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix workspace size comment.
The comment indicates "1 MiB" but 1 << 35
is 32 GiB.
-config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 1 MiB
+config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 32 GiB
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 1 MiB | |
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35) # 32 GiB |
🤖 Prompt for AI Agents
In test.py at line 37, the comment for the memory pool limit is incorrect; it
states "1 MiB" but the value `1 << 35` actually represents 32 GiB. Update the
comment to correctly reflect the size as "32 GiB" to match the value set.
# print(hidden_states_true) | ||
# print(hidden_states) | ||
|
||
# print(torch.max(hidden_states_true), torch.min(hidden_states_true)) | ||
# print(torch.max(hidden_states), torch.min(hidden_states)) | ||
|
||
# torch.testing.assert_close(hidden_states.float().transpose(1,2), hidden_states_true.float(), rtol=0.001, atol=0.001) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncomment and fix test assertions.
The test assertions are commented out, which means the test doesn't actually validate the output. These should be enabled for proper testing.
-cuda_call(cudart.cudaStreamSynchronize(stream))
-# print(hidden_states_true)
-# print(hidden_states)
-
-# print(torch.max(hidden_states_true), torch.min(hidden_states_true))
-# print(torch.max(hidden_states), torch.min(hidden_states))
-
-# torch.testing.assert_close(hidden_states.float().transpose(1,2), hidden_states_true.float(), rtol=0.001, atol=0.001)
+cuda_call(cudart.cudaStreamSynchronize(stream))
+
+# Transpose hidden_states to match the expected output format
+hidden_states_transposed = hidden_states.transpose(1, 2)
+
+# Validate output
+torch.testing.assert_close(
+ hidden_states_transposed.float(),
+ hidden_states_true.float(),
+ rtol=0.01, # Relaxed tolerance for numerical stability
+ atol=0.01
+)
+print("Test passed: WanAttention output matches PyTorch reference")
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In test.py around lines 149 to 155, the test assertions are commented out, so
the test does not validate the output. Uncomment the torch.testing.assert_close
line and ensure the tensors compared have compatible shapes and types, applying
any necessary transformations like transpose or type casting to match them
correctly. This will enable proper validation of the output in the test.
attention plugin
Summary by CodeRabbit