Skip to content

Commit 7b69305

Browse files
committed
Update LLMInterface to restore LC compatibility
1 parent 2d7abb8 commit 7b69305

File tree

4 files changed

+82
-61
lines changed

4 files changed

+82
-61
lines changed
Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
2+
from neo4j_graphrag.message_history import InMemoryMessageHistory
3+
from neo4j_graphrag.types import LLMMessage
24

35
# set api key here on in the OPENAI_API_KEY env var
46
api_key = None
57

8+
messages: list[LLMMessage] = [
9+
{
10+
"role": "system",
11+
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
12+
},
13+
{
14+
"role": "user",
15+
"content": "say something",
16+
},
17+
]
18+
19+
620
llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)
7-
res: LLMResponse = llm.invoke("say something")
21+
res: LLMResponse = llm.invoke(
22+
# "say something",
23+
# messages,
24+
InMemoryMessageHistory(
25+
messages=messages,
26+
)
27+
)
828
print(res.content)

src/neo4j_graphrag/llm/base.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@
1919

2020
from neo4j_graphrag.message_history import MessageHistory
2121
from neo4j_graphrag.types import LLMMessage
22+
from .rate_limit import rate_limit_handler
2223

2324
from .types import LLMResponse, ToolCallResponse
2425
from .rate_limit import (
2526
DEFAULT_RATE_LIMIT_HANDLER,
27+
async_rate_limit_handler,
2628
)
2729

2830
from neo4j_graphrag.tool import Tool
2931

3032
from .rate_limit import RateLimitHandler
33+
from .utils import legacy_inputs_to_message_history
3134

3235

3336
class LLMInterface(ABC):
@@ -55,20 +58,27 @@ def __init__(
5558
else:
5659
self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER
5760

58-
@abstractmethod
61+
@rate_limit_handler
5962
def invoke(
6063
self,
61-
input: str,
64+
input: Union[str, List[LLMMessage], MessageHistory],
6265
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
6366
system_instruction: Optional[str] = None,
67+
) -> LLMResponse:
68+
message_history = legacy_inputs_to_message_history(
69+
input, message_history, system_instruction
70+
)
71+
return self._invoke(message_history.messages)
72+
73+
@abstractmethod
74+
def _invoke(
75+
self,
76+
input: list[LLMMessage],
6477
) -> LLMResponse:
6578
"""Sends a text input to the LLM and retrieves a response.
6679
6780
Args:
68-
input (str): Text sent to the LLM.
69-
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
70-
with each message having a specific role assigned.
71-
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
81+
input (MessageHistory): Text sent to the LLM.
7282
7383
Returns:
7484
LLMResponse: The response from the LLM.
@@ -77,20 +87,27 @@ def invoke(
7787
LLMGenerationError: If anything goes wrong.
7888
"""
7989

80-
@abstractmethod
90+
@async_rate_limit_handler
8191
async def ainvoke(
8292
self,
83-
input: str,
93+
input: Union[str, List[LLMMessage], MessageHistory],
8494
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8595
system_instruction: Optional[str] = None,
96+
) -> LLMResponse:
97+
message_history = legacy_inputs_to_message_history(
98+
input, message_history, system_instruction
99+
)
100+
return await self._ainvoke(message_history.messages)
101+
102+
@abstractmethod
103+
async def _ainvoke(
104+
self,
105+
input: list[LLMMessage],
86106
) -> LLMResponse:
87107
"""Asynchronously sends a text input to the LLM and retrieves a response.
88108
89109
Args:
90110
input (str): Text sent to the LLM.
91-
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
92-
with each message having a specific role assigned.
93-
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
94111
95112
Returns:
96113
LLMResponse: The response from the LLM.

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 30 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,37 +28,33 @@
2828
cast,
2929
)
3030

31-
from pydantic import ValidationError
32-
3331
from neo4j_graphrag.message_history import MessageHistory
3432
from neo4j_graphrag.types import LLMMessage
3533

3634
from ..exceptions import LLMGenerationError
3735
from .base import LLMInterface
38-
from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler
3936
from .types import (
40-
BaseMessage,
4137
LLMResponse,
42-
MessageList,
4338
ToolCall,
4439
ToolCallResponse,
45-
SystemMessage,
46-
UserMessage,
4740
)
4841

4942
from neo4j_graphrag.tool import Tool
5043

5144
if TYPE_CHECKING:
5245
from openai.types.chat import (
5346
ChatCompletionMessageParam,
54-
ChatCompletionToolParam,
55-
)
47+
ChatCompletionToolParam, ChatCompletionUserMessageParam,
48+
ChatCompletionSystemMessageParam, ChatCompletionAssistantMessageParam,
49+
)
5650
from openai import OpenAI, AsyncOpenAI
51+
from .rate_limit import RateLimitHandler
5752
else:
5853
ChatCompletionMessageParam = Any
5954
ChatCompletionToolParam = Any
6055
OpenAI = Any
6156
AsyncOpenAI = Any
57+
RateLimitHandler = Any
6258

6359

6460
class BaseOpenAILLM(LLMInterface, abc.ABC):
@@ -93,23 +89,26 @@ def __init__(
9389

9490
def get_messages(
9591
self,
96-
input: str,
97-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
98-
system_instruction: Optional[str] = None,
92+
messages: list[LLMMessage],
9993
) -> Iterable[ChatCompletionMessageParam]:
100-
messages = []
101-
if system_instruction:
102-
messages.append(SystemMessage(content=system_instruction).model_dump())
103-
if message_history:
104-
if isinstance(message_history, MessageHistory):
105-
message_history = message_history.messages
106-
try:
107-
MessageList(messages=cast(list[BaseMessage], message_history))
108-
except ValidationError as e:
109-
raise LLMGenerationError(e.errors()) from e
110-
messages.extend(cast(Iterable[dict[str, Any]], message_history))
111-
messages.append(UserMessage(content=input).model_dump())
112-
return messages # type: ignore
94+
chat_messages = []
95+
for m in messages:
96+
message_type: ChatCompletionMessageParam
97+
if m["role"] == "system":
98+
message_type = ChatCompletionSystemMessageParam
99+
elif m["role"] == "user":
100+
message_type = ChatCompletionUserMessageParam
101+
elif m["role"] == "assistant":
102+
message_type = ChatCompletionAssistantMessageParam
103+
else:
104+
raise ValueError(f"Unknown message type: {m['role']}")
105+
chat_messages.append(
106+
message_type(
107+
role=m["role"],
108+
content=m["content"],
109+
)
110+
)
111+
return chat_messages
113112

114113
def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]:
115114
"""Convert a Tool object to OpenAI's expected format.
@@ -132,21 +131,15 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]:
132131
except AttributeError:
133132
raise LLMGenerationError(f"Tool {tool} is not a valid Tool object")
134133

