Skip to content

Commit 0094676

Browse files
authored
Node episodes list (#381)
* added episode list virtual field * in progress tests * add tests * update search return type * linter * copyright notice * mark integration tests
1 parent 064d920 commit 0094676

File tree

4 files changed

+186
-75
lines changed

4 files changed

+186
-75
lines changed

graphiti_core/nodes.py

+28-31
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@
3838

3939
logger = logging.getLogger(__name__)
4040

41+
ENTITY_NODE_RETURN: LiteralString = """
42+
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
43+
WITH n, collect(e.uuid) AS episodes
44+
RETURN
45+
n.uuid As uuid,
46+
n.name AS name,
47+
n.name_embedding AS name_embedding,
48+
n.group_id AS group_id,
49+
n.created_at AS created_at,
50+
n.summary AS summary,
51+
labels(n) AS labels,
52+
properties(n) AS attributes,
53+
episodes"""
54+
4155

4256
class EpisodeType(Enum):
4357
"""
@@ -280,6 +294,9 @@ async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: st
280294
class EntityNode(Node):
281295
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
282296
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
297+
episodes: list[str] | None = Field(
298+
default=None, description='List of episode uuids that mention this node.'
299+
)
283300
attributes: dict[str, Any] = Field(
284301
default={}, description='Additional attributes of the node. Dependent on node labels'
285302
)
@@ -318,19 +335,14 @@ async def save(self, driver: AsyncDriver):
318335

319336
@classmethod
320337
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
321-
records, _, _ = await driver.execute_query(
338+
query = (
322339
"""
323-
MATCH (n:Entity {uuid: $uuid})
324-
RETURN
325-
n.uuid As uuid,
326-
n.name AS name,
327-
n.name_embedding AS name_embedding,
328-
n.group_id AS group_id,
329-
n.created_at AS created_at,
330-
n.summary AS summary,
331-
labels(n) AS labels,
332-
properties(n) AS attributes
333-
""",
340+
MATCH (n:Entity {uuid: $uuid})
341+
"""
342+
+ ENTITY_NODE_RETURN
343+
)
344+
records, _, _ = await driver.execute_query(
345+
query,
334346
uuid=uuid,
335347
database_=DEFAULT_DATABASE,
336348
routing_='r',
@@ -348,16 +360,8 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
348360
records, _, _ = await driver.execute_query(
349361
"""
350362
MATCH (n:Entity) WHERE n.uuid IN $uuids
351-
RETURN
352-
n.uuid As uuid,
353-
n.name AS name,
354-
n.name_embedding AS name_embedding,
355-
n.group_id AS group_id,
356-
n.created_at AS created_at,
357-
n.summary AS summary,
358-
labels(n) AS labels,
359-
properties(n) AS attributes
360-
""",
363+
"""
364+
+ ENTITY_NODE_RETURN,
361365
uuids=uuids,
362366
database_=DEFAULT_DATABASE,
363367
routing_='r',
@@ -383,16 +387,8 @@ async def get_by_group_ids(
383387
MATCH (n:Entity) WHERE n.group_id IN $group_ids
384388
"""
385389
+ cursor_query
390+
+ ENTITY_NODE_RETURN
386391
+ """
387-
RETURN
388-
n.uuid As uuid,
389-
n.name AS name,
390-
n.name_embedding AS name_embedding,
391-
n.group_id AS group_id,
392-
n.created_at AS created_at,
393-
n.summary AS summary,
394-
labels(n) AS labels,
395-
properties(n) AS attributes
396392
ORDER BY n.uuid DESC
397393
"""
398394
+ limit_query,
@@ -548,6 +544,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
548544
created_at=record['created_at'].to_native(),
549545
summary=record['summary'],
550546
attributes=record['attributes'],
547+
episodes=record['episodes'],
551548
)
552549

553550
entity_node.attributes.pop('uuid', None)

graphiti_core/search/search_utils.py

+35-41
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
semaphore_gather,
3333
)
3434
from graphiti_core.nodes import (
35+
ENTITY_NODE_RETURN,
3536
CommunityNode,
3637
EntityNode,
3738
EpisodicNode,
@@ -53,6 +54,20 @@
5354
MAX_SEARCH_DEPTH = 3
5455
MAX_QUERY_LENGTH = 32
5556

57+
SEARCH_ENTITY_NODE_RETURN: LiteralString = """
58+
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
59+
WITH n, score, collect(e.uuid) AS episodes
60+
RETURN
61+
n.uuid As uuid,
62+
n.name AS name,
63+
n.name_embedding AS name_embedding,
64+
n.group_id AS group_id,
65+
n.created_at AS created_at,
66+
n.summary AS summary,
67+
labels(n) AS labels,
68+
properties(n) AS attributes,
69+
episodes"""
70+
5671

5772
def fulltext_query(query: str, group_ids: list[str] | None = None):
5873
group_ids_filter_list = (
@@ -230,8 +245,8 @@ async def edge_similarity_search(
230245

231246
query: LiteralString = (
232247
"""
233-
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
234-
"""
248+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
249+
"""
235250
+ group_filter_query
236251
+ filter_query
237252
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@@ -341,27 +356,21 @@ async def node_fulltext_search(
341356

342357
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
343358

344-
records, _, _ = await driver.execute_query(
345-
"""
346-
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
347-
YIELD node AS node, score
348-
MATCH (n:Entity)
349-
WHERE n.uuid = node.uuid
359+
query = (
350360
"""
361+
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
362+
YIELD node AS n, score
363+
WHERE n:Entity
364+
"""
351365
+ filter_query
366+
+ SEARCH_ENTITY_NODE_RETURN
352367
+ """
353-
RETURN
354-
n.uuid AS uuid,
355-
n.group_id AS group_id,
356-
n.name AS name,
357-
n.name_embedding AS name_embedding,
358-
n.created_at AS created_at,
359-
n.summary AS summary,
360-
labels(n) AS labels,
361-
properties(n) AS attributes
362368
ORDER BY score DESC
363-
LIMIT $limit
364-
""",
369+
"""
370+
)
371+
372+
records, _, _ = await driver.execute_query(
373+
query,
365374
filter_params,
366375
query=fuzzy_query,
367376
group_ids=group_ids,
@@ -406,19 +415,12 @@ async def node_similarity_search(
406415
+ filter_query
407416
+ """
408417
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
409-
WHERE score > $min_score
410-
RETURN
411-
n.uuid As uuid,
412-
n.group_id AS group_id,
413-
n.name AS name,
414-
n.name_embedding AS name_embedding,
415-
n.created_at AS created_at,
416-
n.summary AS summary,
417-
labels(n) AS labels,
418-
properties(n) AS attributes
419-
ORDER BY score DESC
420-
LIMIT $limit
421-
""",
418+
WHERE score > $min_score"""
419+
+ SEARCH_ENTITY_NODE_RETURN
420+
+ """
421+
ORDER BY score DESC
422+
LIMIT $limit
423+
""",
422424
query_params,
423425
search_vector=search_vector,
424426
group_ids=group_ids,
@@ -452,16 +454,8 @@ async def node_bfs_search(
452454
WHERE n.group_id = origin.group_id
453455
"""
454456
+ filter_query
457+
+ ENTITY_NODE_RETURN
455458
+ """
456-
RETURN DISTINCT
457-
n.uuid As uuid,
458-
n.group_id AS group_id,
459-
n.name AS name,
460-
n.name_embedding AS name_embedding,
461-
n.created_at AS created_at,
462-
n.summary AS summary,
463-
labels(n) AS labels,
464-
properties(n) AS attributes
465459
LIMIT $limit
466460
""",
467461
filter_params,

tests/test_graphiti_int.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ async def test_graphiti_init():
6565
logger = setup_logging()
6666
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
6767

68-
results = await graphiti.search_(
69-
query='Who is the User?',
70-
)
68+
results = await graphiti.search_(query='Who is the User?')
7169

7270
pretty_results = search_results_to_context_string(results)
7371

tests/test_node_int.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Copyright 2024, Zep Software, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
from datetime import datetime, timezone
19+
from uuid import uuid4
20+
21+
import pytest
22+
from neo4j import AsyncGraphDatabase
23+
24+
from graphiti_core.nodes import (
25+
CommunityNode,
26+
EntityNode,
27+
EpisodeType,
28+
EpisodicNode,
29+
)
30+
31+
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
32+
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
33+
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
34+
35+
36+
@pytest.fixture
37+
def sample_entity_node():
38+
return EntityNode(
39+
uuid=str(uuid4()),
40+
name='Test Entity',
41+
group_id='test_group',
42+
labels=['Entity'],
43+
name_embedding=[0.5] * 1024,
44+
summary='Entity Summary',
45+
)
46+
47+
48+
@pytest.fixture
49+
def sample_episodic_node():
50+
return EpisodicNode(
51+
uuid=str(uuid4()),
52+
name='Episode 1',
53+
group_id='test_group',
54+
source=EpisodeType.text,
55+
source_description='Test source',
56+
content='Some content here',
57+
valid_at=datetime.now(timezone.utc),
58+
)
59+
60+
61+
@pytest.fixture
62+
def sample_community_node():
63+
return CommunityNode(
64+
uuid=str(uuid4()),
65+
name='Community A',
66+
name_embedding=[0.5] * 1024,
67+
group_id='test_group',
68+
summary='Community summary',
69+
)
70+
71+
72+
@pytest.mark.asyncio
73+
@pytest.mark.integration
74+
async def test_entity_node_save_get_and_delete(sample_entity_node):
75+
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
76+
await sample_entity_node.save(neo4j_driver)
77+
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
78+
assert retrieved.uuid == sample_entity_node.uuid
79+
assert retrieved.name == 'Test Entity'
80+
assert retrieved.group_id == 'test_group'
81+
82+
await sample_entity_node.delete(neo4j_driver)
83+
84+
await neo4j_driver.close()
85+
86+
87+
@pytest.mark.asyncio
88+
@pytest.mark.integration
89+
async def test_community_node_save_get_and_delete(sample_community_node):
90+
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
91+
92+
await sample_community_node.save(neo4j_driver)
93+
94+
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
95+
assert retrieved.uuid == sample_community_node.uuid
96+
assert retrieved.name == 'Community A'
97+
assert retrieved.group_id == 'test_group'
98+
assert retrieved.summary == 'Community summary'
99+
100+
await sample_community_node.delete(neo4j_driver)
101+
102+
await neo4j_driver.close()
103+
104+
105+
@pytest.mark.asyncio
106+
@pytest.mark.integration
107+
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
108+
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
109+
110+
await sample_episodic_node.save(neo4j_driver)
111+
112+
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
113+
assert retrieved.uuid == sample_episodic_node.uuid
114+
assert retrieved.name == 'Episode 1'
115+
assert retrieved.group_id == 'test_group'
116+
assert retrieved.source == EpisodeType.text
117+
assert retrieved.source_description == 'Test source'
118+
assert retrieved.content == 'Some content here'
119+
120+
await sample_episodic_node.delete(neo4j_driver)
121+
122+
await neo4j_driver.close()

0 commit comments

Comments
 (0)