Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] [2/N] Enable optimizer: fix fusion #35621

Merged
merged 12 commits into from
May 31, 2023
15 changes: 14 additions & 1 deletion python/ray/data/_internal/execution/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
import os
from typing import Dict, List, Optional, Iterable, Iterator, Tuple, Callable, Union
from typing import Any, Dict, List, Optional, Iterable, Iterator, Tuple, Callable, Union

import ray
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -233,10 +233,23 @@ class TaskContext:
# TODO(chengsu): clean it up from TaskContext with new optimizer framework.
sub_progress_bar_dict: Optional[Dict[str, ProgressBar]] = None

# NOTE(hchen): `upstream_map_transform_fn` and `upstream_map_ray_remote_args`
# are only used for `RandomShuffle`. DO NOT use them for other operators.
# Ideally, they should be handled by the optimizer, and should be transparent
# to the specific operators.
# But for `RandomShuffle`, the AllToAllOperator doesn't do the shuffle itself.
# It uses `ExchangeTaskScheduler` to launch new tasks to do the shuffle.
# That's why we need to pass them to `ExchangeTaskScheduler`.
# TODO(hchen): Use a physical operator to do the shuffle directly.

# The underlying function called in a MapOperator; this is used when fusing
# an AllToAllOperator with an upstream MapOperator.
upstream_map_transform_fn: Optional["MapTransformFn"] = None

# The Ray remote arguments of the fused upstream MapOperator.
# This should be set if upstream_map_transform_fn is set.
upstream_map_ray_remote_args: Dict[str, Any] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it's not ideal to pass this at runtime. Ideally, the optimizer would rewrite the downstream op's ray remote args to this value, instead of having each operator need to properly decide which of the two args to use and looking at the context.

Copy link
Contributor Author

@raulchen raulchen May 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. But currently it's hard to avoid this. For most operators, we are already doing the way you mentioned.
upstream_map_ray_remote_args, along with upstream_map_transform_fn, are used only for RandomShuffle. Because the corresponding AllToAllOperator physical op itself doesn't directly do the shuffle. instead, it uses ExchangeTaskScheduler to launch new tasks to do the shuffle. That's why we need this ad-hoc handling here. I'll add a TODO here.

update: see generate_random_shuffle_fn for more details



