diff --git a/docs_src/src/pages/documentation/en/api_reference/server_sent_events.mdx b/docs_src/src/pages/documentation/en/api_reference/server_sent_events.mdx index cff7e2c95..64d238280 100644 --- a/docs_src/src/pages/documentation/en/api_reference/server_sent_events.mdx +++ b/docs_src/src/pages/documentation/en/api_reference/server_sent_events.mdx @@ -189,6 +189,68 @@ async def stream_api_events(request): --- +## Streaming raw bytes (binary data & file downloads) + +Not everything Batman streams is an SSE event. To stream **arbitrary bytes** — a large file, a generated archive, an `application/octet-stream` body — he uses `StreamingResponse` directly and yields chunks, without loading the whole payload into memory. Each chunk may be `bytes` (sent as-is) or `str` (UTF-8 encoded). + + + +A sync generator yielding `bytes` chunks, served as a binary download: + + + + +```python +from robyn import Robyn, StreamingResponse, Headers + +app = Robyn(__file__) + +@app.get("/download") +def download(request): + def file_chunks(): + with open("large_file.bin", "rb") as f: + while chunk := f.read(8192): + yield chunk # each chunk is `bytes` + + return StreamingResponse( + file_chunks(), + media_type="application/octet-stream", + headers=Headers({ + "Content-Type": "application/octet-stream", + "Content-Disposition": "attachment; filename=large_file.bin", + }), + ) +``` + + + + + + + +`StreamingResponse` accepts both sync and async generators. When you use an **async generator**, it runs on the same event loop as your handler, so you can safely `await` async resources — an async database session, an HTTP client — **inside** the generator: + + + + +```python +@app.get("/export") +async def export(request): + async def rows(): + async with AsyncSessionLocal() as session: + result = await session.execute(select(Record)) # await works here + for row in result.scalars(): + yield f"{row.id},{row.name}\n".encode() + + return StreamingResponse(rows(), media_type="text/csv") +``` + + + + + +--- + ## What's next? Batman has mastered Server-Sent Events and can now stream real-time updates to his crime dashboard. While SSE is perfect for one-way communication from server to client, Batman realizes he needs bidirectional communication for more interactive features like real-time chat with his allies. diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 21c3abe5c..a8745ea37 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -9,7 +9,7 @@ from typing import TypedDict from integration_tests.subroutes import async_auth_subrouter, di_subrouter, inherited_auth_subrouter, static_router, sub_router -from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, WebSocketDisconnect, jsonify, serve_file, serve_html +from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, StreamingResponse, WebSocketDisconnect, jsonify, serve_file, serve_html from robyn.authentication import AuthenticationHandler, BearerGetter, Identity from robyn.robyn import QueryParams, Url from robyn.templating import JinjaTemplate @@ -1559,6 +1559,37 @@ async def async_event_generator(): return SSEResponse(async_event_generator()) +@app.get("/stream/bytes") +def stream_bytes(request): + """Stream raw binary chunks (sync generator) — regression test for #1236.""" + + def gen(): + for i in range(3): + yield bytes([i]) * 4 # 4 bytes per chunk + + return StreamingResponse( + gen(), + media_type="application/octet-stream", + headers=Headers({"Content-Type": "application/octet-stream"}), + ) + + +@app.get("/stream/bytes_async") +async def stream_bytes_async(request): + """Stream raw binary chunks from an async generator (#1236 + #1219).""" + + async def gen(): + for i in range(3): + await asyncio.sleep(0) # exercise the async driver + yield bytes([i]) * 4 + + return StreamingResponse( + gen(), + media_type="application/octet-stream", + headers=Headers({"Content-Type": "application/octet-stream"}), + ) + + @app.get("/sse/streaming_sync") def sse_streaming_sync(request): """SSE endpoint to test real-time sync streaming""" diff --git a/integration_tests/test_binary_streaming.py b/integration_tests/test_binary_streaming.py new file mode 100644 index 000000000..50a01a4a1 --- /dev/null +++ b/integration_tests/test_binary_streaming.py @@ -0,0 +1,21 @@ +import requests + +BASE_URL = "http://127.0.0.1:8080" +TIMEOUT = 5 + +EXPECTED = bytes([0]) * 4 + bytes([1]) * 4 + bytes([2]) * 4 + + +def test_stream_bytes_sync(session): + """A sync generator yielding bytes streams binary data unchanged (#1236).""" + r = requests.get(f"{BASE_URL}/stream/bytes", timeout=TIMEOUT) + assert r.status_code == 200 + assert r.headers.get("Content-Type") == "application/octet-stream" + assert r.content == EXPECTED + + +def test_stream_bytes_async(session): + """An async generator yielding bytes also streams correctly (#1236 + #1219).""" + r = requests.get(f"{BASE_URL}/stream/bytes_async", timeout=TIMEOUT) + assert r.status_code == 200 + assert r.content == EXPECTED diff --git a/robyn/responses.py b/robyn/responses.py index 68cd99589..f46148d51 100644 --- a/robyn/responses.py +++ b/robyn/responses.py @@ -1,7 +1,9 @@ import asyncio import mimetypes import os -from typing import AsyncGenerator, Generator +import threading +import weakref +from typing import AsyncGenerator, Generator, Optional, Union from robyn.robyn import Headers, Response @@ -63,13 +65,57 @@ def serve_file(file_path: str, file_name: str | None = None) -> FileResponse: class AsyncGeneratorWrapper: - """Optimized true-streaming wrapper for async generators""" + """Drive an async generator through Robyn's synchronous streaming protocol. + + The generator is driven on the event loop that was running when the + ``StreamingResponse`` was constructed — i.e. the handler's loop. That keeps + any async resources created in the handler (DB sessions, HTTP clients) on + the loop they are bound to, so ``await``-ing them inside the generator works + instead of raising "attached to a different loop" (#1219). When constructed + outside an async context (a sync handler), a dedicated background loop is + used instead. + + Errors raised by the generator are propagated (not swallowed), so a failing + stream surfaces the real traceback in the server logs rather than silently + truncating. + """ - def __init__(self, async_gen: AsyncGenerator[str, None]): - self.async_gen = async_gen - self._loop = None - self._iterator = None + def __init__(self, async_gen: AsyncGenerator[Union[str, bytes], None]): + self._async_gen = async_gen + self._iterator: Optional[AsyncGenerator] = None self._exhausted = False + self._owns_loop = False + self._thread: Optional[threading.Thread] = None + self._finalizer: Optional[weakref.finalize] = None + try: + # Constructed inside an async handler -> reuse its running loop. + self._loop = asyncio.get_running_loop() + except RuntimeError: + # Constructed in a sync handler -> drive on a dedicated background + # loop. The thread target and finalizer must NOT capture ``self``, or + # the running thread would keep the wrapper alive forever and the + # finalizer could never fire. + self._loop = asyncio.new_event_loop() + self._owns_loop = True + self._thread = threading.Thread(target=self._run_loop, args=(self._loop,), daemon=True) + self._thread.start() + # Guarantee the background loop is stopped even if iteration ends + # early (client disconnect, unsupported chunk type) and _finish() is + # never reached — otherwise the daemon loop thread would leak. + self._finalizer = weakref.finalize(self, self._stop_loop, self._loop) + + @staticmethod + def _run_loop(loop): + asyncio.set_event_loop(loop) + try: + loop.run_forever() + finally: + loop.close() + + @staticmethod + def _stop_loop(loop): + if not loop.is_closed(): + loop.call_soon_threadsafe(loop.stop) def __iter__(self): return self @@ -78,53 +124,35 @@ def __next__(self): if self._exhausted: raise StopIteration - # Initialize the loop and iterator only once if self._iterator is None: - self._init_async_iterator() - - try: - # Get the next value from the async generator - # This is the key optimization - we don't buffer, we get one value at a time - return self._get_next_value() - except StopIteration: - self._exhausted = True - raise + self._iterator = self._async_gen.__aiter__() - def _init_async_iterator(self): - """Initialize the async iterator with proper loop handling""" + # Schedule one step on the owning loop and block until it yields a value. + # run_coroutine_threadsafe is safe to call from the worker thread Robyn + # drives the stream on, and works whether or not we own the loop. + future = asyncio.run_coroutine_threadsafe(self._iterator.__anext__(), self._loop) try: - # Try to get the running event loop - self._loop = asyncio.get_running_loop() - except RuntimeError: - # No running loop, create a new one - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - # Create the async iterator - self._iterator = self.async_gen.__aiter__() - - def _get_next_value(self): - """Get the next value from async generator without buffering""" - try: - # Create a coroutine to get the next value - async def get_next(): - return await self._iterator.__anext__() - - # Run the coroutine to get the next value - return self._loop.run_until_complete(get_next()) + return future.result() except StopAsyncIteration: - # Convert StopAsyncIteration to StopIteration for sync generator protocol - raise StopIteration - except Exception as e: - # Log error and stop iteration - print(f"Error in async generator: {e}") + self._finish() raise StopIteration + except BaseException: + # Surface real errors instead of silently ending the stream. + self._finish() + raise + + def _finish(self): + self._exhausted = True + if self._finalizer is not None: + # Stops the background loop now; idempotent and also runs on GC if + # the stream is dropped before exhaustion. + self._finalizer() class StreamingResponse: def __init__( self, - content: Generator[str, None, None] | AsyncGenerator[str, None], + content: Generator[str | bytes, None, None] | AsyncGenerator[str | bytes, None], status_code: int | None = None, headers: Headers | None = None, media_type: str = "text/event-stream", @@ -149,7 +177,7 @@ def __init__( def SSEResponse( - content: Generator[str, None, None] | AsyncGenerator[str, None], + content: Generator[str | bytes, None, None] | AsyncGenerator[str | bytes, None], status_code: int | None = None, headers: Headers | None = None, ) -> StreamingResponse: diff --git a/src/types/response.rs b/src/types/response.rs index 57144d923..1552dee30 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -129,7 +129,20 @@ fn create_python_stream( let gen = generator.bind(py); match gen.call_method0("__next__") { - Ok(value) => value.extract::().ok().map(|s| (s, generator)), + Ok(value) => { + // Accept both `bytes` (used as-is) and `str` (UTF-8 encoded) + // chunks, so binary streaming works too (#1236). + if let Ok(py_bytes) = value.downcast::() { + Some((py_bytes.as_bytes().to_vec(), generator)) + } else if let Ok(s) = value.extract::() { + Some((s.into_bytes(), generator)) + } else { + log::error!( + "StreamingResponse generator yielded a value that is neither str nor bytes; ending stream" + ); + None + } + } Err(e) => { if !e.is_instance_of::(py) { log::error!("Generator error: {}", e); @@ -141,7 +154,7 @@ fn create_python_stream( }) .await { - Ok(Some((string_value, generator))) => Some((Ok(Bytes::from(string_value)), generator)), + Ok(Some((bytes_value, generator))) => Some((Ok(Bytes::from(bytes_value)), generator)), _ => None, } })) diff --git a/unit_tests/test_streaming_response.py b/unit_tests/test_streaming_response.py new file mode 100644 index 000000000..02317b99a --- /dev/null +++ b/unit_tests/test_streaming_response.py @@ -0,0 +1,95 @@ +import asyncio + +import pytest + +from robyn.responses import AsyncGeneratorWrapper + + +def test_wrapper_drives_generator_on_constructing_loop(): + """The async generator must run on the loop that was active at construction + (the handler's loop), so async resources bound to it work (#1219).""" + captured = {} + + async def gen(): + captured["loop"] = asyncio.get_running_loop() + yield "a" + yield "b" + + async def main(): + wrapper = AsyncGeneratorWrapper(gen()) # constructed on THIS loop + # Robyn drives __next__ from a worker thread, not the loop thread. + chunks = await asyncio.to_thread(lambda: list(wrapper)) + return chunks, asyncio.get_running_loop() + + chunks, handler_loop = asyncio.run(main()) + assert chunks == ["a", "b"] + assert captured["loop"] is handler_loop + + +def test_wrapper_propagates_generator_errors(): + """Errors inside the generator are raised, not silently swallowed.""" + + async def gen(): + yield "ok" + raise ValueError("boom") + + async def main(): + wrapper = AsyncGeneratorWrapper(gen()) + + def drive(): + collected = [] + with pytest.raises(ValueError, match="boom"): + for chunk in wrapper: + collected.append(chunk) + return collected + + return await asyncio.to_thread(drive) + + assert asyncio.run(main()) == ["ok"] + + +def test_wrapper_without_running_loop_uses_background_loop(): + """When constructed outside an async context (sync handler), the wrapper + runs the generator on its own background loop.""" + + async def gen(): + yield "x" + yield "y" + + wrapper = AsyncGeneratorWrapper(gen()) # no running loop here + assert wrapper._owns_loop is True + assert list(wrapper) == ["x", "y"] + + +def test_wrapper_supports_bytes_chunks(): + """The wrapper passes bytes chunks through unchanged (Rust encodes them).""" + + async def gen(): + yield b"\x00\x01" + yield b"\x02" + + wrapper = AsyncGeneratorWrapper(gen()) + assert list(wrapper) == [b"\x00\x01", b"\x02"] + + +def test_owned_loop_thread_is_cleaned_up_when_dropped_early(): + """The background loop thread must not leak if the stream is abandoned + before exhaustion (e.g. a client disconnect).""" + import gc + + async def gen(): + yield "a" + yield "b" + yield "c" + + wrapper = AsyncGeneratorWrapper(gen()) # sync context -> owns a background loop + assert wrapper._owns_loop is True + thread = wrapper._thread + assert thread.is_alive() + + assert next(wrapper) == "a" # consume one chunk, then abandon the rest + del wrapper + gc.collect() + + thread.join(timeout=3) + assert not thread.is_alive()