Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/running_agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,9 @@ settings so the resumed turn continues in the same server-managed conversation.

### Call model input filter

Use `call_model_input_filter` to edit the model input right before the model call. The hook receives the current agent, context, and the combined input items (including session history when present) and returns a new `ModelInputData`.
Use `call_model_input_filter` to edit the model input right before the model call. The hook receives the current agent, context, and the combined input items (including session history when present) and returns a new `ModelInputData`. The payload also includes the effective output schema for the current call, so a filter can inspect or replace the structured-output schema sent to the model.

The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`.
The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`. If the returned `output_schema` is `None`, the current schema is preserved. Return another output schema object to override the schema for that model call, or a plain-text schema such as `AgentOutputSchema(str)` to switch a structured-output call back to plain text.

```python
from agents import Agent, Runner, RunConfig
Expand Down
6 changes: 5 additions & 1 deletion src/agents/extensions/tool_output_trimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def __call__(self, data: CallModelData[Any]) -> ModelInputData:
f"saved ~{chars_saved} chars"
)

return _ModelInputData(input=new_items, instructions=model_data.instructions)
return _ModelInputData(
input=new_items,
instructions=model_data.instructions,
output_schema=model_data.output_schema,
)

def _find_recent_boundary(self, items: list[Any]) -> int:
"""Find the index separating 'old' items from 'recent' items.
Expand Down
2 changes: 2 additions & 0 deletions src/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

if TYPE_CHECKING:
from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .run_context import RunContextWrapper
from .sandbox.manifest import Manifest
from .sandbox.session.base_sandbox_session import BaseSandboxSession
Expand Down Expand Up @@ -50,6 +51,7 @@ class ModelInputData:

input: list[TResponseInputItem]
instructions: str | None
output_schema: AgentOutputSchemaBase | None = None


@dataclass
Expand Down
10 changes: 7 additions & 3 deletions src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,9 @@ def _tool_search_fingerprint(raw_item: Any) -> str:
context_wrapper=context_wrapper,
input_items=input,
system_instructions=system_prompt,
output_schema=output_schema,
)
output_schema = filtered.output_schema
if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)
hosted_mcp_tool_metadata = collect_mcp_list_tools_metadata(streamed_result._model_input_items)
Expand Down Expand Up @@ -1760,7 +1762,7 @@ async def run_single_turn(
else:
input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy)

new_response = await get_new_response(
new_response, output_schema = await get_new_response(
bindings,
system_prompt,
input,
Expand Down Expand Up @@ -1811,7 +1813,7 @@ async def get_new_response(
session: Session | None = None,
session_items_to_rewind: list[TResponseInputItem] | None = None,
prompt_cache_key_resolver: PromptCacheKeyResolver | None = None,
) -> ModelResponse:
) -> tuple[ModelResponse, AgentOutputSchemaBase | None]:
"""Call the model and return the raw response, handling retries and hooks."""
public_agent = bindings.public_agent
execution_agent = bindings.execution_agent
Expand All @@ -1821,7 +1823,9 @@ async def get_new_response(
context_wrapper=context_wrapper,
input_items=input,
system_instructions=system_prompt,
output_schema=output_schema,
)
output_schema = filtered.output_schema
if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)

Expand Down Expand Up @@ -1917,4 +1921,4 @@ async def rewind_model_request() -> None:
hooks.on_llm_end(context_wrapper, public_agent, new_response),
)

return new_response
return new_response, output_schema
10 changes: 9 additions & 1 deletion src/agents/run_internal/turn_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,24 @@ async def maybe_filter_model_input(
context_wrapper: RunContextWrapper[TContext],
input_items: list[TResponseInputItem],
system_instructions: str | None,
output_schema: AgentOutputSchemaBase | None,
) -> ModelInputData:
"""Apply optional call_model_input_filter to modify model input."""
effective_instructions = system_instructions
effective_input: list[TResponseInputItem] = input_items

if run_config.call_model_input_filter is None:
return ModelInputData(input=effective_input, instructions=effective_instructions)
return ModelInputData(
input=effective_input,
instructions=effective_instructions,
output_schema=output_schema,
)

try:
model_input = ModelInputData(
input=effective_input.copy(),
instructions=effective_instructions,
output_schema=output_schema,
)
filter_payload: CallModelData[TContext] = CallModelData(
model_data=model_input,
Expand All @@ -77,6 +83,8 @@ async def maybe_filter_model_input(
updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated
if not isinstance(updated, ModelInputData):
raise UserError("call_model_input_filter must return a ModelInputData instance")
if updated.output_schema is None:
updated.output_schema = output_schema
return updated
except Exception as e:
_error_tracing.attach_error_to_current_span(
Expand Down
45 changes: 45 additions & 0 deletions tests/extensions/test_tool_output_trimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import pytest

from agents.agent_output import AgentOutputSchemaBase
from agents.extensions.tool_output_trimmer import ToolOutputTrimmer
from agents.items import TResponseInputItem
from agents.run_config import CallModelData, ModelInputData

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -43,6 +45,23 @@ def _make_data(items: list[Any]) -> CallModelData[Any]:
return CallModelData(model_data=model_data, agent=MagicMock(), context=None)


class _Schema(AgentOutputSchemaBase):
def is_plain_text(self) -> bool:
return False

def name(self) -> str:
return "Schema"

def json_schema(self) -> dict[str, Any]:
return {"type": "object", "properties": {}}

def is_strict_json_schema(self) -> bool:
return False

def validate_json(self, json_str: str) -> Any:
return json.loads(json_str)


def _output(result: ModelInputData, idx: int) -> Any:
"""Extract the ``output`` field from a result item (untyped for test convenience)."""
item: Any = result.input[idx]
Expand Down Expand Up @@ -182,6 +201,32 @@ def test_no_trimming_when_all_recent(self) -> None:
result = trimmer(_make_data(items))
assert _output(result, 2) == large

def test_preserves_output_schema_when_trimming(self) -> None:
"""The trimmer should not drop structured-output metadata from ModelInputData."""
schema = _Schema()
items = [
_user("q1"),
_func_call("c1", "search"),
_func_output("c1", "x" * 1000),
_assistant("a1"),
_user("q2"),
_assistant("a2"),
]
trimmer = ToolOutputTrimmer(recent_turns=1)
data = CallModelData(
model_data=ModelInputData(
input=cast(list[TResponseInputItem], items),
instructions="You are helpful.",
output_schema=schema,
),
agent=MagicMock(),
context=None,
)

result = trimmer(data)

assert result.output_schema is schema

def test_trims_large_old_output(self) -> None:
"""Large output in an old turn should be trimmed."""
large = "x" * 1000
Expand Down
6 changes: 4 additions & 2 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2561,7 +2561,7 @@ async def test_conversation_lock_rewind_skips_when_no_snapshot() -> None:
model.add_multiple_turn_outputs([locked_error, [get_text_message("ok")]])
agent = Agent(name="test", model=model)

result = await get_new_response(
result, output_schema = await get_new_response(
bindings=bind_public_agent(agent),
system_prompt=None,
input=[history_item, new_item],
Expand All @@ -2579,6 +2579,7 @@ async def test_conversation_lock_rewind_skips_when_no_snapshot() -> None:
)

assert isinstance(result, ModelResponse)
assert output_schema is None
assert session.pop_calls == 0


Expand Down Expand Up @@ -2606,7 +2607,7 @@ async def test_get_new_response_uses_agent_retry_settings() -> None:
),
)

result = await get_new_response(
result, output_schema = await get_new_response(
bindings=bind_public_agent(agent),
system_prompt=None,
input=[get_text_input_item("hello")],
Expand All @@ -2624,6 +2625,7 @@ async def test_get_new_response_uses_agent_retry_settings() -> None:
)

assert isinstance(result, ModelResponse)
assert output_schema is None
assert result.usage.requests == 2


Expand Down
Loading