Skip to content

Commit

Permalink
[Core/Logging] Worker startup hook (#34738)
Browse files Browse the repository at this point in the history
This PR supports the basic worker setup hook API to runtime env according to the design; https://docs.google.com/document/d/1ngiuAZAMnl9c4LozoTpWh37KPviDRIpmEjEI6BsNL7w/edit

This PR also exposes exit API to the Python so that we can easily fail the worker with a exception we want

It is the first PR to support this feature. The PR allows users to add a setup method using runtime env. There will be 2 more PRs that will be coming a follow up

Merge the runtime env when the job + driver specifies the runtime env.
Allow to specify setup hook for individual task and actor
  • Loading branch information
rkooo567 authored May 15, 2023
1 parent 5d0b15e commit 2dbe747
Show file tree
Hide file tree
Showing 16 changed files with 470 additions and 46 deletions.
92 changes: 77 additions & 15 deletions python/ray/_private/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import traceback
from collections import defaultdict, namedtuple
from typing import Optional
from typing import Optional, Callable

import ray
import ray._private.profiling as profiling
Expand All @@ -27,11 +27,16 @@
format_error_message,
)
from ray._private.serialization import pickle_dumps
from ray._raylet import JobID, PythonFunctionDescriptor
from ray._raylet import JobID, PythonFunctionDescriptor, WORKER_SETUP_HOOK_KEY_NAME_GCS

FunctionExecutionInfo = namedtuple(
"FunctionExecutionInfo", ["function", "function_name", "max_calls"]
)
ImportedFunctionInfo = namedtuple(
"ImportedFunctionInfo",
["job_id", "function_id", "function_name", "function", "module", "max_calls"],
)

"""FunctionExecutionInfo: A named tuple storing remote function information."""

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -175,6 +180,53 @@ def export_key(self, key):
# TODO(mwtian) implement per-job notification here.
self._worker.gcs_publisher.publish_function_key(key)

def export_setup_func(
self, setup_func: Callable, timeout: Optional[int] = None
) -> bytes:
"""Export the setup hook function and return the key."""
pickled_function = pickle_dumps(
setup_func, f"Cannot serialize the worker_setup_hook {setup_func.__name__}"
)

function_to_run_id = hashlib.shake_128(pickled_function).digest(
ray_constants.ID_SIZE
)
key = make_function_table_key(
# This value should match with gcs_function_manager.h.
# Otherwise, it won't be GC'ed.
WORKER_SETUP_HOOK_KEY_NAME_GCS.encode(),
# b"FunctionsToRun",
self._worker.current_job_id.binary(),
function_to_run_id,
)

check_oversized_function(
pickled_function, setup_func.__name__, "function", self._worker
)

try:
self._worker.gcs_client.internal_kv_put(
key,
pickle.dumps(
{
"job_id": self._worker.current_job_id.binary(),
"function_id": function_to_run_id,
"function": pickled_function,
}
),
# overwrite
True,
ray_constants.KV_NAMESPACE_FUNCTION_TABLE,
timeout=timeout,
)
except Exception as e:
logger.exception(
"Failed to export the setup hook " f"{setup_func.__name__}."
)
raise e

return key

