diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 9f3b0e8ae079..86b5e1e0ab7c 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -8,12 +8,13 @@ from unittest.mock import Mock import pytest +import torch -from vllm import LLM +from vllm import LLM, envs from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 -from ..conftest import VllmRunner +from ..conftest import HfRunner, VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test @@ -43,11 +44,26 @@ def test_vllm_gc_ed(): assert weak_llm() is None +def _fix_prompt_embed_outputs( + vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, + example_prompts: list[str]) -> list[tuple[list[int], str]]: + fixed_vllm_outputs = [] + for vllm_output, hf_input, prompt in zip( + vllm_outputs, hf_model.get_inputs(example_prompts), + example_prompts): + hf_input_ids = hf_input["input_ids"].tolist()[0] + fixed_vllm_outputs.append( + (hf_input_ids + vllm_output[0][len(hf_input_ids):], + prompt + vllm_output[1])) + return fixed_vllm_outputs + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models( monkeypatch: pytest.MonkeyPatch, hf_runner, @@ -56,8 +72,13 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + enable_prompt_embeds: bool, ) -> None: + if enable_prompt_embeds and envs.is_set( + "VLLM_USE_V1") and envs.VLLM_USE_V1: + pytest.skip("enable_prompt_embeds is not supported in v1.") + if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") @@ -78,14 +99,25 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if enable_prompt_embeds: + with torch.no_grad(): + prompt_embeds = hf_model.get_prompt_embeddings( + example_prompts) with VllmRunner(model, max_model_len=8192, dtype=dtype, enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, gpu_memory_utilization=0.7) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + if enable_prompt_embeds: + vllm_outputs = vllm_model.generate_greedy( + prompt_embeds, max_tokens) + vllm_outputs = _fix_prompt_embed_outputs( + vllm_outputs, hf_model, example_prompts) + else: + vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -108,6 +140,7 @@ def test_models( ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( monkeypatch: pytest.MonkeyPatch, hf_runner, @@ -117,14 +150,22 @@ def test_models_distributed( distributed_executor_backend: str, attention_backend: str, test_suite: str, + enable_prompt_embeds: bool, ) -> None: + if enable_prompt_embeds and envs.is_set( + "VLLM_USE_V1") and envs.VLLM_USE_V1: + pytest.skip("enable_prompt_embeds is not supported in v1.") + if test_suite != TARGET_TEST_SUITE: pytest.skip(f"Skip test for {test_suite}") with monkeypatch.context() as monkeypatch_context: if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa - # test Ray Compiled Graph + if enable_prompt_embeds: + pytest.skip( + "enable_prompt_embeds does not work with ray compiled dag." + ) monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") @@ -147,12 +188,26 @@ def test_models_distributed( dtype=dtype, tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if enable_prompt_embeds: + with hf_runner(model, dtype=dtype) as hf_model: + with torch.no_grad(): + prompt_embeds = hf_model.get_prompt_embeddings( + example_prompts) + vllm_outputs = vllm_model.generate_greedy( + prompt_embeds, max_tokens) + vllm_outputs = _fix_prompt_embed_outputs( + vllm_outputs, hf_model, example_prompts) + hf_outputs = hf_model.generate_greedy( + example_prompts, max_tokens) + else: + vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy( + example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, diff --git a/tests/conftest.py b/tests/conftest.py index c5700179c228..19c2c6247129 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -430,6 +430,15 @@ def get_inputs( return all_inputs + def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]: + all_inputs = self.get_inputs(prompts) + embeddings = [] + for inputs in all_inputs: + input_ids = self.wrap_device(inputs)["input_ids"] + embedding = self.model.get_input_embeddings()(input_ids).squeeze(0) + embeddings.append(embedding) + return embeddings + def classify(self, prompts: list[str]) -> list[str]: # output is final logits all_inputs = self.get_inputs(prompts) diff --git a/vllm/sequence.py b/vllm/sequence.py index 5aa9ae62f542..f5f9c56a7db2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -112,12 +112,12 @@ class RequestMetrics: will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. spec_token_acceptance_counts: number of accepted speculative tokens at - each position; the first token is from + each position; the first token is from the target model and is always accepted; - e.g., when it's [10, 8, 4, 2] for a req, + e.g., when it's [10, 8, 4, 2] for a req, it means there were 10 forward passes in - total, and there were 8, 4, 2 accepted - tokens at 1st, 2nd, 3rd speculation step. + total, and there were 8, 4, 2 accepted + tokens at 1st, 2nd, 3rd speculation step. """ arrival_time: float last_token_time: float @@ -714,9 +714,9 @@ class SequenceGroup: trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. priority: User-defined priority of the request. - draft_size: The number of speculative tokens plus one from the target + draft_size: The number of speculative tokens plus one from the target model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than + for single-draft speculative decoding but larger than that for multi-draft SD (currently not supported). """ @@ -1123,7 +1123,7 @@ def __repr__(self) -> str: self.output_embed.shape if self.output_embed is not None else None return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " - f"output_embed.shape={output_embed_shape}" + f"output_embed.shape={output_embed_shape}, " f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8a294de45c81..12025617e512 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) @@ -872,7 +872,7 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = list[int]() - inputs_embeds_lst = list[torch.Tensor]() + inputs_embeds_list = list[torch.Tensor]() token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: @@ -880,15 +880,15 @@ def build(self) -> ModelInputForGPU: for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds_lst.append( + inputs_embeds_list.append( inter_data.inputs_embeds.to( dtype=self.runner.model_config.dtype, device=self.runner.device)) inputs_embeds: Optional[torch.Tensor] - if len(inputs_embeds_lst) == 0: + if len(inputs_embeds_list) == 0: inputs_embeds = None else: - inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( + inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( dtype=self.runner.model_config.dtype, device=self.runner.device) assert len(inputs_embeds) == len(input_tokens) @@ -1893,50 +1893,60 @@ def execute_model( logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) - if not self.is_driver_worker: - return [] + if self.is_driver_worker: + if model_input.async_callback is not None: + model_input.async_callback() - if model_input.async_callback is not None: - model_input.async_callback() + # Sample the next token. + assert isinstance(self.sampler, Sampler) + orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = True - # Sample the next token. - assert isinstance(self.sampler, Sampler) - orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor - if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = True - - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) + output: SamplerOutput = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + # If there are multiple workers, we are still tracking the + # latency from the start time of the driver worker to the end + # time of the driver worker. The model forward time will then + # end up covering the communication time as well. + output.model_forward_time = (orig_model_forward_time + + model_forward_time) if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = \ - orig_include_gpu_probs_tensor - if output.sampled_token_ids is not None: - output.sampled_token_embeds = self.model.get_input_embeddings( - output.sampled_token_ids.squeeze(1)) - - for token_embed, sequence_group_output in zip( - output.sampled_token_embeds, output.outputs): - assert len(sequence_group_output.samples) == 1 - sequence_group_output.samples[0].output_embed = token_embed + if self.is_driver_worker: + sampled = broadcast_tensor_dict( + {"token_ids": output.sampled_token_ids}) + else: + sampled = broadcast_tensor_dict() + if sampled["token_ids"] is not None: + sampled_token_embeds = self.model.get_input_embeddings( + sampled["token_ids"].squeeze(1)) + if self.is_driver_worker: + self.sampler.include_gpu_probs_tensor = \ + orig_include_gpu_probs + + output.sampled_token_embeds = sampled_token_embeds + + for token_embed, sequence_group_output in zip( + output.sampled_token_embeds, output.outputs): + assert len(sequence_group_output.samples) == 1 + sequence_group_output.samples[ + 0].output_embed = token_embed + + if not self.is_driver_worker: + return [] if self.return_hidden_states: # we only need to pass hidden states of most recent token