Skip to content

Commit

Permalink
[data] Fix wrong output order of streaming_split (#36919)
Browse files Browse the repository at this point in the history
The output order of OutputSplitter should be FIFO, instead of LIFO. This bug also makes streaming_split's order non-deterministic, because it depends on when the OutputSplitter's outputs are taken.

---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
  • Loading branch information
raulchen authored Jun 29, 2023
1 parent cf0bdd6 commit f9912eb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from collections import deque
from typing import Dict, List, Optional

from ray.data._internal.execution.interfaces import (
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(
# Buffer of bundles not yet assigned to output splits.
self._buffer: List[RefBundle] = []
# The outputted bundles with output_split attribute set.
self._output_queue: List[RefBundle] = []
self._output_queue: deque[RefBundle] = deque()
# The number of rows output to each output split so far.
self._num_output: List[int] = [0 for _ in range(n)]

Expand Down Expand Up @@ -84,7 +85,7 @@ def has_next(self) -> bool:
return len(self._output_queue) > 0

def get_next(self) -> RefBundle:
return self._output_queue.pop()
return self._output_queue.popleft()

def get_stats(self) -> StatsDict:
return {"split": []} # TODO(ekl) add split metrics?
Expand Down
42 changes: 29 additions & 13 deletions python/ray/data/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,30 +229,46 @@ def test_map_operator_streamed(ray_start_regular_shared, use_actors):
@pytest.mark.parametrize("equal", [False, True])
@pytest.mark.parametrize("chunk_size", [1, 10])
def test_split_operator(ray_start_regular_shared, equal, chunk_size):
input_op = InputDataBuffer(make_ref_bundles([[i] * chunk_size for i in range(100)]))
op = OutputSplitter(input_op, 3, equal=equal)
num_input_blocks = 100
num_splits = 3
# Add this many input blocks each time.
# Make sure it is greater than num_splits * 2,
# so we can test the output order of `OutputSplitter.get_next`.
num_add_input_blocks = 10
input_op = InputDataBuffer(
make_ref_bundles([[i] * chunk_size for i in range(num_input_blocks)])
)
op = OutputSplitter(input_op, num_splits, equal=equal)

# Feed data and implement streaming exec.
output_splits = collections.defaultdict(list)
output_splits = [[] for _ in range(num_splits)]
op.start(ExecutionOptions())
while input_op.has_next():
op.add_input(input_op.get_next(), 0)
for _ in range(num_add_input_blocks):
if not input_op.has_next():
break
op.add_input(input_op.get_next(), 0)
while op.has_next():
ref = op.get_next()
assert ref.owns_blocks, ref
for block, _ in ref.blocks:
assert ref.output_split_idx is not None
output_splits[ref.output_split_idx].extend(list(ray.get(block)["id"]))
op.all_inputs_done()

expected_splits = [[] for _ in range(num_splits)]
for i in range(num_splits):
for j in range(i, num_input_blocks, num_splits):
expected_splits[i].extend([j] * chunk_size)
if equal:
for i in range(3):
assert len(output_splits[i]) == 33 * chunk_size, output_splits
else:
assert sum(len(output_splits[i]) for i in range(3)) == (100 * chunk_size)
for i in range(3):
assert len(output_splits[i]) in [
33 * chunk_size,
34 * chunk_size,
], output_splits
min_len = min(len(expected_splits[i]) for i in range(num_splits))
for i in range(num_splits):
expected_splits[i] = expected_splits[i][:min_len]
for i in range(num_splits):
assert output_splits[i] == expected_splits[i], (
output_splits[i],
expected_splits[i],
)


@pytest.mark.parametrize("equal", [False, True])
Expand Down
11 changes: 11 additions & 0 deletions python/ray/data/tests/test_streaming_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ def run(self):
c1.start()
c0.join()
c1.join()

def get_outputs(out: List[RefBundle]):
outputs = []
for bundle in out:
for block, _ in bundle.blocks:
ids: pd.Series = ray.get(block)["id"]
outputs.extend(ids.values)
return outputs

assert get_outputs(c0.out) == list(range(0, 20, 2))
assert get_outputs(c1.out) == list(range(1, 20, 2))
assert len(c0.out) == 10, c0.out
assert len(c1.out) == 10, c0.out

Expand Down

0 comments on commit f9912eb

Please sign in to comment.