Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def __init__(
self.py_batch_idx = None
self.py_rewind_len = 0
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
self.py_last_context_chunk = (None, None)
self.py_last_draft_tokens = None
self.py_num_accepted_draft_tokens = 0
self.py_decoding_iter = 0
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,10 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):

for request in scheduled_requests.context_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests
request.py_last_context_chunk = (
request.context_current_position,
request.context_current_position +
request.context_chunk_size)
request.move_to_next_context_chunk()
if request.context_remaining_length == 0:
request.state = LlmRequestState.GENERATION_IN_PROGRESS
Expand Down
46 changes: 39 additions & 7 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,17 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]:
def _create_context_request(self, request: LlmRequest,
input_tokens: Any) -> LlmRequest:
"""Create a context request for first-time drafting."""
return self._create_draft_request(request.py_request_id,
request.py_max_new_tokens,
input_tokens, request.sampling_config,
request.return_perf_metrics)
new_request = self._create_draft_request(request.py_request_id,
request.py_max_new_tokens,
input_tokens,
request.sampling_config,
request.return_perf_metrics)

begin_compute, end_compute = request.py_last_context_chunk
if begin_compute is not None:
new_request.context_current_position = begin_compute
new_request.context_chunk_size = end_compute - begin_compute
return new_request

def _create_generation_request(self, request: LlmRequest,
input_tokens: Any) -> LlmRequest:
Expand All @@ -110,10 +117,13 @@ def _create_generation_request(self, request: LlmRequest,
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
return new_request

def _create_chunked_context_request(self, request: LlmRequest,
def _create_accepted_tokens_request(self, request: LlmRequest,
input_tokens: Any,
num_accepted_tokens: int) -> LlmRequest:
"""Create a chunked context request when some tokens were accepted."""
"""
Create a chunked context request for accepted tokens.
Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3)
"""
new_request = self._create_draft_request(request.py_request_id,
request.py_max_new_tokens,
input_tokens,
Expand Down Expand Up @@ -146,7 +156,7 @@ def _create_draft_request_for_request(

# Tokens accepted - chunked context request
else:
return self._create_chunked_context_request(request, input_tokens,
return self._create_accepted_tokens_request(request, input_tokens,
num_accepted_tokens)

def _add_to_draft_batch(self, draft_batch: ScheduledRequests,
Expand Down Expand Up @@ -184,6 +194,22 @@ def _prepare_draft_batch(
try:
draft_batch = ScheduledRequests()

for request in scheduled_requests.context_requests:
if request.is_first_context_chunk:
# Ignore requests which still need to be processed by the target model.
continue

# We hit this path if we're doing chunked prefill. The target model processed
# a prefill chunk on the last iteration. Now, we need to fill in the KV cache
# for the draft model too.
all_tokens = request.get_tokens()[0]
input_tokens = get_draft_model_prompt(
self.spec_config.spec_dec_mode, all_tokens)

new_request = self._create_context_request(
request, input_tokens)
self._add_to_draft_batch(draft_batch, new_request, request)

for request in scheduled_requests.generation_requests:
if request.py_draft_pages_allocated == 0:
# No space for draft tokens
Expand Down Expand Up @@ -273,6 +299,12 @@ def _process_decoded_tokens(
new_requests = []
for req in draft_batch.all_requests():
target_model_req = req_id_to_old_request[req.py_request_id]
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
# This is a chunked prefill request and we have more prefill chunks
# to process. Defer adding draft tokens until the whole prompt is processed.
self.draft_seq_slot_manager.free_resources(req)
continue

target_model_req.py_draft_tokens.append(req.get_last_tokens(0))
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
target_model_req.py_draft_tokens
Expand Down
42 changes: 27 additions & 15 deletions tests/unittest/_torch/speculative/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@


@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model",
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
[
[True, "TRTLLM", True, False, False],
[False, "TRTLLM", True, False, False],
[True, "TRTLLM", True, True, False],
[False, "TRTLLM", True, True, False],
[True, "FLASHINFER", True, False, False],
[False, "FLASHINFER", True, False, False],
[False, "TRTLLM", False, True, True],
[True, "TRTLLM", False, True, True],
[True, "TRTLLM", True, False, False, False],
[False, "TRTLLM", True, False, False, False],
[True, "FLASHINFER", True, False, False, False],
[False, "FLASHINFER", True, False, False, False],
[False, "TRTLLM", False, True, True, False],
[True, "TRTLLM", False, True, True, False],
[True, "TRTLLM", True, False, True, True],
[True, "TRTLLM", True, False, False, True],
])
@pytest.mark.high_cuda_memory
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
disable_overlap_scheduler: bool, enable_block_reuse: bool,
use_one_model: bool):
use_one_model: bool, enable_chunked_prefill: bool):
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
Expand Down Expand Up @@ -59,7 +59,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
enable_chunked_prefill=enable_chunked_prefill,
)
if enable_chunked_prefill:
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
llm_common_config['max_num_tokens'] = 64

spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
Expand All @@ -71,7 +75,19 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)

# Acceptance rate tests
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
if enable_chunked_prefill:
# Use a long prompt for chunked prefill tests.
prompts = [
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
]
tok_ids = llm_spec.tokenizer.encode(prompts[0])
else:
prompts = [
"The capital of France is",
"The president of the United States is",
]
tok_ids = llm_spec.tokenizer.encode("The future of AI is")

num_tokens = 0
num_drafted = 0
num_accepted = 0
Expand All @@ -88,10 +104,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
assert accept_rate > 0.15

# Output tests
prompts = [
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0)

results_spec = llm_spec.generate(prompts, sampling_params)
Expand Down