diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 543701ed144e..5f5c7140fce4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,6 +3,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import copy import json import time from http import HTTPStatus @@ -722,66 +723,118 @@ def _get_guided_json_from_tool( if tool_name not in tools: raise ValueError( f"Tool '{tool_name}' has not been passed in `tools`.") - tool = tools[tool_name] - return tool.parameters + self.generate_function_call_array_schema_anyOf(self.tools, + tool_name=tool_name) if self.tool_choice == "required": - # Pydantic schema generation cannot be used since the JSON schema - # has to be constructed for a specific instantiation of a tool list - # so that parameters of a function are correctly generated - # based on the chosen function name - def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: - return { - "properties": { - "name": { - "type": "string", - "enum": [tool.function.name] - }, - # parameters are always generated as '{}' in the final - # output if they are missing from the request - # (i.e. are None or '{}') so the schema is - # updated to produce an empty object in that case - "parameters": tool.function.parameters - if tool.function.parameters else { - "type": "object", - "properties": {} - } - }, - "required": ["name", "parameters"] - } - - def get_tool_schema_defs( - tools: list[ChatCompletionToolsParam]) -> dict: - all_defs = dict[str, dict[str, Any]]() - for tool in tools: - if tool.function.parameters is None: - continue - defs = tool.function.parameters.pop("$defs", {}) - for def_name, def_schema in defs.items(): - if def_name in all_defs and all_defs[ - def_name] != def_schema: - raise ValueError( - f"Tool definition '{def_name}' has " - "multiple schemas, which is not " - "supported.") - else: - all_defs[def_name] = def_schema - return all_defs - - json_schema = { - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools] - } - } - json_schema_defs = get_tool_schema_defs(self.tools) - if json_schema_defs: - json_schema["$defs"] = json_schema_defs - return json_schema - - return None + self.generate_function_call_array_schema_anyOf(self.tools, + required=True) + + return self.generate_function_call_array_schema_anyOf(self.tools) + + def generate_function_call_array_schema_anyOf( + self, + function_defs: Optional[list[ChatCompletionToolsParam]] = None, + required: bool = False, + tool_name: Optional[str] = None, + ) -> Optional[dict[str, Any]]: + """ + Generate JSON Schema for an array of function calls (anyOf version), + optionally filtered by a specific tool name. + + Parameters: + function_defs: List of function definitions. Each item should contain: + - "function": a dict with keys "name" (str) and "parameters" (dict) + Optionally, "parameters" may contain a "$defs" dict to extract sub-schemas. + required: bool + If True and tool_name=None: + - The array must have at least one element (minItems=1) + - Arbitrary values in the array are NOT allowed + tool_name: str | None + If specified: + - Only include this tool's schema in the array items + - Array must have at least one element (required=True) + - Arbitrary values are NOT allowed + + Returns: + A JSON Schema dict describing the array structure. + """ + if not function_defs: + return None + item_schemas = [] + # Collect all top-level $defs + json_schema_defs = dict[str, dict[str, Any]]() + + # If a specific tool is requested, automatically enforce required=True + if tool_name is not None: + required = True + + # Build item schemas for each function + for f in function_defs: + if not isinstance(f, ChatCompletionToolsParam): + continue + fname = f.function.name + + # Skip tools that do not match the requested tool_name + if tool_name and fname != tool_name: + continue + + # Deep copy parameters to avoid mutating input + fparams = copy.deepcopy(f.function.parameters or {}) + fdescription = f.function.description or "" + # Extract any $defs from this function's parameters + defs = fparams.pop("$defs", {}) + for def_name, def_schema in defs.items(): + # Ensure no duplicate $defs with different schemas + if def_name in json_schema_defs and json_schema_defs[ + def_name] != def_schema: + raise ValueError( + f"Tool definition '{def_name}' has multiple schemas, " + "which is not supported.") + else: + json_schema_defs[def_name] = def_schema + + # Build the function object schema + item_schemas.append({ + "type": "object", + "additionalProperties": False, + "description": fdescription, + "properties": { + "name": { + "const": fname + }, # Ensure "name" matches this function + "arguments": fparams # Function parameters schema + }, + "required": ["name", "arguments"] + }) + + # If required=False and tool_name is None, allow arbitrary values in the array + if not required and not tool_name: + item_schemas.append({}) # {} means any type of value is allowed + + # Build top-level array schema + schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "FunctionCallArraySchema", + "type": "array", + "items": { + "anyOf": item_schemas + } # Each item matches any of the function schemas + } + + # If required=True, array must have at least one element + if required: + schema["minItems"] = 1 + + # If tool_name is specified, restrict array to exactly one element + if tool_name: + schema["maxItems"] = 1 + + # Include top-level $defs if any were collected + if json_schema_defs: + schema["$defs"] = json_schema_defs + + return schema @model_validator(mode="before") @classmethod diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 099e456aa486..32f3c487b47b 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -20,6 +20,7 @@ from .qwen3coder_tool_parser import Qwen3CoderToolParser from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser +from .generic_tool_parser import GenericToolParser __all__ = [ "ToolParser", @@ -42,4 +43,5 @@ "Glm4MoeModelToolParser", "Qwen3CoderToolParser", "Step3ToolParser", + "GenericToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/generic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/generic_tool_parser.py new file mode 100644 index 000000000000..0bb20ec1a261 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/generic_tool_parser.py @@ -0,0 +1,74 @@ +from typing import Union + +import json + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall) +from pydantic import TypeAdapter + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("generic") +class GenericToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + logger.info(f"----------------{model_output}----------------") + try: + function_calls = json.loads(model_output) # Validate JSON format + tool_calls = [] + content = "" + for f in function_calls: + if isinstance(f, dict) and f.get("name"): + tool_calls.append( + self.tool_call_class(function=FunctionCall( + name=f.get("name"), + arguments=json.dumps(f.get("arguments", {}), + ensure_ascii=False), + ))) + elif isinstance(f, str): + content += f + else: + content += json.dumps(f, ensure_ascii=False) + return ExtractedToolCallInformation(tools_called=len(tool_calls) + > 0, + tool_calls=tool_calls, + content=content) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming(self, previous_text, current_text, + delta_text, previous_token_ids, + current_token_ids, delta_token_ids, + request): + print(f"delta_text: {delta_text}") + print(f"previous_text {previous_text} ") + print(f"current_text {current_text} ") + # delta = DeltaMessage(tool_calls=[ + # DeltaToolCall(index=self.current_tool_id, + # function=DeltaFunctionCall( + # arguments=delta_text).model_dump( + # exclude_none=True)) + # ]) + delta = DeltaMessage(content=delta_text) + return delta