Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions enterprise/enterprise_hooks/banned_keywords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# +------------------------------+
#
# Banned Keywords
#
# +------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
## Reject a call / response if it contains certain keywords


from typing import Optional, Literal
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException
import json, traceback


class _ENTERPRISE_BannedKeywords(CustomLogger):
# Class variables or attributes
def __init__(self):
banned_keywords_list = litellm.banned_keywords_list

if banned_keywords_list is None:
raise Exception(
"`banned_keywords_list` can either be a list or filepath. None set."
)

if isinstance(banned_keywords_list, list):
self.banned_keywords_list = banned_keywords_list

if isinstance(banned_keywords_list, str): # assume it's a filepath
try:
with open(banned_keywords_list, "r") as file:
data = file.read()
self.banned_keywords_list = data.split("\n")
except FileNotFoundError:
raise Exception(
f"File not found. banned_keywords_list={banned_keywords_list}"
)
except Exception as e:
raise Exception(
f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}"
)

def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
if level == "INFO":
verbose_proxy_logger.info(print_statement)
elif level == "DEBUG":
verbose_proxy_logger.debug(print_statement)

if litellm.set_verbose is True:
print(print_statement) # noqa

def test_violation(self, test_str: str):
for word in self.banned_keywords_list:
if word in test_str.lower():
raise HTTPException(
status_code=400,
detail={"error": f"Keyword banned. Keyword={word}"},
)

async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
try:
"""
- check if user id part of call
- check if user id part of blocked list
"""
self.print_verbose(f"Inside Banned Keyword List Pre-Call Hook")
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
self.test_violation(test_str=m["content"])

except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()

async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
for word in self.banned_keywords_list:
self.test_violation(test_str=response.choices[0].message.content)

async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
self.test_violation(test_str=response)
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None
##################
logging: bool = True
caching: bool = (
Expand Down
10 changes: 10 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,16 @@ async def load_config(

blocked_user_list = _ENTERPRISE_BlockedUserList()
imported_list.append(blocked_user_list)
elif (
isinstance(callback, str)
and callback == "banned_keywords"
):
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)

banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
else:
imported_list.append(
get_instance_fn(
Expand Down
63 changes: 63 additions & 0 deletions litellm/tests/test_banned_keyword_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# What is this?
## This tests the blocked user pre call hook for the proxy server


import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv

load_dotenv()
import os

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache


@pytest.mark.asyncio
async def test_banned_keywords_check():
"""
- Set some banned keywords as a litellm module value
- Test to see if a call with banned keywords is made, an error is raised
- Test to see if a call without banned keywords is made it passes
"""
litellm.banned_keywords_list = ["hello"]

banned_keywords_obj = _ENTERPRISE_BannedKeywords()

_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()

## Case 1: blocked user id passed
try:
await banned_keywords_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"messages": [{"role": "user", "content": "Hello world"}]},
)
pytest.fail(f"Expected call to fail")
except Exception as e:
pass

## Case 2: normal user id passed
try:
await banned_keywords_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]},
)
except Exception as e:
pytest.fail(f"An error occurred - {str(e)}")
2 changes: 2 additions & 0 deletions litellm/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ def test_completion_palm_stream():
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except litellm.Timeout as e:
pass
except litellm.APIError as e:
pass
except Exception as e:
Expand Down
8 changes: 8 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6824,6 +6824,14 @@ def exception_type(
llm_provider="palm",
response=original_exception.response,
)
if "504 Deadline expired before operation could complete." in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"PalmException - {original_exception.message}",
model=model,
llm_provider="palm",
request=original_exception.request,
)
if "400 Request payload size exceeds" in error_str:
exception_mapping_worked = True
raise ContextWindowExceededError(
Expand Down