From e0b30bf883f23b6bd6b72fe1691911062b7881da Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Wed, 18 Jan 2023 11:54:38 -0800 Subject: [PATCH] [Datasets] Add support for callable classes to new execution backend. (#31706) This PR adds support for CallableClass UDFs in map APIs to the new execution backend, and also makes the requisite test modifications to confirm that the new execution backend supports passthrough args. --- python/ray/data/__init__.py | 6 +++ .../data/_internal/execution/legacy_compat.py | 51 ++++++++++++++++--- python/ray/data/tests/test_dataset.py | 35 ++++++++----- 3 files changed, 71 insertions(+), 21 deletions(-) diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index f2a98463eb82..9a114f84490a 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -37,6 +37,12 @@ read_tfrecords, ) + +# Module-level cached global functions for callable classes. It needs to be defined here +# since it has to be process-global across cloudpickled funcs. +_cached_fn = None +_cached_cls = None + __all__ = [ "ActorPoolStrategy", "Dataset", diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index 34d7a95c65b0..b487412f82ee 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -4,7 +4,7 @@ """ import ray.cloudpickle as cloudpickle -from typing import Iterator, Tuple +from typing import Iterator, Tuple, Any import ray from ray.types import ObjectRef @@ -14,7 +14,12 @@ from ray.data._internal.stage_impl import RandomizeBlocksStage from ray.data._internal.block_list import BlockList from ray.data._internal.lazy_block_list import LazyBlockList -from ray.data._internal.compute import get_compute +from ray.data._internal.compute import ( + get_compute, + CallableClass, + TaskPoolStrategy, + ActorPoolStrategy, +) from ray.data._internal.memory_tracing import trace_allocation from ray.data._internal.plan import ExecutionPlan, OneToOneStage, AllToAllStage, Stage from ray.data._internal.execution.operators.map_operator import MapOperator @@ -162,12 +167,44 @@ def _stage_to_operator(stage: Stage, input_op: PhysicalOperator) -> PhysicalOper """ if isinstance(stage, OneToOneStage): - if stage.fn_constructor_args or stage.fn_constructor_kwargs: - raise NotImplementedError + compute = get_compute(stage.compute) block_fn = stage.block_fn - # TODO: implement arg packing and passing for test_map_batches_extra_args - fn_args = (stage.fn,) if stage.fn else () + if stage.fn: + if isinstance(stage.fn, CallableClass): + if isinstance(compute, TaskPoolStrategy): + raise ValueError( + "``compute`` must be specified when using a callable class, " + "and must specify the actor compute strategy. " + 'For example, use ``compute="actors"`` or ' + "``compute=ActorPoolStrategy(min, max)``." + ) + assert isinstance(compute, ActorPoolStrategy) + + fn_constructor_args = stage.fn_constructor_args or () + fn_constructor_kwargs = stage.fn_constructor_kwargs or {} + fn_ = stage.fn + + def fn(item: Any) -> Any: + # Wrapper providing cached instantiation of stateful callable class + # UDFs. + if ray.data._cached_fn is None: + ray.data._cached_cls = fn_ + ray.data._cached_fn = fn_( + *fn_constructor_args, **fn_constructor_kwargs + ) + else: + # A worker is destroyed when its actor is killed, so we + # shouldn't have any worker reuse across different UDF + # applications (i.e. different map operators). + assert ray.data._cached_cls == fn_ + return ray.data._cached_fn(item) + + else: + fn = stage.fn + fn_args = (fn,) + else: + fn_args = () if stage.fn_args: fn_args += stage.fn_args fn_kwargs = stage.fn_kwargs or {} @@ -179,7 +216,7 @@ def do_map(blocks: Iterator[Block]) -> Iterator[Block]: do_map, input_op, name=stage.name, - compute_strategy=get_compute(stage.compute), + compute_strategy=compute, min_rows_per_bundle=stage.target_block_size, ray_remote_args=stage.ray_remote_args, ) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index d7935fa5b879..4c46171cf8e5 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -2432,6 +2432,13 @@ def test_map_batches_basic(ray_start_regular_shared, tmp_path): def test_map_batches_extra_args(ray_start_regular_shared, tmp_path): + def put(x): + # We only support automatic deref in the legacy backend. + if DatasetContext.get_current().new_execution_backend: + return x + else: + return ray.put(x) + # Test input validation ds = ray.data.range(5) @@ -2483,7 +2490,7 @@ def udf(batch, a): udf, batch_size=1, batch_format="pandas", - fn_args=(ray.put(1),), + fn_args=(put(1),), ) assert ds2.dataset_format() == "pandas" ds_list = ds2.take() @@ -2502,7 +2509,7 @@ def udf(batch, b=None): udf, batch_size=1, batch_format="pandas", - fn_kwargs={"b": ray.put(2)}, + fn_kwargs={"b": put(2)}, ) assert ds2.dataset_format() == "pandas" ds_list = ds2.take() @@ -2522,8 +2529,8 @@ def udf(batch, a, b=None): udf, batch_size=1, batch_format="pandas", - fn_args=(ray.put(1),), - fn_kwargs={"b": ray.put(2)}, + fn_args=(put(1),), + fn_kwargs={"b": put(2)}, ) assert ds2.dataset_format() == "pandas" ds_list = ds2.take() @@ -2548,7 +2555,7 @@ def __call__(self, x): batch_size=1, batch_format="pandas", compute="actors", - fn_constructor_args=(ray.put(1),), + fn_constructor_args=(put(1),), ) assert ds2.dataset_format() == "pandas" ds_list = ds2.take() @@ -2572,7 +2579,7 @@ def __call__(self, x): batch_size=1, batch_format="pandas", compute="actors", - fn_constructor_kwargs={"b": ray.put(2)}, + fn_constructor_kwargs={"b": put(2)}, ) assert ds2.dataset_format() == "pandas" ds_list = ds2.take() @@ -2598,8 +2605,8 @@ def __call__(self, x): batch_size=1, batch_format="pandas", compute="actors", - fn_constructor_args=(ray.put(1),), - fn_constructor_kwargs={"b": ray.put(2)}, + fn_constructor_args=(put(1),), + fn_constructor_kwargs={"b": put(2)}, ) assert ds2.dataset_format() == "pandas" ds_list = ds2.take() @@ -2610,8 +2617,8 @@ def __call__(self, x): # Test callable chain. ds = ray.data.read_parquet(str(tmp_path)) - fn_constructor_args = (ray.put(1),) - fn_constructor_kwargs = {"b": ray.put(2)} + fn_constructor_args = (put(1),) + fn_constructor_kwargs = {"b": put(2)} ds2 = ( ds.lazy() .map_batches( @@ -2640,8 +2647,8 @@ def __call__(self, x): # Test function + callable chain. ds = ray.data.read_parquet(str(tmp_path)) - fn_constructor_args = (ray.put(1),) - fn_constructor_kwargs = {"b": ray.put(2)} + fn_constructor_args = (put(1),) + fn_constructor_kwargs = {"b": put(2)} ds2 = ( ds.lazy() .map_batches( @@ -2649,8 +2656,8 @@ def __call__(self, x): batch_size=1, batch_format="pandas", compute="actors", - fn_args=(ray.put(1),), - fn_kwargs={"b": ray.put(2)}, + fn_args=(put(1),), + fn_kwargs={"b": put(2)}, ) .map_batches( CallableFn,