Skip to content

Commit 0112192

Browse files
committed
add parametrized fixture
1 parent 2ba5d67 commit 0112192

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

tests/entrypoints/openai/test_basic.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from http import HTTPStatus
2+
from typing import List
23

34
import openai
45
import pytest
@@ -12,8 +13,44 @@
1213
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
1314

1415

16+
@pytest.fixture(scope='module')
17+
def server_args(request: pytest.FixtureRequest) -> List[str]:
18+
""" Provide extra arguments to the server via indirect parametrization
19+
20+
Usage:
21+
22+
>>> @pytest.mark.parametrize(
23+
>>> "server_args",
24+
>>> [
25+
>>> ["--disable-frontend-multiprocessing"],
26+
>>> [
27+
>>> "--model=NousResearch/Hermes-3-Llama-3.1-70B",
28+
>>> "--enable-auto-tool-choice",
29+
>>> ],
30+
>>> ],
31+
>>> indirect=True,
32+
>>> )
33+
>>> def test_foo(server, client):
34+
>>> ...
35+
36+
This will run `test_foo` twice with servers with:
37+
- `--disable-frontend-multiprocessing`
38+
- `--model=NousResearch/Hermes-3-Llama-3.1-70B --enable-auto-tool-choice`.
39+
40+
"""
41+
if not hasattr(request, "param"):
42+
return []
43+
44+
val = request.param
45+
46+
if isinstance(val, str):
47+
return [val]
48+
49+
return request.param
50+
51+
1552
@pytest.fixture(scope="module")
16-
def server():
53+
def server(server_args):
1754
args = [
1855
# use half precision for speed and memory savings in CI environment
1956
"--dtype",
@@ -23,6 +60,7 @@ def server():
2360
"--enforce-eager",
2461
"--max-num-seqs",
2562
"128",
63+
*server_args,
2664
]
2765

2866
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -35,6 +73,15 @@ async def client(server):
3573
yield async_client
3674

3775

76+
@pytest.mark.parametrize(
77+
"server_args",
78+
[
79+
pytest.param([], id="default-frontend-multiprocessing"),
80+
pytest.param(["--disable-frontend-multiprocessing"],
81+
id="disable-frontend-multiprocessing")
82+
],
83+
indirect=True,
84+
)
3885
@pytest.mark.asyncio
3986
async def test_show_version(client: openai.AsyncOpenAI):
4087
base_url = str(client.base_url)[:-3].strip("/")
@@ -45,6 +92,15 @@ async def test_show_version(client: openai.AsyncOpenAI):
4592
assert response.json() == {"version": VLLM_VERSION}
4693

4794

95+
@pytest.mark.parametrize(
96+
"server_args",
97+
[
98+
pytest.param([], id="default-frontend-multiprocessing"),
99+
pytest.param(["--disable-frontend-multiprocessing"],
100+
id="disable-frontend-multiprocessing")
101+
],
102+
indirect=True,
103+
)
48104
@pytest.mark.asyncio
49105
async def test_check_health(client: openai.AsyncOpenAI):
50106
base_url = str(client.base_url)[:-3].strip("/")

0 commit comments

Comments
 (0)