Skip to content
Open
2 changes: 1 addition & 1 deletion fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
guided_grammar: Optional[Any] = None,
structural_tag: Optional[Any] = None,
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True,
enable_thinking: Optional[bool] = False,
trace_carrier: dict = dict(),
dp_rank: Optional[int] = None,
chat_template: Optional[str] = None,
Expand Down
6 changes: 1 addition & 5 deletions fastdeploy/entrypoints/openai/response_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ def accumulate_token_ids(self, request_output):
else:
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})

async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output):
"""
Process a list of responses into a generator that yields each processed response as it's generated.
Args:
request_outputs: The list of outputs to be processed.
stream: Whether or not to stream the output.
enable_thinking: Whether or not to show thinking messages.
include_stop_str_in_output: Whether or not to include stop strings in the output.
"""
for request_output in request_outputs:
Expand All @@ -82,7 +81,6 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking,
yield self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
elif stream:
Expand All @@ -108,7 +106,6 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking,
self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": request_output["outputs"]["text"]}
Expand All @@ -128,7 +125,6 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking,
self.data_processor.process_response_dict(
response_dict=part["request_output"],
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
Expand Down
9 changes: 0 additions & 9 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,6 @@ async def chat_completion_stream_generator(

max_streaming_response_tokens = max(1, max_streaming_response_tokens)

enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
if enable_thinking is None:
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None

include_stop_str_in_output = request.include_stop_str_in_output

stream_options = request.stream_options
Expand Down Expand Up @@ -242,7 +238,6 @@ async def chat_completion_stream_generator(
generator = response_processor.process_response_chat(
response,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)

Expand Down Expand Up @@ -418,9 +413,6 @@ async def chat_completion_full_generator(
"""
created_time = int(time.time())
final_res = None
enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
if enable_thinking is None:
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None

include_stop_str_in_output = request.include_stop_str_in_output
try:
Expand Down Expand Up @@ -464,7 +456,6 @@ async def chat_completion_full_generator(
generator = response_processor.process_response_chat(
response,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
async for data in generator:
Expand Down
172 changes: 24 additions & 148 deletions fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,10 @@

import json
import re
import uuid
from collections.abc import Sequence
from typing import Union

import partial_json_parser


def random_tool_call_id() -> str:
"""Generate a random tool call ID"""
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"


from fastdeploy.entrypoints.chat_utils import random_tool_call_id
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
Expand Down Expand Up @@ -63,12 +55,12 @@ def __init__(self, tokenizer):
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"

self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)

self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
)
raise RuntimeError("Ernie x1 Tool parser could not locate tool call start/end tokens in the tokenizer!")

if not self.model_tokenizer:
raise ValueError(
Expand All @@ -88,143 +80,27 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
"""

try:
tool_calls = []

# Check for invalid <response> tags before tool calls
if re.search(r"<response>[\s\S]*?</response>\s*(?=<tool_call>)", model_output):
data_processor_logger.error("Invalid format: <response> tags found before <tool_call>")
return ExtractedToolCallInformation(tools_called=False, content=model_output)

function_call_arr = []
remaining_text = model_output

while True:
# Find the next <tool_call>
tool_call_pos = remaining_text.find("<tool_call>")
if tool_call_pos == -1:
break

# Extract content after <tool_call>
tool_content_start = tool_call_pos + len("<tool_call>")
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)

tool_json = ""
if tool_content_end == -1:
# Processing unclosed tool_call block (truncated case)
tool_json = remaining_text[tool_content_start:].strip()
remaining_text = "" # No more content to process
else:
# Processing closed </tool_call> block
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]

if not tool_json:
continue

# Process tool_json
tool_json = tool_json.strip()
if not tool_json.startswith("{"):
tool_json = "{" + tool_json
if not tool_json.endswith("}"):
tool_json = tool_json + "}"

try:
# Parsing strategy: First try standard json.loads
try:
tool_data = json.loads(tool_json)

if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
function_call_arr.append(
{
"name": tool_data["name"],
"arguments": tool_data["arguments"],
"_is_complete": True, # Mark as complete
}
)
continue
except json.JSONDecodeError:
pass

# Try partial_json_parser when standard parsing fails
from partial_json_parser.core.options import Allow

try:
tool_data = {}
flags = Allow.ALL & ~Allow.STR

# Parse the name field
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
if name_match:
tool_data["name"] = name_match.group(1)

# Parse the arguments field
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
if args_match:
try:
tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
except:
tool_data["arguments"] = None

if isinstance(tool_data, dict):
function_call_arr.append(
{
"name": tool_data.get("name", ""),
"arguments": tool_data.get("arguments", {}),
"_is_partial": True, # Mark as partial
}
)
except Exception as e:
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
continue
except Exception as e:
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
continue

if not function_call_arr:
data_processor_logger.error("No valid tool calls found")
return ExtractedToolCallInformation(tools_called=False, content=model_output)

tool_calls = []
all_complete = True # Initialize as all complete

for tool_call in function_call_arr:
# Set flags
is_complete = tool_call.get("_is_complete", False)
is_partial = tool_call.get("_is_partial", False)

# If any tool call is incomplete or partial, mark all_complete as False
if not is_complete or is_partial:
all_complete = False

# Process arguments
tool_args = tool_call.get("arguments", {})
if not isinstance(tool_args, dict):
tool_args = {}

try:
args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
except:
args_str = "{}"

tool_calls.append(
ToolCall(
type="function",
id=random_tool_call_id(),
function=FunctionCall(
name=tool_call.get("name", ""),
arguments=args_str,
),
)
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
function_call_tuples = self.tool_call_regex.findall(model_output)

raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples]

tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"], ensure_ascii=False),
),
)

# Only return tools_called=True if all tool calls are complete
return ExtractedToolCallInformation(
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
)

except Exception as e:
data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
for function_call in raw_function_calls
]
return ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content="")
except Exception:
data_processor_logger.error("Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)

def extract_tool_calls_streaming(
self,
Expand Down
Loading
Loading