diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 5e1c6d219f4b..d64046cddfb5 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -17,15 +17,20 @@ from __future__ import absolute_import +import asyncio import logging import random +import threading import uuid +from collections.abc import AsyncIterable +from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from math import floor from threading import RLock from time import sleep from time import time from types import GeneratorType +from typing import Optional import apache_beam as beam from apache_beam import TimeDomain @@ -60,6 +65,9 @@ class AsyncWrapper(beam.DoFn): [coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()])) # The below items are one per dofn (not instance) so are maps of UUID to # value. + _event_loop: Optional[asyncio.AbstractEventLoop] = None + _event_loop_thread: Optional[threading.Thread] = None + _loop_started = threading.Event() _processing_elements = {} _items_in_buffer = {} _pool = {} @@ -78,6 +86,7 @@ def __init__( timeout=1, max_wait_time=0.5, id_fn=None, + use_asyncio=False, ): """Wraps the sync_fn to create an asynchronous version. @@ -104,6 +113,8 @@ def __init__( schedule an item. Used in testing to ensure timeouts are met. id_fn: A function that returns a hashable object from an element. This will be used to track items instead of the element's default hash. + use_asyncio: If true, use asyncio and coroutines to process items. If + false, use ThreadPoolExecutor. """ self._sync_fn = sync_fn self._uuid = uuid.uuid4().hex @@ -112,6 +123,7 @@ def __init__( self._max_wait_time = max_wait_time self._timer_frequency = callback_frequency self._id_fn = id_fn or (lambda x: x) + self._use_asyncio = use_asyncio if max_items_to_buffer is None: self._max_items_to_buffer = max(parallelism * 2, 10) else: @@ -126,8 +138,28 @@ def __init__( def initialize_pool(parallelism): return lambda: ThreadPoolExecutor(max_workers=parallelism) + @staticmethod + def _run_event_loop(): + """Sets up and runs the asyncio event loop in a background thread.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + AsyncWrapper._event_loop = loop + AsyncWrapper._loop_started.set() + loop.run_forever() + loop.close() + @staticmethod def reset_state(): + if AsyncWrapper._event_loop: + AsyncWrapper._event_loop.call_soon_threadsafe( + AsyncWrapper._event_loop.stop) + if AsyncWrapper._event_loop_thread: + AsyncWrapper._event_loop_thread.join() + + AsyncWrapper._event_loop = None + AsyncWrapper._event_loop_thread = None + AsyncWrapper._loop_started.clear() + for pool in AsyncWrapper._pool.values(): pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown( wait=True, cancel_futures=True) @@ -140,6 +172,12 @@ def setup(self): """Forwards to the wrapped dofn's setup method.""" self._sync_fn.setup() with AsyncWrapper._lock: + if self._use_asyncio and AsyncWrapper._event_loop_thread is None: + AsyncWrapper._event_loop_thread = threading.Thread( + target=AsyncWrapper._run_event_loop, daemon=True) + AsyncWrapper._event_loop_thread.start() + AsyncWrapper._loop_started.wait() + if not self._uuid in AsyncWrapper._pool: AsyncWrapper._pool[self._uuid] = Shared() AsyncWrapper._processing_elements[self._uuid] = {} @@ -190,6 +228,52 @@ def sync_fn_process(self, element, *args, **kwargs): return to_return + async def async_fn_process(self, element, *args, **kwargs): + """Makes the call to the wrapped dofn's start_bundle, process + and finish_bundle methods for asynchronous DoFns. + + Args: + element: The element to process. + *args: Any additional arguments to pass to the wrapped dofn's process + method. + **kwargs: Any additional keyword arguments to pass to the wrapped dofn's + process method. + + Returns: + A list of elements produced by the input element. + """ + self._sync_fn.start_bundle() + process_result = self._sync_fn.process(element, *args, **kwargs) + bundle_result = self._sync_fn.finish_bundle() + + if not process_result: + process_result = [] + elif isinstance(process_result, AsyncIterable): + temp = [] + async for item in process_result: + temp.append(item) + process_result = temp + elif not isinstance(process_result, (GeneratorType, Iterable)): + process_result = [process_result] + + if not bundle_result: + bundle_result = [] + elif isinstance(bundle_result, AsyncIterable): + temp = [] + async for item in bundle_result: + temp.append(item) + bundle_result = temp + elif not isinstance(bundle_result, (GeneratorType, Iterable)): + bundle_result = [bundle_result] + + to_return = [] + for x in process_result: + to_return.append(x) + for x in bundle_result: + to_return.append(x) + + return to_return + def decrement_items_in_buffer(self, future): with AsyncWrapper._lock: AsyncWrapper._items_in_buffer[self._uuid] -= 1 @@ -214,10 +298,16 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): logging.info('item %s already in processing elements', element) return True if self.accepting_items() or ignore_buffer: - result = AsyncWrapper._pool[self._uuid].acquire( - AsyncWrapper.initialize_pool(self._parallelism)).submit( - lambda: self.sync_fn_process(element, *args, **kwargs), - ) + if self._use_asyncio: + result = asyncio.run_coroutine_threadsafe( + self.async_fn_process(element, *args, **kwargs), + AsyncWrapper._event_loop, + ) + else: + result = AsyncWrapper._pool[self._uuid].acquire( + AsyncWrapper.initialize_pool(self._parallelism)).submit( + lambda: self.sync_fn_process(element, *args, **kwargs), + ) result.add_done_callback(self.decrement_items_in_buffer) AsyncWrapper._processing_elements[self._uuid][element_id] = ( element, result) diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index fe75de05ccd5..c956ff26cbb8 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -86,7 +86,9 @@ def set(self, time): self.time = time -class AsyncTest(unittest.TestCase): +class _AsyncTestBase: + use_asyncio: bool + def setUp(self): super().setUp() async_lib.AsyncWrapper.reset_state() @@ -132,7 +134,8 @@ def __eq__(self, other): return self.element_id == other.element_id dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id) + async_dofn = async_lib.AsyncWrapper( + dofn, id_fn=lambda x: x.element_id, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -156,7 +159,7 @@ def __eq__(self, other): def test_basic(self): # Setup an async dofn and send a message in to process. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -183,7 +186,7 @@ def test_basic(self): def test_multi_key(self): # Send in two messages with different keys.. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state_key1 = FakeBagState([]) fake_bag_state_key2 = FakeBagState([]) @@ -211,7 +214,7 @@ def test_multi_key(self): def test_long_item(self): # Test that everything still works with a long running time for the dofn. dofn = BasicDofn(sleep_time=5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -234,7 +237,7 @@ def test_lost_item(self): # Setup an element in the bag stat thats not in processing state. # The async dofn should reschedule this element. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_timer = FakeTimer(0) msg = ('key1', 1) @@ -252,7 +255,7 @@ def test_cancelled_item(self): # it is not present in the bag state. Either this item moved or a commit # failed making the local state and bag stat inconsistent. dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() msg = ('key1', 1) msg2 = ('key1', 2) @@ -272,7 +275,7 @@ def test_multi_element_dofn(self): # Test that async works when a dofn produces multiple elements in process # and finish_bundle. dofn = MultiElementDoFn() - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -289,7 +292,7 @@ def test_duplicates(self): # Test that async will produce a single output when a given input is sent # multiple times. dofn = BasicDofn(5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -310,7 +313,7 @@ def test_slow_duplicates(self): # Test that async will produce a single output when a given input is sent # multiple times. dofn = BasicDofn(5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) @@ -335,7 +338,7 @@ def test_slow_duplicates(self): def test_buffer_count(self): # Test that the buffer count is correctly incremented when adding items. dofn = BasicDofn(5) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() msg = ('key1', 1) fake_timer = FakeTimer(0) @@ -353,7 +356,10 @@ def test_buffer_stops_accepting_items(self): # Test that the buffer stops accepting items when it is full. dofn = BasicDofn(5) async_dofn = async_lib.AsyncWrapper( - dofn, parallelism=1, max_items_to_buffer=5) + dofn, + parallelism=1, + max_items_to_buffer=5, + use_asyncio=self.use_asyncio) async_dofn.setup() fake_timer = FakeTimer(0) fake_bag_state = FakeBagState([]) @@ -391,7 +397,7 @@ def add_item(i): def test_buffer_with_cancellation(self): dofn = BasicDofn(3) - async_dofn = async_lib.AsyncWrapper(dofn) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio) async_dofn.setup() msg = ('key1', 1) msg2 = ('key1', 2) @@ -423,7 +429,8 @@ def test_load_correctness(self): # Test AsyncDofn over heavy load. dofn = BasicDofn(1) max_sleep = 10 - async_dofn = async_lib.AsyncWrapper(dofn, max_wait_time=max_sleep) + async_dofn = async_lib.AsyncWrapper( + dofn, max_wait_time=max_sleep, use_asyncio=self.use_asyncio) async_dofn.setup() bag_states = {} timers = {} @@ -473,5 +480,13 @@ def add_item(i): self.assertEqual(bag_states['key' + str(i)].items, []) +class AsyncTestThreadPool(_AsyncTestBase, unittest.TestCase): + use_asyncio = False + + +class AsyncTestAsyncio(_AsyncTestBase, unittest.TestCase): + use_asyncio = True + + if __name__ == '__main__': unittest.main()