Skip to content

Conversation

lancelly
Copy link
Collaborator

@lancelly lancelly commented Aug 8, 2025

We try to use PyBind11 to bind the attentionOp functions instead of relying on custom torch operator, thereby removing the parameter-count restriction on attention_inplace.

Summary by CodeRabbit

  • New Features

    • Introduced new Python bindings for advanced attention operations, enabling direct access to multi-head attention and capability query functions from Python.
    • Added a new internal submodule for attention operations, accessible via Python.
  • Refactor

    • Updated the attention backend to use the new Python bindings for attention operations instead of previous operator registrations.
  • Bug Fixes

    • Removed obsolete operator registrations to streamline and modernize the attention operation integration.

@lancelly lancelly requested a review from a team as a code owner August 8, 2025 07:11
@lancelly lancelly requested a review from QiJune August 8, 2025 07:11
Copy link
Contributor

coderabbitai bot commented Aug 8, 2025

📝 Walkthrough

Note

🔌 MCP (Model Context Protocol) integration is now available in Early Access!

Pro users can now connect to remote MCP servers under the Integrations page to get reviews and chat conversations that understand additional development context.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Aug 8, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

515-524: Also switch supports check to new thop binding to avoid missing op

attention_supports_nvfp4_output still uses torch.ops.trtllm. If the torch custom op was removed, this will break at runtime. Update to call the new pybind function.

Apply:

