Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Next

### Added
- Add optional custom_prompt arg to the Text2CypherRetriever class.

## 0.3.1

### Fixed
Expand Down
34 changes: 23 additions & 11 deletions src/neo4j_genai/retrievers/text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Text2CypherRetriever(Retriever):
llm (neo4j_genai.generation.llm.LLMInterface): LLM object to generate the Cypher query.
neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query.
examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples.
custom_prompt (Optional[str]): Optional custom prompt to use instead of auto generated prompt. Will not include the neo4j_schema or examples args, if provided.

Raises:
RetrieverInitializationError: If validation of the input arguments fail.
Expand All @@ -69,6 +70,7 @@ def __init__(
result_formatter: Optional[
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
custom_prompt: Optional[str] = None,
) -> None:
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -82,6 +84,7 @@ def __init__(
neo4j_schema_model=neo4j_schema_model,
examples=examples,
result_formatter=result_formatter,
custom_prompt=custom_prompt,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e
Expand All @@ -90,12 +93,17 @@ def __init__(
self.llm = validated_data.llm_model.llm
self.examples = validated_data.examples
self.result_formatter = validated_data.result_formatter
self.custom_prompt = validated_data.custom_prompt
try:
self.neo4j_schema = (
validated_data.neo4j_schema_model.neo4j_schema
if validated_data.neo4j_schema_model
else get_schema(validated_data.driver_model.driver)
)
if (
not validated_data.custom_prompt
): # don't need schema for a custom prompt
self.neo4j_schema = (
validated_data.neo4j_schema_model.neo4j_schema
if validated_data.neo4j_schema_model
else get_schema(validated_data.driver_model.driver)
)

except (Neo4jError, DriverError) as e:
error_message = getattr(e, "message", str(e))
raise SchemaFetchError(
Expand Down Expand Up @@ -124,12 +132,16 @@ def get_search_results(
except ValidationError as e:
raise SearchValidationError(e.errors()) from e

prompt_template = Text2CypherTemplate()
prompt = prompt_template.format(
schema=self.neo4j_schema,
examples="\n".join(self.examples) if self.examples else "",
query=validated_data.query_text,
)
if not self.custom_prompt:
prompt_template = Text2CypherTemplate()
prompt = prompt_template.format(
schema=self.neo4j_schema,
examples="\n".join(self.examples) if self.examples else "",
query=validated_data.query_text,
)
else:
prompt = self.custom_prompt

logger.debug("Text2CypherRetriever prompt: %s", prompt)

try:
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,4 @@ class Text2CypherRetrieverModel(BaseModel):
neo4j_schema_model: Optional[Neo4jSchemaModel] = None
examples: Optional[list[str]] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
custom_prompt: Optional[str] = None
84 changes: 78 additions & 6 deletions tests/unit/retrievers/test_text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -63,7 +64,11 @@ def test_t2c_retriever_invalid_neo4j_schema(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
) -> None:
with pytest.raises(RetrieverInitializationError) as exc_info:
Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) # type: ignore
Text2CypherRetriever(
driver=driver,
llm=llm,
neo4j_schema=42, # type: ignore[arg-type, unused-ignore]
)

assert "neo4j_schema" in str(exc_info.value)
assert "Input should be a valid string" in str(exc_info.value)
Expand Down Expand Up @@ -92,7 +97,7 @@ def test_t2c_retriever_invalid_search_examples(
driver=driver,
llm=llm,
neo4j_schema="dummy-text",
examples=42, # type: ignore
examples=42, # type: ignore[arg-type, unused-ignore]
)

assert "examples" in str(exc_info.value)
Expand All @@ -113,8 +118,8 @@ def test_t2c_retriever_happy_path(
retriever = Text2CypherRetriever(
driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples
)
retriever.llm.invoke.return_value = LLMResponse(content=t2c_query)
retriever.driver.execute_query.return_value = ( # type: ignore
llm.invoke.return_value = LLMResponse(content=t2c_query)
driver.execute_query.return_value = (
[neo4j_record],
None,
None,
Expand All @@ -126,8 +131,8 @@ def test_t2c_retriever_happy_path(
query=query_text,
)
retriever.search(query_text=query_text)
retriever.llm.invoke.assert_called_once_with(prompt)
retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) # type: ignore
llm.invoke.assert_called_once_with(prompt)
driver.execute_query.assert_called_once_with(query_=t2c_query)


@patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version")
Expand Down Expand Up @@ -178,3 +183,70 @@ def test_t2c_retriever_with_result_format_function(
],
metadata={"cypher": t2c_query, "__retriever": "Text2CypherRetriever"},
)


@pytest.mark.usefixtures("caplog")
@patch("neo4j_genai.retrievers.base.Retriever._verify_version")
def test_t2c_retriever_initialization_with_custom_prompt(
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
prompt = "This is a custom prompt."
with caplog.at_level(logging.DEBUG):
retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt)
driver.execute_query.return_value = (
[neo4j_record],
None,
None,
)
retriever.search(query_text="test")

assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text


@pytest.mark.usefixtures("caplog")
@patch("neo4j_genai.retrievers.base.Retriever._verify_version")
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples(
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
prompt = "This is another custom prompt."
neo4j_schema = "dummy-schema"
examples = ["example-1", "example-2"]
with caplog.at_level(logging.DEBUG):
retriever = Text2CypherRetriever(
driver=driver,
llm=llm,
custom_prompt=prompt,
neo4j_schema=neo4j_schema,
examples=examples,
)

driver.execute_query.return_value = (
[neo4j_record],
None,
None,
)
retriever.search(query_text="test")

assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text


@patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version")
def test_t2c_retriever_invalid_custom_prompt_type(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
) -> None:
with pytest.raises(RetrieverInitializationError) as exc_info:
Text2CypherRetriever(
driver=driver,
llm=llm,
custom_prompt=42, # type: ignore[arg-type, unused-ignore]
)

assert "Input should be a valid string" in str(exc_info.value)