Skip to content
4 changes: 2 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ are listed in [the last section of this file](#customize).
- [OpenAI (GPT)](./customize/llms/openai_llm.py)
- [Azure OpenAI]()
- [VertexAI (Gemini)](./customize/llms/vertexai_llm.py)
- [MistralAI](./customize/llms/mistalai_llm.py)
- [MistralAI](customize/llms/mistralai_llm.py)
- [Cohere](./customize/llms/cohere_llm.py)
- [Anthropic (Claude)](./customize/llms/anthropic_llm.py)
- [Ollama](./customize/llms/ollama_llm.py)
Expand Down Expand Up @@ -142,7 +142,7 @@ are listed in [the last section of this file](#customize).

### Answer: GraphRAG

- [LangChain compatibility](./customize/answer/langchain_compatiblity.py)
- [LangChain compatibility](customize/answer/langchain_compatibility.py)
- [Use a custom prompt](./customize/answer/custom_prompt.py)


Expand Down
18 changes: 17 additions & 1 deletion examples/customize/llms/anthropic_llm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
from neo4j_graphrag.llm import AnthropicLLM, LLMResponse
from neo4j_graphrag.types import LLMMessage

# set api key here on in the ANTHROPIC_API_KEY env var
api_key = None

messages: list[LLMMessage] = [
{
"role": "system",
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
},
{
"role": "user",
"content": "say something",
},
]


llm = AnthropicLLM(
model_name="claude-3-opus-20240229",
model_params={"max_tokens": 1000}, # max_tokens must be specified
api_key=api_key,
)
res: LLMResponse = llm.invoke("say something")
res: LLMResponse = llm.invoke(
# "say something",
messages,
)
print(res.content)
14 changes: 13 additions & 1 deletion examples/customize/llms/cohere_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from neo4j_graphrag.llm import CohereLLM, LLMResponse
from neo4j_graphrag.types import LLMMessage

# set api key here on in the CO_API_KEY env var
api_key = None

messages: list[LLMMessage] = [
{
"role": "system",
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
},
{
"role": "user",
"content": "say something",
},
]

llm = CohereLLM(
model_name="command-r",
api_key=api_key,
)
res: LLMResponse = llm.invoke("say something")
res: LLMResponse = llm.invoke(input=messages)
print(res.content)
25 changes: 6 additions & 19 deletions examples/customize/llms/custom_llm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import random
import string
from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union
from typing import Any, Awaitable, Callable, Optional, TypeVar

from neo4j_graphrag.llm import LLMInterface, LLMResponse
from neo4j_graphrag.llm.rate_limit import (
RateLimitHandler,
# rate_limit_handler,
# async_rate_limit_handler,
)
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage


Expand All @@ -18,38 +17,26 @@ def __init__(
):
super().__init__(model_name, **kwargs)

# Optional: Apply rate limit handling to synchronous invoke method
# @rate_limit_handler
def invoke(
def _invoke(
self,
input: str,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
input: list[LLMMessage],
) -> LLMResponse:
content: str = (
self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30))
)
return LLMResponse(content=content)

# Optional: Apply rate limit handling to asynchronous ainvoke method
# @async_rate_limit_handler
async def ainvoke(
async def _ainvoke(
self,
input: str,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
input: list[LLMMessage],
) -> LLMResponse:
raise NotImplementedError()


llm = CustomLLM(
""
) # if rate_limit_handler and async_rate_limit_handler decorators are used, the default rate limit handler will be applied automatically (retry with exponential backoff)
llm = CustomLLM("")
res: LLMResponse = llm.invoke("text")
print(res.content)

# If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler
# Type variables for function signatures used in rate limit handlers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the above comments on how to customise a rate limit handler are still valid no?

F = TypeVar("F", bound=Callable[..., Any])
AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]])

Expand Down
10 changes: 0 additions & 10 deletions examples/customize/llms/mistalai_llm.py

This file was deleted.

32 changes: 32 additions & 0 deletions examples/customize/llms/mistralai_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from neo4j_graphrag.llm import MistralAILLM, LLMResponse
from neo4j_graphrag.message_history import InMemoryMessageHistory
from neo4j_graphrag.types import LLMMessage

