Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions haystack/utils/requests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def request_with_retry(
if status_codes_to_retry is None:
status_codes_to_retry = [408, 418, 429, 503]

# Pop timeout once, before the retry loop. Popping inside @retry-decorated
# run() would remove the key on the first attempt so all subsequent retries
# silently fall back to the default 10-second timeout.
timeout = kwargs.pop("timeout", 10)

@retry(
reraise=True,
wait=wait_exponential(),
Expand All @@ -67,7 +72,6 @@ def request_with_retry(
after=after_log(logger, logging.DEBUG),
)
def run() -> httpx.Response:
timeout = kwargs.pop("timeout", 10)
with httpx.Client() as client:
res = client.request(**kwargs, timeout=timeout)

Expand Down Expand Up @@ -168,6 +172,11 @@ async def example_5xx():
if status_codes_to_retry is None:
status_codes_to_retry = [408, 418, 429, 503]

# Pop timeout once, before the retry loop. Popping inside @retry-decorated
# run() would remove the key on the first attempt so all subsequent retries
# silently fall back to the default 10-second timeout.
timeout = kwargs.pop("timeout", 10)

@retry(
reraise=True,
wait=wait_exponential(),
Expand All @@ -177,7 +186,6 @@ async def example_5xx():
after=after_log(logger, logging.DEBUG),
)
async def run() -> httpx.Response:
timeout = kwargs.pop("timeout", 10)
async with httpx.AsyncClient() as client:
res = await client.request(**kwargs, timeout=timeout)

Expand Down
55 changes: 55 additions & 0 deletions test/utils/test_requests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,33 @@ def raise_for_status():
assert mock_request.call_count == 2
mock_sleep.assert_called()

def test_request_with_retry_preserves_timeout_on_retry(self):
"""Regression test: custom timeout must be forwarded on every retry attempt.

Previously, ``timeout`` was popped from ``kwargs`` *inside* the
``@retry``-decorated ``run()`` closure. On the first attempt the pop
succeeded and the caller-supplied value was used; on all subsequent
retries the key was already gone so ``kwargs.pop("timeout", 10)``
silently fell back to the 10-second default, ignoring the user's value.
"""
with patch("time.sleep"):
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
success_response.raise_for_status = lambda: None

with patch("httpx.Client.request") as mock_request:
mock_request.side_effect = [
httpx.RequestError("transient error", request=httpx.Request("GET", "https://example.com")),
success_response,
]

request_with_retry(method="GET", url="https://example.com", attempts=2, timeout=42)

assert mock_request.call_count == 2
for call in mock_request.call_args_list:
assert call.kwargs["timeout"] == 42, (
f"Expected timeout=42 on every attempt, got {call.kwargs['timeout']!r}"
)


class TestAsyncRequestWithRetry:
@pytest.mark.asyncio
Expand Down Expand Up @@ -234,3 +261,31 @@ def raise_for_status():
assert response == success_response
assert mock_request.call_count == 2
mock_sleep.assert_called()

@pytest.mark.asyncio
async def test_async_request_with_retry_preserves_timeout_on_retry(self):
"""Regression test: custom timeout must be forwarded on every retry attempt.

Previously, ``timeout`` was popped from ``kwargs`` *inside* the
``@retry``-decorated ``run()`` closure. On the first attempt the pop
succeeded and the caller-supplied value was used; on all subsequent
retries the key was already gone so ``kwargs.pop("timeout", 10)``
silently fell back to the 10-second default, ignoring the user's value.
"""
with patch("asyncio.sleep"):
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
success_response.raise_for_status = lambda: None

with patch("httpx.AsyncClient.request") as mock_request:
mock_request.side_effect = [
httpx.RequestError("transient error", request=httpx.Request("GET", "https://example.com")),
success_response,
]

await async_request_with_retry(method="GET", url="https://example.com", attempts=2, timeout=42)

assert mock_request.call_count == 2
for call in mock_request.call_args_list:
assert call.kwargs["timeout"] == 42, (
f"Expected timeout=42 on every attempt, got {call.kwargs['timeout']!r}"
)