def export(self, remote_function):
"""Pickle a remote function and export it to redis.
Args:
Expand Down Expand Up @@ -224,29 +276,39 @@ def export(self, remote_function):
key, val, True, KV_NAMESPACE_FUNCTION_TABLE
)

def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
vals = self._worker.gcs_client.internal_kv_get(key, KV_NAMESPACE_FUNCTION_TABLE)
def fetch_registered_method(
self, key: str, timeout: Optional[int] = None
) -> Optional[ImportedFunctionInfo]:
vals = self._worker.gcs_client.internal_kv_get(
key, KV_NAMESPACE_FUNCTION_TABLE, timeout=timeout
)
if vals is None:
return False
return None
else:
vals = pickle.loads(vals)
fields = [
"job_id",
"function_id",
"function_name",
"function",
"module",
"max_calls",
]
fields = [
"job_id",
"function_id",
"function_name",
"function",
"module",
"max_calls",
]
return ImportedFunctionInfo._make(vals.get(field) for field in fields)

def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
remote_function_info = self.fetch_registered_method(key)
if not remote_function_info:
return False
(
job_id_str,
function_id_str,
function_name,
serialized_function,
module,
max_calls,
) = (vals.get(field) for field in fields)
) = remote_function_info

function_id = ray.FunctionID(function_id_str)
job_id = ray.JobID(job_id_str)
Expand Down
3 changes: 3 additions & 0 deletions python/ray/_private/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,6 @@ def gcs_actor_scheduling_enabled():
}

RAY_ENABLE_RECORD_TASK_LOGGING = env_bool("RAY_ENABLE_RECORD_TASK_LOGGING", False)

WORKER_SETUP_HOOK_ENV_VAR = "__RAY_WORKER_SETUP_HOOK_ENV_VAR"
RAY_WORKER_SETUP_HOOK_LOAD_TIMEOUT_ENV_VAR = "RAY_WORKER_SETUP_HOOK_LOAD_TIMEOUT"
2 changes: 2 additions & 0 deletions python/ray/_private/runtime_env/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class RuntimeEnvPlugin(ABC):
def validate(runtime_env_dict: dict) -> None:
"""Validate user entry for this plugin.
The method is invoked upon installation of runtime env.
Args:
runtime_env_dict: the user-supplied runtime environment dict.
Expand Down
131 changes: 131 additions & 0 deletions python/ray/_private/runtime_env/setup_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import traceback
import logging
import base64
import os

from typing import Dict, Any, Callable, Union, Optional

import ray
import ray._private.ray_constants as ray_constants
import ray.cloudpickle as pickle
from ray.runtime_env import RuntimeEnv

logger = logging.getLogger(__name__)


def get_import_export_timeout():
return int(
os.environ.get(ray_constants.RAY_WORKER_SETUP_HOOK_LOAD_TIMEOUT_ENV_VAR, "60")
)


def _decode_function_key(key: bytes) -> str:
return base64.b64encode(key).decode()


def _encode_function_key(key: str) -> bytes:
return base64.b64decode(key)


def upload_worker_setup_hook_if_needed(
runtime_env: Union[Dict[str, Any], RuntimeEnv],
worker: "ray.Worker",
) -> Union[Dict[str, Any], RuntimeEnv]:
"""Uploads the worker_setup_hook to GCS with a key.
runtime_env["worker_setup_hook"] is converted to a decoded key
that can load the worker setup hook function from GCS.
I.e., you can use internalKV.Get(runtime_env["worker_setup_hook])
to access the worker setup hook from GCS.
Args:
runtime_env: The runtime_env. The value will be modified
when returned.
worker: ray.worker instance.
decoder: GCS requires the function key to be bytes. However,
we cannot json serialize (which is required to serialize
runtime env) the bytes. So the key should be decoded to
a string. The given decoder is used to decode the function
key.
"""
setup_func = runtime_env.get("worker_setup_hook")
if setup_func is None:
return runtime_env

if not isinstance(setup_func, Callable):
raise TypeError(
"worker_setup_hook must be a function, " f"got {type(setup_func)}."
)
# TODO(sang): Support modules.

try:
key = worker.function_actor_manager.export_setup_func(
setup_func, timeout=get_import_export_timeout()
)
except Exception as e:
raise ray.exceptions.RuntimeEnvSetupError(
"Failed to export the setup function."
) from e
env_vars = runtime_env.get("env_vars", {})
assert ray_constants.WORKER_SETUP_HOOK_ENV_VAR not in env_vars, (
f"The env var, {ray_constants.WORKER_SETUP_HOOK_ENV_VAR}, "
"is not permitted because it is reserved for the internal use."
)
env_vars[ray_constants.WORKER_SETUP_HOOK_ENV_VAR] = _decode_function_key(key)
runtime_env["env_vars"] = env_vars
# Note: This field is no-op. We don't have a plugin for the setup hook
# because we can implement it simply using an env var.
# This field is just for the observability purpose, so we store
# the name of the method.
runtime_env["worker_setup_hook"] = setup_func.__name__
return runtime_env


