Skip to content

Commit 654bc5c

Browse files
Support for guided decoding for offline LLM (#6878)
Co-authored-by: Cyrus Leung <[email protected]>
1 parent 825b044 commit 654bc5c

File tree

9 files changed

+352
-12
lines changed

9 files changed

+352
-12
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def setup(app):
111111
"tqdm",
112112
"tensorizer",
113113
"pynvml",
114+
"outlines",
114115
]
115116

116117
for mock_target in autodoc_mock_imports:

tests/entrypoints/openai/conftest.py renamed to tests/entrypoints/conftest.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,26 @@
11
import pytest
22

33

4+
@pytest.fixture
5+
def sample_prompts():
6+
return [
7+
"Hello, my name is",
8+
"The president of the United States is",
9+
"The capital of France is",
10+
"The future of AI is",
11+
]
12+
13+
14+
@pytest.fixture
15+
def sample_token_ids():
16+
return [
17+
[0],
18+
[0, 1],
19+
[0, 2, 1],
20+
[0, 3, 1, 2],
21+
]
22+
23+
424
@pytest.fixture
525
def sample_regex():
626
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
@@ -66,4 +86,4 @@ def sample_sql_statements():
6686
table: "table_1" | "table_2"
6787
condition: column "=" number
6888
number: "1" | "2"
69-
""")
89+
""")
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import json
2+
import re
3+
import weakref
4+
5+
import jsonschema
6+
import pytest
7+
8+
from vllm.entrypoints.llm import LLM
9+
from vllm.outputs import RequestOutput
10+
from vllm.sampling_params import SamplingParams
11+
12+
from ...conftest import cleanup
13+
14+
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
15+
16+
17+
@pytest.fixture(scope="module")
18+
def llm():
19+
# pytest caches the fixture so we use weakref.proxy to
20+
# enable garbage collection
21+
llm = LLM(model=MODEL_NAME, max_model_len=1024)
22+
23+
with llm.deprecate_legacy_api():
24+
yield weakref.proxy(llm)
25+
del llm
26+
cleanup()
27+
28+
29+
@pytest.mark.skip_global_cleanup
30+
def test_guided_regex(sample_regex, llm):
31+
sampling_params = SamplingParams(
32+
temperature=0.8,
33+
top_p=0.95,
34+
)
35+
outputs = llm.generate(
36+
prompts=[
37+
f"Give an example IPv4 address with this regex: {sample_regex}"
38+
] * 2,
39+
sampling_params=sampling_params,
40+
use_tqdm=True,
41+
guided_options_request=dict(guided_regex=sample_regex))
42+
43+
assert outputs is not None
44+
for output in outputs:
45+
assert output is not None
46+
assert isinstance(output, RequestOutput)
47+
prompt = output.prompt
48+
generated_text = output.outputs[0].text
49+
print(generated_text)
50+
assert generated_text is not None
51+
assert re.fullmatch(sample_regex, generated_text) is not None
52+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
53+
54+
55+
@pytest.mark.skip_global_cleanup
56+
def test_guided_json_completion(sample_json_schema, llm):
57+
sampling_params = SamplingParams(
58+
temperature=1.0,
59+
max_tokens=1000,
60+
)
61+
outputs = llm.generate(
62+
prompts=[
63+
f"Give an example JSON for an employee profile "
64+
f"that fits this schema: {sample_json_schema}"
65+
] * 2,
66+
sampling_params=sampling_params,
67+
use_tqdm=True,
68+
guided_options_request=dict(guided_json=sample_json_schema))
69+
70+
assert outputs is not None
71+
72+
for output in outputs:
73+
assert output is not None
74+
assert isinstance(output, RequestOutput)
75+
prompt = output.prompt
76+
77+
generated_text = output.outputs[0].text
78+
assert generated_text is not None
79+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
80+
output_json = json.loads(generated_text)
81+
jsonschema.validate(instance=output_json, schema=sample_json_schema)
82+
83+
84+
@pytest.mark.skip_global_cleanup
85+
def test_guided_choice_completion(sample_guided_choice, llm):
86+
sampling_params = SamplingParams(
87+
temperature=0.8,
88+
top_p=0.95,
89+
)
90+
outputs = llm.generate(
91+
prompts="The best language for type-safe systems programming is ",
92+
sampling_params=sampling_params,
93+
use_tqdm=True,
94+
guided_options_request=dict(guided_choice=sample_guided_choice))
95+
96+
assert outputs is not None
97+
for output in outputs:
98+
assert output is not None
99+
assert isinstance(output, RequestOutput)
100+
prompt = output.prompt
101+
generated_text = output.outputs[0].text
102+
print(generated_text)
103+
assert generated_text is not None
104+
assert generated_text in sample_guided_choice
105+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
106+
107+
108+
@pytest.mark.skip_global_cleanup
109+
def test_guided_grammar(sample_sql_statements, llm):
110+
111+
sampling_params = SamplingParams(
112+
temperature=0.8,
113+
top_p=0.95,
114+
max_tokens=1000,
115+
)
116+
outputs = llm.generate(
117+
prompts=("Generate a sql state that select col_1 from "
118+
"table_1 where it is equals to 1"),
119+
sampling_params=sampling_params,
120+
use_tqdm=True,
121+
guided_options_request=dict(guided_grammar=sample_sql_statements))
122+
123+
assert outputs is not None
124+
for output in outputs:
125+
assert output is not None
126+
assert isinstance(output, RequestOutput)
127+
prompt = output.prompt
128+
129+
generated_text = output.outputs[0].text
130+
assert generated_text is not None
131+
# use Lark to parse the output, and make sure it's a valid parse tree
132+
from lark import Lark
133+
parser = Lark(sample_sql_statements)
134+
parser.parse(generated_text)
135+
136+
# remove spaces for comparison b/c we removed them in the grammar
137+
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
138+
" ", "")
139+
140+
assert generated_text.strip() == ground_truth
141+
142+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

