From fbe0079317835943c57ee54b50f91a728aa9e20e Mon Sep 17 00:00:00 2001 From: Will Tai Date: Tue, 9 Jul 2024 16:47:45 +0100 Subject: [PATCH 1/2] Fix init class of OpenAIEmbeddings --- examples/graphrag.py | 2 +- src/neo4j_genai/embeddings/openai.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/graphrag.py b/examples/graphrag.py index 63d81b189..c6e6dbccb 100644 --- a/examples/graphrag.py +++ b/examples/graphrag.py @@ -19,7 +19,7 @@ URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") DATABASE = "neo4j" -INDEX = "moviePlotsEmbedding" +INDEX = "vector-index-name" # setup logger config diff --git a/src/neo4j_genai/embeddings/openai.py b/src/neo4j_genai/embeddings/openai.py index cf9a5d8b0..6aaf1a53b 100644 --- a/src/neo4j_genai/embeddings/openai.py +++ b/src/neo4j_genai/embeddings/openai.py @@ -6,7 +6,7 @@ class OpenAIEmbeddings(Embedder): - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, model: str = "text-embedding-ada-002") -> None: try: import openai except ImportError: @@ -15,10 +15,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "Please install it with `pip install openai`." ) - self.model = openai.OpenAI(*args, **kwargs) + self.openai_model = openai.OpenAI() + self.model = model - def embed_query( - self, text: str, model: str = "text-embedding-ada-002", **kwargs: Any - ) -> list[float]: - response = self.model.embeddings.create(input=text, model=model, **kwargs) + def embed_query(self, text: str, **kwargs: Any) -> list[float]: + response = self.openai_model.embeddings.create( + input=text, model=self.model, **kwargs + ) return response.data[0].embedding From 9dd967c4c79cc0bbf07d454b56199d3f4c182e01 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Tue, 9 Jul 2024 16:57:28 +0100 Subject: [PATCH 2/2] Removed sentence_transformers from init --- CHANGELOG.md | 4 ++++ docs/source/user_guide.rst | 2 +- examples/graphrag.py | 2 +- src/neo4j_genai/embeddings/__init__.py | 19 ++++++++++++++----- src/neo4j_genai/embeddings/openai.py | 15 +++++++++++++++ .../embeddings/sentence_transformers.py | 15 +++++++++++++++ .../embeddings/test_sentence_transformers.py | 2 +- 7 files changed, 51 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10c1766dd..5a33e9df2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next +### Fixed +- Corrected initialization to allow specifying the embedding model name. +- Removed sentence_transformers from embeddings/__init__.py to avoid ImportError when the package is not installed. + ## 0.3.0 ### Added diff --git a/docs/source/user_guide.rst b/docs/source/user_guide.rst index 5eb5ef5bd..cdc3c8d9a 100644 --- a/docs/source/user_guide.rst +++ b/docs/source/user_guide.rst @@ -266,7 +266,7 @@ The `OpenAIEmbedder` was illustrated previously. Here is how to use the `Sentenc .. code:: python - from neo4j_genai.embeddings import SentenceTransformerEmbeddings + from neo4j_genai.embeddings.sentence_transformers import SentenceTransformerEmbeddings embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") # Note: this is the default model diff --git a/examples/graphrag.py b/examples/graphrag.py index c6e6dbccb..63d81b189 100644 --- a/examples/graphrag.py +++ b/examples/graphrag.py @@ -19,7 +19,7 @@ URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") DATABASE = "neo4j" -INDEX = "vector-index-name" +INDEX = "moviePlotsEmbedding" # setup logger config diff --git a/src/neo4j_genai/embeddings/__init__.py b/src/neo4j_genai/embeddings/__init__.py index 29442f9c2..c0199c144 100644 --- a/src/neo4j_genai/embeddings/__init__.py +++ b/src/neo4j_genai/embeddings/__init__.py @@ -1,5 +1,14 @@ -from .sentence_transformers import SentenceTransformerEmbeddings - -__all__ = [ - "SentenceTransformerEmbeddings", -] +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/neo4j_genai/embeddings/openai.py b/src/neo4j_genai/embeddings/openai.py index 6aaf1a53b..46ff09341 100644 --- a/src/neo4j_genai/embeddings/openai.py +++ b/src/neo4j_genai/embeddings/openai.py @@ -1,3 +1,18 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from typing import Any diff --git a/src/neo4j_genai/embeddings/sentence_transformers.py b/src/neo4j_genai/embeddings/sentence_transformers.py index 8c2290d2b..4e5835edb 100644 --- a/src/neo4j_genai/embeddings/sentence_transformers.py +++ b/src/neo4j_genai/embeddings/sentence_transformers.py @@ -1,3 +1,18 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any import numpy as np diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py index f9c0b12e7..c7f08fc14 100644 --- a/tests/unit/embeddings/test_sentence_transformers.py +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -3,7 +3,7 @@ import numpy as np import pytest from neo4j_genai.embedder import Embedder -from neo4j_genai.embeddings import SentenceTransformerEmbeddings +from neo4j_genai.embeddings.sentence_transformers import SentenceTransformerEmbeddings @patch("sentence_transformers.SentenceTransformer")