Skip to content

feat: support method chaining by returning self from LLMRails.register_* methods #1296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from typing_extensions import Self

from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import (
Expand Down Expand Up @@ -1275,33 +1276,38 @@ def process_events(
self.process_events_async(events, state, blocking)
)

def register_action(self, action: callable, name: Optional[str] = None):
def register_action(self, action: callable, name: Optional[str] = None) -> Self:
"""Register a custom action for the rails configuration."""
self.runtime.register_action(action, name)
return self

def register_action_param(self, name: str, value: Any):
def register_action_param(self, name: str, value: Any) -> Self:
"""Registers a custom action parameter."""
self.runtime.register_action_param(name, value)
return self

def register_filter(self, filter_fn: callable, name: Optional[str] = None):
def register_filter(self, filter_fn: callable, name: Optional[str] = None) -> Self:
"""Register a custom filter for the rails configuration."""
self.runtime.llm_task_manager.register_filter(filter_fn, name)
return self

def register_output_parser(self, output_parser: callable, name: str):
def register_output_parser(self, output_parser: callable, name: str) -> Self:
"""Register a custom output parser for the rails configuration."""
self.runtime.llm_task_manager.register_output_parser(output_parser, name)
return self

def register_prompt_context(self, name: str, value_or_fn: Any):
def register_prompt_context(self, name: str, value_or_fn: Any) -> Self:
"""Register a value to be included in the prompt context.

:name: The name of the variable or function that will be used.
:value_or_fn: The value or function that will be used to generate the value.
"""
self.runtime.llm_task_manager.register_prompt_context(name, value_or_fn)
return self

def register_embedding_search_provider(
self, name: str, cls: Type[EmbeddingsIndex]
) -> None:
) -> Self:
"""Register a new embedding search provider.

Args:
Expand All @@ -1310,10 +1316,11 @@ def register_embedding_search_provider(
"""

self.embedding_search_providers[name] = cls
return self

def register_embedding_provider(
self, cls: Type[EmbeddingModel], name: Optional[str] = None
) -> None:
) -> Self:
"""Register a custom embedding provider.

Args:
Expand All @@ -1325,6 +1332,7 @@ def register_embedding_provider(
ValueError: If the model does not have 'encode' or 'encode_async' methods.
"""
register_embedding_provider(engine_name=name, model=cls)
return self

def explain(self) -> ExplainInfo:
"""Helper function to return the latest ExplainInfo object."""
Expand Down
97 changes: 97 additions & 0 deletions tests/test_llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,3 +1155,100 @@ async def test_stream_usage_enabled_for_all_providers_when_streaming(

# stream_usage should be set for all providers when streaming is enabled
assert kwargs.get("stream_usage") is True


# Add this test after the existing tests, around line 1100+


def test_register_methods_return_self():
"""Test that all register_* methods return self for method chaining."""
config = RailsConfig.from_content(config={"models": []})
rails = LLMRails(config=config, llm=FakeLLM(responses=[]))

# Test register_action returns self
def dummy_action():
pass

result = rails.register_action(dummy_action, "test_action")
assert result is rails, "register_action should return self"

# Test register_action_param returns self
result = rails.register_action_param("test_param", "test_value")
assert result is rails, "register_action_param should return self"

# Test register_filter returns self
def dummy_filter(text):
return text

result = rails.register_filter(dummy_filter, "test_filter")
assert result is rails, "register_filter should return self"

# Test register_output_parser returns self
def dummy_parser(text):
return text

result = rails.register_output_parser(dummy_parser, "test_parser")
assert result is rails, "register_output_parser should return self"

# Test register_prompt_context returns self
result = rails.register_prompt_context("test_context", "test_value")
assert result is rails, "register_prompt_context should return self"

# Test register_embedding_search_provider returns self
from nemoguardrails.embeddings.index import EmbeddingsIndex

class DummyEmbeddingProvider(EmbeddingsIndex):
def __init__(self, **kwargs):
pass

def build(self):
pass

def search(self, text, max_results=5):
return []

result = rails.register_embedding_search_provider(
"dummy_provider", DummyEmbeddingProvider
)
assert result is rails, "register_embedding_search_provider should return self"

# Test register_embedding_provider returns self
from nemoguardrails.embeddings.providers.base import EmbeddingModel

class DummyEmbeddingModel(EmbeddingModel):
def encode(self, texts):
return []

result = rails.register_embedding_provider(DummyEmbeddingModel, "dummy_embedding")
assert result is rails, "register_embedding_provider should return self"


def test_method_chaining():
"""Test that method chaining works correctly with register_* methods."""
config = RailsConfig.from_content(config={"models": []})
rails = LLMRails(config=config, llm=FakeLLM(responses=[]))

def dummy_action():
return "action_result"

def dummy_filter(text):
return text.upper()

def dummy_parser(text):
return {"parsed": text}

# Test chaining multiple register methods
result = (
rails.register_action(dummy_action, "chained_action")
.register_action_param("chained_param", "param_value")
.register_filter(dummy_filter, "chained_filter")
.register_output_parser(dummy_parser, "chained_parser")
.register_prompt_context("chained_context", "context_value")
)

assert result is rails, "Method chaining should return the same rails instance"

# Verify that all registrations actually worked
assert "chained_action" in rails.runtime.action_dispatcher.registered_actions
assert "chained_param" in rails.runtime.registered_action_params
assert rails.runtime.registered_action_params["chained_param"] == "param_value"