Skip to content

Commit 41ccc69

Browse files
committed
Add models endpoint
1 parent 6401f55 commit 41ccc69

File tree

2 files changed

+91
-52
lines changed

2 files changed

+91
-52
lines changed

api/models.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
4+
from dataclasses import dataclass
5+
6+
from download import is_model_downloaded, load_model_configs
7+
from pwd import getpwuid
8+
9+
import os
10+
import time
11+
12+
@dataclass
13+
class ModelInfo:
14+
"""Information about a model that can be used to generate completions."""
15+
id: str
16+
created: int
17+
owner: str
18+
object: str = "model"
19+
20+
21+
@dataclass
22+
class ModelInfoResponse:
23+
"""A list of models that can be used to generate completions."""
24+
data: List[ModelInfo]
25+
object: str = "list"
26+
27+
28+
def get_model_info_list(args) -> ModelInfoResponse:
29+
"""Returns a list of models that can be used to generate completions."""
30+
data = []
31+
for model_id, model_config in load_model_configs().items():
32+
model_dir = args.model_directory
33+
if is_model_downloaded(model_id, model_dir):
34+
path = model_dir / model_id
35+
created = int(os.path.getctime(path))
36+
owner = getpwuid(os.stat(path).st_uid).pw_name
37+
38+
data.append(ModelInfo(id=model_config.name, created=created, owner = owner))
39+
response = ModelInfoResponse(data=data)
40+
return response

server.py

+51-52
Original file line numberDiff line numberDiff line change
@@ -10,78 +10,78 @@
1010
from typing import Dict, List, Union
1111

1212
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
13+
from api.models import get_model_info_list, ModelInfoResponse
1314

1415
from build.builder import BuilderArgs, TokenizerArgs
1516
from flask import Flask, request, Response
1617
from generate import GeneratorArgs
18+
from download import load_model_configs, is_model_downloaded
1719

1820

19-
"""
20-
Creates a flask app that can be used to serve the model as a chat API.
21-
"""
22-
app = Flask(__name__)
23-
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
24-
messages: list = []
25-
gen: OpenAiApiGenerator = None
21+
def create_app(args):
22+
"""
23+
Creates a flask app that can be used to serve the model as a chat API.
24+
"""
25+
app = Flask(__name__)
2626

27+
gen: OpenAiApiGenerator = initialize_generator(args)
2728

28-
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
29-
"""Recursively delete None values from a dictionary."""
30-
if type(d) is dict:
31-
return {k: _del_none(v) for k, v in d.items() if v}
32-
elif type(d) is list:
33-
return [_del_none(v) for v in d if v]
34-
return d
3529

30+
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
31+
"""Recursively delete None values from a dictionary."""
32+
if type(d) is dict:
33+
return {k: _del_none(v) for k, v in d.items() if v}
34+
elif type(d) is list:
35+
return [_del_none(v) for v in d if v]
36+
return d
3637

37-
@app.route("/chat", methods=["POST"])
38-
def chat_endpoint():
39-
"""
40-
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
41-
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
4238

43-
** Warning ** : Not all arguments of the CompletionRequest are consumed.
39+
@app.route("/chat", methods=["POST"])
40+
def chat_endpoint():
41+
"""
42+
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
43+
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
4444
45-
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
45+
** Warning ** : Not all arguments of the CompletionRequest are consumed.
4646
47-
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
48-
a single CompletionResponse object will be returned.
49-
"""
47+
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
48+
49+
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
50+
a single CompletionResponse object will be returned.
51+
"""
5052

51-
print(" === Completion Request ===")
53+
print(" === Completion Request ===")
5254

53-
# Parse the request in to a CompletionRequest object
54-
data = request.get_json()
55-
req = CompletionRequest(**data)
55+
# Parse the request in to a CompletionRequest object
56+
data = request.get_json()
57+
req = CompletionRequest(**data)
5658

57-
# Add the user message to our internal message history.
58-
messages.append(UserMessage(**req.messages[-1]))
59+
if data.get("stream") == "true":
5960

60-
if data.get("stream") == "true":
61+
def chunk_processor(chunked_completion_generator):
62+
"""Inline function for postprocessing CompletionResponseChunk objects.
6163
62-
def chunk_processor(chunked_completion_generator):
63-
"""Inline function for postprocessing CompletionResponseChunk objects.
64+
Here, we just jsonify the chunk and yield it as a string.
65+
"""
66+
for chunk in chunked_completion_generator:
67+
if (next_tok := chunk.choices[0].delta.content) is None:
68+
next_tok = ""
69+
print(next_tok, end="")
70+
yield json.dumps(_del_none(asdict(chunk)))
6471

65-
Here, we just jsonify the chunk and yield it as a string.
66-
"""
67-
messages.append(AssistantMessage(content=""))
68-
for chunk in chunked_completion_generator:
69-
if (next_tok := chunk.choices[0].delta.content) is None:
70-
next_tok = ""
71-
messages[-1].content += next_tok
72-
print(next_tok, end="")
73-
yield json.dumps(_del_none(asdict(chunk)))
72+
return Response(
73+
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
74+
)
75+
else:
76+
response = gen.sync_completion(req)
7477

75-
return Response(
76-
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
77-
)
78-
else:
79-
response = gen.sync_completion(req)
78+
return json.dumps(_del_none(asdict(response)))
8079

81-
messages.append(response.choices[0].message)
82-
print(messages[-1].content)
80+
@app.route("/models", methods=["GET"])
81+
def models_endpoint():
82+
return json.dumps(asdict(get_model_info_list(args)))
8383

84-
return json.dumps(_del_none(asdict(response)))
84+
return app
8585

8686

8787
def initialize_generator(args) -> OpenAiApiGenerator:
@@ -103,6 +103,5 @@ def initialize_generator(args) -> OpenAiApiGenerator:
103103

104104

105105
def main(args):
106-
global gen
107-
gen = initialize_generator(args)
106+
app = create_app(args)
108107
app.run()

0 commit comments

Comments
 (0)