diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 98e9ea0fc61a..dac7fe81903a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -24,7 +24,8 @@ RequestResponseMetadata, ToolCall, UsageInfo) from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, ReasoningParserManager) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + clamp_prompt_logprobs) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( @@ -846,7 +847,7 @@ async def chat_completion_full_generator( model=model_name, choices=choices, usage=usage, - prompt_logprobs=final_res.prompt_logprobs, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), ) return response diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index ed09af84f64b..df7d575d5b98 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -23,7 +23,8 @@ RequestResponseMetadata, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + clamp_prompt_logprobs) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -396,13 +397,7 @@ def request_output_to_completion_response( for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids assert prompt_token_ids is not None - prompt_logprobs = final_res.prompt_logprobs - if prompt_logprobs: - for logprob_dict in prompt_logprobs: - if logprob_dict: - for logprob_values in logprob_dict.values(): - if logprob_values.logprob == float('-inf'): - logprob_values.logprob = -9999.0 + prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) prompt_text = final_res.prompt token_ids: GenericSequence[int] diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 59333dbfd24e..125812d2cc01 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -42,7 +42,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob +from vllm.sequence import Logprob, PromptLogprobs from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -535,3 +535,18 @@ def _get_model_name(self, if model_name is None: return self.models.base_model_paths[0].name return model_name + + +def clamp_prompt_logprobs( + prompt_logprobs: Union[PromptLogprobs, + None]) -> Union[PromptLogprobs, None]: + if prompt_logprobs is None: + return prompt_logprobs + + for logprob_dict in prompt_logprobs: + if logprob_dict is None: + continue + for logprob_values in logprob_dict.values(): + if logprob_values.logprob == float('-inf'): + logprob_values.logprob = -9999.0 + return prompt_logprobs