From 57d7df064b0d7c388c41d41f17bac58e81e2c4b0 Mon Sep 17 00:00:00 2001 From: Ihor Indyk Date: Wed, 26 Feb 2025 13:22:31 -0800 Subject: [PATCH] Allow providing worker initialization function in `mp_prefetch`. PiperOrigin-RevId: 731433157 --- grain/_src/python/dataset/dataset.py | 5 +++ .../dataset/transformations/prefetch.py | 13 ++++-- .../dataset/transformations/prefetch_test.py | 40 +++++++++++++++++-- grain/_src/python/grain_pool.py | 34 +++++++++++----- grain/_src/python/grain_pool_test.py | 22 ++++++++++ 5 files changed, 98 insertions(+), 16 deletions(-) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index d5f0fb06..ad2dac36 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1082,6 +1082,7 @@ def prefetch( def mp_prefetch( self, options: grain_options.MultiprocessingOptions | None = None, + worker_init_fn: Callable[[int, int], None] | None = None, ) -> IterDataset[T]: """Returns a dataset prefetching elements in multiple processes. @@ -1101,6 +1102,9 @@ def mp_prefetch( be greater than or equal to 0. If `options.num_workers` is 0, `mp_prefetch` has no effect. Defaults to `MultiprocessingOptions(num_workers=10)`. + worker_init_fn: A function that is called in each worker process before + the data is processed. The function takes two arguments: the current + worker index and the total worker count. Returns: A dataset prefetching input elements in separate processes. @@ -1113,6 +1117,7 @@ def mp_prefetch( return prefetch.MultiprocessPrefetchIterDataset( self, multiprocessing_options=options, + worker_init_fn=worker_init_fn, ) @abc.abstractmethod diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index a060cc40..e1702b32 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -222,6 +222,7 @@ def __init__( self, parent: dataset.IterDataset[T], multiprocessing_options: grain_options.MultiprocessingOptions, + worker_init_fn: Callable[[int, int], None] | None = None, ): if multiprocessing_options.num_workers < 0: raise ValueError( @@ -230,6 +231,7 @@ def __init__( ) super().__init__(parent) self._multiprocessing_options = multiprocessing_options + self._worker_init_fn = worker_init_fn self._validate_parent_dataset() def __str__(self) -> str: @@ -253,8 +255,8 @@ def _validate_parent_dataset(self) -> None: def __iter__(self) -> dataset.DatasetIterator[T]: if self._multiprocessing_options.num_workers == 0: return self._parent.__iter__() - return MultiprocessPrefetchDatasetIterator( - self._parent, self._multiprocessing_options + return _MultiprocessPrefetchDatasetIterator( + self._parent, self._multiprocessing_options, self._worker_init_fn ) @@ -370,7 +372,7 @@ def _check_picklable( class GetElementProducerFn(grain_pool.GetElementProducerFn, Generic[T]): """Implements `GetElementProducerFn` for `grain_pool.MultiProcessIterator`. - This class implements `GetElementProducerFn` with `serialize` being overriden + This class implements `GetElementProducerFn` with `serialize` being overridden to generate better error messages if user-provided dataset is not pickle-able. """ @@ -438,13 +440,14 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: return result -class MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]): +class _MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]): """Iterator that performs prefetching using a multiprocessing pool.""" def __init__( self, parent: dataset.IterDataset[T], multiprocessing_options: grain_options.MultiprocessingOptions, + worker_init_fn: Callable[[int, int], None] | None = None, ): super().__init__() self._iter_parent = parent @@ -453,6 +456,7 @@ def __init__( # propagate them. self._ctx.dataset_options = _get_dataset_options(parent) self._multiprocessing_options = multiprocessing_options + self._worker_init_fn = worker_init_fn # The underlying iterator producing elements and workers state. self._iterator = None # Raw reference to the underlying iterator that can be used to determine the @@ -557,6 +561,7 @@ def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]: self._multiprocessing_options, (self._state[_LAST_WORKER_INDEX] + 1) % self._multiprocessing_options.num_workers, + self._worker_init_fn, ) def __str__(self) -> str: diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index f079bea7..e7c62c71 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -13,11 +13,11 @@ # limitations under the License. from concurrent import futures import dataclasses +import logging as std_logging import sys import time from typing import TypeVar, cast from unittest import mock - from absl import logging from absl.testing import absltest from absl.testing import parameterized @@ -530,8 +530,7 @@ def map(self, features): options.MultiprocessingOptions(num_workers, per_worker_buffer_size), ) - it = iter(ds) - assert isinstance(it, prefetch.MultiprocessPrefetchDatasetIterator) + it = ds.__iter__() for _ in range(start_prefetch_calls): it.start_prefetch() @@ -611,6 +610,41 @@ def test_options_after_prefetch(self): with self.assertRaises(Exception): list(ds) + def test_worker_init_fn(self): + def set_worker_index_and_count(worker_index: int, worker_count: int): + log_formatter = std_logging.Formatter( + f'[Worker {worker_index} out of {worker_count}] %(message)s' + ) + logging.get_absl_handler().setFormatter(log_formatter) + + def map_fn(x): + # absl logging from workers is not propagated to the main process in unit + # tests. Therefore, we manually pass the formatted log message. + record = logging.get_absl_logger().makeRecord( + 'grain', + logging.INFO, + 'grain_pool_test', + 123, + f'processing element {x}', + (), + None, + ) + return logging.get_absl_handler().format(record) + + ds = dataset.MapDataset.range(2).map(map_fn) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch( + options.MultiprocessingOptions(num_workers=2), + worker_init_fn=set_worker_index_and_count, + ) + self.assertEqual( + list(ds), + [ + '[Worker 0 out of 2] processing element 0', + '[Worker 1 out of 2] processing element 1', + ], + ) + class ThreadPrefetchIterDatasetTest(parameterized.TestCase): diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index 237ae9d9..7f4f5c45 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -14,7 +14,7 @@ """This module provides a way to distribute processing across multiple workers. In the context of Grain we use the term "process" similar to JAX, where usually -each machine runs one Python process (identified by `jax.proccess_index()`). +each machine runs one Python process (identified by `jax.process_index()`). In Grain each "process" can create additional Python child processes that we call "workers". @@ -32,8 +32,8 @@ * Child processes are launched as Daemon processes. In case of (unexpected) parent termination, child processes will be terminated by OS. * System uses a multiprocessing event ("termination_event") for termination. - Parent and child processes continously check if the "termination_event" and if - set, they break from what they are doing. + Parent and child processes continuously check if the "termination_event" and + if set, they break from what they are doing. * We never block indefinitely when calling get() or put() on a queue. This ensures parent and child processes continue to check the termination_event. @@ -176,15 +176,18 @@ def _initialize_and_get_element_producer( """Unpickles the element producer from the args queue and closes the queue.""" ( serialized_flag_parse_fn, - serialized_init_fn, + serialized_init_fns, serialized_element_producer_fn, ) = args_queue.get() flag_parse_fn: Callable[[Any], None] = cloudpickle.loads( serialized_flag_parse_fn ) flag_parse_fn(debug_flags) - init_fn: Callable[[], None] = cloudpickle.loads(serialized_init_fn) - init_fn() + init_fns: list[Callable[[int, int], None]] = cloudpickle.loads( + serialized_init_fns + ) + for init_fn in init_fns: + init_fn(worker_index, worker_count) element_producer_fn: GetElementProducerFn[Any] = ( GetElementProducerFn.deserialize(serialized_element_producer_fn) ) @@ -307,6 +310,7 @@ def __init__( worker_index_to_start_reading: int = 0, termination_event: threading.Event | None = None, options: MultiprocessingOptions, + worker_init_fn: Callable[[int, int], None] | None = None, ): """Initialise a Grain Pool. @@ -320,6 +324,9 @@ def __init__( the pool will terminate when either one of the workers failed or when all workers are done processing data. GrainPool will not set this event. options: Options for multiprocessing. See MultiprocessingOptions. + worker_init_fn: Function to run in each worker process before the element + producer. The function takes two arguments: the current worker index and + the total worker count. """ self.num_processes = options.num_workers logging.info("Grain pool will use %i processes.", self.num_processes) @@ -333,6 +340,7 @@ def __init__( # separate events. self._reader_termination_event = termination_event or threading.Event() self._workers_termination_event = ctx.Event() + self._worker_init_fn = worker_init_fn self.completed_processes = set() # Queue to propagate errors from child processes to the parent. Note that # this queue is shared by all child processes. @@ -372,12 +380,12 @@ def __init__( # absl.app.run() is called. We send arguments via a queue to ensure that # they are unpickled after absl.app.run() was called in the child # processes. - worker_init_fn = lambda: None + worker_init_fns = [self._worker_init_fn] if self._worker_init_fn else [] parse_debug_flags_fn = parse_debug_flags - worker_init_fn = cloudpickle.dumps(worker_init_fn) + worker_init_fns = cloudpickle.dumps(worker_init_fns) parse_debug_flags_fn = cloudpickle.dumps(parse_debug_flags_fn) worker_args_queue.put( - (parse_debug_flags_fn, worker_init_fn, get_element_producer_fn) + (parse_debug_flags_fn, worker_init_fns, get_element_producer_fn) ) process = ctx.Process( # pytype: disable=attribute-error # re-none target=_worker_loop, kwargs=process_kwargs, daemon=True @@ -561,6 +569,7 @@ def _process_elements_in_grain_pool( thread_pool: pool.ThreadPool, termination_event: threading.Event, worker_index_to_start_reading: int, + worker_init_fn: Callable[[int, int], None] | None, ) -> None: """Processes elements in grain worker pool asynchronously.""" @@ -576,6 +585,7 @@ def read_thread_should_stop(): worker_index_to_start_reading=worker_index_to_start_reading, termination_event=termination_event, options=multiprocessing_options, + worker_init_fn=worker_init_fn, ) as g_pool: for element in g_pool: if read_thread_should_stop(): @@ -628,6 +638,7 @@ def __init__( get_element_producer_fn: GetElementProducerFn, multiprocessing_options: MultiprocessingOptions, worker_index_to_start_reading: int, + worker_init_fn: Callable[[int, int], None] | None = None, ): """Initializes MultiProcessIterator. @@ -637,10 +648,14 @@ def __init__( multiprocessing_options: options for distributing the record iterators. worker_index_to_start_reading: Index of the next worker to read from. This is useful for recovering from a checkpoint. + worker_init_fn: Function to run in each worker process before the element + producer. The function takes two arguments: the current worker index and + the total worker count. """ self._get_element_producer_fn = get_element_producer_fn self._multiprocessing_options = multiprocessing_options self._last_worker_index = worker_index_to_start_reading - 1 + self._worker_init_fn = worker_init_fn self._reader_queue = None self._reader_thread_pool = None self._termination_event = None @@ -673,6 +688,7 @@ def start_prefetch(self) -> None: thread_pool=self._reader_thread_pool, termination_event=self._termination_event, worker_index_to_start_reading=self._last_worker_index + 1, + worker_init_fn=self._worker_init_fn, ), ) self._reader_thread.start() diff --git a/grain/_src/python/grain_pool_test.py b/grain/_src/python/grain_pool_test.py index 095b9817..1bbc9bc3 100644 --- a/grain/_src/python/grain_pool_test.py +++ b/grain/_src/python/grain_pool_test.py @@ -420,6 +420,28 @@ def __call__( with self.assertRaisesRegex(ValueError, error_msg): list(iterator) + def test_worker_init_fn(self): + + def _set_worker_index_and_count(worker_index: int, worker_count: int): + gp.monkey_patched_index_and_count = (worker_index, worker_count) + + class GetElementProducerFnReturningGlobal(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int + ) -> Iterator[tuple[int, int]]: + del self, worker_index, worker_count + yield gp.monkey_patched_index_and_count # pytype: disable=module-attr + + with gp.MultiProcessIterator( + GetElementProducerFnReturningGlobal(), + MultiprocessingOptions(num_workers=2), + 0, + worker_init_fn=_set_worker_index_and_count, + ) as iterator: + result = list(iterator) + self.assertEqual(result, [(0, 2), (1, 2)]) + if __name__ == "__main__": absltest.main()