Skip to content

Commit

Permalink
[Streaming Generator] Make it compatible with wait (ray-project#36071)
Browse files Browse the repository at this point in the history
This PR makes the streaming generator compatible with ray.wait.

The semantic is as follows;

def f():
    for _ in range(3):
        yield 1
generator = f.options(num_returns="streaming").remote()
# The generator will be in ready if the next reference is available. Otherwise it is in unready.
# This should work with all other options from ray.wait (including fetch_local=True/False)
ready, unready = ray.wait([generator])

# if the generator's next ref is not ready in 0.1 second, it will be in unready.
# otherwise, it is in ready
ready, unready = ray.wait([generator], timeout=0.1)

# If the generator's next ref is available, it is considered as 1 return
# In this case, this will return if both generator and ref is ready.
ready, unready = ray.wait([generator, ref], num_returns=2)

# if the generator's next ref is available, it will fetch the object to the local node
ready, unready = ray.wait([generator, ref], fetch_local=True)
From the previous PR ray-project#36070, we are now able to peek the object reference, and the peeked object is guaranteed to be resolved. We can always peek the next object from the generator and wait on that reference to make the generator compatible to ray.wait.

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
rkooo567 authored and arvind-chandra committed Aug 31, 2023
1 parent 87e1ef0 commit 939d1d2
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 34 deletions.
29 changes: 20 additions & 9 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import ray.job_config
import ray.remote_function
from ray import ActorID, JobID, Language, ObjectRef
from ray._raylet import StreamingObjectRefGenerator
from ray._private import ray_option_utils
from ray._private.client_mode_hook import client_mode_hook
from ray._private.function_manager import FunctionActorManager
Expand Down Expand Up @@ -2463,7 +2464,7 @@ def get(
with profiling.profile("ray.get"):
# TODO(sang): Should make StreamingObjectRefGenerator
# compatible to ray.get for dataset.
if isinstance(object_refs, ray._raylet.StreamingObjectRefGenerator):
if isinstance(object_refs, StreamingObjectRefGenerator):
return object_refs

is_individual_id = isinstance(object_refs, ray.ObjectRef)
Expand Down Expand Up @@ -2605,8 +2606,9 @@ def wait(
- :doc:`/ray-core/patterns/ray-get-submission-order`
Args:
object_refs: List of object refs for objects that may
or may not be ready. Note that these IDs must be unique.
object_refs: List of :class:`~ObjectRefs` or
:class:`~StreamingObjectRefGenerators` for objects that may or may
not be ready. Note that these must be unique.
num_returns: The number of object refs that should be returned.
timeout: The maximum amount of time in seconds to wait before
returning.
Expand Down Expand Up @@ -2637,14 +2639,20 @@ def wait(
)
blocking_wait_inside_async_warned = True

if isinstance(object_refs, ObjectRef):
if isinstance(object_refs, ObjectRef) or isinstance(
object_refs, StreamingObjectRefGenerator
):
raise TypeError(
"wait() expected a list of ray.ObjectRef, got a single ray.ObjectRef"
"wait() expected a list of ray.ObjectRef or ray.StreamingObjectRefGenerator"
", got a single ray.ObjectRef or ray.StreamingObjectRefGenerator "
f"{object_refs}"
)

if not isinstance(object_refs, list):
raise TypeError(
"wait() expected a list of ray.ObjectRef, " f"got {type(object_refs)}"
"wait() expected a list of ray.ObjectRef or "
"ray.StreamingObjectRefGenerator, "
f"got {type(object_refs)}"
)

if timeout is not None and timeout < 0:
Expand All @@ -2653,13 +2661,16 @@ def wait(
)

for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef):
if not isinstance(object_ref, ObjectRef) and not isinstance(
object_ref, StreamingObjectRefGenerator
):
raise TypeError(
"wait() expected a list of ray.ObjectRef, "
"wait() expected a list of ray.ObjectRef or "
"ray.StreamingObjectRefGenerator, "
f"got list containing {type(object_ref)}"
)

worker.check_connected()

# TODO(swang): Check main thread.
with profiling.profile("ray.wait"):
# TODO(rkn): This is a temporary workaround for
Expand Down
42 changes: 30 additions & 12 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,15 @@ class StreamingObjectRefGenerator:
self._generator_task_exception = None
# Ray's worker class. ray._private.worker.global_worker
self.worker = worker
self.worker.check_connected()
assert hasattr(worker, "core_worker")

def get_next_ref(self) -> ObjectRef:
self.worker.check_connected()
core_worker = self.worker.core_worker
return core_worker.peek_object_ref_stream(
self._generator_ref)

def __iter__(self) -> "StreamingObjectRefGenerator":
return self

Expand Down Expand Up @@ -284,10 +291,7 @@ class StreamingObjectRefGenerator:
timeout_s: If the next object is not ready within
this timeout, it returns the nil object ref.
"""
if not hasattr(self.worker, "core_worker"):
raise ValueError(
"Cannot access the core worker. "
"Did you already shutdown Ray via ray.shutdown()?")
self.worker.check_connected()
core_worker = self.worker.core_worker

# Wait for the next ObjectRef to become ready.
Expand Down Expand Up @@ -325,10 +329,7 @@ class StreamingObjectRefGenerator:
timeout_s: Optional[float] = None,
sleep_interval_s: float = 0.0001):
"""Same API as _next_sync, but it is for async context."""
if not hasattr(self.worker, "core_worker"):
raise ValueError(
"Cannot access the core worker. "
"Did you already shutdown Ray via ray.shutdown()?")
self.worker.check_connected()
core_worker = self.worker.core_worker

ref = core_worker.peek_object_ref_stream(
Expand Down Expand Up @@ -2941,13 +2942,30 @@ cdef class CoreWorker:

return c_object_id.Binary()

def wait(self, object_refs, int num_returns, int64_t timeout_ms,
def wait(self, object_refs_or_generators, int num_returns, int64_t timeout_ms,
TaskID current_task_id, c_bool fetch_local):
cdef:
c_vector[CObjectID] wait_ids
c_vector[c_bool] results
CTaskID c_task_id = current_task_id.native()

object_refs = []
for ref_or_generator in object_refs_or_generators:
if (not isinstance(ref_or_generator, ObjectRef)
and not isinstance(ref_or_generator, StreamingObjectRefGenerator)):
raise TypeError(
"wait() expected a list of ray.ObjectRef "
"or StreamingObjectRefGenerator, "
f"got list containing {type(ref_or_generator)}"
)

if isinstance(ref_or_generator, StreamingObjectRefGenerator):
# Before calling wait,
# get the next reference from a generator.
object_refs.append(ref_or_generator.get_next_ref())
else:
object_refs.append(ref_or_generator)

wait_ids = ObjectRefsToVector(object_refs)
with nogil:
op_status = CCoreWorkerProcess.GetCoreWorker().Wait(
Expand All @@ -2957,11 +2975,11 @@ cdef class CoreWorker:
assert len(results) == len(object_refs)

ready, not_ready = [], []
for i, object_ref in enumerate(object_refs):
for i, object_ref_or_generator in enumerate(object_refs_or_generators):
if results[i]:
ready.append(object_ref)
ready.append(object_ref_or_generator)
else:
not_ready.append(object_ref)
not_ready.append(object_ref_or_generator)

return ready, not_ready

Expand Down
Loading

0 comments on commit 939d1d2

Please sign in to comment.