# Block transform function applied by task and actor pools in MapOperator.
MapTransformFn = Callable[[Iterable[Block], TaskContext], Iterable[Block]]
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def _get_execution_dag(
record_operators_usage(plan._logical_plan.dag)

# Get DAG of physical operators and input statistics.
if DataContext.get_current().optimizer_enabled:
if (
DataContext.get_current().optimizer_enabled
# TODO(hchen): Remove this when all operators support logical plan.
and getattr(plan, "_logical_plan", None) is not None
):
dag = get_execution_plan(plan._logical_plan).dag
stats = _get_initial_stats_from_plan(plan)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(
self._inputs_done = False
self._next_task_idx = 0

def get_init_fn(self) -> Callable[[], None]:
return self._init_fn

def internal_queue_size(self) -> int:
return len(self._bundle_queue)

Expand Down
77 changes: 57 additions & 20 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Iterator, List, Tuple

from ray.data._internal.logical.operators.all_to_all_operator import Repartition
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.actor_pool_map_operator import (
ActorPoolMapOperator,
)
from ray.data._internal.execution.operators.task_pool_map_operator import (
TaskPoolMapOperator,
)
from ray.data._internal.logical.operators.all_to_all_operator import (
AbstractAllToAll,
RandomShuffle,
Expand Down Expand Up @@ -102,16 +109,22 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
the same class AND constructor args are the same for both.
* They have compatible remote arguments.
"""
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap

# We currently only support fusing for the following cases:
# - MapOperator -> MapOperator
# - MapOperator -> AllToAllOperator
# - TaskPoolMapOperator -> TaskPoolMapOperator/ActorPoolMapOperator
# - TaskPoolMapOperator -> AllToAllOperator
# (only RandomShuffle and Repartition LogicalOperators are currently supported)
if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance(
up_op, MapOperator
if not (
(
isinstance(up_op, TaskPoolMapOperator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we combine the two cases into a single isinstance check on down_op?

also i recall discussing that we will potentially not support this for Actor case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having 2 separate cases looks more clear to me. But I don't have strong preference.
We are dropping support for actor->actor case. task->actor is still supported.

and isinstance(down_op, (TaskPoolMapOperator, ActorPoolMapOperator))
)
or (
isinstance(up_op, TaskPoolMapOperator)
and isinstance(down_op, AllToAllOperator)
)
):
return False

Expand Down Expand Up @@ -195,7 +208,11 @@ def _get_fused_map_operator(
up_logical_op = self._op_map.pop(up_op)

# Merge target block sizes.
down_target_block_size = down_logical_op._target_block_size
down_target_block_size = (
down_logical_op._target_block_size
if isinstance(down_logical_op, AbstractUDFMap)
else None
)
up_target_block_size = (
up_logical_op._target_block_size
if isinstance(up_logical_op, AbstractUDFMap)
Expand All @@ -219,10 +236,17 @@ def fused_map_transform_fn(
# TODO(Scott): Add zero-copy batching between transform functions.
return down_transform_fn(blocks, ctx)

# Fuse init funcitons.
fused_init_fn = (
down_op.get_init_fn() if isinstance(down_op, ActorPoolMapOperator) else None
)

# We take the downstream op's compute in case we're fusing upstream tasks with a
# downstream actor pool (e.g. read->map).
compute = get_compute(down_logical_op._compute)
ray_remote_args = down_logical_op._ray_remote_args
compute = None
if isinstance(down_logical_op, AbstractUDFMap):
compute = get_compute(down_logical_op._compute)
ray_remote_args = up_logical_op._ray_remote_args
# Make the upstream operator's inputs the new, fused operator's inputs.
input_deps = up_op.input_dependencies
assert len(input_deps) == 1
Expand All @@ -233,6 +257,7 @@ def fused_map_transform_fn(
fused_map_transform_fn,
input_op,
name=name,
init_fn=fused_init_fn,
compute_strategy=compute,
min_rows_per_bundle=target_block_size,
ray_remote_args=ray_remote_args,
Expand Down Expand Up @@ -287,6 +312,7 @@ def _get_fused_all_to_all_operator(
up_logical_op: AbstractUDFMap = self._op_map.pop(up_op)

# Fuse transformation functions.
ray_remote_args = up_logical_op._ray_remote_args
down_transform_fn = down_op.get_transformation_fn()
up_transform_fn = up_op.get_transformation_fn()

Expand All @@ -297,9 +323,9 @@ def fused_all_to_all_transform_fn(
in the TaskContext so that it may be used by the downstream
AllToAllOperator's transform function."""
ctx.upstream_map_transform_fn = up_transform_fn
ctx.upstream_map_ray_remote_args = ray_remote_args
return down_transform_fn(blocks, ctx)

ray_remote_args = down_logical_op._ray_remote_args
# Make the upstream operator's inputs the new, fused operator's inputs.
input_deps = up_op.input_dependencies
assert len(input_deps) == 1
Expand Down Expand Up @@ -330,18 +356,29 @@ def fused_all_to_all_transform_fn(
return op


def _are_remote_args_compatible(up_args, down_args):
def _are_remote_args_compatible(prev_args, next_args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to add unit tests for this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Added in test_read_map_batches_operator_fusion_compatible_remote_args and test_read_map_batches_operator_fusion_incompatible_remote_args

"""Check if Ray remote arguments are compatible for merging."""
from ray.data._internal.execution.operators.map_operator import (
_canonicalize_ray_remote_args,
)

up_args = _canonicalize_ray_remote_args(up_args)
down_args = _canonicalize_ray_remote_args(down_args)
remote_args = down_args.copy()
prev_args = _canonicalize(prev_args)
next_args = _canonicalize(next_args)
remote_args = next_args.copy()
for key in INHERITABLE_REMOTE_ARGS:
if key in up_args:
remote_args[key] = up_args[key]
if up_args != remote_args:
if key in prev_args:
remote_args[key] = prev_args[key]
if prev_args != remote_args:
return False
return True


def _canonicalize(remote_args: dict) -> dict:
"""Returns canonical form of given remote args."""
remote_args = remote_args.copy()
if "num_cpus" not in remote_args or remote_args["num_cpus"] is None:
remote_args["num_cpus"] = 1
if "num_gpus" not in remote_args or remote_args["num_gpus"] is None:
remote_args["num_gpus"] = 0
resources = remote_args.get("resources", {})
for k, v in list(resources.items()):
if v is None or v == 0.0:
del resources[k]
remote_args["resources"] = resources
return remote_args
1 change: 1 addition & 0 deletions python/ray/data/_internal/logical/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"Aggregate",
# N-ary
"Zip",
"Union",
]


Expand Down
29 changes: 1 addition & 28 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ray.data._internal.dataset_logger import DatasetLogger
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.logical.rules.operator_fusion import _are_remote_args_compatible
from ray.data._internal.stats import DatasetStats, DatasetStatsSummary
from ray.data.block import Block
from ray.data.context import DataContext
Expand Down Expand Up @@ -1266,34 +1267,6 @@ def _fuse_one_to_one_stages(stages: List[Stage]) -> List[Stage]:
return fused_stages


def _are_remote_args_compatible(prev_args, next_args):
"""Check if Ray remote arguments are compatible for merging."""
prev_args = _canonicalize(prev_args)
next_args = _canonicalize(next_args)
remote_args = next_args.copy()
for key in INHERITABLE_REMOTE_ARGS:
if key in prev_args:
remote_args[key] = prev_args[key]
if prev_args != remote_args:
return False
return True


def _canonicalize(remote_args: dict) -> dict:
"""Returns canonical form of given remote args."""
remote_args = remote_args.copy()
if "num_cpus" not in remote_args or remote_args["num_cpus"] is None:
remote_args["num_cpus"] = 1
if "num_gpus" not in remote_args or remote_args["num_gpus"] is None:
remote_args["num_gpus"] = 0
resources = remote_args.get("resources", {})
for k, v in list(resources.items()):
if v is None or v == 0.0:
del resources[k]
remote_args["resources"] = resources
return remote_args


def _is_lazy(blocks: BlockList) -> bool:
"""Whether the provided block list is lazy."""
return isinstance(blocks, LazyBlockList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def execute(
for j in range(output_num_blocks)
]

new_blocks, new_metadata = zip(*shuffle_reduce_out)
new_blocks, new_metadata = [], []
if shuffle_reduce_out:
new_blocks, new_metadata = zip(*shuffle_reduce_out)
new_metadata = reduce_bar.fetch_until_complete(list(new_metadata))
reduce_bar.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,10 @@ def merge(*args, **kwargs):
for i, block in enumerate(new_blocks)
]
sorted_blocks.sort(key=lambda x: x[0])
_, new_blocks, reduce_stage_metadata = zip(*sorted_blocks)

new_blocks, reduce_stage_metadata = [], []
if sorted_blocks:
_, new_blocks, reduce_stage_metadata = zip(*sorted_blocks)
del sorted_blocks

assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,19 @@ def execute(
if not ref_bundle.owns_blocks:
input_owned_by_consumer = False

# Compute the (output_num_blocks-1) indices needed for
# an equal split of the input blocks.
# Compute the (output_num_blocks) indices needed for an equal split of the
# input blocks. When output_num_blocks=1, the total number of
# input rows is used as the end index during the split calculation,
# so that we can combine all input blocks into a single output block.
indices = []
cur_idx = 0
for _ in range(output_num_blocks - 1):
cur_idx += input_num_rows / output_num_blocks
indices.append(int(cur_idx))
assert len(indices) < output_num_blocks, (indices, output_num_blocks)
if output_num_blocks == 1:
indices = [input_num_rows]
else:
cur_idx = 0
for _ in range(output_num_blocks - 1):
cur_idx += input_num_rows / output_num_blocks
indices.append(int(cur_idx))
assert len(indices) <= output_num_blocks, (indices, output_num_blocks)

if map_ray_remote_args is None:
map_ray_remote_args = {}
Expand All @@ -59,19 +64,13 @@ def execute(
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]] = []
for ref_bundle in refs:
blocks_with_metadata.extend(ref_bundle.blocks)
if indices:
split_return = _split_at_indices(
blocks_with_metadata, indices, input_owned_by_consumer
)
split_block_refs, split_metadata = [], []
for b, m in zip(*split_return):
split_block_refs.append(b)
split_metadata.extend(m)
else:
split_block_refs, split_metadata = [], []
for b, m in blocks_with_metadata:
split_block_refs.append([b])
split_metadata.append(m)
split_return = _split_at_indices(
blocks_with_metadata, indices, input_owned_by_consumer
)
split_block_refs, split_metadata = [], []
for b, m in zip(*split_return):
split_block_refs.append(b)
split_metadata.extend(m)

reduce_bar = ProgressBar("Split Repartition", total=output_num_blocks)
reduce_task = cached_remote_fn(self._exchange_spec.reduce)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/_internal/planner/map_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from ray.data.context import DataContext


def generate_map_rows_fn() -> Callable[
[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]
]:
def generate_map_rows_fn() -> (
Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]
):
"""Generate function to apply the UDF to each record of blocks."""

context = DataContext.get_current()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/planner/plan_all_to_all_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _plan_all_to_all_op(
if isinstance(op, RandomizeBlocks):
fn = generate_randomize_blocks_fn(op._seed)
elif isinstance(op, RandomShuffle):
fn = generate_random_shuffle_fn(op._seed, op._num_outputs)
fn = generate_random_shuffle_fn(op._seed, op._num_outputs, op._ray_remote_args)
elif isinstance(op, Repartition):
fn = generate_repartition_fn(op._num_outputs, op._shuffle)
elif isinstance(op, Sort):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/planner/plan_from_arrow_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_input_data() -> List[RefBundle]:
get_metadata = cached_remote_fn(get_table_block_metadata)
metadata = ray.get([get_metadata.remote(t) for t in op._tables])
ref_bundles: List[RefBundle] = [
RefBundle([(table_ref, block_metadata)], owns_blocks=True)
RefBundle([(table_ref, block_metadata)], owns_blocks=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are all the owns_blocks changes for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blocks are put into object store inside the FromArrowRefs op, so this RefBundle shouldn't own the blocks. This was a bug. This function is used for the optimizer code path only.

for table_ref, block_metadata in zip(op._tables, metadata)
]
return ref_bundles
Expand Down
11 changes: 9 additions & 2 deletions python/ray/data/_internal/planner/plan_from_pandas_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,23 @@ def get_input_data() -> List[RefBundle]:
get_table_block_metadata,
)

owns_blocks = True
if isinstance(op, FromDask):
if isinstance(op, FromPandasRefs):
# Data is already put into the the Ray object store.
# So owns_blocks should be False.
owns_blocks = False
elif isinstance(op, FromDask):
_init_data_from_dask(op)
owns_blocks = True
elif isinstance(op, FromModin):
_init_data_from_modin(op)
owns_blocks = True
elif isinstance(op, FromMars):
_init_data_from_mars(op)
# MARS holds the MARS dataframe in memory in `to_ray_dataset()`
# to avoid object GC, so this operator cannot not own the blocks.
owns_blocks = False
else:
raise ValueError(f"Unsupported operator type: {type(op)}")

context = DataContext.get_current()

Expand Down
Loading