diff --git a/storey/flow.py b/storey/flow.py index ec95e2aa..97cc3c2b 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -393,7 +393,7 @@ async def _do(self, event): class _UnaryFunctionFlow(Flow): - def __init__(self, fn, long_running=None, **kwargs): + def __init__(self, fn, long_running=None, pass_context=None, **kwargs): super().__init__(**kwargs) if not callable(fn): raise TypeError(f"Expected a callable, got {type(fn)}") @@ -402,12 +402,16 @@ def __init__(self, fn, long_running=None, **kwargs): raise ValueError("long_running=True cannot be used in conjunction with a coroutine") self._long_running = long_running self._fn = fn + self._pass_context = pass_context async def _call(self, element): if self._long_running: res = await asyncio.get_running_loop().run_in_executor(None, self._fn, element) else: - res = self._fn(element) + kwargs = {} + if self._pass_context: + kwargs = {"context": self.context} + res = self._fn(element, **kwargs) if self._is_async: res = await res return res diff --git a/tests/test_flow.py b/tests/test_flow.py index 951ea1a8..c611fec9 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -108,6 +108,22 @@ def test_functional_flow(): assert termination_result == 3300 +def test_pass_context_to_function(): + controller = build_flow( + [ + SyncEmitSource(), + Map(lambda x, context: x + context, pass_context=True, context=10), + Reduce(0, lambda acc, x: acc + x), + ] + ).run() + + for i in range(5): + controller.emit(i) + controller.terminate() + termination_result = controller.await_termination() + assert termination_result == 60 + + class Committer: def __init__(self): self.offsets = {}