Skip to content

Speed up add episode #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def main(use_bulk: bool = True):
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
)
return

episodes: list[RawEpisode] = [
RawEpisode(
Expand All @@ -79,10 +80,10 @@ async def main(use_bulk: bool = True):
source_description='Podcast Transcript',
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:14])
for i, message in enumerate(messages[3:20])
]

await client.add_episode_bulk(episodes)


asyncio.run(main(True))
asyncio.run(main(False))
173 changes: 94 additions & 79 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@
retrieve_previous_episodes_bulk,
)
from graphiti_core.utils.maintenance.edge_operations import (
dedupe_extracted_edges,
extract_edges,
resolve_extracted_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
)
from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
invalidate_edges,
Expand Down Expand Up @@ -177,9 +180,9 @@ async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand Down Expand Up @@ -207,14 +210,14 @@ async def retrieve_episodes(
return await retrieve_episodes(self.driver, reference_time, last_n)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""
Process an episode and update the graph.
Expand Down Expand Up @@ -265,7 +268,6 @@ async def add_episode_endpoint(episode_data: EpisodeData):

nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.get_embedder()
now = datetime.now()

Expand All @@ -280,6 +282,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):
valid_at=reference_time,
)

# Extract entities as nodes

extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

Expand All @@ -288,57 +292,82 @@ async def add_episode_endpoint(episode_data: EpisodeData):
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)

# Resolve extracted nodes with nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
)
)

logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes

mentioned_nodes, _ = await resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists
)
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
nodes.extend(touched_nodes)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)

# Extract facts as edges given entity nodes
extracted_edges = await extract_edges(
self.llm_client, episode, touched_nodes, previous_episodes
self.llm_client, episode, mentioned_nodes, previous_episodes
)

# calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])

existing_edges = await get_relevant_edges(extracted_edges, self.driver)
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
# Resolve extracted edges with edges already in the graph
existing_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
[edge],
self.driver,
RELEVANT_SCHEMA_LIMIT,
edge.source_node_uuid,
edge.target_node_uuid,
)
for edge in extracted_edges
]
)
)
logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
)
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')

deduped_edges = await dedupe_extracted_edges(
self.llm_client,
extracted_edges,
existing_edges,
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, extracted_edges, existing_edges_list
)

edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
for edge in deduped_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)

for edge in deduped_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = now
for edge in existing_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
# Extract dates for the newly extracted edges
edge_dates = await asyncio.gather(
*[
extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
for edge in deduped_edges
]
)

for i, edge in enumerate(deduped_edges):
valid_at = edge_dates[i][0]
invalid_at = edge_dates[i][1]

edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
if edge.invalid_at is not None:
edge.expired_at = now

entity_edges.extend(deduped_edges)

existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
]

(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
Expand All @@ -361,30 +390,18 @@ async def add_episode_endpoint(episode_data: EpisodeData):
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')

edges_to_save = existing_edges + deduped_edges

entity_edges.extend(edges_to_save)

edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]

logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
entity_edges.extend(existing_edges)

logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')

episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
involved_nodes,
episode,
now,
)
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
episode,
now,
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built

logger.info(f'Built episodic edges: {episodic_edges}')

# Future optimization would be using batch operations to save nodes and edges
Expand All @@ -395,9 +412,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):

end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)

if success_callback:
await success_callback(episode)
except Exception as e:
Expand All @@ -407,8 +422,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):
raise e

async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
self,
bulk_episodes: list[RawEpisode],
):
"""
Process multiple episodes in bulk and update the graph.
Expand Down Expand Up @@ -572,18 +587,18 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
return edges

async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)

async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand Down
54 changes: 46 additions & 8 deletions graphiti_core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
v3: PromptVersion
edge_list: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
v3: PromptFunction
edge_list: PromptFunction


Expand All @@ -41,17 +43,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Given the following context, deduplicate facts from a list of new facts given a list of existing edges:

Existing Facts:
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}

New Facts:
New Edges:
{json.dumps(context['extracted_edges'], indent=2)}

Task:
If any facts in New Facts is a duplicate of a fact in Existing Facts,
do not return it in the list of unique facts.
If any edge in New Edges is a duplicate of an edge in Existing Edges, add their uuids to the output list.
When finding duplicates edges, synthesize their facts into a short new fact.

Guidelines:
1. identical or near identical facts are duplicates
Expand All @@ -60,9 +62,11 @@ def v1(context: dict[str, Any]) -> list[Message]:

Respond with a JSON object in the following format:
{{
"unique_facts": [
"duplicates": [
{{
"uuid": "unique identifier of the fact"
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
"duplicate_of": "uuid of the existing node",
"fact": "one sentence description of the fact"
}}
]
}}
Expand Down Expand Up @@ -113,6 +117,40 @@ def v2(context: dict[str, Any]) -> list[Message]:
]


def v3(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.

Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}

New Edge:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing edge in the response

Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.

Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
}}
""",
),
]


def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
Expand Down Expand Up @@ -151,4 +189,4 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
]


versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'edge_list': edge_list}
Loading