diff --git a/src/openai/lib/streaming/_assistants.py b/src/openai/lib/streaming/_assistants.py index 6efb3ca3f1..e234004473 100644 --- a/src/openai/lib/streaming/_assistants.py +++ b/src/openai/lib/streaming/_assistants.py @@ -977,14 +977,55 @@ def accumulate_event( return current_message_snapshot, new_content +def _accumulate_list_delta(acc_value: list[object], delta_value: list[object]) -> list[object]: + # for lists of non-dictionary items we'll only ever get new entries + # in the array, existing entries will never be changed + if all(isinstance(x, (str, int, float)) for x in acc_value + delta_value): + acc_value.extend(delta_value) + return acc_value + + for delta_entry in delta_value: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + existing_index = next( + (idx for idx, entry in enumerate(acc_value) if is_dict(entry) and entry.get("index") == index), + None, + ) + + if existing_index is None: + acc_value.insert(index, delta_entry) + continue + + acc_entry = acc_value[existing_index] + if not is_dict(acc_entry): + raise TypeError("not handled yet") + + acc_value[existing_index] = accumulate_delta(acc_entry, delta_entry) + + return acc_value + + def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: + if is_list(delta_value): + delta_value = _accumulate_list_delta([], delta_value) acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: + if is_list(delta_value): + delta_value = _accumulate_list_delta([], delta_value) acc[key] = delta_value continue @@ -1005,33 +1046,7 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue - - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") - - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + acc_value = _accumulate_list_delta(acc_value, delta_value) acc[key] = acc_value diff --git a/src/openai/lib/streaming/_deltas.py b/src/openai/lib/streaming/_deltas.py index a5e1317612..1b956e6d29 100644 --- a/src/openai/lib/streaming/_deltas.py +++ b/src/openai/lib/streaming/_deltas.py @@ -3,14 +3,55 @@ from ..._utils import is_dict, is_list +def _accumulate_list_delta(acc_value: list[object], delta_value: list[object]) -> list[object]: + # for lists of non-dictionary items we'll only ever get new entries + # in the array, existing entries will never be changed + if all(isinstance(x, (str, int, float)) for x in acc_value + delta_value): + acc_value.extend(delta_value) + return acc_value + + for delta_entry in delta_value: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + existing_index = next( + (idx for idx, entry in enumerate(acc_value) if is_dict(entry) and entry.get("index") == index), + None, + ) + + if existing_index is None: + acc_value.insert(index, delta_entry) + continue + + acc_entry = acc_value[existing_index] + if not is_dict(acc_entry): + raise TypeError("not handled yet") + + acc_value[existing_index] = accumulate_delta(acc_entry, delta_entry) + + return acc_value + + def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: + if is_list(delta_value): + delta_value = _accumulate_list_delta([], delta_value) acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: + if is_list(delta_value): + delta_value = _accumulate_list_delta([], delta_value) acc[key] = delta_value continue @@ -31,33 +72,7 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue - - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") - - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + acc_value = _accumulate_list_delta(acc_value, delta_value) acc[key] = acc_value diff --git a/tests/lib/test_streaming_deltas.py b/tests/lib/test_streaming_deltas.py new file mode 100644 index 0000000000..a2939f7aa0 --- /dev/null +++ b/tests/lib/test_streaming_deltas.py @@ -0,0 +1,85 @@ +from openai.lib.streaming._deltas import accumulate_delta + + +def test_accumulate_delta_merges_duplicate_indexes_in_initial_list() -> None: + acc: dict[object, object] = {"tool_calls": None} + + accumulate_delta( + acc, + { + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "get_weather"}, + "type": "function", + }, + { + "index": 0, + "function": {"arguments": '{"city"'}, + }, + ], + }, + ) + accumulate_delta( + acc, + { + "tool_calls": [ + { + "index": 0, + "function": {"arguments": ': "London"}'}, + }, + ], + }, + ) + + assert acc["tool_calls"] == [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "get_weather", "arguments": '{"city": "London"}'}, + "type": "function", + } + ] + + +def test_accumulate_delta_merges_later_entries_by_logical_index() -> None: + acc: dict[object, object] = { + "tool_calls": [ + { + "index": 0, + "function": {"arguments": "a"}, + }, + { + "index": 1, + "function": {"arguments": "x"}, + }, + ], + } + + accumulate_delta( + acc, + { + "tool_calls": [ + { + "index": 1, + "function": {"arguments": "y"}, + }, + { + "index": 0, + "function": {"arguments": "b"}, + }, + ], + }, + ) + + assert acc["tool_calls"] == [ + { + "index": 0, + "function": {"arguments": "ab"}, + }, + { + "index": 1, + "function": {"arguments": "xy"}, + }, + ]