diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..cd11e70a7 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -137,7 +137,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index e66f1a0d5..2dbb2a61b 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -14,21 +14,31 @@ # limitations under the License. import contextvars -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo + from nemoguardrails.logging.stats import LLMStats streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None) # The object that holds additional explanation information. -explain_info_var = contextvars.ContextVar("explain_info", default=None) +explain_info_var: contextvars.ContextVar[ + Optional["ExplainInfo"] +] = contextvars.ContextVar("explain_info", default=None) # The current LLM call. -llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None) +llm_call_info_var: contextvars.ContextVar[ + Optional["LLMCallInfo"] +] = contextvars.ContextVar("llm_call_info", default=None) # All the generation options applicable to the current context. generation_options_var = contextvars.ContextVar("generation_options", default=None) # The stats about the LLM calls. -llm_stats_var = contextvars.ContextVar("llm_stats", default=None) +llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar( + "llm_stats", default=None +) # The raw LLM request that comes from the user. # This is used in passthrough mode. diff --git a/nemoguardrails/logging/callbacks.py b/nemoguardrails/logging/callbacks.py index 48293bf13..9009c19ac 100644 --- a/nemoguardrails/logging/callbacks.py +++ b/nemoguardrails/logging/callbacks.py @@ -15,11 +15,15 @@ import logging import uuid from time import time -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast from uuid import UUID from langchain.callbacks import StdOutCallbackHandler -from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager +from langchain.callbacks.base import ( + AsyncCallbackHandler, + BaseCallbackHandler, + BaseCallbackManager, +) from langchain.callbacks.manager import AsyncCallbackManagerForChainRun from langchain.schema import AgentAction, AgentFinish, AIMessage, BaseMessage, LLMResult from langchain_core.outputs import ChatGeneration @@ -33,7 +37,7 @@ log = logging.getLogger(__name__) -class LoggingCallbackHandler(AsyncCallbackHandler, StdOutCallbackHandler): +class LoggingCallbackHandler(AsyncCallbackHandler): """Async callback handler that can be used to handle callbacks from langchain.""" async def on_llm_start( @@ -184,10 +188,17 @@ async def on_llm_end( ) log.info("Output Stats :: %s", response.llm_output) - took = llm_call_info.finished_at - llm_call_info.started_at - log.info("--- :: LLM call took %.2f seconds", took) - llm_stats.inc("total_time", took) - llm_call_info.duration = took + if ( + llm_call_info.finished_at is not None + and llm_call_info.started_at is not None + ): + took = llm_call_info.finished_at - llm_call_info.started_at + log.info("--- :: LLM call took %.2f seconds", took) + llm_stats.inc("total_time", took) + llm_call_info.duration = took + else: + log.warning("LLM call timing information incomplete") + llm_call_info.duration = 0.0 # Update the token usage stats as well token_stats_found = False @@ -259,7 +270,7 @@ async def on_llm_end( async def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -290,7 +301,7 @@ async def on_chain_end( async def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -321,7 +332,7 @@ async def on_tool_end( async def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -362,14 +373,15 @@ async def on_agent_finish( handlers = [LoggingCallbackHandler()] logging_callbacks = BaseCallbackManager( - handlers=handlers, inheritable_handlers=handlers + handlers=cast(List[BaseCallbackHandler], handlers), + inheritable_handlers=cast(List[BaseCallbackHandler], handlers), ) logging_callback_manager_for_chain = AsyncCallbackManagerForChainRun( run_id=uuid.uuid4(), parent_run_id=None, - handlers=handlers, - inheritable_handlers=handlers, + handlers=cast(List[BaseCallbackHandler], handlers), + inheritable_handlers=cast(List[BaseCallbackHandler], handlers), tags=[], inheritable_tags=[], ) diff --git a/nemoguardrails/logging/processing_log.py b/nemoguardrails/logging/processing_log.py index decc50181..4841b1c97 100644 --- a/nemoguardrails/logging/processing_log.py +++ b/nemoguardrails/logging/processing_log.py @@ -153,25 +153,36 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: action_params=event_data["action_params"], started_at=event["timestamp"], ) - activated_rail.executed_actions.append(executed_action) + if activated_rail is not None: + activated_rail.executed_actions.append(executed_action) elif event_type == "InternalSystemActionFinished": action_name = event_data["action_name"] if action_name in ignored_actions: continue - executed_action.finished_at = event["timestamp"] - executed_action.duration = ( - executed_action.finished_at - executed_action.started_at - ) - executed_action.return_value = event_data["return_value"] + if executed_action is not None: + executed_action.finished_at = event["timestamp"] + if ( + executed_action.finished_at is not None + and executed_action.started_at is not None + ): + executed_action.duration = ( + executed_action.finished_at - executed_action.started_at + ) + executed_action.return_value = event_data["return_value"] executed_action = None elif event_type in ["InputRailFinished", "OutputRailFinished"]: - activated_rail.finished_at = event["timestamp"] - activated_rail.duration = ( - activated_rail.finished_at - activated_rail.started_at - ) + if activated_rail is not None: + activated_rail.finished_at = event["timestamp"] + if ( + activated_rail.finished_at is not None + and activated_rail.started_at is not None + ): + activated_rail.duration = ( + activated_rail.finished_at - activated_rail.started_at + ) activated_rail = None elif event_type == "InputRailsFinished": @@ -181,14 +192,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: output_rails_finished_at = event["timestamp"] elif event["type"] == "llm_call_info": - executed_action.llm_calls.append(event["data"]) + if executed_action is not None: + executed_action.llm_calls.append(event["data"]) # If at the end of the processing we still have an active rail, it is because # we have hit a stop. In this case, we take the last timestamp as the timestamp for # finishing the rail. if activated_rail is not None: activated_rail.finished_at = last_timestamp - activated_rail.duration = activated_rail.finished_at - activated_rail.started_at + if ( + activated_rail.finished_at is not None + and activated_rail.started_at is not None + ): + activated_rail.duration = ( + activated_rail.finished_at - activated_rail.started_at + ) if activated_rail.type in ["input", "output"]: activated_rail.stop = True @@ -213,9 +231,13 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: if activated_rail.type in ["dialog", "generation"]: next_rail = generation_log.activated_rails[i + 1] activated_rail.finished_at = next_rail.started_at - activated_rail.duration = ( - activated_rail.finished_at - activated_rail.started_at - ) + if ( + activated_rail.finished_at is not None + and activated_rail.started_at is not None + ): + activated_rail.duration = ( + activated_rail.finished_at - activated_rail.started_at + ) # If we have output rails, we also record the general stats if output_rails_started_at: @@ -257,17 +279,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: for executed_action in activated_rail.executed_actions: for llm_call in executed_action.llm_calls: - generation_log.stats.llm_calls_count += 1 - generation_log.stats.llm_calls_duration += llm_call.duration - generation_log.stats.llm_calls_total_prompt_tokens += ( - llm_call.prompt_tokens or 0 - ) - generation_log.stats.llm_calls_total_completion_tokens += ( - llm_call.completion_tokens or 0 - ) - generation_log.stats.llm_calls_total_tokens += ( - llm_call.total_tokens or 0 - ) + generation_log.stats.llm_calls_count = ( + generation_log.stats.llm_calls_count or 0 + ) + 1 + generation_log.stats.llm_calls_duration = ( + generation_log.stats.llm_calls_duration or 0 + ) + (llm_call.duration or 0) + generation_log.stats.llm_calls_total_prompt_tokens = ( + generation_log.stats.llm_calls_total_prompt_tokens or 0 + ) + (llm_call.prompt_tokens or 0) + generation_log.stats.llm_calls_total_completion_tokens = ( + generation_log.stats.llm_calls_total_completion_tokens or 0 + ) + (llm_call.completion_tokens or 0) + generation_log.stats.llm_calls_total_tokens = ( + generation_log.stats.llm_calls_total_tokens or 0 + ) + (llm_call.total_tokens or 0) generation_log.stats.total_duration = ( processing_log[-1]["timestamp"] - processing_log[0]["timestamp"] diff --git a/nemoguardrails/logging/verbose.py b/nemoguardrails/logging/verbose.py index a2f972238..316906fb1 100644 --- a/nemoguardrails/logging/verbose.py +++ b/nemoguardrails/logging/verbose.py @@ -54,7 +54,10 @@ def emit(self, record) -> None: skip_print = True if verbose_llm_calls: console.print("") - console.print(f"[cyan]LLM {title} ({record.id[:5]}..)[/]") + record_id = getattr(record, "id", "unknown") + console.print( + f"[cyan]LLM {title} ({record_id[:5] if record_id != 'unknown' else record_id}..)[/]" + ) for line in body.split("\n"): text = Text(line, style="black on #006600", end="\n") text.pad_right(console.width) @@ -66,8 +69,10 @@ def emit(self, record) -> None: if verbose_llm_calls: skip_print = True console.print("") + record_id = getattr(record, "id", "unknown") + record_task = getattr(record, "task", "unknown") console.print( - f"[cyan]LLM Prompt ({record.id[:5]}..) - {record.task}[/]" + f"[cyan]LLM Prompt ({record_id[:5] if record_id != 'unknown' else record_id}..) - {record_task}[/]" ) for line in body.split("\n"):