Skip to content

Commit 31a9a18

Browse files
Ruff formatting
1 parent 385492b commit 31a9a18

File tree

7 files changed

+50
-64
lines changed

7 files changed

+50
-64
lines changed

examples/customize/build_graph/pipeline/kg_builder_from_pdf.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import asyncio
1818
import logging
1919

20-
import neo4j
2120
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
2221
LLMEntityRelationExtractor,
2322
OnError,
@@ -35,12 +34,12 @@
3534
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3635
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3736

37+
import neo4j
38+
3839
logging.basicConfig(level=logging.INFO)
3940

4041

41-
async def define_and_run_pipeline(
42-
neo4j_driver: neo4j.Driver, llm: LLMInterface
43-
) -> PipelineResult:
42+
async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult:
4443
from neo4j_graphrag.experimental.pipeline import Pipeline
4544

4645
# Instantiate Entity and Relation objects
@@ -57,9 +56,7 @@ async def define_and_run_pipeline(
5756
),
5857
]
5958
relations = [
60-
SchemaRelation(
61-
label="SITUATED_AT", description="Indicates the location of a person."
62-
),
59+
SchemaRelation(label="SITUATED_AT", description="Indicates the location of a person."),
6360
SchemaRelation(
6461
label="LED_BY",
6562
description="Indicates the leader of an organization.",
@@ -68,9 +65,7 @@ async def define_and_run_pipeline(
6865
label="OWNS",
6966
description="Indicates the ownership of an item such as a Horcrux.",
7067
),
71-
SchemaRelation(
72-
label="INTERACTS", description="The interaction between two people."
73-
),
68+
SchemaRelation(label="INTERACTS", description="The interaction between two people."),
7469
]
7570
potential_schema = [
7671
("PERSON", "SITUATED_AT", "LOCATION"),
@@ -131,9 +126,7 @@ async def main() -> PipelineResult:
131126
"response_format": {"type": "json_object"},
132127
},
133128
)
134-
driver = neo4j.GraphDatabase.driver(
135-
"bolt://localhost:7687", auth=("neo4j", "password")
136-
)
129+
driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))
137130
res = await define_and_run_pipeline(driver, llm)
138131
driver.close()
139132
await llm.async_client.close()

examples/customize/build_graph/pipeline/kg_builder_from_text.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import asyncio
1818

19-
import neo4j
2019
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
2120
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
2221
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -37,10 +36,10 @@
3736
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3837
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3938

39+
import neo4j
40+
4041

41-
async def define_and_run_pipeline(
42-
neo4j_driver: neo4j.Driver, llm: LLMInterface
43-
) -> PipelineResult:
42+
async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult:
4443
"""This is where we define and run the KG builder pipeline, instantiating a few
4544
components:
4645
- Text Splitter: in this example we use the fixed size text splitter
@@ -75,9 +74,7 @@ async def define_and_run_pipeline(
7574
# and how the output of previous components must be used
7675
pipe.connect("splitter", "chunk_embedder", input_config={"text_chunks": "splitter"})
7776
pipe.connect("schema", "extractor", input_config={"schema": "schema"})
78-
pipe.connect(
79-
"chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}
80-
)
77+
pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"})
8178
pipe.connect(
8279
"extractor",
8380
"writer",
@@ -148,9 +145,7 @@ async def main() -> PipelineResult:
148145
"response_format": {"type": "json_object"},
149146
},
150147
)
151-
driver = neo4j.GraphDatabase.driver(
152-
"bolt://localhost:7687", auth=("neo4j", "password")
153-
)
148+
driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))
154149
res = await define_and_run_pipeline(driver, llm)
155150
driver.close()
156151
await llm.async_client.close()

examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import asyncio
44

5-
import neo4j
65
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
76
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
87
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
@@ -14,6 +13,8 @@
1413
from neo4j_graphrag.experimental.pipeline import Pipeline
1514
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
1615

16+
import neo4j
17+
1718

1819
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
1920
"""This is where we define and run the Lexical Graph builder pipeline, instantiating
@@ -78,7 +79,5 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
7879

7980

8081
if __name__ == "__main__":
81-
with neo4j.GraphDatabase.driver(
82-
"bolt://localhost:7687", auth=("neo4j", "password")
83-
) as driver:
82+
with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver:
8483
print(asyncio.run(main(driver)))

examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import asyncio
99

10-
import neo4j
1110
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
1211
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
1312
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -29,6 +28,8 @@
2928
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3029
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3130

31+
import neo4j
32+
3233

