Skip to content

Commit 1ae3abb

Browse files
committed
fix: missing logprobs in response, incorrect response type for functionary, minor type issues. Closes #1328 Closes #1314
1 parent 9111b6e commit 1ae3abb

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

llama_cpp/llama_chat_format.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dataclasses
77
import random
88
import string
9-
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol
9+
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
1010

1111
import jinja2
1212

@@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
338338
}
339339
],
340340
},
341+
"logprobs": None,
341342
"finish_reason": "tool_calls",
342343
}
343344
],
@@ -1191,7 +1192,6 @@ def format_mistral_instruct(
11911192
elif (
11921193
message["role"] == "assistant"
11931194
and message["content"] is not None
1194-
and isinstance(message["content"], str)
11951195
):
11961196
prompt += " [/INST]" + message["content"] + eos
11971197
prompt += " [/INST]"
@@ -1263,7 +1263,7 @@ def format_gemma(
12631263
**kwargs: Any,
12641264
) -> ChatFormatterResponse:
12651265
system_message = _get_system_message(messages)
1266-
if system_message is not None and system_message != "":
1266+
if system_message != "":
12671267
logger.debug(
12681268
"`role='system'` messages are not allowed on Google's Gemma models."
12691269
)
@@ -1628,6 +1628,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
16281628
}
16291629
],
16301630
},
1631+
"logprobs": None,
16311632
"finish_reason": "tool_calls",
16321633
}
16331634
],
@@ -1909,14 +1910,14 @@ def get_grammar(function_call):
19091910
return grammar
19101911

19111912
def create_completion(stop):
1912-
completion: llama_types.Completion = llama.create_completion(
1913+
completion = cast(llama_types.Completion, llama.create_completion(
19131914
prompt=prompt,
19141915
temperature=temperature,
19151916
top_p=top_p,
19161917
top_k=top_k,
19171918
min_p=min_p,
19181919
typical_p=typical_p,
1919-
stream=stream,
1920+
stream=False,
19201921
stop=stop,
19211922
max_tokens=max_tokens,
19221923
presence_penalty=presence_penalty,
@@ -1929,7 +1930,7 @@ def create_completion(stop):
19291930
model=model,
19301931
logits_processor=logits_processor,
19311932
grammar=grammar,
1932-
)
1933+
))
19331934

19341935
return completion
19351936

@@ -2050,7 +2051,7 @@ def create_completion(stop):
20502051
assert "usage" in completion
20512052
assert len(function_calls) == len(function_bodies)
20522053

2053-
tool_calls = []
2054+
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
20542055
for function_call, function_body in zip(function_calls, function_bodies):
20552056
tool_calls.append(
20562057
{
@@ -2070,6 +2071,12 @@ def create_completion(stop):
20702071
)
20712072

20722073
# TODO: support stream mode
2074+
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
2075+
"function_call": {
2076+
"name": tool_calls[0]["function"]["name"],
2077+
"arguments": tool_calls[0]["function"]["arguments"],
2078+
}
2079+
} if len(tool_calls) == 1 else {}
20732080
return llama_types.CreateChatCompletionResponse(
20742081
id="chat" + completion["id"],
20752082
object="chat.completion",
@@ -2078,14 +2085,12 @@ def create_completion(stop):
20782085
choices=[
20792086
{
20802087
"index": 0,
2088+
"logprobs": None,
20812089
"message": {
20822090
"role": "assistant",
20832091
"content": None if content == "" else content,
2084-
"function_call": {
2085-
"name": tool_calls[0]["function"]["name"],
2086-
"arguments": tool_calls[0]["function"]["arguments"],
2087-
} if len(tool_calls) > 0 else None,
2088-
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
2092+
"tool_calls": tool_calls,
2093+
**function_call_dict,
20892094
},
20902095
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
20912096
}
@@ -2565,8 +2570,8 @@ def chatml_function_calling(
25652570
tool_name = text[len("functions.") :]
25662571
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
25672572
if not stream:
2568-
completions = []
2569-
completions_tool_name = []
2573+
completions: List[llama_types.CreateCompletionResponse] = []
2574+
completions_tool_name: List[str] = []
25702575
while tool is not None:
25712576
prompt += f"functions.{tool_name}:\n"
25722577
try:
@@ -2603,6 +2608,7 @@ def chatml_function_calling(
26032608
logits_processor=logits_processor,
26042609
grammar=grammar,
26052610
)
2611+
completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
26062612
completions.append(completion_or_chunks)
26072613
completions_tool_name.append(tool_name)
26082614
prompt += completion_or_chunks["choices"][0]["text"]
@@ -2631,14 +2637,15 @@ def chatml_function_calling(
26312637
follow_up_gbnf_tool_grammar, verbose=llama.verbose
26322638
),
26332639
)
2640+
response = cast(llama_types.CreateCompletionResponse, response)
26342641

26352642
tool_name = response["choices"][0]["text"][len("functions.") :]
26362643
tool = next(
26372644
(tool for tool in tools if tool["function"]["name"] == tool_name), None
26382645
)
26392646

26402647
# Merge completions
2641-
function_call = {
2648+
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
26422649
"function_call": {
26432650
"name": tool_name,
26442651
"arguments": completions[0]["choices"][0]["text"],
@@ -2653,6 +2660,7 @@ def chatml_function_calling(
26532660
{
26542661
"finish_reason": "tool_calls",
26552662
"index": 0,
2663+
"logprobs": None,
26562664
"message": {
26572665
"role": "assistant",
26582666
"content": None,
@@ -2673,20 +2681,22 @@ def chatml_function_calling(
26732681
zip(completions_tool_name, completions)
26742682
)
26752683
],
2676-
**function_call
2684+
**function_call_dict
26772685
},
26782686
}
26792687
],
26802688
"usage": {
26812689
"completion_tokens": sum(
2682-
completion["usage"]["completion_tokens"]
2690+
completion["usage"]["completion_tokens"] if "usage" in completion else 0
26832691
for completion in completions
26842692
),
26852693
"prompt_tokens": sum(
2686-
completion["usage"]["prompt_tokens"] for completion in completions
2694+
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
2695+
for completion in completions
26872696
),
26882697
"total_tokens": sum(
2689-
completion["usage"]["total_tokens"] for completion in completions
2698+
completion["usage"]["total_tokens"] if "usage" in completion else 0
2699+
for completion in completions
26902700
),
26912701
},
26922702
}

0 commit comments

Comments
 (0)