From 45f4923f52d0b745a66e4f17d397a4ebff70b615 Mon Sep 17 00:00:00 2001 From: pragnyanramtha Date: Sat, 16 May 2026 16:40:21 +0000 Subject: [PATCH] fix: infer missing streaming tool call index --- src/openai/lib/streaming/chat/_completions.py | 41 ++++++ tests/lib/chat/test_completions_streaming.py | 120 ++++++++++++++++++ 2 files changed, 161 insertions(+) diff --git a/src/openai/lib/streaming/chat/_completions.py b/src/openai/lib/streaming/chat/_completions.py index 5f072cafbd..38f9193eb2 100644 --- a/src/openai/lib/streaming/chat/_completions.py +++ b/src/openai/lib/streaming/chat/_completions.py @@ -357,16 +357,57 @@ def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState: self.__choice_event_states.append(choice_state) return choice_state + def _infer_missing_tool_call_indexes( + self, + *, + choice: ChoiceChunk, + previous_tool_calls: list[Any], + ) -> None: + tool_calls = choice.delta.tool_calls + if not tool_calls or len(tool_calls) != 1: + return + + tool_call = cast(Any, tool_calls[0]) + if tool_call.index is not None: + return + + if len(previous_tool_calls) == 0: + tool_call.index = 0 + return + + if len(previous_tool_calls) != 1: + return + + if not self._is_missing_index_tool_call_continuation(tool_call): + return + + tool_call.index = previous_tool_calls[0].index + + @staticmethod + def _is_missing_index_tool_call_continuation(tool_call: Any) -> bool: + if tool_call.id is not None or tool_call.type is not None: + return False + + function = tool_call.function + if function is not None and function.name is not None: + return False + + return True + def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot: completion_snapshot = self.__current_completion_snapshot if completion_snapshot is None: + for choice in chunk.choices: + self._infer_missing_tool_call_indexes(choice=choice, previous_tool_calls=[]) + return _convert_initial_chunk_into_snapshot(chunk) for choice in chunk.choices: try: choice_snapshot = completion_snapshot.choices[choice.index] previous_tool_calls = choice_snapshot.message.tool_calls or [] + self._infer_missing_tool_call_indexes(choice=choice, previous_tool_calls=previous_tool_calls) choice_snapshot.message = cast( ParsedChatCompletionMessageSnapshot, diff --git a/tests/lib/chat/test_completions_streaming.py b/tests/lib/chat/test_completions_streaming.py index eb3a0973ac..6f2cc790de 100644 --- a/tests/lib/chat/test_completions_streaming.py +++ b/tests/lib/chat/test_completions_streaming.py @@ -20,6 +20,7 @@ from openai import OpenAI, AsyncOpenAI from openai._utils import consume_sync_iterator, assert_signatures_in_sync from openai._compat import model_copy +from openai._models import construct_type from openai.types.chat import ChatCompletionChunk from openai.lib.streaming.chat import ( ContentDoneEvent, @@ -1018,6 +1019,125 @@ def test_allows_non_strict_tools_but_no_parsing( ) +def _make_chat_completion_chunk(delta: dict[str, object]) -> ChatCompletionChunk: + return construct_type( + type_=ChatCompletionChunk, + value={ + "id": "chatcmpl_123", + "object": "chat.completion.chunk", + "created": 1727346167, + "model": "gpt-4o-2024-08-06", + "choices": [ + { + "index": 0, + "delta": delta, + } + ], + }, + ) + + +def test_tool_call_delta_without_index_starts_single_tool_call() -> None: + state = ChatCompletionStreamState() + + chunk = _make_chat_completion_chunk( + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location":"Chicago, IL"}'}, + } + ], + } + ) + + events = list(state.handle_chunk(chunk)) + + tool_call_delta = chunk.choices[0].delta.tool_calls[0] + tool_call = state.current_completion_snapshot.choices[0].message.tool_calls[0] + assert tool_call_delta.index == 0 + assert tool_call.index == 0 + assert tool_call.function.arguments == '{"location":"Chicago, IL"}' + assert [event.type for event in events] == ["chunk", "tool_calls.function.arguments.delta"] + + +def test_tool_call_delta_without_index_continues_active_tool_call() -> None: + state = ChatCompletionStreamState() + + first_chunk = _make_chat_completion_chunk( + { + "role": "assistant", + "tool_calls": [ + { + "index": 0, + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": ""}, + } + ], + } + ) + second_chunk = _make_chat_completion_chunk( + { + "tool_calls": [ + { + "function": {"arguments": '{"city":"Chicago"}'}, + } + ], + } + ) + + state.handle_chunk(first_chunk) + events = list(state.handle_chunk(second_chunk)) + + tool_call_delta = second_chunk.choices[0].delta.tool_calls[0] + tool_call = state.current_completion_snapshot.choices[0].message.tool_calls[0] + assert tool_call_delta.index == 0 + assert tool_call.function.arguments == '{"city":"Chicago"}' + assert [event.type for event in events] == ["chunk", "tool_calls.function.arguments.delta"] + + +def test_tool_call_delta_without_index_does_not_start_second_tool_call() -> None: + state = ChatCompletionStreamState() + + first_chunk = _make_chat_completion_chunk( + { + "role": "assistant", + "tool_calls": [ + { + "index": 0, + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city":"Chicago"}'}, + } + ], + } + ) + second_chunk = _make_chat_completion_chunk( + { + "tool_calls": [ + { + "id": "call_456", + "type": "function", + "function": {"name": "get_time", "arguments": '{"city":"Chicago"}'}, + } + ], + } + ) + + state.handle_chunk(first_chunk) + with pytest.raises(RuntimeError, match="Expected list delta entry to have an `index` key"): + state.handle_chunk(second_chunk) + + tool_call_delta = second_chunk.choices[0].delta.tool_calls[0] + tool_call = state.current_completion_snapshot.choices[0].message.tool_calls[0] + assert tool_call_delta.index is None + assert tool_call.id == "call_123" + assert tool_call.function.name == "get_weather" + + @pytest.mark.respx(base_url=base_url) def test_chat_completion_state_helper(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None: state = ChatCompletionStreamState()