diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ab9e77ca..b9ecbf53c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## Next +### Added +- Add optional custom_prompt arg to the Text2CypherRetriever class. + ## 0.3.1 ### Fixed diff --git a/src/neo4j_genai/retrievers/text2cypher.py b/src/neo4j_genai/retrievers/text2cypher.py index 3384c5979..71a8281fc 100644 --- a/src/neo4j_genai/retrievers/text2cypher.py +++ b/src/neo4j_genai/retrievers/text2cypher.py @@ -55,6 +55,7 @@ class Text2CypherRetriever(Retriever): llm (neo4j_genai.generation.llm.LLMInterface): LLM object to generate the Cypher query. neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query. examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples. + custom_prompt (Optional[str]): Optional custom prompt to use instead of auto generated prompt. Will not include the neo4j_schema or examples args, if provided. Raises: RetrieverInitializationError: If validation of the input arguments fail. @@ -69,6 +70,7 @@ def __init__( result_formatter: Optional[ Callable[[neo4j.Record], RetrieverResultItem] ] = None, + custom_prompt: Optional[str] = None, ) -> None: try: driver_model = Neo4jDriverModel(driver=driver) @@ -82,6 +84,7 @@ def __init__( neo4j_schema_model=neo4j_schema_model, examples=examples, result_formatter=result_formatter, + custom_prompt=custom_prompt, ) except ValidationError as e: raise RetrieverInitializationError(e.errors()) from e @@ -90,12 +93,17 @@ def __init__( self.llm = validated_data.llm_model.llm self.examples = validated_data.examples self.result_formatter = validated_data.result_formatter + self.custom_prompt = validated_data.custom_prompt try: - self.neo4j_schema = ( - validated_data.neo4j_schema_model.neo4j_schema - if validated_data.neo4j_schema_model - else get_schema(validated_data.driver_model.driver) - ) + if ( + not validated_data.custom_prompt + ): # don't need schema for a custom prompt + self.neo4j_schema = ( + validated_data.neo4j_schema_model.neo4j_schema + if validated_data.neo4j_schema_model + else get_schema(validated_data.driver_model.driver) + ) + except (Neo4jError, DriverError) as e: error_message = getattr(e, "message", str(e)) raise SchemaFetchError( @@ -124,12 +132,16 @@ def get_search_results( except ValidationError as e: raise SearchValidationError(e.errors()) from e - prompt_template = Text2CypherTemplate() - prompt = prompt_template.format( - schema=self.neo4j_schema, - examples="\n".join(self.examples) if self.examples else "", - query=validated_data.query_text, - ) + if not self.custom_prompt: + prompt_template = Text2CypherTemplate() + prompt = prompt_template.format( + schema=self.neo4j_schema, + examples="\n".join(self.examples) if self.examples else "", + query=validated_data.query_text, + ) + else: + prompt = self.custom_prompt + logger.debug("Text2CypherRetriever prompt: %s", prompt) try: diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 45826dd5a..1f4888bf9 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -240,3 +240,4 @@ class Text2CypherRetrieverModel(BaseModel): neo4j_schema_model: Optional[Neo4jSchemaModel] = None examples: Optional[list[str]] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None + custom_prompt: Optional[str] = None diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 17a217e14..4bd7cf02e 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from unittest.mock import MagicMock, patch import pytest @@ -63,7 +64,11 @@ def test_t2c_retriever_invalid_neo4j_schema( _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: with pytest.raises(RetrieverInitializationError) as exc_info: - Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) # type: ignore + Text2CypherRetriever( + driver=driver, + llm=llm, + neo4j_schema=42, # type: ignore[arg-type, unused-ignore] + ) assert "neo4j_schema" in str(exc_info.value) assert "Input should be a valid string" in str(exc_info.value) @@ -92,7 +97,7 @@ def test_t2c_retriever_invalid_search_examples( driver=driver, llm=llm, neo4j_schema="dummy-text", - examples=42, # type: ignore + examples=42, # type: ignore[arg-type, unused-ignore] ) assert "examples" in str(exc_info.value) @@ -113,8 +118,8 @@ def test_t2c_retriever_happy_path( retriever = Text2CypherRetriever( driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples ) - retriever.llm.invoke.return_value = LLMResponse(content=t2c_query) - retriever.driver.execute_query.return_value = ( # type: ignore + llm.invoke.return_value = LLMResponse(content=t2c_query) + driver.execute_query.return_value = ( [neo4j_record], None, None, @@ -126,8 +131,8 @@ def test_t2c_retriever_happy_path( query=query_text, ) retriever.search(query_text=query_text) - retriever.llm.invoke.assert_called_once_with(prompt) - retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) # type: ignore + llm.invoke.assert_called_once_with(prompt) + driver.execute_query.assert_called_once_with(query_=t2c_query) @patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version") @@ -178,3 +183,70 @@ def test_t2c_retriever_with_result_format_function( ], metadata={"cypher": t2c_query, "__retriever": "Text2CypherRetriever"}, ) + + +@pytest.mark.usefixtures("caplog") +@patch("neo4j_genai.retrievers.base.Retriever._verify_version") +def test_t2c_retriever_initialization_with_custom_prompt( + _verify_version_mock: MagicMock, + driver: MagicMock, + llm: MagicMock, + neo4j_record: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + prompt = "This is a custom prompt." + with caplog.at_level(logging.DEBUG): + retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt) + driver.execute_query.return_value = ( + [neo4j_record], + None, + None, + ) + retriever.search(query_text="test") + + assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text + + +@pytest.mark.usefixtures("caplog") +@patch("neo4j_genai.retrievers.base.Retriever._verify_version") +def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples( + _verify_version_mock: MagicMock, + driver: MagicMock, + llm: MagicMock, + neo4j_record: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + prompt = "This is another custom prompt." + neo4j_schema = "dummy-schema" + examples = ["example-1", "example-2"] + with caplog.at_level(logging.DEBUG): + retriever = Text2CypherRetriever( + driver=driver, + llm=llm, + custom_prompt=prompt, + neo4j_schema=neo4j_schema, + examples=examples, + ) + + driver.execute_query.return_value = ( + [neo4j_record], + None, + None, + ) + retriever.search(query_text="test") + + assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text + + +@patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version") +def test_t2c_retriever_invalid_custom_prompt_type( + _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock +) -> None: + with pytest.raises(RetrieverInitializationError) as exc_info: + Text2CypherRetriever( + driver=driver, + llm=llm, + custom_prompt=42, # type: ignore[arg-type, unused-ignore] + ) + + assert "Input should be a valid string" in str(exc_info.value)