Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio
import inspect
import json
import time
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -95,6 +97,31 @@ def _handle_unsupported_prompt(self, prompt: ResponsePromptParam | None) -> None
def get_retry_advice(self, request: ModelRetryAdviceRequest) -> ModelRetryAdvice | None:
return get_openai_retry_advice(request)

async def _maybe_aclose_async_iterator(self, iterator: Any) -> None:
aclose = getattr(iterator, "aclose", None)
if callable(aclose):
await aclose()
return

close = getattr(iterator, "close", None)
if callable(close):
close_result = close()
if inspect.isawaitable(close_result):
await close_result

def _schedule_async_iterator_close(self, iterator: Any) -> None:
task = asyncio.create_task(self._maybe_aclose_async_iterator(iterator))
task.add_done_callback(self._consume_background_cleanup_task_result)

@staticmethod
def _consume_background_cleanup_task_result(task: asyncio.Task[Any]) -> None:
try:
task.result()
except asyncio.CancelledError:
pass
except Exception as exc:
logger.debug(f"Background stream cleanup failed after cancellation: {exc}")

def _validate_official_openai_input_content_types(
self, request_input: str | list[TResponseInputItem]
) -> None:
Expand Down Expand Up @@ -307,16 +334,35 @@ async def stream_response(
else:
stream_for_handler = stream

async for chunk in ChatCmplStreamHandler.handle_stream(
response,
cast(AsyncStream[ChatCompletionChunk], stream_for_handler),
model=self.model,
strict_feature_validation=self._strict_feature_validation,
):
yield chunk

if chunk.type == "response.completed":
final_response = chunk.response
close_stream_in_background = False
yielded_terminal_event = False
try:
async for chunk in ChatCmplStreamHandler.handle_stream(
response,
cast(AsyncStream[ChatCompletionChunk], stream_for_handler),
model=self.model,
strict_feature_validation=self._strict_feature_validation,
):
yield chunk

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Set terminal state before yielding completed events

When a consumer breaks or closes the async generator immediately after receiving response.completed, execution is still suspended at this yield, so the bookkeeping on lines 348-350 never runs. In that common early-exit-after-terminal scenario, yielded_terminal_event remains false and any provider-stream close error is re-raised instead of being ignored for an already terminal stream; move the completed-event bookkeeping before yielding the chunk.

Useful? React with 👍 / 👎.


if chunk.type == "response.completed":
final_response = chunk.response
yielded_terminal_event = True
except asyncio.CancelledError:
close_stream_in_background = True
self._schedule_async_iterator_close(stream)
raise
finally:
if not close_stream_in_background:
try:
await self._maybe_aclose_async_iterator(stream)
except Exception as exc:
if yielded_terminal_event:
logger.debug(
f"Ignoring stream cleanup error after terminal event: {exc}"
)
else:
raise

if tracing.include_data() and final_response:
span_generation.span_data.output = [final_response.model_dump()]
Expand Down
58 changes: 58 additions & 0 deletions tests/models/test_openai_chatcompletions_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,64 @@ async def patched_fetch_response(self, *args, **kwargs):
assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_stream_response_close_closes_provider_stream(monkeypatch) -> None:
chunk = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(content="Hi"))],
)

class ClosableChatStream:
def __init__(self) -> None:
self._yielded = False
self.close_calls = 0

def __aiter__(self) -> "ClosableChatStream":
return self

async def __anext__(self) -> ChatCompletionChunk:
if self._yielded:
raise StopAsyncIteration
self._yielded = True
return chunk

async def aclose(self) -> None:
self.close_calls += 1

provider_stream = ClosableChatStream()

async def patched_fetch_response(self, *args, **kwargs):
return _empty_response(), provider_stream

monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
model = OpenAIProvider(use_responses=False).get_model("gpt-4")

stream = model.stream_response(
system_instructions=None,
input="",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
)
stream_agen = cast(Any, stream)

event = await stream_agen.__anext__()
assert event.type == "response.created"

await stream_agen.aclose()

assert provider_stream.close_calls == 1


@pytest.mark.asyncio
async def test_stream_handler_filters_multiple_choices_by_default(
caplog: pytest.LogCaptureFixture,
Expand Down