Skip to content

Commit ba7ee74

Browse files
committed
[TRTLLM-6392][feat] Support turning on/off spec decoding dynamically
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 470544c commit ba7ee74

File tree

12 files changed

+246
-82
lines changed

12 files changed

+246
-82
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ def forward(
130130
# PyExecutor will extract these from the draft model engine's spec metadata.
131131
# They will be passed to the draft model engine on the next iteration.
132132
# TODO: can we support multiple model outputs instead?
133-
spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states,
134-
residual)
133+
if spec_metadata is not None:
134+
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
135+
hidden_states, residual)
135136
return hidden_states, residual
136137

137138

@@ -249,6 +250,9 @@ def forward(
249250
hidden_states: Optional[torch.Tensor] = None,
250251
**kwargs,
251252
) -> torch.Tensor:
253+
if spec_metadata is None:
254+
return None
255+
252256
hidden_states = self.apply_eagle3_fc(spec_metadata.get_hidden_states())
253257
output, _ = self.model(
254258
input_ids=input_ids,
@@ -380,7 +384,7 @@ def forward(
380384
**kwargs,
381385
)
382386

383-
if self.draft_model is not None:
387+
if self.draft_model is not None and spec_metadata is not None:
384388
# get logits
385389
logits = self.logits_processor.forward(
386390
hidden_states[spec_metadata.gather_ids],

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,17 @@ def executor_request_to_llm_request(
477477
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
478478
None))
479479
return llm_request
480+
481+
482+
def get_draft_token_length(request: LlmRequest) -> int:
483+
"""Get the length of draft tokens for a given request.
484+
485+
Args:
486+
request: The LlmRequest to get draft token length for
487+
488+
Returns:
489+
The number of draft tokens, or 0 if no draft tokens exist
490+
"""
491+
if request.py_draft_tokens is not None:
492+
return len(request.py_draft_tokens)
493+
return 0

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 100 additions & 63 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,10 @@ def _executor_loop(self):
826826
self._pad_attention_dp_dummy_request()
827827

828828
if self.drafter is not None:
829+
self.use_spec_decode = self.drafter.should_use_spec_decode(
830+
self.active_requests)
831+
self.model_engine.use_runtime_spec_decode(
832+
self.use_spec_decode)
829833
self._prepare_draft_requests(self.active_requests)
830834

831835
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
@@ -868,7 +872,7 @@ def _executor_loop(self):
868872
self._handle_first_token_response(scheduled_batch)
869873

870874
self.resource_manager.prepare_resources(scheduled_batch)
871-
if self.drafter is not None:
875+
if self.drafter is not None and self.use_spec_decode:
872876
self.drafter.prepare_draft_tokens(
873877
scheduled_batch, self.resource_manager)
874878

@@ -924,7 +928,7 @@ def _prepare_draft_requests(self, requests):
924928
req.py_last_draft_tokens = req.py_draft_tokens
925929
max_draft_len = self.model_engine.spec_config.max_draft_len
926930

927-
if max_draft_len > 0:
931+
if max_draft_len > 0 and self.use_spec_decode:
928932
req.py_draft_tokens = [0] * max_draft_len
929933
req.py_draft_pages_allocated = max_draft_len
930934
else:

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from ..._utils import binding_dtype_size, nvtx_range
1616
from ...logger import logger
1717
from ...mapping import Mapping
18-
from .llm_request import LlmRequest, LlmRequestState, SamplingConfig
18+
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
19+
get_draft_token_length)
1920
from .scheduler import ScheduledRequests
2021

2122
if ENABLE_MULTI_DEVICE:
@@ -353,12 +354,12 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
353354
req_beam_width, req)
354355
for _ in range(self.num_extra_kv_tokens):
355356
self.impl.add_token(req.py_request_id)
356-
for _ in range(len(req.py_draft_tokens)):
357-
self.impl.add_token(req.py_request_id)
357+
for _ in range(get_draft_token_length(req)):
358+
self.impl.add_token(req.py_request_id)
358359

359360
for req in generation_batch:
360361
self.impl.add_token(req.py_request_id)
361-
for _ in range(len(req.py_draft_tokens)):
362+
for _ in range(get_draft_token_length(req)):
362363
self.impl.add_token(req.py_request_id)
363364

