Skip to content

Commit e3bd1e1

Browse files
vmpurivmpuri
authored and
vmpuri
committed
Make API and server compatible with OpenAI API
1 parent c7f56f2 commit e3bd1e1

File tree

4 files changed

+89
-97
lines changed

4 files changed

+89
-97
lines changed

README.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,6 @@ This mode generates text based on an input prompt.
181181
python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy and his bear"
182182
```
183183

184-
### Browser
185-
This mode allows you to chat with the model using a UI in your browser
186-
Running the command automatically open a tab in your browser.
187-
188-
[skip default]: begin
189-
190-
```
191-
streamlit run torchchat.py -- browser llama3.1
192-
```
193-
194184
[skip default]: end
195185

196186
### Server
@@ -252,6 +242,19 @@ curl http://127.0.0.1:5000/v1/chat \
252242

253243
</details>
254244

245+
### Browser
246+
This command opens a basic browser interface for local chat by querying a local server.
247+
248+
First, follow the steps in the Server section above to start a local server. Then, in another terminal, launch the interface. Running the following will open a tab in your browser.
249+
250+
[skip default]: begin
251+
252+
```
253+
streamlit run browser/browser.py
254+
```
255+
256+
Use the "Max Response Tokens" slider to limit the maximum number of tokens generated by the model for each response. Click the "Reset Chat" button to remove the message history and start a fresh chat.
257+
255258

256259
## Desktop/Server Execution
257260

api/api.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ class CompletionRequest:
125125
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
126126
user: Optional[str] = None # unimplemented
127127

128+
def __post_init__(self):
129+
self.stream = bool(self.stream)
130+
128131

129132
@dataclass
130133
class CompletionChoice:
@@ -204,7 +207,7 @@ class CompletionResponseChunk:
204207
choices: List[CompletionChoiceChunk]
205208
created: int
206209
model: str
207-
system_fingerprint: str
210+
system_fingerprint: Optional[str] = None
208211
service_tier: Optional[str] = None
209212
object: str = "chat.completion.chunk"
210213
usage: Optional[UsageStats] = None
@@ -311,7 +314,7 @@ def callback(x, *, done_generating=False):
311314
sequential_prefill=generator_args.sequential_prefill,
312315
start_pos=start_pos,
313316
max_seq_length=self.max_seq_length,
314-
seed=int(completion_request.seed),
317+
seed=int(completion_request.seed or 0),
315318
):
316319
if y is None:
317320
continue
@@ -333,9 +336,10 @@ def callback(x, *, done_generating=False):
333336
choice_chunk = CompletionChoiceChunk(
334337
delta=chunk_delta,
335338
index=idx,
339+
finish_reason=None,
336340
)
337341
chunk_response = CompletionResponseChunk(
338-
id=str(id),
342+
id="chatcmpl-" + str(id),
339343
choices=[choice_chunk],
340344
created=int(time.time()),
341345
model=completion_request.model,
@@ -351,7 +355,7 @@ def callback(x, *, done_generating=False):
351355
)
352356

353357
yield CompletionResponseChunk(
354-
id=str(id),
358+
id="chatcmpl-" + str(id),
355359
choices=[end_chunk],
356360
created=int(time.time()),
357361
model=completion_request.model,
@@ -367,7 +371,7 @@ def sync_completion(self, request: CompletionRequest):
367371

368372
message = AssistantMessage(content=output)
369373
return CompletionResponse(
370-
id=str(uuid.uuid4()),
374+
id="chatcmpl-" + str(uuid.uuid4()),
371375
choices=[
372376
CompletionChoice(
373377
finish_reason="stop",

browser/browser.py

Lines changed: 61 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,85 +7,68 @@
77
import time
88

99
import streamlit as st
10-
from api.api import CompletionRequest, OpenAiApiGenerator
11-
12-
from build.builder import BuilderArgs, TokenizerArgs
13-
14-
from generate import GeneratorArgs
15-
16-
17-
def main(args):
18-
builder_args = BuilderArgs.from_args(args)
19-
speculative_builder_args = BuilderArgs.from_speculative_args(args)
20-
tokenizer_args = TokenizerArgs.from_args(args)
21-
generator_args = GeneratorArgs.from_args(args)
22-
generator_args.chat_mode = False
23-
24-
@st.cache_resource
25-
def initialize_generator() -> OpenAiApiGenerator:
26-
return OpenAiApiGenerator(
27-
builder_args,
28-
speculative_builder_args,
29-
tokenizer_args,
30-
generator_args,
31-
args.profile,
32-
args.quantize,
33-
args.draft_quantize,
34-
)
35-
36-
gen = initialize_generator()
37-
38-
st.title("torchchat")
39-
40-
# Initialize chat history
41-
if "messages" not in st.session_state:
42-
st.session_state.messages = []
43-
44-
# Display chat messages from history on app rerun
45-
for message in st.session_state.messages:
46-
with st.chat_message(message["role"]):
47-
st.markdown(message["content"])
48-
49-
# Accept user input
50-
if prompt := st.chat_input("What is up?"):
51-
# Add user message to chat history
52-
st.session_state.messages.append({"role": "user", "content": prompt})
53-
# Display user message in chat message container
54-
with st.chat_message("user"):
55-
st.markdown(prompt)
56-
57-
# Display assistant response in chat message container
58-
with st.chat_message("assistant"), st.status(
59-
"Generating... ", expanded=True
60-
) as status:
61-
62-
req = CompletionRequest(
63-
model=gen.builder_args.checkpoint_path,
64-
prompt=prompt,
65-
temperature=generator_args.temperature,
66-
messages=[],
10+
from openai import OpenAI
11+
12+
st.title("torchchat")
13+
14+
start_state = [
15+
{
16+
"role": "system",
17+
"content": "You're an assistant. Answer questions directly, be brief, and have fun.",
18+
},
19+
{"role": "assistant", "content": "How can I help you?"},
20+
]
21+
22+
with st.sidebar:
23+
response_max_tokens = st.slider(
24+
"Max Response Tokens", min_value=10, max_value=1000, value=250, step=10
25+
)
26+
if st.button("Reset Chat", type="primary"):
27+
st.session_state["messages"] = start_state
28+
29+
if "messages" not in st.session_state:
30+
st.session_state["messages"] = start_state
31+
32+
33+
for msg in st.session_state.messages:
34+
st.chat_message(msg["role"]).write(msg["content"])
35+
36+
if prompt := st.chat_input():
37+
client = OpenAI(
38+
base_url="http://127.0.0.1:5000/v1",
39+
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
40+
)
41+
42+
st.session_state.messages.append({"role": "user", "content": prompt})
43+
st.chat_message("user").write(prompt)
44+
45+
with st.chat_message("assistant"), st.status(
46+
"Generating... ", expanded=True
47+
) as status:
48+
49+
def get_streamed_completion(completion_generator):
50+
start = time.time()
51+
tokcount = 0
52+
for chunk in completion_generator:
53+
tokcount += 1
54+
yield chunk.choices[0].delta.content
55+
56+
status.update(
57+
label="Done, averaged {:.2f} tokens/second".format(
58+
tokcount / (time.time() - start)
59+
),
60+
state="complete",
6761
)
6862

69-
def unwrap(completion_generator):
70-
start = time.time()
71-
tokcount = 0
72-
for chunk_response in completion_generator:
73-
content = chunk_response.choices[0].delta.content
74-
if not gen.is_llama3_model or content not in set(
75-
gen.tokenizer.special_tokens.keys()
76-
):
77-
yield content
78-
if content == gen.tokenizer.eos_id():
79-
yield "."
80-
tokcount += 1
81-
status.update(
82-
label="Done, averaged {:.2f} tokens/second".format(
83-
tokcount / (time.time() - start)
84-
),
85-
state="complete",
63+
response = st.write_stream(
64+
get_streamed_completion(
65+
client.chat.completions.create(
66+
model="llama3",
67+
messages=st.session_state.messages,
68+
max_tokens=response_max_tokens,
69+
stream=True,
8670
)
71+
)
72+
)[0]
8773

88-
response = st.write_stream(unwrap(gen.completion(req)))
89-
90-
# Add assistant response to chat history
91-
st.session_state.messages.append({"role": "assistant", "content": response})
74+
st.session_state.messages.append({"role": "assistant", "content": response})

server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
4141
return [_del_none(v) for v in d if v]
4242
return d
4343

44-
@app.route(f"/{OPENAI_API_VERSION}/chat", methods=["POST"])
44+
@app.route(f"/{OPENAI_API_VERSION}/chat/completions", methods=["POST"])
4545
def chat_endpoint():
4646
"""
4747
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
@@ -63,7 +63,7 @@ def chat_endpoint():
6363
data = request.get_json()
6464
req = CompletionRequest(**data)
6565

66-
if data.get("stream") == "true":
66+
if req.stream:
6767

6868
def chunk_processor(chunked_completion_generator):
6969
"""Inline function for postprocessing CompletionResponseChunk objects.
@@ -74,14 +74,16 @@ def chunk_processor(chunked_completion_generator):
7474
if (next_tok := chunk.choices[0].delta.content) is None:
7575
next_tok = ""
7676
print(next_tok, end="", flush=True)
77-
yield json.dumps(_del_none(asdict(chunk)))
77+
yield f"data:{json.dumps(_del_none(asdict(chunk)))}\n\n"
7878

79-
return Response(
79+
resp = Response(
8080
chunk_processor(gen.chunked_completion(req)),
8181
mimetype="text/event-stream",
8282
)
83+
return resp
8384
else:
8485
response = gen.sync_completion(req)
86+
print(response.choices[0].message.content)
8587

8688
return json.dumps(_del_none(asdict(response)))
8789

0 commit comments

Comments
 (0)