diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index bc1b3e2319d0..3ad9b4993327 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -133,7 +133,7 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase, def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): """Determines if draft_model_runner GPU multi-step can be used. Currently required conditions are: - 1. Only decodes + 1. Only decodes 2. Only flash-attn 3. No LORA 4. No prompt_adapter_config @@ -171,12 +171,12 @@ def execute_model( num_steps: int = 1, **kwargs, ) -> Optional[List[SamplerOutput]]: - """Executes num_steps forward passes with advacement of input tensors + """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. Optimizations used: 1. Input tensors are updated on the GPU directly - 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need them since we do batch expansion later that uses GPU outputs) 3. Reuses sampling tensors (since we run only decodes and they have a repeating sampling logic) @@ -302,7 +302,12 @@ def execute_model( outputs.append(output) if self.return_hidden_states and is_fallback: - output.hidden_states = hidden_states + if use_cuda_graph: + indices = model_input.sampling_metadata\ + .selected_token_indices + output.hidden_states = hidden_states[:len(indices)] + else: + output.hidden_states = hidden_states if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: