28
28
cast ,
29
29
)
30
30
31
- from pydantic import ValidationError
32
-
33
31
from neo4j_graphrag .message_history import MessageHistory
34
32
from neo4j_graphrag .types import LLMMessage
35
33
36
34
from ..exceptions import LLMGenerationError
37
35
from .base import LLMInterface
38
- from .rate_limit import RateLimitHandler , rate_limit_handler , async_rate_limit_handler
39
36
from .types import (
40
- BaseMessage ,
41
37
LLMResponse ,
42
- MessageList ,
43
38
ToolCall ,
44
39
ToolCallResponse ,
45
- SystemMessage ,
46
- UserMessage ,
47
40
)
48
41
49
42
from neo4j_graphrag .tool import Tool
50
43
51
44
if TYPE_CHECKING :
52
45
from openai .types .chat import (
53
46
ChatCompletionMessageParam ,
54
- ChatCompletionToolParam ,
55
- )
47
+ ChatCompletionToolParam , ChatCompletionUserMessageParam ,
48
+ ChatCompletionSystemMessageParam , ChatCompletionAssistantMessageParam ,
49
+ )
56
50
from openai import OpenAI , AsyncOpenAI
51
+ from .rate_limit import RateLimitHandler
57
52
else :
58
53
ChatCompletionMessageParam = Any
59
54
ChatCompletionToolParam = Any
60
55
OpenAI = Any
61
56
AsyncOpenAI = Any
57
+ RateLimitHandler = Any
62
58
63
59
64
60
class BaseOpenAILLM (LLMInterface , abc .ABC ):
@@ -93,23 +89,26 @@ def __init__(
93
89
94
90
def get_messages (
95
91
self ,
96
- input : str ,
97
- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
98
- system_instruction : Optional [str ] = None ,
92
+ messages : list [LLMMessage ],
99
93
) -> Iterable [ChatCompletionMessageParam ]:
100
- messages = []
101
- if system_instruction :
102
- messages .append (SystemMessage (content = system_instruction ).model_dump ())
103
- if message_history :
104
- if isinstance (message_history , MessageHistory ):
105
- message_history = message_history .messages
106
- try :
107
- MessageList (messages = cast (list [BaseMessage ], message_history ))
108
- except ValidationError as e :
109
- raise LLMGenerationError (e .errors ()) from e
110
- messages .extend (cast (Iterable [dict [str , Any ]], message_history ))
111
- messages .append (UserMessage (content = input ).model_dump ())
112
- return messages # type: ignore
94
+ chat_messages = []
95
+ for m in messages :
96
+ message_type : ChatCompletionMessageParam
97
+ if m ["role" ] == "system" :
98
+ message_type = ChatCompletionSystemMessageParam
99
+ elif m ["role" ] == "user" :
100
+ message_type = ChatCompletionUserMessageParam
101
+ elif m ["role" ] == "assistant" :
102
+ message_type = ChatCompletionAssistantMessageParam
103
+ else :
104
+ raise ValueError (f"Unknown message type: { m ['role' ]} " )
105
+ chat_messages .append (
106
+ message_type (
107
+ role = m ["role" ],
108
+ content = m ["content" ],
109
+ )
110
+ )
111
+ return chat_messages
113
112
114
113
def _convert_tool_to_openai_format (self , tool : Tool ) -> Dict [str , Any ]:
115
114
"""Convert a Tool object to OpenAI's expected format.
@@ -132,21 +131,15 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]:
132
131
except AttributeError :
133
132
raise LLMGenerationError (f"Tool { tool } is not a valid Tool object" )
134
133
135
- @rate_limit_handler
136
- def invoke (
134
+ def _invoke (
137
135
self ,
138
- input : str ,
139
- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
140
- system_instruction : Optional [str ] = None ,
136
+ input : list [LLMMessage ],
141
137
) -> LLMResponse :
142
138
"""Sends a text input to the OpenAI chat completion model
143
139
and returns the response's content.
144
140
145
141
Args:
146
142
input (str): Text sent to the LLM.
147
- message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
148
- with each message having a specific role assigned.
149
- system_instruction (Optional[str]): An option to override the llm system message for this invocation.
150
143
151
144
Returns:
152
145
LLMResponse: The response from OpenAI.
@@ -155,10 +148,8 @@ def invoke(
155
148
LLMGenerationError: If anything goes wrong.
156
149
"""
157
150
try :
158
- if isinstance (message_history , MessageHistory ):
159
- message_history = message_history .messages
160
151
response = self .client .chat .completions .create (
161
- messages = self .get_messages (input , message_history , system_instruction ),
152
+ messages = self .get_messages (input ),
162
153
model = self .model_name ,
163
154
** self .model_params ,
164
155
)
@@ -167,7 +158,6 @@ def invoke(
167
158
except self .openai .OpenAIError as e :
168
159
raise LLMGenerationError (e )
169
160
170
- @rate_limit_handler
171
161
def invoke_with_tools (
172
162
self ,
173
163
input : str ,
@@ -242,21 +232,15 @@ def invoke_with_tools(
242
232
except self .openai .OpenAIError as e :
243
233
raise LLMGenerationError (e )
244
234
245
- @async_rate_limit_handler
246
- async def ainvoke (
235
+ async def _ainvoke (
247
236
self ,
248
- input : str ,
249
- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
250
- system_instruction : Optional [str ] = None ,
237
+ input : list [LLMMessage ],
251
238
) -> LLMResponse :
252
239
"""Asynchronously sends a text input to the OpenAI chat
253
240
completion model and returns the response's content.
254
241
255
242
Args:
256
243
input (str): Text sent to the LLM.
257
- message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
258
- with each message having a specific role assigned.
259
- system_instruction (Optional[str]): An option to override the llm system message for this invocation.
260
244
261
245
Returns:
262
246
LLMResponse: The response from OpenAI.
@@ -265,10 +249,8 @@ async def ainvoke(
265
249
LLMGenerationError: If anything goes wrong.
266
250
"""
267
251
try :
268
- if isinstance (message_history , MessageHistory ):
269
- message_history = message_history .messages
270
252
response = await self .async_client .chat .completions .create (
271
- messages = self .get_messages (input , message_history , system_instruction ),
253
+ messages = self .get_messages (input ),
272
254
model = self .model_name ,
273
255
** self .model_params ,
274
256
)
@@ -277,7 +259,6 @@ async def ainvoke(
277
259
except self .openai .OpenAIError as e :
278
260
raise LLMGenerationError (e )
279
261
280
- @async_rate_limit_handler
281
262
async def ainvoke_with_tools (
282
263
self ,
283
264
input : str ,
0 commit comments