diff --git a/examples/configs/guardrails_ai/config.yml b/examples/configs/guardrails_ai/config.yml new file mode 100644 index 000000000..c4c3d6b3c --- /dev/null +++ b/examples/configs/guardrails_ai/config.yml @@ -0,0 +1,33 @@ +models: + - type: main + engine: openai + model: gpt-4 + +rails: + config: + guardrails_ai: + validators: + - name: toxic_language + parameters: + threshold: 0.5 + validation_method: "sentence" + metadata: {} + - name: guardrails_pii + parameters: + entities: ["phone_number", "email", "ssn"] + metadata: {} + - name: competitor_check + parameters: + competitors: ["Apple", "Google", "Microsoft"] + metadata: {} + - name: restricttotopic + parameters: + valid_topics: ["technology", "science", "education"] + metadata: {} + input: + flows: + - guardrailsai check input $validator="guardrails_pii" + - guardrailsai check input $validator="competitor_check" + output: + flows: + - guardrailsai check output $validator="restricttotopic" diff --git a/nemoguardrails/library/guardrails_ai/__init__.py b/nemoguardrails/library/guardrails_ai/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/library/guardrails_ai/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/library/guardrails_ai/actions.py b/nemoguardrails/library/guardrails_ai/actions.py new file mode 100644 index 000000000..a55e0fdb5 --- /dev/null +++ b/nemoguardrails/library/guardrails_ai/actions.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic validator loading for Guardrails AI integration.""" + +import importlib +import logging +from functools import lru_cache +from typing import Any, Dict, Optional, Type + +try: + from guardrails import Guard +except ImportError: + # Mock Guard class for when guardrails is not available + class Guard: + def __init__(self): + pass + + def use(self, validator): + return self + + def validate(self, text, metadata=None): + return None + + +from nemoguardrails.actions import action +from nemoguardrails.library.guardrails_ai.errors import GuardrailsAIValidationError +from nemoguardrails.library.guardrails_ai.registry import get_validator_info +from nemoguardrails.rails.llm.config import RailsConfig + +log = logging.getLogger(__name__) + + +# cache for loaded validator classes and guard instances +_validator_class_cache: Dict[str, Type] = {} +_guard_cache: Dict[tuple, Guard] = {} + + +def guardrails_ai_validation_mapping(result: Dict[str, Any]) -> bool: + """Map Guardrails AI validation result to NeMo Guardrails format.""" + # The Guardrails AI `validate` method returns a ValidationResult object. + # On failure (PII found, Jailbreak detected, etc.), it's often a FailResult. + # Both PassResult and FailResult have a `validation_passed` boolean attribute + # which indicates if the validation criteria were met. + # FailResult also often contains `fixed_value` if a fix like anonymization was applied. + # We map `validation_passed=False` to `True` (block) and `validation_passed=True` to `False` (don't block). + validation_result = result.get("validation_result", {}) + + # Handle both dict and object formats + if hasattr(validation_result, "validation_passed"): + valid = validation_result.validation_passed + else: + valid = validation_result.get("validation_passed", False) + + return valid # {"valid": valid, "validation_result": validation_result} + + +# TODO: we need to do this +# from guardrails.hub import RegexMatch, ValidLength +# from guardrails import Guard +# +# guard = Guard().use_many( +# RegexMatch(regex="^[A-Z][a-z]*$"), +# ValidLength(min=1, max=12) +# ) +# +# print(guard.parse("Caesar").validation_passed) # Guardrail Passes +# print( +# guard.parse("Caesar Salad") +# .validation_passed +# ) # Guardrail Fails due to regex match +# print( +# guard.parse("Caesarisagreatleader") +# .validation_passed +# ) # Guardrail Fails due to length + + +@action( + name="validate_guardrails_ai_input", + output_mapping=guardrails_ai_validation_mapping, + is_system_action=False, +) +def validate_guardrails_ai_input( + validator: str, + config: RailsConfig, + context: Optional[dict] = None, + text: Optional[str] = None, + **kwargs, +) -> Dict[str, Any]: + """Unified action for all Guardrails AI validators. + + Args: + validator: Name of the validator to use (from VALIDATOR_REGISTRY) + text: Text to validate + context: Optional context dictionary + + Returns: + Dict with validation_result + """ + + text = text or context.get("user_message", "") + if not text: + raise ValueError("Either 'text' or 'context' must be provided.") + + validator_config = config.rails.config.guardrails_ai.get_validator_config(validator) + parameters = validator_config.parameters or {} + metadata = validator_config.metadata or {} + + joined_parameters = {**parameters, **metadata} + + validation_result = validate_guardrails_ai(validator, text, **joined_parameters) + + # Transform to the expected format for Colang flows + return validation_result + + +@action( + name="validate_guardrails_ai_output", + output_mapping=guardrails_ai_validation_mapping, + is_system_action=False, +) +def validate_guardrails_ai_output( + validator: str, + context: Optional[dict] = None, + text: Optional[str] = None, + config: Optional[RailsConfig] = None, + **kwargs, +) -> Dict[str, Any]: + """Unified action for all Guardrails AI validators. + + Args: + validator: Name of the validator to use (from VALIDATOR_REGISTRY) + text: Text to validate + context: Optional context dictionary + + Returns: + Dict with validation_result + """ + + text = text or context.get("bot_message", "") + if not text: + raise ValueError("Either 'text' or 'context' must be provided.") + + validator_config = config.rails.config.guardrails_ai.get_validator_config(validator) + parameters = validator_config.parameters or {} + metadata = validator_config.metadata or {} + + # join parameters and metadata into a single dict + joined_parameters = {**parameters, **metadata} + + validation_result = validate_guardrails_ai(validator, text, **joined_parameters) + + return validation_result + + +def validate_guardrails_ai(validator_name: str, text: str, **kwargs) -> Dict[str, Any]: + """Unified action for all Guardrails AI validators. + + Args: + validator: Name of the validator to use (from VALIDATOR_REGISTRY) + text: Text to validate + + + Returns: + Dict with validation_result + """ + + try: + # extract metadata if provided as a dict + + metadata = kwargs.pop("metadata", {}) + validator_params = kwargs + + validator_params = {k: v for k, v in validator_params.items() if v is not None} + + # get or create the guard with all non-metadata params + guard = _get_guard(validator_name, **validator_params) + + try: + validation_result = guard.validate(text, metadata=metadata) + return {"validation_result": validation_result} + except GuardrailsAIValidationError as e: + # handle Guardrails validation errors (when on_fail="exception") + # return a failed validation result instead of raising + log.warning(f"Guardrails validation failed for {validator_name}: {str(e)}") + + # create a mock validation result for failed validations + class FailedValidation: + validation_passed = False + error = str(e) + + return {"validation_result": FailedValidation()} + + except Exception as e: + log.error(f"Error validating with {validator_name}: {str(e)}") + raise GuardrailsAIValidationError(f"Validation failed: {str(e)}") + + +@lru_cache(maxsize=None) +def _load_validator_class(validator_name: str) -> Type: + """Dynamically load a validator class.""" + cache_key = f"class_{validator_name}" + + if cache_key in _validator_class_cache: + return _validator_class_cache[cache_key] + + try: + validator_info = get_validator_info(validator_name) + + module_name = validator_info["module"] + class_name = validator_info["class"] + + try: + module = importlib.import_module(module_name) + validator_class = getattr(module, class_name) + _validator_class_cache[cache_key] = validator_class + return validator_class + except (ImportError, AttributeError): + log.warning( + f"Could not import {class_name} from {module_name}. " + f"Make sure to install it first: guardrails hub install {validator_info['hub_path']}" + ) + raise ImportError( + f"Validator {validator_name} not installed. " + f"Install with: guardrails hub install {validator_info['hub_path']}" + ) + + except Exception as e: + raise ImportError(f"Failed to load validator {validator_name}: {str(e)}") + + +def _get_guard(validator_name: str, **validator_params) -> Guard: + """Get or create a Guard instance for a validator.""" + + # create a hashable cache key + def make_hashable(obj): + if isinstance(obj, list): + return tuple(obj) + elif isinstance(obj, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + return obj + + cache_items = [(k, make_hashable(v)) for k, v in validator_params.items()] + cache_key = (validator_name, tuple(sorted(cache_items))) + + if cache_key not in _guard_cache: + validator_class = _load_validator_class(validator_name) + + # TODO(@zayd): is this needed? + # default handling for all validators + if "on_fail" not in validator_params: + validator_params["on_fail"] = "noop" + + try: + validator_instance = validator_class(**validator_params) + except TypeError as e: + log.error( + f"Failed to instantiate {validator_name} with params {validator_params}: {str(e)}" + ) + raise + + guard = Guard().use(validator_instance) + _guard_cache[cache_key] = guard + + return _guard_cache[cache_key] diff --git a/nemoguardrails/library/guardrails_ai/errors.py b/nemoguardrails/library/guardrails_ai/errors.py new file mode 100644 index 000000000..4615814ec --- /dev/null +++ b/nemoguardrails/library/guardrails_ai/errors.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from guardrails.errors import ValidationError + + GuardrailsAIValidationError = ValidationError +except ImportError: + # create a fallback error class when guardrails is not installed + class GuardrailsAIValidationError(Exception): + """Fallback validation error when guardrails package is not available.""" + + pass + + +class GuardrailsAIError(Exception): + """Base exception for Guardrails AI integration.""" + + pass + + +class GuardrailsAIConfigError(GuardrailsAIError): + """Raised when configuration is invalid.""" + + pass + + +__all__ = [ + "GuardrailsAIError", + "GuardrailsAIValidationError", + "GuardrailsAIConfigError", +] diff --git a/nemoguardrails/library/guardrails_ai/flows.co b/nemoguardrails/library/guardrails_ai/flows.co new file mode 100644 index 000000000..8586ba47a --- /dev/null +++ b/nemoguardrails/library/guardrails_ai/flows.co @@ -0,0 +1,20 @@ +flow guardrailsai check input $validator + """Check input text using relevant Guardrails AI validators.""" + $result = await ValidateGuardrailsAiInputAction(validator=$validator, text=$user_message) + if not $result["valid"] + if $system.config.enable_rails_exceptions + send GuardrailsAIException(message="Guardrails AI {$validator} validation failed") + else + bot refuse to respond + abort + + +flow guardrailsai check output $validator + """Check output text using relevant Guardrails AI validators.""" + $result = await ValidateGuardrailsAiOutputAction(validator=$validator, text=$bot_message) + if not $result["valid"] + if $system.config.enable_rails_exceptions + send GuardrailsAIException(message="Guardrails AI {$validator} validation failed") + else + bot refuse to respond + abort diff --git a/nemoguardrails/library/guardrails_ai/flows.v1.co b/nemoguardrails/library/guardrails_ai/flows.v1.co new file mode 100644 index 000000000..4bc4621cf --- /dev/null +++ b/nemoguardrails/library/guardrails_ai/flows.v1.co @@ -0,0 +1,24 @@ +define flow guardrailsai check input + """Check input text using relevant Guardrails AI validators.""" + + $result = execute validate_guardrails_ai_input(validator=$validator, text=$user_message) + if not $result["valid"] + if $config.enable_rails_exceptions + $msg = "Guardrails AI " + $validator + " validation failed" + create event GuardrailsAIException(message=$msg) + else + bot refuse to respond + stop + + +define flow guardrailsai check output + """Check output text using relevant Guardrails AI validators.""" + + $result = execute validate_guardrails_ai_output(validator=$validator, text=$bot_message) + if not $result["valid"] + if $config.enable_rails_exceptions + $msg = "Guardrails AI " + $validator + " validation failed" + create event GuardrailsAIException(message=$msg) + else + bot refuse to respond + stop diff --git a/nemoguardrails/library/guardrails_ai/registry.py b/nemoguardrails/library/guardrails_ai/registry.py new file mode 100644 index 000000000..0aef9fcd0 --- /dev/null +++ b/nemoguardrails/library/guardrails_ai/registry.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +from .errors import GuardrailsAIConfigError + +log = logging.getLogger(__name__) + +VALIDATOR_REGISTRY = { + "toxic_language": { + "module": "guardrails.hub", + "class": "ToxicLanguage", + "hub_path": "hub://guardrails/toxic_language", + "default_params": {"on_fail": "noop"}, + }, + "detect_jailbreak": { + "module": "guardrails.hub", + "class": "DetectJailbreak", + "hub_path": "hub://guardrails/detect_jailbreak", + "default_params": {"on_fail": "noop"}, + }, + "guardrails_pii": { + "module": "guardrails.hub", + "class": "GuardrailsPII", + "hub_path": "hub://guardrails/guardrails_pii", + "default_params": {"on_fail": "noop"}, + }, + "competitor_check": { + "module": "guardrails.hub", + "class": "CompetitorCheck", + "hub_path": "hub://guardrails/competitor_check", + "default_params": {"on_fail": "noop"}, + }, + "restricttotopic": { + "module": "guardrails.hub", + "class": "RestrictToTopic", + "hub_path": "hub://tryolabs/restricttotopic", + "default_params": {"on_fail": "noop"}, + }, + "provenance_llm": { + "module": "guardrails.hub", + "class": "ProvenanceLLM", + "hub_path": "hub://guardrails/provenance_llm", + "default_params": {"on_fail": "noop"}, + }, + "regex_match": { + "module": "guardrails.hub", + "class": "RegexMatch", + "hub_path": "hub://guardrails/regex_match", + "default_params": {"on_fail": "noop"}, + }, + "one_line": { + "module": "guardrails.hub", + "class": "OneLine", + "hub_path": "hub://guardrails/one_line", + "default_params": {"on_fail": "noop"}, + }, + "valid_json": { + "module": "guardrails.hub", + "class": "ValidJson", + "hub_path": "hub://guardrails/valid_json", + "default_params": {"on_fail": "noop"}, + }, + "valid_length": { + "module": "guardrails.hub", + "class": "ValidLength", + "hub_path": "hub://guardrails/valid_length", + "default_params": {"on_fail": "noop"}, + }, +} + + +def get_validator_info(validator_path: str) -> Dict[str, str]: + """Get validator information from registry or hub. + + Args: + validator_path: Either a simple name (e.g., "toxic_language") or + a full hub path (e.g., "guardrails/toxic_language") + + Returns: + Dict with module, class, and hub_path information + """ + if validator_path in VALIDATOR_REGISTRY: + return VALIDATOR_REGISTRY[validator_path] + + for _, info in VALIDATOR_REGISTRY.items(): + if info["hub_path"] == f"hub://{validator_path}": + return info + + # not in registry, try to fetch from hub + try: + try: + from guardrails.hub.validator_package_service import get_validator_manifest + except ImportError: + raise GuardrailsAIConfigError( + "Could not import get_validator_manifest. " + "Make sure guardrails-ai is properly installed." + ) + + log.info( + f"Validator '{validator_path}' not found in registry. " + f"Attempting to fetch from Guardrails Hub..." + ) + + manifest = get_validator_manifest(validator_path) + + if manifest.exports: + class_name = manifest.exports[0] + else: + # fallback: construct class name from package name + class_name = "".join( + word.capitalize() for word in manifest.package_name.split("_") + ) + + validator_info = { + "module": "guardrails.hub", + "class": class_name, + "hub_path": f"hub://{manifest.namespace}/{manifest.package_name}", + } + + log.info( + f"Using validator '{validator_path}' that is not in the built-in registry. " + f"Consider adding it to VALIDATOR_REGISTRY for better performance. " + f"Install with: guardrails hub install {validator_info['hub_path']}" + ) + + return validator_info + + except ImportError: + raise GuardrailsAIConfigError( + "Could not import get_validator_manifest. " + "Make sure guardrails-ai is properly installed." + ) + except Exception as e: + raise GuardrailsAIConfigError( + f"Failed to fetch validator info for '{validator_path}': {str(e)}" + ) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index ffdd10220..d5a732a44 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -757,6 +757,40 @@ class ClavataRailConfig(BaseModel): ) +class GuardrailsAIValidatorConfig(BaseModel): + """Configuration for a single Guardrails AI validator.""" + + name: str = Field( + description="Unique identifier or import path for the Guardrails AI validator (e.g., 'toxic_language', 'pii', 'regex_match', or 'guardrails/competitor_check')." + ) + + parameters: Dict[str, Any] = Field( + default_factory=dict, + description="Parameters to pass to the validator during initialization (e.g., threshold, regex pattern).", + ) + + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Metadata to pass to the validator during validation (e.g., valid_topics, context).", + ) + + +class GuardrailsAIRailConfig(BaseModel): + """Configuration data for Guardrails AI integration.""" + + validators: List[GuardrailsAIValidatorConfig] = Field( + default_factory=list, + description="List of Guardrails AI validators to apply. Each validator can have its own parameters and metadata.", + ) + + def get_validator_config(self, name: str) -> Optional[GuardrailsAIValidatorConfig]: + """Get a specific validator configuration by name.""" + for _validator in self.validators: + if _validator.name == name: + return _validator + return None + + class RailsConfigData(BaseModel): """Configuration data for specific rails that are supported out-of-the-box.""" @@ -805,6 +839,11 @@ class RailsConfigData(BaseModel): description="Configuration for Clavata.", ) + guardrails_ai: Optional[GuardrailsAIRailConfig] = Field( + default_factory=GuardrailsAIRailConfig, + description="Configuration for Guardrails AI validators.", + ) + class Rails(BaseModel): """Configuration of specific rails.""" diff --git a/tests/test_guardrails_ai_actions.py b/tests/test_guardrails_ai_actions.py new file mode 100644 index 000000000..99562694d --- /dev/null +++ b/tests/test_guardrails_ai_actions.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Guardrails AI integration - updated to match current implementation.""" + +import inspect +from typing import Any, Dict +from unittest.mock import Mock, patch + +import pytest + + +class TestGuardrailsAIIntegration: + """Test suite for Guardrails AI integration with current implementation.""" + + def test_module_imports_without_guardrails(self): + """Test that modules can be imported even without guardrails package.""" + from nemoguardrails.library.guardrails_ai.actions import ( + _get_guard, + guardrails_ai_validation_mapping, + validate_guardrails_ai, + ) + from nemoguardrails.library.guardrails_ai.registry import VALIDATOR_REGISTRY + + assert callable(validate_guardrails_ai) + assert callable(guardrails_ai_validation_mapping) + assert isinstance(VALIDATOR_REGISTRY, dict) + + def test_validator_registry_structure(self): + """Test that the validator registry has the expected structure.""" + from nemoguardrails.library.guardrails_ai.registry import VALIDATOR_REGISTRY + + assert isinstance(VALIDATOR_REGISTRY, dict) + assert len(VALIDATOR_REGISTRY) >= 6 + + expected_validators = [ + "toxic_language", + "detect_jailbreak", + "guardrails_pii", + "competitor_check", + "restricttotopic", + "provenance_llm", + ] + + for validator in expected_validators: + assert validator in VALIDATOR_REGISTRY + validator_info = VALIDATOR_REGISTRY[validator] + assert "module" in validator_info + assert "class" in validator_info + assert "hub_path" in validator_info + assert "default_params" in validator_info + assert isinstance(validator_info["default_params"], dict) + + def test_validation_mapping_function(self): + """Test the validation mapping function with current interface.""" + from nemoguardrails.library.guardrails_ai.actions import ( + guardrails_ai_validation_mapping, + ) + + mock_result = Mock() + mock_result.validation_passed = True + result1 = {"validation_result": mock_result} + mapped1 = guardrails_ai_validation_mapping(result1) + assert mapped1 is True + + mock_result2 = Mock() + mock_result2.validation_passed = False + result2 = {"validation_result": mock_result2} + mapped2 = guardrails_ai_validation_mapping(result2) + assert mapped2 is False + + result3 = {"validation_result": {"validation_passed": True}} + mapped3 = guardrails_ai_validation_mapping(result3) + assert mapped3 is True + + @patch("nemoguardrails.library.guardrails_ai.actions._get_guard") + def test_validate_guardrails_ai_success(self, mock_get_guard): + """Test successful validation with current interface.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + mock_guard = Mock() + mock_validation_result = Mock() + mock_validation_result.validation_passed = True + mock_guard.validate.return_value = mock_validation_result + mock_get_guard.return_value = mock_guard + + result = validate_guardrails_ai( + validator_name="toxic_language", + text="Hello, this is a safe message", + threshold=0.5, + ) + + assert "validation_result" in result + assert result["validation_result"] == mock_validation_result + mock_guard.validate.assert_called_once_with( + "Hello, this is a safe message", metadata={} + ) + mock_get_guard.assert_called_once_with("toxic_language", threshold=0.5) + + @patch("nemoguardrails.library.guardrails_ai.actions._get_guard") + def test_validate_guardrails_ai_with_metadata(self, mock_get_guard): + """Test validation with metadata parameter.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + mock_guard = Mock() + mock_validation_result = Mock() + mock_validation_result.validation_passed = False + mock_guard.validate.return_value = mock_validation_result + mock_get_guard.return_value = mock_guard + + metadata = {"source": "user_input"} + result = validate_guardrails_ai( + validator_name="detect_jailbreak", + text="Some text", + metadata=metadata, + threshold=0.8, + ) + + assert "validation_result" in result + assert result["validation_result"] == mock_validation_result + mock_guard.validate.assert_called_once_with("Some text", metadata=metadata) + mock_get_guard.assert_called_once_with("detect_jailbreak", threshold=0.8) + + @patch("nemoguardrails.library.guardrails_ai.actions._get_guard") + def test_validate_guardrails_ai_error_handling(self, mock_get_guard): + """Test error handling in validation.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + from nemoguardrails.library.guardrails_ai.errors import ( + GuardrailsAIValidationError, + ) + + mock_guard = Mock() + mock_guard.validate.side_effect = Exception("Validation service error") + mock_get_guard.return_value = mock_guard + + with pytest.raises(GuardrailsAIValidationError) as exc_info: + validate_guardrails_ai(validator_name="toxic_language", text="Any text") + + assert "Validation failed" in str(exc_info.value) + assert "Validation service error" in str(exc_info.value) + + @patch("nemoguardrails.library.guardrails_ai.actions._load_validator_class") + @patch("nemoguardrails.library.guardrails_ai.actions.Guard") + def test_get_guard_creates_and_caches(self, mock_guard_class, mock_load_validator): + """Test that _get_guard creates and caches guards properly.""" + from nemoguardrails.library.guardrails_ai.actions import _get_guard + + mock_validator_class = Mock() + mock_validator_instance = Mock() + mock_guard_instance = Mock() + mock_guard = Mock() + + mock_load_validator.return_value = mock_validator_class + mock_validator_class.return_value = mock_validator_instance + mock_guard_class.return_value = mock_guard + mock_guard.use.return_value = mock_guard_instance + + # clear cache + import nemoguardrails.library.guardrails_ai.actions as actions + + actions._guard_cache.clear() + + # first call should create new guard + result1 = _get_guard("toxic_language", threshold=0.5) + + assert result1 == mock_guard_instance + mock_validator_class.assert_called_once_with(threshold=0.5, on_fail="noop") + mock_guard.use.assert_called_once_with(mock_validator_instance) + + # reset mocks for second call + mock_load_validator.reset_mock() + mock_validator_class.reset_mock() + mock_guard_class.reset_mock() + + # second call with same params should use cache + result2 = _get_guard("toxic_language", threshold=0.5) + + assert result2 == mock_guard_instance + # should not create new validator or guard + mock_load_validator.assert_not_called() + mock_validator_class.assert_not_called() + mock_guard_class.assert_not_called() + + @patch("nemoguardrails.library.guardrails_ai.registry.get_validator_info") + def test_load_validator_class_unknown_validator(self, mock_get_info): + """Test error handling for unknown validators.""" + from nemoguardrails.library.guardrails_ai.actions import _load_validator_class + from nemoguardrails.library.guardrails_ai.errors import GuardrailsAIConfigError + + mock_get_info.side_effect = GuardrailsAIConfigError( + "Unknown validator: unknown_validator" + ) + + with pytest.raises(ImportError) as exc_info: + _load_validator_class("unknown_validator") + + assert "Failed to load validator unknown_validator" in str(exc_info.value) + + def test_validate_guardrails_ai_signature(self): + """Test that validate_guardrails_ai has the expected signature.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + sig = inspect.signature(validate_guardrails_ai) + params = list(sig.parameters.keys()) + + assert "validator_name" in params + assert "text" in params + assert any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()) + + @patch("nemoguardrails.library.guardrails_ai.actions._load_validator_class") + @patch("nemoguardrails.library.guardrails_ai.actions.Guard") + def test_guard_cache_key_generation(self, mock_guard_class, mock_load): + """Test that guard cache keys are generated correctly for different parameter combinations.""" + from nemoguardrails.library.guardrails_ai.actions import _get_guard + + mock_validator_class = Mock() + mock_guard_instance = Mock() + mock_guard = Mock() + + mock_load.return_value = mock_validator_class + mock_guard_class.return_value = mock_guard + mock_guard.use.return_value = mock_guard_instance + + import nemoguardrails.library.guardrails_ai.actions as actions + + actions._guard_cache.clear() + + # create guards with different parameters + _get_guard("toxic_language", threshold=0.5) + _get_guard("toxic_language", threshold=0.8) + _get_guard("detect_jailbreak", threshold=0.5) + + assert len(actions._guard_cache) == 3 diff --git a/tests/test_guardrails_ai_config.py b/tests/test_guardrails_ai_config.py new file mode 100644 index 000000000..f24ff55ff --- /dev/null +++ b/tests/test_guardrails_ai_config.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for guardrails_ai configuration parsing.""" + +import pytest + +from nemoguardrails.rails.llm.config import RailsConfig + + +def test_guardrails_ai_config_parsing(): + """Test that guardrails_ai configuration is properly parsed.""" + + config_content = """ + models: + - type: main + engine: openai + model: gpt-4 + + rails: + config: + guardrails_ai: + validators: + - name: toxic_language + parameters: + threshold: 0.7 + validation_method: "full" + metadata: + context: "customer_service" + + - name: pii + parameters: + entities: ["email", "phone"] + metadata: {} + + - name: competitor_check + parameters: + competitors: ["Apple", "Google"] + metadata: + strict: true + """ + + config = RailsConfig.from_content(yaml_content=config_content) + + assert config.rails.config.guardrails_ai is not None + + validators = config.rails.config.guardrails_ai.validators + assert len(validators) == 3 + + toxic_validator = validators[0] + assert toxic_validator.name == "toxic_language" + assert toxic_validator.parameters["threshold"] == 0.7 + assert toxic_validator.parameters["validation_method"] == "full" + assert toxic_validator.metadata["context"] == "customer_service" + + pii_validator = validators[1] + assert pii_validator.name == "pii" + assert pii_validator.parameters["entities"] == ["email", "phone"] + assert pii_validator.metadata == {} + + competitor_validator = validators[2] + assert competitor_validator.name == "competitor_check" + assert competitor_validator.parameters["competitors"] == ["Apple", "Google"] + assert competitor_validator.metadata["strict"] is True + + +def test_guardrails_ai_get_validator_config(): + """Test that guardrails_ai configuration is properly parsed.""" + + config_content = """ + models: + - type: main + engine: openai + model: gpt-4 + + rails: + config: + guardrails_ai: + validators: + - name: toxic_language + parameters: + threshold: 0.7 + validation_method: "full" + metadata: + context: "customer_service" + + - name: pii + parameters: + entities: ["email", "phone"] + metadata: {} + + - name: competitor_check + parameters: + competitors: ["Apple", "Google"] + metadata: + strict: true + """ + + config = RailsConfig.from_content(yaml_content=config_content) + + assert config.rails.config.guardrails_ai is not None + + guardrails_ai = config.rails.config.guardrails_ai + validators = guardrails_ai.validators + assert len(validators) == 3 + + toxic_validator = guardrails_ai.get_validator_config("toxic_language") + assert toxic_validator.name == "toxic_language" + + pii_validator = guardrails_ai.get_validator_config("pii") + assert pii_validator.name == "pii" + assert pii_validator.parameters["entities"] == ["email", "phone"] + assert pii_validator.metadata == {} + + competitor_validator = validators[2] + assert competitor_validator.name == "competitor_check" + assert competitor_validator.parameters["competitors"] == ["Apple", "Google"] + assert competitor_validator.metadata["strict"] is True + + +def test_guardrails_ai_config_defaults(): + """Test default values for guardrails_ai configuration.""" + + config_content = """ + models: + - type: main + engine: openai + model: gpt-4 + + rails: + config: + guardrails_ai: + validators: + - name: simple_validator + """ + + config = RailsConfig.from_content(yaml_content=config_content) + + validator = config.rails.config.guardrails_ai.validators[0] + assert validator.name == "simple_validator" + assert validator.parameters == {} + assert validator.metadata == {} + + +def test_guardrails_ai_config_empty(): + """Test empty guardrails_ai configuration.""" + + config_content = """ + models: + - type: main + engine: openai + model: gpt-4 + """ + + config = RailsConfig.from_content(yaml_content=config_content) + + assert config.rails.config.guardrails_ai is not None + assert config.rails.config.guardrails_ai.validators == [] diff --git a/tests/test_guardrails_ai_e2e_actions.py b/tests/test_guardrails_ai_e2e_actions.py new file mode 100644 index 000000000..65c2ba883 --- /dev/null +++ b/tests/test_guardrails_ai_e2e_actions.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-End tests for Guardrails AI integration with real validators. + +These tests run against actual Guardrails validators when installed. +They can be skipped in CI/environments where validators aren't available. +""" + +import pytest + +GUARDRAILS_AVAILABLE = False +VALIDATORS_AVAILABLE = {} + +try: + from guardrails import Guard + + GUARDRAILS_AVAILABLE = True + + try: + from guardrails.hub import ToxicLanguage + + VALIDATORS_AVAILABLE["toxic_language"] = True + except ImportError: + VALIDATORS_AVAILABLE["toxic_language"] = False + + try: + from guardrails.hub import RegexMatch + + VALIDATORS_AVAILABLE["regex_match"] = True + except ImportError: + VALIDATORS_AVAILABLE["regex_match"] = False + + try: + from guardrails.hub import ValidLength + + VALIDATORS_AVAILABLE["valid_length"] = True + except ImportError: + VALIDATORS_AVAILABLE["valid_length"] = False + + try: + from guardrails.hub import CompetitorCheck + + VALIDATORS_AVAILABLE["competitor_check"] = True + except ImportError: + VALIDATORS_AVAILABLE["competitor_check"] = False + +except ImportError: + GUARDRAILS_AVAILABLE = False + + +class TestGuardrailsAIE2EIntegration: + """End-to-End tests using real Guardrails validators when available.""" + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not VALIDATORS_AVAILABLE.get("regex_match", False), + reason="Guardrails or RegexMatch validator not installed. Install with: guardrails hub install hub://guardrails/regex_match", + ) + def test_regex_match_e2e_success(self): + """E2E test: RegexMatch validator with text that should pass.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + result = validate_guardrails_ai( + validator_name="regex_match", + text="Hello world", + regex="^[A-Z].*", + on_fail="noop", + ) + + assert "validation_result" in result + assert hasattr(result["validation_result"], "validation_passed") + assert result["validation_result"].validation_passed is True + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not VALIDATORS_AVAILABLE.get("regex_match", False), + reason="Guardrails or RegexMatch validator not installed", + ) + def test_regex_match_e2e_failure(self): + """E2E test: RegexMatch validator with text that should fail.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + result = validate_guardrails_ai( + validator_name="regex_match", + text="hello world", + regex="^[A-Z].*", + on_fail="noop", + ) + + assert "validation_result" in result + assert hasattr(result["validation_result"], "validation_passed") + assert result["validation_result"].validation_passed is False + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not VALIDATORS_AVAILABLE.get("valid_length", False), + reason="Guardrails or ValidLength validator not installed", + ) + def test_valid_length_e2e(self): + """E2E test: ValidLength validator.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + result_pass = validate_guardrails_ai( + validator_name="valid_length", text="Hello", min=1, max=10, on_fail="noop" + ) + + assert result_pass["validation_result"].validation_passed is True + + result_fail = validate_guardrails_ai( + validator_name="valid_length", + text="This is a very long text that exceeds the maximum length", + min=1, + max=10, + on_fail="noop", + ) + + assert result_fail["validation_result"].validation_passed is False + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE + or not VALIDATORS_AVAILABLE.get("toxic_language", False), + reason="Guardrails or ToxicLanguage validator not installed. Install with: guardrails hub install hub://guardrails/toxic_language", + ) + def test_toxic_language_e2e(self): + """E2E test: ToxicLanguage validator with real content.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + result_safe = validate_guardrails_ai( + validator_name="toxic_language", + text="Have a wonderful day! Thank you for your help.", + threshold=0.5, + on_fail="noop", + ) + + assert "validation_result" in result_safe + assert hasattr(result_safe["validation_result"], "validation_passed") + assert result_safe["validation_result"].validation_passed is True + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE + or not VALIDATORS_AVAILABLE.get("competitor_check", False), + reason="Guardrails or CompetitorCheck validator not installed", + ) + def test_competitor_check_e2e(self): + """E2E test: CompetitorCheck validator.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + competitors = ["Apple", "Google", "Microsoft"] + + result_safe = validate_guardrails_ai( + validator_name="competitor_check", + text="Our company provides excellent services.", + competitors=competitors, + on_fail="noop", + ) + + assert result_safe["validation_result"].validation_passed is True + + result_competitor = validate_guardrails_ai( + validator_name="competitor_check", + text="Apple makes great products.", + competitors=competitors, + on_fail="noop", + ) + + assert result_competitor["validation_result"].validation_passed is False + + @pytest.mark.skipif(not GUARDRAILS_AVAILABLE, reason="Guardrails not installed") + def test_validation_mapping_e2e(self): + """E2E test: Validation mapping with real validation results.""" + from nemoguardrails.library.guardrails_ai.actions import ( + guardrails_ai_validation_mapping, + validate_guardrails_ai, + ) + + if VALIDATORS_AVAILABLE.get("regex_match", False): + result = validate_guardrails_ai( + validator_name="regex_match", + text="Hello world", + regex="^[A-Z].*", + on_fail="noop", + ) + + mapped = guardrails_ai_validation_mapping(result) + assert mapped["valid"] is True + assert "validation_result" in mapped + + result_fail = validate_guardrails_ai( + validator_name="regex_match", + text="hello world", + regex="^[A-Z].*", + on_fail="noop", + ) + + mapped_fail = guardrails_ai_validation_mapping(result_fail) + assert mapped_fail["valid"] is False + + @pytest.mark.skipif(not GUARDRAILS_AVAILABLE, reason="Guardrails not installed") + def test_metadata_parameter_e2e(self): + """E2E test: Metadata parameter handling with real validators.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + if VALIDATORS_AVAILABLE.get("regex_match", False): + metadata = {"source": "user_input", "context": "test"} + result = validate_guardrails_ai( + validator_name="regex_match", + text="Hello world", + regex="^[A-Z].*", + metadata=metadata, + on_fail="noop", + ) + + assert "validation_result" in result + assert result["validation_result"].validation_passed is True + + @pytest.mark.skipif(not GUARDRAILS_AVAILABLE, reason="Guardrails not installed") + def test_guard_caching_e2e(self): + """E2E test: Verify guard caching works with real validators.""" + from nemoguardrails.library.guardrails_ai.actions import _get_guard + + if VALIDATORS_AVAILABLE.get("regex_match", False): + import nemoguardrails.library.guardrails_ai.actions as actions + + actions._guard_cache.clear() + + guard1 = _get_guard("regex_match", regex="^[A-Z].*", on_fail="noop") + guard2 = _get_guard("regex_match", regex="^[A-Z].*", on_fail="noop") + + # should be the same instance (cached) + assert guard1 is guard2 + + # different parameters should create different guard + guard3 = _get_guard("regex_match", regex="^[a-z].*", on_fail="noop") + assert guard3 is not guard1 + + def test_error_handling_unknown_validator_e2e(self): + """E2E test: Error handling for unknown validators.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + from nemoguardrails.library.guardrails_ai.errors import ( + GuardrailsAIValidationError, + ) + + # Test with completely unknown validator + with pytest.raises(GuardrailsAIValidationError) as exc_info: + validate_guardrails_ai( + validator_name="completely_unknown_validator", text="Test text" + ) + + assert "Validation failed" in str(exc_info.value) + + @pytest.mark.skipif(not GUARDRAILS_AVAILABLE, reason="Guardrails not installed") + def test_multiple_validators_sequence_e2e(self): + """E2E test: Using multiple validators in sequence.""" + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai + + test_text = "Hello World Test" + + available_validators = [] + if VALIDATORS_AVAILABLE.get("regex_match", False): + available_validators.append(("regex_match", {"regex": "^[A-Z].*"})) + if VALIDATORS_AVAILABLE.get("valid_length", False): + available_validators.append(("valid_length", {"min": 1, "max": 50})) + + # run each available validator + for validator_name, params in available_validators: + result = validate_guardrails_ai( + validator_name=validator_name, text=test_text, on_fail="noop", **params + ) + + assert "validation_result" in result + assert hasattr(result["validation_result"], "validation_passed") + # all should pass with the test text + assert result["validation_result"].validation_passed is True + + +def print_validator_availability(): + """Helper function to print which validators are available for testing.""" + print(f"Guardrails available: {GUARDRAILS_AVAILABLE}") + if GUARDRAILS_AVAILABLE: + for validator, available in VALIDATORS_AVAILABLE.items(): + print(f" {validator}: {available}") + + +if __name__ == "__main__": + print_validator_availability() + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_guardrails_ai_e2e_v1.py b/tests/test_guardrails_ai_e2e_v1.py new file mode 100644 index 000000000..f8688c768 --- /dev/null +++ b/tests/test_guardrails_ai_e2e_v1.py @@ -0,0 +1,496 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails import LLMRails, RailsConfig +from tests.utils import FakeLLM, TestChat + +try: + from guardrails import Guard + + GUARDRAILS_AVAILABLE = True + + try: + from guardrails.hub import RegexMatch + + REGEX_MATCH_AVAILABLE = True + except ImportError: + REGEX_MATCH_AVAILABLE = False + + try: + from guardrails.hub import ValidLength + + VALID_LENGTH_AVAILABLE = True + except ImportError: + VALID_LENGTH_AVAILABLE = False + +except ImportError: + GUARDRAILS_AVAILABLE = False + REGEX_MATCH_AVAILABLE = False + VALID_LENGTH_AVAILABLE = False + + +INPUT_RAILS_ONLY_CONFIG_EXCEPTION = """ +models: + - type: main + engine: fake + model: fake + +enable_rails_exceptions: true + +rails: + config: + guardrails_ai: + validators: + - name: regex_match + parameters: + regex: "^[A-Z].*" + metadata: {} + + input: + flows: + - guardrailsai check input $validator="regex_match" +""" + +INPUT_RAILS_ONLY_CONFIG_REFUSE = """ +models: + - type: main + engine: fake + model: fake + +enable_rails_exceptions: false + +rails: + config: + guardrails_ai: + validators: + - name: regex_match + parameters: + regex: "^[A-Z].*" + metadata: {} + + input: + flows: + - guardrailsai check input $validator="regex_match" +""" + +OUTPUT_RAILS_ONLY_CONFIG_EXCEPTION = """ +models: + - type: main + engine: fake + model: fake + +enable_rails_exceptions: true + +rails: + config: + guardrails_ai: + validators: + - name: valid_length + parameters: + min: 1 + max: 20 + metadata: {} + + output: + flows: + - guardrailsai check output $validator="valid_length" +""" + +OUTPUT_RAILS_ONLY_CONFIG_REFUSE = """ +models: + - type: main + engine: fake + model: fake + +enable_rails_exceptions: false + +rails: + config: + guardrails_ai: + validators: + - name: valid_length + parameters: + min: 1 + max: 20 + metadata: {} + + output: + flows: + - guardrailsai check output $validator="valid_length" +""" + +INPUT_AND_OUTPUT_RAILS_CONFIG_EXCEPTION = """ +models: + - type: main + engine: fake + model: fake + +enable_rails_exceptions: true + +rails: + config: + guardrails_ai: + validators: + - name: regex_match + parameters: + regex: "^[A-Z].*" + metadata: {} + - name: valid_length + parameters: + min: 1 + max: 30 + metadata: {} + + input: + flows: + - guardrailsai check input $validator="regex_match" + + output: + flows: + - guardrailsai check output $validator="valid_length" +""" + +COLANG_CONTENT = """ +define user express greeting + "hello" + "hi" + "hey" + +define bot express greeting + "Hello! How can I help you today?" + +define bot refuse to respond + "I can't help with that request." + +define flow greeting + user express greeting + bot express greeting +""" + +OUTPUT_RAILS_COLANG_CONTENT = """ +define user express greeting + "hello" + "hi" + "hey" + +define bot refuse to respond + "I can't help with that request." + +define flow greeting + user express greeting + # No predefined bot response - will be LLM generated +""" + + +class TestGuardrailsAIBlockingBehavior: + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not REGEX_MATCH_AVAILABLE, + reason="Guardrails or RegexMatch validator not installed", + ) + def test_input_rails_only_validation_passes(self): + """Test input rails when validation passes - conversation continues normally.""" + config = RailsConfig.from_content( + colang_content=COLANG_CONTENT, + yaml_content=INPUT_RAILS_ONLY_CONFIG_EXCEPTION, + ) + + chat = TestChat( + config, + llm_completions=[" express greeting", "Hello! How can I help you today?"], + ) + + chat.user("Hello there!") + chat.bot("Hello! How can I help you today?") + + assert len(chat.history) == 2 + assert chat.history[0]["role"] == "user" + assert chat.history[0]["content"] == "Hello there!" + assert chat.history[1]["role"] == "assistant" + assert "Hello" in chat.history[1]["content"] + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not REGEX_MATCH_AVAILABLE, + reason="Guardrails or RegexMatch validator not installed", + ) + def test_input_rails_only_validation_blocks_with_exception(self): + """Test input rails when validation fails - blocked with exception.""" + config = RailsConfig.from_content( + colang_content=COLANG_CONTENT, + yaml_content=INPUT_RAILS_ONLY_CONFIG_EXCEPTION, + ) + + llm = FakeLLM( + responses=[" express greeting", "Hello! How can I help you today?"] + ) + + rails = LLMRails(config=config, llm=llm) + + result = rails.generate(messages=[{"role": "user", "content": "hello there!"}]) + + assert result["role"] == "exception" + assert result["content"]["type"] == "GuardrailsAIException" + assert ( + "Guardrails AI regex_match validation failed" + in result["content"]["message"] + ) + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not REGEX_MATCH_AVAILABLE, + reason="Guardrails or RegexMatch validator not installed", + ) + def test_input_rails_only_validation_blocks_with_refuse(self): + """Test input rails when validation fails - blocked with bot refuse.""" + config = RailsConfig.from_content( + colang_content=COLANG_CONTENT, yaml_content=INPUT_RAILS_ONLY_CONFIG_REFUSE + ) + + chat = TestChat( + config, + llm_completions=[" express greeting", "Hello! How can I help you today?"], + ) + + chat.user("hello there!") + chat.bot("I can't help with that request.") + + assert len(chat.history) == 2 + assert chat.history[0]["role"] == "user" + assert chat.history[0]["content"] == "hello there!" + assert chat.history[1]["role"] == "assistant" + assert "can't" in chat.history[1]["content"].lower() + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not VALID_LENGTH_AVAILABLE, + reason="Guardrails or ValidLength validator not installed", + ) + def test_output_rails_only_validation_passes(self): + """Test output rails when validation passes - response is allowed.""" + config = RailsConfig.from_content( + colang_content=OUTPUT_RAILS_COLANG_CONTENT, + yaml_content=OUTPUT_RAILS_ONLY_CONFIG_EXCEPTION, + ) + + chat = TestChat( + config, + llm_completions=[" express greeting", "general response", "Hi!"], + ) + + chat.user("Hello") + chat.bot("Hi!") + + assert len(chat.history) == 2 + assert chat.history[0]["role"] == "user" + assert chat.history[0]["content"] == "Hello" + assert chat.history[1]["role"] == "assistant" + assert chat.history[1]["content"] == "Hi!" + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not VALID_LENGTH_AVAILABLE, + reason="Guardrails or ValidLength validator not installed", + ) + def test_output_rails_only_validation_blocks_with_exception(self): + """Test output rails when validation fails - blocked with exception.""" + config = RailsConfig.from_content( + colang_content=OUTPUT_RAILS_COLANG_CONTENT, + yaml_content=OUTPUT_RAILS_ONLY_CONFIG_EXCEPTION, + ) + + llm = FakeLLM( + responses=[ + " express greeting", + "general response", + "This is a very long response that exceeds the maximum length limit set in the validator configuration", + ] + ) + + rails = LLMRails(config=config, llm=llm) + + result = rails.generate(messages=[{"role": "user", "content": "Hello"}]) + + assert result["role"] == "exception" + assert result["content"]["type"] == "GuardrailsAIException" + assert ( + "Guardrails AI valid_length validation failed" + in result["content"]["message"] + ) + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not VALID_LENGTH_AVAILABLE, + reason="Guardrails or ValidLength validator not installed", + ) + def test_output_rails_only_validation_blocks_with_refuse(self): + """Test output rails when validation fails - blocked with bot refuse.""" + config = RailsConfig.from_content( + colang_content=OUTPUT_RAILS_COLANG_CONTENT, + yaml_content=OUTPUT_RAILS_ONLY_CONFIG_REFUSE, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + "general response", + "This is a very long response that exceeds the maximum length limit set in the validator configuration", + ], + ) + + chat.user("Hello") + chat.bot("I can't help with that request.") + + assert len(chat.history) == 2 + assert chat.history[0]["role"] == "user" + assert chat.history[0]["content"] == "Hello" + assert chat.history[1]["role"] == "assistant" + assert "can't" in chat.history[1]["content"].lower() + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE + or not REGEX_MATCH_AVAILABLE + or not VALID_LENGTH_AVAILABLE, + reason="Guardrails, RegexMatch, or ValidLength validator not installed", + ) + def test_input_and_output_rails_both_pass(self): + """Test input+output rails when both validations pass - conversation flows normally.""" + config = RailsConfig.from_content( + colang_content=OUTPUT_RAILS_COLANG_CONTENT, + yaml_content=INPUT_AND_OUTPUT_RAILS_CONFIG_EXCEPTION, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + "general response", + "Hello! How are you?", + ], + ) + + chat.user("Hello there!") + chat.bot("Hello! How are you?") + + assert len(chat.history) == 2 + assert chat.history[0]["role"] == "user" + assert chat.history[0]["content"] == "Hello there!" + assert chat.history[1]["role"] == "assistant" + assert chat.history[1]["content"] == "Hello! How are you?" + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE or not REGEX_MATCH_AVAILABLE, + reason="Guardrails or RegexMatch validator not installed", + ) + def test_input_and_output_rails_input_blocks_with_exception(self): + """Test input+output rails when input validation fails - blocked at input with exception.""" + config = RailsConfig.from_content( + colang_content=OUTPUT_RAILS_COLANG_CONTENT, + yaml_content=INPUT_AND_OUTPUT_RAILS_CONFIG_EXCEPTION, + ) + + llm = FakeLLM( + responses=[" express greeting", "general response", "Hello! How are you?"] + ) + + rails = LLMRails(config=config, llm=llm) + + result = rails.generate(messages=[{"role": "user", "content": "hello there!"}]) + + assert result["role"] == "exception" + assert result["content"]["type"] == "GuardrailsAIException" + assert ( + "Guardrails AI regex_match validation failed" + in result["content"]["message"] + ) + + @pytest.mark.skipif( + not GUARDRAILS_AVAILABLE + or not REGEX_MATCH_AVAILABLE + or not VALID_LENGTH_AVAILABLE, + reason="Guardrails, RegexMatch, or ValidLength validator not installed", + ) + def test_input_and_output_rails_output_blocks_with_exception(self): + """Test input+output rails when output validation fails - blocked at output with exception.""" + config = RailsConfig.from_content( + colang_content=OUTPUT_RAILS_COLANG_CONTENT, + yaml_content=INPUT_AND_OUTPUT_RAILS_CONFIG_EXCEPTION, + ) + + llm = FakeLLM( + responses=[ + " express greeting", + "general response", + "This is a very long response that definitely exceeds the maximum length limit", + ] + ) + + rails = LLMRails(config=config, llm=llm) + + result = rails.generate(messages=[{"role": "user", "content": "Hello there!"}]) + + assert result["role"] == "exception" + assert result["content"]["type"] == "GuardrailsAIException" + assert ( + "Guardrails AI valid_length validation failed" + in result["content"]["message"] + ) + + def test_config_structures_are_valid(self): + """Test that all config structures parse correctly.""" + + input_config = RailsConfig.from_content( + colang_content=COLANG_CONTENT, + yaml_content=INPUT_RAILS_ONLY_CONFIG_EXCEPTION, + ) + assert input_config.rails.config.guardrails_ai is not None + assert len(input_config.rails.input.flows) == 1 + assert len(input_config.rails.output.flows) == 0 + + output_config = RailsConfig.from_content( + colang_content=COLANG_CONTENT, + yaml_content=OUTPUT_RAILS_ONLY_CONFIG_EXCEPTION, + ) + assert output_config.rails.config.guardrails_ai is not None + assert len(output_config.rails.input.flows) == 0 + assert len(output_config.rails.output.flows) == 1 + + both_config = RailsConfig.from_content( + colang_content=COLANG_CONTENT, + yaml_content=INPUT_AND_OUTPUT_RAILS_CONFIG_EXCEPTION, + ) + assert both_config.rails.config.guardrails_ai is not None + assert len(both_config.rails.input.flows) == 1 + assert len(both_config.rails.output.flows) == 1 + + def test_validator_configurations_are_accessible(self): + """Test that validator configurations can be accessed properly.""" + + config = RailsConfig.from_content( + colang_content=COLANG_CONTENT, + yaml_content=INPUT_AND_OUTPUT_RAILS_CONFIG_EXCEPTION, + ) + + guardrails_config = config.rails.config.guardrails_ai + + regex_validator = guardrails_config.get_validator_config("regex_match") + assert regex_validator.name == "regex_match" + assert regex_validator.parameters["regex"] == "^[A-Z].*" + + length_validator = guardrails_config.get_validator_config("valid_length") + assert length_validator.name == "valid_length" + assert length_validator.parameters["min"] == 1 + assert length_validator.parameters["max"] == 30 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])