Skip to content

Commit 069721d

Browse files
committed
[TRTLLM-6352][feat] Migrate EAGLE3 and draft/target speculation to Drafter
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 64aec14 commit 069721d

File tree

5 files changed

+443
-25
lines changed

5 files changed

+443
-25
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from tensorrt_llm.logger import logger
3131

3232
from ..distributed import Distributed
33-
from ..speculative.drafter import Drafter
33+
from ..speculative.drafter import Drafter, create_drafter
3434
from .kv_cache_transceiver import KvCacheTransceiver
3535
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
3636
LlmResponse, executor_request_to_llm_request)
@@ -959,8 +959,20 @@ def _executor_loop(self):
959959
scheduled_batch)
960960

961961
self.resource_manager.prepare_resources(scheduled_batch)
962-
if self.draft_model_engine is not None:
963-
self._prepare_draft_tokens(scheduled_batch)
962+
if self.draft_model_engine is not None and self.drafter is None:
963+
spec_resource_manager = self.resource_manager.get_resource_manager(
964+
ResourceManagerType.SPEC_RESOURCE_MANAGER)
965+
self.drafter = create_drafter(
966+
spec_decoding_mode=self.model_engine.spec_config.
967+
spec_dec_mode,
968+
spec_config=self.model_engine.spec_config,
969+
draft_model_engine=self.draft_model_engine,
970+
max_draft_tokens=self.max_draft_tokens,
971+
draft_seq_slot_manager=self.draft_seq_slot_manager,
972+
sampler=self.sampler,
973+
resource_manager=self.resource_manager,
974+
spec_resource_manager=spec_resource_manager,
975+
)
964976

965977
if self.drafter is not None:
966978
self.drafter.prepare_draft_tokens(scheduled_batch)

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from abc import ABC, abstractmethod
22

33
from ..pyexecutor.scheduler import ScheduledRequests
4+
from .interface import SpeculativeDecodingMode
45

56

67
class Drafter(ABC):
8+
"""Abstract base class for all drafter implementations."""
79

810
@abstractmethod
911
def prepare_draft_tokens(
@@ -12,5 +14,36 @@ def prepare_draft_tokens(
1214
) -> None:
1315
"""
1416
Prepare the drafter tokens for the forward computation this step.
17+
18+
Args:
19+
scheduled_requests: The scheduled requests for this iteration
1520
"""
1621
raise NotImplementedError
22+
23+
24+
def create_drafter(spec_decoding_mode: SpeculativeDecodingMode,
25+
**kwargs) -> Drafter:
26+
"""
27+
Factory function to create the appropriate drafter based on the mode.
28+
29+
Args:
30+
spec_decoding_mode: The speculative decoding mode
31+
**kwargs: Additional arguments for drafter construction
32+
33+
Returns:
34+
Drafter: The appropriate drafter instance
35+
36+
Raises:
37+
ValueError: If the speculative decoding mode is not supported
38+
"""
39+
match spec_decoding_mode:
40+
case SpeculativeDecodingMode.NGRAM:
41+
from .ngram import NGramDrafter
42+
return NGramDrafter(**kwargs)
43+
case SpeculativeDecodingMode.EAGLE3 | SpeculativeDecodingMode.DRAFT_TARGET:
44+
# Import here to avoid circular import
45+
from .model_drafter import ModelDrafter
46+
return ModelDrafter(**kwargs)
47+
case _:
48+
raise ValueError(
49+
f"Unsupported speculative decoding mode: {spec_decoding_mode}")

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -117,29 +117,40 @@ def prepare(self):
117117
# hidden state space before the target model forward.
118118
start_idx = 0
119119
if not self.is_draft_model:
120-
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
121-
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(
122-
req_id)
123-
self.eagle3_resource_manager.start_indices[slot_id] = start_idx
124-
start_idx += seq_len
120+
if self.request_ids is not None and self.seq_lens is not None:
121+
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
122+
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(
123+
req_id
124+
) if self.eagle3_resource_manager is not None else None
125+
if self.eagle3_resource_manager is not None and slot_id is not None:
126+
self.eagle3_resource_manager.start_indices[
127+
slot_id] = start_idx
128+
start_idx += seq_len
125129
# Prepare hidden states gather ids
126130
hidden_states_read_indices = []
127131
hidden_states_write_indices = []
128-
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
129-
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(req_id)
130-
start_idx = self.eagle3_resource_manager.start_indices[slot_id]
131-
# If this is the first draft or the target model forward, we need to
132-
# read/write all of the hidden states, otherwise, only read the last token
133-
if is_first_draft or not self.is_draft_model:
134-
hidden_states_read_indices.extend(
135-
list(range(start_idx, start_idx + seq_len)))
136-
hidden_states_write_indices.extend(
137-
list(range(start_idx, start_idx + seq_len)))
138-
else:
139-
old_seq_len = self.eagle3_resource_manager.seq_lens[slot_id]
140-
hidden_states_read_indices.append(start_idx + old_seq_len - 1)
141-
hidden_states_write_indices.append(start_idx + seq_len - 1)
142-
self.eagle3_resource_manager.seq_lens[slot_id] = seq_len
132+
if self.request_ids is not None and self.seq_lens is not None:
133+
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
134+
if self.eagle3_resource_manager is not None:
135+
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(
136+
req_id)
137+
start_idx = self.eagle3_resource_manager.start_indices[
138+
slot_id]
139+
# If this is the first draft or the target model forward, we need to
140+
# read/write all of the hidden states, otherwise, only read the last token
141+
if is_first_draft or not self.is_draft_model:
142+
hidden_states_read_indices.extend(
143+
list(range(start_idx, start_idx + seq_len)))
144+
hidden_states_write_indices.extend(
145+
list(range(start_idx, start_idx + seq_len)))
146+
else:
147+
old_seq_len = self.eagle3_resource_manager.seq_lens[
148+
slot_id]
149+
hidden_states_read_indices.append(start_idx +
150+
old_seq_len - 1)
151+
hidden_states_write_indices.append(start_idx + seq_len -
152+
1)
153+
self.eagle3_resource_manager.seq_lens[slot_id] = seq_len
143154
# Prepare hidden states gather ids
144155
self.hidden_states_read_indices_host = torch.tensor(
145156
hidden_states_read_indices, dtype=torch.long, pin_memory=True)

0 commit comments

Comments
 (0)