Skip to content

Commit 23c2b92

Browse files
committed
x
1 parent b88115f commit 23c2b92

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

libs/core/langchain_core/tools/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,7 @@ def _parse_input(
694694
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
695695
)
696696
raise NotImplementedError(msg)
697-
return {
698-
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
699-
}
697+
return {k: getattr(result, k) for k, v in result_dict.items()}
700698
return tool_input
701699

702700
@model_validator(mode="before")

libs/core/tests/unit_tests/test_tools.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
import pytest
24-
from pydantic import BaseModel, ConfigDict, Field, ValidationError
24+
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
2525
from pydantic.v1 import BaseModel as BaseModelV1
2626
from pydantic.v1 import ValidationError as ValidationErrorV1
2727
from typing_extensions import TypedDict, override
@@ -643,7 +643,7 @@ def test_named_tool_decorator_return_direct() -> None:
643643
"""Test functionality when arguments and return direct are provided as input."""
644644

645645
@tool("search", return_direct=True)
646-
def search_api(query: str, *args: Any) -> str:
646+
def search_api(query: str) -> str:
647647
"""Search the API for the query."""
648648
return "API result"
649649

@@ -2766,3 +2766,50 @@ def test_tool(
27662766
"type": "array",
27672767
}
27682768
}
2769+
2770+
2771+
def test_tool_args_schema_with_pydantic_validator() -> None:
2772+
"""Test that Pydantic model validators can transform input structure.
2773+
2774+
This test verifies that when a Pydantic validator wraps input in a nested
2775+
structure, the tool correctly processes the transformed input rather than
2776+
filtering it back to only the original input keys.
2777+
2778+
Before the fix, the tool would filter the result to only include keys from
2779+
the original tool_input, which broke validators that added structure.
2780+
"""
2781+
2782+
class InnerModel(BaseModel):
2783+
query: str
2784+
count: int = 10
2785+
2786+
class OuterModel(BaseModel):
2787+
x: InnerModel
2788+
2789+
@model_validator(mode="before")
2790+
@classmethod
2791+
def wrap_in_x(cls, data: Any) -> Any:
2792+
"""Wrap flat input in nested 'x' structure if not already wrapped."""
2793+
if isinstance(data, dict) and "x" not in data:
2794+
return {"x": data}
2795+
return data
2796+
2797+
@tool(args_schema=OuterModel)
2798+
def search_with_nested_schema(x: InnerModel) -> str:
2799+
"""Search with nested input schema and validator transformation."""
2800+
return f"Searched for '{x.query}' with count {x.count}"
2801+
2802+
# Test 1: Direct nested input
2803+
nested_input = {"x": {"query": "test", "count": 5}}
2804+
result1 = search_with_nested_schema.invoke(nested_input)
2805+
assert result1 == "Searched for 'test' with count 5"
2806+
2807+
# Test 2: Flat input that gets wrapped by validator
2808+
flat_input = {"query": "test query", "count": 3}
2809+
result2 = search_with_nested_schema.invoke(flat_input)
2810+
assert result2 == "Searched for 'test query' with count 3"
2811+
2812+
# Test 3: Flat input with default values
2813+
minimal_input = {"query": "minimal test"}
2814+
result3 = search_with_nested_schema.invoke(minimal_input)
2815+
assert result3 == "Searched for 'minimal test' with count 10"

0 commit comments

Comments
 (0)