Skip to content

Commit

Permalink
[Data] [2/N] Enable optimizer: fix fusion (#35621)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

This PR is the 2nd part of enabling optimizer by default (split from #34937). 
It fixes the following issues:
- `ray_remote_args` not correctly set for a fused operator.
- `init_fn` not correctly set for a fused operator.
- Allowed cases for fusion (see `operator_fusion.py`).
- `ray_remote_args` compatibility check for fusion.
- Limit operator not handled when converting logical operator to physical.
- Other small fixes.

Note, some changes in this PR may not be covered in this PR's CI, as the optimizer must be enabled to cover them. But they are already verified in #34937's CI).

## Related issue number

#32596
  • Loading branch information
raulchen authored May 31, 2023
1 parent 89bc406 commit 6d18218
Show file tree
Hide file tree
Showing 20 changed files with 251 additions and 121 deletions.
15 changes: 14 additions & 1 deletion python/ray/data/_internal/execution/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
import os
from typing import Dict, List, Optional, Iterable, Iterator, Tuple, Callable, Union
from typing import Any, Dict, List, Optional, Iterable, Iterator, Tuple, Callable, Union

import ray
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -233,10 +233,23 @@ class TaskContext:
# TODO(chengsu): clean it up from TaskContext with new optimizer framework.
sub_progress_bar_dict: Optional[Dict[str, ProgressBar]] = None

# NOTE(hchen): `upstream_map_transform_fn` and `upstream_map_ray_remote_args`
# are only used for `RandomShuffle`. DO NOT use them for other operators.
# Ideally, they should be handled by the optimizer, and should be transparent
# to the specific operators.
# But for `RandomShuffle`, the AllToAllOperator doesn't do the shuffle itself.
# It uses `ExchangeTaskScheduler` to launch new tasks to do the shuffle.
# That's why we need to pass them to `ExchangeTaskScheduler`.
# TODO(hchen): Use a physical operator to do the shuffle directly.

# The underlying function called in a MapOperator; this is used when fusing
# an AllToAllOperator with an upstream MapOperator.
upstream_map_transform_fn: Optional["MapTransformFn"] = None

# The Ray remote arguments of the fused upstream MapOperator.
# This should be set if upstream_map_transform_fn is set.
upstream_map_ray_remote_args: Dict[str, Any] = None


# Block transform function applied by task and actor pools in MapOperator.
MapTransformFn = Callable[[Iterable[Block], TaskContext], Iterable[Block]]
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def _get_execution_dag(
record_operators_usage(plan._logical_plan.dag)

# Get DAG of physical operators and input statistics.
if DataContext.get_current().optimizer_enabled:
if (
DataContext.get_current().optimizer_enabled
# TODO(hchen): Remove this when all operators support logical plan.
and getattr(plan, "_logical_plan", None) is not None
):
dag = get_execution_plan(plan._logical_plan).dag
stats = _get_initial_stats_from_plan(plan)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(
self._inputs_done = False
self._next_task_idx = 0

def get_init_fn(self) -> Callable[[], None]:
return self._init_fn

def internal_queue_size(self) -> int:
return len(self._bundle_queue)

Expand Down
77 changes: 57 additions & 20 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Iterator, List, Tuple

from ray.data._internal.logical.operators.all_to_all_operator import Repartition
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.actor_pool_map_operator import (
ActorPoolMapOperator,
)
from ray.data._internal.execution.operators.task_pool_map_operator import (
TaskPoolMapOperator,
)
from ray.data._internal.logical.operators.all_to_all_operator import (
AbstractAllToAll,
RandomShuffle,
Expand Down Expand Up @@ -102,16 +109,22 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
the same class AND constructor args are the same for both.
* They have compatible remote arguments.
"""
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap

# We currently only support fusing for the following cases:
# - MapOperator -> MapOperator
# - MapOperator -> AllToAllOperator
# - TaskPoolMapOperator -> TaskPoolMapOperator/ActorPoolMapOperator
# - TaskPoolMapOperator -> AllToAllOperator
# (only RandomShuffle and Repartition LogicalOperators are currently supported)
if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance(
up_op, MapOperator
if not (
(
isinstance(up_op, TaskPoolMapOperator)
and isinstance(down_op, (TaskPoolMapOperator, ActorPoolMapOperator))
)
or (
isinstance(up_op, TaskPoolMapOperator)
and isinstance(down_op, AllToAllOperator)
)
):
return False

Expand Down Expand Up @@ -195,7 +208,11 @@ def _get_fused_map_operator(
up_logical_op = self._op_map.pop(up_op)

# Merge target block sizes.
down_target_block_size = down_logical_op._target_block_size
down_target_block_size = (
down_logical_op._target_block_size
if isinstance(down_logical_op, AbstractUDFMap)
else None
)
up_target_block_size = (
up_logical_op._target_block_size
if isinstance(up_logical_op, AbstractUDFMap)
Expand All @@ -219,10 +236,17 @@ def fused_map_transform_fn(
# TODO(Scott): Add zero-copy batching between transform functions.
return down_transform_fn(blocks, ctx)

# Fuse init funcitons.
fused_init_fn = (
down_op.get_init_fn() if isinstance(down_op, ActorPoolMapOperator) else None
)

# We take the downstream op's compute in case we're fusing upstream tasks with a
# downstream actor pool (e.g. read->map).
compute = get_compute(down_logical_op._compute)
ray_remote_args = down_logical_op._ray_remote_args
compute = None
if isinstance(down_logical_op, AbstractUDFMap):
compute = get_compute(down_logical_op._compute)
ray_remote_args = up_logical_op._ray_remote_args
# Make the upstream operator's inputs the new, fused operator's inputs.
input_deps = up_op.input_dependencies
assert len(input_deps) == 1
Expand All @@ -233,6 +257,7 @@ def fused_map_transform_fn(
fused_map_transform_fn,
input_op,
name=name,
init_fn=fused_init_fn,
compute_strategy=compute,
min_rows_per_bundle=target_block_size,
ray_remote_args=ray_remote_args,
Expand Down Expand Up @@ -287,6 +312,7 @@ def _get_fused_all_to_all_operator(
up_logical_op: AbstractUDFMap = self._op_map.pop(up_op)

# Fuse transformation functions.
ray_remote_args = up_logical_op._ray_remote_args
down_transform_fn = down_op.get_transformation_fn()
up_transform_fn = up_op.get_transformation_fn()

Expand All @@ -297,9 +323,9 @@ def fused_all_to_all_transform_fn(
in the TaskContext so that it may be used by the downstream
AllToAllOperator's transform function."""
ctx.upstream_map_transform_fn = up_transform_fn
ctx.upstream_map_ray_remote_args = ray_remote_args
return down_transform_fn(blocks, ctx)

ray_remote_args = down_logical_op._ray_remote_args
# Make the upstream operator's inputs the new, fused operator's inputs.
input_deps = up_op.input_dependencies
assert len(input_deps) == 1
Expand Down Expand Up @@ -330,18 +356,29 @@ def fused_all_to_all_transform_fn(
return op


def _are_remote_args_compatible(up_args, down_args):
def _are_remote_args_compatible(prev_args, next_args):
"""Check if Ray remote arguments are compatible for merging."""
from ray.data._internal.execution.operators.map_operator import (
_canonicalize_ray_remote_args,
)

up_args = _canonicalize_ray_remote_args(up_args)
down_args = _canonicalize_ray_remote_args(down_args)
remote_args = down_args.copy()
prev_args = _canonicalize(prev_args)
next_args = _canonicalize(next_args)
remote_args = next_args.copy()
for key in INHERITABLE_REMOTE_ARGS:
if key in up_args:
remote_args[key] = up_args[key]
if up_args != remote_args:
if key in prev_args:
remote_args[key] = prev_args[key]
if prev_args != remote_args:
return False
return True


def _canonicalize(remote_args: dict) -> dict:
"""Returns canonical form of given remote args."""
remote_args = remote_args.copy()
if "num_cpus" not in remote_args or remote_args["num_cpus"] is None:
remote_args["num_cpus"] = 1
if "num_gpus" not in remote_args or remote_args["num_gpus"] is None:
remote_args["num_gpus"] = 0
resources = remote_args.get("resources", {})
for k, v in list(resources.items()):
if v is None or v == 0.0:
del resources[k]
remote_args["resources"] = resources
return remote_args
1 change: 1 addition & 0 deletions python/ray/data/_internal/logical/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"Aggregate",
# N-ary
"Zip",
"Union",
]


Expand Down
29 changes: 1 addition & 28 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ray.data._internal.dataset_logger import DatasetLogger
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.logical.rules.operator_fusion import _are_remote_args_compatible
from ray.data._internal.stats import DatasetStats, DatasetStatsSummary
from ray.data.block import Block
from ray.data.context import DataContext
Expand Down Expand Up @@ -1266,34 +1267,6 @@ def _fuse_one_to_one_stages(stages: List[Stage]) -> List[Stage]:
return fused_stages


def _are_remote_args_compatible(prev_args, next_args):
"""Check if Ray remote arguments are compatible for merging."""
prev_args = _canonicalize(prev_args)
next_args = _canonicalize(next_args)
remote_args = next_args.copy()
for key in INHERITABLE_REMOTE_ARGS:
if key in prev_args:
remote_args[key] = prev_args[key]
if prev_args != remote_args:
return False
return True


def _canonicalize(remote_args: dict) -> dict:
"""Returns canonical form of given remote args."""
remote_args = remote_args.copy()
if "num_cpus" not in remote_args or remote_args["num_cpus"] is None:
remote_args["num_cpus"] = 1
if "num_gpus" not in remote_args or remote_args["num_gpus"] is None:
remote_args["num_gpus"] = 0
resources = remote_args.get("resources", {})
for k, v in list(resources.items()):
if v is None or v == 0.0:
del resources[k]
remote_args["resources"] = resources
return remote_args


def _is_lazy(blocks: BlockList) -> bool:
"""Whether the provided block list is lazy."""
return isinstance(blocks, LazyBlockList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def execute(
for j in range(output_num_blocks)
]

new_blocks, new_metadata = zip(*shuffle_reduce_out)
new_blocks, new_metadata = [], []
if shuffle_reduce_out:
new_blocks, new_metadata = zip(*shuffle_reduce_out)
new_metadata = reduce_bar.fetch_until_complete(list(new_metadata))
reduce_bar.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,10 @@ def merge(*args, **kwargs):
for i, block in enumerate(new_blocks)
]
sorted_blocks.sort(key=lambda x: x[0])
_, new_blocks, reduce_stage_metadata = zip(*sorted_blocks)

new_blocks, reduce_stage_metadata = [], []
if sorted_blocks:
_, new_blocks, reduce_stage_metadata = zip(*sorted_blocks)
del sorted_blocks

assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,19 @@ def execute(
if not ref_bundle.owns_blocks:
input_owned_by_consumer = False

# Compute the (output_num_blocks-1) indices needed for
# an equal split of the input blocks.
# Compute the (output_num_blocks) indices needed for an equal split of the
# input blocks. When output_num_blocks=1, the total number of
# input rows is used as the end index during the split calculation,
# so that we can combine all input blocks into a single output block.
indices = []
cur_idx = 0
for _ in range(output_num_blocks - 1):
cur_idx += input_num_rows / output_num_blocks
indices.append(int(cur_idx))
assert len(indices) < output_num_blocks, (indices, output_num_blocks)
if output_num_blocks == 1:
indices = [input_num_rows]
else:
cur_idx = 0
for _ in range(output_num_blocks - 1):
cur_idx += input_num_rows / output_num_blocks
indices.append(int(cur_idx))
assert len(indices) <= output_num_blocks, (indices, output_num_blocks)

if map_ray_remote_args is None:
map_ray_remote_args = {}
Expand All @@ -59,19 +64,13 @@ def execute(
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]] = []
for ref_bundle in refs:
blocks_with_metadata.extend(ref_bundle.blocks)
if indices:
split_return = _split_at_indices(
blocks_with_metadata, indices, input_owned_by_consumer
)
split_block_refs, split_metadata = [], []
for b, m in zip(*split_return):
split_block_refs.append(b)
split_metadata.extend(m)
else:
split_block_refs, split_metadata = [], []
for b, m in blocks_with_metadata:
split_block_refs.append([b])
split_metadata.append(m)
split_return = _split_at_indices(
blocks_with_metadata, indices, input_owned_by_consumer
)
split_block_refs, split_metadata = [], []
for b, m in zip(*split_return):
split_block_refs.append(b)
split_metadata.extend(m)

reduce_bar = ProgressBar("Split Repartition", total=output_num_blocks)
reduce_task = cached_remote_fn(self._exchange_spec.reduce)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/_internal/planner/map_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from ray.data.context import DataContext


def generate_map_rows_fn() -> Callable[
[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]
]:
def generate_map_rows_fn() -> (
Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]
):
"""Generate function to apply the UDF to each record of blocks."""

context = DataContext.get_current()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/planner/plan_all_to_all_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _plan_all_to_all_op(
if isinstance(op, RandomizeBlocks):
fn = generate_randomize_blocks_fn(op._seed)
elif isinstance(op, RandomShuffle):
fn = generate_random_shuffle_fn(op._seed, op._num_outputs)
fn = generate_random_shuffle_fn(op._seed, op._num_outputs, op._ray_remote_args)
elif isinstance(op, Repartition):
fn = generate_repartition_fn(op._num_outputs, op._shuffle)
elif isinstance(op, Sort):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/planner/plan_from_arrow_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_input_data() -> List[RefBundle]:
get_metadata = cached_remote_fn(get_table_block_metadata)
metadata = ray.get([get_metadata.remote(t) for t in op._tables])
ref_bundles: List[RefBundle] = [
RefBundle([(table_ref, block_metadata)], owns_blocks=True)
RefBundle([(table_ref, block_metadata)], owns_blocks=False)
for table_ref, block_metadata in zip(op._tables, metadata)
]
return ref_bundles
Expand Down
11 changes: 9 additions & 2 deletions python/ray/data/_internal/planner/plan_from_pandas_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,23 @@ def get_input_data() -> List[RefBundle]:
get_table_block_metadata,
)

owns_blocks = True
if isinstance(op, FromDask):
if isinstance(op, FromPandasRefs):
# Data is already put into the the Ray object store.
# So owns_blocks should be False.
owns_blocks = False
elif isinstance(op, FromDask):
_init_data_from_dask(op)
owns_blocks = True
elif isinstance(op, FromModin):
_init_data_from_modin(op)
owns_blocks = True
elif isinstance(op, FromMars):
_init_data_from_mars(op)
# MARS holds the MARS dataframe in memory in `to_ray_dataset()`
# to avoid object GC, so this operator cannot not own the blocks.
owns_blocks = False
else:
raise ValueError(f"Unsupported operator type: {type(op)}")

context = DataContext.get_current()

Expand Down
Loading

0 comments on commit 6d18218

Please sign in to comment.