3334
async def define_and_run_pipeline(
3435
neo4j_driver: neo4j.Driver,
@@ -56,7 +57,7 @@ async def define_and_run_pipeline(
5657
pipe = Pipeline()
5758
# define the components
5859
pipe.add_component(
59-
FixedSizeSplitter(chunk_size=200, chunk_overlap=50,approximate=False),
60+
FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False),
6061
"splitter",
6162
)
6263
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
@@ -92,9 +93,7 @@ async def define_and_run_pipeline(
9293
)
9394
# define the execution order of component
9495
# and how the output of previous components must be used
95-
pipe.connect(
96-
"chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}
97-
)
96+
pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"})
9897
pipe.connect("schema", "extractor", input_config={"schema": "schema"})
9998
pipe.connect(
10099
"extractor",
@@ -189,7 +188,5 @@ async def main(driver: neo4j.Driver) -> PipelineResult:
189188

190189

191190
if __name__ == "__main__":
192-
with neo4j.GraphDatabase.driver(
193-
"bolt://localhost:7687", auth=("neo4j", "password")
194-
) as driver:
191+
with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver:
195192
print(asyncio.run(main(driver)))

examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import asyncio
1010

11-
import neo4j
1211
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
1312
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
1413
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -31,6 +30,8 @@
3130
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3231
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3332

33+
import neo4j
34+
3435

3536
async def build_lexical_graph(
3637
neo4j_driver: neo4j.Driver,
@@ -200,15 +201,11 @@ async def main(driver: neo4j.Driver) -> PipelineResult:
200201
},
201202
)
202203
await build_lexical_graph(driver, lexical_graph_config, text=text)
203-
res = await read_chunk_and_perform_entity_extraction(
204-
driver, llm, lexical_graph_config
205-
)
204+
res = await read_chunk_and_perform_entity_extraction(driver, llm, lexical_graph_config)
206205
await llm.async_client.close()
207206
return res
208207

209208

210209
if __name__ == "__main__":
211-
with neo4j.GraphDatabase.driver(
212-
"bolt://localhost:7687", auth=("neo4j", "password")
213-
) as driver:
210+
with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver:
214211
print(asyncio.run(main(driver)))

src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int:
5858
"""
5959
end = approximate_end
6060
if end < len(text):
61-
while end > start and not text[end].isspace() and not text[end-1].isspace():
61+
while end > start and not text[end].isspace() and not text[end - 1].isspace():
6262
end -= 1
6363

6464
# fallback if no whitespace is found
@@ -92,7 +92,9 @@ class FixedSizeSplitter(TextSplitter):
9292
"""
9393

9494
@validate_call
95-
def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True) -> None:
95+
def __init__(
96+
self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True
97+
) -> None:
9698
if chunk_size <= 0:
9799
raise ValueError("chunk_size must be strictly greater than 0")
98100
if chunk_overlap >= chunk_size:
@@ -131,7 +133,7 @@ async def run(self, text: str) -> TextChunks:
131133
end = _adjust_chunk_end(text, start, approximate_end)
132134
# when avoiding splitting words in the middle is not possible, revert to
133135
# initial chunk end and skip adjusting next chunk start
134-
skip_adjust_chunk_start = (end == approximate_end)
136+
skip_adjust_chunk_start = end == approximate_end
135137
else:
136138
# apply fixed size splitting with possibly words cut in half at chunk
137139
# boundaries

tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import pytest
1818
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
19-
FixedSizeSplitter, _adjust_chunk_start, _adjust_chunk_end,
19+
FixedSizeSplitter,
20+
_adjust_chunk_end,
21+
_adjust_chunk_start,
2022
)
2123
from neo4j_graphrag.experimental.components.types import TextChunk
2224

@@ -101,7 +103,11 @@ def test_invalid_chunk_size() -> None:
101103
("Hello World", 0, 0),
102104
],
103105
)
104-
def test_adjust_chunk_start(text, approximate_start, expected_start):
106+
def test_adjust_chunk_start(
107+
text: str,
108+
approximate_start: int,
109+
expected_start: int
110+
) -> None:
105111
"""
106112
Test that the _adjust_chunk_start function correctly shifts
107113
the start index to avoid breaking words, unless no whitespace is found.
@@ -125,7 +131,12 @@ def test_adjust_chunk_start(text, approximate_start, expected_start):
125131
("Hello World", 6, 15, 15),
126132
],
127133
)
128-
def test_adjust_chunk_end(text, start, approximate_end, expected_end):
134+
def test_adjust_chunk_end(
135+
text: str,
136+
start: int,
137+
approximate_end: int,
138+
expected_end: int
139+
) -> None:
129140
"""
130141
Test that the _adjust_chunk_end function correctly shifts
131142
the end index to avoid breaking words, unless no whitespace is found.
@@ -144,27 +155,15 @@ def test_adjust_chunk_end(text, start, approximate_end, expected_end):
144155
10,
145156
2,
146157
True,
147-
[
148-
"Hello ",
149-
"World, ",
150-
"this is a ",
151-
"a test ",
152-
"message."
153-
],
158+
["Hello ", "World, ", "this is a ", "a test ", "message."],
154159
),
155160
# Case: fixed size splitting
156161
(
157162
"Hello World, this is a test message.",
158163
10,
159164
2,
160165
False,
161-
[
162-
"Hello Worl",
163-
"rld, this ",
164-
"s is a tes",
165-
"est messag",
166-
"age."
167-
],
166+
["Hello Worl", "rld, this ", "s is a tes", "est messag", "age."],
168167
),
169168
# Case: short text => only one chunk
170169
(
@@ -193,8 +192,12 @@ def test_adjust_chunk_end(text, start, approximate_end, expected_end):
193192
],
194193
)
195194
async def test_fixed_size_splitter_run(
196-
text, chunk_size, chunk_overlap, approximate, expected_chunks
197-
):
195+
text: str,
196+
chunk_size: int,
197+
chunk_overlap: int,
198+
approximate: bool,
199+
expected_chunks: list[str]
200+
) -> None:
198201
"""
199202
Test that 'FixedSizeSplitter.run' returns the expected chunks
200203
for different configurations.

0 commit comments

Comments
 (0)