-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-6393][feat] add static tree sampling and verification #7161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TRTLLM-6393][feat] add static tree sampling and verification #7161
Conversation
📝 WalkthroughWalkthroughAdds SpecTree-based speculative decoding driven by new Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Quickstart
participant Config as EagleDecodingConfig
participant Exec as PyExecutor
participant RM as ResourceManager
participant SpecTree as SpecTreeManager
participant Sampler
User->>Quickstart: launch (--eagle_choices)
Quickstart->>Config: construct (eagle_choices=...)
Quickstart->>Exec: create executor (spec config)
Exec->>RM: attach ResourceManager
alt eagle_choices present
RM->>SpecTree: instantiate SpecTreeManager(...)
end
loop scheduling
Exec->>Sampler: sample_async(..., resource_manager=RM)
Sampler->>RM: get_spec_tree_manager(RM)
alt SpecTree present
Sampler->>SpecTree: tree_sampling(...)
SpecTree-->>Sampler: accepted tokens
else
Sampler->>Sampler: legacy sampling path
end
Sampler-->>Exec: SampleState
Exec->>Sampler: update_requests(state, resource_manager=RM)
end
sequenceDiagram
autonumber
participant Drafter as ModelDrafter
participant RM as ResourceManager
participant Sampler
Drafter->>RM: update_cur_draft_layer_idx(0)
loop per draft layer
Drafter->>Sampler: _sample_async(..., resource_manager=RM)
Sampler-->>Drafter: SampleState
Drafter->>Sampler: _update_requests(state, resource_manager=RM)
Drafter->>RM: update_cur_draft_layer_idx(next)
end
Drafter->>RM: update_cur_draft_layer_idx(max_layer)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
881-894
: Alignupdate_requests
signature withresource_manager
parameterThe call to
self.sampler.update_requests(sample_state, resource_manager)
in both the pipeline-parallel and overlap paths will raise aTypeError
, because theTorchSampler.update_requests
method currently only accepts a singlestate
argument. You must update the sampler’s signature (and all overrides) to accept an optionalresource_manager
so that PP and overlap schedulers can forward resources correctly.• Locations needing fixes:
- Pipeline-parallel executor (
_executor_loop_pp
), around py_executor.py lines 881–884- Overlap scheduler (
_executor_loop_overlap
), around py_executor.py lines 1119–1122• Suggested diff for
tensorrt_llm/_torch/pyexecutor/sampler.py
:@@ class TorchSampler(…): - def update_requests(self, state: SampleState) -> None: + def update_requests( + self, + state: SampleState, + resource_manager: Optional[ResourceManager] = None, + ) -> None: assert isinstance(state, SampleState) scheduled_requests = state.scheduled_requests …Ensure any subclass implementations of
update_requests
are updated accordingly so they can safely ignore or use theresource_manager
.tensorrt_llm/_torch/speculative/eagle3.py (1)
1-1
: Add NVIDIA copyright header (2025) per repo guidelines.Coding guidelines require the NVIDIA copyright header at the top of all source files.
Apply this diff at the top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
1-1
: Add NVIDIA copyright header (2025).Per guidelines, prepend the header to this file.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
tensorrt_llm/_torch/speculative/model_drafter.py (1)
1-1
: Add NVIDIA copyright header (2025).Please prepend the standard header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
1-1
: Add NVIDIA copyright header (2025).Add the required header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
🧹 Nitpick comments (18)
tensorrt_llm/llmapi/llm_args.py (1)
460-471
: Fix typos and long lines flagged by Ruff (E501)
- “Base on” -> “Based on”
- Break long f-strings into multi-arg logger.warning calls to satisfy line-length.
- logger.warning( - f"Base on the input choices, reset the num_eagle_layers from {self.num_eagle_layers} to {num_eagle_layers_from_choices}" - ) + logger.warning( + "Based on the input choices, reset num_eagle_layers from %s to %s", + self.num_eagle_layers, num_eagle_layers_from_choices + ) @@ - logger.warning( - f"Base on the input choices, reset the max_draft_len from {self.max_draft_len} to {max_draft_len_from_choices}" - ) + logger.warning( + "Based on the input choices, reset max_draft_len from %s to %s", + self.max_draft_len, max_draft_len_from_choices + )examples/llm-api/quickstart_advanced.py (1)
131-131
: Clarify and harden --eagle_choices CLI inputCurrently it’s a raw string parsed via ast.literal_eval. That’s brittle and user-hostile. Prefer JSON and document the expected format. Optionally support @file to load from a path.
- parser.add_argument('--eagle_choices', type=str, default=None) + parser.add_argument( + '--eagle_choices', + type=str, + default=None, + help='Static tree choices as JSON string (e.g. [[0],[0,1],[1]]), or @/path/to/choices.json' + )And in setup_llm, add minimal loader:
@@ - elif spec_decode_algo == "EAGLE3": + elif spec_decode_algo == "EAGLE3": + # Support @file syntax for eagle_choices + eagle_choices = args.eagle_choices + if isinstance(eagle_choices, str) and eagle_choices.startswith("@"): + import json + with open(eagle_choices[1:], "r") as f: + eagle_choices = json.load(f) spec_config = EagleDecodingConfig( @@ - eagle_choices=args.eagle_choices) + eagle_choices=eagle_choices)I can wire this change for you across both files if you want.
tensorrt_llm/_torch/speculative/utils.py (1)
41-45
: Thread tree-specific flags conservativelyThe boolean is_spec_dec_tree is set as (eagle_choices is not None). Prefer bool(eagle_choices) to treat empty list as False. Also, guard num_eagle_layers being None.
- num_eagle_layers=spec_config.num_eagle_layers, - eagle_choices=spec_config.eagle_choices, - is_spec_dec_tree=spec_config.eagle_choices is not None, + num_eagle_layers=spec_config.num_eagle_layers or 0, + eagle_choices=spec_config.eagle_choices, + is_spec_dec_tree=bool(spec_config.eagle_choices), is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree,tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
46-46
: Import location consistencyYou import ResourceManager from .resource_manager while ResourceManagerType/request_context come from tensorrt_llm._torch.pyexecutor.resource_manager earlier. Consider using the same absolute path for consistency.
-from .resource_manager import ResourceManager +from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagertests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (3)
30-44
: Avoid reliance on an actual model pathspeculative_model_dir is required by EagleDecodingConfig.validate(), but tests don’t actually load the model. Consider using a temp directory or bypassing validation if not needed.
If validate() is invoked in the test path, we should create a temp dir instead of a hard-coded path. Do you want me to adjust the test to use tmp_path?
412-414
: Remove unittest main boilerplate in a pytest-style test moduleThis file is structured as a pytest module. Keeping unittest.main() runs zero tests when executed directly and can confuse contributors.
-if __name__ == "__main__": - unittest.main() +# Intentionally left blank; tests are discovered by pytest.
55-88
: Improve readability and robustness of test harness
- Use torch.arange for tensors and explicit dtypes.
- Assert shapes explicitly before equality to aid debugging.
- Add comments clarifying how num_logits_per_request relates to tree nodes.
I can push a small cleanup PR to incorporate these improvements if helpful.
Also applies to: 97-109, 118-134, 136-158, 160-182, 184-214, 216-246, 248-280, 282-410
tensorrt_llm/_torch/speculative/eagle3.py (1)
126-129
: Static-only flags; future dynamic-tree will report as non-tree. Consider aligning metadata now.post_init sets is_spec_dec_tree True only when eagle_choices is provided and hard-codes is_spec_dec_dynamic_tree False. If/when dynamic-tree is enabled, this metadata will misreport the mode.
If the spec config is available here, set these flags from the config; otherwise, thread a boolean into the dataclass and use it:
- if self.eagle_choices is not None: - self.is_spec_dec_tree = True - self.is_spec_dec_dynamic_tree = False + if self.eagle_choices is not None: + self.is_spec_dec_tree = True + self.is_spec_dec_dynamic_tree = False + # else: dynamic-tree would set {True, True} when enabledtests/unittest/_torch/speculative/test_draft_token_tree_verification.py (3)
15-16
: Avoid sys.path mutation in tests.tests should import from the installed package/module path. Mutating sys.path can mask packaging/import issues and leads to brittle CI.
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
348-349
: unittest.main() won’t run top-level test functions. Confirm the runner or wrap in TestCase.These tests are plain functions using bare asserts (pytest-style). Running the module as a script via unittest.main() will discover 0 tests. If CI uses pytest, drop this block; if it uses unittest, wrap functions in a unittest.TestCase.
Two options:
- PyTest: remove the block entirely:
-if __name__ == "__main__": - unittest.main()
- unittest: wrap functions in a TestCase and use self.assertEqual/True.
215-215
: Address line-length warnings flagged by Ruff (E501).Long literals are fine during bring-up, but the repo enforces 120 cols. Please wrap or split these lines.
Example fix pattern:
- ################## CASE 1 static tree, target model's request, no draft tokens are accepted ########################## + # CASE 1: static tree, target model's request, no draft tokens acceptedAnd split long lists across lines.
Also applies to: 259-259, 303-303
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
1153-1161
: Constructor typing for eagle_choices should accept Optional.Dynamic-tree passes no eagle_choices. Tighten the signature to reflect that.
- def __init__(self, max_num_requests: int, use_dynamic_tree: bool, - max_draft_len: int, num_eagle_layers: int, - eagle_choices: [List[List[int]]], dynamic_tree_max_topK: int): + def __init__(self, max_num_requests: int, use_dynamic_tree: bool, + max_draft_len: int, num_eagle_layers: int, + eagle_choices: Optional[List[List[int]]], + dynamic_tree_max_topK: int):
1257-1260
: Explicit dtype for child node tensor to avoid implicit float->int casts.torch.tensor() defaults to float32; writing into int32 buffers causes implicit casts. Specify dtype to keep integer semantics and avoid silent precision loss.
- self.child_nodes_list[tree_idx][i][:len(tmp_child_nodes_list[i])] = torch.tensor( - tmp_child_nodes_list[i]) + self.child_nodes_list[tree_idx][i][:len(tmp_child_nodes_list[i])] = torch.tensor( + tmp_child_nodes_list[i], dtype=torch.int32)
1143-1150
: Minor nit: fix typos in comments and align naming.“dynamice tree” → “dynamic tree”; “drafter layer” → “drafter layers”.
1186-1204
: Potential memory growth for child_nodes_list on large trees.child_nodes_list is [num_trees, max_draft_len+1, max_draft_len]. This is O(N^2) per tree. If max_draft_len is large, consider a ragged structure (CSR-like) or per-layer compact vectors to reduce footprint.
tensorrt_llm/_torch/speculative/model_drafter.py (1)
369-371
: Make error message actionable when resource_manager is missing.Current ValueError is fine. Consider hinting that SPEC_RESOURCE_MANAGER must be registered when tree mode is enabled.
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
477-494
: Docstring says “target model only”, but function also handles drafter-side behavior.The early branch (len(py_draft_tokens) == 0) explicitly handles the drafter model path. Update the docstring to prevent confusion.
- """ Tree verification for draft token tree based speculative decoding. This function will only be called for the target model. + """Tree-driven drafting/verification for speculative decoding. + Handles: + - Drafter requests (no existing draft tokens): append tokens from the previous layer. + - Target requests (has draft tokens): verify longest-accepted path and append tokens accordingly.
667-687
: Prefer explicit NotImplementedError over assert for dynamic-tree branch.Asserts can be stripped with -O and provide less actionable messages.
- if spec_tree_manager.dynamic_tree: - assert False, "Dynamic tree is not supported yet." + if spec_tree_manager.dynamic_tree: + raise NotImplementedError("Dynamic tree sampling is not supported yet.")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (10)
examples/llm-api/quickstart_advanced.py
(2 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/resource_manager.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py
(10 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(4 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(8 hunks)tensorrt_llm/_torch/speculative/utils.py
(1 hunks)tensorrt_llm/llmapi/llm_args.py
(3 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
(1 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/resource_manager.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
examples/llm-api/quickstart_advanced.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/resource_manager.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
examples/llm-api/quickstart_advanced.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (8)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (4)
tensorrt_llm/sampling_params.py (2)
SamplingParams
(125-486)_get_sampling_config
(408-438)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
SpecTreeManager
(1113-1260)tensorrt_llm/_torch/pyexecutor/sampler.py (4)
TorchSampler
(353-971)beam_width
(1135-1138)Args
(366-371)tree_sampling
(667-747)
tensorrt_llm/_torch/speculative/eagle3.py (1)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (3)
BaseResourceManager
(58-81)SlotManager
(934-963)SpecTreeManager
(1113-1260)
tensorrt_llm/llmapi/llm_args.py (1)
tensorrt_llm/logger.py (1)
warning
(131-132)
tensorrt_llm/_torch/pyexecutor/py_executor.py (3)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
ResourceManager
(966-1001)tensorrt_llm/_torch/speculative/model_drafter.py (1)
_update_requests
(290-296)tensorrt_llm/_torch/pyexecutor/sampler.py (6)
SampleState
(43-49)update_requests
(68-69)update_requests
(84-99)update_requests
(129-144)update_requests
(749-777)update_requests
(1264-1277)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (4)
tensorrt_llm/sampling_params.py (1)
_get_sampling_config
(408-438)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
SpecTreeManager
(1113-1260)tensorrt_llm/_torch/pyexecutor/sampler.py (3)
beam_width
(1135-1138)Args
(366-371)_process_draft_tokens_tree
(477-585)
examples/llm-api/quickstart_advanced.py (1)
tensorrt_llm/builder.py (1)
default
(50-58)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
BaseResourceManager
(58-81)ResourceManager
(966-1001)ResourceManagerType
(46-51)get_resource_manager
(978-979)tensorrt_llm/_torch/pyexecutor/sampler.py (11)
SampleState
(43-49)sample_async
(63-65)sample_async
(78-82)sample_async
(122-127)sample_async
(796-820)sample_async
(1155-1261)update_requests
(68-69)update_requests
(84-99)update_requests
(129-144)update_requests
(749-777)update_requests
(1264-1277)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
ResourceManager
(966-1001)ResourceManagerType
(46-51)SpecTreeManager
(1113-1260)get_resource_manager
(978-979)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
ScheduledRequests
(18-39)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/resource_manager.py
1124-1124: Line too long (123 > 120)
(E501)
1131-1131: Line too long (130 > 120)
(E501)
1138-1138: Line too long (125 > 120)
(E501)
1149-1149: Line too long (126 > 120)
(E501)
tensorrt_llm/llmapi/llm_args.py
452-452: Line too long (140 > 120)
(E501)
461-461: Line too long (131 > 120)
(E501)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
215-215: Line too long (122 > 120)
(E501)
259-259: Line too long (140 > 120)
(E501)
303-303: Line too long (150 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/sampler.py
480-480: Line too long (134 > 120)
(E501)
486-486: Line too long (147 > 120)
(E501)
489-489: Line too long (160 > 120)
(E501)
490-490: Line too long (179 > 120)
(E501)
492-492: Line too long (165 > 120)
(E501)
493-493: Line too long (121 > 120)
(E501)
676-676: Line too long (171 > 120)
(E501)
678-678: Line too long (121 > 120)
(E501)
680-680: Line too long (158 > 120)
(E501)
716-716: Line too long (137 > 120)
(E501)
🔇 Additional comments (10)
tensorrt_llm/llmapi/llm_args.py (1)
420-428
: Double-check semantics: mapping max_draft_len -> num_eagle_layersThe initial “mapping” of max_draft_len to num_eagle_layers is overridden later by choices, which is fine. Verify that in dynamic-tree mode (no eagle_choices), equating num_eagle_layers=max_draft_len matches downstream expectations; otherwise derive or make num_eagle_layers required for dynamic mode.
Would you like me to scan downstream uses to confirm dynamic-tree relies on num_eagle_layers and not max_draft_len? I can provide a repo query to verify.
Also applies to: 456-472, 479-482
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1560-1565
: LGTM: sampler.update_requests now receives ResourceManagerPlumbs ResourceManager to the sampler; required for tree-based draft processing. Good change.
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (2)
56-61
: These asserts are the critical oracles — good coverage for accept counts and mtokens.The checks correctly validate both acceptance length and the post-state tokens, which exercises the sampler’s tree path. Nice.
1-1
: Add NVIDIA copyright header (2025).Tests are Python source and must include the header per guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
⛔ Skipped due to learnings
Learnt from: CR PR: NVIDIA/TensorRT-LLM#0 File: CODING_GUIDELINES.md:0-0 Timestamp: 2025-08-12T10:28:57.320Z Learning: Applies to **/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py} : Prepend NVIDIA copyright header (current year) to all source files
Learnt from: galagam PR: NVIDIA/TensorRT-LLM#6487 File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12 Timestamp: 2025-08-06T13:58:07.506Z Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
tensorrt_llm/_torch/speculative/model_drafter.py (4)
265-276
: Propagating resource_manager through sampling is the right call.Threading the ResourceManager into sample_async ensures the sampler can choose the tree path when available. This aligns with the new SpecTreeManager plumbing.
387-395
: Good: set cur_eagle_layer_idx=0 before the first forward.Ensures tree_sampling extracts the correct nodes for the first draft layer.
409-420
: Good: keep cur_eagle_layer_idx in sync each layer and update requests between forwards.This sequencing is critical for correct node expansion per layer.
431-436
: Final layer index update is subtle but necessary.Bumping to num_eagle_layers before the last _update_requests lets the sampler know to pick the last-layer node count correctly.
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
886-919
: Fast-path routing looks correct; keep embedding bias in non-tree branch.Nice separation: tree present → use tree_sampling; otherwise keep the original greedy path with embedding bias and eagle3 d2t offset.
585-586
: Return counts exclude root — clear and consistent.Returning num_accepted_draft_tokens - 1 (to exclude the root) matches how new_tokens is indexed by node id. Good consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
235dc62
to
49e2c68
Compare
/bot run --disable-fail-fast |
PR_Github #16313 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
1212-1233
: Add input validation to prevent index overflowMissing validation could lead to out-of-bounds access when writing to pre-allocated buffers. Since past reviews already identified this issue, it appears unresolved.
Add validation at the start of
init_tree_from_input_choices
:def init_tree_from_input_choices(self): if self.dynamic_tree: return # For the static tree tree_idx = 0 + + # Validate inputs to prevent buffer overflow + if self.eagle_choices is None: + raise ValueError("eagle_choices must be provided for static tree") + + max_choice_depth = max((len(c) for c in self.eagle_choices), default=0) + if len(self.eagle_choices) > self.max_draft_len: + raise ValueError(f"Number of eagle_choices ({len(self.eagle_choices)}) " + f"exceeds max_draft_len ({self.max_draft_len})") + if max_choice_depth > self.num_eagle_layers: + raise ValueError(f"Maximum choice depth ({max_choice_depth}) " + f"exceeds num_eagle_layers ({self.num_eagle_layers})") + # 1) Map the index self.index_mapping_list[tree_idx].clear()tensorrt_llm/llmapi/llm_args.py (2)
431-448
: Override Pydantic BaseModel.init bypasses validation - use validators insteadOverriding
__init__
in a Pydantic model bypasses field validation, default handling, and json_schema generation. This can lead to invalid configurations leaking through.Replace the custom
__init__
with Pydantic validators:@model_validator(mode="before") @classmethod def _parse_and_seed_defaults(cls, data: dict): if data is None: return {} # Accept eagle_choices as a JSON/Python-literal string choices = data.get("eagle_choices", None) if isinstance(choices, str): logger.warning( "NOTE: The Draft token tree is still under development, PLEASE DO NOT USE IT !!!" ) try: data["eagle_choices"] = ast.literal_eval(choices.replace(" ", "")) except Exception as e: raise ValueError(f"Invalid eagle_choices string: {choices}") from e # If num_eagle_layers not explicitly set, seed from max_draft_len if "num_eagle_layers" not in data and "max_draft_len" in data and data["max_draft_len"] is not None: data["num_eagle_layers"] = data["max_draft_len"] return data @model_validator(mode="after") def _normalize_tree_settings(self): if self.eagle_choices is None: return self # Force dynamic tree off when static choices are provided if self.use_dynamic_tree: self.use_dynamic_tree = False logger.warning("If eagle_choices is provided, use_dynamic_tree will be set to False") # Validate/normalize choices and align num_eagle_layers/max_draft_len num_layers = self.check_eagle_choices() if self.num_eagle_layers != num_layers: logger.warning( "Based on the input choices, reset num_eagle_layers from %s to %s", self.num_eagle_layers, num_layers ) self.num_eagle_layers = num_layers max_draft_len_from_choices = len(self.eagle_choices) if self.max_draft_len != max_draft_len_from_choices: logger.warning( "Based on the input choices, reset max_draft_len from %s to %s", self.max_draft_len, max_draft_len_from_choices ) self.max_draft_len = max_draft_len_from_choices return self
487-507
: Use explicit validation instead of assertions for user inputAssertions can be stripped with Python's
-O
flag and are unsuitable for validating user input. Additionally, the membership check against a list of lists is O(n²) and should be optimized.def check_eagle_choices(self): # 1) Check connectivity - unique_choices = set( - tuple(sub_choice) - for sub_choice in self.eagle_choices) # remove repeated choices - self.eagle_choices = sorted([list(t) for t in unique_choices], - key=lambda x: (len(x), x)) # sort choices + unique_choices = {tuple(sub_choice) for sub_choice in self.eagle_choices} + self.eagle_choices = sorted([list(t) for t in unique_choices], key=lambda x: (len(x), x)) + for choice in self.eagle_choices: if len(choice) > 1: - assert choice[ - 0: - -1] in self.eagle_choices, f"Error: choice {choice} is not connected" + prefix = tuple(choice[:-1]) + if prefix not in unique_choices: + raise ValueError(f"Invalid eagle_choices: choice {choice} has no parent prefix {list(prefix)}") # 2) Get num_eagle_layers_from_choices - num_eagle_layers_from_choices = 0 - for choice in self.eagle_choices: - num_eagle_layers_from_choices = max(num_eagle_layers_from_choices, - len(choice)) - - return num_eagle_layers_from_choices + return max((len(choice) for choice in self.eagle_choices), default=0)tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
18-44
: Add CUDA availability check to prevent test failures on CPU-only systemsThe test creates CUDA tensors and will fail on machines without GPU support.
+import pytest import os import sys import unittest import torch from utils.llm_data import llm_models_root from tensorrt_llm import SamplingParams from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, SamplingConfig) from tensorrt_llm._torch.pyexecutor.resource_manager import SpecTreeManager from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler from tensorrt_llm.llmapi import EagleDecodingConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for tree sampling tests") def test_draft_token_static_tree_sampling():tensorrt_llm/_torch/pyexecutor/sampler.py (1)
417-425
: Use enum value to reliably fetch SPEC_RESOURCE_MANAGER from registryThe resource manager registry may store entries by the enum's string value, not the enum itself. Using
.value
ensures consistent retrieval.def get_spec_tree_manager(self, resource_manager: ResourceManager): if resource_manager is None: return None spec_resource_manager = resource_manager.get_resource_manager( - ResourceManagerType.SPEC_RESOURCE_MANAGER) + ResourceManagerType.SPEC_RESOURCE_MANAGER.value) if spec_resource_manager is None: return None return spec_resource_manager.spec_tree_manager
🧹 Nitpick comments (6)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
86-87
: Improve test assertion with descriptive error messageThe assertion would benefit from a more descriptive error message to aid debugging when tests fail.
- assert torch.all(new_tokens == ref_new_tokens.transpose(0, 1).unsqueeze( - dim=-1)) + assert torch.all(new_tokens == ref_new_tokens.transpose(0, 1).unsqueeze(dim=-1)), \ + f"Token mismatch: expected {ref_new_tokens.tolist()}, got {new_tokens.transpose(1, 0).squeeze(-1).tolist()}"tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (2)
74-75
: Fix typo in comment"dn" should be "don't".
- # We dn not need to test the case of draft_layer_id = 0, because _update_requests() is one step delay. + # We don't need to test the case of draft_layer_id = 0, because _update_requests() is one step delay.
215-215
: Split long line for better readabilityLine exceeds 120 characters. Consider splitting for readability.
- ################## CASE 1 static tree, target model's request, no draft tokens are accepted ########################## + ################## CASE 1 static tree, target model's request, no draft tokens are accepted #########tensorrt_llm/_torch/pyexecutor/sampler.py (3)
483-592
: Consider extracting tree verification logic into separate methods for clarityThe
_process_draft_tokens_tree
method handles both drafter and target model logic, making it complex. Consider splitting into_process_drafter_tokens
and_process_target_tokens
for better maintainability.def _process_drafter_tokens(self, request: LlmRequest, new_tokens: torch.Tensor, spec_tree_manager: SpecTreeManager) -> int: """Process tokens for drafter model - add draft tokens from previous layer.""" cur_draft_layer_idx = spec_tree_manager.cur_eagle_layer_idx if spec_tree_manager.dynamic_tree: # TODO: For the last layer, we need to resampling all the draft tokens. cur_layer_num_nodes = spec_tree_manager.dynamic_tree_max_topK else: cur_layer_num_nodes = spec_tree_manager.num_nodes_per_layer[0][cur_draft_layer_idx] for i in range(cur_layer_num_nodes): new_token = add_token(request, new_tokens, beam=self.BEAM, step=i) if self._handle_stop_criteria(request, new_token): break return 0 def _process_target_tokens(self, request: LlmRequest, new_tokens: torch.Tensor, spec_tree_manager: SpecTreeManager) -> int: """Process tokens for target model - perform tree verification.""" # ... existing target model logic ...
673-753
: Add validation for tree_sampling static tree assumptionThe method assumes static tree but doesn't validate this early, leading to a late assertion failure.
def tree_sampling(self, requests: list[LlmRequest], beam_width: int, model_outputs: dict[str, torch.Tensor], spec_tree_manager: SpecTreeManager, seq_slots: torch.Tensor, new_tokens: torch.Tensor): """ Tree sampling for draft token tree based speculative decoding. Each node may expand to multiple child nodes. Args: requests: list[LlmRequest]. List of LlmRequest. beam_width: int. Currently only support beam_width = 1 for speculative decoding. model_outputs: dict[str, torch.Tensor]. Model outputs, including logits, d2t. logits: [N, vocab_size], N = num_logits_per_request * len(requests). num_logits_per_request: The number of nodes that has child nodes in the current layer. d2t: [draft_vocab_size] spec_tree_manager: SpecTreeManager. which contains the tree structure and other meta information of the tree. seq_slots: torch.Tensor. [max_num_sequences]. The sequence slots of the requests. new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], device buffer. The output buffer for new generated draft tokens. """ assert beam_width == 1, "speculative decoding only supports beam_width = 1" + + if spec_tree_manager.dynamic_tree: + raise NotImplementedError("Dynamic tree is not supported yet in tree_sampling") + raw_logits = model_outputs["logits"] - if spec_tree_manager.dynamic_tree: - assert False, "Dynamic tree is not supported yet." - else: - # Static tree branch + # Static tree branch
722-726
: Clean up TODO comment and optimize tensor allocationThe TODO comment could be clearer about the optimization opportunity.
- # TODO: this tensor can be optimized - # new_draft_tokens_cuda = torch.empty((spec_tree_manager.max_draft_len + 1, len(requests)), dtype=torch.int64, device='cuda') + # TODO: Optimize by using torch.empty instead of torch.zeros when values will be fully overwritten new_draft_tokens_cuda = torch.zeros( (len(requests), spec_tree_manager.max_draft_len + 1), dtype=torch.int64, device='cuda')
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (10)
examples/llm-api/quickstart_advanced.py
(2 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/resource_manager.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py
(9 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(4 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(9 hunks)tensorrt_llm/_torch/speculative/utils.py
(1 hunks)tensorrt_llm/llmapi/llm_args.py
(3 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
(1 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- tensorrt_llm/_torch/speculative/utils.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/speculative/eagle3.py
- tensorrt_llm/_torch/speculative/model_drafter.py
- tensorrt_llm/_torch/pyexecutor/py_executor.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
tensorrt_llm/_torch/pyexecutor/sampler.py
tensorrt_llm/_torch/pyexecutor/resource_manager.py
tensorrt_llm/llmapi/llm_args.py
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
tensorrt_llm/_torch/pyexecutor/sampler.py
tensorrt_llm/_torch/pyexecutor/resource_manager.py
tensorrt_llm/llmapi/llm_args.py
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
🧬 Code graph analysis (4)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (4)
tensorrt_llm/sampling_params.py (2)
SamplingParams
(125-486)_get_sampling_config
(408-438)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
SpecTreeManager
(1113-1260)tensorrt_llm/_torch/pyexecutor/sampler.py (4)
TorchSampler
(360-983)beam_width
(70-73)Args
(376-381)tree_sampling
(673-753)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
ResourceManager
(966-1001)ResourceManagerType
(46-51)SpecTreeManager
(1113-1260)get_resource_manager
(978-979)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
ScheduledRequests
(18-39)
tensorrt_llm/llmapi/llm_args.py (1)
tensorrt_llm/logger.py (1)
warning
(131-132)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (5)
tensorrt_llm/sampling_params.py (1)
_get_sampling_config
(408-438)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
SpecTreeManager
(1113-1260)tensorrt_llm/_torch/pyexecutor/sampler.py (3)
beam_width
(70-73)Args
(376-381)_process_draft_tokens_tree
(483-591)tensorrt_llm/llmapi/llm_args.py (1)
EagleDecodingConfig
(420-525)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py
486-486: Line too long (134 > 120)
(E501)
492-492: Line too long (147 > 120)
(E501)
495-495: Line too long (160 > 120)
(E501)
496-496: Line too long (179 > 120)
(E501)
498-498: Line too long (165 > 120)
(E501)
499-499: Line too long (121 > 120)
(E501)
682-682: Line too long (171 > 120)
(E501)
684-684: Line too long (121 > 120)
(E501)
686-686: Line too long (158 > 120)
(E501)
722-722: Line too long (137 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/resource_manager.py
1124-1124: Line too long (123 > 120)
(E501)
1131-1131: Line too long (130 > 120)
(E501)
1138-1138: Line too long (125 > 120)
(E501)
1149-1149: Line too long (126 > 120)
(E501)
tensorrt_llm/llmapi/llm_args.py
438-438: Line too long (143 > 120)
(E501)
456-456: Line too long (140 > 120)
(E501)
465-465: Line too long (131 > 120)
(E501)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
215-215: Line too long (122 > 120)
(E501)
259-259: Line too long (140 > 120)
(E501)
303-303: Line too long (150 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
PR_Github #16313 [ run ] completed with state |
/bot run --disable-fail-fast |
PR_Github #16424 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/mtp.py (1)
69-71
: Bug: misuse of torch.copy_ with a scalar; use fill_/zero_ instead
Tensor.copy_
expects a Tensor source. Copying from a Python scalar is invalid and thenon_blocking
arg is ignored. Usefill_(0)
(orzero_()
) for scalar fills.- self.mtp_relaxed_delta_pool[slot_id].copy_( - 0, non_blocking=True) + self.mtp_relaxed_delta_pool[slot_id].fill_(0) @@ - self.mtp_relaxed_delta_pool[free_slot_id].copy_(0, - non_blocking=True) + self.mtp_relaxed_delta_pool[free_slot_id].fill_(0)Also applies to: 79-80
♻️ Duplicate comments (7)
tensorrt_llm/llmapi/llm_args.py (2)
436-456
: Do not override Pydantic BaseModel.init; use validators insteadOverriding init bypasses Pydantic v2 validation, defaults, schema generation, and field coercion. Move parsing/normalization to
@model_validator
hooks. Also catch parse errors fromast.literal_eval
and raise a typed exception.@@ -class EagleDecodingConfig(DecodingBaseConfig): +class EagleDecodingConfig(DecodingBaseConfig): @@ - def __init__(self, **kwargs): - super().__init__() - for attr_name, attr_value in kwargs.items(): - if attr_name == 'max_draft_len': - self.num_eagle_layers = attr_value - self.max_total_draft_tokens = attr_value # If using linear-tree, the max_total_draft_tokens is the same as max_draft_len - # Convert the data type of Eagle choice from str to List[List[int]] - if attr_name == 'eagle_choices' and attr_value is not None: - logger.warning( - "NOTE: The Draft token tree is still under development, PLEASE DO NOT USE IT !!!" - ) - if not isinstance(attr_value, list): - if isinstance(attr_value, str): - attr_value = ast.literal_eval( - attr_value.replace(" ", "")) - else: - raise ValueError( - "Wrong eagle choices type. Eagle choices should be a List[List[int]] or a string like [[0], [1], [2], [0, 0], [0, 1]]." - ) - setattr(self, attr_name, attr_value) - - assert self.max_draft_len is not None, "max_draft_len is required for Eagle" - - # Checks whether the input eagle choices is valid - # and reset the max_draft_len and num_eagle_layers if necessary - if self.eagle_choices is not None: - # If eagle_choices is provided, use_dynamic_tree will not be used - if self.use_dynamic_tree: - self.use_dynamic_tree = False - logger.warning( - "If eagle_choices is provided, use_dynamic_tree will be set to False" - ) - - # Get num_eagle_layers from eagle_choices - num_eagle_layers_from_choices = self.check_eagle_choices() - if num_eagle_layers_from_choices != self.num_eagle_layers: - logger.warning( - f"Base on the input choices, reset the num_eagle_layers(max_draft_len) from {self.num_eagle_layers} to {num_eagle_layers_from_choices}" - ) - self.num_eagle_layers = num_eagle_layers_from_choices - self.max_draft_len = num_eagle_layers_from_choices - - # The max_draft_len is the length of the longest choice minus 1. - max_draft_len_from_choices = max( - len(choice) for choice in self.eagle_choices) - if max_draft_len_from_choices != self.max_draft_len: - logger.warning( - f"Base on the input choices, reset the max_draft_len from {self.max_draft_len} to {max_draft_len_from_choices}" - ) - self.max_draft_len = max_draft_len_from_choices - - # Each draft node has a path(choice) from the root to it. - # So the number of choices also represents the number of max draft nodes. - len(self.eagle_choices) + @model_validator(mode="before") + @classmethod + def _seed_and_parse(cls, data: dict | None): + data = {} if data is None else dict(data) + # Seed legacy alias + if "num_eagle_layers" not in data and data.get("max_draft_len") is not None: + data["num_eagle_layers"] = data["max_draft_len"] + # Parse string eagle_choices -> List[List[int]] + choices = data.get("eagle_choices") + if isinstance(choices, str): + try: + data["eagle_choices"] = ast.literal_eval(choices.replace(" ", "")) + except Exception as e: + raise ValueError(f"Invalid eagle_choices string: {choices}") from e + return data + + @model_validator(mode="after") + def _normalize_tree_settings(self): + if self.max_draft_len is None: + raise ValueError("max_draft_len is required for Eagle") + # If static choices provided, disable dynamic tree + if self.eagle_choices is not None and self.use_dynamic_tree: + self.use_dynamic_tree = False + logger.warning("If eagle_choices is provided, use_dynamic_tree will be set to False") + # Validate/normalize choices and align fields + if self.eagle_choices is not None: + num_layers = self.check_eagle_choices() + if self.num_eagle_layers != num_layers: + logger.warning( + "Based on eagle_choices, reset num_eagle_layers from %s to %s", + self.num_eagle_layers, num_layers) + self.num_eagle_layers = num_layers + if self.max_draft_len != num_layers: + logger.warning( + "Based on eagle_choices, reset max_draft_len from %s to %s", + self.max_draft_len, num_layers) + self.max_draft_len = num_layers + # For static tree, total nodes equals number of choices + self.max_total_draft_tokens = len(self.eagle_choices) + else: + # Linear chain default: depth equals total draft tokens + if self.max_total_draft_tokens is None: + self.max_total_draft_tokens = self.max_draft_len + return selfAlso applies to: 457-486
501-521
: Use explicit validation instead of asserts; normalize choices efficientlyReplace
assert
-based validation with clear exceptions. Deduplicate choices via a set of tuples and compute depth safely.- def check_eagle_choices(self): - # 1) Check connectivity - unique_choices = set( - tuple(sub_choice) - for sub_choice in self.eagle_choices) # remove repeated choices - self.eagle_choices = sorted([list(t) for t in unique_choices], - key=lambda x: (len(x), x)) # sort choices - for choice in self.eagle_choices: - if len(choice) > 1: - assert choice[ - 0: - -1] in self.eagle_choices, f"Error: choice {choice} is not connected" - - # 2) Get num_eagle_layers_from_choices - num_eagle_layers_from_choices = 0 - for choice in self.eagle_choices: - num_eagle_layers_from_choices = max(num_eagle_layers_from_choices, - len(choice)) - - return num_eagle_layers_from_choices + def check_eagle_choices(self) -> int: + if self.eagle_choices is None or len(self.eagle_choices) == 0: + raise ValueError("eagle_choices must be non-empty for static tree mode.") + unique = {tuple(c) for c in self.eagle_choices} + self.eagle_choices = sorted([list(t) for t in unique], key=lambda x: (len(x), x)) + for choice in self.eagle_choices: + if len(choice) > 1: + prefix = tuple(choice[:-1]) + if prefix not in unique: + raise ValueError(f"Invalid eagle_choices: choice {choice} has no parent prefix {list(prefix)}") + return max((len(c) for c in self.eagle_choices), default=0)tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
1209-1265
: Add capacity and integrity validation for static-tree constructionProtect against OOB writes when choices exceed configured capacities and ensure
eagle_choices
is present in static mode. Also deduplicate choices locally to avoid duplicate node indices.def init_tree_from_input_choices(self): if self.dynamic_tree: return # For the static tree tree_idx = 0 + if self.eagle_choices is None: + raise ValueError("eagle_choices must be provided for static tree mode") + # Deduplicate and sort for deterministic layout + choices = sorted({tuple(c) for c in self.eagle_choices}, key=lambda x: (len(x), x)) + max_choice_len = max((len(c) for c in choices), default=0) + if len(choices) > self.max_total_draft_tokens: + raise ValueError(f"len(eagle_choices) ({len(choices)}) > max_total_draft_tokens ({self.max_total_draft_tokens})") + if max_choice_len > self.max_draft_len: + raise ValueError(f"max choice depth ({max_choice_len}) > max_draft_len ({self.max_draft_len})") # 1) Map the index self.index_mapping_list[tree_idx].clear() - for i, choice in enumerate(self.eagle_choices): + for i, choice in enumerate(choices): self.index_mapping_list[tree_idx][str(choice)] = i + 1 # 2) Reconstruct the eagle_paths self.eagle_paths[tree_idx][0][0] = 0 # root node - for i, choice in enumerate(self.eagle_choices): + for i, choice in enumerate(choices): self.eagle_paths[tree_idx][i + 1][0] = 0 for j, token in enumerate(choice): self.eagle_paths[tree_idx][i + 1][ j + 1] = self.index_mapping_list[tree_idx][str(choice[:j + 1])] @@ - for choice in self.eagle_choices: + for choice in choices: cur_layer = len(choice) self.nodes_list_per_layer[tree_idx][cur_layer][ self.num_nodes_per_layer[tree_idx] [cur_layer]] = self.index_mapping_list[tree_idx][str(choice)] self.num_nodes_per_layer[tree_idx][cur_layer] += 1 @@ - for choice in self.eagle_choices: + for choice in choices: cur_node_index = self.index_mapping_list[tree_idx][str(choice)] if len(choice) == 1: self.parent_node_index[tree_idx][cur_node_index] = 0 else: self.parent_node_index[tree_idx][ cur_node_index] = self.index_mapping_list[tree_idx][str( choice[:-1])] @@ - for choice in self.eagle_choices: + for choice in choices: if len(choice) == 1: tmp_child_nodes_list[0].append( self.index_mapping_list[tree_idx][str(choice)]) else: tmp_child_nodes_list[self.index_mapping_list[tree_idx][str( choice[:-1])]].append( self.index_mapping_list[tree_idx][str(choice)])
1153-1156
: Type hint for eagle_choices is incorrect
[List[List[int]]]
indicates a single-element list type. UseList[List[int]]
.- def __init__(self, max_num_requests: int, use_dynamic_tree: bool, - max_total_draft_tokens: int, max_draft_len: int, - eagle_choices: [List[List[int]]], dynamic_tree_max_topK: int): + def __init__(self, max_num_requests: int, use_dynamic_tree: bool, + max_total_draft_tokens: int, max_draft_len: int, + eagle_choices: List[List[int]] | None, dynamic_tree_max_topK: int):tensorrt_llm/_torch/speculative/eagle3.py (1)
50-59
: Instantiate SpecTreeManager for dynamic-tree mode tooToday it’s only created when
eagle_choices
is provided (static). Foruse_dynamic_tree=True
with no choices, the manager is never initialized, breaking downstream accesses.- self.spec_tree_manager = None - if config.eagle_choices is not None: + self.spec_tree_manager = None + if config.use_dynamic_tree or (config.eagle_choices is not None): self.spec_tree_manager = SpecTreeManager( max_num_requests=self.max_num_requests, use_dynamic_tree=config.use_dynamic_tree, - max_draft_len=self.max_draft_len, - max_total_draft_tokens=self.max_total_draft_tokens, + max_total_draft_tokens=self.max_total_draft_tokens, + max_draft_len=self.max_draft_len, eagle_choices=config.eagle_choices, dynamic_tree_max_topK=config.dynamic_tree_max_topK, )tensorrt_llm/_torch/pyexecutor/sampler.py (1)
432-433
: Use .value to reliably fetch SPEC_RESOURCE_MANAGER.The ResourceManager stores managers under string keys, but this code passes the enum directly. This can cause key mismatch if the registration used the string value.
Apply this diff to fix the key lookup:
- spec_resource_manager = resource_manager.get_resource_manager( - ResourceManagerType.SPEC_RESOURCE_MANAGER) + spec_resource_manager = resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER.value)tensorrt_llm/_torch/speculative/model_drafter.py (1)
359-360
: Use .value to reliably fetch SPEC_RESOURCE_MANAGER.Same issue as in the sampler - the enum should be converted to its string value for reliable key lookup.
Apply this diff to fix the key lookup:
- spec_resource_manager = resource_manager.get_resource_manager( - ResourceManagerType.SPEC_RESOURCE_MANAGER) + spec_resource_manager = resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER.value)
🧹 Nitpick comments (6)
tensorrt_llm/llmapi/llm_args.py (3)
424-433
: Clarify semantics between max_draft_len, max_total_draft_tokens, num_eagle_layersThese three are easy to confuse. Please document and enforce via validators:
- max_draft_len: maximum depth (levels) of the draft tree.
- max_total_draft_tokens: maximum number of draft nodes (can exceed depth for trees; equals depth for chains).
- num_eagle_layers: legacy alias kept for TRT compatibility; mirror max_draft_len.
Would you like me to add validators that keep these in sync (and set max_total_draft_tokens=len(eagle_choices) when static choices are provided)?
478-486
: Comment contradicts code: depth vs “minus 1”The comment says “length of the longest choice minus 1” but the code sets
max_draft_len = max(len(choice))
(no minus 1). Clarify by fixing the comment or the computation.- # The max_draft_len is the length of the longest choice minus 1. + # max_draft_len equals the depth (length) of the longest choice.
445-445
: Minor: wrap long logger.warning strings to satisfy Ruff E501Several lines exceed 120 chars per Ruff. Consider splitting the f-strings or using logger params.
Also applies to: 473-475, 483-484
tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
1149-1151
: Docstring shape mismatch for num_child_nodesComment says
[num_trees, max_draft_len + 1]
but tensor is allocated as[num_trees, max_total_draft_tokens + 1]
. Update the comment to match the actual shape (per-node count).- # shape: [num_trees, max_draft_len + 1] + # shape: [num_trees, max_total_draft_tokens + 1]
1194-1205
: Memory growth concern: child_nodes_list is O(N^2)For large trees,
[num_trees, N+1, N]
can be sizable. If memory pressure shows up, consider a compact CSR-like representation (flat children array + offsets) or per-layer adjacency lists.tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (1)
217-217
: Line length violations should be addressed.The static analysis tool flagged several lines exceeding 120 characters. Consider breaking these long comments into multiple lines for better readability.
Apply these formatting fixes:
- ################## CASE 1 static tree, target model's request, no draft tokens are accepted ########################## + ################## CASE 1 static tree, target model's request, + # no draft tokens are accepted ########################## - ################## CASE 2 static tree, target model's request, only one path is accepted, not the longest one ########################## + ################## CASE 2 static tree, target model's request, + # only one path is accepted, not the longest one ########################## - ################## CASE 3 static tree, target model's request, only one path is accepted, which is also the longest one ########################## + ################## CASE 3 static tree, target model's request, + # only one path is accepted, which is also the longest one ##########################Also applies to: 262-262, 307-307
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (11)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
(2 hunks)tensorrt_llm/_torch/pyexecutor/_util.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/resource_manager.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py
(14 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(5 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(9 hunks)tensorrt_llm/_torch/speculative/mtp.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(1 hunks)tensorrt_llm/llmapi/llm_args.py
(3 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
(1 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
- tensorrt_llm/_torch/speculative/utils.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
🧬 Code graph analysis (7)
tensorrt_llm/_torch/speculative/mtp.py (2)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)
update_requests
(66-70)update_requests
(94-105)update_requests
(140-159)update_requests
(766-794)update_requests
(1264-1280)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
BaseResourceManager
(58-81)
tensorrt_llm/_torch/pyexecutor/_util.py (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
TorchSampler
(371-994)Args
(387-393)
tensorrt_llm/llmapi/llm_args.py (1)
tensorrt_llm/logger.py (1)
warning
(131-132)
tensorrt_llm/_torch/speculative/eagle3.py (1)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (3)
BaseResourceManager
(58-81)SlotManager
(934-963)SpecTreeManager
(1113-1264)
tensorrt_llm/_torch/speculative/model_drafter.py (4)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
BaseResourceManager
(58-81)ResourceManager
(966-1001)ResourceManagerType
(46-51)get_resource_manager
(978-979)tensorrt_llm/_torch/pyexecutor/py_executor.py (2)
_sample_async
(1533-1557)_update_requests
(1570-1579)tensorrt_llm/_torch/pyexecutor/sampler.py (6)
SampleState
(41-47)update_requests
(66-70)update_requests
(94-105)update_requests
(140-159)update_requests
(766-794)update_requests
(1264-1280)tensorrt_llm/_torch/speculative/mtp.py (1)
update_requests
(245-273)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
ResourceManager
(966-1001)ResourceManagerType
(46-51)SpecTreeManager
(1113-1264)get_resource_manager
(978-979)tensorrt_llm/_torch/speculative/mtp.py (1)
update_requests
(245-273)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (4)
tensorrt_llm/sampling_params.py (2)
SamplingParams
(125-486)_get_sampling_config
(408-438)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
SpecTreeManager
(1113-1264)tensorrt_llm/_torch/pyexecutor/sampler.py (4)
TorchSampler
(371-994)beam_width
(73-76)Args
(387-393)_process_draft_tokens_tree
(496-603)
🪛 Ruff (0.12.2)
tensorrt_llm/llmapi/llm_args.py
433-433: Line too long (137 > 120)
(E501)
445-445: Line too long (147 > 120)
(E501)
465-465: Line too long (155 > 120)
(E501)
475-475: Line too long (131 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/resource_manager.py
1117-1117: Line too long (135 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/sampler.py
499-499: Line too long (134 > 120)
(E501)
505-505: Line too long (147 > 120)
(E501)
508-508: Line too long (160 > 120)
(E501)
509-509: Line too long (179 > 120)
(E501)
511-511: Line too long (165 > 120)
(E501)
512-512: Line too long (121 > 120)
(E501)
694-694: Line too long (171 > 120)
(E501)
696-696: Line too long (121 > 120)
(E501)
698-698: Line too long (158 > 120)
(E501)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
217-217: Line too long (122 > 120)
(E501)
262-262: Line too long (140 > 120)
(E501)
307-307: Line too long (150 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (12)
tensorrt_llm/_torch/speculative/eagle3.py (1)
35-36
: LGTM: threading max_total_draft_tokens into resource managerGood to thread capacity into Eagle3ResourceManager; aligns with static-tree buffer sizing.
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
496-603
: Tree verification implementation looks comprehensive.The implementation correctly handles both drafter and target model paths, uses proper tree traversal logic to find the longest matching path, and includes appropriate token handling with stop criteria.
685-764
: Tree sampling implementation is well-structured.The static tree sampling logic correctly calculates nodes per layer, extracts topK values, and properly handles the tensor operations for draft token generation. The integration with Eagle3 d2t mapping is also handled correctly.
925-944
: Tree-based fast path implementation is correct.The fast path correctly routes between non-spec-dec/linear tree mode and tree sampling mode based on the presence of SpecTreeManager. The fallback logic preserves existing behavior when the tree manager is not available.
tensorrt_llm/_torch/pyexecutor/_util.py (1)
581-587
: Safe fallback logic for max_total_draft_tokens derivation.The implementation correctly handles three cases: no speculative config (0), explicit max_total_draft_tokens field, and fallback to max_draft_len. The hasattr check ensures compatibility with older config formats.
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)
288-294
: Consistent max_total_draft_tokens handling.The autodeploy path uses the same derivation logic as the main executor path in _util.py, ensuring consistency across different execution modes.
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (3)
18-62
: Test structure and validation logic look solid.The test properly constructs the necessary components (EagleDecodingConfig, SpecTreeManager, TorchSampler) and validates both the number of accepted draft tokens and the final token sequence. The diagnostic prints will be helpful for debugging.
66-204
: Draft model test cases are comprehensive.The three test cases properly cover different draft layer IDs and verify that draft tokens are correctly added to the request without acceptance (ref_num_accepted_draft_tokens = 0).
207-351
: Target model test cases cover key scenarios.The test cases appropriately cover no acceptance, partial acceptance, and full path acceptance scenarios. The test data correctly simulates different tree verification outcomes.
tensorrt_llm/_torch/speculative/model_drafter.py (3)
355-367
: Resource manager integration for draft layer tracking.The implementation correctly extracts the SpecTreeManager from the resource manager and updates the current draft layer index. The null checks ensure robustness.
266-294
: Resource manager parameter threading looks good.The _sample_async method properly forwards the resource_manager to the sampler's sample_async method, maintaining the chain of resource management throughout the drafting process.
402-404
: Draft layer progression is properly managed.The code correctly updates the draft layer index at key points: before initial forward (0), at each iteration (i+1), and after completion (max_draft_len). This ensures the SpecTreeManager stays synchronized with the drafting progress.
Also applies to: 424-424, 433-434, 446-448, 450-450
PR_Github #16424 [ run ] completed with state |
acd4a56
to
05ba03e
Compare
/bot run --disable-fail-fast |
PR_Github #16525 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (10)
tensorrt_llm/llmapi/llm_args.py (3)
439-458
: Consider using validators instead of custom__init__
.Overriding
__init__
in a Pydantic model can bypass validation and lead to subtle issues. Consider refactoring to use Pydantic validators for better maintainability and validation guarantees.
448-449
: Use logger consistently instead of print.Replace
logger.warning()
for consistency with the rest of the codebase.
504-523
: Use explicit validation instead of asserts.Asserts can be stripped with
-O
flag and are unsuitable for user input validation. Convert to a set for O(1) connectivity checks and raise explicit errors.tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
1171-1171
: Fix type hint foreagle_choices
parameter.The type hint
[List[List[int]]]
incorrectly suggests a list containing a single list of lists. Should beList[List[int]]
.
1230-1253
: Add validation to prevent buffer overflow.The static tree reconstruction could overflow if
eagle_choices
lengths exceed buffer capacities. Add bounds checking before processing.tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
18-411
: Add CUDA availability check.The test creates CUDA tensors and will fail on machines without GPU. Add a skip guard to gracefully handle this case.
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
429-438
: Consider using enum's string value for reliability.The dictionary key might be registered using the enum's string value rather than the enum itself.
tensorrt_llm/_torch/pyexecutor/py_executor.py (3)
842-843
: Missing ResourceManager in PP finalize path disables tree-aware processing.
_executor_loop_pp
calls_update_requests(previous_batch.sample_state)
without theresource_manager
, so static/dynamic tree verification won’t run in PP mode.Apply:
- self._update_requests(previous_batch.sample_state) + self._update_requests(previous_batch.sample_state, self.resource_manager)
1084-1086
: Missing ResourceManager in overlap scheduler path.Same issue as PP: overlap path invokes
_update_requests
withoutresource_manager
, breaking draft-token tree verification whenever overlap is enabled.Apply:
- if self.previous_batch is not None: - self._update_requests(self.previous_batch.sample_state) + if self.previous_batch is not None: + self._update_requests(self.previous_batch.sample_state, self.resource_manager)
978-979
: Missing ResourceManager in remaining_update_requests
callsThe two call sites below still pass only the
sample_state
argument. They need to includeself.resource_manager
to ensure tree-aware draft-token processing in the non-PP/non-overlap path:
- tensorrt_llm/_torch/pyexecutor/py_executor.py:842
- tensorrt_llm/_torch/pyexecutor/py_executor.py:1085
Suggested replacement diff at both locations:
- self._update_requests(previous_batch.sample_state) + self._update_requests(previous_batch.sample_state, self.resource_manager)Please update these call sites to thread the
ResourceManager
as done elsewhere.
🧹 Nitpick comments (8)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
21-21
: Consider parameterizing the hardcoded model path.The eagle model directory path is hardcoded. Consider making it configurable via environment variable or test parameter for better portability.
- eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" # It will not actually be used. + eagle_model_dir = os.environ.get('EAGLE_MODEL_DIR', + f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B") # It will not actually be used.tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1540-1545
: Consider documenting the new optional parameter.The signature change is sound. Add a short docstring on what
resource_manager
is used for (e.g., tree-aware draft token processing) to guide future maintainers.Apply:
@nvtx_range("_update_requests") def _update_requests(self, sample_state: SampleState, resource_manager: Optional[ResourceManager] = None): - try: + """Update requests using the sampler (may use resource_manager for tree-aware drafting).""" + try: self.sampler.update_requests(sample_state, resource_manager)tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (6)
53-57
: Remove noisy prints in tests; rely on assertions.These prints clutter CI logs and aren’t needed for failures (asserts will show values). Prefer logging at DEBUG if absolutely necessary.
Apply:
- print(f"num_accepted_draft_tokens: {num_accepted_draft_tokens}") - print(f"ref_num_accepted_draft_tokens: {ref_num_accepted_draft_tokens}") - print(f"input_request.get_tokens(0): {input_request.get_tokens(0)}") - print(f"ref_mtokens: {ref_mtokens}") + # Debug traces can be enabled locally if needed.
76-78
: Fix minor typos in comments.Polish for clarity.
Apply:
- # We dn not need to test the case of draft_layer_id = 0, because _update_requests() is one step delay. - # And we do not need to extract the root node of in the draft_layer_id = 0. + # We do not need to test draft_layer_id = 0 because _update_requests() is one-step delayed. + # We also do not need to extract the root node for draft_layer_id = 0.
217-217
: Wrap long comment line to satisfy Ruff E501 (<=120 chars).Current line length exceeds the limit.
Apply:
- ################## CASE 1 static tree, target model's request, no draft tokens are accepted ########################## + ################## CASE 1 ########################## + # static tree, target-model request; no draft tokens are accepted
262-262
: Wrap long comment line to satisfy Ruff E501 (<=120 chars).Same formatting issue as Case 1.
Apply:
- ################## CASE 2 static tree, target model's request, only one path is accepted, not the longest one ########################## + ################## CASE 2 ########################## + # static tree, target-model request; only one path accepted, not the longest
307-307
: Wrap long comment line to satisfy Ruff E501 (<=120 chars).Same formatting issue as prior cases.
Apply:
- ################## CASE 3 static tree, target model's request, only one path is accepted, which is also the longest one ########################## + ################## CASE 3 ########################## + # static tree, target-model request; one path accepted (also the longest)
65-204
: Optional: reduce duplication via parametrization.The three draft-model cases and three target-model cases repeat boilerplate. Consider pytest parametrize or a small table-driven loop to improve maintainability.
If you’re open to pytest, I can provide a parametrized version; otherwise, we can factor out the common fixture-style setup inside helpers.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (13)
examples/llm-api/quickstart_advanced.py
(2 hunks)tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
(2 hunks)tensorrt_llm/_torch/pyexecutor/_util.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/resource_manager.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py
(14 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(5 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(9 hunks)tensorrt_llm/_torch/speculative/mtp.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(1 hunks)tensorrt_llm/llmapi/llm_args.py
(3 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
(1 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
- tensorrt_llm/_torch/pyexecutor/_util.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/speculative/utils.py
- tensorrt_llm/_torch/speculative/eagle3.py
- tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
- tensorrt_llm/_torch/speculative/model_drafter.py
- tensorrt_llm/_torch/speculative/mtp.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures
Files:
tensorrt_llm/_torch/pyexecutor/resource_manager.py
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
tensorrt_llm/llmapi/llm_args.py
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
tensorrt_llm/_torch/pyexecutor/sampler.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)
Files:
tensorrt_llm/_torch/pyexecutor/resource_manager.py
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
tensorrt_llm/llmapi/llm_args.py
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
tensorrt_llm/_torch/pyexecutor/sampler.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
🧠 Learnings (1)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
🧬 Code graph analysis (5)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (3)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-424)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
SpecTreeManager
(1129-1280)tensorrt_llm/_torch/pyexecutor/sampler.py (4)
TorchSampler
(371-994)beam_width
(73-76)Args
(387-393)tree_sampling
(685-764)
tensorrt_llm/llmapi/llm_args.py (1)
tensorrt_llm/logger.py (1)
warning
(131-132)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (3)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-424)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
SpecTreeManager
(1129-1280)tensorrt_llm/_torch/pyexecutor/sampler.py (2)
TorchSampler
(371-994)_process_draft_tokens_tree
(496-603)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
ResourceManager
(982-1017)ResourceManagerType
(46-51)SpecTreeManager
(1129-1280)get_resource_manager
(994-995)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-424)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
ScheduledRequests
(18-39)
tensorrt_llm/_torch/pyexecutor/py_executor.py (3)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
ResourceManager
(982-1017)tensorrt_llm/_torch/speculative/model_drafter.py (1)
_update_requests
(305-311)tensorrt_llm/_torch/pyexecutor/sampler.py (6)
SampleState
(41-47)update_requests
(66-70)update_requests
(94-105)update_requests
(140-159)update_requests
(766-794)update_requests
(1264-1280)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/resource_manager.py
1133-1133: Line too long (135 > 120)
(E501)
tensorrt_llm/llmapi/llm_args.py
436-436: Line too long (137 > 120)
(E501)
448-448: Line too long (147 > 120)
(E501)
468-468: Line too long (155 > 120)
(E501)
478-478: Line too long (131 > 120)
(E501)
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py
217-217: Line too long (122 > 120)
(E501)
262-262: Line too long (140 > 120)
(E501)
307-307: Line too long (150 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/sampler.py
499-499: Line too long (134 > 120)
(E501)
505-505: Line too long (147 > 120)
(E501)
508-508: Line too long (160 > 120)
(E501)
509-509: Line too long (179 > 120)
(E501)
511-511: Line too long (165 > 120)
(E501)
512-512: Line too long (121 > 120)
(E501)
694-694: Line too long (171 > 120)
(E501)
696-696: Line too long (121 > 120)
(E501)
698-698: Line too long (158 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (11)
tensorrt_llm/llmapi/llm_args.py (2)
1-1
: LGTM!The addition of
ast
import is appropriate for safely evaluating string representations of eagle choices.
431-431
: Good addition for tree-based speculative decoding.The
max_total_draft_tokens
field properly complementsmax_draft_len
to distinguish between total tree nodes and maximum depth.tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
1129-1168
: Well-structured tree management class.The
SpecTreeManager
class is properly designed with clear attributes and documentation. The buffers are appropriately pinned for efficient GPU transfer.tensorrt_llm/_torch/pyexecutor/sampler.py (6)
26-27
: LGTM! Clean imports for resource management.The imports for ResourceManager components are properly structured.
66-70
: Good API consistency with optional resource_manager parameter.The addition of optional
resource_manager
parameter across all update_requests methods maintains interface consistency while preserving backward compatibility.Also applies to: 94-99, 140-145
496-604
: Well-implemented tree verification logic.The tree verification algorithm correctly handles both draft and target model scenarios with proper longest prefix matching. The documentation clearly explains the verification process.
685-765
: Efficient tree sampling implementation.The static tree sampling correctly handles per-layer topK sampling with proper batching across requests. The temporary CUDA buffer approach for collecting results is efficient.
766-789
: Proper resource manager threading through update pipeline.The resource_manager is correctly threaded through the update pipeline to enable tree-based processing when available.
936-944
: Good fallback to tree sampling in fast path.The logic properly detects when tree sampling should be used and falls back appropriately when spec_tree_manager is available.
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
47-47
: Importing ResourceManager for type hints is correct and scoped.This keeps annotations precise without altering runtime behavior. No concerns.
tests/unittest/_torch/speculative/test_draft_token_tree_verification.py (1)
18-37
: Test scaffolding looks good; config and manager wiring match the intended flow.The construction of EagleDecodingConfig, SpecTreeManager, and TorchSampler aligns with the tree-verification design and isolates the logic for unit testing.
PR_Github #19444 [ run ] triggered by Bot |
PR_Github #19444 [ run ] completed with state |
/bot run |
PR_Github #19525 [ run ] triggered by Bot |
PR_Github #19525 [ run ] completed with state |
/bot run |
PR_Github #19609 [ run ] triggered by Bot |
PR_Github #19609 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on the llmapi changes
b5670e0
to
257ea5b
Compare
/bot run --disable-fail-fast |
PR_Github #19819 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note sure why this PR requires approval from trt-llm-doc-owners, approving to unblock the merge.
Signed-off-by: Yue Weng <[email protected]>
257ea5b
to
affb076
Compare
/bot kill |
/bot run --disable-fail-fast |
PR_Github #19945 [ kill ] triggered by Bot |
PR_Github #19819 [ run ] completed with state |
PR_Github #19945 [ kill ] completed with state |
PR_Github #19946 [ run ] triggered by Bot |
PR_Github #19946 [ run ] completed with state |
/bot run |
PR_Github #20015 [ run ] triggered by Bot |
PR_Github #20015 [ run ] completed with state |
Summary by CodeRabbit
New Features
Tests
Description
This RP implements sampling and verification of the draft token static tree.
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.