Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/openai/lib/streaming/chat/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,16 +357,57 @@ def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState:
self.__choice_event_states.append(choice_state)
return choice_state

def _infer_missing_tool_call_indexes(
self,
*,
choice: ChoiceChunk,
previous_tool_calls: list[Any],
) -> None:
tool_calls = choice.delta.tool_calls
if not tool_calls or len(tool_calls) != 1:
return

tool_call = cast(Any, tool_calls[0])
if tool_call.index is not None:
return

if len(previous_tool_calls) == 0:
tool_call.index = 0
return

if len(previous_tool_calls) != 1:
return

if not self._is_missing_index_tool_call_continuation(tool_call):
return

tool_call.index = previous_tool_calls[0].index
Comment thread
pragnyanramtha marked this conversation as resolved.

@staticmethod
def _is_missing_index_tool_call_continuation(tool_call: Any) -> bool:
if tool_call.id is not None or tool_call.type is not None:
return False

function = tool_call.function
if function is not None and function.name is not None:
return False

return True

def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
completion_snapshot = self.__current_completion_snapshot

if completion_snapshot is None:
for choice in chunk.choices:
self._infer_missing_tool_call_indexes(choice=choice, previous_tool_calls=[])

return _convert_initial_chunk_into_snapshot(chunk)

for choice in chunk.choices:
try:
choice_snapshot = completion_snapshot.choices[choice.index]
previous_tool_calls = choice_snapshot.message.tool_calls or []
self._infer_missing_tool_call_indexes(choice=choice, previous_tool_calls=previous_tool_calls)

choice_snapshot.message = cast(
ParsedChatCompletionMessageSnapshot,
Expand Down
120 changes: 120 additions & 0 deletions tests/lib/chat/test_completions_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from openai import OpenAI, AsyncOpenAI
from openai._utils import consume_sync_iterator, assert_signatures_in_sync
from openai._compat import model_copy
from openai._models import construct_type
from openai.types.chat import ChatCompletionChunk
from openai.lib.streaming.chat import (
ContentDoneEvent,
Expand Down Expand Up @@ -1018,6 +1019,125 @@ def test_allows_non_strict_tools_but_no_parsing(
)


def _make_chat_completion_chunk(delta: dict[str, object]) -> ChatCompletionChunk:
return construct_type(
type_=ChatCompletionChunk,
value={
"id": "chatcmpl_123",
"object": "chat.completion.chunk",
"created": 1727346167,
"model": "gpt-4o-2024-08-06",
"choices": [
{
"index": 0,
"delta": delta,
}
],
},
)


def test_tool_call_delta_without_index_starts_single_tool_call() -> None:
state = ChatCompletionStreamState()

chunk = _make_chat_completion_chunk(
{
"role": "assistant",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location":"Chicago, IL"}'},
}
],
}
)

events = list(state.handle_chunk(chunk))

tool_call_delta = chunk.choices[0].delta.tool_calls[0]
tool_call = state.current_completion_snapshot.choices[0].message.tool_calls[0]
assert tool_call_delta.index == 0
assert tool_call.index == 0
assert tool_call.function.arguments == '{"location":"Chicago, IL"}'
assert [event.type for event in events] == ["chunk", "tool_calls.function.arguments.delta"]


def test_tool_call_delta_without_index_continues_active_tool_call() -> None:
state = ChatCompletionStreamState()

first_chunk = _make_chat_completion_chunk(
{
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_123",
"type": "function",
"function": {"name": "get_weather", "arguments": ""},
}
],
}
)
second_chunk = _make_chat_completion_chunk(
{
"tool_calls": [
{
"function": {"arguments": '{"city":"Chicago"}'},
}
],
}
)

state.handle_chunk(first_chunk)
events = list(state.handle_chunk(second_chunk))

tool_call_delta = second_chunk.choices[0].delta.tool_calls[0]
tool_call = state.current_completion_snapshot.choices[0].message.tool_calls[0]
assert tool_call_delta.index == 0
assert tool_call.function.arguments == '{"city":"Chicago"}'
assert [event.type for event in events] == ["chunk", "tool_calls.function.arguments.delta"]


def test_tool_call_delta_without_index_does_not_start_second_tool_call() -> None:
state = ChatCompletionStreamState()

first_chunk = _make_chat_completion_chunk(
{
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_123",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"city":"Chicago"}'},
}
],
}
)
second_chunk = _make_chat_completion_chunk(
{
"tool_calls": [
{
"id": "call_456",
"type": "function",
"function": {"name": "get_time", "arguments": '{"city":"Chicago"}'},
}
],
}
)

state.handle_chunk(first_chunk)
with pytest.raises(RuntimeError, match="Expected list delta entry to have an `index` key"):
state.handle_chunk(second_chunk)

tool_call_delta = second_chunk.choices[0].delta.tool_calls[0]
tool_call = state.current_completion_snapshot.choices[0].message.tool_calls[0]
assert tool_call_delta.index is None
assert tool_call.id == "call_123"
assert tool_call.function.name == "get_weather"


@pytest.mark.respx(base_url=base_url)
def test_chat_completion_state_helper(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
state = ChatCompletionStreamState()
Expand Down