Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def init(self):
self._init_flows_index(),
)

def _extract_user_message_example(self, flow: Flow):
def _extract_user_message_example(self, flow: Flow) -> None:
"""Heuristic to extract user message examples from a flow."""
elements = [
item
Expand Down
85 changes: 49 additions & 36 deletions nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
from typing import Any, Dict, List, Optional, Union

from annoy import AnnoyIndex
from annoy import AnnoyIndex # type: ignore

from nemoguardrails.embeddings.cache import cache_embeddings
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
Expand Down Expand Up @@ -45,26 +45,16 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
max_batch_hold: The maximum time a batch is held before being processed
"""

embedding_model: str
embedding_engine: str
embedding_params: Dict[str, Any]
index: AnnoyIndex
embedding_size: int
cache_config: EmbeddingsCacheConfig
embeddings: List[List[float]]
search_threshold: float
use_batching: bool
max_batch_size: int
max_batch_hold: float
# Instance attributes are defined in __init__ and accessed via properties

def __init__(
self,
embedding_model=None,
embedding_engine=None,
embedding_params=None,
index=None,
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
search_threshold: float = None,
embedding_model: Optional[str] = None,
embedding_engine: Optional[str] = None,
embedding_params: Optional[Dict[str, Any]] = None,
index: Optional[AnnoyIndex] = None,
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
search_threshold: Optional[float] = None,
use_batching: bool = False,
max_batch_size: int = 10,
max_batch_hold: float = 0.01,
Expand All @@ -81,10 +71,10 @@ def __init__(
max_batch_hold: The maximum time a batch is held before being processed
"""
self._model: Optional[EmbeddingModel] = None
self._items = []
self._embeddings = []
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine
self._items: List[IndexItem] = []
self._embeddings: List[List[float]] = []
self.embedding_model: Optional[str] = embedding_model
self.embedding_engine: Optional[str] = embedding_engine
self.embedding_params = embedding_params or {}
self._embedding_size = 0
self.search_threshold = search_threshold or float("inf")
Expand All @@ -95,12 +85,12 @@ def __init__(
self._index = index

# Data structures for batching embedding requests
self._req_queue = {}
self._req_results = {}
self._req_idx = 0
self._current_batch_finished_event = None
self._current_batch_full_event = None
self._current_batch_submitted = asyncio.Event()
self._req_queue: Dict[int, str] = {}
self._req_results: Dict[int, List[float]] = {}
self._req_idx: int = 0
self._current_batch_finished_event: Optional[asyncio.Event] = None
self._current_batch_full_event: Optional[asyncio.Event] = None
self._current_batch_submitted: asyncio.Event = asyncio.Event()

# Initialize the batching configuration
self.use_batching = use_batching
Expand All @@ -112,6 +102,11 @@ def embeddings_index(self):
"""Get the current embedding index"""
return self._index

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

@property
def cache_config(self):
"""Get the cache configuration."""
Expand All @@ -127,19 +122,23 @@ def embeddings(self):
"""Get the computed embeddings."""
return self._embeddings

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

def _init_model(self):
"""Initialize the model used for computing the embeddings."""
# Provide defaults if not specified
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the defaults to constructor? at line 52

model = self.embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some defaults in llmrails.py therefore in "normal" usage these are never None. We should at least the same defaults (e.g. FastEmbed):

https://github.com/NVIDIA/NeMo-Guardrails/blob/5d974e512582ca3a7e3dd16d806c3a888f94c90d/nemoguardrails/rails/llm/llmrails.py#L125-L127

On my side, there are some type checking errors that might happen (for example, having None here) by just using static analysis tools, but the actual "normal" usage flow in Guardrails makes it never happen.

engine = self.embedding_engine or "SentenceTransformers"

self._model = init_embedding_model(
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
embedding_model=model,
embedding_engine=engine,
embedding_params=self.embedding_params,
)

if not self._model:
raise ValueError(
f"Couldn't create embedding model with model {model} and engine {engine}"
)

@cache_embeddings
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings for a list of texts.
Expand All @@ -153,6 +152,8 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
if self._model is None:
self._init_model()

if not self._model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are already throwing an ValueError in _init_model if an error when initializing the model. Does this make sense to throw another one here?

raise Exception("Couldn't initialize embedding model")
embeddings = await self._model.encode_async(texts)
return embeddings

Expand Down Expand Up @@ -199,6 +200,10 @@ async def _run_batch(self):
"""Runs the current batch of embeddings."""

# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
if not self._current_batch_full_event:
raise Exception("self._current_batch_full_event not initialized")

assert self._current_batch_full_event is not None
done, pending = await asyncio.wait(
[
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
Expand All @@ -210,6 +215,10 @@ async def _run_batch(self):
task.cancel()

# Reset the batch event
if not self._current_batch_finished_event:
raise Exception("self._current_batch_finished_event not initialized")

assert self._current_batch_finished_event is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this redundant? also it must not use assert.

batch_event: asyncio.Event = self._current_batch_finished_event
self._current_batch_finished_event = None

Expand Down Expand Up @@ -252,9 +261,13 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:

# We check if we reached the max batch size
if len(self._req_queue) >= self.max_batch_size:
if not self._current_batch_full_event:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this is redundant, self._current_batch_full_event cannot be None here as per earlier check and assertion.

raise Exception("self._current_batch_full_event not initialized")
self._current_batch_full_event.set()

# Wait for the batch to finish
# Wait for the batch to finish
if not self._current_batch_finished_event:
raise Exception("self._current_batch_finished_event not initialized")
await self._current_batch_finished_event.wait()

# Remove the result and return it
Expand Down
57 changes: 38 additions & 19 deletions nemoguardrails/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from abc import ABC, abstractmethod
from functools import singledispatchmethod
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

try:
import redis # type: ignore
except ImportError:
redis = None # type: ignore

from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig

Expand All @@ -30,18 +35,20 @@
class KeyGenerator(ABC):
"""Abstract class for key generators."""

name: str # Class attribute that should be defined in subclasses

@abstractmethod
def generate_key(self, text: str) -> str:
pass

@classmethod
def from_name(cls, name):
for subclass in cls.__subclasses__():
if subclass.name == name:
if hasattr(subclass, "name") and subclass.name == name:
return subclass
raise ValueError(
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
". Make sure to import the derived class before using it."
)

Expand Down Expand Up @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str:
class CacheStore(ABC):
"""Abstract class for cache stores."""

name: str # Class attribute that should be defined in subclasses

@abstractmethod
def get(self, key):
"""Get a value from the cache."""
Expand All @@ -94,11 +103,11 @@ def clear(self):
@classmethod
def from_name(cls, name):
for subclass in cls.__subclasses__():
if subclass.name == name:
if hasattr(subclass, "name") and subclass.name == name:
return subclass
raise ValueError(
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
". Make sure to import the derived class before using it."
)

Expand Down Expand Up @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore):

name = "filesystem"

def __init__(self, cache_dir: str = None):
def __init__(self, cache_dir: Optional[str] = None):
self._cache_dir = Path(cache_dir or ".cache/embeddings")
self._cache_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore):
name = "redis"

def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0):
import redis

if redis is None:
raise ImportError(
"Could not import redis, please install it with `pip install redis`."
)
self._redis = redis.Redis(host=host, port=port, db=db)

def get(self, key):
Expand All @@ -207,9 +218,9 @@ def clear(self):
class EmbeddingsCache:
def __init__(
self,
key_generator: KeyGenerator = None,
cache_store: CacheStore = None,
store_config: dict = None,
key_generator: Optional[KeyGenerator] = None,
cache_store: Optional[CacheStore] = None,
store_config: Optional[dict] = None,
):
self._key_generator = key_generator
self._cache_store = cache_store
Expand All @@ -218,7 +229,10 @@ def __init__(
@classmethod
def from_dict(cls, d: Dict[str, str]):
key_generator = KeyGenerator.from_name(d.get("key_generator"))()
store_config = d.get("store_config")
store_config_raw = d.get("store_config")
store_config: dict = (
store_config_raw if isinstance(store_config_raw, dict) else {}
)
cache_store = CacheStore.from_name(d.get("store"))(**store_config)

return cls(key_generator=key_generator, cache_store=cache_store)
Expand All @@ -230,25 +244,27 @@ def from_config(cls, config: EmbeddingsCacheConfig):

def get_config(self):
return EmbeddingsCacheConfig(
key_generator=self._key_generator.name,
store=self._cache_store.name,
key_generator=self._key_generator.name if self._key_generator else "sha256",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i find defining defaults here problematic. What Pyright did not like about it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There already is a default here in EmbeddingsCacheConfig:

key_generator: str = Field(
        default="sha256",
        description="The method to use for generating the cache keys.",
    )

store=self._cache_store.name if self._cache_store else "filesystem",
store_config=self._store_config,
)

@singledispatchmethod
def get(self, texts):
raise NotImplementedError

@get.register
@get.register(str)
def _(self, text: str):
if self._key_generator is None or self._cache_store is None:
return None
key = self._key_generator.generate_key(text)
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")

result = self._cache_store.get(key)

return result

@get.register
@get.register(list)
def _(self, texts: list):
cached = {}

Expand All @@ -266,19 +282,22 @@ def _(self, texts: list):
def set(self, texts):
raise NotImplementedError

@set.register
@set.register(str)
def _(self, text: str, value: List[float]):
if self._key_generator is None or self._cache_store is None:
return
key = self._key_generator.generate_key(text)
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
self._cache_store.set(key, value)

@set.register
@set.register(list)
def _(self, texts: list, values: List[List[float]]):
for text, value in zip(texts, values):
self.set(text, value)

def clear(self):
self._cache_store.clear()
if self._cache_store is not None:
self._cache_store.clear()


def cache_embeddings(func):
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/embeddings/providers/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel):
engine_name = "FastEmbed"

def __init__(self, embedding_model: str, **kwargs):
from fastembed import TextEmbedding as Embedding
from fastembed import TextEmbedding as Embedding # type: ignore

# Enabling a short form model name for all-MiniLM-L6-v2.
if embedding_model == "all-MiniLM-L6-v2":
Expand Down
6 changes: 3 additions & 3 deletions nemoguardrails/embeddings/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def __init__(
**kwargs,
):
try:
import openai
from openai import AsyncOpenAI, OpenAI
import openai # type: ignore
from openai import AsyncOpenAI, OpenAI # type: ignore
except ImportError:
raise ImportError(
"Could not import openai, please install it with "
"`pip install openai`."
)
if openai.__version__ < "1.0.0":
if openai.__version__ < "1.0.0": # type: ignore
raise RuntimeError(
"`openai<1.0.0` is no longer supported. "
"Please upgrade using `pip install openai>=1.0.0`."
Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/embeddings/providers/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):

def __init__(self, embedding_model: str, **kwargs):
try:
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer # type: ignore
except ImportError:
raise ImportError(
"Could not import sentence-transformers, please install it with "
"`pip install sentence-transformers`."
)

try:
from torch import cuda
from torch import cuda # type: ignore
except ImportError:
raise ImportError(
"Could not import torch, please install it with `pip install torch`."
Expand Down