Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def init(self):
self._init_flows_index(),
)

def _extract_user_message_example(self, flow: Flow):
def _extract_user_message_example(self, flow: Flow) -> None:
"""Heuristic to extract user message examples from a flow."""
elements = [
item
Expand Down
105 changes: 70 additions & 35 deletions nemoguardrails/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
import warnings
from contextlib import asynccontextmanager
from typing import Any, List, Optional
from typing import Any, Callable, List, Optional

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -42,14 +42,32 @@
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)


class GuardrailsApp(FastAPI):
"""Custom FastAPI subclass with additional attributes for Guardrails server."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize custom attributes
self.default_config_id: Optional[str] = None
self.rails_config_path: str = ""
self.disable_chat_ui: bool = False
self.auto_reload: bool = False
self.stop_signal: bool = False
self.single_config_mode: bool = False
self.single_config_id: Optional[str] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.task: Optional[asyncio.Future] = None


# The list of registered loggers. Can be used to send logs to various
# backends and storage engines.
registered_loggers = []
registered_loggers: List[Callable] = []

api_description = """Guardrails Sever API."""

# The headers for each request
api_request_headers = contextvars.ContextVar("headers")
api_request_headers: contextvars.ContextVar = contextvars.ContextVar("headers")

# The datastore that the Server should use.
# This is currently used only for storing threads.
Expand All @@ -59,7 +77,7 @@


@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: GuardrailsApp):
# Startup logic here
"""Register any additional challenges, if available at startup."""
challenges_files = os.path.join(app.rails_config_path, "challenges.json")
Expand All @@ -82,8 +100,11 @@ async def lifespan(app: FastAPI):
if os.path.exists(filepath):
filename = os.path.basename(filepath)
spec = importlib.util.spec_from_file_location(filename, filepath)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if spec is not None and spec.loader is not None:
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
else:
config_module = None

# If there is an `init` function, we call it with the reference to the app.
if config_module is not None and hasattr(config_module, "init"):
Expand All @@ -110,21 +131,22 @@ async def root_handler():

if app.auto_reload:
app.loop = asyncio.get_running_loop()
# Store the future directly as task
app.task = app.loop.run_in_executor(None, start_auto_reload_monitoring)

yield

# Shutdown logic here
if app.auto_reload:
app.stop_signal = True
if hasattr(app, "task"):
if hasattr(app, "task") and app.task is not None:
app.task.cancel()
log.info("Shutting down file observer")
else:
pass


app = FastAPI(
app = GuardrailsApp(
title="Guardrails Server API",
description=api_description,
version="0.1.0",
Expand Down Expand Up @@ -186,7 +208,7 @@ class RequestBody(BaseModel):
max_length=255,
description="The id of an existing thread to which the messages should be added.",
)
messages: List[dict] = Field(
messages: Optional[List[dict]] = Field(
default=None, description="The list of messages in the current conversation."
)
context: Optional[dict] = Field(
Expand Down Expand Up @@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values):


class ResponseBody(BaseModel):
messages: List[dict] = Field(
messages: Optional[List[dict]] = Field(
default=None, description="The new messages in the conversation"
)
llm_output: Optional[dict] = Field(
Expand Down Expand Up @@ -282,8 +304,8 @@ async def get_rails_configs():


# One instance of LLMRails per config id
llm_rails_instances = {}
llm_rails_events_history_cache = {}
llm_rails_instances: dict[str, LLMRails] = {}
llm_rails_events_history_cache: dict[str, dict] = {}


def _generate_cache_key(config_ids: List[str]) -> str:
Expand All @@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
# get the same thing.
config_ids = [""]

full_llm_rails_config = None
full_llm_rails_config: Optional[RailsConfig] = None

for config_id in config_ids:
base_path = os.path.abspath(app.rails_config_path)
Expand All @@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
else:
full_llm_rails_config += rails_config

if full_llm_rails_config is None:
raise ValueError("No valid rails configuration found.")

llm_rails = LLMRails(config=full_llm_rails_config, verbose=True)
llm_rails_instances[configs_cache_key] = llm_rails

Expand Down Expand Up @@ -368,22 +393,27 @@ async def chat_completion(body: RequestBody, request: Request):
"No 'config_id' provided and no default configuration is set for the server. "
"You must set a 'config_id' in your request or set use --default-config-id when . "
)

# Ensure config_ids is not None before passing to _get_rails
if config_ids is None:
raise GuardrailsConfigurationError("No valid configuration IDs available.")

try:
llm_rails = _get_rails(config_ids)
except ValueError as ex:
log.exception(ex)
return {
"messages": [
return ResponseBody(
messages=[
{
"role": "assistant",
"content": f"Could not load the {config_ids} guardrails configuration. "
f"An internal error has occurred.",
}
]
}
)

try:
messages = body.messages
messages = body.messages or []
if body.context:
messages.insert(0, {"role": "context", "content": body.context})

Expand All @@ -396,14 +426,14 @@ async def chat_completion(body: RequestBody, request: Request):

# We make sure the `thread_id` meets the minimum complexity requirement.
if len(body.thread_id) < 16:
return {
"messages": [
return ResponseBody(
messages=[
{
"role": "assistant",
"content": "The `thread_id` must have a minimum length of 16 characters.",
}
]
}
)

# Fetch the existing thread messages. For easier management, we prepend
# the string `thread-` to all thread keys.
Expand Down Expand Up @@ -440,32 +470,37 @@ async def chat_completion(body: RequestBody, request: Request):
)

if isinstance(res, GenerationResponse):
bot_message = res.response[0]
bot_message_content = res.response[0]
# Ensure bot_message is always a dict
if isinstance(bot_message_content, str):
bot_message = {"role": "assistant", "content": bot_message_content}
else:
bot_message = bot_message_content
else:
assert isinstance(res, dict)
bot_message = res

# If we're using threads, we also need to update the data before returning
# the message.
if body.thread_id:
if body.thread_id and datastore is not None and datastore_key is not None:
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))

result = {"messages": [bot_message]}
result = ResponseBody(messages=[bot_message])

# If we have additional GenerationResponse fields, we return as well
if isinstance(res, GenerationResponse):
result["llm_output"] = res.llm_output
result["output_data"] = res.output_data
result["log"] = res.log
result["state"] = res.state
result.llm_output = res.llm_output
result.output_data = res.output_data
result.log = res.log
result.state = res.state

return result

except Exception as ex:
log.exception(ex)
return {
"messages": [{"role": "assistant", "content": "Internal server error."}]
}
return ResponseBody(
messages=[{"role": "assistant", "content": "Internal server error."}]
)


# By default, there are no challenges
Expand Down Expand Up @@ -498,7 +533,7 @@ def register_datastore(datastore_instance: DataStore):
datastore = datastore_instance


def register_logger(logger: callable):
def register_logger(logger: Callable):
"""Register an additional logger"""
registered_loggers.append(logger)

Expand All @@ -510,8 +545,7 @@ def start_auto_reload_monitoring():
from watchdog.observers import Observer

class Handler(FileSystemEventHandler):
@staticmethod
def on_any_event(event):
def on_any_event(self, event):
if event.is_directory:
return None

Expand All @@ -521,7 +555,8 @@ def on_any_event(event):
)

# Compute the relative path
rel_path = os.path.relpath(event.src_path, app.rails_config_path)
src_path_str = str(event.src_path)
rel_path = os.path.relpath(src_path_str, app.rails_config_path)

# The config_id is the first component
parts = rel_path.split(os.path.sep)
Expand All @@ -530,7 +565,7 @@ def on_any_event(event):
if (
not parts[-1].startswith(".")
and ".ipynb_checkpoints" not in parts
and os.path.isfile(event.src_path)
and os.path.isfile(src_path_str)
):
# We just remove the config from the cache so that a new one is used next time
if config_id in llm_rails_instances:
Expand Down
10 changes: 9 additions & 1 deletion nemoguardrails/server/datastore/redis_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import asyncio
from typing import Optional

import aioredis
try:
import aioredis # type: ignore[import]
except ImportError:
aioredis = None # type: ignore[assignment]

from nemoguardrails.server.datastore.datastore import DataStore

Expand All @@ -35,6 +38,11 @@ def __init__(
username: [Optional] The username to use for authentication.
password: [Optional] The password to use for authentication
"""
if aioredis is None:
raise ImportError(
"aioredis is required for RedisStore. Install it with: pip install aioredis"
)

self.url = url
self.username = username
self.password = password
Expand Down