Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def init(self):
self._init_flows_index(),
)

def _extract_user_message_example(self, flow: Flow):
def _extract_user_message_example(self, flow: Flow) -> None:
"""Heuristic to extract user message examples from a flow."""
elements = [
item
Expand Down
18 changes: 14 additions & 4 deletions nemoguardrails/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,31 @@
# limitations under the License.

import contextvars
from typing import Optional
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo
from nemoguardrails.logging.stats import LLMStats

streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)

# The object that holds additional explanation information.
explain_info_var = contextvars.ContextVar("explain_info", default=None)
explain_info_var: contextvars.ContextVar[
Optional["ExplainInfo"]
] = contextvars.ContextVar("explain_info", default=None)

# The current LLM call.
llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None)
llm_call_info_var: contextvars.ContextVar[
Optional["LLMCallInfo"]
] = contextvars.ContextVar("llm_call_info", default=None)

# All the generation options applicable to the current context.
generation_options_var = contextvars.ContextVar("generation_options", default=None)

# The stats about the LLM calls.
llm_stats_var = contextvars.ContextVar("llm_stats", default=None)
llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar(
"llm_stats", default=None
)

# The raw LLM request that comes from the user.
# This is used in passthrough mode.
Expand Down
38 changes: 25 additions & 13 deletions nemoguardrails/logging/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
import logging
import uuid
from time import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast
from uuid import UUID

from langchain.callbacks import StdOutCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
from langchain.callbacks.base import (
AsyncCallbackHandler,
BaseCallbackHandler,
BaseCallbackManager,
)
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
from langchain.schema import AgentAction, AgentFinish, AIMessage, BaseMessage, LLMResult
from langchain_core.outputs import ChatGeneration
Expand All @@ -33,7 +37,7 @@
log = logging.getLogger(__name__)


