Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions docs/source/features/reasoning_outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
}
```

Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).

## Limitations

- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
- It is not compatible with [`tool_calling`](#tool_calling).
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.

## How to support a new reasoning model

Expand Down Expand Up @@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser):
"""
```

After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.

```python
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int

start_token: str = "<think>"
end_token: str = "</think>"

@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])

def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
```

The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.

Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.

```bash
vllm serve <model_tag> \
--enable-reasoning --reasoning-parser example
```

## Limitations

- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
"""
An example shows how to generate structured outputs from reasoning models
like DeepSeekR1. The thinking process will not be guided by the JSON
schema provided by the user. Only the final output will be structured.

To run this example, you need to start the vLLM server with the reasoning
parser:

```bash
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--enable-reasoning --reasoning-parser deepseek_r1
```

This example demonstrates how to generate chat completions from reasoning models
using the OpenAI Python client library.
"""

from enum import Enum

from openai import OpenAI
from pydantic import BaseModel

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id


# Guided decoding by JSON using Pydantic schema
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"


class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType


json_schema = CarDescription.model_json_schema()

prompt = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's, think in 100 tokens")
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={"guided_json": json_schema},
)
print("content", completion.choices[0].message.content)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
116 changes: 102 additions & 14 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,41 @@

MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"


def test_guided_logits_processors(sample_regex, sample_json_schema):
# Initialize the tokenizer for the model here to avoid repeated loading
@pytest.fixture(scope="module")
def zephyr_7B_tokenzer():
return AutoTokenizer.from_pretrained(MODEL_NAME)


@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
regex_LP = RegexLogitsProcessor(sample_regex,
zephyr_7B_tokenzer,
reasoner=None)
json_LP = JSONLogitsProcessor(sample_json_schema,
tokenizer,
whitespace_pattern=None)
zephyr_7B_tokenzer,
whitespace_pattern=None,
reasoner=None)

token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
Expand All @@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
sample_json_schema):
sample_json_schema,
zephyr_7B_tokenzer):

config = ModelConfig(
MODEL_NAME,
Expand All @@ -60,29 +77,100 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)

regex_lp = get_local_guided_decoding_logits_processor(
regex_request, tokenizer, config) if is_local else \
regex_request, zephyr_7B_tokenzer, config) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, tokenizer, config)
regex_request, zephyr_7B_tokenzer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer, config)
json_request, zephyr_7B_tokenzer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)


@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
@pytest.mark.parametrize("is_local", [True, False])
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
async def test_guided_logits_processor_with_reasoning(
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
sample_json_schema, deepseek_r1_qwen_tokenizer):

config = ModelConfig(
REASONING_MODEL_NAME,
task="generate",
tokenizer=REASONING_MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}."
"<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)

regex_lp = get_local_guided_decoding_logits_processor(regex_request,
deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)

token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)

# Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process</think> Then")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,8 @@ class DecodingConfig:
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'

reasoning_backend: Optional[str] = None

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down
26 changes: 25 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ class EngineArgs:
calculate_kv_scales: Optional[bool] = None

additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None

def __post_init__(self):
if not self.tokenizer:
Expand Down Expand Up @@ -1059,6 +1061,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")

parser.add_argument(
"--enable-reasoning",
action="store_true",
default=False,
help="Whether to enable reasoning_content for the model. "
"If enabled, the model will be able to generate reasoning content."
)

parser.add_argument(
"--reasoning-parser",
type=str,
choices=["deepseek_r1"],
default=None,
help=
"Select the reasoning parser depending on the model that you're "
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.")

return parser

@classmethod
Expand Down Expand Up @@ -1332,7 +1353,10 @@ def create_engine_config(self,
if self.enable_prompt_adapter else None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
guided_decoding_backend=self.guided_decoding_backend,
reasoning_backend=self.reasoning_parser
if self.enable_reasoning else None,
)

show_hidden_metrics = False
if self.show_hidden_metrics_for_version is not None:
Expand Down
11 changes: 8 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ async def add_request_async(
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend,
reasoning_backend=self.decoding_config.reasoning_backend,
model_config=self.model_config)

self._add_processed_request(
Expand All @@ -530,7 +531,7 @@ async def check_health_async(self) -> None:

async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str,
default_guided_backend: str, reasoning_backend: Optional[str],
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
Expand All @@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async(
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding

logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
logger.info(
"Building guided decoding logits processor. "
"guided_decoding: %s%s", guided_decoding,
f", reasoning_backend: {reasoning_backend}"
if reasoning_backend is not None else "")

guided_decoding.backend = guided_decoding.backend or default_guided_backend

processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
reasoning_backend=reasoning_backend,
model_config=model_config)

if processor:
Expand Down
Loading