diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index fedad01d..b7a3f090 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -15,6 +15,7 @@ import hashlib import operator +import logging import numpy as np import pandas as pd @@ -26,6 +27,9 @@ from petastorm.workers_pool import EmptyResultError from petastorm.workers_pool.worker_base import WorkerBase +# Initialize logger +logger = logging.getLogger(__name__) + class ArrowReaderWorkerResultsQueueReader(object): def __init__(self): @@ -91,6 +95,9 @@ class ArrowReaderWorker(WorkerBase): def __init__(self, worker_id, publish_func, args): super(ArrowReaderWorker, self).__init__(worker_id, publish_func, args) + # Add debug log in the constructor + print(f'DEBUG: Initializing ArrowReaderWorker with worker_id: {worker_id}') + self._filesystem = args[0] self._dataset_path_or_paths = args[1] self._schema = args[2] @@ -101,7 +108,10 @@ def __init__(self, worker_id, publish_func, args): self._transformed_schema = args[7] self._arrow_filters = args[8] self._shuffle_rows = args[9] - self._random_state = np.random.RandomState(seed=args[10]) + self._random_seed = args[10] + + # Initialize random number generator + self._rng = np.random.default_rng(self._random_seed) if self._ngram: raise NotImplementedError('ngrams are not supported by ArrowReaderWorker') @@ -128,12 +138,18 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): :return: """ + # Add debug log in the process method + print(f'DEBUG: Processing piece_index: {piece_index}') + if not self._dataset: self._dataset = pq.ParquetDataset( self._dataset_path_or_paths, filesystem=self._filesystem, validate_schema=False, filters=self._arrow_filters) + # Add debug log after dataset is initialized + print(f'DEBUG: ParquetDataset initialized with path: {self._dataset_path_or_paths}') + piece = self._split_pieces[piece_index] # Create pyarrow file system @@ -160,11 +176,16 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): path_str = self._dataset_path_or_paths cache_key = '{}:{}:{}'.format(hashlib.md5(path_str.encode('utf-8')).hexdigest(), piece.path, piece_index) + + # Add debug log for cache key + print(f'DEBUG: Cache key generated: {cache_key}') + all_cols = self._local_cache.get(cache_key, lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition)) if all_cols: self.publish_func(all_cols) + print(f'DEBUG: Published columns for piece_index: {piece_index}') @staticmethod def _check_shape_and_ravel(x, field): @@ -289,9 +310,19 @@ def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_ # pyarrow would fail if we request a column names that the dataset is partitioned by table = piece.read(columns=column_names - partition_names, partitions=self._dataset.partitions) + + # Handle row shuffling based on shuffle_rows setting if self._shuffle_rows: - indices = self._random_state.permutation(table.num_rows) - table = table.take(indices) + if self._random_seed is not None and self._random_seed != 0: + # Deterministic randomization: use provided seed + indices = self._rng.permutation(table.num_rows) + else: + # Non-deterministic randomization: use np.random directly + indices = np.random.permutation(table.num_rows) + else: + # Deterministic natural order: shuffle_rows=False + indices = np.arange(table.num_rows) + table = table.take(indices) # Drop columns we did not explicitly request. This may happen when a table is partitioned. Besides columns # requested, pyarrow will also return partition values. Having these unexpected fields will break some diff --git a/petastorm/reader.py b/petastorm/reader.py index 8fa69935..3b10625b 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -38,6 +38,7 @@ from petastorm.workers_pool.thread_pool import ThreadPool from petastorm.workers_pool.ventilator import ConcurrentVentilator +# Initialize logger logger = logging.getLogger(__name__) # Ventilator guarantees that no more than workers + _VENTILATE_EXTRA_ROWGROUPS are processed at a moment by a @@ -159,7 +160,7 @@ def make_reader(dataset_url, 'To read from a non-Petastorm Parquet store use make_batch_reader') if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': if pyarrow_serialize: warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. " @@ -315,7 +316,7 @@ def make_batch_reader(dataset_url_or_urls, raise ValueError('Unknown cache_type: {}'.format(cache_type)) if reader_pool_type == 'thread': - reader_pool = ThreadPool(workers_count, results_queue_size) + reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed) elif reader_pool_type == 'process': serializer = ArrowTableSerializer() reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers) @@ -400,6 +401,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, These will be applied when loading the parquet file with PyArrow. More information here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html """ + print(f'DEBUG: Initializing Reader with dataset_path: {dataset_path}, num_epochs: {num_epochs}') self.num_epochs = num_epochs # 1. Open the parquet storage (dataset) @@ -437,9 +439,11 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, raise NotImplementedError('Using timestamp_overlap=False is not implemented with' ' shuffle_options.shuffle_row_drop_partitions > 1') + print(f'DEBUG: Reader initialized with schema_fields: {schema_fields}') + cache = cache or NullCache() - self._workers_pool = reader_pool or ThreadPool(10) + self._workers_pool = reader_pool or ThreadPool(10, shuffle_rows=shuffle_rows, seed=seed) # Make a schema view (a view is a Unischema containing only a subset of fields # Will raise an exception if invalid schema fields are in schema_fields @@ -483,7 +487,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, self.ngram, row_groups, cache, transform_spec, self.schema, filters, shuffle_rows, seed), ventilator=self.ventilator) - logger.debug('Workers pool started') + print('DEBUG: Workers pool started') self.last_row_consumed = False self.stopped = False @@ -653,6 +657,7 @@ def _normalize_shuffle_options(shuffle_row_drop_partitions, dataset): def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_drop_partitions, num_epochs, worker_predicate, max_ventilation_queue_size, seed): + print(f'DEBUG: Creating ventilator with row_group_indexes: {row_group_indexes}') items_to_ventilate = [] for piece_index in row_group_indexes: for shuffle_row_drop_partition in range(shuffle_row_drop_partitions): @@ -670,12 +675,12 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_ random_seed=seed) def stop(self): - """Stops all worker threads/processes.""" + print('DEBUG: Stopping Reader') self._workers_pool.stop() self.stopped = True def join(self): - """Joins all worker threads/processes. Will block until all worker workers have been fully terminated.""" + print('DEBUG: Joining Reader') self._workers_pool.join() @property diff --git a/petastorm/tests/test_tf_dataset.py b/petastorm/tests/test_tf_dataset.py index 17e812a9..6ae74809 100644 --- a/petastorm/tests/test_tf_dataset.py +++ b/petastorm/tests/test_tf_dataset.py @@ -128,6 +128,7 @@ def test_with_dataset_repeat(synthetic_dataset, reader_factory): def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory): """ Check if ``tf.data.Dataset``'s ``repeat`` works after ``tf.data.Dataset``'s ``cache``.""" epochs = 3 + print(f"Starting test_with_dataset_repeat_after_cache with {epochs} epochs") with reader_factory(synthetic_dataset.url, schema_fields=[TestSchema.id]) as reader: dataset = make_petastorm_dataset(reader) dataset = dataset.cache() @@ -138,18 +139,22 @@ def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory): with tf.Session() as sess: with pytest.warns(None): # Expect no warnings since cache() is called before repeat() - for _ in range(epochs): + for epoch in range(epochs): + print(f"Starting epoch {epoch}") actual_res = [] - for _, _ in enumerate(synthetic_dataset.data): + for i, _ in enumerate(synthetic_dataset.data): actual = sess.run(it_op)._asdict() actual_res.append(actual["id"]) + print(f"iteration: {i} {actual['id']}") expected_res = list(range(len(synthetic_dataset.data))) + print(f"Epoch: {epoch} actual {sorted(actual_res)}, expected {expected_res}") # sort dataset output since row_groups are shuffled from reader. np.testing.assert_equal(sorted(actual_res), expected_res) - + print(f"Completed epoch {epoch}") # Exhausted all epochs. Fetching next value should trigger OutOfRangeError with pytest.raises(tf.errors.OutOfRangeError): sess.run(it_op) + print("Completed test_with_dataset_repeat_after_cache") @pytest.mark.forked diff --git a/petastorm/workers_pool/tests/test_workers_pool.py b/petastorm/workers_pool/tests/test_workers_pool.py index 142c608a..a3029476 100644 --- a/petastorm/workers_pool/tests/test_workers_pool.py +++ b/petastorm/workers_pool/tests/test_workers_pool.py @@ -141,15 +141,17 @@ def test_stop_when_result_queue_is_full(self): SLEEP_DELTA = 0.01 TIMEOUT = 20 QUEUE_SIZE = 2 + WORKERS_COUNT = 10 - pool = ThreadPool(10, results_queue_size=QUEUE_SIZE) + pool = ThreadPool(WORKERS_COUNT, results_queue_size=QUEUE_SIZE) pool.start(WorkerIdGeneratingWorker) - for _ in range(100): + for _ in range(1000): pool.ventilate() + expected_queue_size = WORKERS_COUNT * max(5, QUEUE_SIZE // WORKERS_COUNT) cumulative_wait = 0 - while pool.results_qsize() != QUEUE_SIZE: + while pool.results_qsize() != expected_queue_size: time.sleep(SLEEP_DELTA) cumulative_wait += SLEEP_DELTA # Make sure we wait no longer than the timeout. Otherwise, something is very wrong diff --git a/petastorm/workers_pool/thread_pool.py b/petastorm/workers_pool/thread_pool.py index 649aa77f..7ee1a8b9 100644 --- a/petastorm/workers_pool/thread_pool.py +++ b/petastorm/workers_pool/thread_pool.py @@ -22,12 +22,15 @@ from six.moves import queue from petastorm.workers_pool import EmptyResultError, VentilatedItemProcessedMessage +import logging # Defines how frequently will we check the stop event while waiting on a blocking queue IO_TIMEOUT_INTERVAL_S = 0.001 # Amount of time we will wait on a the queue to get the next result. If no results received until then, we will # recheck if no more items are expected to be ventilated -_VERIFY_END_OF_VENTILATION_PERIOD = 0.1 +_VERIFY_END_OF_VENTILATION_PERIOD = 1 + +logger = logging.getLogger(__name__) class WorkerTerminationRequested(Exception): @@ -76,7 +79,7 @@ def run(self): class ThreadPool(object): - def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False): + def __init__(self, workers_count, results_queue_size=50, shuffle_rows=False, seed=None, profiling_enabled=False): """Initializes a thread pool. TODO: consider using a standard thread pool @@ -88,9 +91,13 @@ def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False :param workers_count: Number of threads :param profile: Whether to run a profiler on the threads """ + print(f'DEBUG: Initializing ThreadPool with workers_count: {workers_count}') self._seed = random.randint(0, 100000) + self._shuffle_rows = shuffle_rows + self._seed = seed self._workers = [] - self._ventilator_queue = None + self._ventilator_queues = [] + self.workers_count = workers_count self._results_queue_size = results_queue_size # Worker threads will watch this event and gracefully shutdown when the event is set @@ -98,9 +105,14 @@ def __init__(self, workers_count, results_queue_size=50, profiling_enabled=False self._profiling_enabled = profiling_enabled self._ventilated_items = 0 - self._ventilated_items_processed = 0 + # Count of items ventilated by each worker + self._ventilated_items_by_worker = [0 for _ in range(self.workers_count)] + # Count of items processed by each worker + self._ventilated_items_processed_by_worker = [0 for _ in range(self.workers_count)] self._ventilator = None + self._get_results_worker_id = 0 + def start(self, worker_class, worker_args=None, ventilator=None): """Starts worker threads. @@ -110,19 +122,30 @@ class must implement :class:`.WorkerBase` protocol :class:`.WorkerBase` :return: ``None`` """ + print(f'DEBUG: Starting ThreadPool with worker_class: {worker_class}') # Verify stop_event and raise exception if it's already set! if self._stop_event.is_set(): raise RuntimeError('ThreadPool({}) cannot be reused! stop_event set? {}' .format(len(self._workers), self._stop_event.is_set())) - # Set up a channel to send work - self._ventilator_queue = queue.Queue() - self._results_queue = queue.Queue(self._results_queue_size) + # Set up a channel for each worker to send work + self._ventilator_queues = [queue.Queue() for _ in range(self.workers_count)] + + # Set up a channel for each worker to send results + self._results_queues = [ + queue.Queue(max(5, self._results_queue_size // self.workers_count)) + for _ in range(self.workers_count) + ] + self._workers = [] for worker_id in range(self.workers_count): - worker_impl = worker_class(worker_id, self._stop_aware_put, worker_args) - new_thread = WorkerThread(worker_impl, self._stop_event, self._ventilator_queue, - self._results_queue, self._profiling_enabled) + # Create a closure that captures the worker_id for this specific worker + def make_publish_func(worker_id): + return lambda data: self._stop_aware_put(worker_id, data) + + worker_impl = worker_class(worker_id, make_publish_func(worker_id), worker_args) + new_thread = WorkerThread(worker_impl, self._stop_event, self._ventilator_queues[worker_id], + self._results_queues[worker_id], self._profiling_enabled) # Make the thread daemonic. Since it only reads it's ok to abort while running - no resource corruption # will occur. new_thread.daemon = True @@ -139,8 +162,24 @@ class must implement :class:`.WorkerBase` protocol def ventilate(self, *args, **kargs): """Sends a work item to a worker process. Will result in ``worker.process(...)`` call with arbitrary arguments. """ + print(f'DEBUG: Ventilating work item with args: {args}, kargs: {kargs}') + # Distribute work items in a round-robin manner across each worker ventilator queue + current_worker_id = self._ventilated_items % self.workers_count self._ventilated_items += 1 - self._ventilator_queue.put((args, kargs)) + self._ventilated_items_by_worker[current_worker_id] += 1 + self._ventilator_queues[current_worker_id].put((args, kargs)) + + def current_worker_done(self, worker_id): + # Check if the current worker has processed all the items it was assigned and if the results queue is empty + return (self._ventilated_items_processed_by_worker[worker_id] == self._ventilated_items_by_worker[worker_id] + and self._results_queues[worker_id].empty()) + + def all_workers_done(self): + # Check if all workers have processed all the items they were assigned and if the results queues are empty + for i in range(self.workers_count): + if not self.current_worker_done(i): + return False + return True def get_results(self): """Returns results from worker pool or re-raise worker's exception if any happen in worker thread. @@ -151,20 +190,35 @@ def get_results(self): :return: arguments passed to ``publish_func(...)`` by a worker. If no more results are anticipated, :class:`.EmptyResultError`. """ - + # If shuffle_rows is enabled and the seed is not set, we need to use a non-blocking + # as we don't care about the strict round robin order + use_non_blocking_get = self._shuffle_rows and (self._seed is None or self._seed == 0) while True: # If there is no more work to do, raise an EmptyResultError - if self._results_queue.empty() and self._ventilated_items == self._ventilated_items_processed: + if self.all_workers_done(): # We also need to check if we are using a ventilator and if it is completed if not self._ventilator or self._ventilator.completed(): raise EmptyResultError() + # If the current worker is done, we need to get the result from the next worker + if self.current_worker_done(self._get_results_worker_id): + self._get_results_worker_id = (self._get_results_worker_id + 1) % self.workers_count + continue + try: - result = self._results_queue.get(timeout=_VERIFY_END_OF_VENTILATION_PERIOD) + # Get the result from the current worker's results queue. + # Use blocking/strict round robin if shuffle_rows is disabled or the seed is set + result = self._results_queues[self._get_results_worker_id].get( + block=not use_non_blocking_get, timeout=_VERIFY_END_OF_VENTILATION_PERIOD) + print(f'DEBUG: Result from worker {self._get_results_worker_id}: {result}') + # If the result is a VentilatedItemProcessedMessage, we need to increment the count of items + # processed by the current worker if isinstance(result, VentilatedItemProcessedMessage): - self._ventilated_items_processed += 1 + self._ventilated_items_processed_by_worker[self._get_results_worker_id] += 1 if self._ventilator: self._ventilator.processed_item() + # Move to the next worker + self._get_results_worker_id = (self._get_results_worker_id + 1) % self.workers_count continue elif isinstance(result, Exception): self.stop() @@ -197,7 +251,7 @@ def join(self): stats = pstats.Stats(w.prof) stats.sort_stats('cumulative').print_stats() - def _stop_aware_put(self, data): + def _stop_aware_put(self, worker_id, data): """This method is called to write the results to the results queue. We use ``put`` in a non-blocking way so we can gracefully terminate the worker thread without being stuck on :func:`Queue.put`. @@ -205,7 +259,7 @@ def _stop_aware_put(self, data): :func:`WorkerThread.run` which will gracefully terminate main worker loop.""" while True: try: - self._results_queue.put(data, block=True, timeout=IO_TIMEOUT_INTERVAL_S) + self._results_queues[worker_id].put(data, block=True, timeout=IO_TIMEOUT_INTERVAL_S) return except queue.Full: pass @@ -214,7 +268,7 @@ def _stop_aware_put(self, data): raise WorkerTerminationRequested() def results_qsize(self): - return self._results_queue.qsize() + return sum(queue.qsize() for queue in self._results_queues) @property def diagnostics(self): diff --git a/petastorm/workers_pool/ventilator.py b/petastorm/workers_pool/ventilator.py index 0f26bec1..d542b586 100644 --- a/petastorm/workers_pool/ventilator.py +++ b/petastorm/workers_pool/ventilator.py @@ -18,6 +18,10 @@ from time import sleep import six +import logging + +# Initialize logger +logger = logging.getLogger(__name__) _VENTILATION_INTERVAL = 0.01 @@ -98,7 +102,8 @@ def __init__(self, self._items_to_ventilate = items_to_ventilate self._iterations_remaining = iterations self._randomize_item_order = randomize_item_order - self._random_state = np.random.RandomState(seed=random_seed) + self._random_seed = random_seed + self._rng = np.random.default_rng(self._random_seed) self._iterations = iterations # For the default max ventilation queue size we will use the size of the items to ventilate @@ -136,15 +141,22 @@ def reset(self): self.start() def _ventilate(self): + # Randomize the item order before starting the ventilation if randomize_item_order is set + print(f'DEBUG: Items to ventilate before shuffle: {self._items_to_ventilate}') + if self._randomize_item_order: + if self._random_seed is not None and self._random_seed != 0: + # Deterministic randomization: use provided seed + self._items_to_ventilate = list(self._rng.permutation(self._items_to_ventilate)) + else: + # Non-deterministic randomization: use np.random + self._items_to_ventilate = list(np.random.permutation(self._items_to_ventilate)) + print(f'DEBUG: Items to ventilate after shuffle: {self._items_to_ventilate}') + while True: # Stop condition is when no iterations are remaining or there are no items to ventilate if self.completed(): break - # If we are ventilating the first item, we check if we would like to randomize the item order - if self._current_item_to_ventilate == 0 and self._randomize_item_order: - self._random_state.shuffle(self._items_to_ventilate) - # Block until queue has room, but use continue to allow for checking if stop has been called if self._ventilated_items_count - self._processed_items_count >= self._max_ventilation_queue_size: sleep(self._ventilation_interval)