Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions autogen/cache/abstract_cache_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, Optional, Type
import sys

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class AbstractCache(ABC):
Expand All @@ -11,7 +19,7 @@ class AbstractCache(ABC):
"""

@abstractmethod
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the cache.

Expand All @@ -31,7 +39,7 @@ def get(self, key, default=None):
"""

@abstractmethod
def set(self, key, value):
def set(self, key: str, value: Any) -> None:
"""
Set an item in the cache.

Expand All @@ -47,7 +55,7 @@ def set(self, key, value):
"""

@abstractmethod
def close(self):
def close(self) -> None:
"""
Close the cache.

Expand All @@ -60,7 +68,7 @@ def close(self):
"""

@abstractmethod
def __enter__(self):
def __enter__(self) -> Self:
"""
Enter the runtime context related to this object.

Expand All @@ -72,7 +80,12 @@ def __enter__(self):
"""

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Exit the runtime context and close the cache.

Expand Down
38 changes: 27 additions & 11 deletions autogen/cache/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from typing import Dict, Any
from __future__ import annotations
from types import TracebackType
from typing import Dict, Any, Optional, Type, Union

from autogen.cache.cache_factory import CacheFactory
from .abstract_cache_base import AbstractCache

from .cache_factory import CacheFactory

import sys

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class Cache:
Expand All @@ -19,12 +30,12 @@ class Cache:
ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]

@staticmethod
def redis(cache_seed=42, redis_url="redis://localhost:6379/0"):
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> Cache:
"""
Create a Redis cache instance.

Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
cache_seed (Union[str, int], optional): A seed for the cache. Defaults to 42.
redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0".

Returns:
Expand All @@ -33,12 +44,12 @@ def redis(cache_seed=42, redis_url="redis://localhost:6379/0"):
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})

@staticmethod
def disk(cache_seed=42, cache_path_root=".cache"):
def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> Cache:
"""
Create a Disk cache instance.

Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
cache_seed (Union[str, int], optional): A seed for the cache. Defaults to 42.
cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache".

Returns:
Expand Down Expand Up @@ -70,7 +81,7 @@ def __init__(self, config: Dict[str, Any]):
self.config.get("cache_path_root", None),
)

def __enter__(self):
def __enter__(self) -> AbstractCache:
"""
Enter the runtime context related to the cache object.

Expand All @@ -79,7 +90,12 @@ def __enter__(self):
"""
return self.cache.__enter__()

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Exit the runtime context related to the cache object.

Expand All @@ -93,7 +109,7 @@ def __exit__(self, exc_type, exc_value, traceback):
"""
return self.cache.__exit__(exc_type, exc_value, traceback)

def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the cache.

Expand All @@ -107,7 +123,7 @@ def get(self, key, default=None):
"""
return self.cache.get(key, default)

def set(self, key, value):
def set(self, key: str, value: Any) -> None:
"""
Set an item in the cache.

Expand All @@ -117,7 +133,7 @@ def set(self, key, value):
"""
self.cache.set(key, value)

def close(self):
def close(self) -> None:
"""
Close the cache.

Expand Down
24 changes: 14 additions & 10 deletions autogen/cache/cache_factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from autogen.cache.disk_cache import DiskCache

try:
from autogen.cache.redis_cache import RedisCache
except ImportError:
RedisCache = None
from typing import Optional, Union, Type
from .abstract_cache_base import AbstractCache
from .disk_cache import DiskCache


class CacheFactory:
@staticmethod
def cache_factory(seed, redis_url=None, cache_path_root=".cache"):
def cache_factory(
seed: Union[str, int], redis_url: Optional[str] = None, cache_path_root: str = ".cache"
) -> AbstractCache:
"""
Factory function for creating cache instances.

Expand All @@ -17,7 +16,7 @@ def cache_factory(seed, redis_url=None, cache_path_root=".cache"):
a RedisCache instance is created. Otherwise, a DiskCache instance is used.

Args:
seed (str): A string used as a seed or namespace for the cache.
seed (Union[str, int]): A string or int used as a seed or namespace for the cache.
This could be useful for creating distinct cache instances
or for namespacing keys in the cache.
redis_url (str or None): The URL for the Redis server. If this is None
Expand All @@ -40,7 +39,12 @@ def cache_factory(seed, redis_url=None, cache_path_root=".cache"):
disk_cache = cache_factory("myseed", None)
```
"""
if RedisCache is not None and redis_url is not None:
return RedisCache(seed, redis_url)
if redis_url is not None:
try:
from .redis_cache import RedisCache

return RedisCache(seed, redis_url)
except ImportError:
return DiskCache(f"./{cache_path_root}/{seed}")
else:
return DiskCache(f"./{cache_path_root}/{seed}")
27 changes: 20 additions & 7 deletions autogen/cache/disk_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from types import TracebackType
from typing import Any, Optional, Type, Union
import diskcache
from .abstract_cache_base import AbstractCache
import sys

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class DiskCache(AbstractCache):
Expand All @@ -21,18 +29,18 @@ class DiskCache(AbstractCache):
__exit__(self, exc_type, exc_value, traceback): Context management exit.
"""

def __init__(self, seed):
def __init__(self, seed: Union[str, int]):
"""
Initialize the DiskCache instance.

Args:
seed (str): A seed or namespace for the cache. This is used to create
seed (Union[str, int]): A seed or namespace for the cache. This is used to create
a unique storage location for the cache data.

"""
self.cache = diskcache.Cache(seed)

def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the cache.

Expand All @@ -46,7 +54,7 @@ def get(self, key, default=None):
"""
return self.cache.get(key, default)

def set(self, key, value):
def set(self, key: str, value: Any) -> None:
"""
Set an item in the cache.

Expand All @@ -56,7 +64,7 @@ def set(self, key, value):
"""
self.cache.set(key, value)

def close(self):
def close(self) -> None:
"""
Close the cache.

Expand All @@ -65,7 +73,7 @@ def close(self):
"""
self.cache.close()

def __enter__(self):
def __enter__(self) -> Self:
"""
Enter the runtime context related to the object.

Expand All @@ -74,7 +82,12 @@ def __enter__(self):
"""
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Exit the runtime context related to the object.

Expand Down
10 changes: 5 additions & 5 deletions autogen/cache/redis_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
from types import TracebackType
from typing import Any, Optional, Type
from typing import Any, Optional, Type, Union
import redis
import sys
from .abstract_cache_base import AbstractCache
Expand All @@ -19,7 +19,7 @@ class RedisCache(AbstractCache):
interface using the Redis database for caching data.

Attributes:
seed (str): A seed or namespace used as a prefix for cache keys.
seed (Union[str, int]): A seed or namespace used as a prefix for cache keys.
cache (redis.Redis): The Redis client used for caching.

Methods:
Expand All @@ -32,12 +32,12 @@ class RedisCache(AbstractCache):
__exit__(self, exc_type, exc_value, traceback): Context management exit.
"""

def __init__(self, seed: str, redis_url: str):
def __init__(self, seed: Union[str, int], redis_url: str):
"""
Initialize the RedisCache instance.

Args:
seed (str): A seed or namespace for the cache. This is used as a prefix for all cache keys.
seed (Union[str, int]): A seed or namespace for the cache. This is used as a prefix for all cache keys.
redis_url (str): The URL for the Redis server.

"""
Expand All @@ -56,7 +56,7 @@ def _prefixed_key(self, key: str) -> str:
"""
return f"autogen:{self.seed}:{key}"

def get(self, key: str, default: Optional[Any] = None) -> Any:
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the Redis cache.

Expand Down