# set api key here on in the MISTRAL_API_KEY env var
api_key = None


messages: list[LLMMessage] = [
{
"role": "system",
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
},
{
"role": "user",
"content": "say something",
},
]


llm = MistralAILLM(
model_name="mistral-small-latest",
api_key=api_key,
)
res: LLMResponse = llm.invoke(
# "say something",
# messages,
InMemoryMessageHistory(
messages=messages,
)
)
print(res.content)
19 changes: 17 additions & 2 deletions examples/customize/llms/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,26 @@
"""

from neo4j_graphrag.llm import LLMResponse, OllamaLLM
from neo4j_graphrag.types import LLMMessage

messages: list[LLMMessage] = [
{
"role": "system",
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
},
{
"role": "user",
"content": "say something",
},
]


llm = OllamaLLM(
model_name="<model_name>",
model_name="orca-mini:latest",
# model_params={"options": {"temperature": 0}, "format": "json"},
# host="...", # if using a remote server
)
res: LLMResponse = llm.invoke("What is the additive color model?")
res: LLMResponse = llm.invoke(
messages,
)
print(res.content)
22 changes: 21 additions & 1 deletion examples/customize/llms/openai_llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
from neo4j_graphrag.message_history import InMemoryMessageHistory
from neo4j_graphrag.types import LLMMessage

# set api key here on in the OPENAI_API_KEY env var
api_key = None

messages: list[LLMMessage] = [
{
"role": "system",
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
},
{
"role": "user",
"content": "say something",
},
]


llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)
res: LLMResponse = llm.invoke("say something")
res: LLMResponse = llm.invoke(
# "say something",
# messages,
InMemoryMessageHistory(
messages=messages,
)
)
print(res.content)
17 changes: 15 additions & 2 deletions examples/customize/llms/vertexai_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from neo4j_graphrag.llm import LLMResponse, VertexAILLM
from vertexai.generative_models import GenerationConfig

from neo4j_graphrag.types import LLMMessage

messages: list[LLMMessage] = [
{
"role": "system",
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
},
{
"role": "user",
"content": "say something",
},
]


generation_config = GenerationConfig(temperature=1.0)
llm = VertexAILLM(
model_name="gemini-2.0-flash-001",
Expand All @@ -9,7 +23,6 @@
# vertexai.generative_models.GenerativeModel client
)
res: LLMResponse = llm.invoke(
"say something",
system_instruction="You are living in 3000 where AI rules the world",
input=messages,
)
print(res.content)
19 changes: 14 additions & 5 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.utils import legacy_inputs_to_messages
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import LLMMessage, RetrieverResult
Expand Down Expand Up @@ -145,12 +146,17 @@ def search(
prompt = self.prompt_template.format(
query_text=query_text, context=context, examples=validated_data.examples
)

messages = legacy_inputs_to_messages(
prompt,
message_history=message_history,
system_instruction=self.prompt_template.system_instructions,
)

logger.debug(f"RAG: retriever_result={prettify(retriever_result)}")
logger.debug(f"RAG: prompt={prompt}")
llm_response = self.llm.invoke(
prompt,
message_history,
system_instruction=self.prompt_template.system_instructions,
messages,
)
answer = llm_response.content
result: dict[str, Any] = {"answer": answer}
Expand All @@ -168,9 +174,12 @@ def _build_query(
summarization_prompt = self._chat_summary_prompt(
message_history=message_history
)
summary = self.llm.invoke(
input=summarization_prompt,
messages = legacy_inputs_to_messages(
summarization_prompt,
system_instruction=summary_system_message,
)
summary = self.llm.invoke(
messages,
).content
return self.conversation_prompt(summary=summary, current_query=query_text)
return query_text
Expand Down
4 changes: 0 additions & 4 deletions src/neo4j_graphrag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
RateLimitHandler,
NoOpRateLimitHandler,
RetryRateLimitHandler,
rate_limit_handler,
async_rate_limit_handler,
)
from .types import LLMResponse
from .vertexai_llm import VertexAILLM
Expand All @@ -42,6 +40,4 @@
"RateLimitHandler",
"NoOpRateLimitHandler",
"RetryRateLimitHandler",
"rate_limit_handler",
"async_rate_limit_handler",
]
Loading
Loading