Skip to content

Commit 16b9e78

Browse files
committed
Save state first pass
Signed-off-by: Izzy Putterman <[email protected]>
1 parent d1d17db commit 16b9e78

File tree

5 files changed

+194
-2
lines changed

5 files changed

+194
-2
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,9 @@ def forward(
418418
spec_metadata=spec_metadata,
419419
**kwargs,
420420
)
421+
if spec_metadata is not None and spec_metadata.is_final_output_capture(
422+
):
423+
spec_metadata.maybe_capture_final_hidden_states(hidden_states)
421424

422425
if self.draft_model is not None:
423426
# get logits

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class SpeculativeDecodingMode(IntEnum):
1717
NGRAM = auto()
1818
DRAFT_TARGET = auto()
1919
USER_PROVIDED = auto()
20+
SAVE_HIDDEN_STATES = auto()
2021
NONE = auto()
2122
AUTO = auto()
2223

@@ -50,6 +51,9 @@ def is_none(self):
5051
def is_draft_target(self):
5152
return self == SpeculativeDecodingMode.DRAFT_TARGET
5253

54+
def is_save_hidden_states(self):
55+
return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES
56+
5357
def without_logits(self):
5458
return self.is_mtp() or self.is_eagle3_one_model()
5559

@@ -82,7 +86,7 @@ def has_spec_decoder(self):
8286

8387
def has_spec_drafter(self):
8488
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
85-
) or self.is_user_provided()
89+
) or self.is_user_provided() or self.is_save_hidden_states()
8690

8791
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
8892
"""
@@ -185,6 +189,9 @@ def create_cuda_graph_metadata(self, max_batch_size: int):
185189
cuda_graph_metadata.__post_init__()
186190
return cuda_graph_metadata
187191

192+
def is_final_output_capture(self):
193+
return False
194+
188195
def maybe_capture_hidden_states(self, layer_id: int,
189196
hidden_states: torch.Tensor,
190197
residual: torch.Tensor) -> None:
@@ -193,6 +200,13 @@ def maybe_capture_hidden_states(self, layer_id: int,
193200
model. Use this method to record them. By default, does nothing.
194201
"""
195202

