Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,17 @@ def executor_request_to_llm_request(
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
None))
return llm_request


def get_draft_token_length(request: LlmRequest) -> int:
"""Get the length of draft tokens for a given request.

Args:
request: The LlmRequest to get draft token length for

Returns:
The number of draft tokens, or 0 if no draft tokens exist
"""
if request.py_draft_tokens is not None:
return len(request.py_draft_tokens)
return 0
152 changes: 89 additions & 63 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,9 @@ def _prepare_and_schedule_batch(self):
self._pad_attention_dp_dummy_request()

if self.drafter is not None:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests)
self.model_engine.enable_spec_decode = self.use_spec_decode
self._prepare_draft_requests(self.active_requests)

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
Expand Down Expand Up @@ -922,7 +925,7 @@ def _executor_loop(self):
self._handle_first_token_response(scheduled_batch)

self.resource_manager.prepare_resources(scheduled_batch)
if self.drafter is not None:
if self.drafter is not None and self.use_spec_decode:
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)

Expand Down Expand Up @@ -973,7 +976,7 @@ def _prepare_draft_requests(self, requests):
req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_len

if max_draft_len > 0:
if max_draft_len > 0 and self.use_spec_decode:
req.py_draft_tokens = [0] * max_draft_len
req.py_draft_pages_allocated = max_draft_len
else:
Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from ..._utils import binding_dtype_size, nvtx_range
from ...logger import logger
from ...mapping import Mapping
from .llm_request import LlmRequest, LlmRequestState, SamplingConfig
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
get_draft_token_length)
from .scheduler import ScheduledRequests

if ENABLE_MULTI_DEVICE:
Expand Down Expand Up @@ -368,12 +369,12 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
req_beam_width, req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req.py_request_id)
for _ in range(len(req.py_draft_tokens)):
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

for req in generation_batch:
self.impl.add_token(req.py_request_id)
for _ in range(len(req.py_draft_tokens)):
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

def add_dummy_requests(
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tensorrt_llm.mapping import Mapping

from .finish_reason import FinishedState
from .llm_request import LlmRequest, LlmRequestState
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
from .scheduler import ScheduledRequests


Expand Down Expand Up @@ -337,7 +337,7 @@ def update_requests(self, state: SampleState) -> None:
new_token = add_token(req, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(req, new_token)
processed = 1
if not stop and len(req.py_draft_tokens) > 0:
if not stop and get_draft_token_length(req) > 0:
num_accepted = self.process_draft_tokens(
req, new_tokens, new_token)
req.py_num_accepted_draft_tokens = num_accepted
Expand Down Expand Up @@ -401,7 +401,7 @@ def _process_requests(self,
beam_width = self.MAX_BEAM_WIDTH
beam = self.BEAM
raw_logits = model_outputs["logits"]
num_steps = [1 + len(req.py_draft_tokens) for req in requests]
num_steps = [1 + get_draft_token_length(req) for req in requests]
sum_steps = sum(num_steps)
no_draft_tokens = len(requests) == sum_steps
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tensorrt_llm.bindings import executor as tb_executor
from tensorrt_llm.bindings import internal as tb_internal

from .llm_request import LlmRequest, LlmRequestState
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length

RequestList = list[LlmRequest]

Expand Down Expand Up @@ -185,7 +185,7 @@ def schedule(
self, active_requests: RequestList, inflight_request_ids: set[int]
) -> tuple[list[LlmRequest], list[LlmRequest]]:
for request in active_requests:
if len(request.py_draft_tokens) > 0:
if get_draft_token_length(request) > 0:
request.draft_tokens = request.py_draft_tokens
return self.impl(active_requests, inflight_request_ids,
self.max_batch_size, self.max_num_tokens)
Expand Down
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional

from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests

Expand All @@ -21,3 +22,7 @@ def prepare_draft_tokens(
scheduled_requests: The scheduled requests for this iteration
"""
raise NotImplementedError

def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
"""Check if spec decode should be used for the current iteration."""
return True
7 changes: 3 additions & 4 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.logger import logger

from ..pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
SamplingConfig, get_draft_token_length)
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
from ..pyexecutor.sampler import Sampler, SampleState
from ..pyexecutor.scheduler import ScheduledRequests
Expand Down Expand Up @@ -59,7 +60,6 @@ def __init__(
# Configuration
self.spec_config = spec_config
self.max_draft_tokens = max_draft_tokens

# Sampling
self.sampler = sampler

Expand Down Expand Up @@ -214,7 +214,6 @@ def _prepare_draft_batch(
if request.py_draft_pages_allocated == 0:
# No space for draft tokens
continue

# Stop drafting when we hit the max seqlen. We still need dummy draft
# tokens attached to the requests to make sure everything works properly
# with CUDA graph. These dummy tokens are already added by
Expand Down Expand Up @@ -320,7 +319,7 @@ def _pad_to_max_draft_tokens(self,
"""Pad draft tokens to maximum length for all generation requests."""
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_tokens
num_draft_tokens = len(req.py_draft_tokens)
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))

Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ l0_b200:
- unittest/_torch/auto_deploy/unit/singlegpu
- unittest/_torch/speculative/test_eagle3.py
- unittest/_torch/speculative/test_kv_cache_reuse.py
- unittest/_torch/speculative/test_dynamic_spec_decode.py
- condition:
ranges:
system_gpu_count:
Expand Down
91 changes: 91 additions & 0 deletions tests/unittest/_torch/speculative/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import sys
import unittest
from unittest.mock import patch

import pytest
import torch
from utils.llm_data import llm_models_root

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig)

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


@pytest.mark.high_cuda_memory
def test_dynamic_spec_decode():
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")

models_path = llm_models_root()
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"

max_batch_size = 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=True,
free_gpu_memory_fraction=0.5)
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])

llm_common_config = dict(
model=target_model_dir,
attn_backend="TRTLLM",
disable_overlap_scheduler=True,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
# This max_seq_len is larger than the one specified
# in the llama 3 8B eagle's config. We want to make sure
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
)

spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=False,
)

# Mock should_use_spec_decode to return True for first two calls, then False
def mock_should_use_spec_decode(self, requests):
if not hasattr(mock_should_use_spec_decode, 'call_count'):
mock_should_use_spec_decode.call_count = 0
mock_should_use_spec_decode.call_count += 1
return mock_should_use_spec_decode.call_count <= 2

with patch(
'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode',
side_effect=mock_should_use_spec_decode):
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
sampling_params = SamplingParams(max_tokens=128, temperature=0)

# Output tests
prompts = [
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0)

results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [
result.outputs[0].text for result in results_spec
]
llm_spec.shutdown()

llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()

for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref


if __name__ == "__main__":
unittest.main()