diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index e39bbacf1138..5c0c1762f8aa 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni } ``` -Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. +Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). + +## Limitations + +- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). +- It is not compatible with [`tool_calling`](#tool_calling). +- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning. ## How to support a new reasoning model @@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser): """ ``` -After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint. +Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`. + +```python +@dataclass +class DeepSeekReasoner(Reasoner): + """ + Reasoner for DeepSeek R series models. + """ + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + + @classmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + return cls(start_token_id=tokenizer.encode( + "", add_special_tokens=False)[0], + end_token_id=tokenizer.encode("", + add_special_tokens=False)[0]) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids +``` + +The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. + +Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags. ```bash vllm serve \ --enable-reasoning --reasoning-parser example ``` - -## Limitations - -- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). -- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features. -- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning. diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py new file mode 100644 index 000000000000..1f72e1164d42 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +An example shows how to generate structured outputs from reasoning models +like DeepSeekR1. The thinking process will not be guided by the JSON +schema provided by the user. Only the final output will be structured. + +To run this example, you need to start the vLLM server with the reasoning +parser: + +```bash +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --enable-reasoning --reasoning-parser deepseek_r1 +``` + +This example demonstrates how to generate chat completions from reasoning models +using the OpenAI Python client library. +""" + +from enum import Enum + +from openai import OpenAI +from pydantic import BaseModel + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's, think in 100 tokens") +completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, +) +print("content", completion.choices[0].message.content) +print("reasoning_content: ", completion.choices[0].message.reasoning_content) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index be544698fa03..531c3a8c13b2 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -16,17 +16,33 @@ MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] +GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" -def test_guided_logits_processors(sample_regex, sample_json_schema): +# Initialize the tokenizer for the model here to avoid repeated loading +@pytest.fixture(scope="module") +def zephyr_7B_tokenzer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def deepseek_r1_qwen_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, + sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') - regex_LP = RegexLogitsProcessor(sample_regex, tokenizer) + regex_LP = RegexLogitsProcessor(sample_regex, + zephyr_7B_tokenzer, + reasoner=None) json_LP = JSONLogitsProcessor(sample_json_schema, - tokenizer, - whitespace_pattern=None) + zephyr_7B_tokenzer, + whitespace_pattern=None, + reasoner=None) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) tensor = torch.rand(32000) @@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.parametrize("is_local", [True, False]) async def test_guided_logits_processor_black_box(backend: str, is_local: bool, sample_regex, - sample_json_schema): + sample_json_schema, + zephyr_7B_tokenzer): config = ModelConfig( MODEL_NAME, @@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, seed=0, dtype="bfloat16", ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = get_local_guided_decoding_logits_processor( - regex_request, tokenizer, config) if is_local else \ + regex_request, zephyr_7B_tokenzer, config) if is_local else \ await get_guided_decoding_logits_processor( - regex_request, tokenizer, config) + regex_request, zephyr_7B_tokenzer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer, config) + json_request, zephyr_7B_tokenzer, config) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", + GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT) +@pytest.mark.parametrize("is_local", [True, False]) +@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"]) +async def test_guided_logits_processor_with_reasoning( + backend: str, is_local: bool, reasoning_backend: str, sample_regex, + sample_json_schema, deepseek_r1_qwen_tokenizer): + + config = ModelConfig( + REASONING_MODEL_NAME, + task="generate", + tokenizer=REASONING_MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + token_ids = deepseek_r1_qwen_tokenizer.encode( + f"Give an example IPv4 address with this regex: {sample_regex}." + "here is the thinking process") + regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) + + regex_lp = get_local_guided_decoding_logits_processor(regex_request, + deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + regex_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert torch.allclose(tensor, original_tensor) + + token_ids = deepseek_r1_qwen_tokenizer.encode( + f"Give an employee profile that fits this schema: {sample_json_schema}." + "here is the thinking process") + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) + json_lp = get_local_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert torch.allclose(tensor, original_tensor) + + # Thinking is over, so the tensor should change. + token_ids = deepseek_r1_qwen_tokenizer.encode( + f"Give an employee profile that fits this schema: {sample_json_schema}." + "here is the thinking process Then") + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) + json_lp = get_local_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) diff --git a/vllm/config.py b/vllm/config.py index c7108473442b..54ed38418dd4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2715,6 +2715,8 @@ class DecodingConfig: # 'outlines' / 'lm-format-enforcer' / 'xgrammar' guided_decoding_backend: str = 'xgrammar' + reasoning_backend: Optional[str] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1a2f794c9151..989eb4dbfd14 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -213,6 +213,8 @@ class EngineArgs: calculate_kv_scales: Optional[bool] = None additional_config: Optional[Dict[str, Any]] = None + enable_reasoning: Optional[bool] = None + reasoning_parser: Optional[str] = None def __post_init__(self): if not self.tokenizer: @@ -1059,6 +1061,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Different platforms may support different configs. Make sure the " "configs are valid for the platform you are using. The input format" " is like '{\"config_key\":\"config_value\"}'") + + parser.add_argument( + "--enable-reasoning", + action="store_true", + default=False, + help="Whether to enable reasoning_content for the model. " + "If enabled, the model will be able to generate reasoning content." + ) + + parser.add_argument( + "--reasoning-parser", + type=str, + choices=["deepseek_r1"], + default=None, + help= + "Select the reasoning parser depending on the model that you're " + "using. This is used to parse the reasoning content into OpenAI " + "API format. Required for ``--enable-reasoning``.") + return parser @classmethod @@ -1332,7 +1353,10 @@ def create_engine_config(self, if self.enable_prompt_adapter else None decoding_config = DecodingConfig( - guided_decoding_backend=self.guided_decoding_backend) + guided_decoding_backend=self.guided_decoding_backend, + reasoning_backend=self.reasoning_parser + if self.enable_reasoning else None, + ) show_hidden_metrics = False if self.show_hidden_metrics_for_version is not None: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 93d9b74d8e1e..90e66b005f39 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -509,6 +509,7 @@ async def add_request_async( tokenizer=await self.get_tokenizer_async(lora_request), default_guided_backend=self.decoding_config. guided_decoding_backend, + reasoning_backend=self.decoding_config.reasoning_backend, model_config=self.model_config) self._add_processed_request( @@ -530,7 +531,7 @@ async def check_health_async(self) -> None: async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, - default_guided_backend: str, + default_guided_backend: str, reasoning_backend: Optional[str], model_config: ModelConfig) -> SamplingParams: """Constructs logits processors based on the guided_decoding, logits_bias, and allowed_token_ids fields in sampling_params. Deletes @@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async( sampling_params = copy.copy(sampling_params) guided_decoding = sampling_params.guided_decoding - logger.debug("Building guided decoding logits processor. " - "Params: %s", guided_decoding) + logger.info( + "Building guided decoding logits processor. " + "guided_decoding: %s%s", guided_decoding, + f", reasoning_backend: {reasoning_backend}" + if reasoning_backend is not None else "") guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( guided_params=guided_decoding, tokenizer=tokenizer, + reasoning_backend=reasoning_backend, model_config=model_config) if processor: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9c83ea75ead7..f055438d1feb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2048,10 +2048,15 @@ def _build_logits_processors( guided_decoding.backend = guided_decoding.backend or \ self.decoding_config.guided_decoding_backend + logger.debug("Reasoning backend: %s", + self.decoding_config.reasoning_backend) + processor = get_local_guided_decoding_logits_processor( guided_params=guided_decoding, tokenizer=tokenizer, - model_config=self.model_config) + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) if processor: logits_processors.append(processor) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index c12fe242082b..005ba81cd226 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -611,7 +611,8 @@ async def _process_request( default_guided_backend=(self.decoding_config.guided_decoding_backend if self.decoding_config else DecodingConfig.guided_decoding_backend), - model_config=self.model_config + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, ) # 1) Create output queue for this requests. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index ba953c219708..8d877046f75f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -13,7 +13,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) -from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Enable auto tool choice for supported models. Use " "``--tool-call-parser`` to specify which parser to use.") - parser.add_argument( - "--enable-reasoning", - action="store_true", - default=False, - help="Whether to enable reasoning_content for the model. " - "If enabled, the model will be able to generate reasoning content.") - - valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys() - parser.add_argument( - "--reasoning-parser", - type=str, - metavar="{" + ",".join(valid_reasoning_parsers) + "}", - default=None, - help= - "Select the reasoning parser depending on the model that you're using." - " This is used to parse the reasoning content into OpenAI API " - "format. Required for ``--enable-reasoning``.") valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 1522e3404182..86f6f0e5f907 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.reasoner import get_reasoner from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) @@ -103,8 +104,13 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str, async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: + + reasoner = get_reasoner(tokenizer, reasoning_backend) + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': @@ -112,8 +118,8 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) - if guided_params.backend_name == 'lm-format-enforcer': + guided_params, tokenizer, reasoner) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( @@ -122,7 +128,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params, tokenizer, model_config, reasoner) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -130,16 +136,22 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) + + # Get the reasoner if needed, it will be None if reasoning_ + reasoner = get_reasoner(tokenizer, reasoning_backend) + # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) + guided_params, tokenizer, reasoner) if guided_params.backend_name == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) @@ -149,7 +161,7 @@ def get_local_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params, tokenizer, model_config, reasoner) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index ba9c98290368..97f63ae11f45 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -6,12 +6,13 @@ from enum import Enum from json import dumps as json_dumps from re import escape as regex_escape -from typing import Tuple, Union +from typing import Optional, Tuple, Union from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.sampling_params import GuidedDecodingParams @@ -58,7 +59,9 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor( return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, - mode, guided_params.whitespace_pattern) + mode, guided_params.whitespace_pattern, + reasoner) def get_local_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -100,7 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor( return None return _get_logits_processor(guide, tokenizer, mode, - guided_params.whitespace_pattern) + guided_params.whitespace_pattern, reasoner) def _get_guide_and_mode( @@ -131,14 +137,18 @@ def _get_guide_and_mode( def _get_logits_processor( - guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, - whitespace_pattern: Union[str, None] + guide: str, + tokenizer: PreTrainedTokenizerBase, + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None], + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, + reasoner) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: - return RegexLogitsProcessor(guide, tokenizer) + return RegexLogitsProcessor(guide, tokenizer, reasoner) elif mode == GuidedDecodingMode.GRAMMAR: - return CFGLogitsProcessor(guide, tokenizer) + return CFGLogitsProcessor(guide, tokenizer, reasoner) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index a05267d921d1..db5d738f42e4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -19,7 +19,7 @@ import json from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Union +from typing import Callable, DefaultDict, Dict, List, Optional, Union import numpy as np import torch @@ -32,13 +32,18 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.platforms import current_platform +logger = init_logger(__name__) + class BaseLogitsProcessor: - def __init__(self, guide: Guide): + def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): self._guide: Guide = guide + self._reasoner = reasoner # CFGState is used for the FSM state for CFGGuide self._fsm_state: DefaultDict[int, Union[int, CFGState]] = defaultdict(int) @@ -46,6 +51,14 @@ def __init__(self, guide: Guide): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--enable-reasoning` is set. + if self._reasoner is not None and \ + not self._reasoner.is_reasoning_end( + input_ids): + return scores + seq_id = hash(tuple(input_ids)) if len(input_ids) > 0: @@ -113,7 +126,12 @@ def _get_guide(cls, regex_string: str, tokenizer = _adapt_tokenizer(tokenizer) return RegexGuide.from_regex(regex_string, tokenizer) - def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): + def __init__( + self, + regex_string: str, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner], + ): """Compile the FSM that drives the regex-structured generation. Parameters @@ -125,14 +143,15 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): """ super().__init__( - RegexLogitsProcessor._get_guide(regex_string, tokenizer)) + RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner) class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Union[str, None]): + whitespace_pattern: Union[str, None], + reasoner: Optional[Reasoner]): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -160,7 +179,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel], f"a Pydantic object, a dictionary or a string that contains " f"the JSON Schema specification") regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string, tokenizer) + super().__init__(regex_string, tokenizer, reasoner) class CFGLogitsProcessor(BaseLogitsProcessor): @@ -171,7 +190,8 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer = _adapt_tokenizer(tokenizer) return CFGGuide(cfg, tokenizer) - def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner]): """Compile the FSM that drives the context free grammar generation. Parameters @@ -182,7 +202,8 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer)) + super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), + reasoner) self._guide = self._guide.copy() diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py new file mode 100644 index 000000000000..5a91f791d45b --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from transformers import PreTrainedTokenizer + +from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501 + DeepSeekReasoner) +from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner + + +def get_reasoner(tokenizer: PreTrainedTokenizer, + reasoning_backend: str | None) -> Reasoner | None: + if reasoning_backend is None: + # No reasoning backend specified + return None + elif reasoning_backend == "deepseek_r1": + return DeepSeekReasoner.from_tokenizer(tokenizer) + else: + raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'") + + +__all__ = ["Reasoner", "get_reasoner"] diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py new file mode 100644 index 000000000000..e762fb0659de --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from transformers import PreTrainedTokenizer + +from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner + + +@dataclass +class DeepSeekReasoner(Reasoner): + """ + Reasoner for DeepSeek R series models. + """ + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + + @classmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + return cls(start_token_id=tokenizer.encode( + "", add_special_tokens=False)[0], + end_token_id=tokenizer.encode("", + add_special_tokens=False)[0]) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py new file mode 100644 index 000000000000..5db0c9bc7850 --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/reasoner.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from transformers import PreTrainedTokenizer + + +@dataclass +class Reasoner(ABC): + + @abstractmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + pass + + @abstractmethod + def is_reasoning_end(self, input_ids: list[int]) -> bool: + pass diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index eb9d83acb286..ce278c15ab3b 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -11,6 +11,8 @@ import torch from transformers import PreTrainedTokenizerFast +from vllm.logger import init_logger + try: import xgrammar as xgr from xgrammar.base import _core as xgr_core @@ -19,7 +21,6 @@ xgr_installed = False pass -from vllm.logger import init_logger from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, grammar_is_likely_lark) from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer @@ -28,6 +29,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig + from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.sampling_params import GuidedDecodingParams logger = init_logger(__name__) @@ -38,12 +40,13 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, + reasoner: Reasoner | None, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, model_config=model_config, tokenizer=tokenizer, max_threads=max_threads) - return XGrammarLogitsProcessor(config) + return XGrammarLogitsProcessor(config, reasoner) @dataclass(frozen=True) @@ -293,6 +296,7 @@ def choice_as_grammar(choice: List[str] | None) -> str: class XGrammarLogitsProcessor: """Wrapper class to support pickle protocol""" config: GrammarConfig + reasoner: Reasoner | None = None ctx: xgr.CompiledGrammar | None = None token_bitmask: torch.Tensor = None # type: ignore[assignment] @@ -301,10 +305,11 @@ class XGrammarLogitsProcessor: prefilled: bool = field(default=False) def __getstate__(self) -> dict[str, Any]: - return {'config': self.config} + return {'config': self.config, 'reasoner': self.reasoner} def __setstate__(self, state: dict[str, Any]): self.config = state['config'] + self.reasoner = state['reasoner'] self.ctx = None self.matchers = [] @@ -331,6 +336,14 @@ def _ensure_ctx(self): def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--enable-reasoning` is set. + if self.reasoner is not None and \ + not self.reasoner.is_reasoning_end( + input_ids): + return scores + if self.ctx is None: self._ensure_ctx()