1
1
# Adapted from
2
2
# https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py
3
+ import json
3
4
import os
5
+ import re
4
6
import tempfile
5
7
8
+ import jsonschema
6
9
import openai
7
10
import pytest
8
11
import yaml
9
12
10
- from ..test_llm import get_model_path , similar
13
+ from ..test_llm import get_model_path
11
14
from .openai_server import RemoteOpenAIServer
12
15
13
16
pytestmark = pytest .mark .threadleak (enabled = False )
14
17
15
18
16
- @pytest .fixture (scope = "module" , ids = [ "TinyLlama-1.1B-Chat" ] )
19
+ @pytest .fixture (scope = "module" )
17
20
def model_name ():
18
21
return "llama-3.1-model/Llama-3.1-8B-Instruct"
19
22
20
23
21
24
@pytest .fixture (scope = "module" )
22
- def temp_extra_llm_api_options_file (request ):
25
+ def temp_extra_llm_api_options_file ():
23
26
temp_dir = tempfile .gettempdir ()
24
27
temp_file_path = os .path .join (temp_dir , "extra_llm_api_options.yaml" )
25
28
try :
@@ -37,7 +40,12 @@ def temp_extra_llm_api_options_file(request):
37
40
@pytest .fixture (scope = "module" )
38
41
def server (model_name : str , temp_extra_llm_api_options_file : str ):
39
42
model_path = get_model_path (model_name )
40
- args = ["--extra_llm_api_options" , temp_extra_llm_api_options_file ]
43
+
44
+ # Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs.
45
+ args = [
46
+ "--max_batch_size=8" , "--max_seq_len=1024" , "--max_num_tokens=1024" ,
47
+ f"--extra_llm_api_options={ temp_extra_llm_api_options_file } "
48
+ ]
41
49
with RemoteOpenAIServer (model_path , args ) as remote_server :
42
50
yield remote_server
43
51
@@ -112,12 +120,7 @@ def tool_get_current_date():
112
120
113
121
def test_chat_structural_tag (client : openai .OpenAI , model_name : str ,
114
122
tool_get_current_weather , tool_get_current_date ):
115
- messages = [
116
- {
117
- "role" :
118
- "system" ,
119
- "content" :
120
- f"""
123
+ system_prompt = f"""
121
124
# Tool Instructions
122
125
- Always execute python code in messages that you share.
123
126
- When looking for real time information use relevant functions if available else fallback to brave_search
@@ -140,20 +143,24 @@ def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
140
143
- Only call one function at a time
141
144
- Put the entire function call reply on one line
142
145
- Always add your sources when using search results to answer the user query
143
- You are a helpful assistant.""" ,
146
+ You are a helpful assistant."""
147
+ user_prompt = "You are in New York. Please get the current date and time, and the weather."
148
+
149
+ messages = [
150
+ {
151
+ "role" : "system" ,
152
+ "content" : system_prompt ,
144
153
},
145
154
{
146
- "role" :
147
- "user" ,
148
- "content" :
149
- "You are in New York. Please get the current date and time, and the weather." ,
155
+ "role" : "user" ,
156
+ "content" : user_prompt ,
150
157
},
151
158
]
152
159
153
160
chat_completion = client .chat .completions .create (
154
161
model = model_name ,
155
162
messages = messages ,
156
- max_completion_tokens = 100 ,
163
+ max_completion_tokens = 256 ,
157
164
response_format = {
158
165
"type" :
159
166
"structural_tag" ,
@@ -173,11 +180,18 @@ def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
173
180
"triggers" : ["<function=" ],
174
181
},
175
182
)
176
- assert chat_completion .id is not None
177
- assert len (chat_completion .choices ) == 1
183
+
178
184
message = chat_completion .choices [0 ].message
179
185
assert message .content is not None
180
186
assert message .role == "assistant"
181
187
182
- reference = '<function=get_current_date>{"timezone": "America/New_York"}</function>\n <function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>\n \n Sources:\n - get_current_date function\n - get_current_weather function'
183
- assert similar (chat_completion .choices [0 ].message .content , reference )
188
+ match = re .search (r'<function=get_current_weather>([\S\s]+?)</function>' ,
189
+ message .content )
190
+ params = json .loads (match .group (1 ))
191
+ jsonschema .validate (params ,
192
+ tool_get_current_weather ["function" ]["parameters" ])
193
+
194
+ match = re .search (r'<function=get_current_date>([\S\s]+?)</function>' ,
195
+ message .content )
196
+ params = json .loads (match .group (1 ))
197
+ jsonschema .validate (params , tool_get_current_date ["function" ]["parameters" ])
0 commit comments