6
6
import dataclasses
7
7
import random
8
8
import string
9
- from typing import Any , Dict , Iterator , List , Literal , Optional , Tuple , Union , Protocol
9
+ from typing import Any , Dict , Iterator , List , Literal , Optional , Tuple , Union , Protocol , cast
10
10
11
11
import jinja2
12
12
@@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
338
338
}
339
339
],
340
340
},
341
+ "logprobs" : None ,
341
342
"finish_reason" : "tool_calls" ,
342
343
}
343
344
],
@@ -1191,7 +1192,6 @@ def format_mistral_instruct(
1191
1192
elif (
1192
1193
message ["role" ] == "assistant"
1193
1194
and message ["content" ] is not None
1194
- and isinstance (message ["content" ], str )
1195
1195
):
1196
1196
prompt += " [/INST]" + message ["content" ] + eos
1197
1197
prompt += " [/INST]"
@@ -1263,7 +1263,7 @@ def format_gemma(
1263
1263
** kwargs : Any ,
1264
1264
) -> ChatFormatterResponse :
1265
1265
system_message = _get_system_message (messages )
1266
- if system_message is not None and system_message != "" :
1266
+ if system_message != "" :
1267
1267
logger .debug (
1268
1268
"`role='system'` messages are not allowed on Google's Gemma models."
1269
1269
)
@@ -1628,6 +1628,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
1628
1628
}
1629
1629
],
1630
1630
},
1631
+ "logprobs" : None ,
1631
1632
"finish_reason" : "tool_calls" ,
1632
1633
}
1633
1634
],
@@ -1909,14 +1910,14 @@ def get_grammar(function_call):
1909
1910
return grammar
1910
1911
1911
1912
def create_completion (stop ):
1912
- completion : llama_types .Completion = llama .create_completion (
1913
+ completion = cast ( llama_types .Completion , llama .create_completion (
1913
1914
prompt = prompt ,
1914
1915
temperature = temperature ,
1915
1916
top_p = top_p ,
1916
1917
top_k = top_k ,
1917
1918
min_p = min_p ,
1918
1919
typical_p = typical_p ,
1919
- stream = stream ,
1920
+ stream = False ,
1920
1921
stop = stop ,
1921
1922
max_tokens = max_tokens ,
1922
1923
presence_penalty = presence_penalty ,
@@ -1929,7 +1930,7 @@ def create_completion(stop):
1929
1930
model = model ,
1930
1931
logits_processor = logits_processor ,
1931
1932
grammar = grammar ,
1932
- )
1933
+ ))
1933
1934
1934
1935
return completion
1935
1936
@@ -2050,7 +2051,7 @@ def create_completion(stop):
2050
2051
assert "usage" in completion
2051
2052
assert len (function_calls ) == len (function_bodies )
2052
2053
2053
- tool_calls = []
2054
+ tool_calls : List [ llama_types . ChatCompletionMessageToolCall ] = []
2054
2055
for function_call , function_body in zip (function_calls , function_bodies ):
2055
2056
tool_calls .append (
2056
2057
{
@@ -2070,6 +2071,12 @@ def create_completion(stop):
2070
2071
)
2071
2072
2072
2073
# TODO: support stream mode
2074
+ function_call_dict : Union [Dict [str , str ], Dict [Literal ["function_call" ], llama_types .ChatCompletionRequestAssistantMessageFunctionCall ]] = {
2075
+ "function_call" : {
2076
+ "name" : tool_calls [0 ]["function" ]["name" ],
2077
+ "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2078
+ }
2079
+ } if len (tool_calls ) == 1 else {}
2073
2080
return llama_types .CreateChatCompletionResponse (
2074
2081
id = "chat" + completion ["id" ],
2075
2082
object = "chat.completion" ,
@@ -2078,14 +2085,12 @@ def create_completion(stop):
2078
2085
choices = [
2079
2086
{
2080
2087
"index" : 0 ,
2088
+ "logprobs" : None ,
2081
2089
"message" : {
2082
2090
"role" : "assistant" ,
2083
2091
"content" : None if content == "" else content ,
2084
- "function_call" : {
2085
- "name" : tool_calls [0 ]["function" ]["name" ],
2086
- "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2087
- } if len (tool_calls ) > 0 else None ,
2088
- "tool_calls" : tool_calls if len (tool_calls ) > 0 else None ,
2092
+ "tool_calls" : tool_calls ,
2093
+ ** function_call_dict ,
2089
2094
},
2090
2095
"finish_reason" : "tool_calls" if len (tool_calls ) > 0 else "stop" ,
2091
2096
}
@@ -2565,8 +2570,8 @@ def chatml_function_calling(
2565
2570
tool_name = text [len ("functions." ) :]
2566
2571
tool = next ((tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None )
2567
2572
if not stream :
2568
- completions = []
2569
- completions_tool_name = []
2573
+ completions : List [ llama_types . CreateCompletionResponse ] = []
2574
+ completions_tool_name : List [ str ] = []
2570
2575
while tool is not None :
2571
2576
prompt += f"functions.{ tool_name } :\n "
2572
2577
try :
@@ -2603,6 +2608,7 @@ def chatml_function_calling(
2603
2608
logits_processor = logits_processor ,
2604
2609
grammar = grammar ,
2605
2610
)
2611
+ completion_or_chunks = cast (llama_types .CreateCompletionResponse , completion_or_chunks )
2606
2612
completions .append (completion_or_chunks )
2607
2613
completions_tool_name .append (tool_name )
2608
2614
prompt += completion_or_chunks ["choices" ][0 ]["text" ]
@@ -2631,14 +2637,15 @@ def chatml_function_calling(
2631
2637
follow_up_gbnf_tool_grammar , verbose = llama .verbose
2632
2638
),
2633
2639
)
2640
+ response = cast (llama_types .CreateCompletionResponse , response )
2634
2641
2635
2642
tool_name = response ["choices" ][0 ]["text" ][len ("functions." ) :]
2636
2643
tool = next (
2637
2644
(tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None
2638
2645
)
2639
2646
2640
2647
# Merge completions
2641
- function_call = {
2648
+ function_call_dict : Union [ Dict [ str , str ], Dict [ Literal [ " function_call" ], llama_types . ChatCompletionRequestAssistantMessageFunctionCall ]] = {
2642
2649
"function_call" : {
2643
2650
"name" : tool_name ,
2644
2651
"arguments" : completions [0 ]["choices" ][0 ]["text" ],
@@ -2653,6 +2660,7 @@ def chatml_function_calling(
2653
2660
{
2654
2661
"finish_reason" : "tool_calls" ,
2655
2662
"index" : 0 ,
2663
+ "logprobs" : None ,
2656
2664
"message" : {
2657
2665
"role" : "assistant" ,
2658
2666
"content" : None ,
@@ -2673,20 +2681,22 @@ def chatml_function_calling(
2673
2681
zip (completions_tool_name , completions )
2674
2682
)
2675
2683
],
2676
- ** function_call
2684
+ ** function_call_dict
2677
2685
},
2678
2686
}
2679
2687
],
2680
2688
"usage" : {
2681
2689
"completion_tokens" : sum (
2682
- completion ["usage" ]["completion_tokens" ]
2690
+ completion ["usage" ]["completion_tokens" ] if "usage" in completion else 0
2683
2691
for completion in completions
2684
2692
),
2685
2693
"prompt_tokens" : sum (
2686
- completion ["usage" ]["prompt_tokens" ] for completion in completions
2694
+ completion ["usage" ]["prompt_tokens" ] if "usage" in completion else 0
2695
+ for completion in completions
2687
2696
),
2688
2697
"total_tokens" : sum (
2689
- completion ["usage" ]["total_tokens" ] for completion in completions
2698
+ completion ["usage" ]["total_tokens" ] if "usage" in completion else 0
2699
+ for completion in completions
2690
2700
),
2691
2701
},
2692
2702
}
0 commit comments