364365
def add_dummy_requests(

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorrt_llm.mapping import Mapping
2424

2525
from .finish_reason import FinishedState
26-
from .llm_request import LlmRequest, LlmRequestState
26+
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
2727
from .scheduler import ScheduledRequests
2828

2929

@@ -337,7 +337,7 @@ def update_requests(self, state: SampleState) -> None:
337337
new_token = add_token(req, new_tokens, beam=self.BEAM)
338338
stop = self._handle_stop_criteria(req, new_token)
339339
processed = 1
340-
if not stop and len(req.py_draft_tokens) > 0:
340+
if not stop and get_draft_token_length(req) > 0:
341341
num_accepted = self.process_draft_tokens(
342342
req, new_tokens, new_token)
343343
req.py_num_accepted_draft_tokens = num_accepted
@@ -401,7 +401,7 @@ def _process_requests(self,
401401
beam_width = self.MAX_BEAM_WIDTH
402402
beam = self.BEAM
403403
raw_logits = model_outputs["logits"]
404-
num_steps = [1 + len(req.py_draft_tokens) for req in requests]
404+
num_steps = [1 + get_draft_token_length(req) for req in requests]
405405
sum_steps = sum(num_steps)
406406
no_draft_tokens = len(requests) == sum_steps
407407
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None

tensorrt_llm/_torch/pyexecutor/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tensorrt_llm.bindings import executor as tb_executor
66
from tensorrt_llm.bindings import internal as tb_internal
77

8-
from .llm_request import LlmRequest, LlmRequestState
8+
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
99

1010
RequestList = list[LlmRequest]
1111

@@ -185,7 +185,7 @@ def schedule(
185185
self, active_requests: RequestList, inflight_request_ids: set[int]
186186
) -> tuple[list[LlmRequest], list[LlmRequest]]:
187187
for request in active_requests:
188-
if len(request.py_draft_tokens) > 0:
188+
if get_draft_token_length(request) > 0:
189189
request.draft_tokens = request.py_draft_tokens
190190
return self.impl(active_requests, inflight_request_ids,
191191
self.max_batch_size, self.max_num_tokens)

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
2+
from typing import List, Optional
33

4+
from ..pyexecutor.llm_request import LlmRequest
45
from ..pyexecutor.resource_manager import ResourceManager
56
from ..pyexecutor.scheduler import ScheduledRequests
67

@@ -21,3 +22,8 @@ def prepare_draft_tokens(
2122
scheduled_requests: The scheduled requests for this iteration
2223
"""
2324
raise NotImplementedError
25+
26+
@abstractmethod
27+
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
28+
"""Check if spec decode should be used for the current iteration."""
29+
raise NotImplementedError

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from tensorrt_llm._utils import nvtx_range
99
from tensorrt_llm.logger import logger
1010

11-
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig
11+
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
12+
SamplingConfig, get_draft_token_length)
1213
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
1314
from ..pyexecutor.sampler import Sampler, SampleState
1415
from ..pyexecutor.scheduler import ScheduledRequests
@@ -59,7 +60,6 @@ def __init__(
5960
# Configuration
6061
self.spec_config = spec_config
6162
self.max_draft_tokens = max_draft_tokens
62-
6363
# Sampling
6464
self.sampler = sampler
6565

@@ -214,7 +214,6 @@ def _prepare_draft_batch(
214214
if request.py_draft_pages_allocated == 0:
215215
# No space for draft tokens
216216
continue
217-
218217
# Stop drafting when we hit the max seqlen. We still need dummy draft
219218
# tokens attached to the requests to make sure everything works properly
220219
# with CUDA graph. These dummy tokens are already added by
@@ -320,7 +319,7 @@ def _pad_to_max_draft_tokens(self,
320319
"""Pad draft tokens to maximum length for all generation requests."""
321320
for req in scheduled_requests.generation_requests:
322321
max_draft_tokens = self.max_draft_tokens
323-
num_draft_tokens = len(req.py_draft_tokens)
322+
num_draft_tokens = get_draft_token_length(req)
324323
req.py_draft_tokens.extend(
325324
0 for _ in range(max_draft_tokens - num_draft_tokens))
326325

@@ -399,3 +398,6 @@ def prepare_draft_tokens(
399398
error_msg = str(e)
400399
logger.error(f"Encountered an error in decode: {error_msg}")
401400
raise e
401+
402+
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
403+
return True

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import chain
2+
from typing import List
23

34
from ordered_set import OrderedSet
45

@@ -198,3 +199,6 @@ def prepare_draft_tokens(
198199
pad_length = self.max_draft_len - len(draft_tokens)
199200
draft_tokens.extend([request.py_end_id] * pad_length)
200201
request.py_draft_tokens = draft_tokens
202+
203+
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
204+
return True

0 commit comments

Comments
 (0)