From 03b690b34b6e1cbfe5001d4e1a3533bd9a6c52d0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 03:02:07 +0000 Subject: [PATCH 1/9] Gracefully handle missing chat template --- tests/async_engine/test_chat_template.py | 12 ++++---- vllm/entrypoints/chat_utils.py | 30 +++++++++++++++++-- vllm/entrypoints/openai/protocol.py | 5 ++-- vllm/entrypoints/openai/serving_chat.py | 8 ++--- .../openai/serving_tokenization.py | 14 +++++---- vllm/transformers_utils/tokenizer.py | 8 ++--- 6 files changed, 52 insertions(+), 25 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index aea8a7fed6e3..76936fb4e195 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -3,7 +3,7 @@ import pytest -from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer @@ -13,10 +13,6 @@ # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATON_OUTPUT = [ - ("facebook/opt-125m", None, True, - "HelloHi there!What is the capital of"), - ("facebook/opt-125m", None, False, - "HelloHi there!What is the capital of"), ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant @@ -93,11 +89,13 @@ def test_get_gen_prompt(model, template, add_generation_prompt, add_generation_prompt=add_generation_prompt) # Call the function and get the result - result = tokenizer.apply_chat_template( + result = apply_chat_template( + tokenizer, conversation=mock_request.messages, + chat_template=mock_request.chat_template or template_content, tokenize=False, add_generation_prompt=mock_request.add_generation_prompt, - chat_template=mock_request.chat_template or template_content) + ) # Test assertion assert result == expected_output, ( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 072450a6146e..b1d499f0b649 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,8 +1,8 @@ import codecs from dataclasses import dataclass from functools import lru_cache -from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast, - final) +from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union, + cast, final) # yapf conflicts with isort for this block # yapf: disable @@ -22,6 +22,7 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -208,3 +209,28 @@ def parse_chat_messages( mm_futures.extend(parse_result.mm_futures) return conversation, mm_futures + + +def apply_chat_template( + tokenizer: AnyTokenizer, + conversation: List[ConversationMessage], + chat_template: Optional[str], + *, + tokenize: bool = False, # Different from HF's default + **kwargs: Any, +) -> str: + if chat_template is None and tokenizer.chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one.") + + prompt = tokenizer.apply_chat_template( + conversation=conversation, + chat_template=chat_template, + tokenize=tokenize, + **kwargs, + ) + assert isinstance(prompt, str) + + return prompt diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 76318a127122..70467bd87969 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel): default=None, description=( "A Jinja template to use for this conversion. " - "If this is not passed, the model's default chat template will be " - "used instead."), + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), ) chat_template_kwargs: Optional[Dict[str, Any]] = Field( default=None, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d215754993e8..73ee2c75c1bd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, + apply_chat_template, load_chat_template, parse_chat_messages) from vllm.entrypoints.logger import RequestLogger @@ -98,16 +99,15 @@ async def create_chat_completion( tool.model_dump() for tool in request.tools ] - prompt = tokenizer.apply_chat_template( + prompt = apply_chat_template( + tokenizer, conversation=conversation, - tokenize=False, + chat_template=request.chat_template or self.chat_template, add_generation_prompt=request.add_generation_prompt, tools=tool_dicts, documents=request.documents, - chat_template=request.chat_template or self.chat_template, **(request.chat_template_kwargs or {}), ) - assert isinstance(prompt, str) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 5b6b979b9b9e..1aeabb7a7d72 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -2,7 +2,9 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages +from vllm.entrypoints.chat_utils import (apply_chat_template, + load_chat_template, + parse_chat_messages) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -70,12 +72,12 @@ async def create_tokenize( logger.warning( "Multi-modal inputs are ignored during tokenization") - prompt = tokenizer.apply_chat_template( - add_generation_prompt=request.add_generation_prompt, + prompt = apply_chat_template( + tokenizer, conversation=conversation, - tokenize=False, - chat_template=self.chat_template) - assert isinstance(prompt, str) + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + ) else: prompt = request.prompt diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index bf26d889d138..25e4c41592c6 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -12,12 +12,12 @@ from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async +from .tokenizer_group import AnyTokenizer + logger = init_logger(__name__) -def get_cached_tokenizer( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: """Get tokenizer with cached properties. This will patch the tokenizer object in place. @@ -63,7 +63,7 @@ def get_tokenizer( revision: Optional[str] = None, download_dir: Optional[str] = None, **kwargs, -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +) -> AnyTokenizer: """Gets a tokenizer for the given model name via HuggingFace or ModelScope. """ if VLLM_USE_MODELSCOPE: From 309a6c87f5652d4e35305e31de75815f965bcf00 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 03:06:15 +0000 Subject: [PATCH 2/9] Remove redundant arg --- tests/async_engine/test_chat_template.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 76936fb4e195..58f9ce2108f7 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -93,7 +93,6 @@ def test_get_gen_prompt(model, template, add_generation_prompt, tokenizer, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, - tokenize=False, add_generation_prompt=mock_request.add_generation_prompt, ) From a77db250c6f538f37effd3566d0fa9806e69c17f Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 6 Aug 2024 21:29:12 -0700 Subject: [PATCH 3/9] update test with example template --- tests/async_engine/test_openapi_server_ray.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 5ecd770ede83..d30a4d9c945f 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -1,3 +1,6 @@ +import os +import pathlib + import openai # use the official client for correctness check import pytest @@ -5,6 +8,9 @@ # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" +chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( + __file__))).parent.parent / "examples/template_chatml.jinja" +assert chatml_jinja_path.exists() @pytest.fixture(scope="module") @@ -16,7 +22,9 @@ def server(): "--max-model-len", "2048", "--enforce-eager", - "--engine-use-ray" + "--engine-use-ray", + "--chat-template", + str(chatml_jinja_path), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: From d762e3989afe2e1c9e99e05229517a6f4e1bb291 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 05:12:35 +0000 Subject: [PATCH 4/9] Fix and clean test --- tests/async_engine/test_chat_template.py | 8 +++----- tests/async_engine/test_openapi_server_ray.py | 10 +++------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 58f9ce2108f7..4df6c0297328 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -1,14 +1,12 @@ -import os -import pathlib - import pytest from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer -chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( - __file__))).parent.parent / "examples/template_chatml.jinja" +from ..utils import VLLM_PATH + +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() # Define models, templates, and their corresponding expected outputs diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index d30a4d9c945f..0d53b39e7ce1 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -1,15 +1,11 @@ -import os -import pathlib - import openai # use the official client for correctness check import pytest -from ..utils import RemoteOpenAIServer +from ..utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" -chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( - __file__))).parent.parent / "examples/template_chatml.jinja" +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() @@ -91,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI): choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=13, total_tokens=23) + completion_tokens=10, prompt_tokens=55, total_tokens=65) message = choice.message assert message.content is not None and len(message.content) >= 10 From 24c6a055ef6a48b92fbf91b673fa420c5bc79d9b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 06:45:54 +0000 Subject: [PATCH 5/9] Fix entrypoints test --- .../openai/test_oot_registration.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 5272ac4065f1..6889824978ed 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -9,6 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +from ...utils import VLLM_PATH + +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" +assert chatml_jinja_path.exists() + class MyOPTForCausalLM(OPTForCausalLM): @@ -24,9 +29,22 @@ def compute_logits(self, hidden_states: torch.Tensor, def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) - sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " - f"--dtype float32 --api-key token-abc123 --port {port}").split() + + sys.argv = ["placeholder.py"] + [ + "--model", + "facebook/opt-125m", + "--gpu-memory-utilization", + "0.10", + "--dtype", + "float32", + "--api-key", + "token-abc123", + "--port", + str(port), + "--chat-template", + str(chatml_jinja_path), + ] + import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') From 407432a3bbaeae05529d1169370fd491ac0805d6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 07:02:20 +0000 Subject: [PATCH 6/9] Fix `load_chat_template` not accept `Path` objects --- vllm/entrypoints/chat_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b1d499f0b649..12634c326185 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,6 +1,7 @@ import codecs from dataclasses import dataclass from functools import lru_cache +from pathlib import Path from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast, final) @@ -70,13 +71,17 @@ class ChatMessageParseResult: mm_futures: List[Awaitable[MultiModalDataDict]] -def load_chat_template(chat_template: Optional[str]) -> Optional[str]: +def load_chat_template( + chat_template: Optional[Union[Path, str]]) -> Optional[str]: if chat_template is None: return None try: with open(chat_template, "r") as f: resolved_chat_template = f.read() except OSError as e: + if isinstance(chat_template, Path): + raise + JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): msg = (f"The supplied chat template ({chat_template}) " From ae9363b8d40f4daa5edcfb44b7f088e201e1fa46 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 07:13:13 +0000 Subject: [PATCH 7/9] Fix RPC server failing to be shut down after the test --- tests/entrypoints/openai/test_oot_registration.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 6889824978ed..1972cf2a2a1e 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -26,7 +26,7 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits -def server_function(port): +def server_function(port: int): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) @@ -81,8 +81,10 @@ def test_oot_registration_for_api_server(): raise RuntimeError("Server did not start in time") from e else: raise e - server.kill() + server.terminate() + generated_text = completion.choices[0].message.content + assert generated_text is not None # make sure only the first token is generated rest = generated_text.replace("", "") assert rest == "" From 82d0e4cdf3027485c000e30349ba6c7ccab63b06 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 07:17:10 +0000 Subject: [PATCH 8/9] Fix possible failure to clean up --- .../openai/test_oot_registration.py | 61 ++++++++++--------- tests/utils.py | 4 +- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 1972cf2a2a1e..f33039b55a2c 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -9,7 +9,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port -from ...utils import VLLM_PATH +from ...utils import RemoteOpenAIServer, VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() @@ -54,34 +54,37 @@ def test_oot_registration_for_api_server(): ctx = torch.multiprocessing.get_context() server = ctx.Process(target=server_function, args=(port, )) server.start() - MAX_SERVER_START_WAIT_S = 60 - client = OpenAI( - base_url=f"http://localhost:{port}/v1", - api_key="token-abc123", - ) - now = time.time() - while True: - try: - completion = client.chat.completions.create( - model="facebook/opt-125m", - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], - temperature=0, - ) - break - except OpenAIError as e: - if "Connection error" in str(e): - time.sleep(3) - if time.time() - now > MAX_SERVER_START_WAIT_S: - raise RuntimeError("Server did not start in time") from e - else: - raise e - server.terminate() + + try: + client = OpenAI( + base_url=f"http://localhost:{port}/v1", + api_key="token-abc123", + ) + now = time.time() + while True: + try: + completion = client.chat.completions.create( + model="facebook/opt-125m", + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + ) + break + except OpenAIError as e: + if "Connection error" in str(e): + time.sleep(3) + if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S: + msg = "Server did not start in time" + raise RuntimeError(msg) from e + else: + raise e + finally: + server.terminate() generated_text = completion.choices[0].message.content assert generated_text is not None diff --git a/tests/utils.py b/tests/utils.py index 666694299d39..6333854bc43d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -50,7 +50,7 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key - MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds + MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds def __init__( self, @@ -85,7 +85,7 @@ def __init__( stdout=sys.stdout, stderr=sys.stderr) self._wait_for_server(url=self.url_for("health"), - timeout=self.MAX_SERVER_START_WAIT_S) + timeout=self.MAX_START_WAIT_S) def __enter__(self): return self From a74a338273d8ab620c40149b1660bf605fc3b89b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 07:20:54 +0000 Subject: [PATCH 9/9] isort --- tests/entrypoints/openai/test_oot_registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index f33039b55a2c..9f9a4cd972c5 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -9,7 +9,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port -from ...utils import RemoteOpenAIServer, VLLM_PATH +from ...utils import VLLM_PATH, RemoteOpenAIServer chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists()