Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions petastorm/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from petastorm.transform import transform_schema
from petastorm.workers_pool.dummy_pool import DummyPool
from petastorm.workers_pool.process_pool import ProcessPool
from petastorm.workers_pool.thread_pool import ThreadPool
from petastorm.workers_pool.thread_pool import ThreadPool, OrderedThreadPool
from petastorm.workers_pool.ventilator import ConcurrentVentilator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,8 +88,9 @@ def make_reader(dataset_url,
or ``[file:///tmp/mydataset/00000.parquet, file:///tmp/mydataset/00001.parquet]``.
:param schema_fields: Can be: a list of unischema fields and/or regex pattern strings; ``None`` to read all fields;
an NGram object, then it will return an NGram of the specified fields.
:param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'process', 'dummy']
denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread'
:param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'orderedthread',
'process', 'dummy'] denoting a thread pool, a thread pool with ordered dataset pieces, a process pool,
or running everything in the master thread. Defaults to 'thread'
:param workers_count: An int for the number of workers to use in the reader pool. This only is used for the
thread or process pool. Defaults to 10
:param pyarrow_serialize: THE ARGUMENT IS DEPRECATED AND WILL BE REMOVED IN FUTURE VERSIONS.
Expand Down Expand Up @@ -160,6 +161,8 @@ def make_reader(dataset_url,

if reader_pool_type == 'thread':
reader_pool = ThreadPool(workers_count, results_queue_size)
elif reader_pool_type == 'orderedthread':
reader_pool = OrderedThreadPool(workers_count, results_queue_size)
elif reader_pool_type == 'process':
if pyarrow_serialize:
warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. "
Expand Down Expand Up @@ -240,8 +243,9 @@ def make_batch_reader(dataset_url_or_urls,
or ``[file:///tmp/mydataset/00000.parquet, file:///tmp/mydataset/00001.parquet]``.
:param schema_fields: A list of regex pattern strings. Only columns matching at least one of the
patterns in the list will be loaded.
:param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'process', 'dummy']
denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread'
:param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'orderedthread',
'process', 'dummy'] denoting a thread pool, a thread pool with ordered dataset pieces, a process pool,
or running everything in the master thread. Defaults to 'thread'
:param workers_count: An int for the number of workers to use in the reader pool. This only is used for the
thread or process pool. Defaults to 10
:param results_queue_size: Size of the results queue to store prefetched row-groups. Currently only applicable to
Expand Down Expand Up @@ -316,6 +320,8 @@ def make_batch_reader(dataset_url_or_urls,

if reader_pool_type == 'thread':
reader_pool = ThreadPool(workers_count, results_queue_size)
elif reader_pool_type == 'orderedthread':
reader_pool = OrderedThreadPool(workers_count, results_queue_size)
elif reader_pool_type == 'process':
serializer = ArrowTableSerializer()
reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers)
Expand Down
27 changes: 26 additions & 1 deletion petastorm/tests/test_parquet_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
lambda url, **kwargs: make_batch_reader(url, reader_pool_type='thread', **kwargs),
lambda url, **kwargs: make_batch_reader(url, reader_pool_type='process', **kwargs),
]

_OT = [
lambda url, **kwargs: make_batch_reader(url, reader_pool_type='orderedthread', **kwargs),
]

def _check_simple_reader(reader, expected_data):
# Read a bunch of entries from the dataset and compare the data to reference
Expand All @@ -54,6 +56,22 @@ def _check_simple_reader(reader, expected_data):

assert count == len(expected_data)

def _check_reader_order(reader, expected_data):
# Read a bunch of entries from the dataset and compare the data to reference
count = 0
idx = 0
for row in reader:
actual = row._asdict()

# Compare value of each entry in the batch
for id_value in actual['id']:
assert id_value == idx
idx += 1

count += len(actual['id'])

assert count == len(expected_data)


