Skip to content

Commit

Permalink
[Data] Allow fusing MapOperator -> Repartition operators (ray-pro…
Browse files Browse the repository at this point in the history
…ject#35178)

As a followup to ray-project#34847, allow fusing `MapOperator` -> `Repartition` operators for the shuffle repartition case (we do not support fusing for split repartition, which only uses `ShuffleTaskSpec.reduce` and thus cannot call the upstream map function passed to `ShuffleTaskSpec.map`).

Signed-off-by: Scott Lee <sjl@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
scottjlee authored and arvind-chandra committed Aug 31, 2023
1 parent f013431 commit fb3542a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
26 changes: 22 additions & 4 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterator, List, Tuple
from ray.data._internal.logical.operators.all_to_all_operator import Repartition
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.logical.operators.all_to_all_operator import (
AbstractAllToAll,
Expand Down Expand Up @@ -107,8 +108,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:

# We currently only support fusing for the following cases:
# - MapOperator -> MapOperator
# - MapOperator -> AllToAllOperator (only RandomShuffle
# LogicalOperator is currently supported)
# - MapOperator -> AllToAllOperator
# (only RandomShuffle and Repartition LogicalOperators are currently supported)
if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance(
up_op, MapOperator
):
Expand All @@ -125,11 +126,17 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
# We currently only support fusing for the following cases:
# - AbstractMap -> AbstractMap
# - AbstractMap -> RandomShuffle
# - AbstractMap -> Repartition (shuffle=True)
if not isinstance(
down_logical_op, (AbstractMap, RandomShuffle)
down_logical_op, (AbstractMap, RandomShuffle, Repartition)
) or not isinstance(up_logical_op, AbstractMap):
return False

# Do not fuse Repartition operator if shuffle is disabled
# (i.e. using split shuffle).
if isinstance(down_logical_op, Repartition) and not down_logical_op._shuffle:
return False

# Allow fusing tasks->actors if the resources are compatible (read->map), but
# not the other way around. The latter (downstream op) will be used as the
# compute if fused.
Expand Down Expand Up @@ -306,7 +313,18 @@ def fused_all_to_all_transform_fn(
# Bottom out at the source logical op (e.g. Read()).
input_op = up_logical_op

logical_op = RandomShuffle(input_op, name=name, ray_remote_args=ray_remote_args)
if isinstance(down_logical_op, RandomShuffle):
logical_op = RandomShuffle(
input_op,
name=name,
ray_remote_args=ray_remote_args,
)
elif isinstance(down_logical_op, Repartition):
logical_op = Repartition(
input_op,
num_outputs=down_logical_op._num_outputs,
shuffle=down_logical_op._shuffle,
)
self._op_map[op] = logical_op
# Return the fused physical operator.
return op
Expand Down
18 changes: 16 additions & 2 deletions python/ray/data/_internal/planner/repartition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Optional, Tuple, TYPE_CHECKING

from ray.data._internal.execution.interfaces import (
AllToAllTransformFn,
Expand All @@ -19,6 +19,9 @@
from ray.data._internal.stats import StatsDict
from ray.data.context import DataContext

if TYPE_CHECKING:
from python.ray.data._internal.execution.interfaces import MapTransformFn


def generate_repartition_fn(
num_outputs: int,
Expand All @@ -30,7 +33,18 @@ def shuffle_repartition_fn(
refs: List[RefBundle],
ctx: TaskContext,
) -> Tuple[List[RefBundle], StatsDict]:
shuffle_spec = ShuffleTaskSpec(random_shuffle=False)
# If map_transform_fn is specified (e.g. from fusing
# MapOperator->AllToAllOperator), we pass a map function which
# is applied to each block before shuffling.
map_transform_fn: Optional["MapTransformFn"] = ctx.upstream_map_transform_fn
upstream_map_fn = None
if map_transform_fn:
upstream_map_fn = lambda block: map_transform_fn(block, ctx) # noqa: E731

shuffle_spec = ShuffleTaskSpec(
random_shuffle=False,
upstream_map_fn=upstream_map_fn,
)

if DataContext.get_current().use_push_based_shuffle:
scheduler = PushBasedShuffleTaskScheduler(shuffle_spec)
Expand Down
27 changes: 15 additions & 12 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,11 @@ def test_repartition_e2e(
def _check_repartition_usage_and_stats(ds):
_check_usage_record(["ReadRange", "Repartition"])
ds_stats: DatastreamStats = ds._plan.stats()
assert ds_stats.base_name == "Repartition"
if shuffle:
assert "RepartitionMap" in ds_stats.stages
assert ds_stats.base_name == "DoRead->Repartition"
assert "DoRead->RepartitionMap" in ds_stats.stages
else:
assert ds_stats.base_name == "Repartition"
assert "RepartitionSplit" in ds_stats.stages
assert "RepartitionReduce" in ds_stats.stages

Expand Down Expand Up @@ -630,7 +631,7 @@ def fn(batch):


def test_read_map_batches_operator_fusion_with_random_shuffle_operator(
ray_start_regular_shared, enable_optimizer
ray_start_regular_shared, enable_optimizer, use_push_based_shuffle
):
# Note: we currently only support fusing MapOperator->AllToAllOperator.
def fn(batch):
Expand Down Expand Up @@ -676,24 +677,26 @@ def fn(batch):
_check_usage_record(["ReadRange", "RandomShuffle", "MapBatches"])


@pytest.mark.parametrize("shuffle", (True, False))
def test_read_map_batches_operator_fusion_with_repartition_operator(
ray_start_regular_shared, enable_optimizer
ray_start_regular_shared, enable_optimizer, shuffle, use_push_based_shuffle
):
# Note: We currently do not fuse MapBatches->Repartition.
# This test is to ensure that we don't accidentally fuse them, until
# we implement it later.
def fn(batch):
return {"id": [x + 1 for x in batch["id"]]}

n = 10
ds = ray.data.range(n)
ds = ds.map_batches(fn, batch_size=None)
ds = ds.repartition(2)
ds = ds.repartition(2, shuffle=shuffle)
assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1))
# TODO(Scott): update the below assertions after we support fusion.
assert "DoRead->MapBatches->Repartition" not in ds.stats()
assert "DoRead->MapBatches" in ds.stats()
assert "Repartition" in ds.stats()

# Operator fusion is only supported for shuffle repartition.
if shuffle:
assert "DoRead->MapBatches->Repartition" in ds.stats()
else:
assert "DoRead->MapBatches->Repartition" not in ds.stats()
assert "DoRead->MapBatches" in ds.stats()
assert "Repartition" in ds.stats()
_check_usage_record(["ReadRange", "MapBatches", "Repartition"])


Expand Down

0 comments on commit fb3542a

Please sign in to comment.