-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[TRTLLM-6291] feat: Add user-provided speculative decoding support #5204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c9c52b4
to
44a9be3
Compare
44a9be3
to
c5c77c4
Compare
/bot run |
PR_Github #10500 [ run ] triggered by Bot |
PR_Github #10500 [ run ] completed with state |
/bot run |
PR_Github #10510 [ run ] triggered by Bot |
PR_Github #10510 [ run ] completed with state |
0716e0d
to
f5d428c
Compare
/bot run --stage-list "B200_PCIe-PyTorch-1, B200_PCIe-PyTorch-2, B200_PCIe-PyTorch-3" |
PR_Github #10588 [ run ] triggered by Bot |
PR_Github #10588 [ run ] completed with state |
f5d428c
to
a339d8a
Compare
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for user-provided speculative decoding by introducing a new config type, wiring it through the API and runtime, and adding corresponding tests.
- Added
UserProvidedDecodingConfig
andUserProvidedConfig
to allow users to supply their own drafter. - Extended
SpeculativeDecodingMode
, CLI parsing, and runtime helpers (get_spec_metadata
,get_spec_resource_manager
,get_spec_drafter
) to handle the new mode. - Updated the Python executor (
py_executor_creator
andPyExecutor
) to accept and propagate adrafter
, and added unit tests for the user‐provided decoding path.
Reviewed Changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/unittest/api_stability/references_committed/llm.yaml | Updated speculative_config annotation to include new config type |
tests/unittest/_torch/test_pytorch_model_engine.py | Adjusted super().__init__ call to use keyword args |
tensorrt_llm/models/modeling_utils.py | Added USER_PROVIDED flag to SpeculativeDecodingMode |
tensorrt_llm/llmapi/llm_args.py | Defined UserProvidedDecodingConfig and updated factory mapping |
tensorrt_llm/llmapi/llm_utils.py | Exported new decoding config in build stats |
tensorrt_llm/llmapi/init.py | Exposed UserProvidedDecodingConfig in public API |
tensorrt_llm/_torch/speculative/user_provided.py | Introduced UserProvidedConfig |
tensorrt_llm/_torch/speculative/utils.py | Extended helper functions to support USER_PROVIDED |
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py | Plumbed drafter into executor creation |
tensorrt_llm/_torch/pyexecutor/py_executor.py | Updated executor API to accept and invoke drafter |
tests/unittest/_torch/speculative/test_user_provided.py | Added unit test for user-provided speculative decoding |
tests/unittest/_torch/speculative/test_ngram.py | Aligned CUDA graph field name with updated config |
tests/integration/test_lists/waives.txt | Removed obsolete n-gram skip |
tests/integration/test_lists/test-db/l0_b200.yml | Registered new user-provided test in L0 suite |
Comments suppressed due to low confidence (3)
tests/unittest/_torch/speculative/test_user_provided.py:18
- Expand the test matrix to include cases with
disable_overlap_scheduler=False
(remove the TODO) so that user-provided decoding is verified both with and without overlap scheduling.
# TODO: add disable_overlap_scheduler=False
tensorrt_llm/_torch/speculative/user_provided.py:25
- [nitpick] Add a docstring or implementation to
update_from_model_config
explaining how the user-provided config should adapt to the model's config (e.g., populatingnum_extra_kv_tokens
or validating the drafter).
def update_from_model_config(self, model_config):
tests/unittest/api_stability/references_committed/llm.yaml:69
- [nitpick] Insert a space after the comma between
EagleDecodingConfig,
andtensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig
for consistency and readability.
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, NoneType]
/bot run |
PR_Github #10791 [ run ] triggered by Bot |
PR_Github #10791 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- LGTM!
- Good improvement on the Drafter construction ("construct in resource manager and pass to get_spec_drafter()" to "construct in get_spec_drafter() and pass to resource manager"), making the logic clearer.
- Remove input argument
state
fromDrafterr.prepare_draft_tokens()
. this is remained for overlap-scheduler, we do not need it at this moment. - Maybe we need document update (in a later PR). For example, user need to implement their own
prepare_draft_tokens()
inclass UserProvidedConfig
.
- Updated the create_py_executor function to pass parameters as keyword arguments for better readability. - Modified the PyTorchModelEngine instantiation to use named parameters for model_path and pytorch_backend_config. - Added a keyword-only argument to create_py_executor_instance for consistency in parameter handling. Signed-off-by: Robin Kobus <[email protected]>
- Updated the prepare_draft_tokens method in Drafter and NGramDrafter classes to remove the unused SampleState parameter. - Adjusted the PyExecutor class to call prepare_draft_tokens without the SampleState argument. Signed-off-by: Robin Kobus <[email protected]>
- Removed the NGramSpecMetadata class, consolidating its functionality into SpecMetadata. - Updated imports in ngram.py and utils.py to reflect the removal of NGramSpecMetadata. - Added a placeholder __post_init__ method in SpecMetadata as a default implementation. Signed-off-by: Robin Kobus <[email protected]>
- Removed a test_ngram test case from waives.txt. - Updated the batch sizes parameter in the CudaGraphConfig in test_ngram.py. Signed-off-by: Robin Kobus <[email protected]>
- Introduced UserProvidedConfig for user-defined speculative decoding configurations. - Updated SpeculativeDecodingMode to include USER_PROVIDED option. - Enhanced get_spec_resource_manager and get_spec_drafter functions to handle user-provided drafter. - Modified PyExecutor and related classes to accommodate the new user-provided decoding feature. - Added unit tests to validate user-provided decoding functionality. Signed-off-by: Robin Kobus <[email protected]>
Signed-off-by: Robin Kobus <[email protected]>
a339d8a
to
0985212
Compare
/bot run |
PR_Github #11012 [ run ] triggered by Bot |
PR_Github #11012 [ run ] completed with state |
/bot run |
PR_Github #11019 [ run ] triggered by Bot |
PR_Github #11019 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only have nitpicks, thanks
…VIDIA#5204) Signed-off-by: Robin Kobus <[email protected]> Signed-off-by: Yuxin <[email protected]>
Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]
Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-[Post-Merge]-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.