diff --git a/petastorm/reader.py b/petastorm/reader.py index 8fa69935..7aa0fbb8 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -35,7 +35,7 @@ from petastorm.transform import transform_schema from petastorm.workers_pool.dummy_pool import DummyPool from petastorm.workers_pool.process_pool import ProcessPool -from petastorm.workers_pool.thread_pool import ThreadPool +from petastorm.workers_pool.thread_pool import ThreadPool, OrderedThreadPool from petastorm.workers_pool.ventilator import ConcurrentVentilator logger = logging.getLogger(__name__) @@ -88,8 +88,9 @@ def make_reader(dataset_url, or ``[file:///tmp/mydataset/00000.parquet, file:///tmp/mydataset/00001.parquet]``. :param schema_fields: Can be: a list of unischema fields and/or regex pattern strings; ``None`` to read all fields; an NGram object, then it will return an NGram of the specified fields. - :param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'process', 'dummy'] - denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread' + :param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'orderedthread', + 'process', 'dummy'] denoting a thread pool, a thread pool with ordered dataset pieces, a process pool, + or running everything in the master thread. Defaults to 'thread' :param workers_count: An int for the number of workers to use in the reader pool. This only is used for the thread or process pool. Defaults to 10 :param pyarrow_serialize: THE ARGUMENT IS DEPRECATED AND WILL BE REMOVED IN FUTURE VERSIONS. @@ -160,6 +161,8 @@ def make_reader(dataset_url, if reader_pool_type == 'thread': reader_pool = ThreadPool(workers_count, results_queue_size) + elif reader_pool_type == 'orderedthread': + reader_pool = OrderedThreadPool(workers_count, results_queue_size) elif reader_pool_type == 'process': if pyarrow_serialize: warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. " @@ -240,8 +243,9 @@ def make_batch_reader(dataset_url_or_urls, or ``[file:///tmp/mydataset/00000.parquet, file:///tmp/mydataset/00001.parquet]``. :param schema_fields: A list of regex pattern strings. Only columns matching at least one of the patterns in the list will be loaded. - :param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'process', 'dummy'] - denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread' + :param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'orderedthread', + 'process', 'dummy'] denoting a thread pool, a thread pool with ordered dataset pieces, a process pool, + or running everything in the master thread. Defaults to 'thread' :param workers_count: An int for the number of workers to use in the reader pool. This only is used for the thread or process pool. Defaults to 10 :param results_queue_size: Size of the results queue to store prefetched row-groups. Currently only applicable to @@ -316,6 +320,8 @@ def make_batch_reader(dataset_url_or_urls, if reader_pool_type == 'thread': reader_pool = ThreadPool(workers_count, results_queue_size) + elif reader_pool_type == 'orderedthread': + reader_pool = OrderedThreadPool(workers_count, results_queue_size) elif reader_pool_type == 'process': serializer = ArrowTableSerializer() reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers) diff --git a/petastorm/tests/test_parquet_reader.py b/petastorm/tests/test_parquet_reader.py index 62708e1b..11b7695e 100644 --- a/petastorm/tests/test_parquet_reader.py +++ b/petastorm/tests/test_parquet_reader.py @@ -33,7 +33,9 @@ lambda url, **kwargs: make_batch_reader(url, reader_pool_type='thread', **kwargs), lambda url, **kwargs: make_batch_reader(url, reader_pool_type='process', **kwargs), ] - +_OT = [ + lambda url, **kwargs: make_batch_reader(url, reader_pool_type='orderedthread', **kwargs), +] def _check_simple_reader(reader, expected_data): # Read a bunch of entries from the dataset and compare the data to reference @@ -54,6 +56,22 @@ def _check_simple_reader(reader, expected_data): assert count == len(expected_data) +def _check_reader_order(reader, expected_data): + # Read a bunch of entries from the dataset and compare the data to reference + count = 0 + idx = 0 + for row in reader: + actual = row._asdict() + + # Compare value of each entry in the batch + for id_value in actual['id']: + assert id_value == idx + idx += 1 + + count += len(actual['id']) + + assert count == len(expected_data) + def _get_bad_field_name(field_list): """ Grab first name from list of valid fields, append random characters to it to get an invalid @@ -71,6 +89,13 @@ def test_simple_read(scalar_dataset, reader_factory): _check_simple_reader(reader, scalar_dataset.data) +@pytest.mark.parametrize('reader_factory', _OT) +def test_simple_read_ordered(scalar_dataset, reader_factory): + """Just a bunch of read and compares of all values to the expected values using the ordered reader pools""" + with reader_factory(scalar_dataset.url, shuffle_row_groups=False) as reader: + _check_reader_order(reader, scalar_dataset.data) + + @pytest.mark.parametrize('reader_factory', _D) def test_simple_read_from_a_single_file(scalar_dataset, reader_factory): """See if we can read data when a single parquet file is specified instead of a parquet directory""" diff --git a/petastorm/workers_pool/__init__.py b/petastorm/workers_pool/__init__.py index 47243dd6..5e5b4f4f 100644 --- a/petastorm/workers_pool/__init__.py +++ b/petastorm/workers_pool/__init__.py @@ -24,3 +24,8 @@ class TimeoutWaitingForResultError(RuntimeError): class VentilatedItemProcessedMessage(object): """Object to signal that a worker has completed processing an item from the ventilation queue""" + +class OrderedVentilatedItemProcessedMessage(VentilatedItemProcessedMessage): + """Ventilated signal object which contains ordering metadata.""" + def __init__(self, idx): + self.idx = idx \ No newline at end of file diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 649aa77f..27bf0259 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -19,9 +19,9 @@ from threading import Thread, Event from traceback import format_exc -from six.moves import queue +from six.moves import queue # type: ignore -from petastorm.workers_pool import EmptyResultError, VentilatedItemProcessedMessage +from petastorm.workers_pool import EmptyResultError, VentilatedItemProcessedMessage, OrderedVentilatedItemProcessedMessage # Defines how frequently will we check the stop event while waiting on a blocking queue IO_TIMEOUT_INTERVAL_S = 0.001 @@ -74,6 +74,33 @@ def run(self): if self._profiling_enabled: self.prof.disable() +class OrderedWorkerThread(WorkerThread): + def run(self): + if self._profiling_enabled: + self.prof.enable() + # Loop and accept messages from both channels, acting accordingly + while True: + # Check for stop event first to prevent erroneous reuse + if self._stop_event.is_set(): + break + # If the message came from work_receiver channel + try: + (args, kargs) = self._ventilator_queue.get(block=True, timeout=IO_TIMEOUT_INTERVAL_S) + self._worker_impl.process(*args, **kargs) + self._worker_impl.publish_func(OrderedVentilatedItemProcessedMessage(kargs['piece_index'])) + except queue.Empty: + pass + except WorkerTerminationRequested: + pass + except Exception as e: # pylint: disable=broad-except + stderr_message = 'Worker %d terminated: unexpected exception:\n' % self._worker_impl.worker_id + stderr_message += format_exc() + sys.stderr.write(stderr_message) + self._results_queue.put(e) + break + if self._profiling_enabled: + self.prof.disable() + class ThreadPool(object): def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False): @@ -219,3 +246,100 @@ def results_qsize(self): @property def diagnostics(self): return {'output_queue_size': self.results_qsize()} + +class OrderedThreadPool(ThreadPool): + def start(self, worker_class, worker_args=None, ventilator=None): + """Starts worker threads. + + :param worker_class: A class of the worker class. The class will be instantiated in the worker process. The + class must implement :class:`.WorkerBase` protocol + :param worker_setup_args: Argument that will be passed to ``args`` property of the instantiated + :class:`.WorkerBase` + :return: ``None`` + """ + # Verify stop_event and raise exception if it's already set! + if self._stop_event.is_set(): + raise RuntimeError( + "ThreadPool({}) cannot be reused! stop_event set? {}".format( + len(self._workers), self._stop_event.is_set() + ) + ) + + # Set up a channel to send work + self._ventilator_queue = queue.Queue() + self._results_queue = queue.Queue(self._results_queue_size) + self._workers = [] + for worker_id in range(self.workers_count): + worker_impl = worker_class(worker_id, self._stop_aware_put, worker_args) + new_thread = OrderedWorkerThread( + worker_impl, + self._stop_event, + self._ventilator_queue, + self._results_queue, + self._profiling_enabled, + ) + # Make the thread daemonic. Since it only reads it's ok to abort while running - no resource corruption + # will occur. + new_thread.daemon = True + self._workers.append(new_thread) + + # Spin up all worker threads + for w in self._workers: + w.start() + + if ventilator: + self._ventilator = ventilator + self._ventilator.start() + + self._unordered_results_buffer = [] + self._unordered_idx_buffer = [] + self._result_order = [el['piece_index'] for el in ventilator._items_to_ventilate] + + def get_results(self): + """Returns results from worker pool or re-raise worker's exception if any happen in worker thread. + + :param timeout: If None, will block forever, otherwise will raise :class:`.TimeoutWaitingForResultError` + exception if no data received within the timeout (in seconds) + + :return: arguments passed to ``publish_func(...)`` by a worker. If no more results are anticipated, + :class:`.EmptyResultError`. + """ + + while True: + # If there is no more work to do, raise an EmptyResultError + if ( + self._results_queue.empty() + and self._ventilated_items == self._ventilated_items_processed + ): + # We also need to check if we are using a ventilator and if it is completed, and + # whether we have emptied the unordered results buffer + if (not self._ventilator or self._ventilator.completed()) and not len( + self._unordered_idx_buffer + ): + raise EmptyResultError() + + # If the next rowgroup is in the unordered buffer, return it + if self._result_order[0] in self._unordered_idx_buffer: + idx = self._unordered_idx_buffer.index(self._result_order.pop(0)) + _ = self._unordered_idx_buffer.pop(idx) + return self._unordered_results_buffer.pop(idx) + + try: + result = self._results_queue.get( + timeout=_VERIFY_END_OF_VENTILATION_PERIOD + ) + if isinstance(result, VentilatedItemProcessedMessage): + self._ventilated_items_processed += 1 + if self._ventilator: + self._ventilator.processed_item() + self._unordered_idx_buffer.append(result.idx) + continue + elif isinstance(result, Exception): + self.stop() + self.join() + raise result + else: + self._unordered_results_buffer.append(result) + continue + except queue.Empty: + continue