diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 4da04f0f83..4996b9a3e5 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import inspect import json import time from collections.abc import AsyncIterator @@ -95,6 +97,31 @@ def _handle_unsupported_prompt(self, prompt: ResponsePromptParam | None) -> None def get_retry_advice(self, request: ModelRetryAdviceRequest) -> ModelRetryAdvice | None: return get_openai_retry_advice(request) + async def _maybe_aclose_async_iterator(self, iterator: Any) -> None: + aclose = getattr(iterator, "aclose", None) + if callable(aclose): + await aclose() + return + + close = getattr(iterator, "close", None) + if callable(close): + close_result = close() + if inspect.isawaitable(close_result): + await close_result + + def _schedule_async_iterator_close(self, iterator: Any) -> None: + task = asyncio.create_task(self._maybe_aclose_async_iterator(iterator)) + task.add_done_callback(self._consume_background_cleanup_task_result) + + @staticmethod + def _consume_background_cleanup_task_result(task: asyncio.Task[Any]) -> None: + try: + task.result() + except asyncio.CancelledError: + pass + except Exception as exc: + logger.debug(f"Background stream cleanup failed after cancellation: {exc}") + def _validate_official_openai_input_content_types( self, request_input: str | list[TResponseInputItem] ) -> None: @@ -307,16 +334,35 @@ async def stream_response( else: stream_for_handler = stream - async for chunk in ChatCmplStreamHandler.handle_stream( - response, - cast(AsyncStream[ChatCompletionChunk], stream_for_handler), - model=self.model, - strict_feature_validation=self._strict_feature_validation, - ): - yield chunk - - if chunk.type == "response.completed": - final_response = chunk.response + close_stream_in_background = False + yielded_terminal_event = False + try: + async for chunk in ChatCmplStreamHandler.handle_stream( + response, + cast(AsyncStream[ChatCompletionChunk], stream_for_handler), + model=self.model, + strict_feature_validation=self._strict_feature_validation, + ): + if chunk.type == "response.completed": + final_response = chunk.response + yielded_terminal_event = True + + yield chunk + except asyncio.CancelledError: + close_stream_in_background = True + self._schedule_async_iterator_close(stream) + raise + finally: + if not close_stream_in_background: + try: + await self._maybe_aclose_async_iterator(stream) + except Exception as exc: + if yielded_terminal_event: + logger.debug( + f"Ignoring stream cleanup error after terminal event: {exc}" + ) + else: + raise if tracing.include_data() and final_response: span_generation.span_data.output = [final_response.model_dump()] diff --git a/tests/models/test_openai_chatcompletions_stream.py b/tests/models/test_openai_chatcompletions_stream.py index 33de9ca194..251bcb9bf1 100644 --- a/tests/models/test_openai_chatcompletions_stream.py +++ b/tests/models/test_openai_chatcompletions_stream.py @@ -161,6 +161,66 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_close_closes_provider_stream_with_async_close( + monkeypatch, +) -> None: + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="Hi"))], + ) + + class ClosableChatStream: + def __init__(self) -> None: + self._yielded = False + self.close_calls = 0 + + def __aiter__(self) -> "ClosableChatStream": + return self + + async def __anext__(self) -> ChatCompletionChunk: + if self._yielded: + raise StopAsyncIteration + self._yielded = True + return chunk + + async def close(self) -> None: + self.close_calls += 1 + + provider_stream = ClosableChatStream() + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), provider_stream + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + + stream = model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + stream_agen = cast(Any, stream) + + event = await stream_agen.__anext__() + assert event.type == "response.created" + + await stream_agen.aclose() + + assert provider_stream.close_calls == 1 + + @pytest.mark.asyncio async def test_stream_handler_filters_multiple_choices_by_default( caplog: pytest.LogCaptureFixture,