Skip to content

Commit 66d617e

Browse files
[Frontend] Gracefully handle missing chat template and fix CI failure (#7238)
Co-authored-by: Roger Wang <[email protected]>
1 parent 7b26109 commit 66d617e

File tree

9 files changed

+125
-69
lines changed

9 files changed

+125
-69
lines changed

tests/async_engine/test_chat_template.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
1-
import os
2-
import pathlib
3-
41
import pytest
52

6-
from vllm.entrypoints.chat_utils import load_chat_template
3+
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
74
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
85
from vllm.transformers_utils.tokenizer import get_tokenizer
96

10-
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
11-
__file__))).parent.parent / "examples/template_chatml.jinja"
7+
from ..utils import VLLM_PATH
8+
9+
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
1210
assert chatml_jinja_path.exists()
1311

1412
# Define models, templates, and their corresponding expected outputs
1513
MODEL_TEMPLATE_GENERATON_OUTPUT = [
16-
("facebook/opt-125m", None, True,
17-
"Hello</s>Hi there!</s>What is the capital of</s>"),
18-
("facebook/opt-125m", None, False,
19-
"Hello</s>Hi there!</s>What is the capital of</s>"),
2014
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
2115
Hello<|im_end|>
2216
<|im_start|>assistant
@@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
9387
add_generation_prompt=add_generation_prompt)
9488

9589
# Call the function and get the result
96-
result = tokenizer.apply_chat_template(
90+
result = apply_chat_template(
91+
tokenizer,
9792
conversation=mock_request.messages,
98-
tokenize=False,
93+
chat_template=mock_request.chat_template or template_content,
9994
add_generation_prompt=mock_request.add_generation_prompt,
100-
chat_template=mock_request.chat_template or template_content)
95+
)
10196

10297
# Test assertion
10398
assert result == expected_output, (

tests/async_engine/test_openapi_server_ray.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import openai # use the official client for correctness check
22
import pytest
33

4-
from ..utils import RemoteOpenAIServer
4+
from ..utils import VLLM_PATH, RemoteOpenAIServer
55

66
# any model with a chat template should work here
77
MODEL_NAME = "facebook/opt-125m"
8+
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
9+
assert chatml_jinja_path.exists()
810

911

1012
@pytest.fixture(scope="module")
@@ -16,7 +18,9 @@ def server():
1618
"--max-model-len",
1719
"2048",
1820
"--enforce-eager",
19-
"--engine-use-ray"
21+
"--engine-use-ray",
22+
"--chat-template",
23+
str(chatml_jinja_path),
2024
]
2125

2226
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
8387
choice = chat_completion.choices[0]
8488
assert choice.finish_reason == "length"
8589
assert chat_completion.usage == openai.types.CompletionUsage(
86-
completion_tokens=10, prompt_tokens=13, total_tokens=23)
90+
completion_tokens=10, prompt_tokens=55, total_tokens=65)
8791

8892
message = choice.message
8993
assert message.content is not None and len(message.content) >= 10

tests/entrypoints/openai/test_oot_registration.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from vllm.model_executor.sampling_metadata import SamplingMetadata
1010
from vllm.utils import get_open_port
1111

12+
from ...utils import VLLM_PATH, RemoteOpenAIServer
13+
14+
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
15+
assert chatml_jinja_path.exists()
16+
1217

1318
class MyOPTForCausalLM(OPTForCausalLM):
1419

