Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ python3 examples/pytorch/quickstart_advanced.py \
--max_matching_ngram_size=2 \
--spec_decode_nextn=4
```

```bash
# Draft Taret
python3 examples/pytorch/quickstart_advanced.py \
--model_dir meta-llama/Llama-3.1-8B-Instruct \
--spec_decode_algo draft_target \
--spec_decode_nextn 5 \
--draft_model_dir meta-llama/Llama-3.2-1B-Instruct \
--disable_overlap_scheduler
```
17 changes: 12 additions & 5 deletions examples/pytorch/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
MTPDecodingConfig, NGramDecodingConfig,
TorchCompileConfig)
from tensorrt_llm.llmapi import (DraftTargetDecodingConfig, EagleDecodingConfig,
KvCacheConfig, MTPDecodingConfig,
NGramDecodingConfig, TorchCompileConfig)

example_prompts = [
"Hello, my name is",
Expand Down Expand Up @@ -109,7 +109,10 @@ def add_llm_args(parser):
# Speculative decoding
parser.add_argument('--spec_decode_algo', type=str, default=None)
parser.add_argument('--spec_decode_nextn', type=int, default=1)
parser.add_argument('--eagle_model_dir', type=str, default=None)
parser.add_argument('--draft_model_dir',
'--eagle_model_dir',
type=str,
default=None)
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
parser.add_argument('--use_one_model', default=False, action='store_true')

Expand Down Expand Up @@ -166,8 +169,12 @@ def setup_llm(args):
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_nextn,
pytorch_eagle_weights_path=args.eagle_model_dir,
pytorch_weights_path=args.draft_model_dir,
eagle3_one_model=args.use_one_model)
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_nextn,
pytorch_weights_path=args.draft_model_dir)
elif spec_decode_algo == "NGRAM":
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=args.spec_decode_nextn,
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,9 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1):
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[token_idx]
num_tokens = request.add_new_token(new_token, beam_idx)
self._handle_stop_criteria(request, new_token, num_tokens,
beam_idx)
if self._handle_stop_criteria(request, new_token, num_tokens,
beam_idx):
continue

# Accept draft tokens (if we have any) if and only if they match the new
# token exactly.
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/speculative/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .draft_target import DraftTargetConfig
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
from .interface import SpecConfig, SpecMetadata
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
Expand All @@ -9,5 +10,5 @@
"SpecConfig", "SpecMetadata", "MTPConfig", "MTPEagleWorker",
"MTPSpecMetadata", "MTPWorker", "Eagle3Config", "Eagle3SpecMetadata",
"get_spec_metadata", "get_spec_resource_manager", "get_spec_decoder",
"get_num_spec_layers", "get_spec_worker", "NGramConfig"
"get_num_spec_layers", "get_spec_worker", "NGramConfig", "DraftTargetConfig"
]
35 changes: 35 additions & 0 deletions tensorrt_llm/_torch/speculative/draft_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from dataclasses import dataclass

import torch

from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode


@dataclass
class DraftTargetConfig(SpecConfig):
spec_dec_name: str = "DRAFT_TARGET"

def __post_init__(self):
if self.draft_model_path is None:
raise ValueError("Path to Draft weights must be specified.")

self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.num_extra_kv_tokens = 0

def update_from_model_config(self, model_config):
pass

def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
return input_tokens


@dataclass
class DraftTargetSpecMetadata(SpecMetadata):

def __post_init__(self):
pass

def prepare(self):
pass
10 changes: 7 additions & 3 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class SpeculativeDecodingMode(IntEnum):
EAGLE3 = auto()
EAGLE3_ONE_MODEL = auto()
NGRAM = auto()
DRAFT_TARGET = auto()
NONE = auto()

def is_mtp(self):
Expand All @@ -39,6 +40,9 @@ def is_ngram(self):
def is_none(self):
return self == SpeculativeDecodingMode.NONE

def is_draft_target(self):
return self == SpeculativeDecodingMode.DRAFT_TARGET

def without_logits(self):
return self.is_mtp() or self.is_eagle3_one_model()

Expand All @@ -49,7 +53,7 @@ def support_overlap_scheduler(self):
return self.is_mtp() or self.is_eagle3_one_model()

def has_draft_model(self):
return self.is_eagle3()
return self.is_eagle3() or self.is_draft_target()

def needs_kv_cache_recompute(self):
"""
Expand Down Expand Up @@ -77,8 +81,8 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""

# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
return (self.is_eagle3()
and not (issubclass(attention_backend, TrtllmAttention)
return ((self.is_eagle3() or self.is_draft_target())
and not (isinstance(attention_backend, TrtllmAttention)
and get_sm_version() == 100)) or self.is_ngram()

@staticmethod
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/speculative/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .draft_target import DraftTargetSpecMetadata
from .eagle3 import (Eagle3OneModelDecoder, Eagle3OneModelSpecMetadata,
Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3Sampler,
Eagle3SpecMetadata)
Expand Down Expand Up @@ -36,6 +37,11 @@ def get_spec_metadata(spec_config,
num_layers=spec_config.num_layers,
hidden_size=spec_config.hidden_size,
max_num_tokens=max_num_tokens)
elif spec_config.spec_dec_mode.is_draft_target():
return DraftTargetSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests)
else:
return None

Expand Down
14 changes: 9 additions & 5 deletions tensorrt_llm/llmapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from ..sampling_params import GuidedDecodingParams, SamplingParams
from .build_cache import BuildCacheConfig
from .llm import LLM, RequestOutput
# yapf: disable
from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
CapacitySchedulerPolicy, ContextChunkingPolicy,
DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig,
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs)
DraftTargetDecodingConfig, DynamicBatchConfig,
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
MedusaDecodingConfig, MTPDecodingConfig,
NGramDecodingConfig, SchedulerConfig, TorchCompileConfig,
TorchLlmArgs, TrtLlmArgs)
# yapf: enable
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
QuantConfig)
from .mpi_session import MpiCommSession
Expand Down Expand Up @@ -43,6 +46,7 @@
'CacheTransceiverConfig',
'NGramDecodingConfig',
'TorchCompileConfig',
'DraftTargetDecodingConfig',
'LlmArgs',
'TorchLlmArgs',
'TrtLlmArgs',
Expand Down
34 changes: 28 additions & 6 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def from_dict(cls, data: dict):
"Eagle": EagleDecodingConfig,
"Lookahead": LookaheadDecodingConfig,
"NGram": NGramDecodingConfig,
"DraftTarget": DraftTargetDecodingConfig,
}

config_class = config_classes.get(decoding_type)
Expand Down Expand Up @@ -238,7 +239,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
dynamic_tree_max_topK: Optional[int] = None
num_eagle_layers: Optional[int] = None
max_non_leaves_per_layer: Optional[int] = None
pytorch_eagle_weights_path: Optional[str] = None
pytorch_weights_path: Optional[str] = None
eagle3_one_model: Optional[bool] = True

@classmethod
Expand Down Expand Up @@ -282,6 +283,16 @@ def from_dict(cls, data: dict):
decoding_type: ClassVar[str] = "NGram"


class DraftTargetDecodingConfig(DecodingBaseConfig):
pytorch_weights_path: Optional[str] = None

@classmethod
def from_dict(cls, data: dict):
return cls(**data)

decoding_type: ClassVar[str] = "DraftTarget"


class MTPDecodingConfig(DecodingBaseConfig):
num_nextn_predict_layers: Optional[int] = 1
use_relaxed_acceptance_for_thinking: Optional[bool] = False
Expand Down Expand Up @@ -896,10 +907,11 @@ class BaseLlmArgs(BaseModel):
default=None, description="Cache transceiver config.")

# Speculative decoding parameters
speculative_config: Optional[Union[
LookaheadDecodingConfig, MedusaDecodingConfig, EagleDecodingConfig,
MTPDecodingConfig, NGramDecodingConfig]] = Field(
default=None, description="Speculative decoding config.")
speculative_config: Optional[
Union[LookaheadDecodingConfig, MedusaDecodingConfig,
EagleDecodingConfig, MTPDecodingConfig, NGramDecodingConfig,
DraftTargetDecodingConfig]] = Field(
default=None, description="Speculative decoding config.")

batching_type: Optional[BatchingType] = Field(default=None,
description="Batching type.")
Expand Down Expand Up @@ -1302,7 +1314,7 @@ def validate_speculative_config(self):
self.speculative_config = Eagle3Config(
max_draft_tokens=self.speculative_config.max_draft_len,
draft_model_path=self.speculative_config.
pytorch_eagle_weights_path,
pytorch_weights_path,
eagle3_one_model=self.speculative_config.
eagle3_one_model)
elif isinstance(self.speculative_config, NGramDecodingConfig):
Expand All @@ -1320,6 +1332,16 @@ def validate_speculative_config(self):
is_use_oldest=self.speculative_config.is_use_oldest,
is_public_pool=self.speculative_config.is_public_pool,
)
elif isinstance(self.speculative_config, DraftTargetDecodingConfig):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
assert self.backend == 'pytorch'
assert self.speculative_config.max_draft_len > 0
self.build_config.max_draft_len = self.speculative_config.max_draft_len
from tensorrt_llm._torch.speculative import DraftTargetConfig
self.speculative_config = DraftTargetConfig(
max_draft_tokens=self.speculative_config.max_draft_len,
draft_model_path=self.speculative_config.
pytorch_weights_path)
elif isinstance(self.speculative_config, MTPDecodingConfig):
from tensorrt_llm._torch.speculative import MTPConfig
self.speculative_config = MTPConfig(
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from ..module import Module
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
get_build_cache_config_from_env)
from .llm_args import (CalibConfig, EagleDecodingConfig, KvCacheConfig, LlmArgs,
from .llm_args import (CalibConfig, DraftTargetDecodingConfig,
EagleDecodingConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig,
MTPDecodingConfig, NGramDecodingConfig, _ModelFormatKind,
_ModelWrapper, _ParallelConfig, get_model_format,
Expand Down Expand Up @@ -871,6 +872,7 @@ class LlmBuildStats:
'MedusaDecodingConfig',
'MTPDecodingConfig',
'NGramDecodingConfig',
'DraftTargetDecodingConfig',
'ContextChunkingPolicy',
'CapacitySchedulerPolicy',
'BuildConfig',
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def test_eagle3(self):
target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"

draft_len = 4
spec_config = EagleDecodingConfig(
max_draft_len=draft_len, pytorch_eagle_weights_path=eagle_model_dir)
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
pytorch_weights_path=eagle_model_dir)

llm = LLM(model=target_model_dir,
**pytorch_config,
Expand Down
73 changes: 73 additions & 0 deletions tests/unittest/_torch/speculative/test_draft_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import sys
import unittest

import pytest
import torch

from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.llm_data import llm_models_root


@pytest.mark.parametrize("use_cuda_graph,attn_backend",
[[False, "TRTLLM"], [True, "TRTLLM"]])
def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 60:
pytest.skip("Not enough memory to load target model")

models_path = llm_models_root()

kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=2080)

sampling_params = SamplingParams(
max_tokens=32,
temperature=0,
)
max_batch_size = 1

target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
draft_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"

draft_len = 4
spec_config = DraftTargetDecodingConfig(
max_draft_len=draft_len, pytorch_weights_path=draft_model_dir)
llm_spec = LLM(model=target_model_dir,
max_batch_size=max_batch_size,
disable_overlap_scheduler=True,
use_cuda_graph=use_cuda_graph,
attn_backend=attn_backend,
cuda_graph_batch_sizes=[1],
kv_cache_config=kv_cache_config,
speculative_config=spec_config)

prompts = [
"The capital of France is", "The president of the United States is"
]
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()

llm_ref = LLM(model=target_model_dir,
max_batch_size=max_batch_size,
disable_overlap_scheduler=True,
use_cuda_graph=use_cuda_graph,
attn_backend=attn_backend,
cuda_graph_batch_sizes=[1],
kv_cache_config=kv_cache_config)

results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()

for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/unittest/_torch/speculative/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
draft_len = 4
spec_config = EagleDecodingConfig(
max_draft_len=draft_len,
pytorch_eagle_weights_path=eagle_model_dir,
pytorch_weights_path=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ methods:
# Speculative decoding
speculative_config:
annotation: Union[tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_utils.MedusaDecodingConfig,
tensorrt_llm.llmapi.llm_utils.EagleDecodingConfig, tensorrt_llm.llmapi.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, NoneType]
tensorrt_llm.llmapi.llm_utils.EagleDecodingConfig, tensorrt_llm.llmapi.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig,
tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, NoneType]
default: null
# generation constraints
max_batch_size:
Expand Down