|
32 | 32 | from ..sampling_params import SamplingParams
|
33 | 33 | from ..scheduling_params import SchedulingParams
|
34 | 34 | from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
|
35 |
| - TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, |
36 |
| - PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs) |
| 35 | + TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig, |
| 36 | + PybindMirror, TorchLlmArgs, TrtLlmArgs) |
37 | 37 | from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
|
38 | 38 | LlmBuildStats, ModelLoader, _ModelRuntimeContext)
|
39 | 39 | from .mpi_session import MpiPoolSession, external_mpi_comm_available
|
@@ -1015,32 +1015,10 @@ def _build_model(self):
|
1015 | 1015 |
|
1016 | 1016 | spec_config = self.args.speculative_config
|
1017 | 1017 | max_batch_size = self._executor_config.max_batch_size
|
1018 |
| - # Apply default heuristic to AutoDecodingConfig based on benchmark results |
1019 |
| - # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 |
1020 |
| - # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 |
1021 |
| - # With concurrency > 32, speculative decoding is disabled. |
1022 |
| - if spec_config is not None and spec_config.decoding_type == "AUTO": |
1023 |
| - if not self.args.disable_overlap_scheduler: |
1024 |
| - logger.info( |
1025 |
| - "Disable overlap scheduler to enable Auto speculative decoding with Ngram." |
1026 |
| - ) |
1027 |
| - # From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32. |
1028 |
| - # Therefore, we disable overlap scheduler to enable NGram speculative decoding. |
1029 |
| - self.args.disable_overlap_scheduler = True |
1030 |
| - |
1031 |
| - spec_config = NGramDecodingConfig( |
1032 |
| - max_draft_len=5 if max_batch_size <= 4 else 3, |
1033 |
| - max_matching_ngram_size=3 if max_batch_size <= 4 else 5, |
1034 |
| - is_keep_all=True, |
1035 |
| - is_use_oldest=True, |
1036 |
| - is_public_pool=True, |
1037 |
| - # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. |
1038 |
| - is_auto_heuristic=True, |
1039 |
| - ) |
1040 | 1018 |
|
1041 |
| - logger.info( |
1042 |
| - f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" |
1043 |
| - ) |
| 1019 | + if spec_config is not None and spec_config.decoding_type == "AUTO": |
| 1020 | + from tensorrt_llm._torch.speculative import suggest_spec_config |
| 1021 | + spec_config = suggest_spec_config(max_batch_size) |
1044 | 1022 |
|
1045 | 1023 | update_executor_config(
|
1046 | 1024 | self._executor_config,
|
|
0 commit comments