1
+ import threading
2
+ from queue import Empty , Queue
3
+ import json
4
+ import traceback
5
+ from typing import Dict , List , Callable
6
+ #from model_downloader import NotEnoughDiskSpaceException, DownloadException
7
+ #from psutil._common import bytes2human
8
+ from llama_interface import LLMInterface
9
+ from llama_params import LLMParams
10
+
11
+
12
+ RAG_PROMPT_FORMAT = "Answer the questions based on the information below. \n {context}\n \n Question: {prompt}"
13
+
14
+ class LLM_SSE_Adapter :
15
+ msg_queue : Queue
16
+ finish : bool
17
+ singal : threading .Event
18
+ llm_interface : LLMInterface
19
+ should_stop : bool
20
+
21
+ def __init__ (self , llm_interface : LLMInterface ):
22
+ self .msg_queue = Queue (- 1 )
23
+ self .finish = False
24
+ self .singal = threading .Event ()
25
+ self .llm_interface = llm_interface
26
+ self .should_stop = False
27
+
28
+ def put_msg (self , data ):
29
+ self .msg_queue .put_nowait (data )
30
+ self .singal .set ()
31
+
32
+ def load_model_callback (self , event : str ):
33
+ data = {"type" : "load_model" , "event" : event }
34
+ self .put_msg (data )
35
+
36
+ def text_in_callback (self , msg : str ):
37
+ data = {"type" : "text_in" , "value" : msg }
38
+ self .put_msg (data )
39
+
40
+ def text_out_callback (self , msg : str , type = 1 ):
41
+ data = {"type" : "text_out" , "value" : msg , "dtype" : type }
42
+ self .put_msg (data )
43
+
44
+ def first_latency_callback (self , first_latency : str ):
45
+ data = {"type" : "first_token_latency" , "value" : first_latency }
46
+ self .put_msg (data )
47
+
48
+ def after_latency_callback (self , after_latency : str ):
49
+ data = {"type" : "after_token_latency" , "value" : after_latency }
50
+ self .put_msg (data )
51
+
52
+ def sr_latency_callback (self , sr_latency : str ):
53
+ data = {"type" : "sr_latency" , "value" : sr_latency }
54
+ self .put_msg (data )
55
+
56
+ def error_callback (self , ex : Exception ):
57
+ if (
58
+ isinstance (ex , NotImplementedError )
59
+ and ex .__str__ () == "Access to repositories lists is not implemented."
60
+ ):
61
+ self .put_msg (
62
+ {
63
+ "type" : "error" ,
64
+ "err_type" : "repositories_not_found" ,
65
+ }
66
+ )
67
+ # elif isinstance(ex, NotEnoughDiskSpaceException):
68
+ # self.put_msg(
69
+ # {
70
+ # "type": "error",
71
+ # "err_type": "not_enough_disk_space",
72
+ # "need": bytes2human(ex.requires_space),
73
+ # "free": bytes2human(ex.free_space),
74
+ # }
75
+ # )
76
+ # elif isinstance(ex, DownloadException):
77
+ # self.put_msg({"type": "error", "err_type": "download_exception"})
78
+ # # elif isinstance(ex, llm_biz.StopGenerateException):
79
+ # # pass
80
+ elif isinstance (ex , RuntimeError ):
81
+ self .put_msg ({"type" : "error" , "err_type" : "runtime_error" })
82
+ else :
83
+ self .put_msg ({"type" : "error" , "err_type" : "unknow_exception" })
84
+ print (f"exception:{ str (ex )} " )
85
+
86
+ def text_conversation (self , params : LLMParams ):
87
+ thread = threading .Thread (
88
+ target = self .text_conversation_run ,
89
+ args = [params ],
90
+ )
91
+ thread .start ()
92
+ return self .generator ()
93
+
94
+
95
+ def stream_function (self , stream ):
96
+ for output in stream :
97
+ if self .llm_interface .stop_generate :
98
+ self .llm_interface .stop_generate = False
99
+ break
100
+
101
+ if self .llm_interface .get_backend_type () == "ipex_llm" :
102
+ # transformer style
103
+ self .text_out_callback (output )
104
+ else :
105
+ # openai style
106
+ self .text_out_callback (output ["choices" ][0 ]["delta" ].get ("content" ,"" ))
107
+ self .put_msg ({"type" : "finish" })
108
+
109
+ def text_conversation_run (
110
+ self ,
111
+ params : LLMParams ,
112
+ ):
113
+ try :
114
+ print ("sdnmsd" , self .llm_interface )
115
+ if (not self .llm_interface ._model ):
116
+ self .load_model_callback ('start' )
117
+ self .llm_interface .load_model (params )
118
+ self .load_model_callback ('finish' )
119
+
120
+ prompt = params .prompt
121
+ if params .enable_rag :
122
+ last_prompt = prompt [prompt .__len__ () - 1 ]
123
+ last_prompt .__setitem__ (
124
+ "question" , process_rag (last_prompt .get ("question" ), params .device )
125
+ )
126
+
127
+ full_prompt = convert_prompt (prompt )
128
+ stream = self .llm_interface .create_chat_completion (full_prompt )
129
+ self .stream_function (stream )
130
+
131
+ except Exception as ex :
132
+ traceback .print_exc ()
133
+ self .error_callback (ex )
134
+ finally :
135
+ self .finish = True
136
+ self .singal .set ()
137
+
138
+ def generator (self ):
139
+ while True :
140
+ while not self .msg_queue .empty ():
141
+ try :
142
+ data = self .msg_queue .get_nowait ()
143
+ msg = f"data:{ json .dumps (data )} \0 "
144
+ print (msg )
145
+ yield msg
146
+ except Empty (Exception ):
147
+ break
148
+ if not self .finish :
149
+ self .singal .clear ()
150
+ self .singal .wait ()
151
+ else :
152
+ break
153
+
154
+
155
+ _default_prompt = {
156
+ "role" : "system" ,
157
+ "content" : "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. Please keep the output text language the same as the user input." ,
158
+ }
159
+
160
+ def convert_prompt (prompt : List [Dict [str , str ]]):
161
+ chat_history = [_default_prompt ]
162
+ prompt_len = prompt .__len__ ()
163
+ i = 0
164
+ while i < prompt_len :
165
+ chat_history .append ({"role" : "user" , "content" : prompt [i ].get ("question" )})
166
+ if i < prompt_len - 1 :
167
+ chat_history .append (
168
+ {"role" : "assistant" , "content" : prompt [i ].get ("answer" )}
169
+ )
170
+ i = i + 1
171
+ return chat_history
172
+
173
+
174
+ def process_rag (
175
+ prompt : str ,
176
+ device : str ,
177
+ text_out_callback : Callable [[str , int ], None ] = None ,
178
+ ):
179
+ import rag
180
+ rag .to (device )
181
+ query_success , context , rag_source = rag .query (prompt )
182
+ if query_success :
183
+ print ("rag query input\r \n {}output:\r \n {}" .format (prompt , context ))
184
+ prompt = RAG_PROMPT_FORMAT .format (prompt = prompt , context = context )
185
+ if text_out_callback is not None :
186
+ text_out_callback (rag_source , 2 )
187
+ return prompt
0 commit comments