From ac229903150e68907d34aacc544cd32ba3199ac6 Mon Sep 17 00:00:00 2001 From: alex <73862627+a-s-g93@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:42:31 -0500 Subject: [PATCH 01/10] Update CHANGELOG.md --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ab9e77ca..3b079dc34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,8 @@ # @neo4j/neo4j-genai-python ## Next - +- Add optional custom_prompt arg to the Text2CypherRetriever class. + ## 0.3.1 ### Fixed From 9cc46617efc8eb900a81dc9680ca1b360cca886f Mon Sep 17 00:00:00 2001 From: alex <73862627+a-s-g93@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:42:50 -0500 Subject: [PATCH 02/10] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b079dc34..b9ecbf53c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # @neo4j/neo4j-genai-python ## Next + +### Added - Add optional custom_prompt arg to the Text2CypherRetriever class. ## 0.3.1 From bc19cd916d95d32c2a989d5eecb3d0df14249a17 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 29 Jul 2024 14:29:27 -0500 Subject: [PATCH 03/10] add custom prompt option to text2cypher, tested --- src/neo4j_genai/retrievers/text2cypher.py | 34 +++++++++----- src/neo4j_genai/types.py | 1 + tests/unit/retrievers/test_text2cypher.py | 54 +++++++++++++++++++++++ 3 files changed, 78 insertions(+), 11 deletions(-) diff --git a/src/neo4j_genai/retrievers/text2cypher.py b/src/neo4j_genai/retrievers/text2cypher.py index 3384c5979..71a8281fc 100644 --- a/src/neo4j_genai/retrievers/text2cypher.py +++ b/src/neo4j_genai/retrievers/text2cypher.py @@ -55,6 +55,7 @@ class Text2CypherRetriever(Retriever): llm (neo4j_genai.generation.llm.LLMInterface): LLM object to generate the Cypher query. neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query. examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples. + custom_prompt (Optional[str]): Optional custom prompt to use instead of auto generated prompt. Will not include the neo4j_schema or examples args, if provided. Raises: RetrieverInitializationError: If validation of the input arguments fail. @@ -69,6 +70,7 @@ def __init__( result_formatter: Optional[ Callable[[neo4j.Record], RetrieverResultItem] ] = None, + custom_prompt: Optional[str] = None, ) -> None: try: driver_model = Neo4jDriverModel(driver=driver) @@ -82,6 +84,7 @@ def __init__( neo4j_schema_model=neo4j_schema_model, examples=examples, result_formatter=result_formatter, + custom_prompt=custom_prompt, ) except ValidationError as e: raise RetrieverInitializationError(e.errors()) from e @@ -90,12 +93,17 @@ def __init__( self.llm = validated_data.llm_model.llm self.examples = validated_data.examples self.result_formatter = validated_data.result_formatter + self.custom_prompt = validated_data.custom_prompt try: - self.neo4j_schema = ( - validated_data.neo4j_schema_model.neo4j_schema - if validated_data.neo4j_schema_model - else get_schema(validated_data.driver_model.driver) - ) + if ( + not validated_data.custom_prompt + ): # don't need schema for a custom prompt + self.neo4j_schema = ( + validated_data.neo4j_schema_model.neo4j_schema + if validated_data.neo4j_schema_model + else get_schema(validated_data.driver_model.driver) + ) + except (Neo4jError, DriverError) as e: error_message = getattr(e, "message", str(e)) raise SchemaFetchError( @@ -124,12 +132,16 @@ def get_search_results( except ValidationError as e: raise SearchValidationError(e.errors()) from e - prompt_template = Text2CypherTemplate() - prompt = prompt_template.format( - schema=self.neo4j_schema, - examples="\n".join(self.examples) if self.examples else "", - query=validated_data.query_text, - ) + if not self.custom_prompt: + prompt_template = Text2CypherTemplate() + prompt = prompt_template.format( + schema=self.neo4j_schema, + examples="\n".join(self.examples) if self.examples else "", + query=validated_data.query_text, + ) + else: + prompt = self.custom_prompt + logger.debug("Text2CypherRetriever prompt: %s", prompt) try: diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 45826dd5a..1f4888bf9 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -240,3 +240,4 @@ class Text2CypherRetrieverModel(BaseModel): neo4j_schema_model: Optional[Neo4jSchemaModel] = None examples: Optional[list[str]] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None + custom_prompt: Optional[str] = None diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 17a217e14..14ec67111 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from unittest.mock import MagicMock, patch import pytest @@ -178,3 +179,56 @@ def test_t2c_retriever_with_result_format_function( ], metadata={"cypher": t2c_query, "__retriever": "Text2CypherRetriever"}, ) + + +@pytest.mark.usefixtures("caplog") +@patch("neo4j_genai.retrievers.base.Retriever._verify_version") +def test_t2c_retriever_initialization_with_custom_prompt( + _verify_version_mock: MagicMock, + driver: MagicMock, + llm: MagicMock, + neo4j_record: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + prompt = "This is a custom prompt." + with caplog.at_level(logging.DEBUG): + retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt) + retriever.driver.execute_query.return_value = ( # type: ignore + [neo4j_record], + None, + None, + ) + retriever.search(query_text="test") + + assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text + + +@pytest.mark.usefixtures("caplog") +@patch("neo4j_genai.retrievers.base.Retriever._verify_version") +def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples( + _verify_version_mock: MagicMock, + driver: MagicMock, + llm: MagicMock, + neo4j_record: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + prompt = "This is another custom prompt." + neo4j_schema = "dummy-schema" + examples = ["example-1", "example-2"] + with caplog.at_level(logging.DEBUG): + retriever = Text2CypherRetriever( + driver=driver, + llm=llm, + custom_prompt=prompt, + neo4j_schema=neo4j_schema, + examples=examples, + ) + + retriever.driver.execute_query.return_value = ( # type: ignore + [neo4j_record], + None, + None, + ) + retriever.search(query_text="test") + + assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text From f453cb3bfca43e3f1cd9fc74803a603ce906cd62 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 1 Aug 2024 10:30:04 -0500 Subject: [PATCH 04/10] add test for RetrieverInitializationError --- tests/unit/retrievers/test_text2cypher.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 14ec67111..4c3837025 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -64,7 +64,7 @@ def test_t2c_retriever_invalid_neo4j_schema( _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: with pytest.raises(RetrieverInitializationError) as exc_info: - Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) # type: ignore + Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) assert "neo4j_schema" in str(exc_info.value) assert "Input should be a valid string" in str(exc_info.value) @@ -93,7 +93,7 @@ def test_t2c_retriever_invalid_search_examples( driver=driver, llm=llm, neo4j_schema="dummy-text", - examples=42, # type: ignore + examples=42, ) assert "examples" in str(exc_info.value) @@ -115,7 +115,7 @@ def test_t2c_retriever_happy_path( driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples ) retriever.llm.invoke.return_value = LLMResponse(content=t2c_query) - retriever.driver.execute_query.return_value = ( # type: ignore + retriever.driver.execute_query.return_value = ( [neo4j_record], None, None, @@ -128,7 +128,7 @@ def test_t2c_retriever_happy_path( ) retriever.search(query_text=query_text) retriever.llm.invoke.assert_called_once_with(prompt) - retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) # type: ignore + retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) @patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version") @@ -193,7 +193,7 @@ def test_t2c_retriever_initialization_with_custom_prompt( prompt = "This is a custom prompt." with caplog.at_level(logging.DEBUG): retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt) - retriever.driver.execute_query.return_value = ( # type: ignore + retriever.driver.execute_query.return_value = ( [neo4j_record], None, None, @@ -224,7 +224,7 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples examples=examples, ) - retriever.driver.execute_query.return_value = ( # type: ignore + retriever.driver.execute_query.return_value = ( [neo4j_record], None, None, @@ -232,3 +232,13 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples retriever.search(query_text="test") assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text + + +@patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version") +def test_t2c_retriever_invalid_custom_prompt_type( + _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock +) -> None: + with pytest.raises(RetrieverInitializationError) as exc_info: + Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=42) + + assert "Input should be a valid string" in str(exc_info.value) From 82fb9c61470ddbfdee7b2a439eb29aca285f0522 Mon Sep 17 00:00:00 2001 From: alex <73862627+a-s-g93@users.noreply.github.com> Date: Fri, 2 Aug 2024 08:42:08 -0500 Subject: [PATCH 05/10] Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas --- tests/unit/retrievers/test_text2cypher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 4c3837025..db900012c 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -115,7 +115,7 @@ def test_t2c_retriever_happy_path( driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples ) retriever.llm.invoke.return_value = LLMResponse(content=t2c_query) - retriever.driver.execute_query.return_value = ( + retriever.driver.execute_query.return_value = ( # type: ignore [neo4j_record], None, None, From 7b09a143d1f302921720b03cf9c73d0ee5fb46b3 Mon Sep 17 00:00:00 2001 From: alex <73862627+a-s-g93@users.noreply.github.com> Date: Fri, 2 Aug 2024 08:42:13 -0500 Subject: [PATCH 06/10] Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas --- tests/unit/retrievers/test_text2cypher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index db900012c..d8b061299 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -128,7 +128,7 @@ def test_t2c_retriever_happy_path( ) retriever.search(query_text=query_text) retriever.llm.invoke.assert_called_once_with(prompt) - retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) + retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) # type: ignore @patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version") From 3a9b80d6e3e88930b78eb80c75fb763be2328c11 Mon Sep 17 00:00:00 2001 From: alex <73862627+a-s-g93@users.noreply.github.com> Date: Fri, 2 Aug 2024 08:42:20 -0500 Subject: [PATCH 07/10] Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas --- tests/unit/retrievers/test_text2cypher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index d8b061299..820d72f8a 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -239,6 +239,6 @@ def test_t2c_retriever_invalid_custom_prompt_type( _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: with pytest.raises(RetrieverInitializationError) as exc_info: - Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=42) + Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=42) # type: ignore assert "Input should be a valid string" in str(exc_info.value) From 1e520ac2d7d74ef340ef7c9fb0ac302e191a32c6 Mon Sep 17 00:00:00 2001 From: alex <73862627+a-s-g93@users.noreply.github.com> Date: Fri, 2 Aug 2024 08:42:27 -0500 Subject: [PATCH 08/10] Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas --- tests/unit/retrievers/test_text2cypher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 820d72f8a..75e867c89 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -93,7 +93,7 @@ def test_t2c_retriever_invalid_search_examples( driver=driver, llm=llm, neo4j_schema="dummy-text", - examples=42, + examples=42, # type: ignore ) assert "examples" in str(exc_info.value) From 6b8b39252ba457a957c3cedb026d2783113720f0 Mon Sep 17 00:00:00 2001 From: alex <73862627+a-s-g93@users.noreply.github.com> Date: Fri, 2 Aug 2024 08:42:32 -0500 Subject: [PATCH 09/10] Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas --- tests/unit/retrievers/test_text2cypher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 75e867c89..03391621e 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -64,7 +64,7 @@ def test_t2c_retriever_invalid_neo4j_schema( _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: with pytest.raises(RetrieverInitializationError) as exc_info: - Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) + Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) # type: ignore assert "neo4j_schema" in str(exc_info.value) assert "Input should be a valid string" in str(exc_info.value) From 3cd464406206d6f74a18360ecef448ef6f38e8fd Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 7 Aug 2024 16:03:49 +0100 Subject: [PATCH 10/10] Pre-commit fix --- tests/unit/retrievers/test_text2cypher.py | 26 +++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 03391621e..4bd7cf02e 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -64,7 +64,11 @@ def test_t2c_retriever_invalid_neo4j_schema( _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: with pytest.raises(RetrieverInitializationError) as exc_info: - Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) # type: ignore + Text2CypherRetriever( + driver=driver, + llm=llm, + neo4j_schema=42, # type: ignore[arg-type, unused-ignore] + ) assert "neo4j_schema" in str(exc_info.value) assert "Input should be a valid string" in str(exc_info.value) @@ -93,7 +97,7 @@ def test_t2c_retriever_invalid_search_examples( driver=driver, llm=llm, neo4j_schema="dummy-text", - examples=42, # type: ignore + examples=42, # type: ignore[arg-type, unused-ignore] ) assert "examples" in str(exc_info.value) @@ -114,8 +118,8 @@ def test_t2c_retriever_happy_path( retriever = Text2CypherRetriever( driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples ) - retriever.llm.invoke.return_value = LLMResponse(content=t2c_query) - retriever.driver.execute_query.return_value = ( # type: ignore + llm.invoke.return_value = LLMResponse(content=t2c_query) + driver.execute_query.return_value = ( [neo4j_record], None, None, @@ -127,8 +131,8 @@ def test_t2c_retriever_happy_path( query=query_text, ) retriever.search(query_text=query_text) - retriever.llm.invoke.assert_called_once_with(prompt) - retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) # type: ignore + llm.invoke.assert_called_once_with(prompt) + driver.execute_query.assert_called_once_with(query_=t2c_query) @patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version") @@ -193,7 +197,7 @@ def test_t2c_retriever_initialization_with_custom_prompt( prompt = "This is a custom prompt." with caplog.at_level(logging.DEBUG): retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt) - retriever.driver.execute_query.return_value = ( + driver.execute_query.return_value = ( [neo4j_record], None, None, @@ -224,7 +228,7 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples examples=examples, ) - retriever.driver.execute_query.return_value = ( + driver.execute_query.return_value = ( [neo4j_record], None, None, @@ -239,6 +243,10 @@ def test_t2c_retriever_invalid_custom_prompt_type( _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: with pytest.raises(RetrieverInitializationError) as exc_info: - Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=42) # type: ignore + Text2CypherRetriever( + driver=driver, + llm=llm, + custom_prompt=42, # type: ignore[arg-type, unused-ignore] + ) assert "Input should be a valid string" in str(exc_info.value)