diff --git a/docs/running_agents.md b/docs/running_agents.md index d72d47a2ce..5b2a2cbdf4 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -432,9 +432,9 @@ settings so the resumed turn continues in the same server-managed conversation. ### Call model input filter -Use `call_model_input_filter` to edit the model input right before the model call. The hook receives the current agent, context, and the combined input items (including session history when present) and returns a new `ModelInputData`. +Use `call_model_input_filter` to edit the model input right before the model call. The hook receives the current agent, context, and the combined input items (including session history when present) and returns a new `ModelInputData`. The payload also includes the effective output schema for the current call, so a filter can inspect or replace the structured-output schema sent to the model. -The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`. +The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`. If the returned `output_schema` is `None`, the current schema is preserved. Return another output schema object to override the schema for that model call, or a plain-text schema such as `AgentOutputSchema(str)` to switch a structured-output call back to plain text. ```python from agents import Agent, Runner, RunConfig diff --git a/src/agents/extensions/tool_output_trimmer.py b/src/agents/extensions/tool_output_trimmer.py index 26b307f14f..f9d2b055ce 100644 --- a/src/agents/extensions/tool_output_trimmer.py +++ b/src/agents/extensions/tool_output_trimmer.py @@ -152,7 +152,11 @@ def __call__(self, data: CallModelData[Any]) -> ModelInputData: f"saved ~{chars_saved} chars" ) - return _ModelInputData(input=new_items, instructions=model_data.instructions) + return _ModelInputData( + input=new_items, + instructions=model_data.instructions, + output_schema=model_data.output_schema, + ) def _find_recent_boundary(self, items: list[Any]) -> int: """Find the index separating 'old' items from 'recent' items. diff --git a/src/agents/run_config.py b/src/agents/run_config.py index 45dcca5b10..d367b206e1 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from .agent import Agent + from .agent_output import AgentOutputSchemaBase from .run_context import RunContextWrapper from .sandbox.manifest import Manifest from .sandbox.session.base_sandbox_session import BaseSandboxSession @@ -50,6 +51,7 @@ class ModelInputData: input: list[TResponseInputItem] instructions: str | None + output_schema: AgentOutputSchemaBase | None = None @dataclass diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..578c13e3c1 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -1373,7 +1373,9 @@ def _tool_search_fingerprint(raw_item: Any) -> str: context_wrapper=context_wrapper, input_items=input, system_instructions=system_prompt, + output_schema=output_schema, ) + output_schema = filtered.output_schema if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) hosted_mcp_tool_metadata = collect_mcp_list_tools_metadata(streamed_result._model_input_items) @@ -1760,7 +1762,7 @@ async def run_single_turn( else: input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy) - new_response = await get_new_response( + new_response, output_schema = await get_new_response( bindings, system_prompt, input, @@ -1811,7 +1813,7 @@ async def get_new_response( session: Session | None = None, session_items_to_rewind: list[TResponseInputItem] | None = None, prompt_cache_key_resolver: PromptCacheKeyResolver | None = None, -) -> ModelResponse: +) -> tuple[ModelResponse, AgentOutputSchemaBase | None]: """Call the model and return the raw response, handling retries and hooks.""" public_agent = bindings.public_agent execution_agent = bindings.execution_agent @@ -1821,7 +1823,9 @@ async def get_new_response( context_wrapper=context_wrapper, input_items=input, system_instructions=system_prompt, + output_schema=output_schema, ) + output_schema = filtered.output_schema if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) @@ -1917,4 +1921,4 @@ async def rewind_model_request() -> None: hooks.on_llm_end(context_wrapper, public_agent, new_response), ) - return new_response + return new_response, output_schema diff --git a/src/agents/run_internal/turn_preparation.py b/src/agents/run_internal/turn_preparation.py index 0a79ebd813..04ea4d3925 100644 --- a/src/agents/run_internal/turn_preparation.py +++ b/src/agents/run_internal/turn_preparation.py @@ -55,18 +55,24 @@ async def maybe_filter_model_input( context_wrapper: RunContextWrapper[TContext], input_items: list[TResponseInputItem], system_instructions: str | None, + output_schema: AgentOutputSchemaBase | None, ) -> ModelInputData: """Apply optional call_model_input_filter to modify model input.""" effective_instructions = system_instructions effective_input: list[TResponseInputItem] = input_items if run_config.call_model_input_filter is None: - return ModelInputData(input=effective_input, instructions=effective_instructions) + return ModelInputData( + input=effective_input, + instructions=effective_instructions, + output_schema=output_schema, + ) try: model_input = ModelInputData( input=effective_input.copy(), instructions=effective_instructions, + output_schema=output_schema, ) filter_payload: CallModelData[TContext] = CallModelData( model_data=model_input, @@ -77,6 +83,8 @@ async def maybe_filter_model_input( updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated if not isinstance(updated, ModelInputData): raise UserError("call_model_input_filter must return a ModelInputData instance") + if updated.output_schema is None: + updated.output_schema = output_schema return updated except Exception as e: _error_tracing.attach_error_to_current_span( diff --git a/tests/extensions/test_tool_output_trimmer.py b/tests/extensions/test_tool_output_trimmer.py index 04a0a70728..79ba6f58df 100644 --- a/tests/extensions/test_tool_output_trimmer.py +++ b/tests/extensions/test_tool_output_trimmer.py @@ -11,7 +11,9 @@ import pytest +from agents.agent_output import AgentOutputSchemaBase from agents.extensions.tool_output_trimmer import ToolOutputTrimmer +from agents.items import TResponseInputItem from agents.run_config import CallModelData, ModelInputData # --------------------------------------------------------------------------- @@ -43,6 +45,23 @@ def _make_data(items: list[Any]) -> CallModelData[Any]: return CallModelData(model_data=model_data, agent=MagicMock(), context=None) +class _Schema(AgentOutputSchemaBase): + def is_plain_text(self) -> bool: + return False + + def name(self) -> str: + return "Schema" + + def json_schema(self) -> dict[str, Any]: + return {"type": "object", "properties": {}} + + def is_strict_json_schema(self) -> bool: + return False + + def validate_json(self, json_str: str) -> Any: + return json.loads(json_str) + + def _output(result: ModelInputData, idx: int) -> Any: """Extract the ``output`` field from a result item (untyped for test convenience).""" item: Any = result.input[idx] @@ -182,6 +201,32 @@ def test_no_trimming_when_all_recent(self) -> None: result = trimmer(_make_data(items)) assert _output(result, 2) == large + def test_preserves_output_schema_when_trimming(self) -> None: + """The trimmer should not drop structured-output metadata from ModelInputData.""" + schema = _Schema() + items = [ + _user("q1"), + _func_call("c1", "search"), + _func_output("c1", "x" * 1000), + _assistant("a1"), + _user("q2"), + _assistant("a2"), + ] + trimmer = ToolOutputTrimmer(recent_turns=1) + data = CallModelData( + model_data=ModelInputData( + input=cast(list[TResponseInputItem], items), + instructions="You are helpful.", + output_schema=schema, + ), + agent=MagicMock(), + context=None, + ) + + result = trimmer(data) + + assert result.output_schema is schema + def test_trims_large_old_output(self) -> None: """Large output in an old turn should be trimmed.""" large = "x" * 1000 diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 4b5ea867ce..b8a7d15ee2 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -2561,7 +2561,7 @@ async def test_conversation_lock_rewind_skips_when_no_snapshot() -> None: model.add_multiple_turn_outputs([locked_error, [get_text_message("ok")]]) agent = Agent(name="test", model=model) - result = await get_new_response( + result, output_schema = await get_new_response( bindings=bind_public_agent(agent), system_prompt=None, input=[history_item, new_item], @@ -2579,6 +2579,7 @@ async def test_conversation_lock_rewind_skips_when_no_snapshot() -> None: ) assert isinstance(result, ModelResponse) + assert output_schema is None assert session.pop_calls == 0 @@ -2606,7 +2607,7 @@ async def test_get_new_response_uses_agent_retry_settings() -> None: ), ) - result = await get_new_response( + result, output_schema = await get_new_response( bindings=bind_public_agent(agent), system_prompt=None, input=[get_text_input_item("hello")], @@ -2624,6 +2625,7 @@ async def test_get_new_response_uses_agent_retry_settings() -> None: ) assert isinstance(result, ModelResponse) + assert output_schema is None assert result.usage.requests == 2 diff --git a/tests/test_call_model_input_filter_unit.py b/tests/test_call_model_input_filter_unit.py index ff14fc2829..59107f82ea 100644 --- a/tests/test_call_model_input_filter_unit.py +++ b/tests/test_call_model_input_filter_unit.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import sys from pathlib import Path from typing import Any @@ -13,10 +14,41 @@ # Import directly from submodules to avoid heavy __init__ side effects from agents.agent import Agent +from agents.agent_output import AgentOutputSchema, AgentOutputSchemaBase from agents.exceptions import UserError from agents.run import CallModelData, ModelInputData, RunConfig, Runner +class _FilteredOutputSchema(AgentOutputSchemaBase): + def is_plain_text(self) -> bool: + return False + + def name(self) -> str: + return "FilteredOutput" + + def json_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + "additionalProperties": False, + } + + def is_strict_json_schema(self) -> bool: + return True + + def validate_json(self, json_str: str) -> Any: + return {"parsed": json.loads(json_str)["value"]} + + +class _OriginalOutputSchema(_FilteredOutputSchema): + def name(self) -> str: + return "OriginalOutput" + + def validate_json(self, json_str: str) -> Any: + return {"original": json.loads(json_str)["value"]} + + @pytest.mark.asyncio async def test_call_model_input_filter_sync_non_streamed_unit() -> None: model = FakeModel() @@ -109,3 +141,220 @@ def invalid_filter(_data: CallModelData[Any]): input="start", run_config=RunConfig(call_model_input_filter=invalid_filter), ) + + +@pytest.mark.asyncio +async def test_call_model_input_filter_can_override_output_schema_non_streamed_unit() -> None: + model = FakeModel() + replacement_schema = _FilteredOutputSchema() + agent = Agent(name="test", model=model) + + model.set_next_output( + [ + ResponseOutputMessage( + id="1", + type="message", + role="assistant", + content=[ + ResponseOutputText( + text='{"value": "non-streamed"}', + type="output_text", + annotations=[], + logprobs=[], + ) + ], + status="completed", + ) + ] + ) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + assert data.model_data.output_schema is None + return ModelInputData( + input=data.model_data.input, + instructions=data.model_data.instructions, + output_schema=replacement_schema, + ) + + result = await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["output_schema"] is replacement_schema + assert result.final_output == {"parsed": "non-streamed"} + + +@pytest.mark.asyncio +async def test_call_model_input_filter_can_override_output_schema_streamed_unit() -> None: + model = FakeModel() + replacement_schema = _FilteredOutputSchema() + agent = Agent(name="test", model=model) + + model.set_next_output( + [ + ResponseOutputMessage( + id="1", + type="message", + role="assistant", + content=[ + ResponseOutputText( + text='{"value": "streamed"}', + type="output_text", + annotations=[], + logprobs=[], + ) + ], + status="completed", + ) + ] + ) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + assert data.model_data.output_schema is None + return ModelInputData( + input=data.model_data.input, + instructions=data.model_data.instructions, + output_schema=replacement_schema, + ) + + result = Runner.run_streamed( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + async for _ in result.stream_events(): + pass + + assert model.last_turn_args["output_schema"] is replacement_schema + assert result.final_output == {"parsed": "streamed"} + + +@pytest.mark.asyncio +async def test_call_model_input_filter_preserves_existing_output_schema_unit() -> None: + model = FakeModel() + original_schema = _OriginalOutputSchema() + agent = Agent(name="test", model=model, output_type=original_schema) + + model.set_next_output( + [ + ResponseOutputMessage( + id="1", + type="message", + role="assistant", + content=[ + ResponseOutputText( + text='{"value": "original"}', + type="output_text", + annotations=[], + logprobs=[], + ) + ], + status="completed", + ) + ] + ) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + assert data.model_data.output_schema is original_schema + return ModelInputData( + input=data.model_data.input, instructions=data.model_data.instructions + ) + + result = await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["output_schema"] is original_schema + assert result.final_output == {"original": "original"} + + +@pytest.mark.asyncio +async def test_call_model_input_filter_can_replace_existing_output_schema_unit() -> None: + model = FakeModel() + original_schema = _OriginalOutputSchema() + replacement_schema = _FilteredOutputSchema() + agent = Agent(name="test", model=model, output_type=original_schema) + + model.set_next_output( + [ + ResponseOutputMessage( + id="1", + type="message", + role="assistant", + content=[ + ResponseOutputText( + text='{"value": "replacement"}', + type="output_text", + annotations=[], + logprobs=[], + ) + ], + status="completed", + ) + ] + ) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + assert data.model_data.output_schema is original_schema + return ModelInputData( + input=data.model_data.input, + instructions=data.model_data.instructions, + output_schema=replacement_schema, + ) + + result = await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["output_schema"] is replacement_schema + assert result.final_output == {"parsed": "replacement"} + + +@pytest.mark.asyncio +async def test_call_model_input_filter_can_switch_existing_schema_to_plain_text_unit() -> None: + model = FakeModel() + original_schema = _OriginalOutputSchema() + plain_text_schema = AgentOutputSchema(str) + agent = Agent(name="test", model=model, output_type=original_schema) + + model.set_next_output( + [ + ResponseOutputMessage( + id="1", + type="message", + role="assistant", + content=[ + ResponseOutputText( + text="plain replacement", + type="output_text", + annotations=[], + logprobs=[], + ) + ], + status="completed", + ) + ] + ) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + assert data.model_data.output_schema is original_schema + return ModelInputData( + input=data.model_data.input, + instructions=data.model_data.instructions, + output_schema=plain_text_schema, + ) + + result = await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["output_schema"] is plain_text_schema + assert result.final_output == "plain replacement" diff --git a/tests/test_server_conversation_tracker.py b/tests/test_server_conversation_tracker.py index d7fff5a5ed..bd23a4f028 100644 --- a/tests/test_server_conversation_tracker.py +++ b/tests/test_server_conversation_tracker.py @@ -805,7 +805,7 @@ def _filter_input(payload: Any) -> ModelInputData: run_config = RunConfig(call_model_input_filter=_filter_input) - await get_new_response( + response, output_schema = await get_new_response( bind_public_agent(agent), None, [item_1, item_2], @@ -820,6 +820,8 @@ def _filter_input(payload: Any) -> ModelInputData: None, ) + assert isinstance(response, ModelResponse) + assert output_schema is None assert model.last_turn_args["input"] == [item_1] assert any(item is item_1 for item in tracker.sent_items) assert all(item is not item_2 for item in tracker.sent_items)