Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
Expand Down Expand Up @@ -171,12 +171,12 @@ def execute_model(
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.

Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
Expand Down Expand Up @@ -302,7 +302,12 @@ def execute_model(
outputs.append(output)

if self.return_hidden_states and is_fallback:
output.hidden_states = hidden_states
if use_cuda_graph:
indices = model_input.sampling_metadata\
.selected_token_indices
output.hidden_states = hidden_states[:len(indices)]
else:
output.hidden_states = hidden_states

if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
Expand Down