From 17c59c369a8f180e95386d7d434136b68b95bfbb Mon Sep 17 00:00:00 2001 From: Izzy Putterman Date: Mon, 18 Aug 2025 13:15:43 -0700 Subject: [PATCH 1/3] Save state first pass Signed-off-by: Izzy Putterman small changes Signed-off-by: Izzy Putterman fixes inflight Signed-off-by: Izzy Putterman maybe functional Signed-off-by: Izzy Putterman cleanup + small design changes Signed-off-by: Izzy Putterman drop unused lines Signed-off-by: Izzy Putterman --- .../_torch/models/modeling_speculative.py | 1 - tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 + tensorrt_llm/_torch/speculative/__init__.py | 2 + tensorrt_llm/_torch/speculative/drafter.py | 12 +++ tensorrt_llm/_torch/speculative/interface.py | 6 +- .../_torch/speculative/model_drafter.py | 17 +++ .../_torch/speculative/save_hidden_state.py | 101 ++++++++++++++++++ tensorrt_llm/_torch/speculative/utils.py | 26 +++++ tensorrt_llm/llmapi/__init__.py | 4 +- tensorrt_llm/llmapi/llm_args.py | 60 +++++++++++ tensorrt_llm/models/modeling_utils.py | 3 + 11 files changed, 233 insertions(+), 3 deletions(-) create mode 100644 tensorrt_llm/_torch/speculative/save_hidden_state.py diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 6eb989af33f..a988efb329a 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -539,7 +539,6 @@ def forward( spec_metadata=spec_metadata, **kwargs, ) - if spec_metadata is not None and spec_metadata.is_layer_capture( self.layer_idx): spec_metadata.maybe_capture_hidden_states(self.layer_idx, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index aa0902484a5..89fdcc6d850 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1100,6 +1100,10 @@ def _executor_loop(self): sample_state = self._sample_async(scheduled_batch, batch_outputs) + if self.drafter is not None: + self.drafter.run_drafter_post(scheduled_batch, + self.resource_manager, + self.is_warmup) self._update_request_states(scheduled_batch) self._update_requests(sample_state, self.resource_manager) diff --git a/tensorrt_llm/_torch/speculative/__init__.py b/tensorrt_llm/_torch/speculative/__init__.py index 8f6e0254faa..31ed71f76f3 100644 --- a/tensorrt_llm/_torch/speculative/__init__.py +++ b/tensorrt_llm/_torch/speculative/__init__.py @@ -3,6 +3,7 @@ from .interface import SpecMetadata from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker from .ngram import NGramDrafter, NGramPoolManager +from .save_hidden_state import SaveHiddenStatesDrafter from .spec_tree_manager import SpecTreeManager from .utils import (get_num_extra_kv_tokens, get_num_spec_layers, get_spec_decoder, get_spec_drafter, get_spec_metadata, @@ -16,6 +17,7 @@ "MTPWorker", "NGramDrafter", "NGramPoolManager", + "SaveHiddenStatesDrafter", "SpecMetadata", "get_num_extra_kv_tokens", "get_num_spec_layers", diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 74384206740..485934f7b5c 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -67,3 +67,15 @@ def pad_draft_tokens_for_cuda_graph( num_draft_tokens = get_draft_token_length(req) req.py_draft_tokens.extend( 0 for _ in range(max_draft_tokens - num_draft_tokens)) + + def run_drafter_post( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + is_warmup: bool = False, + ) -> None: + """ + If draft forward needs to be run directly after the target model forward, + this method can be overridden to do that. + Used in SaveHiddenStatesDrafter (to ensure correct input_ids) + """ diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 16522e98320..9e88d6986e9 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum): NGRAM = auto() DRAFT_TARGET = auto() USER_PROVIDED = auto() + SAVE_HIDDEN_STATES = auto() NONE = auto() AUTO = auto() @@ -55,6 +56,9 @@ def is_none(self): def is_draft_target(self): return self == SpeculativeDecodingMode.DRAFT_TARGET + def is_save_hidden_states(self): + return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES + def without_logits(self): return self.is_mtp_one_model() or self.is_eagle3_one_model() @@ -96,7 +100,7 @@ def has_spec_decoder(self): def has_spec_drafter(self): return self.is_eagle3() or self.is_draft_target() or self.is_ngram( - ) or self.is_user_provided() or self.is_mtp_eagle() + ) or self.is_user_provided() or self.is_mtp_eagle() or self.is_save_hidden_states() def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index f6082ac3264..8980fba9a20 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -794,3 +794,20 @@ def prepare_draft_tokens( error_msg = str(e) logger.error(f"Encountered an error in decode: {error_msg}") raise e + + @nvtx_range("prepare_draft_tokens_post") + def prepare_draft_tokens_post( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + is_warmup: bool = False, + ) -> None: + """ + If draft forward needs to be run directly after the target model forward, + this method can be overridden to do that. + Used in SaveHiddenStatesDrafter (to ensure correct input_ids) + + Args: + scheduled_requests: The scheduled requests for this iteration + resource_manager: The resource manager for this iteration + """ diff --git a/tensorrt_llm/_torch/speculative/save_hidden_state.py b/tensorrt_llm/_torch/speculative/save_hidden_state.py new file mode 100644 index 00000000000..fac5919b1ce --- /dev/null +++ b/tensorrt_llm/_torch/speculative/save_hidden_state.py @@ -0,0 +1,101 @@ +import os +from typing import Optional + +import torch + +from tensorrt_llm._utils import local_mpi_rank + +from ..pyexecutor.llm_request import LlmRequest +from ..pyexecutor.resource_manager import ResourceManager +from ..pyexecutor.scheduler import ScheduledRequests +from .drafter import Drafter + + +class SaveHiddenStatesDrafter(Drafter): + + def __init__( + self, + spec_config: "SaveHiddenStatesDecodingConfig", + spec_resource_manager: SaveHiddenStatesResourceManager, + ): + super().__init__(spec_config.max_concurrency) + self.spec_config = spec_config + self.max_draft_len = spec_config.max_draft_len + self._iter = 1 + self._output_directory = spec_config.output_directory + self._file_prefix = spec_config.file_prefix + self._write_interval = spec_config.write_interval + self._saved_state = [] + self.spec_resource_manager = spec_resource_manager + os.makedirs(self._output_directory, exist_ok=True) + + def _process_request( + self, request: LlmRequest, + resource_manager: SaveHiddenStatesResourceManager) -> None: + out_dict = {} + if local_mpi_rank() == 0: + input_ids = torch.tensor(list(request.get_tokens(0)), + dtype=torch.long, + device='cpu') + hidden_size = resource_manager.hidden_size + num_tokens = input_ids.shape[0] + hidden_states = resource_manager.hidden_states[:num_tokens, + -hidden_size:].cpu( + ).clone() + + out_dict = { + "id": self._iter, + "input_ids": input_ids, + "hidden_state": hidden_states, + } + if len(self.spec_config.eagle3_layers_to_capture) > 1: + if self.spec_config._last_hidden_in_save: + out_dict[ + "hidden_state_features"] = resource_manager.hidden_states[:num_tokens, :].cpu( + ).clone() + else: + out_dict[ + "hidden_state_features"] = resource_manager.hidden_states[: + num_tokens, : + -hidden_size].cpu( + ).clone( + ) + + self._saved_state.append(out_dict) + + def _write_to_file(self) -> None: + if local_mpi_rank() == 0: + output_path = os.path.join(self._output_directory, + f"{self._file_prefix}_{self._iter}.pt") + torch.save(self._saved_state, output_path) + self._saved_state = [] + + def prepare_draft_tokens( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + ) -> None: + for request in sorted( + scheduled_requests.context_requests, + key=lambda r: + (r.py_batch_idx is None, r.py_batch_idx or r.request_id), + ): + request.py_max_new_tokens = 1 + + def run_drafter_post( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + is_warmup: bool = False, + ) -> None: + for request in sorted( + scheduled_requests.context_requests, + key=lambda r: + (r.py_batch_idx is None, r.py_batch_idx or r.request_id), + ): + if is_warmup: + continue + self._process_request(request, self.spec_resource_manager) + if self._iter % self._write_interval == 0: + self._write_to_file() + self._iter += 1 diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 56b44704c0e..feab608cf04 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -11,6 +11,7 @@ from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata, MTPWorker) from .ngram import NGramDrafter, NGramPoolManager +from .save_hidden_state import SaveHiddenStatesDrafter def get_spec_metadata(spec_config, @@ -55,6 +56,19 @@ def get_spec_metadata(spec_config, max_num_tokens=max_num_tokens, layers_to_capture=spec_config.eagle3_layers_to_capture, ) + if spec_config.spec_dec_mode.is_save_hidden_states(): + return Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, + spec_dec_mode=spec_config.spec_dec_mode, + max_num_requests=max_num_requests, + num_layers=model_config.num_hidden_layers, + hidden_size=model_config.hidden_size, + max_num_tokens=max_num_tokens, + dtype=model_config.torch_dtype, + is_draft_model=is_draft_model, + eagle3_resource_manager=spec_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, + ) if spec_config.spec_dec_mode.is_draft_target() or \ spec_config.spec_dec_mode.is_ngram() or \ spec_config.spec_dec_mode.is_user_provided(): @@ -102,6 +116,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): max_seq_len, max_num_tokens, ) + if spec_dec_mode.is_save_hidden_states(): + return Eagle3ResourceManager( + spec_config, + model_engine.model.config.torch_dtype, + model_config.hidden_size, + max_num_requests, + max_seq_len, + max_num_tokens, + ) if spec_dec_mode.is_ngram(): return NGramPoolManager(spec_config, max_num_requests) if spec_dec_mode.is_user_provided(): @@ -151,6 +174,9 @@ def get_spec_drafter(model_engine, if spec_config.spec_dec_mode.is_ngram(): return NGramDrafter(spec_config, spec_resource_manager) + if spec_config.spec_dec_mode.is_save_hidden_states(): + return SaveHiddenStatesDrafter(spec_config, spec_resource_manager) + return None diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index 1c3ebd6e2b9..adc0e7e35c3 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -11,7 +11,8 @@ DynamicBatchConfig, EagleDecodingConfig, ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, - MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig, + MTPDecodingConfig, NGramDecodingConfig, + SaveHiddenStatesDecodingConfig, SchedulerConfig, TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, @@ -59,4 +60,5 @@ 'AutoDecodingConfig', 'AttentionDpConfig', 'LoRARequest', + 'SaveHiddenStatesDecodingConfig', ] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 5a05ee741f3..07363226b31 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -374,6 +374,7 @@ def from_dict(cls, data: dict): "Lookahead": LookaheadDecodingConfig, "NGram": NGramDecodingConfig, "DraftTarget": DraftTargetDecodingConfig, + "SaveState": SaveHiddenStatesDecodingConfig, "UserProvided": UserProvidedDecodingConfig, "AUTO": AutoDecodingConfig, } @@ -556,6 +557,50 @@ def num_capture_layers(self) -> int: return 3 +class SaveHiddenStatesDecodingConfig(DecodingBaseConfig): + output_directory: str + write_interval: int = 20 + file_prefix: str = "data" + eagle3_layers_to_capture: Optional[Set[int]] = None + + def __post_init__(self): + self._last_hidden_in_save = True + if self.eagle3_layers_to_capture is None: + self._last_hidden_in_save = False + self.eagle3_layers_to_capture = { + 1, self.num_layers // 2 - 1, self.num_layers - 4, -1 + } + if -1 not in self.eagle3_layers_to_capture: + self._last_hidden_in_save = False + self.eagle3_layers_to_capture.add(-1) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + decoding_type: ClassVar[str] = "SaveState" + + def validate(self) -> None: + if self.output_directory is None or not self.eagle3_layers_to_capture: + raise ValueError( + "Save directory and layers to capture must be provided") + + @functools.cached_property + def spec_dec_mode(self): + from tensorrt_llm._torch.speculative.interface import \ + SpeculativeDecodingMode as TorchSpeculativeDecodingMode + return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES + + @functools.cached_property + def num_capture_layers(self): + """ + Returns the number of layers to capture of the target model. + If eagle3_layers_to_capture is not None, return the length of the set. + Otherwise, assume Eagle3 base set and return 3. + """ + return len(self.eagle3_layers_to_capture) + + class UserProvidedDecodingConfig(DecodingBaseConfig): # Cannot use real type annotations due to circular imports drafter: object # Type is Drafter @@ -1044,6 +1089,7 @@ def supports_backend(self, backend: str) -> bool: MTPDecodingConfig, NGramDecodingConfig, UserProvidedDecodingConfig, + SaveHiddenStatesDecodingConfig, AutoDecodingConfig, ]] @@ -1863,6 +1909,20 @@ def validate_speculative_config(self): self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO self.build_config.max_draft_len = self.speculative_config.max_draft_len + elif isinstance(self.speculative_config, + SaveHiddenStatesDecodingConfig): + assert self.backend in ['pytorch'] + logger.warning( + "SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None" + ) + self.build_config.max_batch_size = 1 + self.max_batch_size = 1 + self.disable_overlap_scheduler = True + self.cuda_graph_config = None + self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES + self.build_config.max_draft_len = 1 + self.speculative_config.max_draft_len = 1 + else: raise ValueError( f"Unrecognized speculative config type {type(self.speculative_config)}" diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 6f4bfdf0bb0..b2ad8f82dfc 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag): EAGLE = auto() NGRAM = auto() USER_PROVIDED = auto() + SAVE_HIDDEN_STATES = auto() AUTO = auto() @staticmethod @@ -120,6 +121,8 @@ def from_arguments(args: argparse.Namespace): return SpeculativeDecodingMode.USER_PROVIDED elif args.speculative_decoding_mode == "auto": return SpeculativeDecodingMode.AUTO + elif args.speculative_decoding_mode == "save_hidden_states": + return SpeculativeDecodingMode.SAVE_HIDDEN_STATES else: assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode From ffe9c77d9767c32f5ab7c5b1392382a81986b783 Mon Sep 17 00:00:00 2001 From: Izzy Putterman Date: Tue, 16 Sep 2025 23:04:35 -0700 Subject: [PATCH 2/3] Save state test Signed-off-by: Izzy Putterman --- .../_torch/speculative/test_save_state.py | 152 ++++++++++++++++++ .../references_committed/llm.yaml | 2 +- 2 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 tests/unittest/_torch/speculative/test_save_state.py diff --git a/tests/unittest/_torch/speculative/test_save_state.py b/tests/unittest/_torch/speculative/test_save_state.py new file mode 100644 index 00000000000..8878083d938 --- /dev/null +++ b/tests/unittest/_torch/speculative/test_save_state.py @@ -0,0 +1,152 @@ +import os +import sys +import tempfile +import unittest + +import pytest +import torch +from utils.llm_data import llm_models_root + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig, + SaveHiddenStatesDecodingConfig) + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + + +def test_multi_save_state(): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + enable_chunked_prefill = False + layers_to_capture = {1, 4, 8, 16} + + # Eagle3 one model works with overlap scheduler and block reuse. + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 80: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + with tempfile.TemporaryDirectory() as temp_dir: + + target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + + max_batch_size = 16 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + free_gpu_memory_fraction=0.5) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1, 2, 4]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + ) + spec_config = SaveHiddenStatesDecodingConfig( + output_directory=temp_dir, + write_interval=1, + file_prefix="data", + eagle3_layers_to_capture=layers_to_capture) + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + + assert os.path.exists(os.path.join(temp_dir, "data_1.pt")) + # Read in .pt file + saved_data = torch.load(os.path.join(temp_dir, "data_1.pt")) + + assert saved_data["hidden_state_features"].shape == ( + 1, len(tok_ids), 2048 * len(layers_to_capture)) + assert saved_data["hidden_state"].shape == (1, len(tok_ids), 2048) + assert saved_data["input_ids"].tolist() == tok_ids + hidden_states = saved_data["hidden_state_features"] + + # start the HF version of the model + hf_model = AutoModelForCausalLM.from_pretrained( + target_model_dir, torch_dtype=torch.bfloat16, device_map="cuda") + # do the forward pass and collect hidden states + hf_hidden_states = hf_model(tok_ids, output_hidden_states=True) + # compare the hidden states of saved and HF version + concat_hidden_states = torch.cat( + hf_hidden_states[list(layers_to_capture)]) + assert torch.allclose(hidden_states, concat_hidden_states) + + +@pytest.mark.parametrize("layers_to_capture", [{-1}, None]) +def test_save_state(layers_to_capture): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + enable_chunked_prefill = False + + # Eagle3 one model works with overlap scheduler and block reuse. + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 80: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + with tempfile.TemporaryDirectory() as temp_dir: + + target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + + max_batch_size = 16 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + free_gpu_memory_fraction=0.5) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1, 2, 4]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + ) + spec_config = SaveHiddenStatesDecodingConfig( + output_directory=temp_dir, + write_interval=1, + file_prefix="data", + eagle3_layers_to_capture=layers_to_capture) + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + + assert os.path.exists(os.path.join(temp_dir, "data_1.pt")) + # Read in .pt file + saved_data = torch.load(os.path.join(temp_dir, "data_1.pt")) + if layers_to_capture is None: + assert saved_data["hidden_state_features"].shape == (1, + len(tok_ids), + 2048 * 3) + assert saved_data["hidden_state"].shape == (1, len(tok_ids), 2048) + assert saved_data["input_ids"].tolist() == tok_ids + else: + assert "hidden_state_features" not in saved_data + assert saved_data["hidden_state"].shape == (1, len(tok_ids), 2048) + assert saved_data["input_ids"].tolist() == tok_ids + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/api_stability/references_committed/llm.yaml b/tests/unittest/api_stability/references_committed/llm.yaml index a722da54958..36e2ff28ea9 100644 --- a/tests/unittest/api_stability/references_committed/llm.yaml +++ b/tests/unittest/api_stability/references_committed/llm.yaml @@ -59,7 +59,7 @@ methods: default: null # Speculative decoding speculative_config: - annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType] + annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, tensorrt_llm.llmapi.llm_args.SaveHiddenStatesDecodingConfig, NoneType] default: null # generation constraints max_batch_size: From 6b6b73b5e98c4c863be6e53734fcb9d3e86a3d07 Mon Sep 17 00:00:00 2001 From: Izzy Putterman Date: Wed, 17 Sep 2025 21:38:37 -0700 Subject: [PATCH 3/3] Bug fixing Signed-off-by: Izzy Putterman --- .../_torch/models/modeling_speculative.py | 1 + tensorrt_llm/_torch/speculative/eagle3.py | 4 ++ tensorrt_llm/_torch/speculative/interface.py | 5 ++- .../_torch/speculative/model_drafter.py | 17 ------- .../_torch/speculative/save_hidden_state.py | 22 +++++----- tensorrt_llm/_torch/speculative/utils.py | 6 +++ tensorrt_llm/llmapi/llm_args.py | 14 +++--- .../_torch/speculative/test_save_state.py | 44 +++++++------------ 8 files changed, 47 insertions(+), 66 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index a988efb329a..6eb989af33f 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -539,6 +539,7 @@ def forward( spec_metadata=spec_metadata, **kwargs, ) + if spec_metadata is not None and spec_metadata.is_layer_capture( self.layer_idx): spec_metadata.maybe_capture_hidden_states(self.layer_idx, diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 571850c82da..42682c07934 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -126,6 +126,10 @@ def __post_init__(self): self.num_layers - 4) else: self.layers_to_capture = sorted(list(self.layers_to_capture)) + if self.layers_to_capture[0] == -1: + self.layers_to_capture = self.layers_to_capture[1:] + [ + self.layers_to_capture.pop(0) + ] self.num_capture_layers = len(self.layers_to_capture) # Initialize to 0 to avoid reading uninitialized memory during warmup diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 9e88d6986e9..191eb92c7eb 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -99,8 +99,9 @@ def has_spec_decoder(self): ) or self.is_eagle3_one_model() def has_spec_drafter(self): - return self.is_eagle3() or self.is_draft_target() or self.is_ngram( - ) or self.is_user_provided() or self.is_mtp_eagle() or self.is_save_hidden_states() + return self.is_eagle3( + ) or self.is_draft_target() or self.is_ngram() or self.is_user_provided( + ) or self.is_mtp_eagle() or self.is_save_hidden_states() def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 8980fba9a20..f6082ac3264 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -794,20 +794,3 @@ def prepare_draft_tokens( error_msg = str(e) logger.error(f"Encountered an error in decode: {error_msg}") raise e - - @nvtx_range("prepare_draft_tokens_post") - def prepare_draft_tokens_post( - self, - scheduled_requests: ScheduledRequests, - resource_manager: Optional[ResourceManager] = None, - is_warmup: bool = False, - ) -> None: - """ - If draft forward needs to be run directly after the target model forward, - this method can be overridden to do that. - Used in SaveHiddenStatesDrafter (to ensure correct input_ids) - - Args: - scheduled_requests: The scheduled requests for this iteration - resource_manager: The resource manager for this iteration - """ diff --git a/tensorrt_llm/_torch/speculative/save_hidden_state.py b/tensorrt_llm/_torch/speculative/save_hidden_state.py index fac5919b1ce..202088784fe 100644 --- a/tensorrt_llm/_torch/speculative/save_hidden_state.py +++ b/tensorrt_llm/_torch/speculative/save_hidden_state.py @@ -16,7 +16,7 @@ class SaveHiddenStatesDrafter(Drafter): def __init__( self, spec_config: "SaveHiddenStatesDecodingConfig", - spec_resource_manager: SaveHiddenStatesResourceManager, + spec_resource_manager, ): super().__init__(spec_config.max_concurrency) self.spec_config = spec_config @@ -29,9 +29,7 @@ def __init__( self.spec_resource_manager = spec_resource_manager os.makedirs(self._output_directory, exist_ok=True) - def _process_request( - self, request: LlmRequest, - resource_manager: SaveHiddenStatesResourceManager) -> None: + def _process_request(self, request: LlmRequest, resource_manager) -> None: out_dict = {} if local_mpi_rank() == 0: input_ids = torch.tensor(list(request.get_tokens(0)), @@ -51,15 +49,15 @@ def _process_request( if len(self.spec_config.eagle3_layers_to_capture) > 1: if self.spec_config._last_hidden_in_save: out_dict[ - "hidden_state_features"] = resource_manager.hidden_states[:num_tokens, :].cpu( + "aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu( ).clone() else: out_dict[ - "hidden_state_features"] = resource_manager.hidden_states[: - num_tokens, : - -hidden_size].cpu( - ).clone( - ) + "aux_hidden_states"] = resource_manager.hidden_states[: + num_tokens, : + -hidden_size].cpu( + ).clone( + ) self._saved_state.append(out_dict) @@ -88,13 +86,13 @@ def run_drafter_post( resource_manager: Optional[ResourceManager] = None, is_warmup: bool = False, ) -> None: + if is_warmup: + return for request in sorted( scheduled_requests.context_requests, key=lambda r: (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ): - if is_warmup: - continue self._process_request(request, self.spec_resource_manager) if self._iter % self._write_interval == 0: self._write_to_file() diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index feab608cf04..152cbd1074e 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -57,6 +57,11 @@ def get_spec_metadata(spec_config, layers_to_capture=spec_config.eagle3_layers_to_capture, ) if spec_config.spec_dec_mode.is_save_hidden_states(): + if spec_config.eagle3_layers_to_capture is None: + spec_config.eagle3_layers_to_capture = { + 1, model_config.num_hidden_layers // 2 - 1, + model_config.num_hidden_layers - 4, -1 + } return Eagle3SpecMetadata( max_draft_len=spec_config.max_draft_len, spec_dec_mode=spec_config.spec_dec_mode, @@ -68,6 +73,7 @@ def get_spec_metadata(spec_config, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, layers_to_capture=spec_config.eagle3_layers_to_capture, + max_total_draft_tokens=1, ) if spec_config.spec_dec_mode.is_draft_target() or \ spec_config.spec_dec_mode.is_ngram() or \ diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 07363226b31..4bad387a1b0 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -563,14 +563,14 @@ class SaveHiddenStatesDecodingConfig(DecodingBaseConfig): file_prefix: str = "data" eagle3_layers_to_capture: Optional[Set[int]] = None - def __post_init__(self): + max_total_draft_tokens: Optional[int] = Field(default=1, init=False) + eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False) + + def model_post_init(self, __context): self._last_hidden_in_save = True if self.eagle3_layers_to_capture is None: self._last_hidden_in_save = False - self.eagle3_layers_to_capture = { - 1, self.num_layers // 2 - 1, self.num_layers - 4, -1 - } - if -1 not in self.eagle3_layers_to_capture: + elif -1 not in self.eagle3_layers_to_capture: self._last_hidden_in_save = False self.eagle3_layers_to_capture.add(-1) @@ -596,8 +596,10 @@ def num_capture_layers(self): """ Returns the number of layers to capture of the target model. If eagle3_layers_to_capture is not None, return the length of the set. - Otherwise, assume Eagle3 base set and return 3. + Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state). """ + if self.eagle3_layers_to_capture is None: + return 4 return len(self.eagle3_layers_to_capture) diff --git a/tests/unittest/_torch/speculative/test_save_state.py b/tests/unittest/_torch/speculative/test_save_state.py index 8878083d938..406dd4f8cf0 100644 --- a/tests/unittest/_torch/speculative/test_save_state.py +++ b/tests/unittest/_torch/speculative/test_save_state.py @@ -20,9 +20,8 @@ def test_multi_save_state(): disable_overlap_scheduler = False enable_block_reuse = False enable_chunked_prefill = False - layers_to_capture = {1, 4, 8, 16} + layers_to_capture = {10, 11, 12} - # Eagle3 one model works with overlap scheduler and block reuse. total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 80: pytest.skip("Not enough memory to load target + draft model") @@ -30,7 +29,7 @@ def test_multi_save_state(): models_path = llm_models_root() with tempfile.TemporaryDirectory() as temp_dir: - target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct" max_batch_size = 16 kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, @@ -62,26 +61,15 @@ def test_multi_save_state(): sampling_params, streaming=True): pass - + llm_spec.shutdown() assert os.path.exists(os.path.join(temp_dir, "data_1.pt")) # Read in .pt file - saved_data = torch.load(os.path.join(temp_dir, "data_1.pt")) + saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0] - assert saved_data["hidden_state_features"].shape == ( - 1, len(tok_ids), 2048 * len(layers_to_capture)) - assert saved_data["hidden_state"].shape == (1, len(tok_ids), 2048) + assert saved_data["aux_hidden_states"].shape == (len(tok_ids), 2048 * + len(layers_to_capture)) + assert saved_data["hidden_state"].shape == (len(tok_ids), 2048) assert saved_data["input_ids"].tolist() == tok_ids - hidden_states = saved_data["hidden_state_features"] - - # start the HF version of the model - hf_model = AutoModelForCausalLM.from_pretrained( - target_model_dir, torch_dtype=torch.bfloat16, device_map="cuda") - # do the forward pass and collect hidden states - hf_hidden_states = hf_model(tok_ids, output_hidden_states=True) - # compare the hidden states of saved and HF version - concat_hidden_states = torch.cat( - hf_hidden_states[list(layers_to_capture)]) - assert torch.allclose(hidden_states, concat_hidden_states) @pytest.mark.parametrize("layers_to_capture", [{-1}, None]) @@ -92,7 +80,6 @@ def test_save_state(layers_to_capture): enable_block_reuse = False enable_chunked_prefill = False - # Eagle3 one model works with overlap scheduler and block reuse. total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 80: pytest.skip("Not enough memory to load target + draft model") @@ -100,7 +87,7 @@ def test_save_state(layers_to_capture): models_path = llm_models_root() with tempfile.TemporaryDirectory() as temp_dir: - target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct" max_batch_size = 16 kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, @@ -132,19 +119,18 @@ def test_save_state(layers_to_capture): sampling_params, streaming=True): pass - + llm_spec.shutdown() assert os.path.exists(os.path.join(temp_dir, "data_1.pt")) # Read in .pt file - saved_data = torch.load(os.path.join(temp_dir, "data_1.pt")) + saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0] if layers_to_capture is None: - assert saved_data["hidden_state_features"].shape == (1, - len(tok_ids), - 2048 * 3) - assert saved_data["hidden_state"].shape == (1, len(tok_ids), 2048) + assert saved_data["aux_hidden_states"].shape == (len(tok_ids), + 2048 * 3) + assert saved_data["hidden_state"].shape == (len(tok_ids), 2048) assert saved_data["input_ids"].tolist() == tok_ids else: - assert "hidden_state_features" not in saved_data - assert saved_data["hidden_state"].shape == (1, len(tok_ids), 2048) + assert "aux_hidden_states" not in saved_data + assert saved_data["hidden_state"].shape == (len(tok_ids), 2048) assert saved_data["input_ids"].tolist() == tok_ids