Skip to content

chore: Update the context for date extraction + bug fixes #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
extract_edge_dates,
extract_node_edge_node_triplet,
invalidate_edges,
prepare_edges_for_invalidation,
)
Expand Down Expand Up @@ -183,22 +182,27 @@ async def add_episode(
)

for edge in invalidated_edges:
for existing_edge in existing_edges:
if existing_edge.uuid == edge.uuid:
existing_edge.expired_at = edge.expired_at
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)

edges_to_save = invalidated_edges
edges_to_save = existing_edges + deduped_edges

# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
for deduped_edge in deduped_edges:
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
edges_to_save.append(deduped_edge)
for deduped_edge in deduped_edges:
triplet = extract_node_edge_node_triplet(deduped_edge, nodes)
for edge_to_extract_dates_from in edges_to_save:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client, triplet, episode.valid_at, episode, previous_episodes
self.llm_client,
edge_to_extract_dates_from,
episode.valid_at,
episode,
previous_episodes,
)
deduped_edge.valid_at = valid_at
deduped_edge.invalid_at = invalid_at
edge_to_extract_dates_from.valid_at = valid_at
edge_to_extract_dates_from.invalid_at = invalid_at
entity_edges.extend(edges_to_save)

edge_touched_node_uuids = list(set(edge_touched_node_uuids))
Expand Down
2 changes: 0 additions & 2 deletions core/prompts/extract_edge_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
role='user',
content=f"""
Edge:
Source Node: {context['source_node']}
Edge Name: {context['edge_name']}
Target Node: {context['target_node']}
Fact: {context['edge_fact']}

Current Episode: {context['current_episode']}
Expand Down
8 changes: 1 addition & 7 deletions core/utils/maintenance/temporal_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ async def invalidate_edges(
current_episode,
previous_episodes,
)
logger.info(prompt_library.invalidate_edges.v1(context))
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
logger.info(f'invalidate_edges LLM response: {llm_response}')

edges_to_invalidate = llm_response.get('invalidated_edges', [])
invalidated_edges = process_edge_invalidation_llm_response(
Expand Down Expand Up @@ -139,17 +137,13 @@ def process_edge_invalidation_llm_response(

async def extract_edge_dates(
llm_client: LLMClient,
edge_triplet: NodeEdgeNodeTriplet,
edge: EntityEdge,
reference_time: datetime,
current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode],
) -> tuple[datetime | None, datetime | None, str]:
source_node, edge, target_node = edge_triplet

context = {
'source_node': source_node.name,
'edge_name': edge.name,
'target_node': target_node.name,
'edge_fact': edge.fact,
'current_episode': current_episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
Expand Down
1 change: 0 additions & 1 deletion runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ async def main():
await clear_data(client.driver)
await client.build_indices_and_constraints()

# await client.build_indices()
for i, message in enumerate(bmw_sales):
await client.add_episode(
name=f'Message {i}',
Expand Down
Loading