Skip to content

Commit 3e46624

Browse files
nv-guomingzsyuoni
andauthored
[https://nvbugs/5375594][fix] fix oom issue on structural_tag test case (#6838)
Signed-off-by: nv-guomingz <[email protected]> Signed-off-by: Enwei Zhu <[email protected]> Co-authored-by: Enwei Zhu <[email protected]>
1 parent fd8f417 commit 3e46624

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bf
4343
examples/test_whisper.py::test_llm_whisper_general[large-v3-enable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime] SKIP (https://nvbugs/4866931)
4444
examples/test_nemotron.py::test_llm_nemotron_3_8b_1gpu[bfloat16-fp8] SKIP (https://nvbugs/4961624)
4545
examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-chunked_summarization_long] SKIP (https://nvbugs/5321371)
46-
test_e2e.py::test_openai_chat_structural_tag_example SKIP (https://nvbugspro.nvidia.com/bug/5375594)
4746
cpp/test_e2e.py::test_model[fp8-chatglm-90] SKIP (https://nvbugs/5034830)
4847
full:B200_PCIe/unittest/trt/functional SKIP (Disable for Blackwell)
4948
full:B200_PCIe/unittest/trt/quantization SKIP (Disable for Blackwell)

tests/unittest/llmapi/apps/_test_openai_chat_json.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@ def temp_extra_llm_api_options_file(request):
2626
temp_dir = tempfile.gettempdir()
2727
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
2828
try:
29-
extra_llm_api_options_dict = {
30-
"guided_decoding_backend": "xgrammar",
31-
"disable_overlap_scheduler":
32-
True, # Guided decoding is not supported with overlap scheduler
33-
}
29+
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
3430

3531
with open(temp_file_path, "w") as f:
3632
yaml.dump(extra_llm_api_options_dict, f)

tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
# Adapted from
22
# https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py
3+
import json
34
import os
5+
import re
46
import tempfile
57

8+
import jsonschema
69
import openai
710
import pytest
811
import yaml
912

10-
from ..test_llm import get_model_path, similar
13+
from ..test_llm import get_model_path
1114
from .openai_server import RemoteOpenAIServer
1215

1316
pytestmark = pytest.mark.threadleak(enabled=False)
1417

1518

16-
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
19+
@pytest.fixture(scope="module")
1720
def model_name():
1821
return "llama-3.1-model/Llama-3.1-8B-Instruct"
1922

2023

2124
@pytest.fixture(scope="module")
22-
def temp_extra_llm_api_options_file(request):
25+
def temp_extra_llm_api_options_file():
2326
temp_dir = tempfile.gettempdir()
2427
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
2528
try:
@@ -37,7 +40,12 @@ def temp_extra_llm_api_options_file(request):
3740
@pytest.fixture(scope="module")
3841
def server(model_name: str, temp_extra_llm_api_options_file: str):
3942
model_path = get_model_path(model_name)
40-
args = ["--extra_llm_api_options", temp_extra_llm_api_options_file]
43+
44+
# Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs.
45+
args = [
46+
"--max_batch_size=8", "--max_seq_len=1024", "--max_num_tokens=1024",
47+
f"--extra_llm_api_options={temp_extra_llm_api_options_file}"
48+
]
4149
with RemoteOpenAIServer(model_path, args) as remote_server:
4250
yield remote_server
4351

@@ -112,12 +120,7 @@ def tool_get_current_date():
112120

113121
def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
114122
tool_get_current_weather, tool_get_current_date):
115-
messages = [
116-
{
117-
"role":
118-
"system",
119-
"content":
120-
f"""
123+
system_prompt = f"""
121124
# Tool Instructions
122125
- Always execute python code in messages that you share.
123126
- When looking for real time information use relevant functions if available else fallback to brave_search
@@ -140,20 +143,24 @@ def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
140143
- Only call one function at a time
141144
- Put the entire function call reply on one line
142145
- Always add your sources when using search results to answer the user query
143-
You are a helpful assistant.""",
146+
You are a helpful assistant."""
147+
user_prompt = "You are in New York. Please get the current date and time, and the weather."
148+
149+
messages = [
150+
{
151+
"role": "system",
152+
"content": system_prompt,
144153
},
145154
{
146-
"role":
147-
"user",
148-
"content":
149-
"You are in New York. Please get the current date and time, and the weather.",
155+
"role": "user",
156+
"content": user_prompt,
150157
},
151158
]
152159

153160
chat_completion = client.chat.completions.create(
154161
model=model_name,
155162
messages=messages,
156-
max_completion_tokens=100,
163+
max_completion_tokens=256,
157164
response_format={
158165
"type":
159166
"structural_tag",
@@ -173,11 +180,18 @@ def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
173180
"triggers": ["<function="],
174181
},
175182
)
176-
assert chat_completion.id is not None
177-
assert len(chat_completion.choices) == 1
183+
178184
message = chat_completion.choices[0].message
179185
assert message.content is not None
180186
assert message.role == "assistant"
181187

182-
reference = '<function=get_current_date>{"timezone": "America/New_York"}</function>\n<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>\n\nSources:\n- get_current_date function\n- get_current_weather function'
183-
assert similar(chat_completion.choices[0].message.content, reference)
188+
match = re.search(r'<function=get_current_weather>([\S\s]+?)</function>',
189+
message.content)
190+
params = json.loads(match.group(1))
191+
jsonschema.validate(params,
192+
tool_get_current_weather["function"]["parameters"])
193+
194+
match = re.search(r'<function=get_current_date>([\S\s]+?)</function>',
195+
message.content)
196+
params = json.loads(match.group(1))
197+
jsonschema.validate(params, tool_get_current_date["function"]["parameters"])

0 commit comments

Comments
 (0)