From 92e04cb6a29c21f9f333fa60f5fabdfdcacaa664 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 25 May 2023 07:44:55 +0900 Subject: [PATCH] [4/N] Support async actor and async generator interface. (#35584) NOTE: It is a big PR, but most of code is testing (200~300 lines I believe) This is the fourth PR to support the streaming generator. The detailed design and API proposal can be found from https://docs.google.com/document/d/1hAASLe2sCoay23raqxqwJdSDiJWNMcNhlTwWJXsJOU4/edit#heading=h.w91y1fgnpu0m. The Execution plan can be found from https://docs.google.com/document/d/1hAASLe2sCoay23raqxqwJdSDiJWNMcNhlTwWJXsJOU4/edit#heading=h.kxktymq5ihf7. There will be 4 PRs to enable streaming generator for Ray Serve (phase 1). This PR -> introduce cpp interfaces to handle intermediate task return [1/N] Streaming Generator. Cpp interfaces and implementation #35291 Support core worker APIs + cython generator interface. [2/N] Streaming Generator. Support core worker APIs + cython generator interface. #35324 E2e integration [3/N] Streaming Generator. E2e integration #35325 (review) Support async actors [4/N] Support async actor and async generator interface. #35382 < ---- This PR adds an async actor execution support to the generator implementation (basically keep posting generator.anext to the event loop) Impelements a standard async generator interface from Python. (anext and aiter) --- python/ray/_private/async_compat.py | 7 +- python/ray/_private/ray_perf.py | 1 - python/ray/_raylet.pxd | 10 +- python/ray/_raylet.pyx | 421 ++++++++++++++---- python/ray/actor.py | 8 +- python/ray/includes/common.pxd | 4 +- python/ray/includes/libcoreworker.pxd | 9 +- python/ray/includes/unique_ids.pxd | 2 + python/ray/includes/unique_ids.pxi | 3 +- python/ray/tests/test_async.py | 20 + python/ray/tests/test_runtime_context.py | 21 + python/ray/tests/test_streaming_generator.py | 315 ++++++++++++- python/ray/util/tracing/tracing_helper.py | 6 + src/ray/common/status.h | 11 +- src/ray/core_worker/context.cc | 23 + src/ray/core_worker/context.h | 20 + src/ray/core_worker/core_worker.cc | 22 +- src/ray/core_worker/core_worker.h | 28 +- src/ray/core_worker/task_manager.cc | 2 +- src/ray/core_worker/task_manager.h | 9 +- src/ray/core_worker/test/core_worker_test.cc | 33 ++ src/ray/core_worker/test/task_manager_test.cc | 12 +- 22 files changed, 838 insertions(+), 149 deletions(-) diff --git a/python/ray/_private/async_compat.py b/python/ray/_private/async_compat.py index b1ecccf2590ec..2e3b03aca6238 100644 --- a/python/ray/_private/async_compat.py +++ b/python/ray/_private/async_compat.py @@ -19,10 +19,15 @@ def get_new_event_loop(): return asyncio.new_event_loop() +def is_async_func(func): + """Return True if the function is an async or async generator method.""" + return inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) + + def sync_to_async(func): """Convert a blocking function to async function""" - if inspect.iscoroutinefunction(func): + if is_async_func(func): return func async def wrapper(*args, **kwargs): diff --git a/python/ray/_private/ray_perf.py b/python/ray/_private/ray_perf.py index 316f3baeca846..d73378b674cd7 100644 --- a/python/ray/_private/ray_perf.py +++ b/python/ray/_private/ray_perf.py @@ -286,7 +286,6 @@ def async_actor_multi(): ray.get([async_actor_work.remote(a) for _ in range(m)]) results += timeit("n:n async-actor calls async", async_actor_multi, m * n) - ray.shutdown() NUM_PGS = 100 NUM_BUNDLES = 1 diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index f871e1bf1729a..7297547b32dfb 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -39,7 +39,8 @@ from ray.includes.libcoreworker cimport ( from ray.includes.unique_ids cimport ( CObjectID, - CActorID + CActorID, + CTaskID, ) from ray.includes.function_descriptor cimport ( CFunctionDescriptor, @@ -154,6 +155,13 @@ cdef class CoreWorker: cdef python_scheduling_strategy_to_c( self, python_scheduling_strategy, CSchedulingStrategy *c_scheduling_strategy) + cdef CObjectID allocate_dynamic_return_id_for_generator( + self, + const CAddress &owner_address, + const CTaskID &task_id, + return_size, + generator_index, + is_async_actor) cdef class FunctionDescriptor: cdef: diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index b929a83096a4b..756085d1bf874 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -23,6 +23,7 @@ import time import traceback import _thread import typing +from typing import Union, Awaitable, Callable, Any from libc.stdint cimport ( int32_t, @@ -107,6 +108,7 @@ from ray.includes.unique_ids cimport ( CObjectID, CNodeID, CPlacementGroupID, + ObjectIDIndexType, ) from ray.includes.libcoreworker cimport ( ActorHandleSharedPtr, @@ -124,7 +126,7 @@ from ray.includes.ray_config cimport RayConfig from ray.includes.global_state_accessor cimport CGlobalStateAccessor from ray.includes.global_state_accessor cimport RedisDelKeySync from ray.includes.optional cimport ( - optional + optional, nullopt ) import ray @@ -149,7 +151,11 @@ from ray.util.scheduling_strategies import ( import ray._private.ray_constants as ray_constants import ray.cloudpickle as ray_pickle from ray.core.generated.common_pb2 import ActorDiedErrorContext -from ray._private.async_compat import sync_to_async, get_new_event_loop +from ray._private.async_compat import ( + sync_to_async, + get_new_event_loop, + is_async_func +) from ray._private.client_mode_hook import disable_client_hook import ray._private.gcs_utils as gcs_utils import ray._private.memory_monitor as memory_monitor @@ -186,6 +192,10 @@ current_task_id_lock = threading.Lock() job_config_initialized = False job_config_initialization_lock = threading.Lock() +# It is used to indicate optional::nullopt for +# AllocateDynamicReturnId. +cdef optional[ObjectIDIndexType] NULL_PUT_INDEX = nullopt + class ObjectRefGenerator: def __init__(self, refs): @@ -202,7 +212,7 @@ class ObjectRefGenerator: return len(self._refs) -class ObjectRefStreamEoFError(RayError): +class ObjectRefStreamEneOfStreamError(RayError): pass @@ -217,7 +227,6 @@ class StreamingObjectRefGenerator: # Ray's worker class. ray._private.worker.global_worker self.worker = worker assert hasattr(worker, "core_worker") - self.worker.core_worker.create_object_ref_stream(self._generator_ref) def __iter__(self) -> "StreamingObjectRefGenerator": return self @@ -233,9 +242,15 @@ class StreamingObjectRefGenerator: up to N + 1 objects (if there's a system failure, the last object will contain a system level exception). """ - return self._next() + return self._next_sync() + + def __aiter__(self): + return self + + async def __anext__(self): + return await self._next_async() - def _next( + def _next_sync( self, timeout_s: float = -1, sleep_interval_s: float = 0.0001, @@ -267,73 +282,144 @@ class StreamingObjectRefGenerator: available within this time, it will hard fail the generator. """ - obj = self._handle_next() + obj = self._handle_next_sync() last_time = time.time() # The generator ref will be None if the task succeeds. # It will contain an exception if the task fails by # a system error. while obj.is_nil(): - if self._generator_task_exception: - # The generator task has failed already. - # We raise StopIteration - # to conform the next interface in Python. - raise StopIteration from None - else: - # Otherwise, we should ray.get on the generator - # ref to find if the task has a system failure. - # Return the generator ref that contains the system - # error as soon as possible. - r, _ = ray.wait([self._generator_ref], timeout=0) - if len(r) > 0: - try: - ray.get(r) - except Exception as e: - # If it has failed, return the generator task ref - # so that the ref will raise an exception. - self._generator_task_exception = e - return self._generator_ref - finally: - if self._generator_task_completed_time is None: - self._generator_task_completed_time = time.time() - - # Currently, since the ordering of intermediate result report - # is not guaranteed, it is possible that althoug the task - # has succeeded, all of the object references are not reported - # (e.g., when there are network failures). - # If all the object refs are not reported to the generator - # within 30 seconds, we consider is as an unreconverable error. - if self._generator_task_completed_time: - if (time.time() - self._generator_task_completed_time - > unexpected_network_failure_timeout_s): - # It means the next wasn't reported although the task - # has been terminated 30 seconds ago. - self._generator_task_exception = AssertionError - assert False, "Unexpected network failure occured." - - if timeout_s != -1 and time.time() - last_time > timeout_s: - return ObjectRef.nil() - - # 100us busy waiting + error_ref = self._handle_error( + False, + last_time, + timeout_s, + unexpected_network_failure_timeout_s) + if error_ref is not None: + return error_ref + time.sleep(sleep_interval_s) - obj = self._handle_next() + obj = self._handle_next_sync() + return obj - def _handle_next(self) -> ObjectRef: + async def _next_async( + self, + timeout_s: float = -1, + sleep_interval_s: float = 0.0001, + unexpected_network_failure_timeout_s: float = 30): + """Same API as _next_sync, but it is for async context.""" + obj = await self._handle_next_async() + last_time = time.time() + + # The generator ref will be None if the task succeeds. + # It will contain an exception if the task fails by + # a system error. + while obj.is_nil(): + error_ref = self._handle_error( + True, + last_time, + timeout_s, + unexpected_network_failure_timeout_s) + if error_ref is not None: + return error_ref + + await asyncio.sleep(sleep_interval_s) + obj = await self._handle_next_async() + + return obj + + async def _handle_next_async(self): try: - if hasattr(self.worker, "core_worker"): - obj = self.worker.core_worker.try_read_next_object_ref_stream( - self._generator_ref) - return obj + return self._handle_next() + except ObjectRefStreamEneOfStreamError: + raise StopAsyncIteration + + def _handle_next_sync(self): + try: + return self._handle_next() + except ObjectRefStreamEneOfStreamError: + raise StopIteration + + def _handle_next(self): + """Get the next item from the ObjectRefStream. + + This API return immediately all the time. It returns a nil object + if it doesn't have the next item ready. It raises + ObjectRefStreamEneOfStreamError if there's nothing more to read. + If there's a next item, it will return a object ref. + """ + if hasattr(self.worker, "core_worker"): + obj = self.worker.core_worker.try_read_next_object_ref_stream( + self._generator_ref) + return obj + else: + raise ValueError( + "Cannot access the core worker. " + "Did you already shutdown Ray via ray.shutdown()?") + + def _handle_error( + self, + is_async: bool, + last_time: int, + timeout_s: float, + unexpected_network_failure_timeout_s: float): + """Handle the error case of next APIs. + + Return None if there's no error. Returns a ref if + the ref is supposed to be return. + """ + if self._generator_task_exception: + # The generator task has failed already. + # We raise StopIteration + # to conform the next interface in Python. + if is_async: + raise StopAsyncIteration else: - raise ValueError( - "Cannot access the core worker. " - "Did you already shutdown Ray via ray.shutdown()?") - except ObjectRefStreamEoFError: - raise StopIteration from None + raise StopIteration + else: + # Otherwise, we should ray.get on the generator + # ref to find if the task has a system failure. + # Return the generator ref that contains the system + # error as soon as possible. + r, _ = ray.wait([self._generator_ref], timeout=0) + if len(r) > 0: + try: + ray.get(r) + except Exception as e: + # If it has failed, return the generator task ref + # so that the ref will raise an exception. + self._generator_task_exception = e + return self._generator_ref + finally: + if self._generator_task_completed_time is None: + self._generator_task_completed_time = time.time() + + # Currently, since the ordering of intermediate result report + # is not guaranteed, it is possible that althoug the task + # has succeeded, all of the object references are not reported + # (e.g., when there are network failures). + # If all the object refs are not reported to the generator + # within 30 seconds, we consider is as an unreconverable error. + if self._generator_task_completed_time: + if (time.time() - self._generator_task_completed_time + > unexpected_network_failure_timeout_s): + # It means the next wasn't reported although the task + # has been terminated 30 seconds ago. + self._generator_task_exception = AssertionError + assert False, ( + "Unexpected network failure occured. " + f"Task ID: {self._generator_ref.task_id().hex()}" + ) + + if timeout_s != -1 and time.time() - last_time > timeout_s: + return ObjectRef.nil() + + return None def __del__(self): if hasattr(self.worker, "core_worker"): + # The stream is created when a task is first submitted via + # CreateObjectRefStream. # NOTE: This can be called multiple times # because python doesn't guarantee __del__ is called # only once. @@ -356,8 +442,8 @@ cdef int check_status(const CRayStatus& status) nogil except -1: raise ObjectStoreFullError(message) elif status.IsOutOfDisk(): raise OutOfDiskError(message) - elif status.IsObjectRefStreamEoF(): - raise ObjectRefStreamEoFError(message) + elif status.IsObjectRefEndOfStream(): + raise ObjectRefStreamEneOfStreamError(message) elif status.IsInterrupted(): raise KeyboardInterrupt() elif status.IsTimedOut(): @@ -442,7 +528,6 @@ cdef increase_recursion_limit(): int CURRENT_DEPTH(CPyThreadState *x) int current_depth = CURRENT_DEPTH(s) - if current_limit - current_depth < 500: Py_SetRecursionLimit(new_limit) logger.debug("Increasing Python recursion limit to {} " @@ -818,6 +903,8 @@ cdef execute_streaming_generator( title, actor, actor_id, + name_of_concurrency_group_to_execute, + return_size, c_bool *is_retryable_error, c_string *application_error): """Execute a given generator and streaming-report the @@ -850,23 +937,37 @@ cdef execute_streaming_generator( actor: The instance of the actor created in this worker. It is used to write an error message. actor_id: The ID of the actor. It is used to write an error message. + return_size: The number of static returns. is_retryable_error(out): It is set to True if the generator raises an exception, and the error is retryable. application_error(out): It is set if the generator raises an application error. """ worker = ray._private.worker.global_worker + # Generator task should only have 1 return object ref, + # which contains None or exceptions (if system error occurs). + assert return_size == 1 + cdef: CoreWorker core_worker = worker.core_worker generator_index = 0 - assert inspect.isgenerator(generator), ( - "execute_streaming_generator's first argument must be a generator." - ) + is_async = inspect.isasyncgen(generator) while True: try: - output = next(generator) + if is_async: + output = core_worker.run_async_func_or_coro_in_event_loop( + generator.__anext__(), + function_descriptor, + name_of_concurrency_group_to_execute) + else: + output = next(generator) + except AsyncioActorExit: + # Make the task handle this exception. + raise + except StopAsyncIteration: + break except StopIteration: break except Exception as e: @@ -882,6 +983,9 @@ cdef execute_streaming_generator( title, actor, actor_id, + return_size, + generator_index, + is_async, is_retryable_error, application_error ) @@ -899,7 +1003,11 @@ cdef execute_streaming_generator( output, generator_id, worker, - caller_address) + caller_address, + task_id, + return_size, + generator_index, + is_async) # Del output here so that we can GC the memory # usage asap. del output @@ -932,7 +1040,11 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( output, const CObjectID &generator_id, worker: "Worker", - const CAddress &caller_address): + const CAddress &caller_address, + TaskID task_id, + return_size, + generator_index, + is_async): """Create a generator return object based on a given output. Args: @@ -942,6 +1054,11 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( caller_address: The address of the caller. By our protocol, the caller of the streaming generator task is always the owner, so we can also call it "owner address". + task_id: The task ID of the generator task. + return_size: The number of static returns. + generator_index: The index of a current error object. + is_async: Whether or not the given object is created within + an async actor. Returns: A Ray Object that contains the given output. @@ -950,9 +1067,13 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker - return_id = ( - CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( - caller_address)) + return_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + return_size, + generator_index, + is_async, + ) intermediate_result.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -977,6 +1098,9 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( title, actor, actor_id, + return_size, + generator_index, + is_async, c_bool *is_retryable_error, c_string *application_error): """Create a generator error object. @@ -1004,6 +1128,10 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( actor: The instance of the actor created in this worker. It is used to write an error message. actor_id: The ID of the actor. It is used to write an error message. + return_size: The number of static returns. + generator_index: The index of a current error object. + is_async: Whether or not the given object is created within + an async actor. is_retryable_error(out): It is set to True if the generator raises an exception, and the error is retryable. application_error(out): It is set if the generator raises an @@ -1038,8 +1166,13 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( "Task failed with unretryable exception:" " {}.".format(task_id), exc_info=True) - error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + error_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + return_size, + generator_index, + is_async, + ) intermediate_result.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -1105,7 +1238,8 @@ cdef execute_dynamic_generator_and_store_task_outputs( # generate one additional ObjectRef. This last # ObjectRef will contain the error. error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + .AllocateDynamicReturnId( + caller_address, CTaskID.Nil(), NULL_PUT_INDEX)) dynamic_returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -1213,10 +1347,10 @@ cdef void execute_task( if core_worker.current_actor_is_asyncio(): if len(inspect.getmembers( actor.__class__, - predicate=inspect.iscoroutinefunction)) == 0: + predicate=is_async_func)) == 0: error_message = ( - "Failed to create actor. The failure reason " - "is that you set the async flag, but the actor does not " + "Failed to create actor. You set the async flag, " + "but the actor does not " "have any coroutine functions.") raise RayActorError( ActorDiedErrorContext( @@ -1226,7 +1360,7 @@ cdef void execute_task( ) ) - if inspect.iscoroutinefunction(function.method): + if is_async_func(function.method): async_function = function else: # Just execute the method if it's ray internal method. @@ -1234,10 +1368,15 @@ cdef void execute_task( return function(actor, *arguments, **kwarguments) async_function = sync_to_async(function) - return core_worker.run_async_func_in_event_loop( - async_function, function_descriptor, - name_of_concurrency_group_to_execute, actor, - *arguments, **kwarguments) + if inspect.isasyncgenfunction(function.method): + # The coroutine will be handled separately by + # execute_dynamic_generator_and_store_task_outputs + return async_function(actor, *arguments, **kwarguments) + else: + return core_worker.run_async_func_or_coro_in_event_loop( + async_function, function_descriptor, + name_of_concurrency_group_to_execute, actor, + *arguments, **kwarguments) return function(actor, *arguments, **kwarguments) @@ -1259,7 +1398,7 @@ cdef void execute_task( return (ray._private.worker.global_worker .deserialize_objects( metadata_pairs, object_refs)) - args = core_worker.run_async_func_in_event_loop( + args = core_worker.run_async_func_or_coro_in_event_loop( deserialize_args, function_descriptor, name_of_concurrency_group_to_execute) else: @@ -1313,7 +1452,8 @@ cdef void execute_task( # which is the generator task return. assert returns[0].size() == 1 - if not inspect.isgenerator(outputs): + if (not inspect.isgenerator(outputs) + and not inspect.isasyncgen(outputs)): raise ValueError( "Functions with " "@ray.remote(num_returns=\"streaming\" " @@ -1331,6 +1471,8 @@ cdef void execute_task( title, actor, actor_id, + name_of_concurrency_group_to_execute, + returns[0].size(), is_retryable_error, application_error) # Streaming generator output is not used, so set it to None. @@ -1386,7 +1528,9 @@ cdef void execute_task( print(task_attempt_magic_token, end="") print(task_attempt_magic_token, file=sys.stderr, end="") - if returns[0].size() == 1 and not inspect.isgenerator(outputs): + if (returns[0].size() == 1 + and not inspect.isgenerator(outputs) + and not inspect.isasyncgen(outputs)): # If there is only one return specified, we should return # all return values as a single object. outputs = (outputs,) @@ -1411,9 +1555,10 @@ cdef void execute_task( # like GCS has such info. core_worker.set_actor_repr_name(actor_repr) - if (returns[0].size() > 0 and - not inspect.isgenerator(outputs) and - len(outputs) != int(returns[0].size())): + if (returns[0].size() > 0 + and not inspect.isgenerator(outputs) + and not inspect.isasyncgen(outputs) + and len(outputs) != int(returns[0].size())): raise ValueError( "Task returned {} objects, but num_returns={}.".format( len(outputs), returns[0].size())) @@ -1603,7 +1748,8 @@ cdef execute_task_with_cancellation_handler( actor, actor_id, execution_info.function_name, - task_type, title, caller_address, returns, + task_type, title, caller_address, + returns, # application_error: we are passing NULL since we don't want the # cancel tasks to fail. NULL) @@ -3267,6 +3413,7 @@ cdef class CoreWorker: int64_t task_output_inlined_bytes int64_t num_returns = -1 shared_ptr[CRayObject] *return_ptr + num_outputs_stored = 0 if not ref_generator_id.IsNil(): # The task specified a dynamic number of return values. Determine @@ -3301,7 +3448,8 @@ cdef class CoreWorker: # enabled by default. while i >= returns[0].size(): return_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + .AllocateDynamicReturnId( + caller_address, CTaskID.Nil(), NULL_PUT_INDEX)) returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -3431,9 +3579,23 @@ cdef class CoreWorker: return self.eventloop_for_default_cg, self.thread_for_default_cg - def run_async_func_in_event_loop( - self, func, function_descriptor, specified_cgname, *args, **kwargs): + def run_async_func_or_coro_in_event_loop( + self, + func_or_coro: Union[Callable[[Any, Any], Awaitable[Any]], Awaitable], + function_descriptor: FunctionDescriptor, + specified_cgname: str, + *args, + **kwargs): + """Run the async function or coroutine to the event loop. + The event loop is running in a separate thread. + Args: + func_or_coro: Async function (not a generator) or awaitable objects. + function_descriptor: The function descriptor. + specified_cgname: The name of a concurrent group. + args: The arguments for the async function. + kwargs: The keyword arguments for the async function. + """ cdef: CFiberEvent event @@ -3448,7 +3610,12 @@ cdef class CoreWorker: eventloop, async_thread = self.get_event_loop( function_descriptor, specified_cgname) - coroutine = func(*args, **kwargs) + + if inspect.isawaitable(func_or_coro): + coroutine = func_or_coro + else: + coroutine = func_or_coro(*args, **kwargs) + future = asyncio.run_coroutine_threadsafe(coroutine, eventloop) future.add_done_callback(lambda _: event.Notify()) with nogil: @@ -3583,11 +3750,71 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker() \ .RecordTaskLogEnd(out_end_offset, err_end_offset) - def create_object_ref_stream(self, ObjectRef generator_id): - cdef: - CObjectID c_generator_id = generator_id.native() + cdef CObjectID allocate_dynamic_return_id_for_generator( + self, + const CAddress &owner_address, + const CTaskID &task_id, + return_size, + generator_index, + is_async_actor): + """Allocate a dynamic return ID for a generator task. + + NOTE: When is_async_actor is True, + this API SHOULD NOT BE called + within an async actor's event IO thread. The caller MUST ensure + this for correctness. It is due to the limitation WorkerContext + API when async actor is used. + See https://github.com/ray-project/ray/issues/10324 for further details. - CCoreWorkerProcess.GetCoreWorker().CreateObjectRefStream(c_generator_id) + Args: + owner_address: The address of the owner (caller) of the + generator task. + task_id: The task ID of the generator task. + return_size: The size of the static return from the task. + generator_index: The index of dynamically generated object + ref. + is_async_actor: True if the allocation is for async actor. + If async actor is used, we should calculate the + put_index ourselves. + """ + # Generator only has 1 static return. + assert return_size == 1 + if is_async_actor: + # This part of code has a couple of assumptions. + # - This API is not called within an asyncio event loop + # thread. + # - Ray object ref is generated by incrementing put_index + # whenever a new return value is added or ray.put is called. + # + # When an async actor is used, it uses its own thread to execute + # async tasks. That means all the ray.put will use a put_index + # scoped to a asyncio event loop thread. + # This means the execution thread that this API will be called + # will only create "return" objects. That means if we use + # return_size + genreator_index as a put_index, it is guaranteed + # to be unique. + # + # Why do we need it? + # + # We have to provide a put_index ourselves here because + # the current implementation only has 1 worker context at any + # given time, meaning WorkerContext::TaskID & WorkerContext::PutIndex + # both could be incorrect (duplicated) when this API is called. + return CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( + owner_address, + task_id, + # Should add 1 because put index is always incremented + # before it is used. So if you have 1 return object + # the next index will be 2. + make_optional[ObjectIDIndexType]( + 1 + return_size + generator_index) # put_index + ) + else: + return CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( + owner_address, + CTaskID.Nil(), + NULL_PUT_INDEX + ) def delete_object_ref_stream(self, ObjectRef generator_id): cdef: diff --git a/python/ray/actor.py b/python/ray/actor.py index bfd28aa9ab0f2..c9a0600315123 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -9,6 +9,7 @@ import ray._raylet from ray import ActorClassID, Language, cross_language from ray._private import ray_option_utils +from ray._private.async_compat import is_async_func from ray._private.auto_init_hook import auto_init_ray from ray._private.client_mode_hook import ( client_mode_convert_actor, @@ -756,12 +757,7 @@ def _remote(self, args=None, kwargs=None, **actor_options): kwargs = {} meta = self.__ray_metadata__ actor_has_async_methods = ( - len( - inspect.getmembers( - meta.modified_class, predicate=inspect.iscoroutinefunction - ) - ) - > 0 + len(inspect.getmembers(meta.modified_class, predicate=is_async_func)) > 0 ) is_asyncio = actor_has_async_methods diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 2d284fd7c8ae3..13ed1c06d6429 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -100,7 +100,7 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: CRayStatus NotFound() @staticmethod - CRayStatus ObjectRefStreamEoF() + CRayStatus ObjectRefEndOfStream() c_bool ok() c_bool IsOutOfMemory() @@ -121,7 +121,7 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: c_bool IsObjectUnknownOwner() c_bool IsRpcError() c_bool IsOutOfResource() - c_bool IsObjectRefStreamEoF() + c_bool IsObjectRefEndOfStream() c_string ToString() c_string CodeAsString() diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 3998de724433b..cc0b3092ffb28 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport ( CObjectID, CPlacementGroupID, CWorkerID, + ObjectIDIndexType, ) from ray.includes.common cimport ( @@ -49,7 +50,7 @@ from ray.includes.function_descriptor cimport ( ) from ray.includes.optional cimport ( - optional + optional, ) ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \ @@ -148,11 +149,13 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: shared_ptr[CRayObject] *return_object, const CObjectID& generator_id) void DelObjectRefStream(const CObjectID &generator_id) - void CreateObjectRefStream(const CObjectID &generator_id) CRayStatus TryReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) - CObjectID AllocateDynamicReturnId(const CAddress &owner_address) + CObjectID AllocateDynamicReturnId( + const CAddress &owner_address, + const CTaskID &task_id, + optional[ObjectIDIndexType] put_index) CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index cd7890119a407..2fb14e6322c0f 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -173,3 +173,5 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: @staticmethod CPlacementGroupID Of(CJobID job_id) + + ctypedef uint32_t ObjectIDIndexType diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 2b4f5c78f5ba3..8221111a29556 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -21,9 +21,10 @@ from ray.includes.unique_ids cimport ( CTaskID, CUniqueID, CWorkerID, - CPlacementGroupID + CPlacementGroupID, ) + import ray from ray._private.utils import decode diff --git a/python/ray/tests/test_async.py b/python/ray/tests/test_async.py index 21fa6a026c312..5136fb2bb593b 100644 --- a/python/ray/tests/test_async.py +++ b/python/ray/tests/test_async.py @@ -8,6 +8,7 @@ import pytest import ray +from ray._private.async_compat import is_async_func from ray._private.test_utils import wait_for_condition from ray._private.utils import ( get_or_create_event_loop, @@ -33,6 +34,25 @@ def f(n): return [f.remote(i) for i in range(5)] +def test_is_async_func(): + def f(): + return 1 + + def f_gen(): + yield 1 + + async def g(): + return 1 + + async def g_gen(): + yield 1 + + assert is_async_func(f) is False + assert is_async_func(f_gen) is False + assert is_async_func(g) is True + assert is_async_func(g_gen) is True + + def test_simple(init): @ray.remote def f(): diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py index 42b7b5fed42e5..503ab6a10320e 100644 --- a/python/ray/tests/test_runtime_context.py +++ b/python/ray/tests/test_runtime_context.py @@ -240,6 +240,27 @@ async def func(self): assert max(result["AysncActor.func"]["pending"] for result in results) == 3 +def test_actor_stats_async_actor_generator(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class AysncActor: + async def func(self): + await signal.wait.remote() + yield ray.get_runtime_context()._get_actor_call_stats() + + actor = AysncActor.options(max_concurrency=3).remote() + gens = [actor.func.options(num_returns="streaming").remote() for _ in range(6)] + time.sleep(1) + signal.send.remote() + results = [] + for gen in gens: + for ref in gen: + results.append(ray.get(ref)) + assert max(result["AysncActor.func"]["running"] for result in results) == 3 + assert max(result["AysncActor.func"]["pending"] for result in results) == 3 + + # Use default filterwarnings behavior for this test @pytest.mark.filterwarnings("default") def test_ids(ray_start_regular): diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 4e46ba66837f2..68b0c6ba5ed3d 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -1,3 +1,4 @@ +import asyncio import pytest import numpy as np import sys @@ -9,7 +10,7 @@ import ray from ray._private.test_utils import wait_for_condition from ray.experimental.state.api import list_objects -from ray._raylet import StreamingObjectRefGenerator, ObjectRefStreamEoFError +from ray._raylet import StreamingObjectRefGenerator, ObjectRefStreamEneOfStreamError from ray.cloudpickle import dumps from ray.exceptions import WorkerCrashedError @@ -45,28 +46,26 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): generator_ref = ray.ObjectRef.from_random() generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - c.create_object_ref_stream.assert_called() # Test when there's no new ref, it returns a nil. mocked_ray_wait.return_value = [], [generator_ref] - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) assert ref.is_nil() # When the new ref is available, next should return it. for _ in range(3): new_ref = ray.ObjectRef.from_random() c.try_read_next_object_ref_stream.return_value = new_ref - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) assert new_ref == ref # When try_read_next_object_ref_stream raises a - # ObjectRefStreamEoFError, it should raise a stop iteration. - c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEoFError( + # ObjectRefStreamEneOfStreamError, it should raise a stop iteration. + c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEneOfStreamError( "" ) # noqa with pytest.raises(StopIteration): - ref = generator._next(timeout_s=0) - + ref = generator._next_sync(timeout_s=0) # Make sure we cannot serialize the generator. with pytest.raises(TypeError): dumps(generator) @@ -91,19 +90,17 @@ def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): mocked_ray_get.side_effect = WorkerCrashedError() c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) # If the generator task fails by a systsem error, # meaning the ref will raise an exception # it should be returned. - print(ref) - print(generator_ref) assert ref == generator_ref # Once exception is raised, it should always # raise stopIteration regardless of what # the ref contains now. with pytest.raises(StopIteration): - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): @@ -128,14 +125,83 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): # unexpected_network_failure_timeout_s second, # it should fail. c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + ref = generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) assert ref == ray.ObjectRef.nil() time.sleep(1) with pytest.raises(AssertionError): - generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) # After that StopIteration should be raised. with pytest.raises(StopIteration): - generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) + + +@pytest.mark.asyncio +async def test_streaming_object_ref_generator_unit_async(mocked_worker): + """ + Verify the basic case: + create a generator -> read values -> nothing more to read -> delete. + """ + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() + + # Test when there's no new ref, it returns a nil. + mocked_ray_wait.return_value = [], [generator_ref] + ref = await generator._next_async(timeout_s=0) + assert ref.is_nil() + + # When the new ref is available, next should return it. + for _ in range(3): + new_ref = ray.ObjectRef.from_random() + c.try_read_next_object_ref_stream.return_value = new_ref + ref = await generator._next_async(timeout_s=0) + assert new_ref == ref + + # When try_read_next_object_ref_stream raises a + # ObjectRefStreamEneOfStreamError, it should raise a stop iteration. + c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEneOfStreamError( + "" + ) # noqa + with pytest.raises(StopAsyncIteration): + ref = await generator._next_async(timeout_s=0) + + +@pytest.mark.asyncio +async def test_async_ref_generator_task_failed_unit(mocked_worker): + """ + Verify when a task is failed by a system error, + the generator ref is returned. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the worker failure happens. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.side_effect = WorkerCrashedError() + + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = await generator._next_async(timeout_s=0) + # If the generator task fails by a systsem error, + # meaning the ref will raise an exception + # it should be returned. + assert ref == generator_ref + + # Once exception is raised, it should always + # raise stopIteration regardless of what + # the ref contains now. + with pytest.raises(StopAsyncIteration): + ref = await generator._next_async(timeout_s=0) def test_generator_basic(shutdown_only): @@ -368,6 +434,7 @@ def generator(num_returns, store_in_plasma): def test_generator_dist_chain(ray_start_cluster): + """E2E test to verify chain of generator works properly.""" cluster = ray_start_cluster cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) ray.init() @@ -402,6 +469,224 @@ def get_data(self): del ref +@pytest.mark.parametrize("store_in_plasma", [False, True]) +def test_actor_streaming_generator(shutdown_only, store_in_plasma): + """Test actor/async actor with sync/async generator interfaces.""" + ray.init() + + @ray.remote + class Actor: + def f(self, ref): + for i in range(3): + yield i + + async def async_f(self, ref): + for i in range(3): + await asyncio.sleep(0.1) + yield i + + def g(self): + return 3 + + a = Actor.remote() + if store_in_plasma: + arr = np.random.rand(5 * 1024 * 1024) + else: + arr = 3 + + def verify_sync_task_executor(): + generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + # Verify it works with next. + assert isinstance(generator, StreamingObjectRefGenerator) + assert ray.get(next(generator)) == 0 + assert ray.get(next(generator)) == 1 + assert ray.get(next(generator)) == 2 + with pytest.raises(StopIteration): + ray.get(next(generator)) + + # Verify it works with for. + generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + for index, ref in enumerate(generator): + assert index == ray.get(ref) + + def verify_async_task_executor(): + # Verify it works with next. + generator = a.async_f.options(num_returns="streaming").remote(ray.put(arr)) + assert isinstance(generator, StreamingObjectRefGenerator) + assert ray.get(next(generator)) == 0 + assert ray.get(next(generator)) == 1 + assert ray.get(next(generator)) == 2 + + # Verify it works with for. + generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + for index, ref in enumerate(generator): + assert index == ray.get(ref) + + async def verify_sync_task_async_generator(): + # Verify anext + async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + assert isinstance(async_generator, StreamingObjectRefGenerator) + for expected in range(3): + ref = await async_generator.__anext__() + assert await ref == expected + with pytest.raises(StopAsyncIteration): + await async_generator.__anext__() + + # Verify async for. + async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + expected = 0 + async for ref in async_generator: + value = await ref + assert value == value + expected += 1 + + async def verify_async_task_async_generator(): + async_generator = a.async_f.options(num_returns="streaming").remote( + ray.put(arr) + ) + assert isinstance(async_generator, StreamingObjectRefGenerator) + for expected in range(3): + ref = await async_generator.__anext__() + assert await ref == expected + with pytest.raises(StopAsyncIteration): + await async_generator.__anext__() + + # Verify async for. + async_generator = a.async_f.options(num_returns="streaming").remote( + ray.put(arr) + ) + expected = 0 + async for value in async_generator: + value = await ref + assert value == value + expected += 1 + + verify_sync_task_executor() + verify_async_task_executor() + asyncio.run(verify_sync_task_async_generator()) + asyncio.run(verify_async_task_async_generator()) + + +def test_streaming_generator_exception(shutdown_only): + # Verify the exceptions are correctly raised. + # Also verify the followup next will raise StopIteration. + ray.init() + + @ray.remote + class Actor: + def f(self): + raise ValueError + yield 1 # noqa + + async def async_f(self): + raise ValueError + yield 1 # noqa + + a = Actor.remote() + g = a.f.options(num_returns="streaming").remote() + with pytest.raises(ValueError): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + g = a.async_f.options(num_returns="streaming").remote() + with pytest.raises(ValueError): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + +def test_threaded_actor_generator(shutdown_only): + ray.init() + + @ray.remote(max_concurrency=10) + class Actor: + def f(self): + for i in range(30): + time.sleep(0.1) + yield np.ones(1024 * 1024) * i + + @ray.remote(max_concurrency=20) + class AsyncActor: + async def f(self): + for i in range(30): + await asyncio.sleep(0.1) + yield np.ones(1024 * 1024) * i + + async def main(): + a = Actor.remote() + asy = AsyncActor.remote() + + async def run(): + i = 0 + async for ref in a.f.options(num_returns="streaming").remote(): + val = ray.get(ref) + print(val) + print(ref) + assert np.array_equal(val, np.ones(1024 * 1024) * i) + i += 1 + del ref + + async def run2(): + i = 0 + async for ref in asy.f.options(num_returns="streaming").remote(): + val = await ref + print(ref) + print(val) + assert np.array_equal(val, np.ones(1024 * 1024) * i), ref + i += 1 + del ref + + coroutines = [run() for _ in range(10)] + coroutines = [run2() for _ in range(20)] + + await asyncio.gather(*coroutines) + + asyncio.run(main()) + + +def test_generator_dist_gather(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) + ray.init() + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + + @ray.remote(num_cpus=1) + class Actor: + def __init__(self, child=None): + self.child = child + + def get_data(self): + for _ in range(10): + time.sleep(0.1) + yield np.ones(5 * 1024 * 1024) + + async def all_gather(): + actor = Actor.remote() + async for ref in actor.get_data.options(num_returns="streaming").remote(): + val = await ref + assert np.array_equal(np.ones(5 * 1024 * 1024), val) + del ref + + async def main(): + await asyncio.gather(all_gather(), all_gather(), all_gather(), all_gather()) + + asyncio.run(main()) + summary = ray._private.internal_api.memory_summary(stats_only=True) + print(summary) + + if __name__ == "__main__": import os diff --git a/python/ray/util/tracing/tracing_helper.py b/python/ray/util/tracing/tracing_helper.py index 0c027c33a8e7f..985edb0d612cc 100644 --- a/python/ray/util/tracing/tracing_helper.py +++ b/python/ray/util/tracing/tracing_helper.py @@ -520,6 +520,12 @@ async def _resume_span( if is_static_method(_cls, name) or is_class_method(method): continue + if inspect.isgeneratorfunction(method) or inspect.isasyncgenfunction(method): + # Right now, this method somehow changes the signature of the method + # when they are generator. + # TODO(sang): Fix it. + continue + # Don't decorate the __del__ magic method. # It's because the __del__ can be called after Python # modules are garbage colleted, which means the modules diff --git a/src/ray/common/status.h b/src/ray/common/status.h index 25d9befdfd089..cfbcff3dfc897 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -115,7 +115,8 @@ enum class StatusCode : char { ObjectUnknownOwner = 29, RpcError = 30, OutOfResource = 31, - ObjectRefStreamEoF = 32 + // Meaning the ObjectRefStream reaches to the end of stream. + ObjectRefEndOfStream = 32 }; #if defined(__clang__) @@ -147,8 +148,8 @@ class RAY_EXPORT Status { return Status(StatusCode::KeyError, msg); } - static Status ObjectRefStreamEoF(const std::string &msg) { - return Status(StatusCode::ObjectRefStreamEoF, msg); + static Status ObjectRefEndOfStream(const std::string &msg) { + return Status(StatusCode::ObjectRefEndOfStream, msg); } static Status TypeError(const std::string &msg) { @@ -259,7 +260,9 @@ class RAY_EXPORT Status { bool IsOutOfMemory() const { return code() == StatusCode::OutOfMemory; } bool IsOutOfDisk() const { return code() == StatusCode::OutOfDisk; } bool IsKeyError() const { return code() == StatusCode::KeyError; } - bool IsObjectRefStreamEoF() const { return code() == StatusCode::ObjectRefStreamEoF; } + bool IsObjectRefEndOfStream() const { + return code() == StatusCode::ObjectRefEndOfStream; + } bool IsInvalid() const { return code() == StatusCode::Invalid; } bool IsIOError() const { return code() == StatusCode::IOError; } bool IsTypeError() const { return code() == StatusCode::TypeError; } diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 125f42d17e392..7715d96368510 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -363,6 +363,29 @@ bool WorkerContext::CurrentActorDetached() const { return is_detached_actor_; } +const ObjectID WorkerContext::GetGeneratorReturnId( + const TaskID &task_id, std::optional put_index) { + TaskID current_task_id; + // We only allow to specify both task id and put index or not specifying both. + RAY_CHECK((task_id.IsNil() && !put_index.has_value()) || + (!task_id.IsNil() || put_index.has_value())); + if (task_id.IsNil()) { + const auto &task_spec = GetCurrentTask(); + current_task_id = task_spec->TaskId(); + } else { + current_task_id = task_id; + } + + ObjectIDIndexType current_put_index; + if (!put_index.has_value()) { + current_put_index = GetNextPutIndex(); + } else { + current_put_index = put_index.value(); + } + + return ObjectID::FromIndex(current_task_id, current_put_index); +} + WorkerThreadContext &WorkerContext::GetThreadContext() const { if (thread_context_ == nullptr) { absl::ReaderMutexLock lock(&mutex_); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 6920639005268..b7d2d50e7260e 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -31,6 +31,26 @@ class WorkerContext { public: WorkerContext(WorkerType worker_type, const WorkerID &worker_id, const JobID &job_id); + // Return the generator return ID. + /// + /// By default, it deduces a generator return ID from a current task + /// from the context. However, it also supports manual specification of + /// put index and task id to support `AllocateDynamicReturnId`. + /// See the docstring of AllocateDynamicReturnId for more details. + /// + /// The caller should either not specify both task_id AND put_index + /// or specify both at the same time. Otherwise it will panic. + /// + /// \param[in] task_id The task id of the dynamically generated return ID. + /// If Nil() is specified, it will deduce the Task ID from the current + /// worker context. + /// \param[in] put_index The equivalent of the return value of + /// WorkerContext::GetNextPutIndex. + /// If std::nullopt is specified, it will deduce the put index from the + /// current worker context. + const ObjectID GetGeneratorReturnId(const TaskID &task_id, + std::optional put_index); + const WorkerType GetWorkerType() const; const WorkerID &GetWorkerID() const; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 0931631052262..327a04c671a2f 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1947,6 +1947,13 @@ std::vector CoreWorker::SubmitTask( } else { returned_refs = task_manager_->AddPendingTask( task_spec.CallerAddress(), task_spec, CurrentCallSite(), max_retries); + + // If it is a generator task, create a object ref stream. + // The language frontend is responsible for calling DeleteObjectRefStream. + if (task_spec.IsStreamingGenerator()) { + CreateObjectRefStream(task_spec.ReturnId(0)); + } + io_service_.post( [this, task_spec]() { RAY_UNUSED(direct_task_submitter_->SubmitTask(task_spec)); @@ -2272,6 +2279,13 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, } else { returned_refs = task_manager_->AddPendingTask( rpc_address_, task_spec, CurrentCallSite(), actor_handle->MaxTaskRetries()); + + // If it is a generator task, create a object ref stream. + // The language frontend is responsible for calling DeleteObjectRefStream. + if (task_spec.IsStreamingGenerator()) { + CreateObjectRefStream(task_spec.ReturnId(0)); + } + RAY_CHECK_OK(direct_actor_submitter_->SubmitTask(task_spec)); } task_returns = std::move(returned_refs); @@ -2856,10 +2870,10 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, } } -ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address) { - const auto &task_spec = worker_context_.GetCurrentTask(); - const auto return_id = - ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); +ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address, + const TaskID &task_id, + std::optional put_index) { + const auto return_id = worker_context_.GetGeneratorReturnId(task_id, put_index); AddLocalReference(return_id, ""); reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); return return_id; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index e4793944d4253..574ea8b69a95e 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -370,14 +370,16 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void CreateObjectRefStream(const ObjectID &generator_id); /// Read the next index of a ObjectRefStream of generator_id. + /// This API always return immediately. /// /// \param[in] generator_id The object ref id of the streaming /// generator task. /// \param[out] object_ref_out The ObjectReference /// that the caller can convert to its own ObjectRef. /// The current process is always the owner of the - /// generated ObjectReference. - /// \return Status RayKeyError if the stream reaches to EoF. + /// generated ObjectReference. It will be Nil() if there's + /// no next item. + /// \return Status ObjectRefEndOfStream if the stream reaches to EoF. /// OK otherwise. Status TryReadObjectRefStream(const ObjectID &generator_id, rpc::ObjectReference *object_ref_out); @@ -1025,11 +1027,27 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// object to the task caller and have the resulting ObjectRef be owned by /// the caller. This is in contrast to static allocation, where the caller /// decides at task invocation time how many returns the task should have. + /// + /// NOTE: Normally task_id and put_index it not necessary to be specified + /// because we can obtain them from the global worker context. However, + /// when the async actor uses this API, it cannot find the correct + /// worker context due to the implementation limitation. + /// In this case, the caller is responsible for providing the correct + /// task ID and index. + /// See https://github.com/ray-project/ray/issues/10324 for the further details. + /// /// \param[in] owner_address The address of the owner who will own this /// dynamically generated object. - /// - /// \param[out] The ObjectID that the caller should use to store the object. - ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address); + /// \param[in] task_id The task id of the dynamically generated return ID. + /// If Nil() is specified, it will deduce the Task ID from the current + /// worker context. + /// \param[in] put_index The equivalent of the return value of + /// WorkerContext::GetNextPutIndex. + /// If std::nullopt is specified, it will deduce the put index from the + /// current worker context. + ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address, + const TaskID &task_id = TaskID::Nil(), + std::optional put_index = -1); /// Get a handle to an actor. /// diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index a451217358475..4577a294f09a6 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -54,7 +54,7 @@ Status ObjectRefStream::TryReadNextItem(ObjectID *object_id_out) { RAY_LOG(DEBUG) << "ObjectRefStream of an id " << generator_id_ << " has no more objects."; *object_id_out = ObjectID::Nil(); - return Status::ObjectRefStreamEoF(""); + return Status::ObjectRefEndOfStream(""); } auto it = item_index_to_refs_.find(next_index_); diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index e3abfb24d48e3..5db92ca75ff1e 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -243,8 +243,12 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// Create the object ref stream. /// If the object ref stream is not created by this API, /// all object ref stream operation will be no-op. + /// /// Once the stream is created, it has to be deleted /// by DelObjectRefStream when it is not used anymore. + /// Once you generate a stream, it is the caller's responsibility + /// to call DelObjectRefStream. + /// /// The API is not idempotent. /// /// \param[in] generator_id The object ref id of the streaming @@ -257,8 +261,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// generator task. bool ObjectRefStreamExists(const ObjectID &generator_id); - /// Asynchronously read object reference of the next index from the + /// Read object reference of the next index from the /// object stream of a generator_id. + /// This API always return immediately. /// /// The caller should ensure the ObjectRefStream is already created /// via CreateObjectRefStream. @@ -267,7 +272,7 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// /// \param[out] object_id_out The next object ID from the stream. /// Nil ID is returned if the next index hasn't been written. - /// \return KeyError if it reaches to EoF. Ok otherwise. + /// \return ObjectRefEndOfStream if it reaches to EoF. Ok otherwise. Status TryReadObjectRefStream(const ObjectID &generator_id, ObjectID *object_id_out); /// Returns true if task can be retried. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 62dd91f4474b0..31b6089d854c3 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -684,6 +684,39 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { ASSERT_EQ(context.GetNextPutIndex(), num_returns + 2); } +TEST_F(ZeroNodeTest, TestWorkerContextGeneratorReturn) { + auto job_id = NextJobId(); + + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), job_id); + TaskSpecification task_spec; + size_t num_returns = 1; + task_spec.GetMutableMessage().set_job_id(job_id.Binary()); + task_spec.GetMutableMessage().set_num_returns(num_returns); + context.ResetCurrentTask(); + context.SetCurrentTask(task_spec); + ASSERT_EQ(context.GetCurrentTaskID(), task_spec.TaskId()); + ; + + // Verify when task ID is nil and put index is nullopt, + // it deduces the next return ID from the current context. + auto return_id = context.GetGeneratorReturnId(TaskID::Nil(), std::nullopt); + ASSERT_EQ(return_id.TaskId(), context.GetCurrentTaskID()); + ASSERT_EQ(return_id, ObjectID::FromIndex(context.GetCurrentTaskID(), 2)); + auto return_id2 = context.GetGeneratorReturnId(TaskID::Nil(), std::nullopt); + ASSERT_EQ(return_id2.TaskId(), context.GetCurrentTaskID()); + ASSERT_EQ(return_id2, ObjectID::FromIndex(context.GetCurrentTaskID(), 3)); + + // Verify manual specification of put index and taskId. + auto task_id = TaskID::FromRandom(job_id); + auto put_index = 1; + return_id = context.GetGeneratorReturnId(task_id, put_index); + ASSERT_EQ(return_id.TaskId(), task_id); + ASSERT_EQ(return_id, ObjectID::FromIndex(task_id, put_index)); + // Although we repeat, it should return the same value. + return_id = context.GetGeneratorReturnId(task_id, put_index); + ASSERT_EQ(return_id, ObjectID::FromIndex(task_id, put_index)); +} + TEST_F(ZeroNodeTest, TestActorHandle) { // Test actor handle serialization and deserialization round trip. JobID job_id = NextJobId(); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 69600d77feb9f..e01d8f8f8d31b 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -1268,7 +1268,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamBasic) { } // READ (EoF) auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE manager_.DelObjectRefStream(generator_id); @@ -1315,13 +1315,13 @@ TEST_F(TaskManagerTest, TestObjectRefStreamMixture) { ObjectID obj_id; // READ (EoF) auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE manager_.DelObjectRefStream(generator_id); } -TEST_F(TaskManagerTest, TestObjectRefStreamEoF) { +TEST_F(TaskManagerTest, TestObjectRefEndOfStream) { /** * Test that after writing EoF, write/read doesn't work. * CREATE WRITE WRITEEoF, WRITE(verify no op) DELETE @@ -1364,7 +1364,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEoF) { ASSERT_TRUE(manager_.HandleReportGeneratorItemReturns(req)); // READ (doesn't works because EoF is already written) status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); } TEST_F(TaskManagerTest, TestObjectRefStreamIndexDiscarded) { @@ -1529,7 +1529,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { // Nothing more to read. status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); manager_.DelObjectRefStream(generator_id); } @@ -1670,7 +1670,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamOutofOrder) { // READ (EoF) auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE manager_.DelObjectRefStream(generator_id);