Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 5 additions & 193 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import weakref
from collections import deque, namedtuple
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -308,7 +308,7 @@ def __init__(self,
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
self.event_loop = trace_func(self.event_loop)

if self.draft_model_engine is not None:
if self.drafter is not None:
if self.event_loop.__name__ != self._executor_loop.__name__:
raise NotImplementedError(
"Drafting is not supported for selected executor loop. "
Expand Down Expand Up @@ -905,10 +905,6 @@ def _executor_loop_pp(self):

def _executor_loop(self):
torch.cuda.set_device(self.device_id)
is_ngram = hasattr(
self.model_engine, "spec_config"
) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram(
)
with self._profiler() as profile_step:
sample_state = None
iter_start_time = time.time()
Expand All @@ -931,7 +927,7 @@ def _executor_loop(self):

self._pad_attention_dp_dummy_request()

if self.draft_model_engine is not None or is_ngram or self.drafter is not None:
if self.drafter is not None:
self._prepare_draft_requests(self.active_requests)

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
Expand Down Expand Up @@ -971,11 +967,9 @@ def _executor_loop(self):
scheduled_batch)

self.resource_manager.prepare_resources(scheduled_batch)
if self.draft_model_engine is not None:
self._prepare_draft_tokens(scheduled_batch)

if self.drafter is not None:
self.drafter.prepare_draft_tokens(scheduled_batch)
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
Expand Down Expand Up @@ -1798,188 +1792,6 @@ def _update_requests(self, sample_state: SampleState):
logger.error(f"Encountered an error in sampling: {error_msg}")
self._handle_errors(error_msg)

@nvtx_range("_prepare_draft_batch")
def _prepare_draft_batch(
self, scheduled_requests: ScheduledRequests
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
"""
Prepares a batch for the draft model engine. Draft tokens are only produced
for generation requests.

The requests are prepared as follows:
1. The first time the draft engine sees a request, it's a context request.
2. Otherwise, if draft tokens were accepted on the last target model decoding
step, it's a chunked context request (we process all the accepted tokens together).
3. Otherwise, it's a generation request.
"""
try:
draft_batch = ScheduledRequests()

for request in scheduled_requests.generation_requests:
if request.py_draft_pages_allocated == 0:
# No space for draft tokens.
continue

# Stop drafting when we hit the max seqlen. We still need dummy draft
# tokens attached to the requests to make sure everything works properly
# with CUDA graph. These dummy tokens are already added by
# _prepare_draft_requests to make the KV cache/scheduler aware of the fact
# that we want to do spec decoding, so no need to do anything else here.
# This makes the perf for this case suboptimal, but that's OK - this is
# a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len:
continue

num_draft_tokens = len(
request.py_last_draft_tokens
) if request.py_last_draft_tokens is not None else 0
request.py_draft_tokens = []

num_accepted_tokens = request.py_num_accepted_draft_tokens
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
assert num_rejected_tokens >= 0

spec_config = self.model_engine.spec_config
beam_idx = 0
input_tokens = spec_config.get_draft_model_prompt(
request.get_tokens()[beam_idx])

def create_new_request(input_tokens):
return LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
input_tokens=input_tokens,
sampling_config=request.sampling_config,
return_perf_metrics=request.return_perf_metrics,
is_streaming=False,
is_draft=True)

if request.max_beam_num_tokens - 1 == request.py_prompt_len:
# This is the first time the draft model is seeing this request.
# Prepare a context request. We discard the first token and take
# the newly decoded one - this is the convention for EAGLE 2 and 3.
new_request = create_new_request(input_tokens)
draft_batch.context_requests.append(new_request)
elif num_accepted_tokens == 0:
new_request = create_new_request(input_tokens[:-1])
# Explicitly add the last token so get_last_tokens() returns
# the right value
new_request.add_new_token(input_tokens[-1], beam_idx)
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
draft_batch.generation_requests.append(new_request)
else:
new_request = create_new_request(input_tokens)
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1

draft_batch.context_requests.append(new_request)

new_request.py_stop_words_list = request.py_stop_words_list

return draft_batch

except Exception as e:
traceback.print_exc()
error_msg = str(e)
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)

@nvtx_range("_prepare_draft_tokens")
def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests):
if not self.draft_model_engine:
raise ValueError("Draft model engine is not set")

try:
draft_batch = self._prepare_draft_batch(scheduled_requests)

if draft_batch.batch_size == 0:
return
self.draft_seq_slot_manager.prepare_resources(draft_batch)

req_id_to_old_request = {
req.py_request_id: req
for req in scheduled_requests.all_requests()
}

# Disable cuda graph for the 1st draft model forward
if self.model_engine.spec_config.spec_dec_mode.needs_kv_cache_recompute(
):
with self.draft_model_engine.no_cuda_graph():
outputs = self.draft_model_engine.forward(
draft_batch, self.resource_manager)
else:
outputs = self.draft_model_engine.forward(
draft_batch, self.resource_manager)
if hasattr(self.draft_model_engine.model.model, 'd2t'):
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data

sample_state = self._sample_async(draft_batch, outputs)
previous_batch = sample_state

self._update_request_states(draft_batch)

def _process_decoded_tokens(draft_batch):
new_requests = []
for req in draft_batch.all_requests():
target_model_req = req_id_to_old_request[req.py_request_id]
target_model_req.py_draft_tokens.append(
req.get_last_tokens(0))
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
target_model_req.py_draft_tokens
) < target_model_req.py_draft_pages_allocated:
new_requests.append(req)
else:
self.draft_seq_slot_manager.free_resources(req)

return new_requests

# The TRTLLM attention kernels cannot handle generation requests with
# different seqlens. No issues with flashinfer, should we look into removing
# this? Just needs proper kernel support.
def _pad_to_max_draft_tokens():
for req in scheduled_requests.generation_requests:
max_draft_len = self.max_draft_len
num_draft_tokens = len(req.py_draft_tokens)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_len - num_draft_tokens))

draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests
draft_batch.context_requests = []

for i in range(self.max_draft_len - 1):
if len(draft_batch.generation_requests) == 0:
break

outputs = self.draft_model_engine.forward(
draft_batch,
self.resource_manager,
new_tensors_device=previous_batch.device)

if hasattr(self.draft_model_engine.model.model, 'd2t'):
outputs[
'd2t'] = self.draft_model_engine.model.model.d2t.data
sample_state = self._sample_async(draft_batch, outputs)
self._update_request_states(draft_batch)
self._update_requests(previous_batch)
new_requests = _process_decoded_tokens(
previous_batch.scheduled_requests)
draft_batch.generation_requests = new_requests
previous_batch = sample_state
self._update_requests(previous_batch)
new_requests = _process_decoded_tokens(
previous_batch.scheduled_requests)
_pad_to_max_draft_tokens()

except Exception as e:
traceback.print_exc()
error_msg = str(e)
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)

def _handle_errors(self, error_msg: Optional[str] = None):
error_responses = {}
error_msg = error_msg or "error"
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def create_py_executor(

# Drafter for speculative decoding
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
drafter = get_spec_drafter(model_engine, spec_resource_manager)
drafter = get_spec_drafter(model_engine, draft_model_engine, sampler,
spec_resource_manager)

with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from abc import ABC, abstractmethod
from typing import Optional

from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests


class Drafter(ABC):
"""Abstract base class for all drafter implementations."""

@abstractmethod
def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
"""
Prepare the drafter tokens for the forward computation this step.

Args:
scheduled_requests: The scheduled requests for this iteration
"""
raise NotImplementedError
Loading