vllm/entrypoints/llm.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
parse_and_batch_prompt)
1111
from vllm.logger import init_logger
1212
from vllm.lora.request import LoRARequest
13+
from vllm.model_executor.guided_decoding import (
14+
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
15+
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
1316
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
1417
from vllm.pooling_params import PoolingParams
1518
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -262,6 +265,8 @@ def generate(
262265
use_tqdm: bool = True,
263266
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
264267
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
268+
guided_options_request: Optional[Union[LLMGuidedOptions,
269+
GuidedDecodingRequest]] = None
265270
) -> List[RequestOutput]:
266271
"""Generates the completions for the input prompts.
267272
@@ -303,6 +308,14 @@ def generate(
303308
else:
304309
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
305310

311+
if isinstance(guided_options_request, dict):
312+
if len(guided_options_request) > 1:
313+
raise ValueError(
314+
"You can only use one guided decoding but multiple is "
315+
f"specified: {guided_options_request}")
316+
guided_options_request = GuidedDecodingRequest(
317+
**guided_options_request)
318+
306319
if sampling_params is None:
307320
# Use default sampling params.
308321
sampling_params = SamplingParams()
@@ -311,7 +324,8 @@ def generate(
311324
inputs=inputs,
312325
params=sampling_params,
313326
lora_request=lora_request,
314-
prompt_adapter_request=prompt_adapter_request)
327+
prompt_adapter_request=prompt_adapter_request,
328+
guided_options=guided_options_request)
315329

316330
outputs = self._run_engine(use_tqdm=use_tqdm)
317331
return LLMEngine.validate_outputs(outputs, RequestOutput)
@@ -508,6 +522,7 @@ def _validate_and_add_requests(
508522
Sequence[PoolingParams]],
509523
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
510524
prompt_adapter_request: Optional[PromptAdapterRequest],
525+
guided_options: Optional[GuidedDecodingRequest] = None,
511526
) -> None:
512527
if isinstance(inputs, (str, dict)):
513528
# Convert a single prompt to a list.
@@ -523,6 +538,15 @@ def _validate_and_add_requests(
523538
raise ValueError("The lengths of prompts and lora_request "
524539
"must be the same.")
525540

541+
if isinstance(params, list):
542+
params = [
543+
self._add_guided_processor(param, guided_options)
544+
if isinstance(param, SamplingParams) else param
545+
for param in params
546+
]
547+
elif isinstance(params, SamplingParams):
548+
params = self._add_guided_processor(params, guided_options)
549+
526550
# Add requests to the engine.
527551
for i, request_inputs in enumerate(inputs):
528552
self._add_request(
@@ -548,6 +572,24 @@ def _add_request(
548572
lora_request=lora_request,
549573
prompt_adapter_request=prompt_adapter_request)
550574

575+
def _add_guided_processor(
576+
self,
577+
params: SamplingParams,
578+
guided_options: Optional[GuidedDecodingRequest] = None):
579+
if guided_options:
580+
if guided_options.guided_decoding_backend is None:
581+
decoding_config = self.llm_engine.get_decoding_config()
582+
guided_options.guided_decoding_backend = (
583+
decoding_config.guided_decoding_backend)
584+
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
585+
guided_options.guided_decoding_backend, guided_options,
586+
self.get_tokenizer())
587+
if guided_logits_processor:
588+
if params.logits_processors is None:
589+
params.logits_processors = []
590+
params.logits_processors.append(guided_logits_processor)
591+
return params
592+
551593
def _run_engine(
552594
self, *, use_tqdm: bool
553595
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:

vllm/entrypoints/openai/protocol.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Adapted from
22
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
33
import time
4+
from argparse import Namespace
45
from typing import Any, Dict, List, Literal, Optional, Union
56

67
import torch
@@ -14,6 +15,23 @@
1415
from vllm.sampling_params import LogitsProcessor, SamplingParams
1516
from vllm.utils import random_uuid
1617

18+
# torch is mocked during docs generation,
19+
# so we have to provide the values as literals
20+
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
21+
22+
try:
23+
from sphinx.ext.autodoc.mock import _MockModule
24+
25+
if isinstance(torch, _MockModule):
26+
_LONG_INFO = _MOCK_LONG_INFO
27+
else:
28+
_LONG_INFO = torch.iinfo(torch.long)
29+
except ModuleNotFoundError:
30+
_LONG_INFO = torch.iinfo(torch.long)
31+
32+
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
33+
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
34+
1735

1836
class OpenAIBaseModel(BaseModel):
1937
# OpenAI API does not allow extra fields
@@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
108126
n: Optional[int] = 1
109127
presence_penalty: Optional[float] = 0.0
110128
response_format: Optional[ResponseFormat] = None
111-
seed: Optional[int] = Field(None,
112-
ge=torch.iinfo(torch.long).min,
113-
le=torch.iinfo(torch.long).max)
129+
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
114130
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
115131
stream: Optional[bool] = False
116132
stream_options: Optional[StreamOptions] = None
@@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
327343
max_tokens: Optional[int] = 16
328344
n: int = 1
329345
presence_penalty: Optional[float] = 0.0
330-
seed: Optional[int] = Field(None,
331-
ge=torch.iinfo(torch.long).min,
332-
le=torch.iinfo(torch.long).max)
346+
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
333347
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
334348
stream: Optional[bool] = False
335349
stream_options: Optional[StreamOptions] = None

vllm/model_executor/guided_decoding/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from vllm.entrypoints.openai.protocol import (
44
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
55
CompletionRequest)
6-
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
7-
get_lm_format_enforcer_guided_decoding_logits_processor)
6+
from vllm.model_executor.guided_decoding.guided_fields import (
7+
GuidedDecodingRequest)
88
from vllm.model_executor.guided_decoding.outlines_decoding import (
9+
get_local_outlines_guided_decoding_logits_processor,
910
get_outlines_guided_decoding_logits_processor)
1011
from vllm.sampling_params import LogitsProcessor
1112

@@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
2021
return await get_outlines_guided_decoding_logits_processor(
2122
request, tokenizer)
2223
if guided_decoding_backend == 'lm-format-enforcer':
24+
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
25+
get_lm_format_enforcer_guided_decoding_logits_processor)
2326
return await get_lm_format_enforcer_guided_decoding_logits_processor(
2427
request, tokenizer)
2528

@@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
2831
"Must be one of 'outlines, 'lm-format-enforcer'")
2932

3033

34+
def get_local_guided_decoding_logits_processor(
35+
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
36+
tokenizer) -> Optional[LogitsProcessor]:
37+
# request = _adapt_request_for_tool_use(request)
38+
39+
if guided_decoding_backend == 'outlines':
40+
return get_local_outlines_guided_decoding_logits_processor(
41+
guided_options, tokenizer)
42+
if guided_decoding_backend == 'lm-format-enforcer':
43+
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
44+
get_local_lm_format_enforcer_guided_decoding_logits_processor)
45+
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
46+
guided_options, tokenizer)
47+
48+
raise ValueError(
49+
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
50+
"Must be one of 'outlines, 'lm-format-enforcer'")
51+
52+
3153
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
3254
ChatCompletionRequest]):
3355
# the legacy completion API does not support tool use

0 commit comments

Comments
 (0)