def _get_bad_field_name(field_list):
""" Grab first name from list of valid fields, append random characters to it to get an invalid
Expand All @@ -71,6 +89,13 @@ def test_simple_read(scalar_dataset, reader_factory):
_check_simple_reader(reader, scalar_dataset.data)


@pytest.mark.parametrize('reader_factory', _OT)
def test_simple_read_ordered(scalar_dataset, reader_factory):
"""Just a bunch of read and compares of all values to the expected values using the ordered reader pools"""
with reader_factory(scalar_dataset.url, shuffle_row_groups=False) as reader:
_check_reader_order(reader, scalar_dataset.data)


@pytest.mark.parametrize('reader_factory', _D)
def test_simple_read_from_a_single_file(scalar_dataset, reader_factory):
"""See if we can read data when a single parquet file is specified instead of a parquet directory"""
Expand Down
5 changes: 5 additions & 0 deletions petastorm/workers_pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ class TimeoutWaitingForResultError(RuntimeError):

class VentilatedItemProcessedMessage(object):
"""Object to signal that a worker has completed processing an item from the ventilation queue"""

class OrderedVentilatedItemProcessedMessage(VentilatedItemProcessedMessage):
"""Ventilated signal object which contains ordering metadata."""
def __init__(self, idx):
self.idx = idx
128 changes: 126 additions & 2 deletions petastorm/workers_pool/thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from threading import Thread, Event
from traceback import format_exc

from six.moves import queue
from six.moves import queue # type: ignore

from petastorm.workers_pool import EmptyResultError, VentilatedItemProcessedMessage
from petastorm.workers_pool import EmptyResultError, VentilatedItemProcessedMessage, OrderedVentilatedItemProcessedMessage

# Defines how frequently will we check the stop event while waiting on a blocking queue
IO_TIMEOUT_INTERVAL_S = 0.001
Expand Down Expand Up @@ -74,6 +74,33 @@ def run(self):
if self._profiling_enabled:
self.prof.disable()

class OrderedWorkerThread(WorkerThread):
def run(self):
if self._profiling_enabled:
self.prof.enable()
# Loop and accept messages from both channels, acting accordingly
while True:
# Check for stop event first to prevent erroneous reuse
if self._stop_event.is_set():
break
# If the message came from work_receiver channel
try:
(args, kargs) = self._ventilator_queue.get(block=True, timeout=IO_TIMEOUT_INTERVAL_S)
self._worker_impl.process(*args, **kargs)
self._worker_impl.publish_func(OrderedVentilatedItemProcessedMessage(kargs['piece_index']))
except queue.Empty:
pass
except WorkerTerminationRequested:
pass
except Exception as e: # pylint: disable=broad-except
stderr_message = 'Worker %d terminated: unexpected exception:\n' % self._worker_impl.worker_id
stderr_message += format_exc()
sys.stderr.write(stderr_message)
self._results_queue.put(e)
break
if self._profiling_enabled:
self.prof.disable()


class ThreadPool(object):
def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False):
Expand Down Expand Up @@ -219,3 +246,100 @@ def results_qsize(self):
@property
def diagnostics(self):
return {'output_queue_size': self.results_qsize()}

class OrderedThreadPool(ThreadPool):
def start(self, worker_class, worker_args=None, ventilator=None):
"""Starts worker threads.

:param worker_class: A class of the worker class. The class will be instantiated in the worker process. The
class must implement :class:`.WorkerBase` protocol
:param worker_setup_args: Argument that will be passed to ``args`` property of the instantiated
:class:`.WorkerBase`
:return: ``None``
"""
# Verify stop_event and raise exception if it's already set!
if self._stop_event.is_set():
raise RuntimeError(
"ThreadPool({}) cannot be reused! stop_event set? {}".format(
len(self._workers), self._stop_event.is_set()
)
)

# Set up a channel to send work
self._ventilator_queue = queue.Queue()
self._results_queue = queue.Queue(self._results_queue_size)
self._workers = []
for worker_id in range(self.workers_count):
worker_impl = worker_class(worker_id, self._stop_aware_put, worker_args)
new_thread = OrderedWorkerThread(
worker_impl,
self._stop_event,
self._ventilator_queue,
self._results_queue,
self._profiling_enabled,
)
# Make the thread daemonic. Since it only reads it's ok to abort while running - no resource corruption
# will occur.
new_thread.daemon = True
self._workers.append(new_thread)

# Spin up all worker threads
for w in self._workers:
w.start()

if ventilator:
self._ventilator = ventilator
self._ventilator.start()

self._unordered_results_buffer = []
self._unordered_idx_buffer = []
self._result_order = [el['piece_index'] for el in ventilator._items_to_ventilate]

def get_results(self):
"""Returns results from worker pool or re-raise worker's exception if any happen in worker thread.

:param timeout: If None, will block forever, otherwise will raise :class:`.TimeoutWaitingForResultError`
exception if no data received within the timeout (in seconds)

:return: arguments passed to ``publish_func(...)`` by a worker. If no more results are anticipated,
:class:`.EmptyResultError`.
"""

while True:
# If there is no more work to do, raise an EmptyResultError
if (
self._results_queue.empty()
and self._ventilated_items == self._ventilated_items_processed
):
# We also need to check if we are using a ventilator and if it is completed, and
# whether we have emptied the unordered results buffer
if (not self._ventilator or self._ventilator.completed()) and not len(
self._unordered_idx_buffer
):
raise EmptyResultError()

# If the next rowgroup is in the unordered buffer, return it
if self._result_order[0] in self._unordered_idx_buffer:
idx = self._unordered_idx_buffer.index(self._result_order.pop(0))
_ = self._unordered_idx_buffer.pop(idx)
return self._unordered_results_buffer.pop(idx)

try:
result = self._results_queue.get(
timeout=_VERIFY_END_OF_VENTILATION_PERIOD
)
if isinstance(result, VentilatedItemProcessedMessage):
self._ventilated_items_processed += 1
if self._ventilator:
self._ventilator.processed_item()
self._unordered_idx_buffer.append(result.idx)
continue
elif isinstance(result, Exception):
self.stop()
self.join()
raise result
else:
self._unordered_results_buffer.append(result)
continue
except queue.Empty:
continue