def load_and_execute_setup_hook(
worker_setup_hook_key: str,
) -> Optional[str]:
"""Load the setup hook from a given key and execute.
Args:
worker_setup_hook_key: The key to import the setup hook
from GCS.
Returns:
An error message if it fails. None if it succeeds.
"""
assert worker_setup_hook_key is not None
worker = ray._private.worker.global_worker
assert worker.connected

func_manager = worker.function_actor_manager
try:
worker_setup_func_info = func_manager.fetch_registered_method(
_encode_function_key(worker_setup_hook_key),
timeout=get_import_export_timeout(),
)
except Exception:
error_message = (
"Failed to import setup hook within "
f"{get_import_export_timeout()} seconds.\n"
f"{traceback.format_exc()}"
)
return error_message

try:
setup_func = pickle.loads(worker_setup_func_info.function)
except Exception:
error_message = (
"Failed to deserialize the setup hook method.\n" f"{traceback.format_exc()}"
)
return error_message

try:
setup_func()
except Exception:
error_message = (
f"Failed to execute the setup hook method. Function name:"
f"{worker_setup_func_info.function_name}\n"
f"{traceback.format_exc()}"
)
return error_message

return None
5 changes: 5 additions & 0 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray._private.runtime_env.setup_hook import upload_worker_setup_hook_if_needed
from ray._private.storage import _load_class
from ray._private.utils import check_oversized_function, get_ray_doc_version
from ray.exceptions import ObjectStoreFullError, RayError, RaySystemError, RayTaskError
Expand Down Expand Up @@ -2173,6 +2174,10 @@ def connect(
runtime_env = upload_working_dir_if_needed(
runtime_env, scratch_dir, logger=logger
)
runtime_env = upload_worker_setup_hook_if_needed(
runtime_env,
worker,
)
# Remove excludes, it isn't relevant after the upload step.
runtime_env.pop("excludes", None)
job_config.set_runtime_env(runtime_env)
Expand Down
17 changes: 14 additions & 3 deletions python/ray/_private/workers/default_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import argparse
import base64
import json
Expand All @@ -10,6 +11,7 @@
import ray.actor
from ray._private.parameter import RayParams
from ray._private.ray_logging import configure_log_file, get_worker_log_file_name
from ray._private.runtime_env.setup_hook import load_and_execute_setup_hook


parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -236,20 +238,29 @@
worker_launched_time_ms=worker_launched_time_ms,
)

worker = ray._private.worker.global_worker

# Setup log file.
out_file, err_file = node.get_log_file_handles(
get_worker_log_file_name(args.worker_type)
)
configure_log_file(out_file, err_file)
ray._private.worker.global_worker.set_out_file(out_file)
ray._private.worker.global_worker.set_err_file(err_file)
worker.set_out_file(out_file)
worker.set_err_file(err_file)

if mode == ray.WORKER_MODE and args.worker_preload_modules:
module_names_to_import = args.worker_preload_modules.split(",")
ray._private.utils.try_import_each_module(module_names_to_import)

# If the worker setup function is configured, run it.
worker_setup_hook_key = os.getenv(ray_constants.WORKER_SETUP_HOOK_ENV_VAR)
if worker_setup_hook_key:
error = load_and_execute_setup_hook(worker_setup_hook_key)
if error is not None:
worker.core_worker.exit_worker("system", error)

if mode == ray.WORKER_MODE:
ray._private.worker.global_worker.main_loop()
worker.main_loop()
elif mode in [ray.RESTORE_WORKER_MODE, ray.SPILL_WORKER_MODE]:
# It is handled by another thread in the C++ core worker.
# We just need to keep the worker alive.
Expand Down
Loading

0 comments on commit 2dbe747

Please sign in to comment.