From 20506dd41da5a33e3ad684075274a9123a9b8a63 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 12 Feb 2025 08:29:32 +0000 Subject: [PATCH 1/8] stream with own chunk type + usage info Signed-off-by: NickLucche --- .../serving/openai_compatible_server.md | 4 + .../openai_chat_completion_client.py | 34 ++-- .../openai_transcription_client.py | 63 +++++-- .../openai/test_transcription_validation.py | 71 ++++++++ vllm/entrypoints/openai/protocol.py | 40 ++++- .../openai/serving_transcription.py | 166 ++++++++++++++++-- 6 files changed, 333 insertions(+), 45 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 5ab46da90ea6..0880a4530d8c 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. +:::{note} +To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. +::: + Code example: diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index a81562041130..12af3b001d20 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -15,24 +15,20 @@ models = client.models.list() model = models.data[0].id -chat_completion = client.chat.completions.create( - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - model=model, -) +chat_completion = client.chat.completions.create(messages=[{ + "role": + "system", + "content": + "You are a helpful assistant." +}, { + "role": "user", + "content": "Say Hi" +}], + model=model, + stream=True) print("Chat completion results:") -print(chat_completion) +# print(chat_completion) +for chunk in chat_completion: + print(chunk, '\n\n') + # print(chunk.choices[0].delta.content) diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index bd3c02a8a95e..50960f8c3373 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from openai import OpenAI +import asyncio +import json + +import httpx +from openai import AsyncOpenAI from vllm.assets.audio import AudioAsset @@ -9,15 +13,54 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( +client = AsyncOpenAI( api_key=openai_api_key, base_url=openai_api_base, ) -with open(str(mary_had_lamb), "rb") as f: - transcription = client.audio.transcriptions.create( - file=f, - model="openai/whisper-large-v3", - language="en", - response_format="text", - temperature=0.0) - print("transcription result:", transcription) + + +async def main(): + with open(str(mary_had_lamb), "rb") as f: + transcription = await client.audio.transcriptions.create( + file=f, + model="openai/whisper-small", + language="en", + response_format="json", + temperature=0.0) + print("transcription result:", transcription.text) + + +asyncio.run(main()) + + +# OpenAI Transcription API client does not support streaming. +async def stream_openai_response(): + data = { + "language": "en", + 'stream': True, + "model": "openai/whisper-large-v3", + } + url = openai_api_base + "/audio/transcriptions" + print("transcription result:", end=' ') + async with httpx.AsyncClient() as client: + with open(str(winning_call), "rb") as f: + async with client.stream('POST', url, files={'file': f}, + data=data) as response: + async for line in response.aiter_lines(): + # Each line is a JSON object prefixed with 'data: ' + if line: + if line.startswith('data: '): + line = line[len('data: '):] + # Last chunk, stream ends + if line.strip() == '[DONE]': + break + # Parse the JSON response + chunk = json.loads(line) + # Extract and print the content + content = chunk['choices'][0].get('delta', + {}).get('content') + print(content, end='') + + +# Run the asynchronous function +asyncio.run(stream_openai_response()) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 5d4a5de4badd..453d2d14db16 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -3,12 +3,14 @@ # imports for guided decoding tests import io import json +from unittest.mock import patch import librosa import numpy as np import openai import pytest import soundfile as sf +from openai._base_client import AsyncAPIClient from vllm.assets.audio import AudioAsset @@ -120,3 +122,72 @@ async def test_completion_endpoints(): res = await client.completions.create(model=model_name, prompt="Hello") assert res.code == 400 assert res.message == "The model does not support Completions API" + + +async def test_streaming_response(winning_call): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + transcription = "" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res_no_stream = await client.audio.transcriptions.create( + model=model_name, + file=winning_call, + response_format="json", + language="en", + temperature=0.0) + # Unfortunately this only works when the openai client is patched + # to use streaming mode, not exposed in the transcription api. + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create( + model=model_name, + file=winning_call, + language="en", + temperature=0.0, + extra_body=dict(stream=True)) + # Reconstruct from chunks and validate + async for chunk in res: + # just a chunk + text = chunk.choices[0]['delta']['content'] + transcription += text + + assert transcription == res_no_stream.text + + +@pytest.mark.asyncio +async def test_stream_options(winning_call): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create( + model=model_name, + file=winning_call, + language="en", + temperature=0.0, + extra_body=dict(stream=True, + stream_include_usage=True, + stream_continuous_usage_stats=True)) + final = False + continuous = True + async for chunk in res: + if not len(chunk.choices): + # final usage sent + final = True + else: + continuous = continuous and hasattr(chunk, 'usage') + assert final and continuous diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2c740caf20fb..ab74a6592cf6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1289,6 +1289,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class TranscriptionResponseStreamChoice(OpenAIBaseModel): + delta: DeltaMessage + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class TranscriptionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}") + object: Literal["transcription.chunk"] = "transcription.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[TranscriptionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + class BatchRequestInput(OpenAIBaseModel): """ The per-line object of the batch input file. @@ -1514,6 +1529,15 @@ class TranscriptionRequest(OpenAIBaseModel): timestamps incurs additional latency. """ + stream: Optional[bool] = False + """Custom field not present in the original OpenAI definition. When set, + it will enable output to be streamed in a similar fashion as the Chat + Completion endpoint. + """ + # Flattened stream option to simplify form data. + stream_include_usage: Optional[bool] = False + stream_continuous_usage_stats: Optional[bool] = False + # Default sampling parameters for transcription requests. _DEFAULT_SAMPLING_PARAMS: dict = { "temperature": 0, @@ -1534,7 +1558,21 @@ def to_sampling_params( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens) + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY) + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] + stream = data.get("stream", False) + if any(bool(data.get(so, False)) for so in stream_opts) and not stream: + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data # Transcription response objects diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 402a0bb7a6b0..add99b75df75 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -2,23 +2,25 @@ import asyncio import io from collections.abc import AsyncGenerator -from typing import Optional, Union, cast +import time +from math import ceil +from typing import AsyncGenerator, Final, Optional, Union, cast from fastapi import Request from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ErrorResponse, - RequestResponseMetadata, - TranscriptionRequest, - TranscriptionResponse, - TranscriptionResponseVerbose) +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest, + TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionResponseVerbose, TranscriptionStreamResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.outputs import RequestOutput +from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import PlaceholderModule try: @@ -140,8 +142,6 @@ # As per https://platform.openai.com/docs/guides/speech-to-text#overview. # TODO configurable MAX_AUDIO_CLIP_FILESIZE_MB = 25 -# TODO get from processor.feature_extractor.chunk_length -MAX_AUDIO_CLIP_DURATION_S = 30 class OpenAIServingTranscription(OpenAIServing): @@ -163,6 +163,11 @@ def __init__( self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) + processor = cached_get_processor(model_config.model) + self.max_audio_clip_s = processor.feature_extractor.chunk_length + self.model_sr = processor.feature_extractor.sampling_rate + self.hop_length = processor.feature_extractor.hop_length + if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", @@ -198,9 +203,11 @@ async def _preprocess_transcription( with io.BytesIO(audio_data) as bytes_: y, sr = librosa.load(bytes_) - if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: + + duration = librosa.get_duration(y=y, sr=sr) + if duration > self.max_audio_clip_s: raise ValueError( - f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " + f"Maximum clip duration ({self.max_audio_clip_s}s) " "exceeded.") prompt = { @@ -213,7 +220,7 @@ async def _preprocess_transcription( "decoder_prompt": f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" } - return cast(PromptType, prompt) + return cast(PromptType, prompt), duration # TODO (varun) : Make verbose response work ! async def create_transcription( @@ -240,8 +247,7 @@ async def create_transcription( return self.create_error_response( "Currently only support response_format `text` or `json`") - # TODO cmpl->transcription? - request_id = f"cmpl-{self._base_request_id(raw_request)}" + request_id = f"trsc-{self._base_request_id(raw_request)}" request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: @@ -261,7 +267,7 @@ async def create_transcription( "Currently do not support PromptAdapter for Transcription." ) - prompt = await self._preprocess_transcription( + prompt, duration_s = await self._preprocess_transcription( request=request, audio_data=audio_data, ) @@ -293,7 +299,12 @@ async def create_transcription( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - # TODO(rob): figure out a way to pipe streaming in. + if request.stream: + return self.transcription_stream_generator(request, + result_generator, + request_id, + request_metadata, + duration_s) # Non-streaming response. try: assert result_generator is not None @@ -305,3 +316,128 @@ async def create_transcription( except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + + async def transcription_stream_generator( + self, request: TranscriptionRequest, + result_generator: AsyncGenerator[RequestOutput, None], + request_id: str, request_metadata: RequestResponseMetadata, + audio_duration_s: float) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + model_name = request.model + chunk_object_type: Final = "transcription.chunk" + first_iteration = True + + completion_tokens = 0 + num_prompt_tokens = 0 + + include_usage = request.stream_include_usage \ + if request.stream_include_usage else False + include_continuous_usage = request.stream_continuous_usage_stats\ + if include_usage and request.stream_continuous_usage_stats\ + else False + + try: + async for res in result_generator: + # On first result. + if res.prompt_token_ids is not None: + # Do not account the 4-tokens `<|startoftranscript|>..` + # Could be negative when language token is not specified. + num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0) + # NOTE(NickLucche) user can't pass encoder prompts directly + # at least not to Whisper. One indicator of the encoder + # amount of processing is the log-mel spectogram length. + num_prompt_tokens = ceil(audio_duration_s * self.model_sr / + self.hop_length) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + # Fist delta message. + choice_data = TranscriptionResponseStreamChoice( + delta=DeltaMessage(content="", ), finish_reason=None) + chunk = TranscriptionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + first_iteration = False + + # Just one output (n=1) supported. + output = res.outputs[0] + + delta_message = DeltaMessage(content=output.text) + completion_tokens += len(output.token_ids) + + if output.finish_reason is None: + # Still generating, send delta update. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message) + else: + # Model is finished generating. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + + chunk = TranscriptionStreamResponse(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Once the final token is handled, if stream_options.include_usage + # is sent, send the usage. + if include_usage: + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + + final_usage_chunk = TranscriptionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" From 878a499fdf89355a11c0b8e02e7396d3ade14b7f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 14 Feb 2025 17:43:20 +0000 Subject: [PATCH 2/8] rebase leftover Signed-off-by: NickLucche --- .../openai_chat_completion_client.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index 12af3b001d20..a81562041130 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -15,20 +15,24 @@ models = client.models.list() model = models.data[0].id -chat_completion = client.chat.completions.create(messages=[{ - "role": - "system", - "content": - "You are a helpful assistant." -}, { - "role": "user", - "content": "Say Hi" -}], - model=model, - stream=True) +chat_completion = client.chat.completions.create( + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }], + model=model, +) print("Chat completion results:") -# print(chat_completion) -for chunk in chat_completion: - print(chunk, '\n\n') - # print(chunk.choices[0].delta.content) +print(chat_completion) From a01373ed3470f8a7c86b176c9eea0affd5b21b14 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 14 Feb 2025 17:45:52 +0000 Subject: [PATCH 3/8] rebase leftover Signed-off-by: NickLucche --- tests/entrypoints/openai/test_transcription_validation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 453d2d14db16..29571bcd7649 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -124,6 +124,7 @@ async def test_completion_endpoints(): assert res.message == "The model does not support Completions API" +@pytest.mark.asyncio async def test_streaming_response(winning_call): model_name = "openai/whisper-small" server_args = ["--enforce-eager"] From ab8802f82ab9fafd8ddbf6f915d76db1c9c5a783 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 14 Feb 2025 18:12:27 +0000 Subject: [PATCH 4/8] minor Signed-off-by: NickLucche --- examples/online_serving/openai_transcription_client.py | 10 +++++----- vllm/entrypoints/openai/serving_transcription.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 50960f8c3373..494e7c8ebe12 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -3,7 +3,7 @@ import json import httpx -from openai import AsyncOpenAI +from openai import OpenAI from vllm.assets.audio import AudioAsset @@ -13,15 +13,15 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = AsyncOpenAI( +client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) -async def main(): +def sync_openai(): with open(str(mary_had_lamb), "rb") as f: - transcription = await client.audio.transcriptions.create( + transcription = client.audio.transcriptions.create( file=f, model="openai/whisper-small", language="en", @@ -30,7 +30,7 @@ async def main(): print("transcription result:", transcription.text) -asyncio.run(main()) +sync_openai() # OpenAI Transcription API client does not support streaming. diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index add99b75df75..4a3267b8befa 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -346,8 +346,8 @@ async def transcription_stream_generator( # NOTE(NickLucche) user can't pass encoder prompts directly # at least not to Whisper. One indicator of the encoder # amount of processing is the log-mel spectogram length. - num_prompt_tokens = ceil(audio_duration_s * self.model_sr / - self.hop_length) + num_prompt_tokens += ceil(audio_duration_s * + self.model_sr / self.hop_length) # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST From 121b226901faa55408f9145b968491b80699a420 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 17 Feb 2025 08:05:34 +0000 Subject: [PATCH 5/8] types Signed-off-by: NickLucche --- vllm/entrypoints/openai/serving_transcription.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 4a3267b8befa..74e28bc9eade 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -4,7 +4,7 @@ from collections.abc import AsyncGenerator import time from math import ceil -from typing import AsyncGenerator, Final, Optional, Union, cast +from typing import AsyncGenerator, Final, Optional, Tuple, Union, cast from fastapi import Request @@ -14,7 +14,7 @@ from vllm.entrypoints.openai.protocol import ( DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest, TranscriptionResponse, TranscriptionResponseStreamChoice, - TranscriptionResponseVerbose, TranscriptionStreamResponse, UsageInfo) + TranscriptionStreamResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import PromptType @@ -177,7 +177,7 @@ async def _preprocess_transcription( self, request: TranscriptionRequest, audio_data: bytes, - ) -> PromptType: + ) -> Tuple[PromptType, float]: # Validate request # TODO language should be optional and can be guessed. # For now we default to en. See @@ -226,7 +226,7 @@ async def _preprocess_transcription( async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request - ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose, + ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], ErrorResponse]: """Transcription API similar to OpenAI's API. From 96ee3521d11c8045ea32268a62f8532bde5a260d Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 26 Feb 2025 13:55:23 +0000 Subject: [PATCH 6/8] remove first empty chunk from stream response Signed-off-by: NickLucche --- .../openai/serving_transcription.py | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 74e28bc9eade..788e33f08ddb 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -325,7 +325,6 @@ async def transcription_stream_generator( created_time = int(time.time()) model_name = request.model chunk_object_type: Final = "transcription.chunk" - first_iteration = True completion_tokens = 0 num_prompt_tokens = 0 @@ -352,30 +351,9 @@ async def transcription_stream_generator( # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). - if first_iteration: - # Fist delta message. - choice_data = TranscriptionResponseStreamChoice( - delta=DeltaMessage(content="", ), finish_reason=None) - chunk = TranscriptionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - - # if continuous usage stats are requested, add it - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - total_tokens=num_prompt_tokens) - - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - - first_iteration = False # Just one output (n=1) supported. + assert len(res.outputs) == 1 output = res.outputs[0] delta_message = DeltaMessage(content=output.text) From b980d1e16083a08a9f0516436abb4db6c8d5dd0a Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 5 Mar 2025 13:40:41 +0000 Subject: [PATCH 7/8] fix List Signed-off-by: NickLucche --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ab74a6592cf6..6b519e1b7041 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1300,7 +1300,7 @@ class TranscriptionStreamResponse(OpenAIBaseModel): object: Literal["transcription.chunk"] = "transcription.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[TranscriptionResponseStreamChoice] + choices: list[TranscriptionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) From e9683246d6d196a6b3b0a577d350b9a452b339e0 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 5 Mar 2025 14:00:19 +0000 Subject: [PATCH 8/8] fix Tuple Signed-off-by: NickLucche --- vllm/entrypoints/openai/serving_transcription.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 788e33f08ddb..13565d0ef8dd 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import io -from collections.abc import AsyncGenerator import time +from collections.abc import AsyncGenerator from math import ceil -from typing import AsyncGenerator, Final, Optional, Tuple, Union, cast +from typing import Final, Optional, Union, cast from fastapi import Request @@ -177,7 +177,7 @@ async def _preprocess_transcription( self, request: TranscriptionRequest, audio_data: bytes, - ) -> Tuple[PromptType, float]: + ) -> tuple[PromptType, float]: # Validate request # TODO language should be optional and can be guessed. # For now we default to en. See