From 515a0c713ab235eb6cc85676168ce0d0d6e97148 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 19 Feb 2025 09:57:01 +0800 Subject: [PATCH 01/12] [v0][structured output] Support reasoning output Signed-off-by: Ce Gao --- ...etion_structured_outputs_with_reasoning.py | 64 +++++++++++++++ .../model_executor/test_guided_processors.py | 80 ++++++++++++++++++- vllm/config.py | 2 + vllm/engine/arg_utils.py | 26 +++++- vllm/engine/async_llm_engine.py | 10 ++- vllm/engine/llm_engine.py | 7 +- vllm/engine/multiprocessing/client.py | 3 +- vllm/entrypoints/openai/cli_args.py | 18 ----- .../guided_decoding/__init__.py | 32 ++++++-- .../guided_decoding/outlines_decoding.py | 31 ++++--- .../outlines_logits_processors.py | 32 ++++++-- .../guided_decoding/reasoner/__init__.py | 19 +++++ .../reasoner/deepseek_reasoner.py | 35 ++++++++ .../guided_decoding/reasoner/reasoner.py | 28 +++++++ .../guided_decoding/xgrammar_decoding.py | 16 +++- 15 files changed, 349 insertions(+), 54 deletions(-) create mode 100644 examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py create mode 100644 vllm/model_executor/guided_decoding/reasoner/__init__.py create mode 100644 vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py create mode 100644 vllm/model_executor/guided_decoding/reasoner/reasoner.py diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py new file mode 100644 index 000000000000..1f72e1164d42 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +An example shows how to generate structured outputs from reasoning models +like DeepSeekR1. The thinking process will not be guided by the JSON +schema provided by the user. Only the final output will be structured. + +To run this example, you need to start the vLLM server with the reasoning +parser: + +```bash +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --enable-reasoning --reasoning-parser deepseek_r1 +``` + +This example demonstrates how to generate chat completions from reasoning models +using the OpenAI Python client library. +""" + +from enum import Enum + +from openai import OpenAI +from pydantic import BaseModel + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's, think in 100 tokens") +completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, +) +print("content", completion.choices[0].message.content) +print("reasoning_content: ", completion.choices[0].message.reasoning_content) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index be544698fa03..30fc8b4b5350 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -16,15 +16,19 @@ MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] +GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') - regex_LP = RegexLogitsProcessor(sample_regex, tokenizer) + regex_LP = RegexLogitsProcessor(sample_regex, + tokenizer, + reasoner_config=None) json_LP = JSONLogitsProcessor(sample_json_schema, tokenizer, - whitespace_pattern=None) + whitespace_pattern=None, + reasoner_config=None) token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") @@ -91,6 +95,78 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert not torch.allclose(tensor, original_tensor) +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", + GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT) +@pytest.mark.parametrize("is_local", [True, False]) +@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"]) +async def test_guided_logits_processor_with_reasoning(backend: str, + is_local: bool, + reasoning_backend: str, + sample_regex, + sample_json_schema): + + reasoning_model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + config = ModelConfig( + reasoning_model, + task="generate", + tokenizer=reasoning_model, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + tokenizer = AutoTokenizer.from_pretrained(reasoning_model) + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {sample_regex}." + "here is the thinking process") + regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) + + regex_lp = get_local_guided_decoding_logits_processor(regex_request, + tokenizer, config, reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + regex_request, tokenizer, config, reasoning_backend) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert torch.allclose(tensor, original_tensor) + + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {sample_json_schema}." + "here is the thinking process") + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) + json_lp = get_local_guided_decoding_logits_processor( + json_request, tokenizer, config, reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + json_request, tokenizer, config, reasoning_backend) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert torch.allclose(tensor, original_tensor) + + # Thinking is over, so the tensor should change. + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {sample_json_schema}." + "here is the thinking process Then") + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) + json_lp = get_local_guided_decoding_logits_processor( + json_request, tokenizer, config, reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + json_request, tokenizer, config, reasoning_backend) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): with pytest.raises(ValueError, match="You can only use one kind of guided"): diff --git a/vllm/config.py b/vllm/config.py index c7108473442b..54ed38418dd4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2715,6 +2715,8 @@ class DecodingConfig: # 'outlines' / 'lm-format-enforcer' / 'xgrammar' guided_decoding_backend: str = 'xgrammar' + reasoning_backend: Optional[str] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1a2f794c9151..989eb4dbfd14 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -213,6 +213,8 @@ class EngineArgs: calculate_kv_scales: Optional[bool] = None additional_config: Optional[Dict[str, Any]] = None + enable_reasoning: Optional[bool] = None + reasoning_parser: Optional[str] = None def __post_init__(self): if not self.tokenizer: @@ -1059,6 +1061,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Different platforms may support different configs. Make sure the " "configs are valid for the platform you are using. The input format" " is like '{\"config_key\":\"config_value\"}'") + + parser.add_argument( + "--enable-reasoning", + action="store_true", + default=False, + help="Whether to enable reasoning_content for the model. " + "If enabled, the model will be able to generate reasoning content." + ) + + parser.add_argument( + "--reasoning-parser", + type=str, + choices=["deepseek_r1"], + default=None, + help= + "Select the reasoning parser depending on the model that you're " + "using. This is used to parse the reasoning content into OpenAI " + "API format. Required for ``--enable-reasoning``.") + return parser @classmethod @@ -1332,7 +1353,10 @@ def create_engine_config(self, if self.enable_prompt_adapter else None decoding_config = DecodingConfig( - guided_decoding_backend=self.guided_decoding_backend) + guided_decoding_backend=self.guided_decoding_backend, + reasoning_backend=self.reasoning_parser + if self.enable_reasoning else None, + ) show_hidden_metrics = False if self.show_hidden_metrics_for_version is not None: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 93d9b74d8e1e..8da44c44ead5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -509,6 +509,7 @@ async def add_request_async( tokenizer=await self.get_tokenizer_async(lora_request), default_guided_backend=self.decoding_config. guided_decoding_backend, + reasoning_backend=self.decoding_config.reasoning_backend, model_config=self.model_config) self._add_processed_request( @@ -530,7 +531,7 @@ async def check_health_async(self) -> None: async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, - default_guided_backend: str, + default_guided_backend: str, reasoning_backend: Optional[str], model_config: ModelConfig) -> SamplingParams: """Constructs logits processors based on the guided_decoding, logits_bias, and allowed_token_ids fields in sampling_params. Deletes @@ -545,14 +546,17 @@ async def build_guided_decoding_logits_processor_async( sampling_params = copy.copy(sampling_params) guided_decoding = sampling_params.guided_decoding - logger.debug("Building guided decoding logits processor. " - "Params: %s", guided_decoding) + logger.debug( + "Building guided decoding logits processor. " + "guided_decoding: %s, reasoning: %s", guided_decoding, + reasoning_backend) guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( guided_params=guided_decoding, tokenizer=tokenizer, + reasoning_backend=reasoning_backend, model_config=model_config) if processor: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9c83ea75ead7..f055438d1feb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2048,10 +2048,15 @@ def _build_logits_processors( guided_decoding.backend = guided_decoding.backend or \ self.decoding_config.guided_decoding_backend + logger.debug("Reasoning backend: %s", + self.decoding_config.reasoning_backend) + processor = get_local_guided_decoding_logits_processor( guided_params=guided_decoding, tokenizer=tokenizer, - model_config=self.model_config) + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) if processor: logits_processors.append(processor) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index c12fe242082b..005ba81cd226 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -611,7 +611,8 @@ async def _process_request( default_guided_backend=(self.decoding_config.guided_decoding_backend if self.decoding_config else DecodingConfig.guided_decoding_backend), - model_config=self.model_config + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, ) # 1) Create output queue for this requests. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index ba953c219708..8d877046f75f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -13,7 +13,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) -from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Enable auto tool choice for supported models. Use " "``--tool-call-parser`` to specify which parser to use.") - parser.add_argument( - "--enable-reasoning", - action="store_true", - default=False, - help="Whether to enable reasoning_content for the model. " - "If enabled, the model will be able to generate reasoning content.") - - valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys() - parser.add_argument( - "--reasoning-parser", - type=str, - metavar="{" + ",".join(valid_reasoning_parsers) + "}", - default=None, - help= - "Select the reasoning parser depending on the model that you're using." - " This is used to parse the reasoning content into OpenAI API " - "format. Required for ``--enable-reasoning``.") valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 1522e3404182..af4b64c2da77 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.reasoner import (ReasonerConfig, + get_reasoner) from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) @@ -104,7 +106,14 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str, async def get_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: + model_config: ModelConfig, + reasoning_backend: str | None) -> LogitsProcessor | None: + + reasoner_config = None + if reasoning_backend is not None: + reasoner = get_reasoner(reasoning_backend, tokenizer) + reasoner_config = ReasonerConfig.from_reasoner(reasoner) + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': @@ -112,8 +121,8 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) - if guided_params.backend_name == 'lm-format-enforcer': + guided_params, tokenizer, reasoner_config) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( @@ -122,7 +131,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params, tokenizer, model_config, reasoner_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -131,16 +140,23 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: + model_config: ModelConfig, + reasoning_backend: str | None) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) + + reasoner_config = None + if reasoning_backend is not None: + reasoner = get_reasoner(reasoning_backend, tokenizer) + reasoner_config = ReasonerConfig.from_reasoner(reasoner) + # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) - if guided_params.backend_name == 'lm-format-enforcer': + guided_params, tokenizer, reasoner_config) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( @@ -149,7 +165,7 @@ def get_local_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params, tokenizer, model_config, reasoner_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index ba9c98290368..8df1f6053dca 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -6,12 +6,13 @@ from enum import Enum from json import dumps as json_dumps from re import escape as regex_escape -from typing import Tuple, Union +from typing import Optional, Tuple, Union from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.reasoner import ReasonerConfig from vllm.sampling_params import GuidedDecodingParams @@ -58,7 +59,9 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner_config: Optional[ReasonerConfig], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor( return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, - mode, guided_params.whitespace_pattern) + mode, guided_params.whitespace_pattern, + reasoner_config) def get_local_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner_config: Optional[ReasonerConfig], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -100,7 +106,8 @@ def get_local_outlines_guided_decoding_logits_processor( return None return _get_logits_processor(guide, tokenizer, mode, - guided_params.whitespace_pattern) + guided_params.whitespace_pattern, + reasoner_config) def _get_guide_and_mode( @@ -131,14 +138,18 @@ def _get_guide_and_mode( def _get_logits_processor( - guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, - whitespace_pattern: Union[str, None] + guide: str, + tokenizer: PreTrainedTokenizerBase, + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None], + reasoner_config: Optional[ReasonerConfig], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, + reasoner_config) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: - return RegexLogitsProcessor(guide, tokenizer) + return RegexLogitsProcessor(guide, tokenizer, reasoner_config) elif mode == GuidedDecodingMode.GRAMMAR: - return CFGLogitsProcessor(guide, tokenizer) + return CFGLogitsProcessor(guide, tokenizer, reasoner_config) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index a05267d921d1..99b7ef92b9e4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -19,7 +19,7 @@ import json from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Union +from typing import Callable, DefaultDict, Dict, List, Optional, Union import numpy as np import torch @@ -32,13 +32,19 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.reasoner import ReasonerConfig from vllm.platforms import current_platform +logger = init_logger(__name__) + class BaseLogitsProcessor: - def __init__(self, guide: Guide): + def __init__(self, guide: Guide, + reasoner_config: Optional[ReasonerConfig]): self._guide: Guide = guide + self._reasoner_config = reasoner_config # CFGState is used for the FSM state for CFGGuide self._fsm_state: DefaultDict[int, Union[int, CFGState]] = defaultdict(int) @@ -46,6 +52,11 @@ def __init__(self, guide: Guide): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" + if self._reasoner_config is not None and \ + not self._reasoner_config.is_reasoning_end( + input_ids): + return scores + seq_id = hash(tuple(input_ids)) if len(input_ids) > 0: @@ -113,7 +124,8 @@ def _get_guide(cls, regex_string: str, tokenizer = _adapt_tokenizer(tokenizer) return RegexGuide.from_regex(regex_string, tokenizer) - def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, + reasoner_config: Optional[ReasonerConfig]): """Compile the FSM that drives the regex-structured generation. Parameters @@ -125,14 +137,16 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): """ super().__init__( - RegexLogitsProcessor._get_guide(regex_string, tokenizer)) + RegexLogitsProcessor._get_guide(regex_string, tokenizer), + reasoner_config) class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Union[str, None]): + whitespace_pattern: Union[str, None], + reasoner_config: Optional[ReasonerConfig]): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -160,7 +174,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel], f"a Pydantic object, a dictionary or a string that contains " f"the JSON Schema specification") regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string, tokenizer) + super().__init__(regex_string, tokenizer, reasoner_config) class CFGLogitsProcessor(BaseLogitsProcessor): @@ -171,7 +185,8 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer = _adapt_tokenizer(tokenizer) return CFGGuide(cfg, tokenizer) - def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, + reasoner_config: Optional[ReasonerConfig]): """Compile the FSM that drives the context free grammar generation. Parameters @@ -182,7 +197,8 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer)) + super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), + reasoner_config) self._guide = self._guide.copy() diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py new file mode 100644 index 000000000000..ab89f5eb27b5 --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +from transformers import PreTrainedTokenizer + +from vllm.model_executor.guided_decoding.reasoner.reasoner import ( + Reasoner, ReasonerConfig) + + +def get_reasoner(reasoning_backend: str, + tokenizer: PreTrainedTokenizer) -> Reasoner: + if reasoning_backend == "deepseek_r1": + from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa + DeepSeekReasoner) + return DeepSeekReasoner(tokenizer) + + raise ValueError(f"Unknown reasoner '{reasoning_backend}'. " + "Must be one of 'deepseek'") + + +__all__ = ["get_reasoner", "ReasonerConfig", "Reasoner"] diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py new file mode 100644 index 000000000000..0a4dc60b93c8 --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +from transformers import PreTrainedTokenizer + +from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner + + +class DeepSeekReasoner(Reasoner): + _instance = None + _start_token_id = None + _end_token_id = None + + def __new__(cls, tokenizer: PreTrainedTokenizer): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, tokenizer: PreTrainedTokenizer): + self.tokenizer = tokenizer + + # Initialize token IDs only once + if self.__class__._start_token_id is None: + self.__class__._start_token_id = tokenizer.encode( + "", add_special_tokens=False)[0] + self.__class__._end_token_id = tokenizer.encode( + "", add_special_tokens=False)[0] + + # Use class variables + self.start_token_id = self.__class__._start_token_id + self.end_token_id = self.__class__._end_token_id + + def get_start_token_id(self) -> int: + return self.start_token_id + + def get_end_token_id(self) -> int: + return self.end_token_id diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py new file mode 100644 index 000000000000..f9214d8d1271 --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/reasoner.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +class Reasoner(ABC): + + @abstractmethod + def get_start_token_id(self) -> int: + pass + + @abstractmethod + def get_end_token_id(self) -> int: + pass + + +@dataclass +class ReasonerConfig: + start_token_id: int + end_token_id: int + + @classmethod + def from_reasoner(cls, reasoner: Reasoner) -> 'ReasonerConfig': + return cls(start_token_id=reasoner.get_start_token_id(), + end_token_id=reasoner.get_end_token_id()) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index eb9d83acb286..c2a1e757edb9 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -11,6 +11,8 @@ import torch from transformers import PreTrainedTokenizerFast +from vllm.logger import init_logger + try: import xgrammar as xgr from xgrammar.base import _core as xgr_core @@ -28,6 +30,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig + from vllm.model_executor.guided_decoding.reasoner import ReasonerConfig from vllm.sampling_params import GuidedDecodingParams logger = init_logger(__name__) @@ -38,12 +41,13 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, + reasoner_config: ReasonerConfig | None, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, model_config=model_config, tokenizer=tokenizer, max_threads=max_threads) - return XGrammarLogitsProcessor(config) + return XGrammarLogitsProcessor(config, reasoner_config) @dataclass(frozen=True) @@ -293,6 +297,7 @@ def choice_as_grammar(choice: List[str] | None) -> str: class XGrammarLogitsProcessor: """Wrapper class to support pickle protocol""" config: GrammarConfig + reasoner_config: ReasonerConfig | None = None ctx: xgr.CompiledGrammar | None = None token_bitmask: torch.Tensor = None # type: ignore[assignment] @@ -301,10 +306,11 @@ class XGrammarLogitsProcessor: prefilled: bool = field(default=False) def __getstate__(self) -> dict[str, Any]: - return {'config': self.config} + return {'config': self.config, 'reasoner_config': self.reasoner_config} def __setstate__(self, state: dict[str, Any]): self.config = state['config'] + self.reasoner_config = state['reasoner_config'] self.ctx = None self.matchers = [] @@ -331,6 +337,12 @@ def _ensure_ctx(self): def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: + + if self.reasoner_config is not None and \ + not self.reasoner_config.is_reasoning_end( + input_ids): + return scores + if self.ctx is None: self._ensure_ctx() From 7c7d83577ae08cc163f12473634a5dbc5b5005ac Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 19 Feb 2025 10:27:37 +0800 Subject: [PATCH 02/12] chore: Add docs Signed-off-by: Ce Gao --- docs/source/features/reasoning_outputs.md | 52 +++++++++++++++++++---- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index e39bbacf1138..6a78bdee6ce5 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni } ``` -Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. +Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). + +## Limitations + +- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). +- It is not compatible with [`tool_calling`](#tool_calling). +- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning. ## How to support a new reasoning model @@ -137,15 +143,45 @@ class ExampleParser(ReasoningParser): """ ``` -After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint. +Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`. + +```python +class DeepSeekReasoner(Reasoner): + _instance = None + _start_token_id = None + _end_token_id = None + + def __new__(cls, tokenizer: PreTrainedTokenizer): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, tokenizer: PreTrainedTokenizer): + self.tokenizer = tokenizer + + # Initialize token IDs only once + if self.__class__._start_token_id is None: + self.__class__._start_token_id = tokenizer.encode( + "", add_special_tokens=False)[0] + self.__class__._end_token_id = tokenizer.encode( + "", add_special_tokens=False)[0] + + # Use class variables + self.start_token_id = self.__class__._start_token_id + self.end_token_id = self.__class__._end_token_id + + def get_start_token_id(self) -> int: + return self.start_token_id + + def get_end_token_id(self) -> int: + return self.end_token_id +``` + +The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. + +Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags. ```bash vllm serve \ --enable-reasoning --reasoning-parser example ``` - -## Limitations - -- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). -- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features. -- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning. From 1dc8f577dd0276a1c6a3c078b1ffe3149cb81160 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 26 Feb 2025 09:54:02 +0800 Subject: [PATCH 03/12] resolve the conflicts and address comments Signed-off-by: Ce Gao --- .../model_executor/test_guided_processors.py | 14 +++++---- .../guided_decoding/__init__.py | 11 ++++--- .../outlines_logits_processors.py | 3 ++ .../reasoner/deepseek_reasoner.py | 29 ++++++++++++------- .../guided_decoding/xgrammar_decoding.py | 2 ++ 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 30fc8b4b5350..223654f7a5d6 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -18,19 +18,21 @@ GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] +# Load the tokenizer for the model to speed up the tests +zephyr_7B_tokenzer = AutoTokenizer.from_pretrained(MODEL_NAME) + def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') regex_LP = RegexLogitsProcessor(sample_regex, - tokenizer, + zephyr_7B_tokenzer, reasoner_config=None) json_LP = JSONLogitsProcessor(sample_json_schema, - tokenizer, + zephyr_7B_tokenzer, whitespace_pattern=None, reasoner_config=None) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -38,7 +40,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) tensor = torch.rand(32000) @@ -64,7 +66,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, seed=0, dtype="bfloat16", ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + tokenizer = zephyr_7B_tokenzer token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index af4b64c2da77..24c1409af3a1 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -105,9 +105,10 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str, async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, model_config: ModelConfig, - reasoning_backend: str | None) -> LogitsProcessor | None: + reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoner_config = None if reasoning_backend is not None: @@ -139,11 +140,13 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, model_config: ModelConfig, - reasoning_backend: str | None) -> LogitsProcessor | None: + reasoning_backend: str | None = None) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) + # Get the reasoner config if a reasoning backend is specified reasoner_config = None if reasoning_backend is not None: reasoner = get_reasoner(reasoning_backend, tokenizer) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 99b7ef92b9e4..9ccca67ed84c 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -52,6 +52,9 @@ def __init__(self, guide: Guide, def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" + + # Skip the structured logits processing if reasoning is not finished. + # reasoner_config is not None only when `--enable-reasoning` is set. if self._reasoner_config is not None and \ not self._reasoner_config.is_reasoning_end( input_ids): diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py index 0a4dc60b93c8..d37375c6c75b 100644 --- a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py +++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py @@ -1,30 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 +from threading import Lock + from transformers import PreTrainedTokenizer from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner class DeepSeekReasoner(Reasoner): + """ + Reasoner for DeepSeek. + + This class is a singleton and should be instantiated with the tokenizer + to ensure that the start and end token IDs are initialized only once. + """ _instance = None _start_token_id = None _end_token_id = None + _lock = Lock() def __new__(cls, tokenizer: PreTrainedTokenizer): - if cls._instance is None: - cls._instance = super().__new__(cls) + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + # Initialize token IDs in __new__ + cls._start_token_id = tokenizer.encode( + "", add_special_tokens=False)[0] + cls._end_token_id = tokenizer.encode( + "", add_special_tokens=False)[0] return cls._instance def __init__(self, tokenizer: PreTrainedTokenizer): self.tokenizer = tokenizer - - # Initialize token IDs only once - if self.__class__._start_token_id is None: - self.__class__._start_token_id = tokenizer.encode( - "", add_special_tokens=False)[0] - self.__class__._end_token_id = tokenizer.encode( - "", add_special_tokens=False)[0] - - # Use class variables + # Use class variables to avoid reinitializing the token IDs self.start_token_id = self.__class__._start_token_id self.end_token_id = self.__class__._end_token_id diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index c2a1e757edb9..551cbc887fee 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -338,6 +338,8 @@ def _ensure_ctx(self): def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: + # Skip the structured logits processing if reasoning is not finished. + # reasoner_config is not None only when `--enable-reasoning` is set. if self.reasoner_config is not None and \ not self.reasoner_config.is_reasoning_end( input_ids): From 1d634b03cbd7c7c1c75a5d093c2ad48b556df7fa Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 26 Feb 2025 10:15:47 +0800 Subject: [PATCH 04/12] make pre commit happy Signed-off-by: Ce Gao --- .../guided_decoding/reasoner/deepseek_reasoner.py | 4 ++-- .../guided_decoding/reasoner/reasoner.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py index d37375c6c75b..49af6b612c74 100644 --- a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py +++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py @@ -35,8 +35,8 @@ def __init__(self, tokenizer: PreTrainedTokenizer): self.start_token_id = self.__class__._start_token_id self.end_token_id = self.__class__._end_token_id - def get_start_token_id(self) -> int: + def get_start_token_id(self) -> int | None: return self.start_token_id - def get_end_token_id(self) -> int: + def get_end_token_id(self) -> int | None: return self.end_token_id diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py index f9214d8d1271..f0fe25313a3b 100644 --- a/vllm/model_executor/guided_decoding/reasoner/reasoner.py +++ b/vllm/model_executor/guided_decoding/reasoner/reasoner.py @@ -6,11 +6,11 @@ class Reasoner(ABC): @abstractmethod - def get_start_token_id(self) -> int: + def get_start_token_id(self) -> int | None: pass @abstractmethod - def get_end_token_id(self) -> int: + def get_end_token_id(self) -> int | None: pass @@ -21,8 +21,11 @@ class ReasonerConfig: @classmethod def from_reasoner(cls, reasoner: Reasoner) -> 'ReasonerConfig': - return cls(start_token_id=reasoner.get_start_token_id(), - end_token_id=reasoner.get_end_token_id()) + if reasoner is None or reasoner.get_start_token_id() is None or \ + reasoner.get_end_token_id() is None: + raise ValueError("The reasoner must have token IDs.") + return cls(start_token_id=int(reasoner.get_start_token_id()), + end_token_id=int(reasoner.get_end_token_id())) def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.end_token_id in input_ids From 2abb04e1b25dc5316e157ce883983d703111c316 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 26 Feb 2025 10:53:03 +0800 Subject: [PATCH 05/12] fix test cases Signed-off-by: Ce Gao --- .../model_executor/test_guided_processors.py | 66 +++++++++++-------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 223654f7a5d6..b1d7eaa0ab0e 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -17,12 +17,22 @@ MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" -# Load the tokenizer for the model to speed up the tests -zephyr_7B_tokenzer = AutoTokenizer.from_pretrained(MODEL_NAME) +# Initialize the tokenizer for the model here to avoid repeated loading +@pytest.fixture(scope="module") +def zephyr_7B_tokenzer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) -def test_guided_logits_processors(sample_regex, sample_json_schema): + +@pytest.fixture(scope="module") +def deepseek_r1_qwen_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, + sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" regex_LP = RegexLogitsProcessor(sample_regex, zephyr_7B_tokenzer, @@ -55,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.parametrize("is_local", [True, False]) async def test_guided_logits_processor_black_box(backend: str, is_local: bool, sample_regex, - sample_json_schema): + sample_json_schema, + zephyr_7B_tokenzer): config = ModelConfig( MODEL_NAME, @@ -66,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, seed=0, dtype="bfloat16", ) - tokenizer = zephyr_7B_tokenzer - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = get_local_guided_decoding_logits_processor( - regex_request, tokenizer, config) if is_local else \ + regex_request, zephyr_7B_tokenzer, config) if is_local else \ await get_guided_decoding_logits_processor( - regex_request, tokenizer, config) + regex_request, zephyr_7B_tokenzer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -82,13 +92,13 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer, config) + json_request, zephyr_7B_tokenzer, config) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -102,32 +112,30 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT) @pytest.mark.parametrize("is_local", [True, False]) @pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"]) -async def test_guided_logits_processor_with_reasoning(backend: str, - is_local: bool, - reasoning_backend: str, - sample_regex, - sample_json_schema): +async def test_guided_logits_processor_with_reasoning( + backend: str, is_local: bool, reasoning_backend: str, sample_regex, + sample_json_schema, deepseek_r1_qwen_tokenizer): - reasoning_model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" config = ModelConfig( - reasoning_model, + REASONING_MODEL_NAME, task="generate", - tokenizer=reasoning_model, + tokenizer=REASONING_MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, seed=0, dtype="bfloat16", ) - tokenizer = AutoTokenizer.from_pretrained(reasoning_model) - token_ids = tokenizer.encode( + token_ids = deepseek_r1_qwen_tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}." "here is the thinking process") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = get_local_guided_decoding_logits_processor(regex_request, - tokenizer, config, reasoning_backend) if is_local else \ + deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ await get_guided_decoding_logits_processor( - regex_request, tokenizer, config, reasoning_backend) + regex_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -135,15 +143,16 @@ async def test_guided_logits_processor_with_reasoning(backend: str, assert tensor.shape == original_tensor.shape assert torch.allclose(tensor, original_tensor) - token_ids = tokenizer.encode( + token_ids = deepseek_r1_qwen_tokenizer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}." "here is the thinking process") json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = get_local_guided_decoding_logits_processor( - json_request, tokenizer, config, reasoning_backend) if is_local else \ + json_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ await get_guided_decoding_logits_processor( - json_request, tokenizer, config, reasoning_backend) + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -152,15 +161,16 @@ async def test_guided_logits_processor_with_reasoning(backend: str, assert torch.allclose(tensor, original_tensor) # Thinking is over, so the tensor should change. - token_ids = tokenizer.encode( + token_ids = deepseek_r1_qwen_tokenizer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}." "here is the thinking process Then") json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = get_local_guided_decoding_logits_processor( - json_request, tokenizer, config, reasoning_backend) if is_local else \ + json_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ await get_guided_decoding_logits_processor( - json_request, tokenizer, config, reasoning_backend) + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) From aa5c2186af0703f0c7746c9a8b1d2a79f515d5f1 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 26 Feb 2025 13:22:18 +0800 Subject: [PATCH 06/12] refactor: Simplify the code base Signed-off-by: Ce Gao --- .../model_executor/test_guided_processors.py | 4 +- .../guided_decoding/__init__.py | 23 ++++------ .../guided_decoding/outlines_decoding.py | 19 ++++---- .../outlines_logits_processors.py | 26 +++++------ .../guided_decoding/reasoner/__init__.py | 26 ++++++----- .../reasoner/deepseek_reasoner.py | 44 +++++++------------ .../guided_decoding/reasoner/reasoner.py | 26 +++-------- .../guided_decoding/xgrammar_decoding.py | 18 ++++---- 8 files changed, 76 insertions(+), 110 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index b1d7eaa0ab0e..531c3a8c13b2 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -36,11 +36,11 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" regex_LP = RegexLogitsProcessor(sample_regex, zephyr_7B_tokenzer, - reasoner_config=None) + reasoner=None) json_LP = JSONLogitsProcessor(sample_json_schema, zephyr_7B_tokenzer, whitespace_pattern=None, - reasoner_config=None) + reasoner=None) token_ids = zephyr_7B_tokenzer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 24c1409af3a1..4ae5bac7a1e1 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -5,8 +5,7 @@ from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm.model_executor.guided_decoding.reasoner import (ReasonerConfig, - get_reasoner) +from vllm.model_executor.guided_decoding.reasoner import get_reasoner from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) @@ -110,10 +109,7 @@ async def get_guided_decoding_logits_processor( model_config: ModelConfig, reasoning_backend: str | None = None) -> LogitsProcessor | None: - reasoner_config = None - if reasoning_backend is not None: - reasoner = get_reasoner(reasoning_backend, tokenizer) - reasoner_config = ReasonerConfig.from_reasoner(reasoner) + reasoner = get_reasoner(tokenizer, reasoning_backend) guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead @@ -122,7 +118,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( - guided_params, tokenizer, reasoner_config) + guided_params, tokenizer, reasoner) if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) @@ -132,7 +128,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config, reasoner_config) + guided_params, tokenizer, model_config, reasoner) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -146,11 +142,8 @@ def get_local_guided_decoding_logits_processor( reasoning_backend: str | None = None) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) - # Get the reasoner config if a reasoning backend is specified - reasoner_config = None - if reasoning_backend is not None: - reasoner = get_reasoner(reasoning_backend, tokenizer) - reasoner_config = ReasonerConfig.from_reasoner(reasoner) + # Get the reasoner if needed, it will be None if reasoning_ + reasoner = get_reasoner(tokenizer, reasoning_backend) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': @@ -158,7 +151,7 @@ def get_local_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_params, tokenizer, reasoner_config) + guided_params, tokenizer, reasoner) if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) @@ -168,7 +161,7 @@ def get_local_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config, reasoner_config) + guided_params, tokenizer, model_config, reasoner) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 8df1f6053dca..97f63ae11f45 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,7 +12,7 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) -from vllm.model_executor.guided_decoding.reasoner import ReasonerConfig +from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.sampling_params import GuidedDecodingParams @@ -61,7 +61,7 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, - reasoner_config: Optional[ReasonerConfig], + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -86,13 +86,13 @@ async def get_outlines_guided_decoding_logits_processor( return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, mode, guided_params.whitespace_pattern, - reasoner_config) + reasoner) def get_local_outlines_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, - reasoner_config: Optional[ReasonerConfig], + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -106,8 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor( return None return _get_logits_processor(guide, tokenizer, mode, - guided_params.whitespace_pattern, - reasoner_config) + guided_params.whitespace_pattern, reasoner) def _get_guide_and_mode( @@ -142,14 +141,14 @@ def _get_logits_processor( tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, whitespace_pattern: Union[str, None], - reasoner_config: Optional[ReasonerConfig], + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, - reasoner_config) + reasoner) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: - return RegexLogitsProcessor(guide, tokenizer, reasoner_config) + return RegexLogitsProcessor(guide, tokenizer, reasoner) elif mode == GuidedDecodingMode.GRAMMAR: - return CFGLogitsProcessor(guide, tokenizer, reasoner_config) + return CFGLogitsProcessor(guide, tokenizer, reasoner) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 9ccca67ed84c..bd7ae7d2e4a0 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -33,7 +33,7 @@ from transformers import PreTrainedTokenizerBase from vllm.logger import init_logger -from vllm.model_executor.guided_decoding.reasoner import ReasonerConfig +from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.platforms import current_platform logger = init_logger(__name__) @@ -41,10 +41,9 @@ class BaseLogitsProcessor: - def __init__(self, guide: Guide, - reasoner_config: Optional[ReasonerConfig]): + def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): self._guide: Guide = guide - self._reasoner_config = reasoner_config + self._reasoner = reasoner # CFGState is used for the FSM state for CFGGuide self._fsm_state: DefaultDict[int, Union[int, CFGState]] = defaultdict(int) @@ -54,9 +53,9 @@ def __call__(self, input_ids: List[int], """Use the FSM to bias the logits before sampling the next token.""" # Skip the structured logits processing if reasoning is not finished. - # reasoner_config is not None only when `--enable-reasoning` is set. - if self._reasoner_config is not None and \ - not self._reasoner_config.is_reasoning_end( + # reasoner is not None only when `--enable-reasoning` is set. + if self._reasoner is not None and \ + not self._reasoner.is_reasoning_end( input_ids): return scores @@ -128,7 +127,7 @@ def _get_guide(cls, regex_string: str, return RegexGuide.from_regex(regex_string, tokenizer) def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, - reasoner_config: Optional[ReasonerConfig]): + reasoner: Optional[Reasoner]): """Compile the FSM that drives the regex-structured generation. Parameters @@ -140,8 +139,7 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, """ super().__init__( - RegexLogitsProcessor._get_guide(regex_string, tokenizer), - reasoner_config) + RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner) class JSONLogitsProcessor(RegexLogitsProcessor): @@ -149,7 +147,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Union[str, None], - reasoner_config: Optional[ReasonerConfig]): + reasoner: Optional[Reasoner]): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -177,7 +175,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel], f"a Pydantic object, a dictionary or a string that contains " f"the JSON Schema specification") regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string, tokenizer, reasoner_config) + super().__init__(regex_string, tokenizer, reasoner) class CFGLogitsProcessor(BaseLogitsProcessor): @@ -189,7 +187,7 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: return CFGGuide(cfg, tokenizer) def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, - reasoner_config: Optional[ReasonerConfig]): + reasoner: Optional[Reasoner]): """Compile the FSM that drives the context free grammar generation. Parameters @@ -201,7 +199,7 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, """ super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), - reasoner_config) + reasoner) self._guide = self._guide.copy() diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py index ab89f5eb27b5..65a3892cd761 100644 --- a/vllm/model_executor/guided_decoding/reasoner/__init__.py +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -1,19 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 -from transformers import PreTrainedTokenizer -from vllm.model_executor.guided_decoding.reasoner.reasoner import ( - Reasoner, ReasonerConfig) +from transformers import PreTrainedTokenizer +from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501 + DeepSeekReasoner) +from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner -def get_reasoner(reasoning_backend: str, - tokenizer: PreTrainedTokenizer) -> Reasoner: - if reasoning_backend == "deepseek_r1": - from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa - DeepSeekReasoner) - return DeepSeekReasoner(tokenizer) - raise ValueError(f"Unknown reasoner '{reasoning_backend}'. " - "Must be one of 'deepseek'") +def get_reasoner(tokenizer: PreTrainedTokenizer, + reasoning_backend: str | None) -> Reasoner | None: + if reasoning_backend is None: + # No reasoning backend specified + return None + elif reasoning_backend == "deepseek_r1": + return DeepSeekReasoner.from_tokenizer(tokenizer) + else: + raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'") -__all__ = ["get_reasoner", "ReasonerConfig", "Reasoner"] +__all__ = ["Reasoner", "get_reasoner"] diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py index 49af6b612c74..e762fb0659de 100644 --- a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py +++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py @@ -1,42 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 -from threading import Lock +from dataclasses import dataclass from transformers import PreTrainedTokenizer from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner +@dataclass class DeepSeekReasoner(Reasoner): """ - Reasoner for DeepSeek. - - This class is a singleton and should be instantiated with the tokenizer - to ensure that the start and end token IDs are initialized only once. + Reasoner for DeepSeek R series models. """ - _instance = None - _start_token_id = None - _end_token_id = None - _lock = Lock() - - def __new__(cls, tokenizer: PreTrainedTokenizer): - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - # Initialize token IDs in __new__ - cls._start_token_id = tokenizer.encode( - "", add_special_tokens=False)[0] - cls._end_token_id = tokenizer.encode( - "", add_special_tokens=False)[0] - return cls._instance + start_token_id: int + end_token_id: int - def __init__(self, tokenizer: PreTrainedTokenizer): - self.tokenizer = tokenizer - # Use class variables to avoid reinitializing the token IDs - self.start_token_id = self.__class__._start_token_id - self.end_token_id = self.__class__._end_token_id + start_token: str = "" + end_token: str = "" - def get_start_token_id(self) -> int | None: - return self.start_token_id + @classmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + return cls(start_token_id=tokenizer.encode( + "", add_special_tokens=False)[0], + end_token_id=tokenizer.encode("", + add_special_tokens=False)[0]) - def get_end_token_id(self) -> int | None: - return self.end_token_id + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py index f0fe25313a3b..5db0c9bc7850 100644 --- a/vllm/model_executor/guided_decoding/reasoner/reasoner.py +++ b/vllm/model_executor/guided_decoding/reasoner/reasoner.py @@ -1,31 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from abc import ABC, abstractmethod from dataclasses import dataclass +from transformers import PreTrainedTokenizer + +@dataclass class Reasoner(ABC): @abstractmethod - def get_start_token_id(self) -> int | None: + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: pass @abstractmethod - def get_end_token_id(self) -> int | None: - pass - - -@dataclass -class ReasonerConfig: - start_token_id: int - end_token_id: int - - @classmethod - def from_reasoner(cls, reasoner: Reasoner) -> 'ReasonerConfig': - if reasoner is None or reasoner.get_start_token_id() is None or \ - reasoner.get_end_token_id() is None: - raise ValueError("The reasoner must have token IDs.") - return cls(start_token_id=int(reasoner.get_start_token_id()), - end_token_id=int(reasoner.get_end_token_id())) - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids + pass diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 551cbc887fee..3dc525820b39 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -30,7 +30,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig - from vllm.model_executor.guided_decoding.reasoner import ReasonerConfig + from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.sampling_params import GuidedDecodingParams logger = init_logger(__name__) @@ -41,13 +41,13 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, - reasoner_config: ReasonerConfig | None, + reasoner: Reasoner | None, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, model_config=model_config, tokenizer=tokenizer, max_threads=max_threads) - return XGrammarLogitsProcessor(config, reasoner_config) + return XGrammarLogitsProcessor(config, reasoner) @dataclass(frozen=True) @@ -297,7 +297,7 @@ def choice_as_grammar(choice: List[str] | None) -> str: class XGrammarLogitsProcessor: """Wrapper class to support pickle protocol""" config: GrammarConfig - reasoner_config: ReasonerConfig | None = None + reasoner: Reasoner | None = None ctx: xgr.CompiledGrammar | None = None token_bitmask: torch.Tensor = None # type: ignore[assignment] @@ -306,11 +306,11 @@ class XGrammarLogitsProcessor: prefilled: bool = field(default=False) def __getstate__(self) -> dict[str, Any]: - return {'config': self.config, 'reasoner_config': self.reasoner_config} + return {'config': self.config, 'reasoner': self.reasoner} def __setstate__(self, state: dict[str, Any]): self.config = state['config'] - self.reasoner_config = state['reasoner_config'] + self.reasoner = state['reasoner'] self.ctx = None self.matchers = [] @@ -339,9 +339,9 @@ def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: # Skip the structured logits processing if reasoning is not finished. - # reasoner_config is not None only when `--enable-reasoning` is set. - if self.reasoner_config is not None and \ - not self.reasoner_config.is_reasoning_end( + # reasoner is not None only when `--enable-reasoning` is set. + if self.reasoner is not None and \ + not self.reasoner.is_reasoning_end( input_ids): return scores From c04291fb51fa7fe0d153d2b94840dfbef1508546 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 26 Feb 2025 15:06:38 +0800 Subject: [PATCH 07/12] address comments Signed-off-by: Ce Gao --- vllm/model_executor/guided_decoding/reasoner/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py index 65a3892cd761..5a91f791d45b 100644 --- a/vllm/model_executor/guided_decoding/reasoner/__init__.py +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from transformers import PreTrainedTokenizer from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501 From 807be61034b8918ea9264c250608e94fa23da7f9 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Thu, 27 Feb 2025 09:35:46 +0800 Subject: [PATCH 08/12] fix docs Signed-off-by: Ce Gao --- docs/source/features/reasoning_outputs.md | 47 +++++++++-------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 6a78bdee6ce5..5c0c1762f8aa 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -146,35 +146,26 @@ class ExampleParser(ReasoningParser): Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`. ```python +@dataclass class DeepSeekReasoner(Reasoner): - _instance = None - _start_token_id = None - _end_token_id = None - - def __new__(cls, tokenizer: PreTrainedTokenizer): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self, tokenizer: PreTrainedTokenizer): - self.tokenizer = tokenizer - - # Initialize token IDs only once - if self.__class__._start_token_id is None: - self.__class__._start_token_id = tokenizer.encode( - "", add_special_tokens=False)[0] - self.__class__._end_token_id = tokenizer.encode( - "", add_special_tokens=False)[0] - - # Use class variables - self.start_token_id = self.__class__._start_token_id - self.end_token_id = self.__class__._end_token_id - - def get_start_token_id(self) -> int: - return self.start_token_id - - def get_end_token_id(self) -> int: - return self.end_token_id + """ + Reasoner for DeepSeek R series models. + """ + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + + @classmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + return cls(start_token_id=tokenizer.encode( + "", add_special_tokens=False)[0], + end_token_id=tokenizer.encode("", + add_special_tokens=False)[0]) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids ``` The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. From 2c40d3a91ea47096b65f1da370c7d749e9b8ce13 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Thu, 27 Feb 2025 09:41:41 +0800 Subject: [PATCH 09/12] fix import Signed-off-by: Ce Gao --- vllm/model_executor/guided_decoding/xgrammar_decoding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 3dc525820b39..ce278c15ab3b 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -21,7 +21,6 @@ xgr_installed = False pass -from vllm.logger import init_logger from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, grammar_is_likely_lark) from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer From 732c583dd140d04cc9afc70fca43b23f4506afb6 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Fri, 28 Feb 2025 09:16:39 +0800 Subject: [PATCH 10/12] fix Signed-off-by: Ce Gao --- vllm/model_executor/guided_decoding/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 4ae5bac7a1e1..86f6f0e5f907 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -152,7 +152,7 @@ def get_local_guided_decoding_logits_processor( get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_params, tokenizer, reasoner) - if guided_params.backend == 'lm-format-enforcer': + if guided_params.backend_name == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( From 17969af402c08984d60753c866052ecdd1fa7fdf Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Sat, 1 Mar 2025 09:25:01 +0800 Subject: [PATCH 11/12] fix: Fix log Signed-off-by: Ce Gao --- vllm/engine/async_llm_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8da44c44ead5..90e66b005f39 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -546,10 +546,11 @@ async def build_guided_decoding_logits_processor_async( sampling_params = copy.copy(sampling_params) guided_decoding = sampling_params.guided_decoding - logger.debug( + logger.info( "Building guided decoding logits processor. " - "guided_decoding: %s, reasoning: %s", guided_decoding, - reasoning_backend) + "guided_decoding: %s%s", guided_decoding, + f", reasoning_backend: {reasoning_backend}" + if reasoning_backend is not None else "") guided_decoding.backend = guided_decoding.backend or default_guided_backend From f895fa2435aba1a9d2a161814951856146cd78da Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Sun, 2 Mar 2025 13:07:27 +0800 Subject: [PATCH 12/12] fix format Signed-off-by: Ce Gao --- .../guided_decoding/outlines_logits_processors.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index bd7ae7d2e4a0..db5d738f42e4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -126,8 +126,12 @@ def _get_guide(cls, regex_string: str, tokenizer = _adapt_tokenizer(tokenizer) return RegexGuide.from_regex(regex_string, tokenizer) - def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner]): + def __init__( + self, + regex_string: str, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner], + ): """Compile the FSM that drives the regex-structured generation. Parameters