Skip to content

Commit

Permalink
[workflow] Fix the object loss due to driver exit issues. (#29092)
Browse files Browse the repository at this point in the history
When the workflow runs in driver mode, the owner of the object ref is the driver. So when the driver exits, the objects are no longer available. This happens when we run with `run_async`.

This PR fixed this by passing the manager actor as the owner of the objects.
  • Loading branch information
fishbone authored and maxpumperla committed Oct 7, 2022
1 parent 39fe2aa commit 5449b23
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 13 deletions.
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ py_test_module_list(
"test_basic_4.py",
"test_basic_5.py",
"test_asyncio.py",
"test_object_assign_owner.py",
"test_multiprocessing.py",
"test_list_actors.py",
"test_list_actors_2.py",
Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2022-07-24"
CURRENT_PROTOCOL_VERSION = "2022-10-05"


class _ClientContext:
Expand Down
1 change: 1 addition & 0 deletions python/ray/util/client/dataclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def chunk_put(req: ray_client_pb2.DataRequest):
chunk_id=chunk_id,
total_chunks=total_chunks,
total_size=total_size,
owner_id=req.put.owner_id,
)
yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)

Expand Down
1 change: 1 addition & 0 deletions python/ray/util/client/server/dataservicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def Datapath(self, request_iterator, context):
self.put_request_chunk_collector.data,
req.put.client_ref_id,
client_id,
req.put.owner_id,
)
self.put_request_chunk_collector.reset()
resp = ray_client_pb2.DataResponse(put=put_resp)
Expand Down
13 changes: 11 additions & 2 deletions python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,16 @@ def PutObject(
self, request: ray_client_pb2.PutRequest, context=None
) -> ray_client_pb2.PutResponse:
"""gRPC entrypoint for unary PutObject"""
return self._put_object(request.data, request.client_ref_id, "", context)
return self._put_object(
request.data, request.client_ref_id, "", request.owner_id, context
)

def _put_object(
self,
data: Union[bytes, bytearray],
client_ref_id: bytes,
client_id: str,
owner_id: bytes,
context=None,
):
"""Put an object in the cluster with ray.put() via gRPC.
Expand All @@ -524,12 +527,18 @@ def _put_object(
client_ref_id: The id associated with this object on the client.
client_id: The client who owns this data, for tracking when to
delete this reference.
owner_id: The owner id of the object.
context: gRPC context.
"""
try:
obj = loads_from_client(data, self)

if owner_id:
owner = self.actor_refs[owner_id]
else:
owner = None
with disable_client_hook():
objectref = ray.put(obj)
objectref = ray.put(obj, _owner=owner)
except Exception as e:
logger.exception("Put failed:")
return ray_client_pb2.PutResponse(
Expand Down
17 changes: 14 additions & 3 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,13 @@ def _get(self, ref: List[ClientObjectRef], timeout: float):
raise decode_exception(e)
return loads_from_server(data)

def put(self, val, *, client_ref_id: bytes = None):
def put(
self,
val,
*,
client_ref_id: bytes = None,
_owner: Optional[ClientActorHandle] = None,
):
if isinstance(val, ClientObjectRef):
raise TypeError(
"Calling 'put' on an ObjectRef is not allowed "
Expand All @@ -487,12 +493,17 @@ def put(self, val, *, client_ref_id: bytes = None):
"call 'put' on it (or return it)."
)
data = dumps_from_client(val, self._client_id)
return self._put_pickled(data, client_ref_id)
return self._put_pickled(data, client_ref_id, _owner)

def _put_pickled(self, data, client_ref_id: bytes):
def _put_pickled(
self, data, client_ref_id: bytes, owner: Optional[ClientActorHandle] = None
):
req = ray_client_pb2.PutRequest(data=data)
if client_ref_id is not None:
req.client_ref_id = client_ref_id
if owner is not None:
req.owner_id = owner.actor_ref.id

resp = self.data_client.PutObject(req)
if not resp.valid:
try:
Expand Down
10 changes: 5 additions & 5 deletions python/ray/workflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def _workflow_start(storage_url, shared, use_ray_client, **kwargs):
assert use_ray_client in {"no_ray_client", "ray_client"}
with _init_cluster(storage_url, **kwargs) as cluster:
if use_ray_client == "ray_client":
address_info = ray.init(
address=f"ray://{cluster.address.split(':')[0]}:10001"
)
address = f"ray://{cluster.address.split(':')[0]}:10001"
else:
address_info = ray.init(address=cluster.address)
address = cluster.address

yield address_info
ray.init(address=address)

yield address


@pytest.fixture(scope="function")
Expand Down
30 changes: 30 additions & 0 deletions python/ray/workflow/tests/test_basic_workflows_4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Basic tests isolated from other tests for shared fixtures."""
import os
import pytest
from ray._private.test_utils import run_string_as_driver

import ray
from ray import workflow
Expand Down Expand Up @@ -68,6 +69,35 @@ def test_no_init_api(shutdown_only):
workflow.list_all()


def test_object_valid(workflow_start_regular):
# Test the async api and make sure the object live
# across the lifetime of the job.
import uuid

workflow_id = str(uuid.uuid4())
script = f"""
import ray
from ray import workflow
from typing import List
ray.init(address="{workflow_start_regular}")
@ray.remote
def echo(data, sleep_s=0, others=None):
from time import sleep
sleep(sleep_s)
print(data)
a = {{"abc": "def"}}
e1 = echo.bind(a, 5)
e2 = echo.bind(a, 0, e1)
workflow.run_async(e2, workflow_id="{workflow_id}")
"""
run_string_as_driver(script)

print(ray.get(workflow.get_output_async(workflow_id=workflow_id)))


if __name__ == "__main__":
import sys

Expand Down
5 changes: 3 additions & 2 deletions python/ray/workflow/workflow_state_from_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ray
from ray.workflow.common import WORKFLOW_OPTIONS

from ray.dag import DAGNode, FunctionNode, InputNode
from ray.dag.input_node import InputAttributeNode, DAGInputData
from ray import cloudpickle
Expand Down Expand Up @@ -168,7 +167,9 @@ def _node_visitor(node: Any) -> Any:
flattened_args = _SerializationContextPreservingWrapper(
flattened_args
)
input_placeholder: ray.ObjectRef = ray.put(flattened_args)
# Set the owner of the objects to the actor so that even the driver
# exits, these objects are still available.
input_placeholder: ray.ObjectRef = ray.put(flattened_args, _owner=mgr)

orig_task_id = workflow_options.get("task_id", None)
if orig_task_id is None:
Expand Down
2 changes: 2 additions & 0 deletions src/ray/protobuf/ray_client.proto
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ message PutRequest {
int32 total_chunks = 4;
// Total size in bytes of the data being put
int64 total_size = 5;
// The owner of the put
bytes owner_id = 6;
}

message PutResponse {
Expand Down

0 comments on commit 5449b23

Please sign in to comment.