Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
This module provides the main entry point for the autogen_agentchat package.
It includes logger names for trace and event logs, and retrieves the package version.
"""
#from .agents._assistant_agent import AssistantAgent
from .utils.constants import EVENT_LOGGER_NAME
# autogen_agentchat/__init__.py
from .agents._assistant_agent import AssistantAgent

import importlib.metadata

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
from pydantic import BaseModel, Field
from typing_extensions import Self

from .. import EVENT_LOGGER_NAME
#from .. import EVENT_LOGGER_NAME
from ..utils.constants import EVENT_LOGGER_NAME

from ..base import Handoff as HandoffBase
from ..base import Response
from ..messages import (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# autogen_agentchat/utils/constants.py
EVENT_LOGGER_NAME = "event_logger"
16 changes: 16 additions & 0 deletions python/packages/autogen-core/src/autogen_core/utils/sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def sanitize_tool_calls(message: dict) -> dict:
"""
Validates tool_calls from LLM response. Raises error on known malformed patterns.
"""
tool_calls = message.get("tool_calls", [])
content = message.get("content", "")

# Block empty tool_calls list
if isinstance(tool_calls, list) and len(tool_calls) == 0:
raise ValueError("Malformed tool call: tool_calls is an empty list.")

# Block spurious tool_call_end marker
if isinstance(content, str) and content.strip() == "<|tool_call_end|>":
raise ValueError(" Invalid content-only tool call marker received.")

return message
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import math
import os
import sys
import re
import warnings
from asyncio import Task
Expand Down Expand Up @@ -677,9 +678,9 @@ async def create(
json_output,
extra_create_args,
)

future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]]
if create_params.response_format is not None:
# Use beta client if response_format is not None
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=create_params.messages,
Expand All @@ -689,7 +690,6 @@ async def create(
)
)
else:
# Use the regular client
future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=create_params.messages,
Expand All @@ -701,14 +701,19 @@ async def create(

if cancellation_token is not None:
cancellation_token.link_future(future)

result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future
if create_params.response_format is not None:
result = cast(ParsedChatCompletion[Any], result)


# Sanitize malformed tool call responses here
from autogen_core.utils.sanitizer import sanitize_tool_calls
result_dict = result.model_dump()
sanitized_result_dict = sanitize_tool_calls(result_dict)
result = result.__class__(**sanitized_result_dict) # Rebuild the same class

# Handle the case where OpenAI API might return None for token counts
# even when result.usage is not None
usage = RequestUsage(
# TODO backup token counting
prompt_tokens=getattr(result.usage, "prompt_tokens", 0) if result.usage is not None else 0,
completion_tokens=getattr(result.usage, "completion_tokens", 0) if result.usage is not None else 0,
)
Expand All @@ -723,20 +728,16 @@ async def create(
)
)

if self._resolved_model is not None:
if self._resolved_model != result.model:
warnings.warn(
f"Resolved model mismatch: {self._resolved_model} != {result.model}. "
"Model mapping in autogen_ext.models.openai may be incorrect. "
f"Set the model to {result.model} to enhance token/cost estimation and suppress this warning.",
stacklevel=2,
)
if self._resolved_model is not None and self._resolved_model != result.model:
warnings.warn(
f"Resolved model mismatch: {self._resolved_model} != {result.model}. "
"Model mapping in autogen_ext.models.openai may be incorrect. "
f"Set the model to {result.model} to enhance token/cost estimation and suppress this warning.",
stacklevel=2,
)

# Limited to a single choice currently.
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0]

# Detect whether it is a function call or not.
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
content: Union[str, List[FunctionCall]]
thought: str | None = None
if choice.message.function_call is not None:
Expand All @@ -745,20 +746,16 @@ async def create(
if choice.finish_reason != "tool_calls":
warnings.warn(
f"Finish reason mismatch: {choice.finish_reason} != tool_calls "
"when tool_calls are present. Finish reason may not be accurate. "
"This may be due to the API used that is not returning the correct finish reason.",
"when tool_calls are present. Finish reason may not be accurate.",
stacklevel=2,
)
if choice.message.content is not None and choice.message.content != "":
# Put the content in the thought field.
if choice.message.content:
thought = choice.message.content
# NOTE: If OAI response type changes, this will need to be updated
content = []
for tool_call in choice.message.tool_calls:
if not isinstance(tool_call.function.arguments, str):
warnings.warn(
f"Tool call function arguments field is not a string: {tool_call.function.arguments}."
"This is unexpected and may due to the API used not returning the correct type. "
f"Tool call function arguments field is not a string: {tool_call.function.arguments}. "
"Attempting to convert it to string.",
stacklevel=2,
)
Expand All @@ -773,11 +770,9 @@ async def create(
)
finish_reason = "tool_calls"
else:
# if not tool_calls, then it is a text response and we populate the content and thought fields.
finish_reason = choice.finish_reason
content = choice.message.content or ""
# if there is a reasoning_content field, then we populate the thought field. This is for models such as R1 - direct from deepseek api.
if choice.message.model_extra is not None:
if choice.message.model_extra:
reasoning_content = choice.message.model_extra.get("reasoning_content")
if reasoning_content is not None:
thought = reasoning_content
Expand All @@ -794,7 +789,6 @@ async def create(
for x in choice.logprobs.content
]

# This is for local R1 models.
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None:
thought, content = parse_r1_content(content)

Expand All @@ -810,9 +804,13 @@ async def create(
self._total_usage = _add_usage(self._total_usage, usage)
self._actual_usage = _add_usage(self._actual_usage, usage)

# TODO - why is this cast needed?
return response






async def create_stream(
self,
messages: Sequence[LLMMessage],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
from autogen_ext.models.openai._transformation.registry import _find_model_family # pyright: ignore[reportPrivateUsage]
from openai.lib.streaming.chat import AsyncChatCompletionStreamManager

from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import (
Expand Down
10 changes: 10 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
[project]
name = "autogen"
version = "0.1.0"
description = "AutoGen Python SDK"
requires-python = ">=3.10"
dependencies = []


[dependency-groups]
dev = [
"pyright==1.1.389",
Expand Down Expand Up @@ -145,3 +153,5 @@ cmd = "python -m grpc_tools.protoc --python_out=./packages/autogen-ext/tests/pro
markers = [
"grpc: tests invoking gRPC functionality",
]
[tool.setuptools.packages.find]
where = ["packages"]
29 changes: 29 additions & 0 deletions test_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import sys
import os

# Add path to autogen_core
sys.path.append(os.path.abspath("python/packages/autogen-core/src"))

from autogen_core.utils.sanitizer import sanitize_tool_calls

def test_valid():
message = {"content": "All good", "tool_calls": [{"name": "tool1"}]}
assert sanitize_tool_calls(message) == message
print(" Valid message passed.")

def test_empty_tool_calls():
try:
sanitize_tool_calls({"content": "nothing", "tool_calls": []})
except ValueError as e:
print(f" Caught expected error: {e}")

def test_tool_call_end_only():
try:
sanitize_tool_calls({"content": "<|tool_call_end|>"})
except ValueError as e:
print(f" Caught expected error: {e}")

if __name__ == "__main__":
test_valid()
test_empty_tool_calls()
test_tool_call_end_only()