Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/tokenization/test_mistral_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool

from vllm.transformers_utils.tokenizers.mistral import (
make_mistral_chat_completion_request)


# yapf: enable
@pytest.mark.parametrize(
"openai_request,expected_mistral_request",
[(
{
"messages": [{
"role": "user",
"content": "What is the current local date and time?",
}],
"tools": [{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
)],
)
def test_make_mistral_chat_completion_request(openai_request,
expected_mistral_request):
assert (make_mistral_chat_completion_request(
openai_request["messages"],
openai_request["tools"]) == expected_mistral_request)
57 changes: 38 additions & 19 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
return matched_files[0]


def make_mistral_chat_completion_request(
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str,
Any]]] = None) -> "ChatCompletionRequest":
last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True

last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True

# mistral-common requires AssistantMessage content to be string [1].
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
for message in messages:
if message.get("role") == "assistant":
content = message.get("content")
if isinstance(content, list):
content = "\n".join(chunk.get("text") for chunk in content)
message["content"] = content

# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict to be present, even if it's empty.
if tools:
for function in [
tool["function"] for tool in tools
if tool["type"] == "function"
]:
function.setdefault("parameters", {})

from mistral_common.protocol.instruct.request import ChatCompletionRequest
return ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]


class MistralTokenizer:

def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
Expand Down Expand Up @@ -283,27 +319,10 @@ def encode(self, prompt: str) -> List[int]:

def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typing here was wrong :)

**kwargs) -> List[int]:

last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True

from mistral_common.protocol.instruct.request import (
ChatCompletionRequest)

# mistral-common requires AssistantMessage content to be string [1].
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
for message in messages:
if message.get("role") == "assistant":
content = message.get("content")
if isinstance(content, list):
content = "\n".join(chunk.get("text") for chunk in content)
message["content"] = content
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
request = make_mistral_chat_completion_request(messages, tools)
encoded = self.mistral.encode_chat_completion(request)

# encode-decode to get clean prompt
Expand Down