diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 9ff01599c6494..300d19f0fb963 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -455,6 +455,7 @@ py_test_module_list( "test_basic_4.py", "test_basic_5.py", "test_asyncio.py", + "test_object_assign_owner.py", "test_multiprocessing.py", "test_list_actors.py", "test_list_actors_2.py", diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 17094dae04f14..123d334f4444f 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -19,7 +19,7 @@ # This version string is incremented to indicate breaking changes in the # protocol that require upgrading the client version. -CURRENT_PROTOCOL_VERSION = "2022-07-24" +CURRENT_PROTOCOL_VERSION = "2022-10-05" class _ClientContext: diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 8158e3b554b8f..33b80948cdd2b 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -65,6 +65,7 @@ def chunk_put(req: ray_client_pb2.DataRequest): chunk_id=chunk_id, total_chunks=total_chunks, total_size=total_size, + owner_id=req.put.owner_id, ) yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk) diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index c459b1e80628f..ca09af07fe3f5 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -231,6 +231,7 @@ def Datapath(self, request_iterator, context): self.put_request_chunk_collector.data, req.put.client_ref_id, client_id, + req.put.owner_id, ) self.put_request_chunk_collector.reset() resp = ray_client_pb2.DataResponse(put=put_resp) diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 4d21c27f12a53..ab9e5f937ef27 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -507,13 +507,16 @@ def PutObject( self, request: ray_client_pb2.PutRequest, context=None ) -> ray_client_pb2.PutResponse: """gRPC entrypoint for unary PutObject""" - return self._put_object(request.data, request.client_ref_id, "", context) + return self._put_object( + request.data, request.client_ref_id, "", request.owner_id, context + ) def _put_object( self, data: Union[bytes, bytearray], client_ref_id: bytes, client_id: str, + owner_id: bytes, context=None, ): """Put an object in the cluster with ray.put() via gRPC. @@ -524,12 +527,18 @@ def _put_object( client_ref_id: The id associated with this object on the client. client_id: The client who owns this data, for tracking when to delete this reference. + owner_id: The owner id of the object. context: gRPC context. """ try: obj = loads_from_client(data, self) + + if owner_id: + owner = self.actor_refs[owner_id] + else: + owner = None with disable_client_hook(): - objectref = ray.put(obj) + objectref = ray.put(obj, _owner=owner) except Exception as e: logger.exception("Put failed:") return ray_client_pb2.PutResponse( diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 084bc17838f6d..91acde7467759 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -477,7 +477,13 @@ def _get(self, ref: List[ClientObjectRef], timeout: float): raise decode_exception(e) return loads_from_server(data) - def put(self, val, *, client_ref_id: bytes = None): + def put( + self, + val, + *, + client_ref_id: bytes = None, + _owner: Optional[ClientActorHandle] = None, + ): if isinstance(val, ClientObjectRef): raise TypeError( "Calling 'put' on an ObjectRef is not allowed " @@ -487,12 +493,17 @@ def put(self, val, *, client_ref_id: bytes = None): "call 'put' on it (or return it)." ) data = dumps_from_client(val, self._client_id) - return self._put_pickled(data, client_ref_id) + return self._put_pickled(data, client_ref_id, _owner) - def _put_pickled(self, data, client_ref_id: bytes): + def _put_pickled( + self, data, client_ref_id: bytes, owner: Optional[ClientActorHandle] = None + ): req = ray_client_pb2.PutRequest(data=data) if client_ref_id is not None: req.client_ref_id = client_ref_id + if owner is not None: + req.owner_id = owner.actor_ref.id + resp = self.data_client.PutObject(req) if not resp.valid: try: diff --git a/python/ray/workflow/tests/conftest.py b/python/ray/workflow/tests/conftest.py index c18e17ea24a2c..f14583bbbcace 100644 --- a/python/ray/workflow/tests/conftest.py +++ b/python/ray/workflow/tests/conftest.py @@ -39,13 +39,13 @@ def _workflow_start(storage_url, shared, use_ray_client, **kwargs): assert use_ray_client in {"no_ray_client", "ray_client"} with _init_cluster(storage_url, **kwargs) as cluster: if use_ray_client == "ray_client": - address_info = ray.init( - address=f"ray://{cluster.address.split(':')[0]}:10001" - ) + address = f"ray://{cluster.address.split(':')[0]}:10001" else: - address_info = ray.init(address=cluster.address) + address = cluster.address - yield address_info + ray.init(address=address) + + yield address @pytest.fixture(scope="function") diff --git a/python/ray/workflow/tests/test_basic_workflows_4.py b/python/ray/workflow/tests/test_basic_workflows_4.py index 622420607a1f2..5f8feaa1534e5 100644 --- a/python/ray/workflow/tests/test_basic_workflows_4.py +++ b/python/ray/workflow/tests/test_basic_workflows_4.py @@ -1,6 +1,7 @@ """Basic tests isolated from other tests for shared fixtures.""" import os import pytest +from ray._private.test_utils import run_string_as_driver import ray from ray import workflow @@ -68,6 +69,35 @@ def test_no_init_api(shutdown_only): workflow.list_all() +def test_object_valid(workflow_start_regular): + # Test the async api and make sure the object live + # across the lifetime of the job. + import uuid + + workflow_id = str(uuid.uuid4()) + script = f""" +import ray +from ray import workflow +from typing import List + +ray.init(address="{workflow_start_regular}") + +@ray.remote +def echo(data, sleep_s=0, others=None): + from time import sleep + sleep(sleep_s) + print(data) + +a = {{"abc": "def"}} +e1 = echo.bind(a, 5) +e2 = echo.bind(a, 0, e1) +workflow.run_async(e2, workflow_id="{workflow_id}") +""" + run_string_as_driver(script) + + print(ray.get(workflow.get_output_async(workflow_id=workflow_id))) + + if __name__ == "__main__": import sys diff --git a/python/ray/workflow/workflow_state_from_dag.py b/python/ray/workflow/workflow_state_from_dag.py index 9eed250780055..9fd44a9448e35 100644 --- a/python/ray/workflow/workflow_state_from_dag.py +++ b/python/ray/workflow/workflow_state_from_dag.py @@ -4,7 +4,6 @@ import ray from ray.workflow.common import WORKFLOW_OPTIONS - from ray.dag import DAGNode, FunctionNode, InputNode from ray.dag.input_node import InputAttributeNode, DAGInputData from ray import cloudpickle @@ -168,7 +167,9 @@ def _node_visitor(node: Any) -> Any: flattened_args = _SerializationContextPreservingWrapper( flattened_args ) - input_placeholder: ray.ObjectRef = ray.put(flattened_args) + # Set the owner of the objects to the actor so that even the driver + # exits, these objects are still available. + input_placeholder: ray.ObjectRef = ray.put(flattened_args, _owner=mgr) orig_task_id = workflow_options.get("task_id", None) if orig_task_id is None: diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 6e3c01c331c7e..bcbe5bb64d017 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -114,6 +114,8 @@ message PutRequest { int32 total_chunks = 4; // Total size in bytes of the data being put int64 total_size = 5; + // The owner of the put + bytes owner_id = 6; } message PutResponse {