From d93b57dae2babb7c598baabde2212bf9d2796ede Mon Sep 17 00:00:00 2001 From: maxpetrusenkoagent <[REDACTED EMAIL]> Date: Sun, 28 Jun 2026 00:06:05 -0400 Subject: [PATCH] docs: clarify inter-turn input filtering --- docs/running_agents.md | 46 +++++++++++++ tests/test_call_model_input_filter.py | 94 ++++++++++++++++++++++++++- 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/docs/running_agents.md b/docs/running_agents.md index d72d47a2ce..3d15953ba0 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -461,6 +461,52 @@ If you are using OpenAI server-managed conversation state with `conversation_id` Set the hook per run via `run_config` to redact sensitive data, trim long histories, or inject additional system guidance. +#### Injecting input between tool turns + +Because `call_model_input_filter` runs before every model call, it can also reconcile host-side state that changed while tools were running. A common pattern is to keep a small queue in your local context, let tools or background application code append new user messages to that queue, and drain it in the filter before the next model turn. + +```python +from dataclasses import dataclass, field + +from agents import Agent, RunConfig, RunContextWrapper, Runner, function_tool +from agents.run import CallModelData, ModelInputData + + +@dataclass +class AppContext: + queued_user_messages: list[str] = field(default_factory=list) + + +@function_tool +def long_running_tool(wrapper: RunContextWrapper[AppContext]) -> str: + wrapper.context.queued_user_messages.append("The user added a constraint.") + return "tool result" + + +def inject_queued_messages(data: CallModelData[AppContext]) -> ModelInputData: + input_items = list(data.model_data.input) + if data.context is not None: + input_items.extend( + {"role": "user", "content": message} + for message in data.context.queued_user_messages + ) + data.context.queued_user_messages.clear() + return ModelInputData(input=input_items, instructions=data.model_data.instructions) + + +context = AppContext() +agent = Agent(name="Assistant", tools=[long_running_tool]) + +result = await Runner.run( + agent, + "Use the tool, then continue with any queued user updates.", + context=context, + run_config=RunConfig(call_model_input_filter=inject_queued_messages), +) +``` + +The filter changes the model-facing payload for that turn; it does not mutate `RunResult.input`, automatically persist injected items to a session, or pause a run by itself. If you need durable conversation storage, also write the accepted queued messages to your own session or store. For streaming runs, keep consuming `stream_events()` until completion so the next model turn and any session writes finish before you inspect the result. + ## Errors and recovery ### Error handlers diff --git a/tests/test_call_model_input_filter.py b/tests/test_call_model_input_filter.py index f0239089c6..4e76b857e1 100644 --- a/tests/test_call_model_input_filter.py +++ b/tests/test_call_model_input_filter.py @@ -1,14 +1,46 @@ from __future__ import annotations +from dataclasses import dataclass, field from typing import Any, cast import pytest -from agents import Agent, RunConfig, Runner, TResponseInputItem, UserError +from agents import Agent, RunConfig, RunContextWrapper, Runner, TResponseInputItem, UserError from agents.run import CallModelData, ModelInputData +from agents.tool import function_tool from .fake_model import FakeModel -from .test_responses import get_text_input_item, get_text_message +from .test_responses import get_function_tool_call, get_text_input_item, get_text_message + + +@dataclass +class ExternalEventContext: + queued_user_messages: list[str] = field(default_factory=list) + + +EXTERNAL_MESSAGE = "The user added a new constraint while the tool was running." + + +@function_tool +def collect_external_update(wrapper: RunContextWrapper[ExternalEventContext]) -> str: + wrapper.context.queued_user_messages.append(EXTERNAL_MESSAGE) + return "tool-result" + + +def inject_queued_messages(data: CallModelData[ExternalEventContext]) -> ModelInputData: + input_items = list(data.model_data.input) + if data.context is not None: + input_items.extend( + get_text_input_item(message) for message in data.context.queued_user_messages + ) + data.context.queued_user_messages.clear() + return ModelInputData(input=input_items, instructions=data.model_data.instructions) + + +def assert_external_message_injected(input_items: list[TResponseInputItem]) -> None: + last_item = cast(dict[str, Any], input_items[-1]) + assert last_item["content"] == EXTERNAL_MESSAGE + assert any(item.get("type") == "function_call_output" for item in input_items) @pytest.mark.asyncio @@ -63,6 +95,64 @@ async def filter_fn(data: CallModelData[Any]) -> ModelInputData: assert model.last_turn_args["input"][-1]["content"] == "added-async" +@pytest.mark.asyncio +async def test_call_model_input_filter_injects_external_input_between_tool_turns() -> None: + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("collect_external_update")], + [get_text_message("done")], + ] + ) + agent = Agent[ExternalEventContext]( + name="test", + model=model, + tools=[collect_external_update], + ) + context = ExternalEventContext() + + await Runner.run( + agent, + input="start", + context=context, + run_config=RunConfig(call_model_input_filter=inject_queued_messages), + ) + + assert isinstance(model.last_turn_args["input"], list) + assert_external_message_injected(model.last_turn_args["input"]) + assert context.queued_user_messages == [] + + +@pytest.mark.asyncio +async def test_call_model_input_filter_injects_external_input_between_streamed_tool_turns() -> None: + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("collect_external_update")], + [get_text_message("done")], + ] + ) + agent = Agent[ExternalEventContext]( + name="test", + model=model, + tools=[collect_external_update], + ) + context = ExternalEventContext() + + result = Runner.run_streamed( + agent, + input="start", + context=context, + run_config=RunConfig(call_model_input_filter=inject_queued_messages), + ) + async for _ in result.stream_events(): + pass + + assert isinstance(model.last_turn_args["input"], list) + assert_external_message_injected(model.last_turn_args["input"]) + assert context.queued_user_messages == [] + + @pytest.mark.asyncio async def test_call_model_input_filter_invalid_return_type_raises() -> None: model = FakeModel()