diff --git a/src/agents/tool.py b/src/agents/tool.py index c8563e2e1b..cc38c082f8 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -552,6 +552,12 @@ def __agents_bind_function_tool__( return bound_invoker async def __call__(self, ctx: ToolContext[Any], input: str) -> Any: + if not isinstance(ctx, RunContextWrapper): + raise TypeError( + f"on_invoke_tool requires a ToolContext, got {type(ctx).__name__}. " + "Construct one with ToolContext(context=..., tool_name=..., " + "tool_call_id=..., tool_arguments=...) or invoke the tool through Runner." + ) try: return await self._invoke_tool_impl(ctx, input) except Exception as e: diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 60ae2558cc..476cf0d980 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -34,6 +34,7 @@ ) from agents.tool import default_tool_error_function from agents.tool_context import ToolContext +from openai.types.responses import ResponseFunctionToolCall def argless_function() -> str: @@ -1063,3 +1064,35 @@ def test_function_tool_timeout_error_function_must_be_callable() -> None: on_invoke_tool=_noop_on_invoke_tool, timeout_error_function=cast(Any, "not-callable"), ) + + +async def test_on_invoke_tool_rejects_non_tool_context() -> None: + """Calling on_invoke_tool with a non-context value should fail fast and clearly.""" + + @function_tool + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + with pytest.raises(TypeError, match="on_invoke_tool requires a ToolContext"): + await add.on_invoke_tool(cast(Any, None), '{"a": 1, "b": 2}') + + with pytest.raises(TypeError, match="on_invoke_tool requires a ToolContext"): + await add.on_invoke_tool(cast(Any, "not a context"), '{"a": 1, "b": 2}') + + # A valid ToolContext should still work. + tool_call = ResponseFunctionToolCall( + type="function_call", + name="add", + call_id="call-add", + arguments='{"a": 1, "b": 2}', + ) + tool_context = ToolContext( + context=None, + tool_name="add", + tool_call_id="call-add", + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + result = await add.on_invoke_tool(tool_context, tool_call.arguments) + assert result == 3