Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/running_agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,52 @@ If you are using OpenAI server-managed conversation state with `conversation_id`

Set the hook per run via `run_config` to redact sensitive data, trim long histories, or inject additional system guidance.

#### Injecting input between tool turns

Because `call_model_input_filter` runs before every model call, it can also reconcile host-side state that changed while tools were running. A common pattern is to keep a small queue in your local context, let tools or background application code append new user messages to that queue, and drain it in the filter before the next model turn.

```python
from dataclasses import dataclass, field

from agents import Agent, RunConfig, RunContextWrapper, Runner, function_tool
from agents.run import CallModelData, ModelInputData


@dataclass
class AppContext:
queued_user_messages: list[str] = field(default_factory=list)


@function_tool
def long_running_tool(wrapper: RunContextWrapper[AppContext]) -> str:
wrapper.context.queued_user_messages.append("The user added a constraint.")
return "tool result"


def inject_queued_messages(data: CallModelData[AppContext]) -> ModelInputData:
input_items = list(data.model_data.input)
if data.context is not None:
input_items.extend(
{"role": "user", "content": message}
for message in data.context.queued_user_messages
)
data.context.queued_user_messages.clear()
return ModelInputData(input=input_items, instructions=data.model_data.instructions)


context = AppContext()
agent = Agent(name="Assistant", tools=[long_running_tool])

result = await Runner.run(
agent,
"Use the tool, then continue with any queued user updates.",
context=context,
run_config=RunConfig(call_model_input_filter=inject_queued_messages),
)
```

The filter changes the model-facing payload for that turn; it does not mutate `RunResult.input`, automatically persist injected items to a session, or pause a run by itself. If you need durable conversation storage, also write the accepted queued messages to your own session or store. For streaming runs, keep consuming `stream_events()` until completion so the next model turn and any session writes finish before you inspect the result.

## Errors and recovery

### Error handlers
Expand Down
94 changes: 92 additions & 2 deletions tests/test_call_model_input_filter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,46 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, cast

import pytest

from agents import Agent, RunConfig, Runner, TResponseInputItem, UserError
from agents import Agent, RunConfig, RunContextWrapper, Runner, TResponseInputItem, UserError
from agents.run import CallModelData, ModelInputData
from agents.tool import function_tool

from .fake_model import FakeModel
from .test_responses import get_text_input_item, get_text_message
from .test_responses import get_function_tool_call, get_text_input_item, get_text_message


@dataclass
class ExternalEventContext:
queued_user_messages: list[str] = field(default_factory=list)


EXTERNAL_MESSAGE = "The user added a new constraint while the tool was running."


@function_tool
def collect_external_update(wrapper: RunContextWrapper[ExternalEventContext]) -> str:
wrapper.context.queued_user_messages.append(EXTERNAL_MESSAGE)
return "tool-result"


def inject_queued_messages(data: CallModelData[ExternalEventContext]) -> ModelInputData:
input_items = list(data.model_data.input)
if data.context is not None:
input_items.extend(
get_text_input_item(message) for message in data.context.queued_user_messages
)
data.context.queued_user_messages.clear()
return ModelInputData(input=input_items, instructions=data.model_data.instructions)


def assert_external_message_injected(input_items: list[TResponseInputItem]) -> None:
last_item = cast(dict[str, Any], input_items[-1])
assert last_item["content"] == EXTERNAL_MESSAGE
assert any(item.get("type") == "function_call_output" for item in input_items)


@pytest.mark.asyncio
Expand Down Expand Up @@ -63,6 +95,64 @@ async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
assert model.last_turn_args["input"][-1]["content"] == "added-async"


@pytest.mark.asyncio
async def test_call_model_input_filter_injects_external_input_between_tool_turns() -> None:
model = FakeModel()
model.add_multiple_turn_outputs(
[
[get_function_tool_call("collect_external_update")],
[get_text_message("done")],
]
)
agent = Agent[ExternalEventContext](
name="test",
model=model,
tools=[collect_external_update],
)
context = ExternalEventContext()

await Runner.run(
agent,
input="start",
context=context,
run_config=RunConfig(call_model_input_filter=inject_queued_messages),
)

assert isinstance(model.last_turn_args["input"], list)
assert_external_message_injected(model.last_turn_args["input"])
assert context.queued_user_messages == []


@pytest.mark.asyncio
async def test_call_model_input_filter_injects_external_input_between_streamed_tool_turns() -> None:
model = FakeModel()
model.add_multiple_turn_outputs(
[
[get_function_tool_call("collect_external_update")],
[get_text_message("done")],
]
)
agent = Agent[ExternalEventContext](
name="test",
model=model,
tools=[collect_external_update],
)
context = ExternalEventContext()

result = Runner.run_streamed(
agent,
input="start",
context=context,
run_config=RunConfig(call_model_input_filter=inject_queued_messages),
)
async for _ in result.stream_events():
pass

assert isinstance(model.last_turn_args["input"], list)
assert_external_message_injected(model.last_turn_args["input"])
assert context.queued_user_messages == []


@pytest.mark.asyncio
async def test_call_model_input_filter_invalid_return_type_raises() -> None:
model = FakeModel()
Expand Down