@@ -21,12 +26,25 @@ def compute_logits(self, hidden_states: torch.Tensor,
2126
return logits
2227

2328

24-
def server_function(port):
29+
def server_function(port: int):
2530
# register our dummy model
2631
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
27-
sys.argv = ["placeholder.py"] + \
28-
("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
29-
f"--dtype float32 --api-key token-abc123 --port {port}").split()
32+
33+
sys.argv = ["placeholder.py"] + [
34+
"--model",
35+
"facebook/opt-125m",
36+
"--gpu-memory-utilization",
37+
"0.10",
38+
"--dtype",
39+
"float32",
40+
"--api-key",
41+
"token-abc123",
42+
"--port",
43+
str(port),
44+
"--chat-template",
45+
str(chatml_jinja_path),
46+
]
47+
3048
import runpy
3149
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
3250

@@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
3654
ctx = torch.multiprocessing.get_context()
3755
server = ctx.Process(target=server_function, args=(port, ))
3856
server.start()
39-
MAX_SERVER_START_WAIT_S = 60
40-
client = OpenAI(
41-
base_url=f"http://localhost:{port}/v1",
42-
api_key="token-abc123",
43-
)
44-
now = time.time()
45-
while True:
46-
try:
47-
completion = client.chat.completions.create(
48-
model="facebook/opt-125m",
49-
messages=[{
50-
"role": "system",
51-
"content": "You are a helpful assistant."
52-
}, {
53-
"role": "user",
54-
"content": "Hello!"
55-
}],
56-
temperature=0,
57-
)
58-
break
59-
except OpenAIError as e:
60-
if "Connection error" in str(e):
61-
time.sleep(3)
62-
if time.time() - now > MAX_SERVER_START_WAIT_S:
63-
raise RuntimeError("Server did not start in time") from e
64-
else:
65-
raise e
66-
server.kill()
57+
58+
try:
59+
client = OpenAI(
60+
base_url=f"http://localhost:{port}/v1",
61+
api_key="token-abc123",
62+
)
63+
now = time.time()
64+
while True:
65+
try:
66+
completion = client.chat.completions.create(
67+
model="facebook/opt-125m",
68+
messages=[{
69+
"role": "system",
70+
"content": "You are a helpful assistant."
71+
}, {
72+
"role": "user",
73+
"content": "Hello!"
74+
}],
75+
temperature=0,
76+
)
77+
break
78+
except OpenAIError as e:
79+
if "Connection error" in str(e):
80+
time.sleep(3)
81+
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
82+
msg = "Server did not start in time"
83+
raise RuntimeError(msg) from e
84+
else:
85+
raise e
86+
finally:
87+
server.terminate()
88+
6789
generated_text = completion.choices[0].message.content
90+
assert generated_text is not None
6891
# make sure only the first token is generated
6992
rest = generated_text.replace("<s>", "")
7093
assert rest == ""

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _nvml():
5050

5151
class RemoteOpenAIServer:
5252
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
53-
MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds
53+
MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds
5454

5555
def __init__(
5656
self,
@@ -85,7 +85,7 @@ def __init__(
8585
stdout=sys.stdout,
8686
stderr=sys.stderr)
8787
self._wait_for_server(url=self.url_for("health"),
88-
timeout=self.MAX_SERVER_START_WAIT_S)
88+
timeout=self.MAX_START_WAIT_S)
8989

9090
def __enter__(self):
9191
return self

vllm/entrypoints/chat_utils.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import codecs
22
from dataclasses import dataclass
33
from functools import lru_cache
4-
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
5-
final)
4+
from pathlib import Path
5+
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
6+
cast, final)
67

78
# yapf conflicts with isort for this block
89
# yapf: disable
@@ -22,6 +23,7 @@
2223
from vllm.logger import init_logger
2324
from vllm.multimodal import MultiModalDataDict
2425
from vllm.multimodal.utils import async_get_and_parse_image
26+
from vllm.transformers_utils.tokenizer import AnyTokenizer
2527

2628
logger = init_logger(__name__)
2729

@@ -69,13 +71,17 @@ class ChatMessageParseResult:
6971
mm_futures: List[Awaitable[MultiModalDataDict]]
7072

7173

