diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index 441878cf7687a..1cf912a9cb5d9 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -17,6 +17,7 @@ PhysicalOperator, RefBundle, ) +from ray.data._internal.execution.interfaces.task_context import TaskContext from ray.data._internal.execution.operators.actor_pool_map_operator import ( ActorPoolMapOperator, ) @@ -1022,6 +1023,45 @@ def yield_five(block_iter: Iterable[Block], ctx) -> Iterable[Block]: assert op._estimated_num_output_bundles == 100 +@pytest.mark.parametrize("use_actors", [False, True]) +def test_map_kwargs(ray_start_regular_shared, use_actors): + """Test propagating additional kwargs to map tasks.""" + foo = 1 + bar = ray.put(np.zeros(1024 * 1024)) + kwargs = { + "foo": foo, # Pass by value + "bar": bar, # Pass by ObjectRef + } + + def map_fn(block_iter: Iterable[Block], ctx: TaskContext) -> Iterable[Block]: + assert ctx.kwargs["foo"] == foo + # bar should be automatically deref'ed. + assert ctx.kwargs["bar"] == bar + + yield from block_iter + + input_op = InputDataBuffer( + DataContext.get_current(), make_ref_bundles([[i] for i in range(100)]) + ) + compute_strategy = ActorPoolStrategy() if use_actors else TaskPoolStrategy() + op = MapOperator.create( + create_map_transformer_from_block_fn(map_fn), + input_op=input_op, + data_context=DataContext.get_current(), + name="TestMapper", + compute_strategy=compute_strategy, + ) + op.add_map_task_kwargs_fn(lambda: kwargs) + op.start(ExecutionOptions()) + while input_op.has_next(): + op.add_input(input_op.get_next(), 0) + op.all_inputs_done() + run_op_tasks_sync(op) + + _take_outputs(op) + assert op.completed() + + def test_limit_estimated_num_output_bundles(): # Test limit operator estimation input_op = InputDataBuffer(