Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions src/openai/resources/vector_stores/file_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import asyncio
from typing import Dict, Iterable, Optional
from typing import Any, Dict, Iterable, Optional
from typing_extensions import Union, Literal
from concurrent.futures import Future, ThreadPoolExecutor, as_completed

Expand All @@ -15,6 +15,7 @@
from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, SequenceNotStr, omit, not_given
from ..._utils import is_given, path_template, maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._models import construct_type_unchecked
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
from ...pagination import SyncCursorPage, AsyncCursorPage
Expand All @@ -28,6 +29,26 @@
__all__ = ["FileBatches", "AsyncFileBatches"]


def _coerce_vector_store_poll_response(
data: dict[str, Any],
*,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatch | None:
if data.get("object") != "vector_store" or data.get("id") != vector_store_id:
return None

return construct_type_unchecked(
value={
**data,
"id": batch_id,
"object": "vector_store.files_batch",
"vector_store_id": vector_store_id,
},
type_=VectorStoreFileBatch,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep strict validation for coerced poll responses

When poll() receives a vector_store payload, this path builds VectorStoreFileBatch via construct_type_unchecked, which bypasses _strict_response_validation entirely. That means clients that explicitly enabled strict validation can silently accept malformed fields (wrong enum/value types) in this branch instead of getting a validation error, unlike the normal response.parse() path.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch — I switched the coerced vector_store poll path to run back through the client response processor so strict validation still applies, and added sync/async regression coverage for the malformed payload case. Updated in 0e79e20.



class FileBatches(SyncAPIResource):
@cached_property
def with_raw_response(self) -> FileBatchesWithRawResponse:
Expand Down Expand Up @@ -351,7 +372,11 @@ def poll(
extra_headers=headers,
)

batch = response.parse()
data = response.parse(to=dict)
batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id)
if batch is None:
batch = response.parse()

if batch.file_counts.in_progress > 0:
if not is_given(poll_interval_ms):
from_header = response.headers.get("openai-poll-after-ms")
Expand Down Expand Up @@ -739,7 +764,11 @@ async def poll(
extra_headers=headers,
)

batch = response.parse()
data = response.parse(to=dict)
batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id)
if batch is None:
batch = response.parse()

if batch.file_counts.in_progress > 0:
if not is_given(poll_interval_ms):
from_header = response.headers.get("openai-poll-after-ms")
Expand Down
56 changes: 56 additions & 0 deletions tests/api_resources/vector_stores/test_file_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from typing import Any, cast

import httpx
import pytest

from openai import OpenAI, AsyncOpenAI
Expand Down Expand Up @@ -462,3 +463,58 @@ def test_create_and_poll_method_in_sync(sync: bool, client: OpenAI, async_client
checking_client.vector_stores.file_batches.create,
checking_client.vector_stores.file_batches.create_and_poll,
)


def _completed_vector_store_response() -> dict[str, object]:
return {
"id": "vs_abc123",
"created_at": 1761991501,
"file_counts": {
"cancelled": 0,
"completed": 1,
"failed": 0,
"in_progress": 0,
"total": 1,
},
"object": "vector_store",
"status": "completed",
"vector_store_id": None,
}


def test_poll_coerces_completed_vector_store_response() -> None:
def handler(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123"
return httpx.Response(200, json=_completed_vector_store_response())

with OpenAI(
api_key="My API Key",
base_url=base_url,
http_client=httpx.Client(transport=httpx.MockTransport(handler)),
_strict_response_validation=True,
) as client:
file_batch = client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123")

assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"])
assert file_batch.id == "vsfb_abc123"
assert file_batch.vector_store_id == "vs_abc123"


async def test_async_poll_coerces_completed_vector_store_response() -> None:
async def handler(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123"
return httpx.Response(200, json=_completed_vector_store_response())

async with AsyncOpenAI(
api_key="My API Key",
base_url=base_url,
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
_strict_response_validation=True,
) as async_client:
file_batch = await async_client.vector_stores.file_batches.poll(
batch_id="vsfb_abc123", vector_store_id="vs_abc123"
)

assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"])
assert file_batch.id == "vsfb_abc123"
assert file_batch.vector_store_id == "vs_abc123"