diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 6127177b4d88..dd7fc025a384 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -44,11 +44,13 @@ async def generate(request: Request) -> Response: The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - stream: whether to stream the results or not. + - is_return_prompt: whether to include the prompt in the response. - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) + is_return_prompt = request_dict.pop("is_return_prompt", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -63,7 +65,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: prompt = request_output.prompt assert prompt is not None text_outputs = [ - prompt + output.text for output in request_output.outputs + (prompt + output.text) if is_return_prompt else output.text + for output in request_output.outputs ] ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") @@ -82,9 +85,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]: assert final_output is not None prompt = final_output.prompt assert prompt is not None - text_outputs = [prompt + output.text for output in final_output.outputs] + text_outputs = [ + (prompt + output.text) if is_return_prompt else output.text + for output in final_output.outputs + ] ret = {"text": text_outputs} return JSONResponse(ret) + def build_app(args: Namespace) -> FastAPI: