Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,10 @@ def _executor_loop(self):

sample_state = self._sample_async(scheduled_batch,
batch_outputs)
if self.drafter is not None:
self.drafter.run_drafter_post(scheduled_batch,
self.resource_manager,
self.is_warmup)

self._update_request_states(scheduled_batch)
self._update_requests(sample_state, self.resource_manager)
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/speculative/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .interface import SpecMetadata
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
from .spec_tree_manager import SpecTreeManager
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
get_spec_decoder, get_spec_drafter, get_spec_metadata,
Expand All @@ -16,6 +17,7 @@
"MTPWorker",
"NGramDrafter",
"NGramPoolManager",
"SaveHiddenStatesDrafter",
"SpecMetadata",
"get_num_extra_kv_tokens",
"get_num_spec_layers",
Expand Down
12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,15 @@ def pad_draft_tokens_for_cuda_graph(
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))

def run_drafter_post(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
is_warmup: bool = False,
) -> None:
"""
If draft forward needs to be run directly after the target model forward,
this method can be overridden to do that.
Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
"""
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def __post_init__(self):
self.num_layers - 4)
else:
self.layers_to_capture = sorted(list(self.layers_to_capture))
if self.layers_to_capture[0] == -1:
self.layers_to_capture = self.layers_to_capture[1:] + [
self.layers_to_capture.pop(0)
]
self.num_capture_layers = len(self.layers_to_capture)

# Initialize to 0 to avoid reading uninitialized memory during warmup
Expand Down
9 changes: 7 additions & 2 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum):
NGRAM = auto()
DRAFT_TARGET = auto()
USER_PROVIDED = auto()
SAVE_HIDDEN_STATES = auto()
NONE = auto()
AUTO = auto()

Expand Down Expand Up @@ -55,6 +56,9 @@ def is_none(self):
def is_draft_target(self):
return self == SpeculativeDecodingMode.DRAFT_TARGET

def is_save_hidden_states(self):
return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES

def without_logits(self):
return self.is_mtp_one_model() or self.is_eagle3_one_model()

Expand Down Expand Up @@ -95,8 +99,9 @@ def has_spec_decoder(self):
) or self.is_eagle3_one_model()

def has_spec_drafter(self):
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
) or self.is_user_provided() or self.is_mtp_eagle()
return self.is_eagle3(
) or self.is_draft_target() or self.is_ngram() or self.is_user_provided(
) or self.is_mtp_eagle() or self.is_save_hidden_states()

def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""
Expand Down
99 changes: 99 additions & 0 deletions tensorrt_llm/_torch/speculative/save_hidden_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from typing import Optional

import torch

from tensorrt_llm._utils import local_mpi_rank

from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
from .drafter import Drafter


class SaveHiddenStatesDrafter(Drafter):

def __init__(
self,
spec_config: "SaveHiddenStatesDecodingConfig",
spec_resource_manager,
):
super().__init__(spec_config.max_concurrency)
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
self._iter = 1
self._output_directory = spec_config.output_directory
self._file_prefix = spec_config.file_prefix
self._write_interval = spec_config.write_interval
self._saved_state = []
self.spec_resource_manager = spec_resource_manager
os.makedirs(self._output_directory, exist_ok=True)

def _process_request(self, request: LlmRequest, resource_manager) -> None:
out_dict = {}
if local_mpi_rank() == 0:
input_ids = torch.tensor(list(request.get_tokens(0)),
dtype=torch.long,
device='cpu')
hidden_size = resource_manager.hidden_size
num_tokens = input_ids.shape[0]
hidden_states = resource_manager.hidden_states[:num_tokens,
-hidden_size:].cpu(
).clone()

out_dict = {
"id": self._iter,
"input_ids": input_ids,
"hidden_state": hidden_states,
}
if len(self.spec_config.eagle3_layers_to_capture) > 1:
if self.spec_config._last_hidden_in_save:
out_dict[
"aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu(
).clone()
else:
out_dict[
"aux_hidden_states"] = resource_manager.hidden_states[:
num_tokens, :
-hidden_size].cpu(
).clone(
)

self._saved_state.append(out_dict)

def _write_to_file(self) -> None:
if local_mpi_rank() == 0:
output_path = os.path.join(self._output_directory,
f"{self._file_prefix}_{self._iter}.pt")
torch.save(self._saved_state, output_path)
self._saved_state = []

