diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 25368ce09468..bac5551b37db 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -58,6 +58,7 @@ import ray.job_config import ray.remote_function from ray import ActorID, JobID, Language, ObjectRef +from ray._raylet import StreamingObjectRefGenerator from ray._private import ray_option_utils from ray._private.client_mode_hook import client_mode_hook from ray._private.function_manager import FunctionActorManager @@ -2463,7 +2464,7 @@ def get( with profiling.profile("ray.get"): # TODO(sang): Should make StreamingObjectRefGenerator # compatible to ray.get for dataset. - if isinstance(object_refs, ray._raylet.StreamingObjectRefGenerator): + if isinstance(object_refs, StreamingObjectRefGenerator): return object_refs is_individual_id = isinstance(object_refs, ray.ObjectRef) @@ -2605,8 +2606,9 @@ def wait( - :doc:`/ray-core/patterns/ray-get-submission-order` Args: - object_refs: List of object refs for objects that may - or may not be ready. Note that these IDs must be unique. + object_refs: List of :class:`~ObjectRefs` or + :class:`~StreamingObjectRefGenerators` for objects that may or may + not be ready. Note that these must be unique. num_returns: The number of object refs that should be returned. timeout: The maximum amount of time in seconds to wait before returning. @@ -2637,14 +2639,20 @@ def wait( ) blocking_wait_inside_async_warned = True - if isinstance(object_refs, ObjectRef): + if isinstance(object_refs, ObjectRef) or isinstance( + object_refs, StreamingObjectRefGenerator + ): raise TypeError( - "wait() expected a list of ray.ObjectRef, got a single ray.ObjectRef" + "wait() expected a list of ray.ObjectRef or ray.StreamingObjectRefGenerator" + ", got a single ray.ObjectRef or ray.StreamingObjectRefGenerator " + f"{object_refs}" ) if not isinstance(object_refs, list): raise TypeError( - "wait() expected a list of ray.ObjectRef, " f"got {type(object_refs)}" + "wait() expected a list of ray.ObjectRef or " + "ray.StreamingObjectRefGenerator, " + f"got {type(object_refs)}" ) if timeout is not None and timeout < 0: @@ -2653,13 +2661,16 @@ def wait( ) for object_ref in object_refs: - if not isinstance(object_ref, ObjectRef): + if not isinstance(object_ref, ObjectRef) and not isinstance( + object_ref, StreamingObjectRefGenerator + ): raise TypeError( - "wait() expected a list of ray.ObjectRef, " + "wait() expected a list of ray.ObjectRef or " + "ray.StreamingObjectRefGenerator, " f"got list containing {type(object_ref)}" ) - worker.check_connected() + # TODO(swang): Check main thread. with profiling.profile("ray.wait"): # TODO(rkn): This is a temporary workaround for diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 771e8b7bd62d..21c6ad9c63b3 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -235,8 +235,15 @@ class StreamingObjectRefGenerator: self._generator_task_exception = None # Ray's worker class. ray._private.worker.global_worker self.worker = worker + self.worker.check_connected() assert hasattr(worker, "core_worker") + def get_next_ref(self) -> ObjectRef: + self.worker.check_connected() + core_worker = self.worker.core_worker + return core_worker.peek_object_ref_stream( + self._generator_ref) + def __iter__(self) -> "StreamingObjectRefGenerator": return self @@ -284,10 +291,7 @@ class StreamingObjectRefGenerator: timeout_s: If the next object is not ready within this timeout, it returns the nil object ref. """ - if not hasattr(self.worker, "core_worker"): - raise ValueError( - "Cannot access the core worker. " - "Did you already shutdown Ray via ray.shutdown()?") + self.worker.check_connected() core_worker = self.worker.core_worker # Wait for the next ObjectRef to become ready. @@ -325,10 +329,7 @@ class StreamingObjectRefGenerator: timeout_s: Optional[float] = None, sleep_interval_s: float = 0.0001): """Same API as _next_sync, but it is for async context.""" - if not hasattr(self.worker, "core_worker"): - raise ValueError( - "Cannot access the core worker. " - "Did you already shutdown Ray via ray.shutdown()?") + self.worker.check_connected() core_worker = self.worker.core_worker ref = core_worker.peek_object_ref_stream( @@ -2941,13 +2942,30 @@ cdef class CoreWorker: return c_object_id.Binary() - def wait(self, object_refs, int num_returns, int64_t timeout_ms, + def wait(self, object_refs_or_generators, int num_returns, int64_t timeout_ms, TaskID current_task_id, c_bool fetch_local): cdef: c_vector[CObjectID] wait_ids c_vector[c_bool] results CTaskID c_task_id = current_task_id.native() + object_refs = [] + for ref_or_generator in object_refs_or_generators: + if (not isinstance(ref_or_generator, ObjectRef) + and not isinstance(ref_or_generator, StreamingObjectRefGenerator)): + raise TypeError( + "wait() expected a list of ray.ObjectRef " + "or StreamingObjectRefGenerator, " + f"got list containing {type(ref_or_generator)}" + ) + + if isinstance(ref_or_generator, StreamingObjectRefGenerator): + # Before calling wait, + # get the next reference from a generator. + object_refs.append(ref_or_generator.get_next_ref()) + else: + object_refs.append(ref_or_generator) + wait_ids = ObjectRefsToVector(object_refs) with nogil: op_status = CCoreWorkerProcess.GetCoreWorker().Wait( @@ -2957,11 +2975,11 @@ cdef class CoreWorker: assert len(results) == len(object_refs) ready, not_ready = [], [] - for i, object_ref in enumerate(object_refs): + for i, object_ref_or_generator in enumerate(object_refs_or_generators): if results[i]: - ready.append(object_ref) + ready.append(object_ref_or_generator) else: - not_ready.append(object_ref) + not_ready.append(object_ref_or_generator) return ready, not_ready diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 07101d43a91a..06e85a87ac2a 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -5,6 +5,9 @@ import time import threading import gc +import random + +from collections import Counter from unittest.mock import patch, Mock @@ -27,6 +30,16 @@ } +def assert_no_leak(): + gc.collect() + core_worker = ray._private.worker.global_worker.core_worker + ref_counts = core_worker.get_all_reference_counts() + print(ref_counts) + for rc in ref_counts.values(): + assert rc["local"] == 0 + assert rc["submitted"] == 0 + + class MockedWorker: def __init__(self, mocked_core_worker): self.core_worker = mocked_core_worker @@ -37,6 +50,9 @@ def reset_core_worker(self): """ self.core_worker = None + def check_connected(self): + return True + @pytest.fixture def mocked_worker(): @@ -755,18 +771,164 @@ async def main(): print(summary) -def test_reconstruction(ray_start_cluster): - cluster = ray_start_cluster - # Head node with no resources. - cluster.add_node( - num_cpus=0, - _system_config=RECONSTRUCTION_CONFIG, - enable_object_reconstruction=True, - ) - ray.init(address=cluster.address) - # Node to place the initial object. - node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10**8) - cluster.wait_for_nodes() +def test_generator_wait(shutdown_only): + """ + Make sure the generator works with ray.wait. + """ + ray.init(num_cpus=8) + + @ray.remote + def f(sleep_time): + for i in range(2): + time.sleep(sleep_time) + yield i + + @ray.remote + def g(sleep_time): + time.sleep(sleep_time) + return 10 + + gen = f.options(num_returns="streaming").remote(1) + + """ + Test basic cases. + """ + for expected_rval in [0, 1]: + s = time.time() + r, ur = ray.wait([gen], num_returns=1) + print(time.time() - s) + assert len(r) == 1 + assert ray.get(next(r[0])) == expected_rval + assert len(ur) == 0 + + # Should raise a stop iteration. + for _ in range(3): + s = time.time() + r, ur = ray.wait([gen], num_returns=1) + print(time.time() - s) + assert len(r) == 1 + with pytest.raises(StopIteration): + assert next(r[0]) == 0 + assert len(ur) == 0 + + gen = f.options(num_returns="streaming").remote(0) + # Wait until the generator task finishes + ray.get(gen._generator_ref) + for i in range(2): + r, ur = ray.wait([gen], timeout=0) + assert len(r) == 1 + assert len(ur) == 0 + assert ray.get(next(r[0])) == i + + """ + Test the case ref is mixed with regular object ref. + """ + gen = f.options(num_returns="streaming").remote(0) + ref = g.remote(3) + ready, unready = [], [gen, ref] + result_set = set() + while unready: + ready, unready = ray.wait(unready) + print(ready, unready) + assert len(ready) == 1 + for r in ready: + if isinstance(r, StreamingObjectRefGenerator): + try: + ref = next(r) + print(ref) + print(ray.get(ref)) + result_set.add(ray.get(ref)) + except StopIteration: + pass + else: + unready.append(r) + else: + result_set.add(ray.get(r)) + + assert result_set == {0, 1, 10} + + """ + Test timeout. + """ + gen = f.options(num_returns="streaming").remote(3) + ref = g.remote(1) + ready, unready = ray.wait([gen, ref], timeout=2) + assert len(ready) == 1 + assert len(unready) == 1 + + """ + Test num_returns + """ + gen = f.options(num_returns="streaming").remote(1) + ref = g.remote(1) + ready, unready = ray.wait([ref, gen], num_returns=2) + assert len(ready) == 2 + assert len(unready) == 0 + + +def test_generator_wait_e2e(shutdown_only): + ray.init(num_cpus=8) + + @ray.remote + def f(sleep_time): + for i in range(2): + time.sleep(sleep_time) + yield i + + @ray.remote + def g(sleep_time): + time.sleep(sleep_time) + return 10 + + gen = [f.options(num_returns="streaming").remote(1) for _ in range(4)] + ref = [g.remote(2) for _ in range(4)] + ready, unready = [], [*gen, *ref] + result = [] + start = time.time() + while unready: + ready, unready = ray.wait(unready, num_returns=len(unready), timeout=0.1) + for r in ready: + if isinstance(r, StreamingObjectRefGenerator): + try: + ref = next(r) + result.append(ray.get(ref)) + except StopIteration: + pass + else: + unready.append(r) + else: + result.append(ray.get(r)) + elapsed = time.time() - start + assert elapsed < 3 + assert 2 < elapsed + + assert len(result) == 12 + result = Counter(result) + assert result[0] == 4 + assert result[1] == 4 + assert result[10] == 4 + + +@pytest.mark.parametrize("delay", [True]) +def test_reconstruction(monkeypatch, ray_start_cluster, delay): + with monkeypatch.context() as m: + if delay: + m.setenv( + "RAY_testing_asio_delay_us", + "CoreWorkerService.grpc_server." + "ReportGeneratorItemReturns=10000:1000000", + ) + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, + _system_config=RECONSTRUCTION_CONFIG, + enable_object_reconstruction=True, + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10**8) + cluster.wait_for_nodes() @ray.remote(num_returns="streaming", max_retries=2) def dynamic_generator(num_returns): @@ -901,6 +1063,70 @@ def fetch(x): assert "The worker died" in str(e.value) +def test_ray_datasetlike_mini_stress_test(monkeypatch, ray_start_cluster): + """ + Test a workload that's like ray dataset + lineage reconstruction. + """ + with monkeypatch.context() as m: + m.setenv( + "RAY_testing_asio_delay_us", + "CoreWorkerService.grpc_server." "ReportGeneratorItemReturns=10000:1000000", + ) + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=1, + resources={"head": 1}, + _system_config=RECONSTRUCTION_CONFIG, + enable_object_reconstruction=True, + ) + ray.init(address=cluster.address) + + @ray.remote(num_returns="streaming", max_retries=-1) + def dynamic_generator(num_returns): + for i in range(num_returns): + time.sleep(0.1) + yield np.ones(1_000_000, dtype=np.int8) * i + + @ray.remote(num_cpus=0, resources={"head": 1}) + def driver(): + unready = [dynamic_generator.remote(10) for _ in range(5)] + ready = [] + while unready: + ready, unready = ray.wait( + unready, num_returns=len(unready), timeout=0.1 + ) + for r in ready: + try: + ref = next(r) + print(ref) + ray.get(ref) + except StopIteration: + pass + else: + unready.append(r) + return None + + ref = driver.remote() + + nodes = [] + for _ in range(4): + nodes.append(cluster.add_node(num_cpus=1, object_store_memory=10**8)) + cluster.wait_for_nodes() + + for _ in range(10): + time.sleep(0.1) + node_to_kill = random.choices(nodes)[0] + nodes.remove(node_to_kill) + cluster.remove_node(node_to_kill, allow_graceful=False) + nodes.append(cluster.add_node(num_cpus=1, object_store_memory=10**8)) + + ray.get(ref) + del ref + + assert_no_leak() + + def test_generator_max_returns(monkeypatch, shutdown_only): """ Test when generator returns more than system limit values @@ -928,6 +1154,27 @@ def driver(): ray.get(driver.remote()) +def test_return_yield_mix(shutdown_only): + """ + Test the case where yield and return is mixed within a + generator task. + """ + + @ray.remote + def g(): + for i in range(3): + yield i + return + + generator = g.options(num_returns="streaming").remote() + result = [] + for ref in generator: + result.append(ray.get(ref)) + + assert len(result) == 1 + assert result[0] == 0 + + if __name__ == "__main__": import os diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 0c9d6b8e6e02..34f842583a01 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -755,7 +755,7 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, RAY_CHECK_EQ(reply.return_objects_size(), 1); for (size_t i = 0; i < spec.NumStreamingGeneratorReturns(); i++) { const auto generator_return_id = spec.StreamingGeneratorReturnId(i); - RAY_LOG(DEBUG) << "Failing streamed object " << generator_return_id; + RAY_CHECK_EQ(reply.return_objects_size(), 1); const auto &return_object = reply.return_objects(0); HandleTaskReturn(generator_return_id, return_object,