diff --git a/docs/user-guides/configuration-guide.md b/docs/user-guides/configuration-guide.md index 803c4a519..7a622e6b4 100644 --- a/docs/user-guides/configuration-guide.md +++ b/docs/user-guides/configuration-guide.md @@ -538,6 +538,7 @@ The following tables lists the supported embedding providers: | OpenAI | `openai` | `text-embedding-ada-002`, etc. | | SentenceTransformers | `SentenceTransformers` | `all-MiniLM-L6-v2`, etc. | | NVIDIA AI Endpoints | `nvidia_ai_endpoints` | `nv-embed-v1`, etc. | +| Google | `google` | `gemini-embedding-001`, etc. | ```{note} You can use any of the supported models for any of the supported embedding providers. diff --git a/nemoguardrails/embeddings/providers/__init__.py b/nemoguardrails/embeddings/providers/__init__.py index c9a8f2896..b95d89d39 100644 --- a/nemoguardrails/embeddings/providers/__init__.py +++ b/nemoguardrails/embeddings/providers/__init__.py @@ -18,7 +18,7 @@ from typing import Optional, Type -from . import fastembed, nim, openai, sentence_transformers +from . import fastembed, google, nim, openai, sentence_transformers from .base import EmbeddingModel from .registry import EmbeddingProviderRegistry @@ -68,6 +68,7 @@ def register_embedding_provider( register_embedding_provider(sentence_transformers.SentenceTransformerEmbeddingModel) register_embedding_provider(nim.NIMEmbeddingModel) register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel) +register_embedding_provider(google.GoogleEmbeddingModel) def init_embedding_model( diff --git a/nemoguardrails/embeddings/providers/google.py b/nemoguardrails/embeddings/providers/google.py new file mode 100644 index 000000000..83e3e7950 --- /dev/null +++ b/nemoguardrails/embeddings/providers/google.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from .base import EmbeddingModel + + +class GoogleEmbeddingModel(EmbeddingModel): + """Embedding model using langchain_google_genai. + + This class is a wrapper for using embedding models powered by Google AI (hosted in the Google Cloud). + + To use, you must have either: + + 1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or + 2. Pass your API key using the google_api_key kwarg to the + GoogleGenerativeAIEmbeddings constructor. + + Args: + embedding_model (str): The name of the embedding model to be used. + + Attributes: + model: The name of the embedding model. + embedding_size (int): The size of the embeddings. + """ + + engine_name = "google" + + def __init__(self, embedding_model: str, **kwargs): + try: + from langchain_google_genai import GoogleGenerativeAIEmbeddings + + except ImportError: + raise ImportError( + "Could not import langchain_google_genai, please install it with " + "`pip install langchain-google-genai`." + ) + + self.model = embedding_model + self.document_embedder = GoogleGenerativeAIEmbeddings( + model=embedding_model, **kwargs + ) + + self.embedding_size_dict = { + "gemini-embedding-001": 3072, + "text-embedding-005": 768, + "text-multilingual-embedding-002": 768, + } + + if self.model in self.embedding_size_dict: + self.embedding_size = self.embedding_size_dict[self.model] + else: + # Perform a first encoding to get the embedding size + self.embedding_size = len(self.encode(["test"])[0]) + + async def encode_async(self, documents: List[str]) -> List[List[float]]: + """Encode a list of documents into their corresponding sentence embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The list of sentence embeddings, where each embedding is a list of floats. + """ + + result = await self.document_embedder.aembed_documents(documents) + return result + + def encode(self, documents: List[str]) -> List[List[float]]: + """Encode a list of documents into their corresponding sentence embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The list of sentence embeddings, where each embedding is a list of floats. + """ + return self.document_embedder.embed_documents(documents) diff --git a/tests/test_configs/with_google_embeddings/config.co b/tests/test_configs/with_google_embeddings/config.co new file mode 100644 index 000000000..56035e40c --- /dev/null +++ b/tests/test_configs/with_google_embeddings/config.co @@ -0,0 +1,12 @@ +define user ask capabilities + "What can you do?" + "What can you help me with?" + "tell me what you can do" + "tell me about you" + +define bot inform capabilities + "I am an AI assistant that helps answer questions." + +define flow + user ask capabilities + bot inform capabilities diff --git a/tests/test_configs/with_google_embeddings/config.yml b/tests/test_configs/with_google_embeddings/config.yml new file mode 100644 index 000000000..14730a4a2 --- /dev/null +++ b/tests/test_configs/with_google_embeddings/config.yml @@ -0,0 +1,8 @@ +models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + + - type: embeddings + engine: google + model: gemini-embedding-001 diff --git a/tests/test_embeddings_google.py b/tests/test_embeddings_google.py new file mode 100644 index 000000000..bb39747ee --- /dev/null +++ b/tests/test_embeddings_google.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from nemoguardrails import LLMRails, RailsConfig + +try: + from nemoguardrails.embeddings.providers.google import GoogleEmbeddingModel +except ImportError: + # Ignore this if running in test environment when langchain-google-genai not installed. + GoogleEmbeddingModel = None + +CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), ".", "test_configs") + +LIVE_TEST_MODE = os.environ.get("LIVE_TEST") + + +@pytest.fixture +def app(): + """Load the configuration where we replace FastEmbed with Google.""" + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_google_embeddings") + ) + + return LLMRails(config) + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +def test_custom_llm_registration(app): + assert isinstance( + app.llm_generation_actions.flows_index._model, GoogleEmbeddingModel + ) + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +async def test_live_query(): + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_google_embeddings") + ) + app = LLMRails(config) + + result = await app.generate_async( + messages=[{"role": "user", "content": "tell me what you can do"}] + ) + + assert result == { + "role": "assistant", + "content": "I am an AI assistant that helps answer questions.", + } + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +def test_live_query(app): + result = app.generate( + messages=[{"role": "user", "content": "tell me what you can do"}] + ) + + assert result == { + "role": "assistant", + "content": "I am an AI assistant that helps answer questions.", + } + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +def test_sync_embeddings(): + model = GoogleEmbeddingModel("gemini-embedding-001") + + result = model.encode(["test"]) + + assert len(result[0]) == 3072 + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +async def test_async_embeddings(): + model = GoogleEmbeddingModel("gemini-embedding-001") + + result = await model.encode_async(["test"]) + + assert len(result[0]) == 3072