Skip to content

Commit

Permalink
Allow providing worker initialization function in mp_prefetch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731433157
  • Loading branch information
iindyk authored and copybara-github committed Feb 26, 2025
1 parent 5725eb3 commit 57d7df0
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 16 deletions.
5 changes: 5 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def prefetch(
def mp_prefetch(
self,
options: grain_options.MultiprocessingOptions | None = None,
worker_init_fn: Callable[[int, int], None] | None = None,
) -> IterDataset[T]:
"""Returns a dataset prefetching elements in multiple processes.
Expand All @@ -1101,6 +1102,9 @@ def mp_prefetch(
be greater than or equal to 0. If `options.num_workers` is 0,
`mp_prefetch` has no effect. Defaults to
`MultiprocessingOptions(num_workers=10)`.
worker_init_fn: A function that is called in each worker process before
the data is processed. The function takes two arguments: the current
worker index and the total worker count.
Returns:
A dataset prefetching input elements in separate processes.
Expand All @@ -1113,6 +1117,7 @@ def mp_prefetch(
return prefetch.MultiprocessPrefetchIterDataset(
self,
multiprocessing_options=options,
worker_init_fn=worker_init_fn,
)

@abc.abstractmethod
Expand Down
13 changes: 9 additions & 4 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(
self,
parent: dataset.IterDataset[T],
multiprocessing_options: grain_options.MultiprocessingOptions,
worker_init_fn: Callable[[int, int], None] | None = None,
):
if multiprocessing_options.num_workers < 0:
raise ValueError(
Expand All @@ -230,6 +231,7 @@ def __init__(
)
super().__init__(parent)
self._multiprocessing_options = multiprocessing_options
self._worker_init_fn = worker_init_fn
self._validate_parent_dataset()

def __str__(self) -> str:
Expand All @@ -253,8 +255,8 @@ def _validate_parent_dataset(self) -> None:
def __iter__(self) -> dataset.DatasetIterator[T]:
if self._multiprocessing_options.num_workers == 0:
return self._parent.__iter__()
return MultiprocessPrefetchDatasetIterator(
self._parent, self._multiprocessing_options
return _MultiprocessPrefetchDatasetIterator(
self._parent, self._multiprocessing_options, self._worker_init_fn
)


Expand Down Expand Up @@ -370,7 +372,7 @@ def _check_picklable(
class GetElementProducerFn(grain_pool.GetElementProducerFn, Generic[T]):
"""Implements `GetElementProducerFn` for `grain_pool.MultiProcessIterator`.
This class implements `GetElementProducerFn` with `serialize` being overriden
This class implements `GetElementProducerFn` with `serialize` being overridden
to generate better error messages if user-provided dataset is not pickle-able.
"""

Expand Down Expand Up @@ -438,13 +440,14 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions:
return result


class MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]):
class _MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]):
"""Iterator that performs prefetching using a multiprocessing pool."""

def __init__(
self,
parent: dataset.IterDataset[T],
multiprocessing_options: grain_options.MultiprocessingOptions,
worker_init_fn: Callable[[int, int], None] | None = None,
):
super().__init__()
self._iter_parent = parent
Expand All @@ -453,6 +456,7 @@ def __init__(
# propagate them.
self._ctx.dataset_options = _get_dataset_options(parent)
self._multiprocessing_options = multiprocessing_options
self._worker_init_fn = worker_init_fn
# The underlying iterator producing elements and workers state.
self._iterator = None
# Raw reference to the underlying iterator that can be used to determine the
Expand Down Expand Up @@ -557,6 +561,7 @@ def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]:
self._multiprocessing_options,
(self._state[_LAST_WORKER_INDEX] + 1)
% self._multiprocessing_options.num_workers,
self._worker_init_fn,
)

def __str__(self) -> str:
Expand Down
40 changes: 37 additions & 3 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
from concurrent import futures
import dataclasses
import logging as std_logging
import sys
import time
from typing import TypeVar, cast
from unittest import mock

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -530,8 +530,7 @@ def map(self, features):
options.MultiprocessingOptions(num_workers, per_worker_buffer_size),
)

it = iter(ds)
assert isinstance(it, prefetch.MultiprocessPrefetchDatasetIterator)
it = ds.__iter__()
for _ in range(start_prefetch_calls):
it.start_prefetch()

Expand Down Expand Up @@ -611,6 +610,41 @@ def test_options_after_prefetch(self):
with self.assertRaises(Exception):
list(ds)

def test_worker_init_fn(self):
def set_worker_index_and_count(worker_index: int, worker_count: int):
log_formatter = std_logging.Formatter(
f'[Worker {worker_index} out of {worker_count}] %(message)s'
)
logging.get_absl_handler().setFormatter(log_formatter)

def map_fn(x):
# absl logging from workers is not propagated to the main process in unit
# tests. Therefore, we manually pass the formatted log message.
record = logging.get_absl_logger().makeRecord(
'grain',
logging.INFO,
'grain_pool_test',
123,
f'processing element {x}',
(),
None,
)
return logging.get_absl_handler().format(record)

ds = dataset.MapDataset.range(2).map(map_fn)
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch(
options.MultiprocessingOptions(num_workers=2),
worker_init_fn=set_worker_index_and_count,
)
self.assertEqual(
list(ds),
[
'[Worker 0 out of 2] processing element 0',
'[Worker 1 out of 2] processing element 1',
],
)


class ThreadPrefetchIterDatasetTest(parameterized.TestCase):

Expand Down
34 changes: 25 additions & 9 deletions grain/_src/python/grain_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""This module provides a way to distribute processing across multiple workers.
In the context of Grain we use the term "process" similar to JAX, where usually
each machine runs one Python process (identified by `jax.proccess_index()`).
each machine runs one Python process (identified by `jax.process_index()`).
In Grain each "process" can create additional Python child processes that we
call "workers".
Expand All @@ -32,8 +32,8 @@
* Child processes are launched as Daemon processes. In case of (unexpected)
parent termination, child processes will be terminated by OS.
* System uses a multiprocessing event ("termination_event") for termination.
Parent and child processes continously check if the "termination_event" and if
set, they break from what they are doing.
Parent and child processes continuously check if the "termination_event" and
if set, they break from what they are doing.
* We never block indefinitely when calling get() or put() on a queue. This
ensures parent and child processes continue to check the termination_event.
Expand Down Expand Up @@ -176,15 +176,18 @@ def _initialize_and_get_element_producer(
"""Unpickles the element producer from the args queue and closes the queue."""
(
serialized_flag_parse_fn,
serialized_init_fn,
serialized_init_fns,
serialized_element_producer_fn,
) = args_queue.get()
flag_parse_fn: Callable[[Any], None] = cloudpickle.loads(
serialized_flag_parse_fn
)
flag_parse_fn(debug_flags)
init_fn: Callable[[], None] = cloudpickle.loads(serialized_init_fn)
init_fn()
init_fns: list[Callable[[int, int], None]] = cloudpickle.loads(
serialized_init_fns
)
for init_fn in init_fns:
init_fn(worker_index, worker_count)
element_producer_fn: GetElementProducerFn[Any] = (
GetElementProducerFn.deserialize(serialized_element_producer_fn)
)
Expand Down Expand Up @@ -307,6 +310,7 @@ def __init__(
worker_index_to_start_reading: int = 0,
termination_event: threading.Event | None = None,
options: MultiprocessingOptions,
worker_init_fn: Callable[[int, int], None] | None = None,
):
"""Initialise a Grain Pool.
Expand All @@ -320,6 +324,9 @@ def __init__(
the pool will terminate when either one of the workers failed or when
all workers are done processing data. GrainPool will not set this event.
options: Options for multiprocessing. See MultiprocessingOptions.
worker_init_fn: Function to run in each worker process before the element
producer. The function takes two arguments: the current worker index and
the total worker count.
"""
self.num_processes = options.num_workers
logging.info("Grain pool will use %i processes.", self.num_processes)
Expand All @@ -333,6 +340,7 @@ def __init__(
# separate events.
self._reader_termination_event = termination_event or threading.Event()
self._workers_termination_event = ctx.Event()
self._worker_init_fn = worker_init_fn
self.completed_processes = set()
# Queue to propagate errors from child processes to the parent. Note that
# this queue is shared by all child processes.
Expand Down Expand Up @@ -372,12 +380,12 @@ def __init__(
# absl.app.run() is called. We send arguments via a queue to ensure that
# they are unpickled after absl.app.run() was called in the child
# processes.
worker_init_fn = lambda: None
worker_init_fns = [self._worker_init_fn] if self._worker_init_fn else []
parse_debug_flags_fn = parse_debug_flags
worker_init_fn = cloudpickle.dumps(worker_init_fn)
worker_init_fns = cloudpickle.dumps(worker_init_fns)
parse_debug_flags_fn = cloudpickle.dumps(parse_debug_flags_fn)
worker_args_queue.put(
(parse_debug_flags_fn, worker_init_fn, get_element_producer_fn)
(parse_debug_flags_fn, worker_init_fns, get_element_producer_fn)
)
process = ctx.Process( # pytype: disable=attribute-error # re-none
target=_worker_loop, kwargs=process_kwargs, daemon=True
Expand Down Expand Up @@ -561,6 +569,7 @@ def _process_elements_in_grain_pool(
thread_pool: pool.ThreadPool,
termination_event: threading.Event,
worker_index_to_start_reading: int,
worker_init_fn: Callable[[int, int], None] | None,
) -> None:
"""Processes elements in grain worker pool asynchronously."""

Expand All @@ -576,6 +585,7 @@ def read_thread_should_stop():
worker_index_to_start_reading=worker_index_to_start_reading,
termination_event=termination_event,
options=multiprocessing_options,
worker_init_fn=worker_init_fn,
) as g_pool:
for element in g_pool:
if read_thread_should_stop():
Expand Down Expand Up @@ -628,6 +638,7 @@ def __init__(
get_element_producer_fn: GetElementProducerFn,
multiprocessing_options: MultiprocessingOptions,
worker_index_to_start_reading: int,
worker_init_fn: Callable[[int, int], None] | None = None,
):
"""Initializes MultiProcessIterator.
Expand All @@ -637,10 +648,14 @@ def __init__(
multiprocessing_options: options for distributing the record iterators.
worker_index_to_start_reading: Index of the next worker to read from. This
is useful for recovering from a checkpoint.
worker_init_fn: Function to run in each worker process before the element
producer. The function takes two arguments: the current worker index and
the total worker count.
"""
self._get_element_producer_fn = get_element_producer_fn
self._multiprocessing_options = multiprocessing_options
self._last_worker_index = worker_index_to_start_reading - 1
self._worker_init_fn = worker_init_fn
self._reader_queue = None
self._reader_thread_pool = None
self._termination_event = None
Expand Down Expand Up @@ -673,6 +688,7 @@ def start_prefetch(self) -> None:
thread_pool=self._reader_thread_pool,
termination_event=self._termination_event,
worker_index_to_start_reading=self._last_worker_index + 1,
worker_init_fn=self._worker_init_fn,
),
)
self._reader_thread.start()
Expand Down
22 changes: 22 additions & 0 deletions grain/_src/python/grain_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,28 @@ def __call__(
with self.assertRaisesRegex(ValueError, error_msg):
list(iterator)

def test_worker_init_fn(self):

def _set_worker_index_and_count(worker_index: int, worker_count: int):
gp.monkey_patched_index_and_count = (worker_index, worker_count)

class GetElementProducerFnReturningGlobal(gp.GetElementProducerFn):

def __call__(
self, *, worker_index: int, worker_count: int
) -> Iterator[tuple[int, int]]:
del self, worker_index, worker_count
yield gp.monkey_patched_index_and_count # pytype: disable=module-attr

with gp.MultiProcessIterator(
GetElementProducerFnReturningGlobal(),
MultiprocessingOptions(num_workers=2),
0,
worker_init_fn=_set_worker_index_and_count,
) as iterator:
result = list(iterator)
self.assertEqual(result, [(0, 2), (1, 2)])


if __name__ == "__main__":
absltest.main()

0 comments on commit 57d7df0

Please sign in to comment.