-
Notifications
You must be signed in to change notification settings - Fork 548
chore(types): Type-clean embeddings/ (25 errors) #1383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
import logging | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from annoy import AnnoyIndex | ||
from annoy import AnnoyIndex # type: ignore | ||
|
||
from nemoguardrails.embeddings.cache import cache_embeddings | ||
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem | ||
|
@@ -45,26 +45,16 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): | |
max_batch_hold: The maximum time a batch is held before being processed | ||
""" | ||
|
||
embedding_model: str | ||
embedding_engine: str | ||
embedding_params: Dict[str, Any] | ||
index: AnnoyIndex | ||
embedding_size: int | ||
cache_config: EmbeddingsCacheConfig | ||
embeddings: List[List[float]] | ||
search_threshold: float | ||
use_batching: bool | ||
max_batch_size: int | ||
max_batch_hold: float | ||
# Instance attributes are defined in __init__ and accessed via properties | ||
|
||
def __init__( | ||
self, | ||
embedding_model=None, | ||
embedding_engine=None, | ||
embedding_params=None, | ||
index=None, | ||
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None, | ||
search_threshold: float = None, | ||
embedding_model: Optional[str] = None, | ||
embedding_engine: Optional[str] = None, | ||
embedding_params: Optional[Dict[str, Any]] = None, | ||
index: Optional[AnnoyIndex] = None, | ||
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, | ||
search_threshold: Optional[float] = None, | ||
use_batching: bool = False, | ||
max_batch_size: int = 10, | ||
max_batch_hold: float = 0.01, | ||
|
@@ -81,10 +71,10 @@ def __init__( | |
max_batch_hold: The maximum time a batch is held before being processed | ||
""" | ||
self._model: Optional[EmbeddingModel] = None | ||
self._items = [] | ||
self._embeddings = [] | ||
self.embedding_model = embedding_model | ||
self.embedding_engine = embedding_engine | ||
self._items: List[IndexItem] = [] | ||
self._embeddings: List[List[float]] = [] | ||
self.embedding_model: Optional[str] = embedding_model | ||
self.embedding_engine: Optional[str] = embedding_engine | ||
self.embedding_params = embedding_params or {} | ||
self._embedding_size = 0 | ||
self.search_threshold = search_threshold or float("inf") | ||
|
@@ -95,12 +85,12 @@ def __init__( | |
self._index = index | ||
|
||
# Data structures for batching embedding requests | ||
self._req_queue = {} | ||
self._req_results = {} | ||
self._req_idx = 0 | ||
self._current_batch_finished_event = None | ||
self._current_batch_full_event = None | ||
self._current_batch_submitted = asyncio.Event() | ||
self._req_queue: Dict[int, str] = {} | ||
self._req_results: Dict[int, List[float]] = {} | ||
self._req_idx: int = 0 | ||
self._current_batch_finished_event: Optional[asyncio.Event] = None | ||
self._current_batch_full_event: Optional[asyncio.Event] = None | ||
self._current_batch_submitted: asyncio.Event = asyncio.Event() | ||
|
||
# Initialize the batching configuration | ||
self.use_batching = use_batching | ||
|
@@ -112,6 +102,11 @@ def embeddings_index(self): | |
"""Get the current embedding index""" | ||
return self._index | ||
|
||
@embeddings_index.setter | ||
def embeddings_index(self, index): | ||
"""Setter to allow replacing the index dynamically.""" | ||
self._index = index | ||
|
||
@property | ||
def cache_config(self): | ||
"""Get the cache configuration.""" | ||
|
@@ -127,19 +122,23 @@ def embeddings(self): | |
"""Get the computed embeddings.""" | ||
return self._embeddings | ||
|
||
@embeddings_index.setter | ||
def embeddings_index(self, index): | ||
"""Setter to allow replacing the index dynamically.""" | ||
self._index = index | ||
|
||
def _init_model(self): | ||
"""Initialize the model used for computing the embeddings.""" | ||
# Provide defaults if not specified | ||
model = self.embedding_model or "sentence-transformers/all-MiniLM-L6-v2" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have some defaults in On my side, there are some type checking errors that might happen (for example, having |
||
engine = self.embedding_engine or "SentenceTransformers" | ||
|
||
self._model = init_embedding_model( | ||
embedding_model=self.embedding_model, | ||
embedding_engine=self.embedding_engine, | ||
embedding_model=model, | ||
embedding_engine=engine, | ||
embedding_params=self.embedding_params, | ||
) | ||
|
||
if not self._model: | ||
raise ValueError( | ||
f"Couldn't create embedding model with model {model} and engine {engine}" | ||
) | ||
|
||
@cache_embeddings | ||
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
"""Compute embeddings for a list of texts. | ||
|
@@ -153,6 +152,8 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: | |
if self._model is None: | ||
self._init_model() | ||
|
||
if not self._model: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are already throwing an |
||
raise Exception("Couldn't initialize embedding model") | ||
embeddings = await self._model.encode_async(texts) | ||
return embeddings | ||
|
||
|
@@ -199,6 +200,10 @@ async def _run_batch(self): | |
"""Runs the current batch of embeddings.""" | ||
|
||
# Wait up to `max_batch_hold` time or until `max_batch_size` is reached. | ||
if not self._current_batch_full_event: | ||
raise Exception("self._current_batch_full_event not initialized") | ||
|
||
assert self._current_batch_full_event is not None | ||
done, pending = await asyncio.wait( | ||
[ | ||
asyncio.create_task(asyncio.sleep(self.max_batch_hold)), | ||
|
@@ -210,6 +215,10 @@ async def _run_batch(self): | |
task.cancel() | ||
|
||
# Reset the batch event | ||
if not self._current_batch_finished_event: | ||
raise Exception("self._current_batch_finished_event not initialized") | ||
|
||
assert self._current_batch_finished_event is not None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this redundant? also it must not use assert. |
||
batch_event: asyncio.Event = self._current_batch_finished_event | ||
self._current_batch_finished_event = None | ||
|
||
|
@@ -252,9 +261,13 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: | |
|
||
# We check if we reached the max batch size | ||
if len(self._req_queue) >= self.max_batch_size: | ||
if not self._current_batch_full_event: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this is redundant, |
||
raise Exception("self._current_batch_full_event not initialized") | ||
self._current_batch_full_event.set() | ||
|
||
# Wait for the batch to finish | ||
# Wait for the batch to finish | ||
if not self._current_batch_finished_event: | ||
raise Exception("self._current_batch_finished_event not initialized") | ||
await self._current_batch_finished_event.wait() | ||
|
||
# Remove the result and return it | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,12 @@ | |
from abc import ABC, abstractmethod | ||
from functools import singledispatchmethod | ||
from pathlib import Path | ||
from typing import Dict, List | ||
from typing import Dict, List, Optional | ||
|
||
try: | ||
import redis # type: ignore | ||
except ImportError: | ||
redis = None # type: ignore | ||
|
||
from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig | ||
|
||
|
@@ -30,18 +35,20 @@ | |
class KeyGenerator(ABC): | ||
"""Abstract class for key generators.""" | ||
|
||
name: str # Class attribute that should be defined in subclasses | ||
|
||
@abstractmethod | ||
def generate_key(self, text: str) -> str: | ||
pass | ||
|
||
@classmethod | ||
def from_name(cls, name): | ||
for subclass in cls.__subclasses__(): | ||
if subclass.name == name: | ||
if hasattr(subclass, "name") and subclass.name == name: | ||
return subclass | ||
raise ValueError( | ||
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " | ||
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" | ||
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" | ||
". Make sure to import the derived class before using it." | ||
) | ||
|
||
|
@@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str: | |
class CacheStore(ABC): | ||
"""Abstract class for cache stores.""" | ||
|
||
name: str # Class attribute that should be defined in subclasses | ||
|
||
@abstractmethod | ||
def get(self, key): | ||
"""Get a value from the cache.""" | ||
|
@@ -94,11 +103,11 @@ def clear(self): | |
@classmethod | ||
def from_name(cls, name): | ||
for subclass in cls.__subclasses__(): | ||
if subclass.name == name: | ||
if hasattr(subclass, "name") and subclass.name == name: | ||
return subclass | ||
raise ValueError( | ||
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " | ||
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" | ||
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" | ||
". Make sure to import the derived class before using it." | ||
) | ||
|
||
|
@@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore): | |
|
||
name = "filesystem" | ||
|
||
def __init__(self, cache_dir: str = None): | ||
def __init__(self, cache_dir: Optional[str] = None): | ||
self._cache_dir = Path(cache_dir or ".cache/embeddings") | ||
self._cache_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
|
@@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore): | |
name = "redis" | ||
|
||
def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0): | ||
import redis | ||
|
||
if redis is None: | ||
raise ImportError( | ||
"Could not import redis, please install it with `pip install redis`." | ||
) | ||
self._redis = redis.Redis(host=host, port=port, db=db) | ||
|
||
def get(self, key): | ||
|
@@ -207,9 +218,9 @@ def clear(self): | |
class EmbeddingsCache: | ||
def __init__( | ||
self, | ||
key_generator: KeyGenerator = None, | ||
cache_store: CacheStore = None, | ||
store_config: dict = None, | ||
key_generator: Optional[KeyGenerator] = None, | ||
cache_store: Optional[CacheStore] = None, | ||
store_config: Optional[dict] = None, | ||
): | ||
self._key_generator = key_generator | ||
self._cache_store = cache_store | ||
|
@@ -218,7 +229,10 @@ def __init__( | |
@classmethod | ||
def from_dict(cls, d: Dict[str, str]): | ||
key_generator = KeyGenerator.from_name(d.get("key_generator"))() | ||
store_config = d.get("store_config") | ||
store_config_raw = d.get("store_config") | ||
store_config: dict = ( | ||
store_config_raw if isinstance(store_config_raw, dict) else {} | ||
) | ||
cache_store = CacheStore.from_name(d.get("store"))(**store_config) | ||
|
||
return cls(key_generator=key_generator, cache_store=cache_store) | ||
|
@@ -230,25 +244,27 @@ def from_config(cls, config: EmbeddingsCacheConfig): | |
|
||
def get_config(self): | ||
return EmbeddingsCacheConfig( | ||
key_generator=self._key_generator.name, | ||
store=self._cache_store.name, | ||
key_generator=self._key_generator.name if self._key_generator else "sha256", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i find defining defaults here problematic. What Pyright did not like about it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There already is a default here in
|
||
store=self._cache_store.name if self._cache_store else "filesystem", | ||
store_config=self._store_config, | ||
) | ||
|
||
@singledispatchmethod | ||
def get(self, texts): | ||
raise NotImplementedError | ||
|
||
@get.register | ||
@get.register(str) | ||
def _(self, text: str): | ||
if self._key_generator is None or self._cache_store is None: | ||
return None | ||
key = self._key_generator.generate_key(text) | ||
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") | ||
|
||
result = self._cache_store.get(key) | ||
|
||
return result | ||
|
||
@get.register | ||
@get.register(list) | ||
def _(self, texts: list): | ||
cached = {} | ||
|
||
|
@@ -266,19 +282,22 @@ def _(self, texts: list): | |
def set(self, texts): | ||
raise NotImplementedError | ||
|
||
@set.register | ||
@set.register(str) | ||
def _(self, text: str, value: List[float]): | ||
if self._key_generator is None or self._cache_store is None: | ||
return | ||
key = self._key_generator.generate_key(text) | ||
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") | ||
self._cache_store.set(key, value) | ||
|
||
@set.register | ||
@set.register(list) | ||
def _(self, texts: list, values: List[List[float]]): | ||
for text, value in zip(texts, values): | ||
self.set(text, value) | ||
|
||
def clear(self): | ||
self._cache_store.clear() | ||
if self._cache_store is not None: | ||
self._cache_store.clear() | ||
|
||
|
||
def cache_embeddings(func): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move the defaults to constructor? at line 52