|
21 | 21 | )
|
22 | 22 |
|
23 | 23 | import pytest
|
24 |
| -from pydantic import BaseModel, ConfigDict, Field, ValidationError |
| 24 | +from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator |
25 | 25 | from pydantic.v1 import BaseModel as BaseModelV1
|
26 | 26 | from pydantic.v1 import ValidationError as ValidationErrorV1
|
27 | 27 | from typing_extensions import TypedDict, override
|
@@ -643,7 +643,7 @@ def test_named_tool_decorator_return_direct() -> None:
|
643 | 643 | """Test functionality when arguments and return direct are provided as input."""
|
644 | 644 |
|
645 | 645 | @tool("search", return_direct=True)
|
646 |
| - def search_api(query: str, *args: Any) -> str: |
| 646 | + def search_api(query: str) -> str: |
647 | 647 | """Search the API for the query."""
|
648 | 648 | return "API result"
|
649 | 649 |
|
@@ -2766,3 +2766,50 @@ def test_tool(
|
2766 | 2766 | "type": "array",
|
2767 | 2767 | }
|
2768 | 2768 | }
|
| 2769 | + |
| 2770 | + |
| 2771 | +def test_tool_args_schema_with_pydantic_validator() -> None: |
| 2772 | + """Test that Pydantic model validators can transform input structure. |
| 2773 | +
|
| 2774 | + This test verifies that when a Pydantic validator wraps input in a nested |
| 2775 | + structure, the tool correctly processes the transformed input rather than |
| 2776 | + filtering it back to only the original input keys. |
| 2777 | +
|
| 2778 | + Before the fix, the tool would filter the result to only include keys from |
| 2779 | + the original tool_input, which broke validators that added structure. |
| 2780 | + """ |
| 2781 | + |
| 2782 | + class InnerModel(BaseModel): |
| 2783 | + query: str |
| 2784 | + count: int = 10 |
| 2785 | + |
| 2786 | + class OuterModel(BaseModel): |
| 2787 | + x: InnerModel |
| 2788 | + |
| 2789 | + @model_validator(mode="before") |
| 2790 | + @classmethod |
| 2791 | + def wrap_in_x(cls, data: Any) -> Any: |
| 2792 | + """Wrap flat input in nested 'x' structure if not already wrapped.""" |
| 2793 | + if isinstance(data, dict) and "x" not in data: |
| 2794 | + return {"x": data} |
| 2795 | + return data |
| 2796 | + |
| 2797 | + @tool(args_schema=OuterModel) |
| 2798 | + def search_with_nested_schema(x: InnerModel) -> str: |
| 2799 | + """Search with nested input schema and validator transformation.""" |
| 2800 | + return f"Searched for '{x.query}' with count {x.count}" |
| 2801 | + |
| 2802 | + # Test 1: Direct nested input |
| 2803 | + nested_input = {"x": {"query": "test", "count": 5}} |
| 2804 | + result1 = search_with_nested_schema.invoke(nested_input) |
| 2805 | + assert result1 == "Searched for 'test' with count 5" |
| 2806 | + |
| 2807 | + # Test 2: Flat input that gets wrapped by validator |
| 2808 | + flat_input = {"query": "test query", "count": 3} |
| 2809 | + result2 = search_with_nested_schema.invoke(flat_input) |
| 2810 | + assert result2 == "Searched for 'test query' with count 3" |
| 2811 | + |
| 2812 | + # Test 3: Flat input with default values |
| 2813 | + minimal_input = {"query": "minimal test"} |
| 2814 | + result3 = search_with_nested_schema.invoke(minimal_input) |
| 2815 | + assert result3 == "Searched for 'minimal test' with count 10" |
0 commit comments