72-
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
74+
def load_chat_template(
75+
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
7376
if chat_template is None:
7477
return None
7578
try:
7679
with open(chat_template, "r") as f:
7780
resolved_chat_template = f.read()
7881
except OSError as e:
82+
if isinstance(chat_template, Path):
83+
raise
84+
7985
JINJA_CHARS = "{}\n"
8086
if not any(c in chat_template for c in JINJA_CHARS):
8187
msg = (f"The supplied chat template ({chat_template}) "
@@ -208,3 +214,28 @@ def parse_chat_messages(
208214
mm_futures.extend(parse_result.mm_futures)
209215

210216
return conversation, mm_futures
217+
218+
219+
def apply_chat_template(
220+
tokenizer: AnyTokenizer,
221+
conversation: List[ConversationMessage],
222+
chat_template: Optional[str],
223+
*,
224+
tokenize: bool = False, # Different from HF's default
225+
**kwargs: Any,
226+
) -> str:
227+
if chat_template is None and tokenizer.chat_template is None:
228+
raise ValueError(
229+
"As of transformers v4.44, default chat template is no longer "
230+
"allowed, so you must provide a chat template if the tokenizer "
231+
"does not define one.")
232+
233+
prompt = tokenizer.apply_chat_template(
234+
conversation=conversation,
235+
chat_template=chat_template,
236+
tokenize=tokenize,
237+
**kwargs,
238+
)
239+
assert isinstance(prompt, str)
240+
241+
return prompt

vllm/entrypoints/openai/protocol.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
190190
default=None,
191191
description=(
192192
"A Jinja template to use for this conversion. "
193-
"If this is not passed, the model's default chat template will be "
194-
"used instead."),
193+
"As of transformers v4.44, default chat template is no longer "
194+
"allowed, so you must provide a chat template if the tokenizer "
195+
"does not define one."),
195196
)
196197
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
197198
default=None,

vllm/entrypoints/openai/serving_chat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.config import ModelConfig
1111
from vllm.engine.protocol import AsyncEngineClient
1212
from vllm.entrypoints.chat_utils import (ConversationMessage,
13+
apply_chat_template,
1314
load_chat_template,
1415
parse_chat_messages)
1516
from vllm.entrypoints.logger import RequestLogger
@@ -99,16 +100,15 @@ async def create_chat_completion(
99100
tool.model_dump() for tool in request.tools
100101
]
101102

102-
prompt = tokenizer.apply_chat_template(
103+
prompt = apply_chat_template(
104+
tokenizer,
103105
conversation=conversation,
104-
tokenize=False,
106+
chat_template=request.chat_template or self.chat_template,
105107
add_generation_prompt=request.add_generation_prompt,
106108
tools=tool_dicts,
107109
documents=request.documents,
108-
chat_template=request.chat_template or self.chat_template,
109110
**(request.chat_template_kwargs or {}),
110111
)
111-
assert isinstance(prompt, str)
112112
except Exception as e:
113113
logger.error("Error in applying chat template from request: %s", e)
114114
return self.create_error_response(str(e))

vllm/entrypoints/openai/serving_tokenization.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from vllm.config import ModelConfig
44
from vllm.engine.protocol import AsyncEngineClient
5-
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
5+
from vllm.entrypoints.chat_utils import (apply_chat_template,
6+
load_chat_template,
7+
parse_chat_messages)
68
from vllm.entrypoints.logger import RequestLogger
79
# yapf conflicts with isort for this block
810
# yapf: disable
@@ -70,12 +72,12 @@ async def create_tokenize(
7072
logger.warning(
7173
"Multi-modal inputs are ignored during tokenization")
7274

73-
prompt = tokenizer.apply_chat_template(
74-
add_generation_prompt=request.add_generation_prompt,
75+
prompt = apply_chat_template(
76+
tokenizer,
7577
conversation=conversation,
76-
tokenize=False,
77-
chat_template=self.chat_template)
78-
assert isinstance(prompt, str)
78+
chat_template=self.chat_template,
79+
add_generation_prompt=request.add_generation_prompt,
80+
)
7981
else:
8082
prompt = request.prompt
8183

vllm/transformers_utils/tokenizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
1313
from vllm.utils import make_async
1414

15+
from .tokenizer_group import AnyTokenizer
16+
1517
logger = init_logger(__name__)
1618

1719

18-
def get_cached_tokenizer(
19-
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
20-
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
20+
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
2121
"""Get tokenizer with cached properties.
2222
2323
This will patch the tokenizer object in place.
@@ -63,7 +63,7 @@ def get_tokenizer(
6363
revision: Optional[str] = None,
6464
download_dir: Optional[str] = None,
6565
**kwargs,
66-
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
66+
) -> AnyTokenizer:
6767
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
6868
"""
6969
if VLLM_USE_MODELSCOPE:

0 commit comments

Comments
 (0)