Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import os
import pathlib

import pytest

from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer

chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
__file__))).parent.parent / "examples/template_chatml.jinja"
from ..utils import VLLM_PATH

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", None, True,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", None, False,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Expand Down Expand Up @@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)

# Call the function and get the result
result = tokenizer.apply_chat_template(
result = apply_chat_template(
tokenizer,
conversation=mock_request.messages,
tokenize=False,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)
)

# Test assertion
assert result == expected_output, (
Expand Down
10 changes: 7 additions & 3 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import openai # use the official client for correctness check
import pytest

from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()


@pytest.fixture(scope="module")
Expand All @@ -16,7 +18,9 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
"--engine-use-ray"
"--engine-use-ray",
"--chat-template",
str(chatml_jinja_path),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down Expand Up @@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=13, total_tokens=23)
completion_tokens=10, prompt_tokens=55, total_tokens=65)

message = choice.message
assert message.content is not None and len(message.content) >= 10
Expand Down
87 changes: 55 additions & 32 deletions tests/entrypoints/openai/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port

from ...utils import VLLM_PATH, RemoteOpenAIServer

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()


class MyOPTForCausalLM(OPTForCausalLM):

Expand All @@ -21,12 +26,25 @@ def compute_logits(self, hidden_states: torch.Tensor,
return logits


def server_function(port):
def server_function(port: int):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
f"--dtype float32 --api-key token-abc123 --port {port}").split()

sys.argv = ["placeholder.py"] + [
"--model",
"facebook/opt-125m",
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float32",
"--api-key",
"token-abc123",
"--port",
str(port),
"--chat-template",
str(chatml_jinja_path),
]

import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')

Expand All @@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
now = time.time()
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
if time.time() - now > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server did not start in time") from e
else:
raise e
server.kill()

try:
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
now = time.time()
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
msg = "Server did not start in time"
raise RuntimeError(msg) from e
else:
raise e
finally:
server.terminate()

generated_text = completion.choices[0].message.content
assert generated_text is not None
# make sure only the first token is generated
rest = generated_text.replace("<s>", "")
assert rest == ""
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _nvml():

class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds
MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds

def __init__(
self,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)
timeout=self.MAX_START_WAIT_S)

def __enter__(self):
return self
Expand Down
37 changes: 34 additions & 3 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import codecs
from dataclasses import dataclass
from functools import lru_cache
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
cast, final)

# yapf conflicts with isort for this block
# yapf: disable
Expand All @@ -22,6 +23,7 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.transformers_utils.tokenizer import AnyTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -69,13 +71,17 @@ class ChatMessageParseResult:
mm_futures: List[Awaitable[MultiModalDataDict]]


def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise

JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
Expand Down Expand Up @@ -208,3 +214,28 @@ def parse_chat_messages(
mm_futures.extend(parse_result.mm_futures)

return conversation, mm_futures


def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")

prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)

return prompt
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=(
"A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead."),
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
Expand Down Expand Up @@ -98,16 +99,15 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = tokenizer.apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
Expand Down
14 changes: 8 additions & 6 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.chat_utils import (apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -70,12 +72,12 @@ async def create_tokenize(
logger.warning(
"Multi-modal inputs are ignored during tokenization")

prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=self.chat_template)
assert isinstance(prompt, str)
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = request.prompt

Expand Down
8 changes: 4 additions & 4 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async

from .tokenizer_group import AnyTokenizer

logger = init_logger(__name__)


def get_cached_tokenizer(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.

This will patch the tokenizer object in place.
Expand Down Expand Up @@ -63,7 +63,7 @@ def get_tokenizer(
revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> AnyTokenizer:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if VLLM_USE_MODELSCOPE:
Expand Down