Skip to content

Commit 1920b35

Browse files
committed
Update utils and remove legacy code
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 069721d commit 1920b35

File tree

7 files changed

+76
-286
lines changed

7 files changed

+76
-286
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 6 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import weakref
1212
from collections import deque, namedtuple
1313
from contextlib import contextmanager
14-
from typing import Dict, List, Optional, Tuple, Union
14+
from typing import Dict, List, Optional, Union
1515

1616
import torch
1717

@@ -30,7 +30,7 @@
3030
from tensorrt_llm.logger import logger
3131

3232
from ..distributed import Distributed
33-
from ..speculative.drafter import Drafter, create_drafter
33+
from ..speculative.drafter import Drafter
3434
from .kv_cache_transceiver import KvCacheTransceiver
3535
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
3636
LlmResponse, executor_request_to_llm_request)
@@ -305,7 +305,7 @@ def __init__(self,
305305
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
306306
self.event_loop = trace_func(self.event_loop)
307307

308-
if self.draft_model_engine is not None:
308+
if self.drafter is not None:
309309
if self.event_loop.__name__ != self._executor_loop.__name__:
310310
raise NotImplementedError(
311311
"Drafting is not supported for selected executor loop. "
@@ -918,8 +918,7 @@ def _executor_loop(self):
918918

919919
self._pad_attention_dp_dummy_request()
920920

921-
if self.draft_model_engine is not None or hasattr(
922-
self, 'drafter') and self.drafter is not None:
921+
if self.drafter is not None:
923922
self._prepare_draft_requests(self.active_requests)
924923

925924
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
@@ -959,23 +958,9 @@ def _executor_loop(self):
959958
scheduled_batch)
960959

961960
self.resource_manager.prepare_resources(scheduled_batch)
962-
if self.draft_model_engine is not None and self.drafter is None:
963-
spec_resource_manager = self.resource_manager.get_resource_manager(
964-
ResourceManagerType.SPEC_RESOURCE_MANAGER)
965-
self.drafter = create_drafter(
966-
spec_decoding_mode=self.model_engine.spec_config.
967-
spec_dec_mode,
968-
spec_config=self.model_engine.spec_config,
969-
draft_model_engine=self.draft_model_engine,
970-
max_draft_tokens=self.max_draft_tokens,
971-
draft_seq_slot_manager=self.draft_seq_slot_manager,
972-
sampler=self.sampler,
973-
resource_manager=self.resource_manager,
974-
spec_resource_manager=spec_resource_manager,
975-
)
976-
977961
if self.drafter is not None:
978-
self.drafter.prepare_draft_tokens(scheduled_batch)
962+
self.drafter.prepare_draft_tokens(
963+
scheduled_batch, self.resource_manager)
979964

980965
if self.kv_cache_transceiver:
981966
# For generation requests which have completed KV cache transfer
@@ -1780,188 +1765,6 @@ def _update_requests(self, sample_state: SampleState):
17801765
logger.error(f"Encountered an error in sampling: {error_msg}")
17811766
self._handle_errors(error_msg)
17821767

1783-
@nvtx_range("_prepare_draft_batch")
1784-
def _prepare_draft_batch(
1785-
self, scheduled_requests: ScheduledRequests
1786-
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
1787-
"""
1788-
Prepares a batch for the draft model engine. Draft tokens are only produced
1789-
for generation requests.
1790-
1791-
The requests are prepared as follows:
1792-
1. The first time the draft engine sees a request, it's a context request.
1793-
2. Otherwise, if draft tokens were accepted on the last target model decoding
1794-
step, it's a chunked context request (we process all the accepted tokens together).
1795-
3. Otherwise, it's a generation request.
1796-
"""
1797-
try:
1798-
draft_batch = ScheduledRequests()
1799-
1800-
for request in scheduled_requests.generation_requests:
1801-
if request.py_draft_pages_allocated == 0:
1802-
# No space for draft tokens.
1803-
continue
1804-
1805-
# Stop drafting when we hit the max seqlen. We still need dummy draft
1806-
# tokens attached to the requests to make sure everything works properly
1807-
# with CUDA graph. These dummy tokens are already added by
1808-
# _prepare_draft_requests to make the KV cache/scheduler aware of the fact
1809-
# that we want to do spec decoding, so no need to do anything else here.
1810-
# This makes the perf for this case suboptimal, but that's OK - this is
1811-
# a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
1812-
if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len:
1813-
continue
1814-
1815-
num_draft_tokens = len(
1816-
request.py_last_draft_tokens
1817-
) if request.py_last_draft_tokens is not None else 0
1818-
request.py_draft_tokens = []
1819-
1820-
num_accepted_tokens = request.py_num_accepted_draft_tokens
1821-
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
1822-
assert num_rejected_tokens >= 0
1823-
1824-
spec_config = self.model_engine.spec_config
1825-
beam_idx = 0
1826-
input_tokens = spec_config.get_draft_model_prompt(
1827-
request.get_tokens()[beam_idx])
1828-
1829-
def create_new_request(input_tokens):
1830-
return LlmRequest(
1831-
request_id=request.py_request_id,
1832-
max_new_tokens=request.py_max_new_tokens,
1833-
input_tokens=input_tokens,
1834-
sampling_config=request.sampling_config,
1835-
return_perf_metrics=request.return_perf_metrics,
1836-
is_streaming=False,
1837-
is_draft=True)
1838-
1839-
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
1840-
# This is the first time the draft model is seeing this request.
1841-
# Prepare a context request. We discard the first token and take
1842-
# the newly decoded one - this is the convention for EAGLE 2 and 3.
1843-
new_request = create_new_request(input_tokens)
1844-
draft_batch.context_requests.append(new_request)
1845-
elif num_accepted_tokens == 0:
1846-
new_request = create_new_request(input_tokens[:-1])
1847-
# Explicitly add the last token so get_last_tokens() returns
1848-
# the right value
1849-
new_request.add_new_token(input_tokens[-1], beam_idx)
1850-
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
1851-
draft_batch.generation_requests.append(new_request)
1852-
else:
1853-
new_request = create_new_request(input_tokens)
1854-
new_request.context_chunk_size = num_accepted_tokens + 1
1855-
new_request.context_current_position = len(
1856-
input_tokens) - num_accepted_tokens - 1
1857-
new_request.context_chunk_size = num_accepted_tokens + 1
1858-
new_request.context_current_position = len(
1859-
input_tokens) - num_accepted_tokens - 1
1860-
1861-
draft_batch.context_requests.append(new_request)
1862-
1863-
new_request.py_stop_words_list = request.py_stop_words_list
1864-
1865-
return draft_batch
1866-
1867-
except Exception as e:
1868-
traceback.print_exc()
1869-
error_msg = str(e)
1870-
logger.error(f"Encountered an error in decode: {error_msg}")
1871-
self._handle_errors(error_msg)
1872-
1873-
@nvtx_range("_prepare_draft_tokens")
1874-
def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests):
1875-
if not self.draft_model_engine:
1876-
raise ValueError("Draft model engine is not set")
1877-
1878-
try:
1879-
draft_batch = self._prepare_draft_batch(scheduled_requests)
1880-
1881-
if draft_batch.batch_size == 0:
1882-
return
1883-
self.draft_seq_slot_manager.prepare_resources(draft_batch)
1884-
1885-
req_id_to_old_request = {
1886-
req.py_request_id: req
1887-
for req in scheduled_requests.all_requests()
1888-
}
1889-
1890-
# Disable cuda graph for the 1st draft model forward
1891-
if self.model_engine.spec_config.spec_dec_mode.needs_kv_cache_recompute(
1892-
):
1893-
with self.draft_model_engine.no_cuda_graph():
1894-
outputs = self.draft_model_engine.forward(
1895-
draft_batch, self.resource_manager)
1896-
else:
1897-
outputs = self.draft_model_engine.forward(
1898-
draft_batch, self.resource_manager)
1899-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
1900-
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
1901-
1902-
sample_state = self._sample_async(draft_batch, outputs)
1903-
previous_batch = sample_state
1904-
1905-
self._update_request_states(draft_batch)
1906-
1907-
def _process_decoded_tokens(draft_batch):
1908-
new_requests = []
1909-
for req in draft_batch.all_requests():
1910-
target_model_req = req_id_to_old_request[req.py_request_id]
1911-
target_model_req.py_draft_tokens.append(
1912-
req.get_last_tokens(0))
1913-
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
1914-
target_model_req.py_draft_tokens
1915-
) < target_model_req.py_draft_pages_allocated:
1916-
new_requests.append(req)
1917-
else:
1918-
self.draft_seq_slot_manager.free_resources(req)
1919-
1920-
return new_requests
1921-
1922-
# The TRTLLM attention kernels cannot handle generation requests with
1923-
# different seqlens. No issues with flashinfer, should we look into removing
1924-
# this? Just needs proper kernel support.
1925-
def _pad_to_max_draft_tokens():
1926-
for req in scheduled_requests.generation_requests:
1927-
max_draft_len = self.max_draft_len
1928-
num_draft_tokens = len(req.py_draft_tokens)
1929-
req.py_draft_tokens.extend(
1930-
0 for _ in range(max_draft_len - num_draft_tokens))
1931-
1932-
draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests
1933-
draft_batch.context_requests = []
1934-
1935-
for i in range(self.max_draft_len - 1):
1936-
if len(draft_batch.generation_requests) == 0:
1937-
break
1938-
1939-
outputs = self.draft_model_engine.forward(
1940-
draft_batch,
1941-
self.resource_manager,
1942-
new_tensors_device=previous_batch.device)
1943-
1944-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
1945-
outputs[
1946-
'd2t'] = self.draft_model_engine.model.model.d2t.data
1947-
sample_state = self._sample_async(draft_batch, outputs)
1948-
self._update_request_states(draft_batch)
1949-
self._update_requests(previous_batch)
1950-
new_requests = _process_decoded_tokens(
1951-
previous_batch.scheduled_requests)
1952-
draft_batch.generation_requests = new_requests
1953-
previous_batch = sample_state
1954-
self._update_requests(previous_batch)
1955-
new_requests = _process_decoded_tokens(
1956-
previous_batch.scheduled_requests)
1957-
_pad_to_max_draft_tokens()
1958-
1959-
except Exception as e:
1960-
traceback.print_exc()
1961-
error_msg = str(e)
1962-
logger.error(f"Encountered an error in decode: {error_msg}")
1963-
self._handle_errors(error_msg)
1964-
19651768
def _handle_errors(self, error_msg: Optional[str] = None):
19661769
error_responses = {}
19671770
error_msg = error_msg or "error"

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ def create_py_executor(
369369

370370
# Drafter for speculative decoding
371371
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
372-
drafter = get_spec_drafter(model_engine, spec_resource_manager)
372+
drafter = get_spec_drafter(model_engine, draft_model_engine, sampler,
373+
spec_resource_manager)
373374

374375
with mem_monitor.observe_creation_stage(
375376
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
2+
from typing import Optional
23

4+
from ..pyexecutor.resource_manager import ResourceManager
35
from ..pyexecutor.scheduler import ScheduledRequests
4-
from .interface import SpeculativeDecodingMode
56

67

78
class Drafter(ABC):
@@ -11,6 +12,7 @@ class Drafter(ABC):
1112
def prepare_draft_tokens(
1213
self,
1314
scheduled_requests: ScheduledRequests,
15+
resource_manager: Optional[ResourceManager] = None,
1416
) -> None:
1517
"""
1618
Prepare the drafter tokens for the forward computation this step.
@@ -19,31 +21,3 @@ def prepare_draft_tokens(
1921
scheduled_requests: The scheduled requests for this iteration
2022
"""
2123
raise NotImplementedError
22-
23-
24-
def create_drafter(spec_decoding_mode: SpeculativeDecodingMode,
25-
**kwargs) -> Drafter:
26-
"""
27-
Factory function to create the appropriate drafter based on the mode.
28-
29-
Args:
30-
spec_decoding_mode: The speculative decoding mode
31-
**kwargs: Additional arguments for drafter construction
32-
33-
Returns:
34-
Drafter: The appropriate drafter instance
35-
36-
Raises:
37-
ValueError: If the speculative decoding mode is not supported
38-
"""
39-
match spec_decoding_mode:
40-
case SpeculativeDecodingMode.NGRAM:
41-
from .ngram import NGramDrafter
42-
return NGramDrafter(**kwargs)
43-
case SpeculativeDecodingMode.EAGLE3 | SpeculativeDecodingMode.DRAFT_TARGET:
44-
# Import here to avoid circular import
45-
from .model_drafter import ModelDrafter
46-
return ModelDrafter(**kwargs)
47-
case _:
48-
raise ValueError(
49-
f"Unsupported speculative decoding mode: {spec_decoding_mode}")

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -117,40 +117,29 @@ def prepare(self):
117117
# hidden state space before the target model forward.
118118
start_idx = 0
119119
if not self.is_draft_model:
120-
if self.request_ids is not None and self.seq_lens is not None:
121-
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
122-
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(
123-
req_id
124-
) if self.eagle3_resource_manager is not None else None
125-
if self.eagle3_resource_manager is not None and slot_id is not None:
126-
self.eagle3_resource_manager.start_indices[
127-
slot_id] = start_idx
128-
start_idx += seq_len
120+
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
121+
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(
122+
req_id)
123+
self.eagle3_resource_manager.start_indices[slot_id] = start_idx
124+
start_idx += seq_len
129125
# Prepare hidden states gather ids
130126
hidden_states_read_indices = []
131127
hidden_states_write_indices = []
132-
if self.request_ids is not None and self.seq_lens is not None:
133-
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
134-
if self.eagle3_resource_manager is not None:
135-
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(
136-
req_id)
137-
start_idx = self.eagle3_resource_manager.start_indices[
138-
slot_id]
139-
# If this is the first draft or the target model forward, we need to
140-
# read/write all of the hidden states, otherwise, only read the last token
141-
if is_first_draft or not self.is_draft_model:
142-
hidden_states_read_indices.extend(
143-
list(range(start_idx, start_idx + seq_len)))
144-
hidden_states_write_indices.extend(
145-
list(range(start_idx, start_idx + seq_len)))
146-
else:
147-
old_seq_len = self.eagle3_resource_manager.seq_lens[
148-
slot_id]
149-
hidden_states_read_indices.append(start_idx +
150-
old_seq_len - 1)
151-
hidden_states_write_indices.append(start_idx + seq_len -
152-
1)
153-
self.eagle3_resource_manager.seq_lens[slot_id] = seq_len
128+
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
129+
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(req_id)
130+
start_idx = self.eagle3_resource_manager.start_indices[slot_id]
131+
# If this is the first draft or the target model forward, we need to
132+
# read/write all of the hidden states, otherwise, only read the last token
133+
if is_first_draft or not self.is_draft_model:
134+
hidden_states_read_indices.extend(
135+
list(range(start_idx, start_idx + seq_len)))
136+
hidden_states_write_indices.extend(
137+
list(range(start_idx, start_idx + seq_len)))
138+
else:
139+
old_seq_len = self.eagle3_resource_manager.seq_lens[slot_id]
140+
hidden_states_read_indices.append(start_idx + old_seq_len - 1)
141+
hidden_states_write_indices.append(start_idx + seq_len - 1)
142+
self.eagle3_resource_manager.seq_lens[slot_id] = seq_len
154143
# Prepare hidden states gather ids
155144
self.hidden_states_read_indices_host = torch.tensor(
156145
hidden_states_read_indices, dtype=torch.long, pin_memory=True)

0 commit comments

Comments
 (0)