Skip to content

Commit f7ef5df

Browse files
committed
WIP: expose dimensions in the Embedders
1 parent 2d7abb8 commit f7ef5df

File tree

11 files changed

+73
-26
lines changed

11 files changed

+73
-26
lines changed

examples/customize/embeddings/cohere_embeddings.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44
api_key = None
55

66
embeder = CohereEmbeddings(
7-
model="embed-english-v3.0",
7+
model="embed-v4.0",
88
api_key=api_key,
99
)
10-
res = embeder.embed_query("my question")
10+
res = embeder.embed_query(
11+
"my question",
12+
# optionally, set output dimensions if it's supported by the model
13+
dimensions=256,
14+
input_type="search_query",
15+
)
16+
print("Embedding dimensions", len(res))
1117
print(res[:10])

examples/customize/embeddings/custom_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class CustomEmbeddings(Embedder):
88
def __init__(self, dimension: int = 10, **kwargs: Any):
99
self.dimension = dimension
1010

11-
def embed_query(self, input: str) -> list[float]:
11+
def embed_query(self, input: str, **kwargs) -> list[float]:
1212
return [random.random() for _ in range(self.dimension)]
1313

1414

examples/customize/embeddings/mistalai_embeddings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,10 @@
88
api_key = None
99

1010
embeder = MistralAIEmbeddings(model="mistral-embed", api_key=api_key)
11-
res = embeder.embed_query("my question")
11+
res = embeder.embed_query(
12+
"my question",
13+
# optionally, set output dimensions
14+
dimensions=256,
15+
)
16+
print("Embedding dimensions", len(res))
1217
print(res[:10])

examples/customize/embeddings/openai_embeddings.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
# set api key here on in the OPENAI_API_KEY env var
88
api_key = None
99

10-
embeder = OpenAIEmbeddings(model="text-embedding-ada-002", api_key=api_key)
11-
res = embeder.embed_query("my question")
10+
embeder = OpenAIEmbeddings(model="text-embedding-3-small", api_key=api_key)
11+
res = embeder.embed_query(
12+
"my question",
13+
# optionally, set output dimensions
14+
# dimensions=256,
15+
)
16+
17+
print("Embedding dimensions", len(res))
1218
print(res[:10])

examples/customize/embeddings/vertexai_embeddings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,9 @@
55
from neo4j_graphrag.embeddings import VertexAIEmbeddings
66

77
embeder = VertexAIEmbeddings(model="text-embedding-005")
8-
res = embeder.embed_query("my question")
8+
res = embeder.embed_query(
9+
"my question",
10+
dimensions=256,
11+
)
12+
print("Embedding dimensions", len(res))
913
print(res[:10])

src/neo4j_graphrag/embeddings/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ class Embedder(ABC):
2424
"""
2525

2626
@abstractmethod
27-
def embed_query(self, text: str) -> list[float]:
27+
def embed_query(
28+
self, text: str, dimensions: int | None = None, **kwargs
29+
) -> list[float]:
2830
"""Embed query text.
2931
3032
Args:
3133
text (str): Text to convert to vector embedding
34+
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.
3235
3336
Returns:
3437
list[float]: A vector embedding.

src/neo4j_graphrag/embeddings/cohere.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,23 @@ def __init__(self, model: str = "", **kwargs: Any) -> None:
3232
Please install it with `pip install "neo4j-graphrag[cohere]"`."""
3333
)
3434
self.model = model
35-
self.client = cohere.Client(**kwargs)
35+
self.client = cohere.ClientV2(**kwargs)
3636

37-
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
37+
def embed_query(
38+
self, text: str, dimensions: int | None = None, **kwargs: Any
39+
) -> list[float]:
40+
"""
41+
Generate embeddings for a given query using a Cohere text embedding model.
42+
43+
Args:
44+
text (str): The text to generate an embedding for.
45+
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.
46+
**kwargs (Any): Additional keyword arguments to pass to the Cohere ClientV2.embed method.
47+
"""
3848
response = self.client.embed(
3949
texts=[text],
4050
model=self.model,
51+
output_dimension=dimensions,
4152
**kwargs,
4253
)
43-
return response.embeddings[0] # type: ignore
54+
return response.embeddings.float[0] # type: ignore

src/neo4j_graphrag/embeddings/mistral.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,19 @@ def __init__(self, model: str = "mistral-embed", **kwargs: Any) -> None:
4848
self.model = model
4949
self.mistral_client = Mistral(api_key=api_key, **kwargs)
5050

51-
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
51+
def embed_query(
52+
self, text: str, dimensions: int | None = None, **kwargs: Any
53+
) -> list[float]:
5254
"""
5355
Generate embeddings for a given query using a Mistral AI text embedding model.
5456
5557
Args:
5658
text (str): The text to generate an embedding for.
57-
**kwargs (Any): Additional keyword arguments to pass to the Mistral AI client.
59+
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.
60+
**kwargs (Any): Additional keyword arguments to pass to the embeddings.create method.
5861
"""
5962
embeddings_batch_response = self.mistral_client.embeddings.create(
60-
model=self.model, inputs=[text], **kwargs
63+
model=self.model, inputs=[text], output_dimension=dimensions, **kwargs
6164
)
6265
if embeddings_batch_response is None or not embeddings_batch_response.data:
6366
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")

src/neo4j_graphrag/embeddings/openai.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,21 @@ def _initialize_client(self, **kwargs: Any) -> Any:
5151
"""
5252
pass
5353

54-
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
54+
def embed_query(
55+
self, text: str, dimensions: int | None = None, **kwargs: Any
56+
) -> list[float]:
5557
"""
5658
Generate embeddings for a given query using an OpenAI text embedding model.
5759
5860
Args:
5961
text (str): The text to generate an embedding for.
62+
dimensions (Optional[int]): The number of dimensions the resulting output embeddings should have. Only for models supporting it.
63+
6064
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
6165
"""
62-
response = self.client.embeddings.create(input=text, model=self.model, **kwargs)
66+
response = self.client.embeddings.create(
67+
input=text, model=self.model, dimensions=dimensions, **kwargs
68+
)
6369
embedding: list[float] = response.data[0].embedding
6470
return embedding
6571

src/neo4j_graphrag/embeddings/sentence_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self.np = np
3636
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)
3737

38-
def embed_query(self, text: str) -> Any:
38+
def embed_query(self, text: str, **kwargs) -> Any:
3939
result = self.model.encode([text])
4040
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
4141
return result.flatten().tolist()

0 commit comments

Comments
 (0)