Skip to content

Commit e3ac4fe

Browse files
committed
Mypy
1 parent f7ef5df commit e3ac4fe

File tree

5 files changed

+10
-6
lines changed

5 files changed

+10
-6
lines changed

src/neo4j_graphrag/embeddings/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16+
from typing import Any
1617

1718
from abc import ABC, abstractmethod
1819

@@ -25,7 +26,7 @@ class Embedder(ABC):
2526

2627
@abstractmethod
2728
def embed_query(
28-
self, text: str, dimensions: int | None = None, **kwargs
29+
self, text: str, dimensions: int | None = None, **kwargs: Any
2930
) -> list[float]:
3031
"""Embed query text.
3132

src/neo4j_graphrag/embeddings/ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, model: str, **kwargs: Any) -> None:
4141
self.model = model
4242
self.client = ollama.Client(**kwargs)
4343

44-
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
44+
def embed_query(self, text: str, **kwargs: Any) -> list[float]: # type: ignore[override]
4545
"""
4646
Generate embeddings for a given query using an Ollama text embedding model.
4747

src/neo4j_graphrag/embeddings/openai.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import abc
1919
from typing import TYPE_CHECKING, Any
2020

21+
from openai import NotGiven
22+
2123
from neo4j_graphrag.embeddings.base import Embedder
2224

2325
if TYPE_CHECKING:
@@ -63,8 +65,9 @@ def embed_query(
6365
6466
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
6567
"""
68+
d = dimensions or NotGiven()
6669
response = self.client.embeddings.create(
67-
input=text, model=self.model, dimensions=dimensions, **kwargs
70+
input=text, model=self.model, dimensions=d, **kwargs
6871
)
6972
embedding: list[float] = response.data[0].embedding
7073
return embedding

src/neo4j_graphrag/embeddings/sentence_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self.np = np
3636
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)
3737

38-
def embed_query(self, text: str, **kwargs) -> Any:
38+
def embed_query(self, text: str, **kwargs: Any) -> Any: # type: ignore[override]
3939
result = self.model.encode([text])
4040
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
4141
return result.flatten().tolist()

src/neo4j_graphrag/embeddings/vertexai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, model: str = "text-embedding-004") -> None:
4242
)
4343
self.model = TextEmbeddingModel.from_pretrained(model)
4444

45-
def embed_query(
45+
def embed_query( # type: ignore[override]
4646
self,
4747
text: str,
4848
task_type: str = "RETRIEVAL_QUERY",
@@ -62,4 +62,4 @@ def embed_query(
6262
embeddings = self.model.get_embeddings(
6363
inputs, output_dimensionality=dimensions, **kwargs
6464
)
65-
return embeddings[0].values # type: ignore
65+
return embeddings[0].values

0 commit comments

Comments
 (0)