diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md
index e39bbacf1138..5c0c1762f8aa 100644
--- a/docs/source/features/reasoning_outputs.md
+++ b/docs/source/features/reasoning_outputs.md
@@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
}
```
-Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
+Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
+
+## Limitations
+
+- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
+- It is not compatible with [`tool_calling`](#tool_calling).
+- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
## How to support a new reasoning model
@@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser):
"""
```
-After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
+Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
+
+```python
+@dataclass
+class DeepSeekReasoner(Reasoner):
+ """
+ Reasoner for DeepSeek R series models.
+ """
+ start_token_id: int
+ end_token_id: int
+
+ start_token: str = ""
+ end_token: str = ""
+
+ @classmethod
+ def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
+ return cls(start_token_id=tokenizer.encode(
+ "", add_special_tokens=False)[0],
+ end_token_id=tokenizer.encode("",
+ add_special_tokens=False)[0])
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.end_token_id in input_ids
+```
+
+The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
+
+Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
```bash
vllm serve \
--enable-reasoning --reasoning-parser example
```
-
-## Limitations
-
-- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
-- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
-- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
new file mode 100644
index 000000000000..1f72e1164d42
--- /dev/null
+++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
@@ -0,0 +1,64 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+An example shows how to generate structured outputs from reasoning models
+like DeepSeekR1. The thinking process will not be guided by the JSON
+schema provided by the user. Only the final output will be structured.
+
+To run this example, you need to start the vLLM server with the reasoning
+parser:
+
+```bash
+vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
+ --enable-reasoning --reasoning-parser deepseek_r1
+```
+
+This example demonstrates how to generate chat completions from reasoning models
+using the OpenAI Python client library.
+"""
+
+from enum import Enum
+
+from openai import OpenAI
+from pydantic import BaseModel
+
+# Modify OpenAI's API key and API base to use vLLM's API server.
+openai_api_key = "EMPTY"
+openai_api_base = "http://localhost:8000/v1"
+
+client = OpenAI(
+ api_key=openai_api_key,
+ base_url=openai_api_base,
+)
+
+models = client.models.list()
+model = models.data[0].id
+
+
+# Guided decoding by JSON using Pydantic schema
+class CarType(str, Enum):
+ sedan = "sedan"
+ suv = "SUV"
+ truck = "Truck"
+ coupe = "Coupe"
+
+
+class CarDescription(BaseModel):
+ brand: str
+ model: str
+ car_type: CarType
+
+
+json_schema = CarDescription.model_json_schema()
+
+prompt = ("Generate a JSON with the brand, model and car_type of"
+ "the most iconic car from the 90's, think in 100 tokens")
+completion = client.chat.completions.create(
+ model=model,
+ messages=[{
+ "role": "user",
+ "content": prompt,
+ }],
+ extra_body={"guided_json": json_schema},
+)
+print("content", completion.choices[0].message.content)
+print("reasoning_content: ", completion.choices[0].message.reasoning_content)
diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py
index be544698fa03..531c3a8c13b2 100644
--- a/tests/model_executor/test_guided_processors.py
+++ b/tests/model_executor/test_guided_processors.py
@@ -16,17 +16,33 @@
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
+GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
+REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
-def test_guided_logits_processors(sample_regex, sample_json_schema):
+# Initialize the tokenizer for the model here to avoid repeated loading
+@pytest.fixture(scope="module")
+def zephyr_7B_tokenzer():
+ return AutoTokenizer.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture(scope="module")
+def deepseek_r1_qwen_tokenizer():
+ return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
+
+
+def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
+ sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
- tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
- regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
+ regex_LP = RegexLogitsProcessor(sample_regex,
+ zephyr_7B_tokenzer,
+ reasoner=None)
json_LP = JSONLogitsProcessor(sample_json_schema,
- tokenizer,
- whitespace_pattern=None)
+ zephyr_7B_tokenzer,
+ whitespace_pattern=None,
+ reasoner=None)
- token_ids = tokenizer.encode(
+ token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
@@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
- token_ids = tokenizer.encode(
+ token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
@@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
- sample_json_schema):
+ sample_json_schema,
+ zephyr_7B_tokenzer):
config = ModelConfig(
MODEL_NAME,
@@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0,
dtype="bfloat16",
)
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
- token_ids = tokenizer.encode(
+ token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(
- regex_request, tokenizer, config) if is_local else \
+ regex_request, zephyr_7B_tokenzer, config) if is_local else \
await get_guided_decoding_logits_processor(
- regex_request, tokenizer, config)
+ regex_request, zephyr_7B_tokenzer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
@@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
- token_ids = tokenizer.encode(
+ token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
- json_request, tokenizer, config)
+ json_request, zephyr_7B_tokenzer, config)
+ assert json_lp is not None
+ tensor = torch.rand(32000)
+ original_tensor = torch.clone(tensor)
+ tensor = json_lp(token_ids, tensor)
+ assert tensor.shape == original_tensor.shape
+ assert not torch.allclose(tensor, original_tensor)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("backend",
+ GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
+@pytest.mark.parametrize("is_local", [True, False])
+@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
+async def test_guided_logits_processor_with_reasoning(
+ backend: str, is_local: bool, reasoning_backend: str, sample_regex,
+ sample_json_schema, deepseek_r1_qwen_tokenizer):
+
+ config = ModelConfig(
+ REASONING_MODEL_NAME,
+ task="generate",
+ tokenizer=REASONING_MODEL_NAME,
+ tokenizer_mode="auto",
+ trust_remote_code=False,
+ seed=0,
+ dtype="bfloat16",
+ )
+ token_ids = deepseek_r1_qwen_tokenizer.encode(
+ f"Give an example IPv4 address with this regex: {sample_regex}."
+ "here is the thinking process")
+ regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
+
+ regex_lp = get_local_guided_decoding_logits_processor(regex_request,
+ deepseek_r1_qwen_tokenizer, config,
+ reasoning_backend) if is_local else \
+ await get_guided_decoding_logits_processor(
+ regex_request, deepseek_r1_qwen_tokenizer, config,
+ reasoning_backend)
+ assert regex_lp is not None
+ tensor = torch.rand(32000)
+ original_tensor = torch.clone(tensor)
+ tensor = regex_lp(token_ids, tensor)
+ assert tensor.shape == original_tensor.shape
+ assert torch.allclose(tensor, original_tensor)
+
+ token_ids = deepseek_r1_qwen_tokenizer.encode(
+ f"Give an employee profile that fits this schema: {sample_json_schema}."
+ "here is the thinking process")
+ json_request = GuidedDecodingParams(json=sample_json_schema,
+ backend=backend)
+ json_lp = get_local_guided_decoding_logits_processor(
+ json_request, deepseek_r1_qwen_tokenizer, config,
+ reasoning_backend) if is_local else \
+ await get_guided_decoding_logits_processor(
+ json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
+ assert json_lp is not None
+ tensor = torch.rand(32000)
+ original_tensor = torch.clone(tensor)
+ tensor = json_lp(token_ids, tensor)
+ assert tensor.shape == original_tensor.shape
+ assert torch.allclose(tensor, original_tensor)
+
+ # Thinking is over, so the tensor should change.
+ token_ids = deepseek_r1_qwen_tokenizer.encode(
+ f"Give an employee profile that fits this schema: {sample_json_schema}."
+ "here is the thinking process Then")
+ json_request = GuidedDecodingParams(json=sample_json_schema,
+ backend=backend)
+ json_lp = get_local_guided_decoding_logits_processor(
+ json_request, deepseek_r1_qwen_tokenizer, config,
+ reasoning_backend) if is_local else \
+ await get_guided_decoding_logits_processor(
+ json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
diff --git a/vllm/config.py b/vllm/config.py
index c7108473442b..54ed38418dd4 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2715,6 +2715,8 @@ class DecodingConfig:
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'
+ reasoning_backend: Optional[str] = None
+
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 1a2f794c9151..989eb4dbfd14 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -213,6 +213,8 @@ class EngineArgs:
calculate_kv_scales: Optional[bool] = None
additional_config: Optional[Dict[str, Any]] = None
+ enable_reasoning: Optional[bool] = None
+ reasoning_parser: Optional[str] = None
def __post_init__(self):
if not self.tokenizer:
@@ -1059,6 +1061,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")
+
+ parser.add_argument(
+ "--enable-reasoning",
+ action="store_true",
+ default=False,
+ help="Whether to enable reasoning_content for the model. "
+ "If enabled, the model will be able to generate reasoning content."
+ )
+
+ parser.add_argument(
+ "--reasoning-parser",
+ type=str,
+ choices=["deepseek_r1"],
+ default=None,
+ help=
+ "Select the reasoning parser depending on the model that you're "
+ "using. This is used to parse the reasoning content into OpenAI "
+ "API format. Required for ``--enable-reasoning``.")
+
return parser
@classmethod
@@ -1332,7 +1353,10 @@ def create_engine_config(self,
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
- guided_decoding_backend=self.guided_decoding_backend)
+ guided_decoding_backend=self.guided_decoding_backend,
+ reasoning_backend=self.reasoning_parser
+ if self.enable_reasoning else None,
+ )
show_hidden_metrics = False
if self.show_hidden_metrics_for_version is not None:
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 93d9b74d8e1e..90e66b005f39 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -509,6 +509,7 @@ async def add_request_async(
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend,
+ reasoning_backend=self.decoding_config.reasoning_backend,
model_config=self.model_config)
self._add_processed_request(
@@ -530,7 +531,7 @@ async def check_health_async(self) -> None:
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
- default_guided_backend: str,
+ default_guided_backend: str, reasoning_backend: Optional[str],
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
@@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async(
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
- logger.debug("Building guided decoding logits processor. "
- "Params: %s", guided_decoding)
+ logger.info(
+ "Building guided decoding logits processor. "
+ "guided_decoding: %s%s", guided_decoding,
+ f", reasoning_backend: {reasoning_backend}"
+ if reasoning_backend is not None else "")
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
+ reasoning_backend=reasoning_backend,
model_config=model_config)
if processor:
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 9c83ea75ead7..f055438d1feb 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2048,10 +2048,15 @@ def _build_logits_processors(
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
+ logger.debug("Reasoning backend: %s",
+ self.decoding_config.reasoning_backend)
+
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
- model_config=self.model_config)
+ model_config=self.model_config,
+ reasoning_backend=self.decoding_config.reasoning_backend,
+ )
if processor:
logits_processors.append(processor)
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index c12fe242082b..005ba81cd226 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -611,7 +611,8 @@ async def _process_request(
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
- model_config=self.model_config
+ model_config=self.model_config,
+ reasoning_backend=self.decoding_config.reasoning_backend,
)
# 1) Create output queue for this requests.
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index ba953c219708..8d877046f75f 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -13,7 +13,6 @@
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
-from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Enable auto tool choice for supported models. Use "
"``--tool-call-parser`` to specify which parser to use.")
- parser.add_argument(
- "--enable-reasoning",
- action="store_true",
- default=False,
- help="Whether to enable reasoning_content for the model. "
- "If enabled, the model will be able to generate reasoning content.")
-
- valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
- parser.add_argument(
- "--reasoning-parser",
- type=str,
- metavar="{" + ",".join(valid_reasoning_parsers) + "}",
- default=None,
- help=
- "Select the reasoning parser depending on the model that you're using."
- " This is used to parse the reasoning content into OpenAI API "
- "format. Required for ``--enable-reasoning``.")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py
index 1522e3404182..86f6f0e5f907 100644
--- a/vllm/model_executor/guided_decoding/__init__.py
+++ b/vllm/model_executor/guided_decoding/__init__.py
@@ -5,6 +5,7 @@
from typing import TYPE_CHECKING
from vllm.logger import init_logger
+from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
@@ -103,8 +104,13 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
async def get_guided_decoding_logits_processor(
- guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
- model_config: ModelConfig) -> LogitsProcessor | None:
+ guided_params: GuidedDecodingParams,
+ tokenizer: PreTrainedTokenizer,
+ model_config: ModelConfig,
+ reasoning_backend: str | None = None) -> LogitsProcessor | None:
+
+ reasoner = get_reasoner(tokenizer, reasoning_backend)
+
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
@@ -112,8 +118,8 @@ async def get_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
- guided_params, tokenizer)
- if guided_params.backend_name == 'lm-format-enforcer':
+ guided_params, tokenizer, reasoner)
+ if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
@@ -122,7 +128,7 @@ async def get_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
- guided_params, tokenizer, model_config)
+ guided_params, tokenizer, model_config, reasoner)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
@@ -130,16 +136,22 @@ async def get_guided_decoding_logits_processor(
def get_local_guided_decoding_logits_processor(
- guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
- model_config: ModelConfig) -> LogitsProcessor | None:
+ guided_params: GuidedDecodingParams,
+ tokenizer: PreTrainedTokenizer,
+ model_config: ModelConfig,
+ reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
+
+ # Get the reasoner if needed, it will be None if reasoning_
+ reasoner = get_reasoner(tokenizer, reasoning_backend)
+
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
- guided_params, tokenizer)
+ guided_params, tokenizer, reasoner)
if guided_params.backend_name == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
@@ -149,7 +161,7 @@ def get_local_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
- guided_params, tokenizer, model_config)
+ guided_params, tokenizer, model_config, reasoner)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py
index ba9c98290368..97f63ae11f45 100644
--- a/vllm/model_executor/guided_decoding/outlines_decoding.py
+++ b/vllm/model_executor/guided_decoding/outlines_decoding.py
@@ -6,12 +6,13 @@
from enum import Enum
from json import dumps as json_dumps
from re import escape as regex_escape
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
+from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.sampling_params import GuidedDecodingParams
@@ -58,7 +59,9 @@ class GuidedDecodingMode(Enum):
async def get_outlines_guided_decoding_logits_processor(
- guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
+ guided_params: GuidedDecodingParams,
+ tokenizer: PreTrainedTokenizerBase,
+ reasoner: Optional[Reasoner],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor(
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
- mode, guided_params.whitespace_pattern)
+ mode, guided_params.whitespace_pattern,
+ reasoner)
def get_local_outlines_guided_decoding_logits_processor(
- guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
+ guided_params: GuidedDecodingParams,
+ tokenizer: PreTrainedTokenizerBase,
+ reasoner: Optional[Reasoner],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@@ -100,7 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor(
return None
return _get_logits_processor(guide, tokenizer, mode,
- guided_params.whitespace_pattern)
+ guided_params.whitespace_pattern, reasoner)
def _get_guide_and_mode(
@@ -131,14 +137,18 @@ def _get_guide_and_mode(
def _get_logits_processor(
- guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
- whitespace_pattern: Union[str, None]
+ guide: str,
+ tokenizer: PreTrainedTokenizerBase,
+ mode: GuidedDecodingMode,
+ whitespace_pattern: Union[str, None],
+ reasoner: Optional[Reasoner],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
- return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
+ return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
+ reasoner)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
- return RegexLogitsProcessor(guide, tokenizer)
+ return RegexLogitsProcessor(guide, tokenizer, reasoner)
elif mode == GuidedDecodingMode.GRAMMAR:
- return CFGLogitsProcessor(guide, tokenizer)
+ return CFGLogitsProcessor(guide, tokenizer, reasoner)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")
diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
index a05267d921d1..db5d738f42e4 100644
--- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py
+++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
@@ -19,7 +19,7 @@
import json
from collections import defaultdict
from functools import lru_cache
-from typing import Callable, DefaultDict, Dict, List, Union
+from typing import Callable, DefaultDict, Dict, List, Optional, Union
import numpy as np
import torch
@@ -32,13 +32,18 @@
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
+from vllm.logger import init_logger
+from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
+logger = init_logger(__name__)
+
class BaseLogitsProcessor:
- def __init__(self, guide: Guide):
+ def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
self._guide: Guide = guide
+ self._reasoner = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
@@ -46,6 +51,14 @@ def __init__(self, guide: Guide):
def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
+
+ # Skip the structured logits processing if reasoning is not finished.
+ # reasoner is not None only when `--enable-reasoning` is set.
+ if self._reasoner is not None and \
+ not self._reasoner.is_reasoning_end(
+ input_ids):
+ return scores
+
seq_id = hash(tuple(input_ids))
if len(input_ids) > 0:
@@ -113,7 +126,12 @@ def _get_guide(cls, regex_string: str,
tokenizer = _adapt_tokenizer(tokenizer)
return RegexGuide.from_regex(regex_string, tokenizer)
- def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
+ def __init__(
+ self,
+ regex_string: str,
+ tokenizer: PreTrainedTokenizerBase,
+ reasoner: Optional[Reasoner],
+ ):
"""Compile the FSM that drives the regex-structured generation.
Parameters
@@ -125,14 +143,15 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""
super().__init__(
- RegexLogitsProcessor._get_guide(regex_string, tokenizer))
+ RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
- whitespace_pattern: Union[str, None]):
+ whitespace_pattern: Union[str, None],
+ reasoner: Optional[Reasoner]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
@@ -160,7 +179,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel],
f"a Pydantic object, a dictionary or a string that contains "
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
- super().__init__(regex_string, tokenizer)
+ super().__init__(regex_string, tokenizer, reasoner)
class CFGLogitsProcessor(BaseLogitsProcessor):
@@ -171,7 +190,8 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer)
- def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
+ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
+ reasoner: Optional[Reasoner]):
"""Compile the FSM that drives the context free grammar generation.
Parameters
@@ -182,7 +202,8 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
The model's tokenizer
"""
- super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
+ super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
+ reasoner)
self._guide = self._guide.copy()
diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py
new file mode 100644
index 000000000000..5a91f791d45b
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py
@@ -0,0 +1,23 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+from transformers import PreTrainedTokenizer
+
+from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
+ DeepSeekReasoner)
+from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
+
+
+def get_reasoner(tokenizer: PreTrainedTokenizer,
+ reasoning_backend: str | None) -> Reasoner | None:
+ if reasoning_backend is None:
+ # No reasoning backend specified
+ return None
+ elif reasoning_backend == "deepseek_r1":
+ return DeepSeekReasoner.from_tokenizer(tokenizer)
+ else:
+ raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
+
+
+__all__ = ["Reasoner", "get_reasoner"]
diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py
new file mode 100644
index 000000000000..e762fb0659de
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py
@@ -0,0 +1,28 @@
+# SPDX-License-Identifier: Apache-2.0
+from dataclasses import dataclass
+
+from transformers import PreTrainedTokenizer
+
+from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
+
+
+@dataclass
+class DeepSeekReasoner(Reasoner):
+ """
+ Reasoner for DeepSeek R series models.
+ """
+ start_token_id: int
+ end_token_id: int
+
+ start_token: str = ""
+ end_token: str = ""
+
+ @classmethod
+ def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
+ return cls(start_token_id=tokenizer.encode(
+ "", add_special_tokens=False)[0],
+ end_token_id=tokenizer.encode("",
+ add_special_tokens=False)[0])
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.end_token_id in input_ids
diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py
new file mode 100644
index 000000000000..5db0c9bc7850
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/reasoner/reasoner.py
@@ -0,0 +1,19 @@
+# SPDX-License-Identifier: Apache-2.0
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from transformers import PreTrainedTokenizer
+
+
+@dataclass
+class Reasoner(ABC):
+
+ @abstractmethod
+ def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
+ pass
+
+ @abstractmethod
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ pass
diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
index eb9d83acb286..ce278c15ab3b 100644
--- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py
+++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
@@ -11,6 +11,8 @@
import torch
from transformers import PreTrainedTokenizerFast
+from vllm.logger import init_logger
+
try:
import xgrammar as xgr
from xgrammar.base import _core as xgr_core
@@ -19,7 +21,6 @@
xgr_installed = False
pass
-from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
@@ -28,6 +29,7 @@
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
+ from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
@@ -38,12 +40,13 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
+ reasoner: Reasoner | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
max_threads=max_threads)
- return XGrammarLogitsProcessor(config)
+ return XGrammarLogitsProcessor(config, reasoner)
@dataclass(frozen=True)
@@ -293,6 +296,7 @@ def choice_as_grammar(choice: List[str] | None) -> str:
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
+ reasoner: Reasoner | None = None
ctx: xgr.CompiledGrammar | None = None
token_bitmask: torch.Tensor = None # type: ignore[assignment]
@@ -301,10 +305,11 @@ class XGrammarLogitsProcessor:
prefilled: bool = field(default=False)
def __getstate__(self) -> dict[str, Any]:
- return {'config': self.config}
+ return {'config': self.config, 'reasoner': self.reasoner}
def __setstate__(self, state: dict[str, Any]):
self.config = state['config']
+ self.reasoner = state['reasoner']
self.ctx = None
self.matchers = []
@@ -331,6 +336,14 @@ def _ensure_ctx(self):
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
+
+ # Skip the structured logits processing if reasoning is not finished.
+ # reasoner is not None only when `--enable-reasoning` is set.
+ if self.reasoner is not None and \
+ not self.reasoner.is_reasoning_end(
+ input_ids):
+ return scores
+
if self.ctx is None:
self._ensure_ctx()