diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 3906cd29b5..f24a9b463b 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -71,7 +71,7 @@ def __init__( guided_grammar: Optional[Any] = None, structural_tag: Optional[Any] = None, guided_json_object: Optional[bool] = None, - enable_thinking: Optional[bool] = True, + enable_thinking: Optional[bool] = False, trace_carrier: dict = dict(), dp_rank: Optional[int] = None, chat_template: Optional[str] = None, diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py index e51147899e..0640ec9985 100644 --- a/fastdeploy/entrypoints/openai/response_processors.py +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -67,13 +67,12 @@ def accumulate_token_ids(self, request_output): else: self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output}) - async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output): + async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output): """ Process a list of responses into a generator that yields each processed response as it's generated. Args: request_outputs: The list of outputs to be processed. stream: Whether or not to stream the output. - enable_thinking: Whether or not to show thinking messages. include_stop_str_in_output: Whether or not to include stop strings in the output. """ for request_output in request_outputs: @@ -82,7 +81,6 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking, yield self.data_processor.process_response_dict( response_dict=request_output, stream=stream, - enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output, ) elif stream: @@ -108,7 +106,6 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking, self.data_processor.process_response_dict( response_dict=request_output, stream=stream, - enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output, ) text = {"type": "text", "text": request_output["outputs"]["text"]} @@ -128,7 +125,6 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking, self.data_processor.process_response_dict( response_dict=part["request_output"], stream=False, - enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output, ) text = {"type": "text", "text": part["request_output"]["outputs"]["text"]} diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 52cd556916..36f5a97c53 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -187,10 +187,6 @@ async def chat_completion_stream_generator( max_streaming_response_tokens = max(1, max_streaming_response_tokens) - enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None - if enable_thinking is None: - enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None - include_stop_str_in_output = request.include_stop_str_in_output stream_options = request.stream_options @@ -242,7 +238,6 @@ async def chat_completion_stream_generator( generator = response_processor.process_response_chat( response, stream=True, - enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output, ) @@ -418,9 +413,6 @@ async def chat_completion_full_generator( """ created_time = int(time.time()) final_res = None - enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None - if enable_thinking is None: - enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None include_stop_str_in_output = request.include_stop_str_in_output try: @@ -464,7 +456,6 @@ async def chat_completion_full_generator( generator = response_processor.process_response_chat( response, stream=False, - enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output, ) async for data in generator: diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index 14a784f174..ec3ff9ce14 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -16,18 +16,10 @@ import json import re -import uuid from collections.abc import Sequence from typing import Union -import partial_json_parser - - -def random_tool_call_id() -> str: - """Generate a random tool call ID""" - return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" - - +from fastdeploy.entrypoints.chat_utils import random_tool_call_id from fastdeploy.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, @@ -63,12 +55,12 @@ def __init__(self, tokenizer): self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" + self.tool_call_regex = re.compile(r"(.*?)|(.*)", re.DOTALL) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: - raise RuntimeError( - "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!" - ) + raise RuntimeError("Ernie x1 Tool parser could not locate tool call start/end tokens in the tokenizer!") if not self.model_tokenizer: raise ValueError( @@ -88,143 +80,27 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) """ try: - tool_calls = [] - - # Check for invalid tags before tool calls - if re.search(r"[\s\S]*?\s*(?=)", model_output): - data_processor_logger.error("Invalid format: tags found before ") - return ExtractedToolCallInformation(tools_called=False, content=model_output) - - function_call_arr = [] - remaining_text = model_output - - while True: - # Find the next - tool_call_pos = remaining_text.find("") - if tool_call_pos == -1: - break - - # Extract content after - tool_content_start = tool_call_pos + len("") - tool_content_end = remaining_text.find("", tool_content_start) - - tool_json = "" - if tool_content_end == -1: - # Processing unclosed tool_call block (truncated case) - tool_json = remaining_text[tool_content_start:].strip() - remaining_text = "" # No more content to process - else: - # Processing closed block - tool_json = remaining_text[tool_content_start:tool_content_end].strip() - remaining_text = remaining_text[tool_content_end + len("") :] - - if not tool_json: - continue - - # Process tool_json - tool_json = tool_json.strip() - if not tool_json.startswith("{"): - tool_json = "{" + tool_json - if not tool_json.endswith("}"): - tool_json = tool_json + "}" - - try: - # Parsing strategy: First try standard json.loads - try: - tool_data = json.loads(tool_json) - - if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data: - function_call_arr.append( - { - "name": tool_data["name"], - "arguments": tool_data["arguments"], - "_is_complete": True, # Mark as complete - } - ) - continue - except json.JSONDecodeError: - pass - - # Try partial_json_parser when standard parsing fails - from partial_json_parser.core.options import Allow - - try: - tool_data = {} - flags = Allow.ALL & ~Allow.STR - - # Parse the name field - name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json) - if name_match: - tool_data["name"] = name_match.group(1) - - # Parse the arguments field - args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json) - if args_match: - try: - tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags) - except: - tool_data["arguments"] = None - - if isinstance(tool_data, dict): - function_call_arr.append( - { - "name": tool_data.get("name", ""), - "arguments": tool_data.get("arguments", {}), - "_is_partial": True, # Mark as partial - } - ) - except Exception as e: - data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") - continue - except Exception as e: - data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") - continue - - if not function_call_arr: - data_processor_logger.error("No valid tool calls found") - return ExtractedToolCallInformation(tools_called=False, content=model_output) - - tool_calls = [] - all_complete = True # Initialize as all complete - - for tool_call in function_call_arr: - # Set flags - is_complete = tool_call.get("_is_complete", False) - is_partial = tool_call.get("_is_partial", False) - - # If any tool call is incomplete or partial, mark all_complete as False - if not is_complete or is_partial: - all_complete = False - - # Process arguments - tool_args = tool_call.get("arguments", {}) - if not isinstance(tool_args, dict): - tool_args = {} - - try: - args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}" - except: - args_str = "{}" - - tool_calls.append( - ToolCall( - type="function", - id=random_tool_call_id(), - function=FunctionCall( - name=tool_call.get("name", ""), - arguments=args_str, - ), - ) + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) + function_call_tuples = self.tool_call_regex.findall(model_output) + + raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] + + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], ensure_ascii=False), + ), ) - - # Only return tools_called=True if all tool calls are complete - return ExtractedToolCallInformation( - tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content="" - ) - - except Exception as e: - data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}") - return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) + for function_call in raw_function_calls + ] + return ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content="") + except Exception: + data_processor_logger.error("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) def extract_tool_calls_streaming( self, diff --git a/fastdeploy/input/ernie4_5_processor.py b/fastdeploy/input/ernie4_5_processor.py index 8d2463a088..b75d2c4fbe 100644 --- a/fastdeploy/input/ernie4_5_processor.py +++ b/fastdeploy/input/ernie4_5_processor.py @@ -60,6 +60,7 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob self.decode_status = dict() self.tool_parser_dict = dict() self.thinking_parser_dict = dict() + self.model_status_dict = dict() self._load_tokenizer() data_processor_logger.info( f"tokenizer information: bos_token is {self.tokenizer.bos_token} \ @@ -153,6 +154,10 @@ def process_request(self, request, max_model_len=None, **kwargs): request.set("top_p", _SAMPLING_EPS) if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser": request.enable_thinking = True + if self.reasoning_parser: + model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids) + self.model_status_dict[request.request_id] = model_status + request.enable_thinking = model_status == "think_start" data_processor_logger.info(f"Processed request: {request}") return request @@ -231,7 +236,10 @@ def process_request_dict(self, request, max_model_len=None): request["top_p"] = _SAMPLING_EPS if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser": request["enable_thinking"] = True - + if self.reasoning_parser: + model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"]) + self.model_status_dict[request["request_id"]] = model_status + request["enable_thinking"] = model_status == "think_start" data_processor_logger.info(f"Processed request dict: {request}") return request @@ -253,7 +261,11 @@ def process_response(self, response_dict, **kwargs): token_ids = token_ids[:-1] full_text = self.tokenizer.decode(token_ids) if self.reasoning_parser: - reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict) + reasoning_content, text = self.reasoning_parser.extract_reasoning_content( + full_text, + response_dict, + self.model_status_dict.get(req_id), + ) response_dict.outputs.text = text response_dict.outputs.reasoning_content = reasoning_content else: @@ -267,6 +279,8 @@ def process_response(self, response_dict, **kwargs): data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}") if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "": return None + if req_id in self.model_status_dict: + del self.model_status_dict[req_id] return response_dict def process_response_dict(self, response_dict, stream, **kwargs): @@ -294,7 +308,6 @@ def process_response_dict_normal(self, response_dict, **kwargs): Returns: Dict: response contain text fields """ - enable_thinking = kwargs.get("enable_thinking") token_ids = response_dict["outputs"]["token_ids"] is_end = response_dict["finished"] req_id = response_dict["request_id"] @@ -304,14 +317,15 @@ def process_response_dict_normal(self, response_dict, **kwargs): delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) if is_end: full_text = previous_texts + delta_text - if self.reasoning_parser and ( - enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser" - ): - reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict) + response_dict["outputs"]["text"] = full_text + if self.reasoning_parser: + reasoning_content, text = self.reasoning_parser.extract_reasoning_content( + full_text, + response_dict, + self.model_status_dict.get(req_id), + ) response_dict["outputs"]["text"] = text response_dict["outputs"]["reasoning_content"] = reasoning_content - else: - response_dict["outputs"]["text"] = full_text if self.tool_parser_obj: tool_parser = self.tool_parser_obj(self.tokenizer) tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) @@ -321,6 +335,8 @@ def process_response_dict_normal(self, response_dict, **kwargs): response_dict["outputs"]["raw_prediction"] = full_text data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") del self.decode_status[req_id] + if req_id in self.model_status_dict: + del self.model_status_dict[req_id] return response_dict def process_response_dict_streaming(self, response_dict, **kwargs): @@ -333,7 +349,6 @@ def process_response_dict_streaming(self, response_dict, **kwargs): Returns: Dict: response contain text fields """ - enable_thinking = kwargs.get("enable_thinking") is_end = response_dict["finished"] req_id = response_dict["request_id"] token_ids = response_dict["outputs"]["token_ids"] @@ -343,9 +358,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs): token_ids = token_ids[:-1] delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id) response_dict["outputs"]["raw_prediction"] = delta_text - if self.reasoning_parser and ( - enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser" - ): + if self.reasoning_parser: reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming( previous_texts, previous_texts + delta_text, @@ -353,6 +366,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs): previous_token_ids, previous_token_ids + token_ids, token_ids, + self.model_status_dict.get(req_id), ) response_dict["outputs"]["delta_message"] = reasoning_delta_message if self.tool_parser_obj: @@ -376,6 +390,8 @@ def process_response_dict_streaming(self, response_dict, **kwargs): del self.decode_status[req_id] if req_id in self.tool_parser_dict: del self.tool_parser_dict[req_id] + if req_id in self.model_status_dict: + del self.model_status_dict[req_id] return response_dict def messages2ids(self, request_or_messages, **kwargs): diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 9251dd9d95..2ce5f9dde7 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -54,6 +54,7 @@ def __init__( self.tool_parser_dict = dict() self.decode_status = dict() + self.model_status_dict = dict() self._load_tokenizer() # Generation config @@ -258,6 +259,11 @@ def process_request_dict(self, request, max_model_len=None): request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1) data_processor_logger.info(f"Processed request {request}") + if self.reasoning_parser: + model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"]) + self.model_status_dict[request["request_id"]] = model_status + request["enable_thinking"] = model_status == "think_start" + return request def append_completion_tokens(self, multimodal_inputs, completion_token_ids): @@ -290,21 +296,3 @@ def pack_outputs(self, outs): outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64) return outs - - def process_response_dict(self, response_dict, stream, **kwargs): - """ - Preprocess the response - - Args: - response_dict (Dict): response for engine, contain ids fields - - Returns: - Dict: response contain text fields - """ - enable_thinking = kwargs.pop("enable_thinking", True) - if enable_thinking is None: - enable_thinking = True - if stream: - return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs) - else: - return self.process_response_dict_normal(response_dict, enable_thinking=enable_thinking, **kwargs) diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index a29e1b2605..0ea1ea6c33 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -175,6 +175,7 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob self.generation_config = None self.decode_status = dict() + self.model_status_dict = dict() self.tool_parser_dict = dict() self.tokenizer = self._load_tokenizer() data_processor_logger.info( @@ -267,6 +268,10 @@ def process_request(self, request, max_model_len=None, **kwargs): request.set("temperature", 1) if request.get("top_p") < _SAMPLING_EPS: request.set("top_p", _SAMPLING_EPS) + if self.reasoning_parser: + model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids) + self.model_status_dict[request.request_id] = model_status + request.enable_thinking = model_status == "think_start" data_processor_logger.info(f"Processed request: {request}") return request @@ -341,6 +346,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): request["temperature"] = 1 if request.get("top_p") < _SAMPLING_EPS: request["top_p"] = _SAMPLING_EPS + if self.reasoning_parser: + model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"]) + self.model_status_dict[request["request_id"]] = model_status + request["enable_thinking"] = model_status == "think_start" data_processor_logger.info(f"Processed request dict: {request}") return request @@ -364,21 +373,21 @@ def process_response(self, response_dict, **kwargs): if token_ids[-1] == self.tokenizer.eos_token_id: token_ids = token_ids[:-1] full_text = self.tokenizer.decode(token_ids) - - # 模型支持思考,并且支持思考 + response_dict.outputs.text = full_text if self.reasoning_parser: - reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict) + reasoning_content, text = self.reasoning_parser.extract_reasoning_content( + full_text, response_dict, self.model_status_dict[req_id] + ) response_dict.outputs.text = text response_dict.outputs.reasoning_content = reasoning_content - else: - # 模型不支持思考,并且没单独设置enable_thinking为false - response_dict.outputs.text = full_text if self.tool_parser_obj: tool_parser = self.tool_parser_obj(self.tokenizer) tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) if tool_call_info.tools_called: response_dict.outputs.tool_calls = tool_call_info.tool_calls response_dict.outputs.text = tool_call_info.content + if req_id in self.model_status_dict: + del self.model_status_dict[req_id] data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}") return response_dict @@ -393,7 +402,6 @@ def process_response_dict_normal(self, response_dict, **kwargs): Returns: Dict: response contain text fields """ - enable_thinking = kwargs.get("enable_thinking") token_ids = response_dict["outputs"]["token_ids"] is_end = response_dict["finished"] req_id = response_dict["request_id"] @@ -404,12 +412,13 @@ def process_response_dict_normal(self, response_dict, **kwargs): if is_end: full_text = previous_texts + delta_text response_dict["outputs"]["raw_prediction"] = full_text - if enable_thinking and self.reasoning_parser: - reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict) + response_dict["outputs"]["text"] = full_text + if self.reasoning_parser: + reasoning_content, text = self.reasoning_parser.extract_reasoning_content( + full_text, response_dict, self.model_status_dict[req_id] + ) response_dict["outputs"]["text"] = text response_dict["outputs"]["reasoning_content"] = reasoning_content - else: - response_dict["outputs"]["text"] = full_text if self.tool_parser_obj: tool_parser = self.tool_parser_obj(self.tokenizer) tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) @@ -430,7 +439,6 @@ def process_response_dict_streaming(self, response_dict, **kwargs): Returns: Dict: response contain text fields """ - enable_thinking = kwargs.get("enable_thinking") is_end = response_dict["finished"] req_id = response_dict["request_id"] token_ids = response_dict["outputs"]["token_ids"] @@ -440,9 +448,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs): token_ids = token_ids[:-1] delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id) response_dict["outputs"]["raw_prediction"] = delta_text - if self.reasoning_parser and ( - enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser" - ): + if self.reasoning_parser: reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming( previous_texts, previous_texts + delta_text, @@ -450,6 +456,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs): previous_token_ids, previous_token_ids + token_ids, token_ids, + self.model_status_dict[req_id], ) response_dict["outputs"]["delta_message"] = reasoning_delta_message if self.tool_parser_obj: @@ -473,6 +480,8 @@ def process_response_dict_streaming(self, response_dict, **kwargs): del self.decode_status[req_id] if req_id in self.tool_parser_dict: del self.tool_parser_dict[req_id] + if req_id in self.model_status_dict: + del self.model_status_dict[req_id] return response_dict def process_response_dict(self, response_dict, **kwargs): @@ -485,16 +494,12 @@ def process_response_dict(self, response_dict, **kwargs): Returns: Dict: response contain text fields """ - enable_thinking = kwargs.pop("enable_thinking", True) - if enable_thinking is None: - enable_thinking = True stream = kwargs.get("stream", True) if stream: - return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs) + return self.process_response_dict_streaming(response_dict, **kwargs) else: return self.process_response_dict_normal( response_dict=response_dict, - enable_thinking=enable_thinking, **kwargs, ) diff --git a/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py b/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py index 5636ee9f5e..5daaa986ce 100644 --- a/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py +++ b/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py @@ -35,20 +35,48 @@ class ErnieVLReasoningParser(ReasoningParser): def __init__(self, tokenizer): super().__init__(tokenizer) - self.think_end_token = "" + token_definitions = { + "think_start_token": "", + "think_end_token": "", + } if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " "constructor during construction." - ) + raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.") + + missing_tokens = [] + for name, token_value in token_definitions.items(): + setattr(self, name, token_value) + token_id = self.vocab.get(token_value) + setattr(self, f"{name}_id", token_id) + if token_id is None: + missing_tokens.append(f"{name.replace('_', ' ')} token") - self.think_end_token_id = self.vocab.get(self.think_end_token) - if self.think_end_token_id is None: - raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!") + if missing_tokens: + raise RuntimeError( + f"ernie vl reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}" + ) + self.token_status_mapping = { + self.think_start_token_id: "think_start", + self.think_end_token_id: "think_end", + } def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.think_end_token_id in input_ids + def find_last_special_token(self, prompt_token_ids: list[int]) -> int: + for i in range(len(prompt_token_ids) - 1, -1, -1): + if prompt_token_ids[i] in self.token_status_mapping: + return prompt_token_ids[i] + return -1 + + def get_model_status(self, prompt_token_ids: list[int]): + special_token_id = self.find_last_special_token(prompt_token_ids) + + if special_token_id == -1: + return "think_start" + + return self.token_status_mapping[special_token_id] + def extract_reasoning_content_streaming( self, previous_text: str, @@ -57,6 +85,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + model_status: str, ) -> Union[DeltaMessage, None]: """ Extract reasoning content from a delta message. @@ -69,18 +98,24 @@ def extract_reasoning_content_streaming( # Skip single special tokens if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: return None - if self.think_end_token_id in delta_token_ids: - end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token) :] - return DeltaMessage(reasoning_content=reasoning_content, content=content) - elif self.think_end_token_id in previous_token_ids: - return DeltaMessage(content=delta_text) + if model_status == "think_start": + if self.think_end_token_id in delta_token_ids: + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage(reasoning_content=reasoning_content, content=content) + elif self.think_end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + else: + return DeltaMessage(reasoning_content=delta_text) else: - return DeltaMessage(reasoning_content=delta_text) + return DeltaMessage(content=delta_text) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, + model_output: str, + request: ChatCompletionRequest, + model_status: str, ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. @@ -94,9 +129,11 @@ def extract_reasoning_content( """ # Check if the model output contains the tokens. - if self.think_end_token not in model_output: + if model_status == "think_start": + if self.think_end_token not in model_output: + return model_output, "" + reasoning_content, _, content = model_output.partition(self.think_end_token) + final_content = content or "" + return reasoning_content, final_content + else: return "", model_output - reasoning_content, _, content = model_output.partition(self.think_end_token) - - final_content = content or "" - return reasoning_content, final_content diff --git a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py index 54b72a0eb5..0ab2f26f09 100644 --- a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py +++ b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py @@ -18,19 +18,55 @@ class ErnieX1ReasoningParser(ReasoningParser): def __init__(self, tokenizer): super().__init__(tokenizer) - self.think_end_token = "" - self.response_start_token = "" - self.response_end_token = "" - self.tool_call_start_token = "" - self.tool_call_end_token = "" + + # 定义所有需要检查的token + token_definitions = { + "think_start_token": "", + "think_end_token": "", + "response_start_token": "", + "response_end_token": "", + "tool_call_start_token": "", + "tool_call_end_token": "", + } if not self.model_tokenizer: raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.") - self.think_end_token_id = self.vocab.get("") - if self.think_end_token_id is None: - raise RuntimeError("Could not find think end token id in tokenizer vocabulary") - self.tool_call_start_token_id = self.vocab.get("") + missing_tokens = [] + for name, token_value in token_definitions.items(): + setattr(self, name, token_value) + token_id = self.vocab.get(token_value) + setattr(self, f"{name}_id", token_id) + if token_id is None: + missing_tokens.append(token_value) + + if missing_tokens: + raise RuntimeError( + f"ernie x1 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}" + ) + + self.token_status_mapping = { + self.think_start_token_id: "think_start", + self.think_end_token_id: "think_end", + self.response_start_token_id: "response_start", + self.response_end_token_id: "response_end", + self.tool_call_start_token_id: "tool_call_start", + self.tool_call_end_token_id: "tool_call_end", + } + + def find_last_special_token(self, prompt_token_ids: list[int]) -> int: + for i in range(len(prompt_token_ids) - 1, -1, -1): + if prompt_token_ids[i] in self.token_status_mapping: + return prompt_token_ids[i] + return -1 + + def get_model_status(self, prompt_token_ids: list[int]): + special_token_id = self.find_last_special_token(prompt_token_ids) + + if special_token_id == -1: + return "think_start" + + return self.token_status_mapping[special_token_id] def extract_reasoning_content_streaming( self, @@ -40,64 +76,81 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + model_status: str, ) -> Union[DeltaMessage, None]: - # Ignore the single token - if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: - return None - # --- Thinking stage handling --- - if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text: - # If delta is , stop thinking, do not return - if delta_text.startswith(self.think_end_token): - return None - # Otherwise, return thinking content (keep \n as-is) - return DeltaMessage(reasoning_content=delta_text) - - # --- After thinking ends, check tool_call or response --- - remaining_text = previous_text + delta_text - after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :] - after_think = after_think.lstrip("\n") - - # Handle tool_call case: skip it - if after_think.startswith(self.tool_call_start_token): + if len(delta_token_ids) == 1 and delta_token_ids[0] in [ + self.think_end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ]: return None - # Handle response case - if after_think.startswith(self.response_start_token) and self.response_end_token not in after_think: - # Do not return when tag itself appears - if delta_text == self.response_start_token or delta_text == self.response_end_token: - return None - return DeltaMessage(content=delta_text) + if model_status == "think_start": + if self.think_end_token_id in delta_token_ids: + reasoning_content = "" + response_content = "" + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + response_start_pos = delta_text.find(self.response_start_token) + if response_start_pos != -1: + response_content = self._extract_response_content( + delta_text[response_start_pos + len(self.response_start_token) :] + ) + return DeltaMessage(reasoning_content=reasoning_content, content=response_content) + elif self.think_end_token in previous_text: + if self.response_start_token in previous_text and self.response_end_token not in previous_text: + return DeltaMessage(content=delta_text) + else: + return DeltaMessage(reasoning_content=delta_text) + elif model_status == "think_end": + if self.response_start_token in previous_text and self.response_end_token not in previous_text: + return DeltaMessage(content=delta_text) + elif model_status == "response_start": + if self.response_end_token not in previous_text: + return DeltaMessage(content=delta_text) - # Default case: return nothing return None - def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]: + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest, model_status: str + ) -> Tuple[str, str]: + """ + 优化版解析器。保留推理和响应内容中的换行符, + 仅删除闭合标签前的单个换行符。 + """ reasoning_content = "" response_content = "" - think_end_pos = model_output.find(self.think_end_token) - if think_end_pos != -1: - reasoning_content = model_output[:think_end_pos] - - remaining = model_output[think_end_pos + len(self.think_end_token) :] + if model_status in ["think_start", "think_end"]: + if model_status == "think_start": + think_end_pos = model_output.find(self.think_end_token) + if think_end_pos != -1: + reasoning_content = model_output[:think_end_pos] + remaining = model_output[think_end_pos + len(self.think_end_token) :].lstrip("\n") + else: + reasoning_content = model_output + remaining = "" + else: + remaining = model_output.lstrip("\n") - # find or - response_pos = remaining.find(self.response_start_token) - tool_pos = remaining.find(self.tool_call_start_token) + response_start_pos = remaining.find(self.response_start_token) + if response_start_pos != -1: + response_content = self._extract_response_content( + remaining[response_start_pos + len(self.response_start_token) :] + ) - # first - if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos): - # The content after the response_start position - remaining_response = remaining[response_pos + len(self.response_start_token) :] - response_end_pos = remaining_response.find(self.response_end_token) - if response_end_pos != -1: - response_content = remaining_response[:response_end_pos] - else: - response_content = remaining_response - # The content after the response_start position is tool_call - else: - reasoning_content = model_output - response_content = "" + elif model_status == "response_start": + response_content = self._extract_response_content(model_output) return reasoning_content, response_content + + def _extract_response_content(self, remaining: str) -> str: + """ + Extracts response content, ensuring that the last newline before + the tag is removed. + """ + response_end_pos = remaining.find(self.response_end_token) + if response_end_pos != -1: + return remaining[:response_end_pos] + return remaining diff --git a/fastdeploy/reasoning/qwen3_reasoning_parsers.py b/fastdeploy/reasoning/qwen3_reasoning_parsers.py index 463cab83df..b01cdf0d69 100644 --- a/fastdeploy/reasoning/qwen3_reasoning_parsers.py +++ b/fastdeploy/reasoning/qwen3_reasoning_parsers.py @@ -35,22 +35,50 @@ class Qwen3ReasoningParser(ReasoningParser): def __init__(self, tokenizer): super().__init__(tokenizer) - self.think_start_token = "" - self.think_end_token = "" + + # 定义所有需要检查的token + token_definitions = { + "think_start_token": "", + "think_end_token": "", + } if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " "constructor during construction." + raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.") + + missing_tokens = [] + for name, token_value in token_definitions.items(): + setattr(self, name, token_value) + token_id = self.vocab.get(token_value) + setattr(self, f"{name}_id", token_id) + if token_id is None: + missing_tokens.append(token_value) + + if missing_tokens: + raise RuntimeError( + f"Qwen3 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}" ) - - self.think_start_token_id = self.vocab.get(self.think_start_token) - self.think_end_token_id = self.vocab.get(self.think_end_token) - if self.think_end_token_id is None: - raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!") + self.token_status_mapping = { + self.think_start_token_id: "think_start", + self.think_end_token_id: "think_end", + } def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.think_end_token_id in input_ids + def find_last_special_token(self, prompt_token_ids: list[int]) -> int: + for i in range(len(prompt_token_ids) - 1, -1, -1): + if prompt_token_ids[i] in self.token_status_mapping: + return prompt_token_ids[i] + return -1 + + def get_model_status(self, prompt_token_ids: list[int]): + special_token_id = self.find_last_special_token(prompt_token_ids) + + if special_token_id == -1: + return "think_start" + + return self.token_status_mapping[special_token_id] + def extract_reasoning_content_streaming( self, previous_text: str, @@ -59,6 +87,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + model_status: str, ) -> Union[DeltaMessage, None]: """ Extract reasoning content from a delta message. @@ -71,39 +100,42 @@ def extract_reasoning_content_streaming( if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]): return None - # in delta - if self.think_end_token_id in delta_token_ids: - # in delta, in delta, extract reasoning content - if self.think_start_token_id in delta_token_ids: + if model_status == "think_start": + # in delta + if self.think_end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + if self.think_start_token_id in delta_token_ids: + start_index = delta_text.find(self.think_start_token) + end_index = delta_token_ids.find(self.think_end_token) + reasoning_content = delta_text[start_index + len(self.think_start_token) : end_index] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage(reasoning_content=reasoning_content, content=content) + # in previous, in delta, + else: + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token) :] + content = content if content else None + return DeltaMessage(reasoning_content=reasoning_content, content=content) + # in previous reasoning content continues + elif self.think_end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + # in previous + elif self.think_start_token_id in previous_token_ids: + return DeltaMessage(reasoning_content=delta_text) + # in delta + elif self.think_start_token_id in delta_token_ids: start_index = delta_text.find(self.think_start_token) - end_index = delta_token_ids.find(self.think_end_token) - reasoning_content = delta_text[start_index + len(self.think_start_token) : end_index] - content = delta_text[end_index + len(self.think_end_token) :] + reasoning_content = delta_text[start_index + len(self.think_start_token) :] + content = "" return DeltaMessage(reasoning_content=reasoning_content, content=content) - # in previous, in delta, else: - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token) :] - content = content if content else None - return DeltaMessage(reasoning_content=reasoning_content, content=content) - # in previous reasoning content continues - elif self.think_end_token_id in previous_token_ids: - return DeltaMessage(content=delta_text) - # in previous - elif self.think_start_token_id in previous_token_ids: - return DeltaMessage(reasoning_content=delta_text) - # in delta - elif self.think_start_token_id in delta_token_ids: - start_index = delta_text.find(self.think_start_token) - reasoning_content = delta_text[start_index + len(self.think_start_token) :] - content = "" - return DeltaMessage(reasoning_content=reasoning_content, content=content) + return DeltaMessage(reasoning_content=delta_text) else: - return DeltaMessage(reasoning_content=delta_text) + return DeltaMessage(content=delta_text) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest, model_status: str ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. @@ -116,36 +148,39 @@ def extract_reasoning_content( tuple[Optional[str], Optional[str]]: reasoning content and content """ - # 检查是否包含结束标签 - if self.think_end_token not in model_output: - return None, model_output - - # 检查是否有起始标签 - if self.think_start_token in model_output: - # 标准格式:contentanswer - if self.think_start_token not in model_output or self.think_end_token not in model_output: - return None, model_output - # Check if the is present in the model output, remove it - # if it is present. - model_output_parts = model_output.partition(self.think_start_token) - model_output = model_output_parts[2] if model_output_parts[1] else model_output_parts[0] - # Check if the model output contains the tokens. - # If the end token is not found, return the model output as is. + if model_status == "think_start": + # 检查是否包含结束标签 if self.think_end_token not in model_output: return None, model_output - # Extract reasoning content from the model output. - reasoning_content, _, content = model_output.partition(self.think_end_token) - - final_content = content or None - return reasoning_content, final_content - else: - # 缺少起始标签的格式:contentanswer - parts = model_output.split(self.think_end_token, 1) - - if len(parts) == 2: - reasoning_content = parts[0].strip() - final_content = parts[1].strip() if parts[1].strip() else None + # 检查是否有起始标签 + if self.think_start_token in model_output: + # 标准格式:contentanswer + if self.think_start_token not in model_output or self.think_end_token not in model_output: + return None, model_output + # Check if the is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.think_start_token) + model_output = model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + # Check if the model output contains the tokens. + # If the end token is not found, return the model output as is. + if self.think_end_token not in model_output: + return None, model_output + + # Extract reasoning content from the model output. + reasoning_content, _, content = model_output.partition(self.think_end_token) + + final_content = content or None return reasoning_content, final_content + else: + # 缺少起始标签的格式:contentanswer + parts = model_output.split(self.think_end_token, 1) + + if len(parts) == 2: + reasoning_content = parts[0].strip() + final_content = parts[1].strip() if parts[1].strip() else None + return reasoning_content, final_content - return None, model_output + return None, model_output + else: + return None, model_output diff --git a/tests/e2e/test_EB_VL_Lite_serving.py b/tests/e2e/test_EB_VL_Lite_serving.py index 41dd81a097..e116e8bb9e 100644 --- a/tests/e2e/test_EB_VL_Lite_serving.py +++ b/tests/e2e/test_EB_VL_Lite_serving.py @@ -532,7 +532,7 @@ def test_chat_with_thinking(openai_client, capsys): max_tokens=10, extra_body={"chat_template_kwargs": {"enable_thinking": False}}, ) - assert response.choices[0].message.reasoning_content is None + assert response.choices[0].message.reasoning_content == "" assert "" not in response.choices[0].message.content # test logic @@ -703,4 +703,4 @@ def test_thinking_logic_flag(openai_client, capsys): "chat_template_kwargs": {"enable_thinking": False}, }, ) - assert response_case_3.choices[0].message.reasoning_content is None + assert response_case_3.choices[0].message.reasoning_content == "" diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index 61d5f88d45..0c8a3f8d22 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -141,7 +141,7 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class mock_processor_instance = Mock() - async def mock_process_response_chat_single(response, stream, enable_thinking, include_stop_str_in_output): + async def mock_process_response_chat_single(response, stream, include_stop_str_in_output): yield response mock_processor_instance.process_response_chat = mock_process_response_chat_single diff --git a/tests/entrypoints/openai/test_response_processors.py b/tests/entrypoints/openai/test_response_processors.py index afab163b97..34cade7cd8 100644 --- a/tests/entrypoints/openai/test_response_processors.py +++ b/tests/entrypoints/openai/test_response_processors.py @@ -48,7 +48,7 @@ async def test_text_only_mode(self): results = [ r async for r in processor.process_response_chat( - request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False + request_outputs, stream=False, include_stop_str_in_output=False ) ] @@ -67,7 +67,7 @@ async def test_streaming_text_and_image(self): results = [ r async for r in self.processor_mm.process_response_chat( - request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False + request_outputs, stream=True, include_stop_str_in_output=False ) ] @@ -94,7 +94,7 @@ async def test_streaming_buffer_accumulation(self): results = [ r async for r in self.processor_mm.process_response_chat( - request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False + request_outputs, stream=True, include_stop_str_in_output=False ) ] @@ -112,7 +112,7 @@ async def test_non_streaming_accumulate_and_emit(self): results = [ r async for r in self.processor_mm.process_response_chat( - request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False + request_outputs, stream=False, include_stop_str_in_output=False ) ] diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py index e818801d93..1b8b58d1e9 100644 --- a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py @@ -52,33 +52,12 @@ def test_extract_tool_calls_complete(self): self.assertTrue(result.tools_called) self.assertEqual(result.tool_calls[0].function.name, "get_weather") - def test_extract_tool_calls_partial_arguments(self): - """Test partial extraction when arguments incomplete""" - output = '{"name": "get_weather", "arguments": {"location": "北"' - result = self.parser.extract_tool_calls(output, self.dummy_request) - self.assertFalse(result.tools_called) - self.assertEqual(result.tool_calls[0].function.name, "get_weather") - - def test_extract_tool_calls_invalid_response_before_toolcall(self): - """Test case where before is invalid""" - output = 'hello{"name": "get_weather", "arguments": {}}' - result = self.parser.extract_tool_calls(output, self.dummy_request) - self.assertFalse(result.tools_called) - self.assertIn("", result.content) - def test_extract_tool_calls_no_toolcall(self): """Test when no tool_call tags are present""" output = "no tool call here" result = self.parser.extract_tool_calls(output, self.dummy_request) self.assertFalse(result.tools_called) - def test_extract_tool_calls_invalid_json(self): - """Test tool_call with badly formatted JSON triggers fallback parser""" - output = '"name": "get_weather", "arguments": {' - result = self.parser.extract_tool_calls(output, self.dummy_request) - self.assertFalse(result.tools_called) - self.assertEqual(result.tool_calls[0].function.name, "get_weather") - def test_extract_tool_calls_exception(self): """Force exception to cover error branch""" with patch( diff --git a/tests/input/test_ernie_processor.py b/tests/input/test_ernie_processor.py index b2357eeaa8..75da4786bd 100644 --- a/tests/input/test_ernie_processor.py +++ b/tests/input/test_ernie_processor.py @@ -4,6 +4,11 @@ from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor +class MockReasoningParser: + def get_model_status(self, prompt_token_ids): + return "think_start" + + class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase): def setUp(self): # 创建 Ernie4_5Processor 实例的模拟对象 @@ -14,11 +19,13 @@ def setUp(self): # 设置必要的属性 self.processor.tokenizer = MagicMock() self.processor.tokenizer.eos_token_id = 1 - self.processor.decode_status = {} + self.processor.decode_status = {"test": []} self.processor.reasoning_end_dict = {} self.processor.tool_parser_dict = {} self.processor.generation_config = MagicMock() self.processor.eos_token_ids = [1] + self.processor.reasoning_parser = MockReasoningParser() + self.processor.model_status_dict = {"test": "think_start"} # 模拟 ids2tokens 方法 def mock_ids2tokens(token_ids, task_id): @@ -63,8 +70,17 @@ def test_process_response_dict_streaming_normal_case(self): # 验证结果 self.assertEqual(result["outputs"]["raw_prediction"], "delta_text") + response_dict = {"finished": True, "request_id": "test", "outputs": {"token_ids": [4, 5]}} + + # 调用方法 + result = self.processor.process_response_dict_streaming(response_dict) + + # 验证结果 + self.assertEqual(result["outputs"]["raw_prediction"], "delta_text") + def test_process_request_dict(self): request_dict = { + "request_id": "123", "messages": [{"role": "user", "content": "Hello!"}], "chat_template_kwargs": {"chat_template": "Hello!"}, "eos_token_ids": [1], diff --git a/tests/input/test_text_processor.py b/tests/input/test_text_processor.py index 6ca0178fe8..337ad0a0d3 100644 --- a/tests/input/test_text_processor.py +++ b/tests/input/test_text_processor.py @@ -5,6 +5,11 @@ from fastdeploy.input.text_processor import DataProcessor +class MockReasoningParser: + def get_model_status(self, prompt_token_ids): + return "think_start" + + class TestDataProcessorProcess(unittest.TestCase): def setUp(self): # 创建 DataProcessor 实例的模拟对象 @@ -20,6 +25,8 @@ def setUp(self): self.processor.tool_parser_dict = {} self.processor.generation_config = MagicMock() self.processor.eos_token_ids = [1] + self.processor.reasoning_parser = MockReasoningParser() + self.processor.model_status_dict = {} def mock_messages2ids(request, **kwargs): if "chat_template" in kwargs: @@ -49,6 +56,7 @@ def test_process_request(self): def test_process_request_dict(self): request_dict = { + "request_id": "123", "messages": [{"role": "user", "content": "Hello!"}], "chat_template_kwargs": {"chat_template": "Hello!"}, "eos_token_ids": [1], diff --git a/tests/reasoning/test_reasoning_parser.py b/tests/reasoning/test_reasoning_parser.py index 90a48c8990..4b938a7a25 100644 --- a/tests/reasoning/test_reasoning_parser.py +++ b/tests/reasoning/test_reasoning_parser.py @@ -27,10 +27,11 @@ class DummyTokenizer: def __init__(self): self.vocab = { "": 100, - "": 101, - "": 102, - "": 103, - "": 104, + "": 101, + "": 102, + "": 103, + "": 104, + "": 105, } def get_vocab(self): @@ -137,6 +138,7 @@ def test_streaming_thinking_content(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[200], + model_status="think_start", ) self.assertEqual(msg.reasoning_content, "a") @@ -148,6 +150,7 @@ def test_streaming_thinking_newline_preserved(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[201], + model_status="think_start", ) self.assertEqual(msg.reasoning_content, "\n") @@ -159,6 +162,7 @@ def test_streaming_thinking_end_tag(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[self.parser.think_end_token_id], + model_status="think_start", ) self.assertIsNone(msg) @@ -170,6 +174,7 @@ def test_streaming_response_content(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[202], + model_status="think_start", ) self.assertEqual(msg.content, "h") @@ -181,6 +186,7 @@ def test_streaming_response_newline_preserved(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[203], + model_status="think_start", ) self.assertEqual(msg.content, "\n") @@ -193,6 +199,7 @@ def test_streaming_response_ignore_tags(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[self.parser.vocab[""]], + model_status="think_start", ) ) @@ -203,6 +210,7 @@ def test_streaming_response_ignore_tags(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[204], + model_status="think_start", ) self.assertIsInstance(msg, DeltaMessage) self.assertEqual(msg.content, "\n") @@ -215,6 +223,7 @@ def test_streaming_response_ignore_tags(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[self.parser.vocab[""]], + model_status="think_start", ) ) @@ -226,37 +235,39 @@ def test_streaming_tool_call(self): previous_token_ids=[], current_token_ids=[], delta_token_ids=[self.parser.vocab[""]], + model_status="think_start", ) + print(msg) self.assertIsNone(msg) # ---- Batch parsing ---- def test_batch_reasoning_and_response(self): text = "abc\n\nhello\nworld" - reasoning, response = self.parser.extract_reasoning_content(text, self.request) + reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start") self.assertEqual(reasoning, "abc\n") self.assertEqual(response, "hello\nworld") def test_batch_reasoning_and_tool_call(self): text = "abccall_here" - reasoning, response = self.parser.extract_reasoning_content(text, self.request) + reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start") self.assertEqual(reasoning, "abc") self.assertEqual(response, "") def test_batch_no_thinking_tag(self): text = "no_thinking_here" - reasoning, response = self.parser.extract_reasoning_content(text, self.request) + reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start") self.assertEqual(reasoning, "no_thinking_here") self.assertEqual(response, "") def test_batch_response_without_end_tag(self): text = "abcpartial response" - reasoning, response = self.parser.extract_reasoning_content(text, self.request) + reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start") self.assertEqual(reasoning, "abc") self.assertEqual(response, "partial response") def test_batch_preserve_all_newlines(self): text = "abc\n\nline1\nline2\n" - reasoning, response = self.parser.extract_reasoning_content(text, self.request) + reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start") self.assertEqual(reasoning, "abc\n") self.assertEqual(response, "line1\nline2\n") diff --git a/tests/reasoning/test_vl_reasoning_parser.py b/tests/reasoning/test_vl_reasoning_parser.py new file mode 100644 index 0000000000..7eaa5fb4f8 --- /dev/null +++ b/tests/reasoning/test_vl_reasoning_parser.py @@ -0,0 +1,135 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest +from fastdeploy.reasoning.ernie_vl_reasoning_parsers import ErnieVLReasoningParser + + +class MockTokenizer: + """Minimal tokenizer with vocab for testing.""" + + def __init__(self): + self.vocab = { + "": 100, + "": 101, + } + + def get_vocab(self): + """Return vocab dict for testing.""" + return self.vocab + + +class TestErnieVLReasoningParser(unittest.TestCase): + def setUp(self): + self.parser = ErnieVLReasoningParser(MockTokenizer()) + self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}]) + self.tokenizer = MockTokenizer() + + def test_get_model_status(self): + status = self.parser.get_model_status([1, 2, 100]) + self.assertEqual(status, "think_start") + status = self.parser.get_model_status([1, 2, 101]) + self.assertEqual(status, "think_end") + status = self.parser.get_model_status([1]) + self.assertEqual(status, "think_start") + + def test_streaming_thinking_content(self): + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="a", + delta_text="a", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[200], + model_status="think_start", + ) + self.assertEqual(msg.reasoning_content, "a") + + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="ab", + delta_text="ab", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[100, 101, 102], + model_status="think_start", + ) + self.assertEqual(msg.reasoning_content, "a") + self.assertEqual(msg.content, "b") + + msg = self.parser.extract_reasoning_content_streaming( + previous_text="a", + current_text="ab", + delta_text="b", + previous_token_ids=[1, 101], + current_token_ids=[], + delta_token_ids=[102], + model_status="think_start", + ) + self.assertEqual(msg.content, "b") + + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="a", + delta_text="a", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + model_status="think_start", + ) + self.assertEqual(msg.reasoning_content, "a") + + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="a", + delta_text="a", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[200], + model_status="think_end", + ) + self.assertEqual(msg.content, "a") + + def test_none_streaming_thinking_content(self): + reasoning_content, content = self.parser.extract_reasoning_content( + model_output="a", + request={}, + model_status="think_start", + ) + self.assertEqual(reasoning_content, "a") + self.assertEqual(content, "") + + reasoning_content, content = self.parser.extract_reasoning_content( + model_output="ab", + request={}, + model_status="think_start", + ) + self.assertEqual(reasoning_content, "a") + self.assertEqual(content, "b") + + reasoning_content, content = self.parser.extract_reasoning_content( + model_output="a", + request={}, + model_status="think_end", + ) + self.assertEqual(reasoning_content, "") + self.assertEqual(content, "a") + + +if __name__ == "__main__": + unittest.main()