Skip to content

fix(utils.py): support streaming cached response logging #2124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 22, 2024
Merged
4 changes: 3 additions & 1 deletion litellm/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def set_cache(self, key, value, **kwargs):
self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e)
print_verbose(
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
)

async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client()
Expand Down
9 changes: 8 additions & 1 deletion litellm/model_prices_and_context_window_backup.json
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,14 @@
"mode": "chat"
},
"openrouter/mistralai/mistral-7b-instruct": {
"max_tokens": 4096,
"max_tokens": 8192,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.00000013,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/mistralai/mistral-7b-instruct:free": {
"max_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "openrouter",
Expand Down
6 changes: 6 additions & 0 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,8 @@ def test_completion_gemini():
response = completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response)
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

Expand Down Expand Up @@ -2015,6 +2017,8 @@ def test_completion_palm():
response = completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response)
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

Expand All @@ -2037,6 +2041,8 @@ def test_completion_palm_stream():
# Add any assertions here to check the response
for chunk in response:
print(chunk)
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

Expand Down
49 changes: 48 additions & 1 deletion litellm/tests/test_custom_callback_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
import pytest, uuid
from pydantic import BaseModel

sys.path.insert(0, os.path.abspath("../.."))
Expand Down Expand Up @@ -795,6 +795,53 @@ async def test_async_completion_azure_caching():
assert len(customHandler_caching.states) == 4 # pre, post, success, success


@pytest.mark.asyncio
async def test_async_completion_azure_caching_streaming():
import copy

litellm.set_verbose = True
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = uuid.uuid4()
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response1:
print(f"chunk in response1: {chunk}")
await asyncio.sleep(1)
initial_customhandler_caching_states = len(customHandler_caching.states)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response2:
print(f"chunk in response2: {chunk}")
await asyncio.sleep(1) # success callbacks are done in parallel
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert (
len(customHandler_caching.states) > initial_customhandler_caching_states
) # pre, post, streaming .., success, success


@pytest.mark.asyncio
async def test_async_embedding_azure_caching():
print("Testing custom callback input - Azure Caching")
Expand Down
6 changes: 6 additions & 0 deletions litellm/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ def test_completion_palm_stream():
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

Expand Down Expand Up @@ -425,6 +427,8 @@ def test_completion_gemini_stream():
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

Expand Down Expand Up @@ -461,6 +465,8 @@ async def test_acompletion_gemini_stream():
print(f"completion_response: {complete_response}")
if complete_response.strip() == "":
raise Exception("Empty response received")
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

Expand Down
113 changes: 77 additions & 36 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def success_handler(
print_verbose(
f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
)
return
pass
else:
print_verbose(
"success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
Expand Down Expand Up @@ -1616,7 +1616,7 @@ async def async_success_handler(
print_verbose(
f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
)
return
pass
else:
print_verbose(
"async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
Expand All @@ -1625,8 +1625,10 @@ async def async_success_handler(
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(f"Async success callbacks: {callback}")
if self.stream:
print_verbose(
f"Async success callbacks: {callback}; self.stream: {self.stream}; complete_streaming_response: {self.model_call_details.get('complete_streaming_response', None)}"
)
if self.stream == True:
if "complete_streaming_response" in self.model_call_details:
await callback.async_log_success_event(
kwargs=self.model_call_details,
Expand Down Expand Up @@ -2328,6 +2330,13 @@ def wrapper(*args, **kwargs):
model_response_object=ModelResponse(),
stream=kwargs.get("stream", False),
)
if kwargs.get("stream", False) == True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
elif call_type == CallTypes.embedding.value and isinstance(
cached_result, dict
):
Expand Down Expand Up @@ -2624,28 +2633,6 @@ async def wrapper_async(*args, **kwargs):
cached_result, list
):
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
Expand Down Expand Up @@ -2685,15 +2672,44 @@ async def wrapper_async(*args, **kwargs):
additional_args=None,
stream=kwargs.get("stream", False),
)
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
if kwargs.get("stream", False) == False:
# LOG SUCCESS
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
return cached_result
elif (
call_type == CallTypes.aembedding.value
Expand Down Expand Up @@ -4296,7 +4312,9 @@ def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)]
optional_params["tools"] = [
generative_models.Tool(function_declarations=gtool_func_declarations)
]
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
Expand Down Expand Up @@ -6795,7 +6813,7 @@ def exception_type(
llm_provider="vertex_ai",
request=original_exception.request,
)
elif custom_llm_provider == "palm":
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
if "503 Getting metadata" in error_str:
# auth errors look like this
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
Expand All @@ -6814,6 +6832,15 @@ def exception_type(
llm_provider="palm",
response=original_exception.response,
)
if "500 An internal error has occurred." in error_str:
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"PalmException - {original_exception.message}",
llm_provider="palm",
model=model,
request=original_exception.request,
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 400:
exception_mapping_worked = True
Expand Down Expand Up @@ -8524,6 +8551,19 @@ def chunk_creator(self, chunk):
]
elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "cached_response":
response_obj = {
"text": chunk.choices[0].delta.content,
"is_finished": True,
"finish_reason": chunk.choices[0].finish_reason,
}

completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
Expand Down Expand Up @@ -8732,6 +8772,7 @@ async def __anext__(self):
or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream:
Expand Down