-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] [2/N] Enable optimizer: fix fusion #35621
Changes from all commits
c192455
a98877f
5996f01
758e369
6cd3c5a
3432b74
1b9acf8
385162f
0d38d0c
f3d43be
3b7bc9a
95de138
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
from typing import Iterator, List, Tuple | ||
|
||
from ray.data._internal.logical.operators.all_to_all_operator import Repartition | ||
from ray.data._internal.execution.operators.map_operator import MapOperator | ||
from ray.data._internal.execution.operators.actor_pool_map_operator import ( | ||
ActorPoolMapOperator, | ||
) | ||
from ray.data._internal.execution.operators.task_pool_map_operator import ( | ||
TaskPoolMapOperator, | ||
) | ||
from ray.data._internal.logical.operators.all_to_all_operator import ( | ||
AbstractAllToAll, | ||
RandomShuffle, | ||
|
@@ -102,16 +109,22 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: | |
the same class AND constructor args are the same for both. | ||
* They have compatible remote arguments. | ||
""" | ||
from ray.data._internal.execution.operators.map_operator import MapOperator | ||
from ray.data._internal.logical.operators.map_operator import AbstractMap | ||
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap | ||
|
||
# We currently only support fusing for the following cases: | ||
# - MapOperator -> MapOperator | ||
# - MapOperator -> AllToAllOperator | ||
# - TaskPoolMapOperator -> TaskPoolMapOperator/ActorPoolMapOperator | ||
# - TaskPoolMapOperator -> AllToAllOperator | ||
# (only RandomShuffle and Repartition LogicalOperators are currently supported) | ||
if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance( | ||
up_op, MapOperator | ||
if not ( | ||
( | ||
isinstance(up_op, TaskPoolMapOperator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should we combine the two cases into a single isinstance check on down_op? also i recall discussing that we will potentially not support this for Actor case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having 2 separate cases looks more clear to me. But I don't have strong preference. |
||
and isinstance(down_op, (TaskPoolMapOperator, ActorPoolMapOperator)) | ||
) | ||
or ( | ||
isinstance(up_op, TaskPoolMapOperator) | ||
and isinstance(down_op, AllToAllOperator) | ||
) | ||
): | ||
return False | ||
|
||
|
@@ -195,7 +208,11 @@ def _get_fused_map_operator( | |
up_logical_op = self._op_map.pop(up_op) | ||
|
||
# Merge target block sizes. | ||
down_target_block_size = down_logical_op._target_block_size | ||
down_target_block_size = ( | ||
down_logical_op._target_block_size | ||
if isinstance(down_logical_op, AbstractUDFMap) | ||
else None | ||
) | ||
up_target_block_size = ( | ||
up_logical_op._target_block_size | ||
if isinstance(up_logical_op, AbstractUDFMap) | ||
|
@@ -219,10 +236,17 @@ def fused_map_transform_fn( | |
# TODO(Scott): Add zero-copy batching between transform functions. | ||
return down_transform_fn(blocks, ctx) | ||
|
||
# Fuse init funcitons. | ||
fused_init_fn = ( | ||
down_op.get_init_fn() if isinstance(down_op, ActorPoolMapOperator) else None | ||
) | ||
|
||
# We take the downstream op's compute in case we're fusing upstream tasks with a | ||
# downstream actor pool (e.g. read->map). | ||
compute = get_compute(down_logical_op._compute) | ||
ray_remote_args = down_logical_op._ray_remote_args | ||
compute = None | ||
if isinstance(down_logical_op, AbstractUDFMap): | ||
compute = get_compute(down_logical_op._compute) | ||
ray_remote_args = up_logical_op._ray_remote_args | ||
# Make the upstream operator's inputs the new, fused operator's inputs. | ||
input_deps = up_op.input_dependencies | ||
assert len(input_deps) == 1 | ||
|
@@ -233,6 +257,7 @@ def fused_map_transform_fn( | |
fused_map_transform_fn, | ||
input_op, | ||
name=name, | ||
init_fn=fused_init_fn, | ||
compute_strategy=compute, | ||
min_rows_per_bundle=target_block_size, | ||
ray_remote_args=ray_remote_args, | ||
|
@@ -287,6 +312,7 @@ def _get_fused_all_to_all_operator( | |
up_logical_op: AbstractUDFMap = self._op_map.pop(up_op) | ||
|
||
# Fuse transformation functions. | ||
ray_remote_args = up_logical_op._ray_remote_args | ||
down_transform_fn = down_op.get_transformation_fn() | ||
up_transform_fn = up_op.get_transformation_fn() | ||
|
||
|
@@ -297,9 +323,9 @@ def fused_all_to_all_transform_fn( | |
in the TaskContext so that it may be used by the downstream | ||
AllToAllOperator's transform function.""" | ||
ctx.upstream_map_transform_fn = up_transform_fn | ||
ctx.upstream_map_ray_remote_args = ray_remote_args | ||
return down_transform_fn(blocks, ctx) | ||
|
||
ray_remote_args = down_logical_op._ray_remote_args | ||
# Make the upstream operator's inputs the new, fused operator's inputs. | ||
input_deps = up_op.input_dependencies | ||
assert len(input_deps) == 1 | ||
|
@@ -330,18 +356,29 @@ def fused_all_to_all_transform_fn( | |
return op | ||
|
||
|
||
def _are_remote_args_compatible(up_args, down_args): | ||
def _are_remote_args_compatible(prev_args, next_args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be possible to add unit tests for this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. Added in |
||
"""Check if Ray remote arguments are compatible for merging.""" | ||
from ray.data._internal.execution.operators.map_operator import ( | ||
_canonicalize_ray_remote_args, | ||
) | ||
|
||
up_args = _canonicalize_ray_remote_args(up_args) | ||
down_args = _canonicalize_ray_remote_args(down_args) | ||
remote_args = down_args.copy() | ||
prev_args = _canonicalize(prev_args) | ||
next_args = _canonicalize(next_args) | ||
remote_args = next_args.copy() | ||
for key in INHERITABLE_REMOTE_ARGS: | ||
if key in up_args: | ||
remote_args[key] = up_args[key] | ||
if up_args != remote_args: | ||
if key in prev_args: | ||
remote_args[key] = prev_args[key] | ||
if prev_args != remote_args: | ||
return False | ||
return True | ||
|
||
|
||
def _canonicalize(remote_args: dict) -> dict: | ||
"""Returns canonical form of given remote args.""" | ||
remote_args = remote_args.copy() | ||
if "num_cpus" not in remote_args or remote_args["num_cpus"] is None: | ||
remote_args["num_cpus"] = 1 | ||
if "num_gpus" not in remote_args or remote_args["num_gpus"] is None: | ||
remote_args["num_gpus"] = 0 | ||
resources = remote_args.get("resources", {}) | ||
for k, v in list(resources.items()): | ||
if v is None or v == 0.0: | ||
del resources[k] | ||
remote_args["resources"] = resources | ||
return remote_args |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,7 @@ | |
"Aggregate", | ||
# N-ary | ||
"Zip", | ||
"Union", | ||
] | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ def get_input_data() -> List[RefBundle]: | |
get_metadata = cached_remote_fn(get_table_block_metadata) | ||
metadata = ray.get([get_metadata.remote(t) for t in op._tables]) | ||
ref_bundles: List[RefBundle] = [ | ||
RefBundle([(table_ref, block_metadata)], owns_blocks=True) | ||
RefBundle([(table_ref, block_metadata)], owns_blocks=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are all the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The blocks are put into object store inside the FromArrowRefs op, so this RefBundle shouldn't own the blocks. This was a bug. This function is used for the optimizer code path only. |
||
for table_ref, block_metadata in zip(op._tables, metadata) | ||
] | ||
return ref_bundles | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, it's not ideal to pass this at runtime. Ideally, the optimizer would rewrite the downstream op's ray remote args to this value, instead of having each operator need to properly decide which of the two args to use and looking at the context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you. But currently it's hard to avoid this. For most operators, we are already doing the way you mentioned.
upstream_map_ray_remote_args
, along withupstream_map_transform_fn
, are used only forRandomShuffle
. Because the correspondingAllToAllOperator
physical op itself doesn't directly do the shuffle. instead, it usesExchangeTaskScheduler
to launch new tasks to do the shuffle. That's why we need this ad-hoc handling here. I'll add a TODO here.update: see generate_random_shuffle_fn for more details