class LoggingCallbackHandler(AsyncCallbackHandler, StdOutCallbackHandler):
class LoggingCallbackHandler(AsyncCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""

async def on_llm_start(
Expand Down Expand Up @@ -184,10 +188,17 @@ async def on_llm_end(
)

log.info("Output Stats :: %s", response.llm_output)
took = llm_call_info.finished_at - llm_call_info.started_at
log.info("--- :: LLM call took %.2f seconds", took)
llm_stats.inc("total_time", took)
llm_call_info.duration = took
if (
llm_call_info.finished_at is not None
and llm_call_info.started_at is not None
):
took = llm_call_info.finished_at - llm_call_info.started_at
log.info("--- :: LLM call took %.2f seconds", took)
llm_stats.inc("total_time", took)
llm_call_info.duration = took
else:
log.warning("LLM call timing information incomplete")
llm_call_info.duration = 0.0

# Update the token usage stats as well
token_stats_found = False
Expand Down Expand Up @@ -259,7 +270,7 @@ async def on_llm_end(

async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
Expand Down Expand Up @@ -290,7 +301,7 @@ async def on_chain_end(

async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
Expand Down Expand Up @@ -321,7 +332,7 @@ async def on_tool_end(

async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
Expand Down Expand Up @@ -362,14 +373,15 @@ async def on_agent_finish(

handlers = [LoggingCallbackHandler()]
logging_callbacks = BaseCallbackManager(
handlers=handlers, inheritable_handlers=handlers
handlers=cast(List[BaseCallbackHandler], handlers),
inheritable_handlers=cast(List[BaseCallbackHandler], handlers),
)

logging_callback_manager_for_chain = AsyncCallbackManagerForChainRun(
run_id=uuid.uuid4(),
parent_run_id=None,
handlers=handlers,
inheritable_handlers=handlers,
handlers=cast(List[BaseCallbackHandler], handlers),
inheritable_handlers=cast(List[BaseCallbackHandler], handlers),
tags=[],
inheritable_tags=[],
)
78 changes: 52 additions & 26 deletions nemoguardrails/logging/processing_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,25 +153,36 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
action_params=event_data["action_params"],
started_at=event["timestamp"],
)
activated_rail.executed_actions.append(executed_action)
if activated_rail is not None:
activated_rail.executed_actions.append(executed_action)

elif event_type == "InternalSystemActionFinished":
action_name = event_data["action_name"]
if action_name in ignored_actions:
continue

executed_action.finished_at = event["timestamp"]
executed_action.duration = (
executed_action.finished_at - executed_action.started_at
)
executed_action.return_value = event_data["return_value"]
if executed_action is not None:
executed_action.finished_at = event["timestamp"]
if (
executed_action.finished_at is not None
and executed_action.started_at is not None
):
executed_action.duration = (
executed_action.finished_at - executed_action.started_at
)
executed_action.return_value = event_data["return_value"]
executed_action = None

elif event_type in ["InputRailFinished", "OutputRailFinished"]:
activated_rail.finished_at = event["timestamp"]
activated_rail.duration = (
activated_rail.finished_at - activated_rail.started_at
)
if activated_rail is not None:
activated_rail.finished_at = event["timestamp"]
if (
activated_rail.finished_at is not None
and activated_rail.started_at is not None
):
activated_rail.duration = (
activated_rail.finished_at - activated_rail.started_at
)
activated_rail = None

elif event_type == "InputRailsFinished":
Expand All @@ -181,14 +192,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
output_rails_finished_at = event["timestamp"]

elif event["type"] == "llm_call_info":
executed_action.llm_calls.append(event["data"])
if executed_action is not None:
executed_action.llm_calls.append(event["data"])

# If at the end of the processing we still have an active rail, it is because
# we have hit a stop. In this case, we take the last timestamp as the timestamp for
# finishing the rail.
if activated_rail is not None:
activated_rail.finished_at = last_timestamp
activated_rail.duration = activated_rail.finished_at - activated_rail.started_at
if (
activated_rail.finished_at is not None
and activated_rail.started_at is not None
):
activated_rail.duration = (
activated_rail.finished_at - activated_rail.started_at
)

if activated_rail.type in ["input", "output"]:
activated_rail.stop = True
Expand All @@ -213,9 +231,13 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
if activated_rail.type in ["dialog", "generation"]:
next_rail = generation_log.activated_rails[i + 1]
activated_rail.finished_at = next_rail.started_at
activated_rail.duration = (
activated_rail.finished_at - activated_rail.started_at
)
if (
activated_rail.finished_at is not None
and activated_rail.started_at is not None
):
activated_rail.duration = (
activated_rail.finished_at - activated_rail.started_at
)

# If we have output rails, we also record the general stats
if output_rails_started_at:
Expand Down Expand Up @@ -257,17 +279,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:

for executed_action in activated_rail.executed_actions:
for llm_call in executed_action.llm_calls:
generation_log.stats.llm_calls_count += 1
generation_log.stats.llm_calls_duration += llm_call.duration
generation_log.stats.llm_calls_total_prompt_tokens += (
llm_call.prompt_tokens or 0
)
generation_log.stats.llm_calls_total_completion_tokens += (
llm_call.completion_tokens or 0
)
generation_log.stats.llm_calls_total_tokens += (
llm_call.total_tokens or 0
)
generation_log.stats.llm_calls_count = (
generation_log.stats.llm_calls_count or 0
) + 1
generation_log.stats.llm_calls_duration = (
generation_log.stats.llm_calls_duration or 0
) + (llm_call.duration or 0)
generation_log.stats.llm_calls_total_prompt_tokens = (
generation_log.stats.llm_calls_total_prompt_tokens or 0
) + (llm_call.prompt_tokens or 0)
generation_log.stats.llm_calls_total_completion_tokens = (
generation_log.stats.llm_calls_total_completion_tokens or 0
) + (llm_call.completion_tokens or 0)
generation_log.stats.llm_calls_total_tokens = (
generation_log.stats.llm_calls_total_tokens or 0
) + (llm_call.total_tokens or 0)

generation_log.stats.total_duration = (
processing_log[-1]["timestamp"] - processing_log[0]["timestamp"]
Expand Down
9 changes: 7 additions & 2 deletions nemoguardrails/logging/verbose.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def emit(self, record) -> None:
skip_print = True
if verbose_llm_calls:
console.print("")
console.print(f"[cyan]LLM {title} ({record.id[:5]}..)[/]")
record_id = getattr(record, "id", "unknown")
console.print(
f"[cyan]LLM {title} ({record_id[:5] if record_id != 'unknown' else record_id}..)[/]"
)
for line in body.split("\n"):
text = Text(line, style="black on #006600", end="\n")
text.pad_right(console.width)
Expand All @@ -66,8 +69,10 @@ def emit(self, record) -> None:
if verbose_llm_calls:
skip_print = True
console.print("")
record_id = getattr(record, "id", "unknown")
record_task = getattr(record, "task", "unknown")
console.print(
f"[cyan]LLM Prompt ({record.id[:5]}..) - {record.task}[/]"
f"[cyan]LLM Prompt ({record_id[:5] if record_id != 'unknown' else record_id}..) - {record_task}[/]"
)

for line in body.split("\n"):
Expand Down