def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
for request in sorted(
scheduled_requests.context_requests,
key=lambda r:
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
request.py_max_new_tokens = 1

def run_drafter_post(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
is_warmup: bool = False,
) -> None:
if is_warmup:
return
for request in sorted(
scheduled_requests.context_requests,
key=lambda r:
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
self._process_request(request, self.spec_resource_manager)
if self._iter % self._write_interval == 0:
self._write_to_file()
self._iter += 1
32 changes: 32 additions & 0 deletions tensorrt_llm/_torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
MTPSpecMetadata, MTPWorker)
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter


def get_spec_metadata(spec_config,
Expand Down Expand Up @@ -55,6 +56,25 @@ def get_spec_metadata(spec_config,
max_num_tokens=max_num_tokens,
layers_to_capture=spec_config.eagle3_layers_to_capture,
)
if spec_config.spec_dec_mode.is_save_hidden_states():
if spec_config.eagle3_layers_to_capture is None:
spec_config.eagle3_layers_to_capture = {
1, model_config.num_hidden_layers // 2 - 1,
model_config.num_hidden_layers - 4, -1
}
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
max_total_draft_tokens=1,
)
if spec_config.spec_dec_mode.is_draft_target() or \
spec_config.spec_dec_mode.is_ngram() or \
spec_config.spec_dec_mode.is_user_provided():
Expand Down Expand Up @@ -102,6 +122,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_save_hidden_states():
return Eagle3ResourceManager(
spec_config,
model_engine.model.config.torch_dtype,
model_config.hidden_size,
max_num_requests,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests)
if spec_dec_mode.is_user_provided():
Expand Down Expand Up @@ -151,6 +180,9 @@ def get_spec_drafter(model_engine,
if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config, spec_resource_manager)

if spec_config.spec_dec_mode.is_save_hidden_states():
return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)

return None


Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/llmapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
MTPDecodingConfig, NGramDecodingConfig,
SaveHiddenStatesDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
UserProvidedDecodingConfig)
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
Expand Down Expand Up @@ -59,4 +60,5 @@
'AutoDecodingConfig',
'AttentionDpConfig',
'LoRARequest',
'SaveHiddenStatesDecodingConfig',
]
62 changes: 62 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def from_dict(cls, data: dict):
"Lookahead": LookaheadDecodingConfig,
"NGram": NGramDecodingConfig,
"DraftTarget": DraftTargetDecodingConfig,
"SaveState": SaveHiddenStatesDecodingConfig,
"UserProvided": UserProvidedDecodingConfig,
"AUTO": AutoDecodingConfig,
}
Expand Down Expand Up @@ -556,6 +557,52 @@ def num_capture_layers(self) -> int:
return 3


class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
output_directory: str
write_interval: int = 20
file_prefix: str = "data"
eagle3_layers_to_capture: Optional[Set[int]] = None

max_total_draft_tokens: Optional[int] = Field(default=1, init=False)
eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False)

def model_post_init(self, __context):
self._last_hidden_in_save = True
if self.eagle3_layers_to_capture is None:
self._last_hidden_in_save = False
elif -1 not in self.eagle3_layers_to_capture:
self._last_hidden_in_save = False
self.eagle3_layers_to_capture.add(-1)

@classmethod
def from_dict(cls, data: dict):
return cls(**data)

decoding_type: ClassVar[str] = "SaveState"

def validate(self) -> None:
if self.output_directory is None or not self.eagle3_layers_to_capture:
raise ValueError(
"Save directory and layers to capture must be provided")

@functools.cached_property
def spec_dec_mode(self):
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES

@functools.cached_property
def num_capture_layers(self):
"""
Returns the number of layers to capture of the target model.
If eagle3_layers_to_capture is not None, return the length of the set.
Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).
"""
if self.eagle3_layers_to_capture is None:
return 4
return len(self.eagle3_layers_to_capture)


class UserProvidedDecodingConfig(DecodingBaseConfig):
# Cannot use real type annotations due to circular imports
drafter: object # Type is Drafter
Expand Down Expand Up @@ -1044,6 +1091,7 @@ def supports_backend(self, backend: str) -> bool:
MTPDecodingConfig,
NGramDecodingConfig,
UserProvidedDecodingConfig,
SaveHiddenStatesDecodingConfig,
AutoDecodingConfig,
]]

Expand Down Expand Up @@ -1863,6 +1911,20 @@ def validate_speculative_config(self):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
self.build_config.max_draft_len = self.speculative_config.max_draft_len

elif isinstance(self.speculative_config,
SaveHiddenStatesDecodingConfig):
assert self.backend in ['pytorch']
logger.warning(
"SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"
)
self.build_config.max_batch_size = 1
self.max_batch_size = 1
self.disable_overlap_scheduler = True
self.cuda_graph_config = None
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
self.build_config.max_draft_len = 1
self.speculative_config.max_draft_len = 1

else:
raise ValueError(
f"Unrecognized speculative config type {type(self.speculative_config)}"
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
EAGLE = auto()
NGRAM = auto()
USER_PROVIDED = auto()
SAVE_HIDDEN_STATES = auto()
AUTO = auto()

@staticmethod
Expand All @@ -120,6 +121,8 @@ def from_arguments(args: argparse.Namespace):
return SpeculativeDecodingMode.USER_PROVIDED
elif args.speculative_decoding_mode == "auto":
return SpeculativeDecodingMode.AUTO
elif args.speculative_decoding_mode == "save_hidden_states":
return SpeculativeDecodingMode.SAVE_HIDDEN_STATES
else:
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode

Expand Down
Loading