203+
def maybe_capture_final_hidden_states(self,
204+
hidden_states: torch.Tensor) -> None:
205+
"""
206+
Some spec decode algorithms require hidden states from the target
207+
model. Use this method to record them. By default, does nothing.
208+
"""
209+
196210
@property
197211
def all_rank_num_tokens(self) -> Optional[List[int]]:
198212
return self._all_rank_num_tokens
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Optional
2+
3+
from tensorrt_llm._utils import local_mpi_rank
4+
5+
from ..pyexecutor.llm_request import LlmRequest
6+
from ..pyexecutor.resource_manager import ResourceManager
7+
from ..pyexecutor.scheduler import ScheduledRequests
8+
from .drafter import Drafter
9+
from .eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
10+
11+
12+
@dataclass
13+
class SaveHiddenStatesSpecMetadata(Eagle3SpecMetadata):
14+
save_last_layer_post_norm: bool = False
15+
16+
def is_final_output_capture(self):
17+
return self.save_last_layer_post_norm
18+
19+
def maybe_capture_final_hidden_states(self,
20+
hidden_states: torch.Tensor) -> None:
21+
if self.save_last_layer_post_norm:
22+
# Assume no chunking, BS=1
23+
eagle3_hidden_states = self.eagle3_resource_manager.last_hidden_states
24+
eagle3_hidden_states.copy_(hidden_states)
25+
26+
27+
class SaveHiddenStatesResourceManager(Eagle3ResourceManager):
28+
29+
def __init__(self, config: "SaveHiddenStatesDecodingConfig",
30+
dtype: torch.dtype, hidden_size: int, max_num_requests: int,
31+
max_seq_len: int, max_num_tokens: int):
32+
super().__init__(config, dtype, hidden_size, max_num_requests,
33+
max_seq_len, max_num_tokens)
34+
self.last_hidden_states = None
35+
if config.save_last_layer_post_norm:
36+
self.last_hidden_states = torch.empty(
37+
(max_num_tokens, self.hidden_size),
38+
dtype=self.dtype,
39+
device='cuda')
40+
41+
42+
class SaveHiddenStatesDrafter(Drafter):
43+
44+
def __init__(
45+
self,
46+
spec_config: SaveHiddenStatesDecodingConfig,
47+
):
48+
super().__init__(spec_config.max_concurrency)
49+
self.spec_config = spec_config
50+
self.max_draft_len = spec_config.max_draft_len
51+
self._iter = 0
52+
self._output_directory = spec_config.output_directory
53+
self._file_prefix = spec_config.file_prefix
54+
self._write_interval = spec_config.write_interval
55+
self._saved_state = []
56+
57+
def _process_request(self, request: LlmRequest) -> None:
58+
out_dict = {}
59+
if local_mpi_rank() != 0:
60+
input_ids = torch.tensor(list(request.get_tokens(0)),
61+
dtype=torch.long,
62+
device='cpu')
63+
hidden_size = resource_manager.hidden_size
64+
if self.spec_config.save_last_layer_post_norm:
65+
hidden_states = resource_manager.last_hidden_states.cpu().clone(
66+
)
67+
else:
68+
hidden_states = resource_manager.hidden_states[:,
69+
-hidden_size:].cpu(
70+
).clone()
71+
72+
out_dict = {
73+
"id":
74+
self.iteration,
75+
"input_ids":
76+
input_ids,
77+
"hidden_state_features":
78+
resource_manager.hidden_states.cpu().clone(),
79+
"hidden_state":
80+
hidden_states,
81+
}
82+
83+
self._saved_state.append(out_dict)
84+
85+
def _write_to_file(self) -> None:
86+
if local_mpi_rank() == 0 and self.iteration != self.start_iteration:
87+
output_path = os.path.join(self._output_directory,
88+
f"{self._file_prefix}_{self._iter}.pt")
89+
torch.save(self._saved_state, output_path)
90+
self._saved_state = []
91+
92+
def prepare_draft_tokens(
93+
self,
94+
scheduled_requests: ScheduledRequests,
95+
resource_manager: Optional[ResourceManager] = None,
96+
) -> None:
97+
for request in sorted(
98+
scheduled_requests.context_requests,
99+
key=lambda r:
100+
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
101+
):
102+
request.py_max_new_tokens = 1
103+
self._process_request(request, resource_manager)
104+
if self._iter % self._write_interval == 0:
105+
self._write_to_file()
106+
self._iter += 1
107+
# Pad length to `self.max_draft_len`
108+
if len(draft_tokens) > 0:
109+
pad_length = self.max_draft_len - len(draft_tokens)
110+
draft_tokens.extend([0] * pad_length)
111+
request.py_draft_tokens = draft_tokens

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
1212
MTPSpecMetadata, MTPWorker)
1313
from .ngram import NGramDrafter, NGramPoolManager
14+
from .save_hidden_state import (SaveHiddenStatesResourceManager,
15+
SaveHiddenStatesSpecMetadata)
1416

1517

1618
def get_spec_metadata(spec_config,
@@ -48,6 +50,20 @@ def get_spec_metadata(spec_config,
4850
hidden_size=model_config.hidden_size,
4951
max_num_tokens=max_num_tokens,
5052
)
53+
if spec_config.spec_dec_mode.is_save_hidden_states():
54+
return SaveHiddenStatesSpecMetadata(
55+
max_draft_len=spec_config.max_draft_len,
56+
spec_dec_mode=spec_config.spec_dec_mode,
57+
max_num_requests=max_num_requests,
58+
num_layers=model_config.num_hidden_layers,
59+
hidden_size=model_config.hidden_size,
60+
max_num_tokens=max_num_tokens,
61+
dtype=model_config.torch_dtype,
62+
is_draft_model=is_draft_model,
63+
eagle3_resource_manager=spec_resource_manager,
64+
num_capture_layers=spec_config.num_capture_layers,
65+
save_last_layer_post_norm=spec_config.save_last_layer_post_norm,
66+
)
5167
if spec_config.spec_dec_mode.is_draft_target() or \
5268
spec_config.spec_dec_mode.is_ngram() or \
5369
spec_config.spec_dec_mode.is_user_provided():
@@ -95,6 +111,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
95111
max_seq_len,
96112
max_num_tokens,
97113
)
114+
if spec_dec_mode.is_save_hidden_states():
115+
return SaveHiddenStatesResourceManager(
116+
spec_config,
117+
draft_model_engine.model.config.torch_dtype,
118+
model_config.hidden_size,
119+
max_num_requests,
120+
max_seq_len,
121+
max_num_tokens,
122+
)
98123
if spec_dec_mode.is_ngram():
99124
return NGramPoolManager(spec_config, max_num_requests)
100125
if spec_dec_mode.is_user_provided():