-        return torch.ops.trtllm.attention_supports_nvfp4_output(
+        return trtllm_thop.attention_supports_nvfp4_output(
             self.num_heads,
             self.num_kv_heads,
             self.head_size,
             tokens_per_block,
             int(mask_type),
             self.quant_mode,
             use_paged_context_fmha,
             is_mla_enable,
         )
🧹 Nitpick comments (6)
cpp/tensorrt_llm/pybind/thop/attentionOp.h (1)

2-2: Update copyright year

Per project guidelines, include the current year (2025).

- * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cpp/tensorrt_llm/thop/attentionOp.cpp (1)

24-24: Redundant include – but harmless

attentionOp.cpp already includes the implementation header (tensorrt_llm/thop/attentionOp.cpp) further below. Adding the public declaration header here has no functional impact but adds an extra compilation unit dependency and slows incremental builds.
Unless a new interface from the header is actually needed in this TU, consider dropping the line.

cpp/tensorrt_llm/pybind/thop/attentionOp.cpp (3)

63-68: Cheap construction of c10::List

The manual push_back loop is fine but can be simplified:

c10::List<bool> cpp_spec_decoding_bool_list;
cpp_spec_decoding_bool_list.reserve(spec_decoding_bool_params.size());
for (auto v : spec_decoding_bool_params) { cpp_spec_decoding_bool_list.emplace_back(v); }

(or construct from ArrayRef). Minor, purely readability.


85-110: Very long argument lines exceed 120-char limit

Several of the py::arg lines break the project’s 120-character rule, e.g. lines 97-103.
Please clang-format or wrap.


112-116: Return-value policy missing

attention_supports_nvfp4_output returns a plain bool; default policy works, but explicitly adding py::return_value_policy::automatic makes intent clear and silences some linters.

cpp/tensorrt_llm/thop/attentionOp.h (1)

38-62: Monster signature – consider a config struct

attention_inplace now takes >80 parameters. Even with pybind this hurts readability, maintenance and ABI stability. Encapsulate stable groups (e.g. model hyper-params, optional tensors) into small POD structs passed by const-ref.

This is an architectural suggestion and not a blocker for this PR but will pay off quickly.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d913955 and 1d86697.

📒 Files selected for processing (7)
  • cpp/tensorrt_llm/pybind/CMakeLists.txt (2 hunks)
  • cpp/tensorrt_llm/pybind/bindings.cpp (2 hunks)
  • cpp/tensorrt_llm/pybind/thop/attentionOp.cpp (1 hunks)
  • cpp/tensorrt_llm/pybind/thop/attentionOp.h (1 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (1 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.h (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,h,hpp,cc,cxx}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (used for checking signness/existence/emptiness), nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do..while, or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with the first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have case-insensitive unique filenames.
All types (including class names) should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by 's' (e.g., sMutableStaticGlobal).
Locally visible static variables should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic-number/literal constants should be uppercase snake case with prefix...

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/thop/attentionOp.h
  • cpp/tensorrt_llm/pybind/thop/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.h
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/thop/attentionOp.h
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/pybind/thop/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.h
**/*.{h,hpp}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.

Files:

  • cpp/tensorrt_llm/pybind/thop/attentionOp.h
  • cpp/tensorrt_llm/thop/attentionOp.h
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
🧠 Learnings (10)
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • cpp/tensorrt_llm/pybind/CMakeLists.txt
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/thop/attentionOp.h
  • cpp/tensorrt_llm/pybind/thop/attentionOp.cpp
📚 Learning: 2025-08-08T05:06:31.537Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:36-36
Timestamp: 2025-08-08T05:06:31.537Z
Learning: CUTLASS extension files (under cpp/tensorrt_llm/cutlass_extensions/) follow CUTLASS coding style conventions, including using #pragma once instead of TRTLLM_ prefixed header guards, even though they are .hpp files.

Applied to files:

  • cpp/tensorrt_llm/pybind/CMakeLists.txt
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/thop/attentionOp.h
  • cpp/tensorrt_llm/pybind/thop/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.h
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • cpp/tensorrt_llm/pybind/CMakeLists.txt
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • cpp/tensorrt_llm/pybind/thop/attentionOp.h
  • tensorrt_llm/_torch/attention_backend/trtllm.py
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use smart pointers for allocating objects on the heap in C++.

Applied to files:

  • cpp/tensorrt_llm/pybind/CMakeLists.txt
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Do not use smart pointers that have been deprecated in C++11.

Applied to files:

  • cpp/tensorrt_llm/pybind/CMakeLists.txt
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{h,hpp} : Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.

Applied to files:

  • cpp/tensorrt_llm/pybind/CMakeLists.txt
  • cpp/tensorrt_llm/thop/attentionOp.cpp
📚 Learning: 2025-08-08T04:10:18.987Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6728
File: cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp:966-966
Timestamp: 2025-08-08T04:10:18.987Z
Learning: TensorRT plugins currently don't support padding functionality, and TensorRT is not getting new features (in maintenance mode). This means that duplicating parameters like mExpertHiddenSize in function calls, even with TODO comments, can be acceptable as pragmatic solutions within these constraints.

Applied to files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/pybind/thop/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.h
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tensorrt_llm/pybind/bindings.cpp
  • tensorrt_llm/_torch/attention_backend/trtllm.py
📚 Learning: 2025-08-08T05:10:38.893Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:0-0
Timestamp: 2025-08-08T05:10:38.893Z
Learning: The ScaledAccPerRowBiasPerColScaleScatter fusion in CUTLASS extensions (cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp) is specifically designed for per-column scaling factors only, so it uses a fixed Stride<_0,_1,int64_t> rather than conditional stride logic.

Applied to files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
📚 Learning: 2025-08-01T15:14:45.673Z
Learnt from: yibinl-nvidia
PR: NVIDIA/TensorRT-LLM#6506
File: examples/models/core/mixtral/requirements.txt:3-3
Timestamp: 2025-08-01T15:14:45.673Z
Learning: In TensorRT-LLM, examples directory can have different dependency versions than the root requirements.txt file. Version conflicts between root and examples dependencies are acceptable because examples are designed to be standalone and self-contained.

Applied to files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
🔇 Additional comments (6)
cpp/tensorrt_llm/pybind/CMakeLists.txt (1)

21-21: SRCS update to include thop/attentionOp.cpp looks correct

This integrates the new pybind bindings into the module build. Path and placement are fine.

cpp/tensorrt_llm/pybind/bindings.cpp (3)

36-36: Include of thop attentionOp bindings header

Header path and namespace intent look good.


120-120: Add internal.thop submodule

Appropriate submodule name and placement under internal. Good.


124-124: Initialize thop attention bindings

Calling tensorrt_llm::pybind::thop::attentionOp::initBindings(mInternalThop); wires the new functions into the module as intended.

tensorrt_llm/_torch/attention_backend/trtllm.py (2)

10-10: Switch import to new pybind thop module

Importing tensorrt_llm.bindings.internal.thop as trtllm_thop is consistent with the new binding location.


420-484: Route attention call through pybind (thop) instead of torch.ops

The callsite migration to trtllm_thop.attention_inplace aligns with the PR’s objective and preserves the previous argument flow.

If the new binding relaxed the parameter-count restriction, please confirm that no argument truncation or defaulted parameters are occurring at this callsite.

@lancelly lancelly force-pushed the feat/pybind_attentionOp branch from 1d86697 to 40c8a01 Compare August 8, 2025 08:29
@lancelly
Copy link
Collaborator Author

lancelly commented Aug 8, 2025

/bot run

1 similar comment
@lancelly
Copy link
Collaborator Author

lancelly commented Aug 8, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14606 [ run ] triggered by Bot

@lancelly
Copy link
Collaborator Author

lancelly commented Aug 8, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14609 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14606 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14609 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11035 completed with status: 'FAILURE'

@lancelly
Copy link
Collaborator Author

lancelly commented Aug 8, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14616 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14616 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11041 completed with status: 'FAILURE'

Copy link
Collaborator

@jhaotingc jhaotingc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
General q: I packed rotary_embedding_scales, rotary_embedding_max_position_info, spec-dec related params into vectors before.
It's a design choice to leave it as it is or unpacking it, given that it required some packing logics in trtllm.py and create API discrepancy between attention.py and trtllm.py.
It can be another chore in separate PR.

Thanks for the work!

Signed-off-by: Lanyu Liao <[email protected]>
@lancelly
Copy link
Collaborator Author

lancelly commented Aug 9, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14676 [ run ] triggered by Bot

@lancelly
Copy link
Collaborator Author

LGTM. General q: I packed rotary_embedding_scales, rotary_embedding_max_position_info, spec-dec related params into vectors before. It's a design choice to leave it as it is or unpacking it, given that it required some packing logics in trtllm.py and create API discrepancy between attention.py and trtllm.py. It can be another chore in separate PR.

Thanks for the work!

Thanks for reviewing~ We will unpack these vectors in the following PR, tracked in this jira: https://jirasw.nvidia.com/browse/TRTLLM-7076

Signed-off-by: Lanyu Liao <[email protected]>
Signed-off-by: Lanyu Liao <[email protected]>
@lancelly
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14795 [ run ] triggered by Bot

Copy link
Collaborator

@yuxianq yuxianq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Lanyu Liao <[email protected]>
@lancelly
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14800 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14795 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14800 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11173 completed with status: 'FAILURE'

@lancelly
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14818 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14818 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11190 completed with status: 'FAILURE'

@lancelly
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14863 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14863 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11219 completed with status: 'FAILURE'

@lancelly
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14899 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14899 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11243 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@yuxianq yuxianq merged commit f7c13a4 into NVIDIA:main Aug 12, 2025
4 checks passed
MartinMarciniszyn added a commit to MartinMarciniszyn/TensorRT-LLM that referenced this pull request Aug 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Community want to contribute PRs initiated from Community
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants