From 71d00f083fb59bda34c82b82eea85602c1710265 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 2 Sep 2025 11:17:40 -0500 Subject: [PATCH 1/2] Dummy commit to set up the chore/type-clean-guardrails PR and branch --- nemoguardrails/actions/llm/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..cd11e70a7 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -137,7 +137,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item From 5517ce81fe11bf22ab9787f2abb5bc488c30a41d Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:39:57 -0500 Subject: [PATCH 2/2] Cleaned embeddings/ --- nemoguardrails/embeddings/basic.py | 85 +++++++++++-------- nemoguardrails/embeddings/cache.py | 57 ++++++++----- .../embeddings/providers/fastembed.py | 2 +- nemoguardrails/embeddings/providers/openai.py | 6 +- .../providers/sentence_transformers.py | 4 +- 5 files changed, 93 insertions(+), 61 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index ad3109e8f..3fe17fe49 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -17,7 +17,7 @@ import logging from typing import Any, Dict, List, Optional, Union -from annoy import AnnoyIndex +from annoy import AnnoyIndex # type: ignore from nemoguardrails.embeddings.cache import cache_embeddings from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem @@ -45,26 +45,16 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): max_batch_hold: The maximum time a batch is held before being processed """ - embedding_model: str - embedding_engine: str - embedding_params: Dict[str, Any] - index: AnnoyIndex - embedding_size: int - cache_config: EmbeddingsCacheConfig - embeddings: List[List[float]] - search_threshold: float - use_batching: bool - max_batch_size: int - max_batch_hold: float + # Instance attributes are defined in __init__ and accessed via properties def __init__( self, - embedding_model=None, - embedding_engine=None, - embedding_params=None, - index=None, - cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None, - search_threshold: float = None, + embedding_model: Optional[str] = None, + embedding_engine: Optional[str] = None, + embedding_params: Optional[Dict[str, Any]] = None, + index: Optional[AnnoyIndex] = None, + cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, + search_threshold: Optional[float] = None, use_batching: bool = False, max_batch_size: int = 10, max_batch_hold: float = 0.01, @@ -81,10 +71,10 @@ def __init__( max_batch_hold: The maximum time a batch is held before being processed """ self._model: Optional[EmbeddingModel] = None - self._items = [] - self._embeddings = [] - self.embedding_model = embedding_model - self.embedding_engine = embedding_engine + self._items: List[IndexItem] = [] + self._embeddings: List[List[float]] = [] + self.embedding_model: Optional[str] = embedding_model + self.embedding_engine: Optional[str] = embedding_engine self.embedding_params = embedding_params or {} self._embedding_size = 0 self.search_threshold = search_threshold or float("inf") @@ -95,12 +85,12 @@ def __init__( self._index = index # Data structures for batching embedding requests - self._req_queue = {} - self._req_results = {} - self._req_idx = 0 - self._current_batch_finished_event = None - self._current_batch_full_event = None - self._current_batch_submitted = asyncio.Event() + self._req_queue: Dict[int, str] = {} + self._req_results: Dict[int, List[float]] = {} + self._req_idx: int = 0 + self._current_batch_finished_event: Optional[asyncio.Event] = None + self._current_batch_full_event: Optional[asyncio.Event] = None + self._current_batch_submitted: asyncio.Event = asyncio.Event() # Initialize the batching configuration self.use_batching = use_batching @@ -112,6 +102,11 @@ def embeddings_index(self): """Get the current embedding index""" return self._index + @embeddings_index.setter + def embeddings_index(self, index): + """Setter to allow replacing the index dynamically.""" + self._index = index + @property def cache_config(self): """Get the cache configuration.""" @@ -127,19 +122,23 @@ def embeddings(self): """Get the computed embeddings.""" return self._embeddings - @embeddings_index.setter - def embeddings_index(self, index): - """Setter to allow replacing the index dynamically.""" - self._index = index - def _init_model(self): """Initialize the model used for computing the embeddings.""" + # Provide defaults if not specified + model = self.embedding_model or "sentence-transformers/all-MiniLM-L6-v2" + engine = self.embedding_engine or "SentenceTransformers" + self._model = init_embedding_model( - embedding_model=self.embedding_model, - embedding_engine=self.embedding_engine, + embedding_model=model, + embedding_engine=engine, embedding_params=self.embedding_params, ) + if not self._model: + raise ValueError( + f"Couldn't create embedding model with model {model} and engine {engine}" + ) + @cache_embeddings async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: """Compute embeddings for a list of texts. @@ -153,6 +152,8 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: if self._model is None: self._init_model() + if not self._model: + raise Exception("Couldn't initialize embedding model") embeddings = await self._model.encode_async(texts) return embeddings @@ -199,6 +200,10 @@ async def _run_batch(self): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. + if not self._current_batch_full_event: + raise Exception("self._current_batch_full_event not initialized") + + assert self._current_batch_full_event is not None done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), @@ -210,6 +215,10 @@ async def _run_batch(self): task.cancel() # Reset the batch event + if not self._current_batch_finished_event: + raise Exception("self._current_batch_finished_event not initialized") + + assert self._current_batch_finished_event is not None batch_event: asyncio.Event = self._current_batch_finished_event self._current_batch_finished_event = None @@ -252,9 +261,13 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: # We check if we reached the max batch size if len(self._req_queue) >= self.max_batch_size: + if not self._current_batch_full_event: + raise Exception("self._current_batch_full_event not initialized") self._current_batch_full_event.set() - # Wait for the batch to finish + # Wait for the batch to finish + if not self._current_batch_finished_event: + raise Exception("self._current_batch_finished_event not initialized") await self._current_batch_finished_event.wait() # Remove the result and return it diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index e8a348049..8a93eabfe 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -20,7 +20,12 @@ from abc import ABC, abstractmethod from functools import singledispatchmethod from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional + +try: + import redis # type: ignore +except ImportError: + redis = None # type: ignore from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig @@ -30,6 +35,8 @@ class KeyGenerator(ABC): """Abstract class for key generators.""" + name: str # Class attribute that should be defined in subclasses + @abstractmethod def generate_key(self, text: str) -> str: pass @@ -37,11 +44,11 @@ def generate_key(self, text: str) -> str: @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if subclass.name == name: + if hasattr(subclass, "name") and subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" ". Make sure to import the derived class before using it." ) @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str: class CacheStore(ABC): """Abstract class for cache stores.""" + name: str # Class attribute that should be defined in subclasses + @abstractmethod def get(self, key): """Get a value from the cache.""" @@ -94,11 +103,11 @@ def clear(self): @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if subclass.name == name: + if hasattr(subclass, "name") and subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" ". Make sure to import the derived class before using it." ) @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore): name = "filesystem" - def __init__(self, cache_dir: str = None): + def __init__(self, cache_dir: Optional[str] = None): self._cache_dir = Path(cache_dir or ".cache/embeddings") self._cache_dir.mkdir(parents=True, exist_ok=True) @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore): name = "redis" def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0): - import redis - + if redis is None: + raise ImportError( + "Could not import redis, please install it with `pip install redis`." + ) self._redis = redis.Redis(host=host, port=port, db=db) def get(self, key): @@ -207,9 +218,9 @@ def clear(self): class EmbeddingsCache: def __init__( self, - key_generator: KeyGenerator = None, - cache_store: CacheStore = None, - store_config: dict = None, + key_generator: Optional[KeyGenerator] = None, + cache_store: Optional[CacheStore] = None, + store_config: Optional[dict] = None, ): self._key_generator = key_generator self._cache_store = cache_store @@ -218,7 +229,10 @@ def __init__( @classmethod def from_dict(cls, d: Dict[str, str]): key_generator = KeyGenerator.from_name(d.get("key_generator"))() - store_config = d.get("store_config") + store_config_raw = d.get("store_config") + store_config: dict = ( + store_config_raw if isinstance(store_config_raw, dict) else {} + ) cache_store = CacheStore.from_name(d.get("store"))(**store_config) return cls(key_generator=key_generator, cache_store=cache_store) @@ -230,8 +244,8 @@ def from_config(cls, config: EmbeddingsCacheConfig): def get_config(self): return EmbeddingsCacheConfig( - key_generator=self._key_generator.name, - store=self._cache_store.name, + key_generator=self._key_generator.name if self._key_generator else "sha256", + store=self._cache_store.name if self._cache_store else "filesystem", store_config=self._store_config, ) @@ -239,8 +253,10 @@ def get_config(self): def get(self, texts): raise NotImplementedError - @get.register + @get.register(str) def _(self, text: str): + if self._key_generator is None or self._cache_store is None: + return None key = self._key_generator.generate_key(text) log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") @@ -248,7 +264,7 @@ def _(self, text: str): return result - @get.register + @get.register(list) def _(self, texts: list): cached = {} @@ -266,19 +282,22 @@ def _(self, texts: list): def set(self, texts): raise NotImplementedError - @set.register + @set.register(str) def _(self, text: str, value: List[float]): + if self._key_generator is None or self._cache_store is None: + return key = self._key_generator.generate_key(text) log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") self._cache_store.set(key, value) - @set.register + @set.register(list) def _(self, texts: list, values: List[List[float]]): for text, value in zip(texts, values): self.set(text, value) def clear(self): - self._cache_store.clear() + if self._cache_store is not None: + self._cache_store.clear() def cache_embeddings(func): diff --git a/nemoguardrails/embeddings/providers/fastembed.py b/nemoguardrails/embeddings/providers/fastembed.py index f4b806c13..4bb10e6ff 100644 --- a/nemoguardrails/embeddings/providers/fastembed.py +++ b/nemoguardrails/embeddings/providers/fastembed.py @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel): engine_name = "FastEmbed" def __init__(self, embedding_model: str, **kwargs): - from fastembed import TextEmbedding as Embedding + from fastembed import TextEmbedding as Embedding # type: ignore # Enabling a short form model name for all-MiniLM-L6-v2. if embedding_model == "all-MiniLM-L6-v2": diff --git a/nemoguardrails/embeddings/providers/openai.py b/nemoguardrails/embeddings/providers/openai.py index 4a567b86e..ed2512188 100644 --- a/nemoguardrails/embeddings/providers/openai.py +++ b/nemoguardrails/embeddings/providers/openai.py @@ -46,14 +46,14 @@ def __init__( **kwargs, ): try: - import openai - from openai import AsyncOpenAI, OpenAI + import openai # type: ignore + from openai import AsyncOpenAI, OpenAI # type: ignore except ImportError: raise ImportError( "Could not import openai, please install it with " "`pip install openai`." ) - if openai.__version__ < "1.0.0": + if openai.__version__ < "1.0.0": # type: ignore raise RuntimeError( "`openai<1.0.0` is no longer supported. " "Please upgrade using `pip install openai>=1.0.0`." diff --git a/nemoguardrails/embeddings/providers/sentence_transformers.py b/nemoguardrails/embeddings/providers/sentence_transformers.py index 21932b725..1c389990a 100644 --- a/nemoguardrails/embeddings/providers/sentence_transformers.py +++ b/nemoguardrails/embeddings/providers/sentence_transformers.py @@ -43,7 +43,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from sentence_transformers import SentenceTransformer + from sentence_transformers import SentenceTransformer # type: ignore except ImportError: raise ImportError( "Could not import sentence-transformers, please install it with " @@ -51,7 +51,7 @@ def __init__(self, embedding_model: str, **kwargs): ) try: - from torch import cuda + from torch import cuda # type: ignore except ImportError: raise ImportError( "Could not import torch, please install it with `pip install torch`."