9
9
from vllm .model_executor .sampling_metadata import SamplingMetadata
10
10
from vllm .utils import get_open_port
11
11
12
+ from ...utils import VLLM_PATH , RemoteOpenAIServer
13
+
14
+ chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
15
+ assert chatml_jinja_path .exists ()
16
+
12
17
13
18
class MyOPTForCausalLM (OPTForCausalLM ):
14
19
@@ -21,12 +26,25 @@ def compute_logits(self, hidden_states: torch.Tensor,
21
26
return logits
22
27
23
28
24
- def server_function (port ):
29
+ def server_function (port : int ):
25
30
# register our dummy model
26
31
ModelRegistry .register_model ("OPTForCausalLM" , MyOPTForCausalLM )
27
- sys .argv = ["placeholder.py" ] + \
28
- ("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
29
- f"--dtype float32 --api-key token-abc123 --port { port } " ).split ()
32
+
33
+ sys .argv = ["placeholder.py" ] + [
34
+ "--model" ,
35
+ "facebook/opt-125m" ,
36
+ "--gpu-memory-utilization" ,
37
+ "0.10" ,
38
+ "--dtype" ,
39
+ "float32" ,
40
+ "--api-key" ,
41
+ "token-abc123" ,
42
+ "--port" ,
43
+ str (port ),
44
+ "--chat-template" ,
45
+ str (chatml_jinja_path ),
46
+ ]
47
+
30
48
import runpy
31
49
runpy .run_module ('vllm.entrypoints.openai.api_server' , run_name = '__main__' )
32
50
@@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
36
54
ctx = torch .multiprocessing .get_context ()
37
55
server = ctx .Process (target = server_function , args = (port , ))
38
56
server .start ()
39
- MAX_SERVER_START_WAIT_S = 60
40
- client = OpenAI (
41
- base_url = f"http://localhost:{ port } /v1" ,
42
- api_key = "token-abc123" ,
43
- )
44
- now = time .time ()
45
- while True :
46
- try :
47
- completion = client .chat .completions .create (
48
- model = "facebook/opt-125m" ,
49
- messages = [{
50
- "role" : "system" ,
51
- "content" : "You are a helpful assistant."
52
- }, {
53
- "role" : "user" ,
54
- "content" : "Hello!"
55
- }],
56
- temperature = 0 ,
57
- )
58
- break
59
- except OpenAIError as e :
60
- if "Connection error" in str (e ):
61
- time .sleep (3 )
62
- if time .time () - now > MAX_SERVER_START_WAIT_S :
63
- raise RuntimeError ("Server did not start in time" ) from e
64
- else :
65
- raise e
66
- server .kill ()
57
+
58
+ try :
59
+ client = OpenAI (
60
+ base_url = f"http://localhost:{ port } /v1" ,
61
+ api_key = "token-abc123" ,
62
+ )
63
+ now = time .time ()
64
+ while True :
65
+ try :
66
+ completion = client .chat .completions .create (
67
+ model = "facebook/opt-125m" ,
68
+ messages = [{
69
+ "role" : "system" ,
70
+ "content" : "You are a helpful assistant."
71
+ }, {
72
+ "role" : "user" ,
73
+ "content" : "Hello!"
74
+ }],
75
+ temperature = 0 ,
76
+ )
77
+ break
78
+ except OpenAIError as e :
79
+ if "Connection error" in str (e ):
80
+ time .sleep (3 )
81
+ if time .time () - now > RemoteOpenAIServer .MAX_START_WAIT_S :
82
+ msg = "Server did not start in time"
83
+ raise RuntimeError (msg ) from e
84
+ else :
85
+ raise e
86
+ finally :
87
+ server .terminate ()
88
+
67
89
generated_text = completion .choices [0 ].message .content
90
+ assert generated_text is not None
68
91
# make sure only the first token is generated
69
92
rest = generated_text .replace ("<s>" , "" )
70
93
assert rest == ""
0 commit comments