tensorrt_llm/llmapi/llm_args.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from enum import Enum, EnumMeta
1010
from pathlib import Path
1111
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
12-
Type, TypeAlias, TypeVar, Union, get_args, get_origin)
12+
Set, Type, TypeAlias, TypeVar, Union, get_args, get_origin)
1313

1414
import torch
1515
import yaml
@@ -361,9 +361,11 @@ def from_dict(cls, data: dict):
361361
"MTP": MTPDecodingConfig,
362362
"Medusa": MedusaDecodingConfig,
363363
"Eagle": EagleDecodingConfig,
364+
"SaveState": SaveHiddenStatesDecodingConfig,
364365
"Lookahead": LookaheadDecodingConfig,
365366
"NGram": NGramDecodingConfig,
366367
"DraftTarget": DraftTargetDecodingConfig,
368+
"SaveState": SaveHiddenStatesDecodingConfig,
367369
"UserProvided": UserProvidedDecodingConfig,
368370
"AUTO": AutoDecodingConfig,
369371
}
@@ -444,6 +446,31 @@ def spec_dec_mode(self):
444446
return TorchSpeculativeDecodingMode.EAGLE3
445447

446448

449+
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
450+
output_directory: str
451+
write_interval: int = 20
452+
file_prefix: str = "data"
453+
eagle3_layers_to_capture: Optional[Set[int]] = None
454+
save_last_layer_post_norm: bool = False
455+
456+
@classmethod
457+
def from_dict(cls, data: dict):
458+
return cls(**data)
459+
460+
decoding_type: ClassVar[str] = "SaveState"
461+
462+
def validate(self) -> None:
463+
if self.output_directory is None or not self.eagle3_layers_to_capture:
464+
raise ValueError(
465+
"Save directory and layers to capture must be provided")
466+
467+
@functools.cached_property
468+
def spec_dec_mode(self):
469+
from tensorrt_llm._torch.speculative.interface import \
470+
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
471+
return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES
472+
473+
447474
class UserProvidedDecodingConfig(DecodingBaseConfig):
448475
# Cannot use real type annotations due to circular imports
449476
drafter: object # Type is Drafter
@@ -921,6 +948,7 @@ def supports_backend(self, backend: str) -> bool:
921948
MTPDecodingConfig,
922949
NGramDecodingConfig,
923950
UserProvidedDecodingConfig,
951+
SaveHiddenStateDecodingConfig,
924952
AutoDecodingConfig,
925953
]]
926954

@@ -1695,6 +1723,17 @@ def validate_speculative_config(self):
16951723
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
16961724
self.build_config.max_draft_len = self.speculative_config.max_draft_len
16971725

1726+
elif isinstance(self.speculative_config,
1727+
SaveHiddenStatesDecodingConfig):
1728+
assert self.backend in ['pytorch']
1729+
assert self.speculative_config.max_draft_len > 0
1730+
self.build_config.max_batch_size = 1
1731+
self.max_batch_size = 1
1732+
self.disable_overlap_scheduler = True
1733+
self.cuda_graph_config = None
1734+
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
1735+
self.build_config.max_draft_len = 1
1736+
16981737
else:
16991738
raise ValueError(
17001739
f"Unrecognized speculative config type {type(self.speculative_config)}"

0 commit comments

Comments
 (0)