Skip to content

Commit a386b20

Browse files
authored
Merge pull request #114 from TNG/llama-cpp-squashed
**Description:** This adds a new python backend which uses python-llama-cpp as an inference backend for the Answer section. This allows users to run text generation using single file GGUF models. **Changes Made:** * add llama.cpp backend * add installation management for llama.cpp * adjust build scripts * adjust add model dialog **Testing Done:** Tested locally on BMG. **Checklist:** - [x] I have tested the changes locally. - [x] I have self-reviewed the code changes.
2 parents 2557c88 + ca365d1 commit a386b20

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+915
-70
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ WebUI/external/service/
1010
*.7z
1111
*.whl
1212
ComfyUI/
13+
env
14+
llama-cpp-env/
1315
*env_tmp/
1416
*service_tmp/
1517
*-env/

LlamaCPP/.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
.vscode/
2+
__pycache__/
3+
models/llm/
4+
temp/
5+
test/
6+
dist/
7+
build/
8+
cache/
9+
test/
10+
env/
11+
12+
!tools/*.exe

LlamaCPP/llama_adapter.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import threading
2+
from queue import Empty, Queue
3+
import json
4+
import traceback
5+
from typing import Dict, List, Callable
6+
#from model_downloader import NotEnoughDiskSpaceException, DownloadException
7+
#from psutil._common import bytes2human
8+
from llama_interface import LLMInterface
9+
from llama_params import LLMParams
10+
11+
12+
RAG_PROMPT_FORMAT = "Answer the questions based on the information below. \n{context}\n\nQuestion: {prompt}"
13+
14+
class LLM_SSE_Adapter:
15+
msg_queue: Queue
16+
finish: bool
17+
singal: threading.Event
18+
llm_interface: LLMInterface
19+
should_stop: bool
20+
21+
def __init__(self, llm_interface: LLMInterface):
22+
self.msg_queue = Queue(-1)
23+
self.finish = False
24+
self.singal = threading.Event()
25+
self.llm_interface = llm_interface
26+
self.should_stop = False
27+
28+
def put_msg(self, data):
29+
self.msg_queue.put_nowait(data)
30+
self.singal.set()
31+
32+
def load_model_callback(self, event: str):
33+
data = {"type": "load_model", "event": event}
34+
self.put_msg(data)
35+
36+
def text_in_callback(self, msg: str):
37+
data = {"type": "text_in", "value": msg}
38+
self.put_msg(data)
39+
40+
def text_out_callback(self, msg: str, type=1):
41+
data = {"type": "text_out", "value": msg, "dtype": type}
42+
self.put_msg(data)
43+
44+
def first_latency_callback(self, first_latency: str):
45+
data = {"type": "first_token_latency", "value": first_latency}
46+
self.put_msg(data)
47+
48+
def after_latency_callback(self, after_latency: str):
49+
data = {"type": "after_token_latency", "value": after_latency}
50+
self.put_msg(data)
51+
52+
def sr_latency_callback(self, sr_latency: str):
53+
data = {"type": "sr_latency", "value": sr_latency}
54+
self.put_msg(data)
55+
56+
def error_callback(self, ex: Exception):
57+
if (
58+
isinstance(ex, NotImplementedError)
59+
and ex.__str__() == "Access to repositories lists is not implemented."
60+
):
61+
self.put_msg(
62+
{
63+
"type": "error",
64+
"err_type": "repositories_not_found",
65+
}
66+
)
67+
# elif isinstance(ex, NotEnoughDiskSpaceException):
68+
# self.put_msg(
69+
# {
70+
# "type": "error",
71+
# "err_type": "not_enough_disk_space",
72+
# "need": bytes2human(ex.requires_space),
73+
# "free": bytes2human(ex.free_space),
74+
# }
75+
# )
76+
# elif isinstance(ex, DownloadException):
77+
# self.put_msg({"type": "error", "err_type": "download_exception"})
78+
# # elif isinstance(ex, llm_biz.StopGenerateException):
79+
# # pass
80+
elif isinstance(ex, RuntimeError):
81+
self.put_msg({"type": "error", "err_type": "runtime_error"})
82+
else:
83+
self.put_msg({"type": "error", "err_type": "unknow_exception"})
84+
print(f"exception:{str(ex)}")
85+
86+
def text_conversation(self, params: LLMParams):
87+
thread = threading.Thread(
88+
target=self.text_conversation_run,
89+
args=[params],
90+
)
91+
thread.start()
92+
return self.generator()
93+
94+
95+
def stream_function(self, stream):
96+
for output in stream:
97+
if self.llm_interface.stop_generate:
98+
self.llm_interface.stop_generate = False
99+
break
100+
101+
if self.llm_interface.get_backend_type() == "ipex_llm":
102+
# transformer style
103+
self.text_out_callback(output)
104+
else:
105+
# openai style
106+
self.text_out_callback(output["choices"][0]["delta"].get("content",""))
107+
self.put_msg({"type": "finish"})
108+
109+
def text_conversation_run(
110+
self,
111+
params: LLMParams,
112+
):
113+
try:
114+
print("sdnmsd", self.llm_interface)
115+
if (not self.llm_interface._model):
116+
self.load_model_callback('start')
117+
self.llm_interface.load_model(params)
118+
self.load_model_callback('finish')
119+
120+
prompt = params.prompt
121+
if params.enable_rag:
122+
last_prompt = prompt[prompt.__len__() - 1]
123+
last_prompt.__setitem__(
124+
"question", process_rag(last_prompt.get("question"), params.device)
125+
)
126+
127+
full_prompt = convert_prompt(prompt)
128+
stream = self.llm_interface.create_chat_completion(full_prompt)
129+
self.stream_function(stream)
130+
131+
except Exception as ex:
132+
traceback.print_exc()
133+
self.error_callback(ex)
134+
finally:
135+
self.finish = True
136+
self.singal.set()
137+
138+
def generator(self):
139+
while True:
140+
while not self.msg_queue.empty():
141+
try:
142+
data = self.msg_queue.get_nowait()
143+
msg = f"data:{json.dumps(data)}\0"
144+
print(msg)
145+
yield msg
146+
except Empty(Exception):
147+
break
148+
if not self.finish:
149+
self.singal.clear()
150+
self.singal.wait()
151+
else:
152+
break
153+
154+
155+
_default_prompt = {
156+
"role": "system",
157+
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. Please keep the output text language the same as the user input.",
158+
}
159+
160+
def convert_prompt(prompt: List[Dict[str, str]]):
161+
chat_history = [_default_prompt]
162+
prompt_len = prompt.__len__()
163+
i = 0
164+
while i < prompt_len:
165+
chat_history.append({"role": "user", "content": prompt[i].get("question")})
166+
if i < prompt_len - 1:
167+
chat_history.append(
168+
{"role": "assistant", "content": prompt[i].get("answer")}
169+
)
170+
i = i + 1
171+
return chat_history
172+
173+
174+
def process_rag(
175+
prompt: str,
176+
device: str,
177+
text_out_callback: Callable[[str, int], None] = None,
178+
):
179+
import rag
180+
rag.to(device)
181+
query_success, context, rag_source = rag.query(prompt)
182+
if query_success:
183+
print("rag query input\r\n{}output:\r\n{}".format(prompt, context))
184+
prompt = RAG_PROMPT_FORMAT.format(prompt=prompt, context=context)
185+
if text_out_callback is not None:
186+
text_out_callback(rag_source, 2)
187+
return prompt

LlamaCPP/llama_cpp_backend.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Dict, List
2+
from os import path
3+
from llama_interface import LLMInterface
4+
from llama_cpp import CreateChatCompletionStreamResponse, Iterator, Llama
5+
from llama_params import LLMParams
6+
import model_config
7+
import gc
8+
9+
class LlamaCpp(LLMInterface):
10+
def __init__(self):
11+
self._model = None
12+
self.stop_generate = False
13+
self._last_repo_id = None
14+
15+
def load_model(self, params: LLMParams, n_gpu_layers: int = -1, context_length: int = 16000):
16+
model_repo_id = params.model_repo_id
17+
if self._model is None or self._last_repo_id != model_repo_id:
18+
self.unload_model()
19+
20+
model_base_path = model_config.llamaCppConfig.get("ggufLLM")
21+
namespace, repo, *model = model_repo_id.split("/")
22+
model_path = path.abspath(path.join(model_base_path,"---".join([namespace, repo]), "---".join(model)))
23+
24+
self._model = Llama(
25+
model_path=model_path,
26+
n_gpu_layers=n_gpu_layers,
27+
n_ctx=context_length,
28+
)
29+
30+
self._last_repo_id = model_repo_id
31+
32+
def create_chat_completion(self, messages: List[Dict[str, str]]):
33+
completion: Iterator[CreateChatCompletionStreamResponse] = self._model.create_chat_completion(
34+
messages=messages,
35+
stream=True,
36+
)
37+
return completion
38+
39+
def unload_model(self):
40+
if self._model is not None:
41+
self._model.close()
42+
del self._model
43+
gc.collect()
44+
self._model = None
45+
46+
def get_backend_type(self):
47+
return "llama_cpp"

LlamaCPP/llama_cpp_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import requests
2+
3+
4+
url = "http://127.0.0.1:59003/api/llm/chat"
5+
params = {
6+
"prompt": [{"question": "Who is the president of the United States in 5 years?" }],
7+
"device": "",
8+
"enable_rag": False,
9+
"model_repo_id": "meta-llama-3.1-8b-instruct-q5_k_m.gguf",
10+
}
11+
response = requests.post(url, json=params, stream=True)
12+
# Check if the response status code is 200 (OK)
13+
response.raise_for_status()
14+
e = 1
15+
# Iterate over the response lines
16+
for line in response.iter_lines():
17+
e += 1
18+
if line:
19+
# Decode the line (assuming UTF-8 encoding)
20+
decoded_line = line.decode('utf-8')
21+
22+
# SSE events typically start with "data: "
23+
if decoded_line.startswith("data:"):
24+
# Extract the data part
25+
data = decoded_line[len("data:"):]
26+
print(data) # Process the data as needed

LlamaCPP/llama_interface.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, List, Optional
3+
from llama_params import LLMParams
4+
5+
class LLMInterface(ABC):
6+
stop_generate: bool
7+
_model: Optional[object]
8+
9+
@abstractmethod
10+
def load_model(self, params: LLMParams, **kwargs):
11+
pass
12+
13+
@abstractmethod
14+
def unload_model(self):
15+
pass
16+
17+
@abstractmethod
18+
def create_chat_completion(self, messages: List[Dict[str, str]]):
19+
pass
20+
21+
@abstractmethod
22+
def get_backend_type(self):
23+
pass
24+

LlamaCPP/llama_params.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import Dict, List
2+
3+
class LLMParams:
4+
prompt: List[Dict[str, str]]
5+
device: int
6+
enable_rag: bool
7+
model_repo_id: str
8+
9+
def __init__(
10+
self, prompt: list, device: int, enable_rag: bool, model_repo_id: str
11+
) -> None:
12+
self.prompt = prompt
13+
self.device = device
14+
self.enable_rag = enable_rag
15+
self.model_repo_id = model_repo_id

0 commit comments

Comments
 (0)