135-
@rate_limit_handler
136-
def invoke(
134+
def _invoke(
137135
self,
138-
input: str,
139-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
140-
system_instruction: Optional[str] = None,
136+
input: list[LLMMessage],
141137
) -> LLMResponse:
142138
"""Sends a text input to the OpenAI chat completion model
143139
and returns the response's content.
144140
145141
Args:
146142
input (str): Text sent to the LLM.
147-
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
148-
with each message having a specific role assigned.
149-
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
150143
151144
Returns:
152145
LLMResponse: The response from OpenAI.
@@ -155,10 +148,8 @@ def invoke(
155148
LLMGenerationError: If anything goes wrong.
156149
"""
157150
try:
158-
if isinstance(message_history, MessageHistory):
159-
message_history = message_history.messages
160151
response = self.client.chat.completions.create(
161-
messages=self.get_messages(input, message_history, system_instruction),
152+
messages=self.get_messages(input),
162153
model=self.model_name,
163154
**self.model_params,
164155
)
@@ -167,7 +158,6 @@ def invoke(
167158
except self.openai.OpenAIError as e:
168159
raise LLMGenerationError(e)
169160

170-
@rate_limit_handler
171161
def invoke_with_tools(
172162
self,
173163
input: str,
@@ -242,21 +232,15 @@ def invoke_with_tools(
242232
except self.openai.OpenAIError as e:
243233
raise LLMGenerationError(e)
244234

245-
@async_rate_limit_handler
246-
async def ainvoke(
235+
async def _ainvoke(
247236
self,
248-
input: str,
249-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
250-
system_instruction: Optional[str] = None,
237+
input: list[LLMMessage],
251238
) -> LLMResponse:
252239
"""Asynchronously sends a text input to the OpenAI chat
253240
completion model and returns the response's content.
254241
255242
Args:
256243
input (str): Text sent to the LLM.
257-
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
258-
with each message having a specific role assigned.
259-
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
260244
261245
Returns:
262246
LLMResponse: The response from OpenAI.
@@ -265,10 +249,8 @@ async def ainvoke(
265249
LLMGenerationError: If anything goes wrong.
266250
"""
267251
try:
268-
if isinstance(message_history, MessageHistory):
269-
message_history = message_history.messages
270252
response = await self.async_client.chat.completions.create(
271-
messages=self.get_messages(input, message_history, system_instruction),
253+
messages=self.get_messages(input),
272254
model=self.model_name,
273255
**self.model_params,
274256
)
@@ -277,7 +259,6 @@ async def ainvoke(
277259
except self.openai.OpenAIError as e:
278260
raise LLMGenerationError(e)
279261

280-
@async_rate_limit_handler
281262
async def ainvoke_with_tools(
282263
self,
283264
input: str,

src/neo4j_graphrag/message_history.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class MessageHistory(ABC):
7474
@abstractmethod
7575
def messages(self) -> List[LLMMessage]: ...
7676

77+
def is_empty(self) -> bool:
78+
return len(self.messages) == 0
79+
7780
@abstractmethod
7881
def add_message(self, message: LLMMessage) -> None: ...
7982

0 commit comments

Comments
 (0)