diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index d1e92f7bb9e9..2e14ca8584dd 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -13,9 +13,14 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxReconstructions; - private ActorCreationOptions(Map resources, int maxReconstructions) { + public final String jvmOptions; + + private ActorCreationOptions(Map resources, + int maxReconstructions, + String jvmOptions) { super(resources); this.maxReconstructions = maxReconstructions; + this.jvmOptions = jvmOptions; } /** @@ -25,6 +30,7 @@ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; + private String jvmOptions = ""; public Builder setResources(Map resources) { this.resources = resources; @@ -36,8 +42,13 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } + public Builder setJvmOptions(String jvmOptions) { + this.jvmOptions = jvmOptions; + return this; + } + public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions); + return new ActorCreationOptions(resources, maxReconstructions, jvmOptions); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index fbd03bf10483..26a8d6e541ba 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -35,6 +35,7 @@ import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.IdUtil; +import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -363,8 +364,13 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes } int maxActorReconstruction = 0; + List dynamicWorkerOptions = ImmutableList.of(); if (taskOptions instanceof ActorCreationOptions) { maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; + String jvmOptions = ((ActorCreationOptions) taskOptions).jvmOptions; + if (!StringUtil.isNullOrEmpty(jvmOptions)) { + dynamicWorkerOptions = ImmutableList.of(((ActorCreationOptions) taskOptions).jvmOptions); + } } TaskLanguage language; @@ -393,7 +399,8 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes numReturns, resources, language, - functionDescriptor + functionDescriptor, + dynamicWorkerOptions ); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 01b9e4675016..c369e6f2cab8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -190,9 +190,16 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) ); + + // Deserialize dynamic worker options. + List dynamicWorkerOptions = new ArrayList<>(); + for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) { + dynamicWorkerOptions.add(info.dynamicWorkerOptions(i)); + } + return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor); + args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -275,6 +282,12 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); } + int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()]; + for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) { + dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index)); + } + int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets); + int root = TaskInfo.createTaskInfo( fbb, driverIdOffset, @@ -293,7 +306,8 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { requiredResourcesOffset, requiredPlacementResourcesOffset, language, - functionDescriptorOffset); + functionDescriptorOffset, + dynamicWorkerOptionsOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 15240e43e234..773499fcf5cf 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -319,6 +319,9 @@ private String buildWorkerCommandRaylet() { cmd.addAll(rayConfig.jvmParameters); + // jvm options + cmd.add("RAY_WORKER_OPTION_0"); + // Main class cmd.add(WORKER_CLASS); String command = Joiner.on(" ").join(cmd); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 3473a9bdb3cc..060ca6fff4c3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -63,6 +63,8 @@ public class TaskSpec { // Language of this task. public final TaskLanguage language; + public final List dynamicWorkerOptions; + // Descriptor of the remote function. // Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language // is Python, the type is PyFunctionDescriptor. @@ -93,7 +95,8 @@ public TaskSpec( int numReturns, Map resources, TaskLanguage language, - FunctionDescriptor functionDescriptor) { + FunctionDescriptor functionDescriptor, + List dynamicWorkerOptions) { this.driverId = driverId; this.taskId = taskId; this.parentTaskId = parentTaskId; @@ -106,6 +109,8 @@ public TaskSpec( this.newActorHandles = newActorHandles; this.args = args; this.numReturns = numReturns; + this.dynamicWorkerOptions = dynamicWorkerOptions; + returnIds = new ObjectId[numReturns]; for (int i = 0; i < numReturns; ++i) { returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); @@ -157,6 +162,7 @@ public String toString() { ", resources=" + resources + ", language=" + language + ", functionDescriptor=" + functionDescriptor + + ", dynamicWorkerOptions=" + dynamicWorkerOptions + ", executionDependencies=" + executionDependencies + '}'; } diff --git a/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java new file mode 100644 index 000000000000..90a2817a8366 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java @@ -0,0 +1,31 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class WorkerJvmOptionsTest extends BaseTest { + + @RayRemote + public static class Echo { + String getOptions() { + return System.getProperty("test.suffix"); + } + } + + @Test + public void testJvmOptions() { + TestUtils.skipTestUnderSingleProcess(); + ActorCreationOptions options = new ActorCreationOptions.Builder() + .setJvmOptions("-Dtest.suffix=suffix") + .createActorCreationOptions(); + RayActor actor = Ray.createActor(Echo::new, options); + RayObject obj = Ray.call(Echo::getOptions, actor); + Assert.assertEquals(obj.get(), "suffix"); + } +} diff --git a/python/ray/services.py b/python/ray/services.py index 2c843f7bbbc7..14e13620eea2 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1233,6 +1233,7 @@ def build_java_worker_command( assert java_worker_options is not None command = "java " + if redis_address is not None: command += "-Dray.redis.address={} ".format(redis_address) @@ -1253,6 +1254,8 @@ def build_java_worker_command( # Put `java_worker_options` in the last, so it can overwrite the # above options. command += java_worker_options + " " + + command += "RAY_WORKER_OPTION_0 " command += "org.ray.runtime.runner.worker.DefaultWorker" return command diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index c92e6a74aa5d..1f50b8025d57 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -36,4 +36,6 @@ constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. constexpr char kTaskTablePrefix[] = "TaskTable"; +constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = "RAY_WORKER_OPTION_"; + #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 614c80b27672..90476da73425 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -106,6 +106,11 @@ table TaskInfo { // For a Python function, it should be: [module_name, class_name, function_name] // For a Java function, it should be: [class_name, method_name, type_descriptor] function_descriptor: [string]; + // The dynamic options used in the worker command when starting the worker process for + // an actor creation task. If the list isn't empty, the options will be used to replace + // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the + // worker command. + dynamic_worker_options: [string]; } table ResourcePair { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a0bde1ff0655..fc364539ccce 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -83,7 +83,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, initial_config_(config), local_available_resources_(config.resource_config), worker_pool_(config.num_initial_workers, config.num_workers_per_process, - config.maximum_startup_concurrency, config.worker_commands), + config.maximum_startup_concurrency, gcs_client_, + config.worker_commands), scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, @@ -1723,18 +1724,6 @@ bool NodeManager::AssignTask(const Task &task) { std::shared_ptr worker = worker_pool_.PopWorker(spec); if (worker == nullptr) { // There are no workers that can execute this task. - if (!spec.IsActorTask()) { - // There are no more non-actor workers available to execute this task. - // Start a new worker. - worker_pool_.StartWorkerProcess(spec.GetLanguage()); - // Push an error message to the user if the worker pool tells us that it is - // getting too big. - const std::string warning_message = worker_pool_.WarningAboutSize(); - if (warning_message != "") { - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), "worker_pool_large", warning_message, current_time_ms())); - } - } // We couldn't assign this task, as no worker available. return false; } @@ -2205,6 +2194,12 @@ void NodeManager::ForwardTask( const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); + if (worker_pool_.HasPendingWorkerForTask(spec.GetLanguage(), task_id)) { + // There is a worker being starting for this task, + // so we shouldn't forward this task to another node. + return; + } + // Get and serialize the task's unforwarded, uncommitted lineage. Lineage uncommitted_lineage; if (lineage_cache_.ContainsTask(task_id)) { diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index eeab29272126..1d722de18f73 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -80,12 +80,12 @@ TaskSpecification::TaskSpecification( const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor) + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options) : spec_() { flatbuffers::FlatBufferBuilder fbb; TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter); - // Add argument object IDs. std::vector> arguments; for (auto &argument : task_arguments) { @@ -101,7 +101,8 @@ TaskSpecification::TaskSpecification( ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, - string_vec_to_flatbuf(fbb, function_descriptor)); + string_vec_to_flatbuf(fbb, function_descriptor), + string_vec_to_flatbuf(fbb, dynamic_worker_options)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -258,6 +259,11 @@ std::vector TaskSpecification::NewActorHandles() const { return ids_from_flatbuf(*message->new_actor_handles()); } +std::vector TaskSpecification::DynamicWorkerOptions() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return string_vec_from_flatbuf(*message->dynamic_worker_options()); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index d557c188ae68..8a08e9974ef2 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -128,6 +128,7 @@ class TaskSpecification { /// will default to be equal to the required_resources argument. /// \param language The language of the worker that must execute the function. /// \param function_descriptor The function descriptor. + /// \param dynamic_worker_options The dynamic options for starting an actor worker. TaskSpecification( const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, @@ -138,7 +139,8 @@ class TaskSpecification { int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor); + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options = {}); /// Deserialize a task specification from a string. /// @@ -214,6 +216,8 @@ class TaskSpecification { ObjectID ActorDummyObject() const; std::vector NewActorHandles() const; + std::vector DynamicWorkerOptions() const; + private: /// Assign the specification data from a pointer. void AssignSpecification(const uint8_t *spec, size_t spec_size); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index d4ac4cf4ecce..719378216fb7 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -5,10 +5,12 @@ #include #include +#include "ray/common/constants.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/stats/stats.h" #include "ray/util/logging.h" +#include "ray/util/util.h" namespace { @@ -41,11 +43,12 @@ namespace raylet { /// (num_worker_processes * num_workers_per_process) workers for each language. WorkerPool::WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands) : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), + gcs_client_(std::move(gcs_client)), last_warning_multiple_(0) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); @@ -98,7 +101,8 @@ uint32_t WorkerPool::Size(const Language &language) const { } } -void WorkerPool::StartWorkerProcess(const Language &language) { +int WorkerPool::StartWorkerProcess(const Language &language, + const std::vector &dynamic_options) { auto &state = GetStateForLanguage(language); // If we are already starting up too many workers, then return without starting // more. @@ -108,7 +112,7 @@ void WorkerPool::StartWorkerProcess(const Language &language) { RAY_LOG(DEBUG) << "Worker not started, " << state.starting_worker_processes.size() << " worker processes of language type " << static_cast(language) << " pending registration"; - return; + return -1; } // Either there are no workers pending registration or the worker start is being forced. RAY_LOG(DEBUG) << "Starting new worker process, current pool has " @@ -117,8 +121,20 @@ void WorkerPool::StartWorkerProcess(const Language &language) { // Extract pointers from the worker command to pass into execvp. std::vector worker_command_args; + size_t dynamic_option_index = 0; for (auto const &token : state.worker_command) { - worker_command_args.push_back(token.c_str()); + const auto option_placeholder = + kWorkerDynamicOptionPlaceholderPrefix + std::to_string(dynamic_option_index); + + if (token == option_placeholder) { + if (!dynamic_options.empty()) { + RAY_CHECK(dynamic_option_index < dynamic_options.size()); + worker_command_args.push_back(dynamic_options[dynamic_option_index].c_str()); + ++dynamic_option_index; + } + } else { + worker_command_args.push_back(token.c_str()); + } } worker_command_args.push_back(nullptr); @@ -126,14 +142,14 @@ void WorkerPool::StartWorkerProcess(const Language &language) { if (pid < 0) { // Failure case. RAY_LOG(FATAL) << "Failed to fork worker process: " << strerror(errno); - return; } else if (pid > 0) { // Parent process case. RAY_LOG(DEBUG) << "Started worker process with pid " << pid; state.starting_worker_processes.emplace( std::make_pair(pid, num_workers_per_process_)); - return; + return pid; } + return -1; } pid_t WorkerPool::StartProcess(const std::vector &worker_command_args) { @@ -158,7 +174,7 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_a } void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { - auto pid = worker->Pid(); + const auto pid = worker->Pid(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid; auto &state = GetStateForLanguage(worker->GetLanguage()); state.registered_workers.insert(std::move(worker)); @@ -207,30 +223,74 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; auto &state = GetStateForLanguage(worker->GetLanguage()); - // Add the worker to the idle pool. - if (worker->GetActorId().IsNil()) { - state.idle.insert(std::move(worker)); + + auto it = state.dedicated_workers_to_tasks.find(worker->Pid()); + if (it != state.dedicated_workers_to_tasks.end()) { + // The worker is used for the actor creation task with dynamic options. + // Put it into idle dedicated worker pool. + const auto task_id = it->second; + state.idle_dedicated_workers[task_id] = std::move(worker); } else { - state.idle_actor[worker->GetActorId()] = std::move(worker); + // The worker is not used for the actor creation task without dynamic options. + // Put the worker to the corresponding idle pool. + if (worker->GetActorId().IsNil()) { + state.idle.insert(std::move(worker)); + } else { + state.idle_actor[worker->GetActorId()] = std::move(worker); + } } } std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec) { auto &state = GetStateForLanguage(task_spec.GetLanguage()); const auto &actor_id = task_spec.ActorId(); + std::shared_ptr worker = nullptr; - if (actor_id.IsNil()) { + int pid = -1; + if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) { + // Code path of actor creation task with dynamic worker options. + // Try to pop it from idle dedicated pool. + auto it = state.idle_dedicated_workers.find(task_spec.TaskId()); + if (it != state.idle_dedicated_workers.end()) { + // There is an idle dedicated worker for this task. + worker = std::move(it->second); + state.idle_dedicated_workers.erase(it); + // Because we found a worker that can perform this task, + // we can remove it from dedicated_workers_to_tasks. + state.dedicated_workers_to_tasks.erase(worker->Pid()); + state.tasks_to_dedicated_workers.erase(task_spec.TaskId()); + } else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) { + // We are not pending a registration from a worker for this task, + // so start a new worker process for this task. + pid = StartWorkerProcess(task_spec.GetLanguage(), task_spec.DynamicWorkerOptions()); + if (pid > 0) { + state.dedicated_workers_to_tasks[pid] = task_spec.TaskId(); + state.tasks_to_dedicated_workers[task_spec.TaskId()] = pid; + } + } + } else if (!task_spec.IsActorTask()) { + // Code path of normal task or actor creation task without dynamic worker options. if (!state.idle.empty()) { worker = std::move(*state.idle.begin()); state.idle.erase(state.idle.begin()); + } else { + // There are no more non-actor workers available to execute this task. + // Start a new worker process. + pid = StartWorkerProcess(task_spec.GetLanguage()); } } else { + // Code path of actor task. auto actor_entry = state.idle_actor.find(actor_id); if (actor_entry != state.idle_actor.end()) { worker = std::move(actor_entry->second); state.idle_actor.erase(actor_entry); } } + + if (worker == nullptr && pid > 0) { + WarnAboutSize(); + } + return worker; } @@ -274,7 +334,7 @@ std::vector> WorkerPool::GetWorkersRunningTasksForDriver return workers; } -std::string WorkerPool::WarningAboutSize() { +void WorkerPool::WarnAboutSize() { int64_t num_workers_started_or_registered = 0; for (const auto &entry : states_by_lang_) { num_workers_started_or_registered += @@ -285,6 +345,8 @@ std::string WorkerPool::WarningAboutSize() { int64_t multiple = num_workers_started_or_registered / multiple_for_warning_; std::stringstream warning_message; if (multiple >= 3 && multiple > last_warning_multiple_) { + // Push an error message to the user if the worker pool tells us that it is + // getting too big. last_warning_multiple_ = multiple; warning_message << "WARNING: " << num_workers_started_or_registered << " workers have been started. This could be a result of using " @@ -292,8 +354,16 @@ std::string WorkerPool::WarningAboutSize() { << "using nested tasks " << "(see https://github.com/ray-project/ray/issues/3644) for " << "some a discussion of workarounds."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + DriverID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); } - return warning_message.str(); +} + +bool WorkerPool::HasPendingWorkerForTask(const Language &language, + const TaskID &task_id) { + auto &state = GetStateForLanguage(language); + auto it = state.tasks_to_dedicated_workers.find(task_id); + return it != state.tasks_to_dedicated_workers.end(); } std::string WorkerPool::DebugString() const { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 03443447cf58..e1e726268093 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -7,6 +7,7 @@ #include #include "ray/common/client_connection.h" +#include "ray/gcs/client.h" #include "ray/gcs/format/util.h" #include "ray/raylet/task.h" #include "ray/raylet/worker.h" @@ -37,22 +38,12 @@ class WorkerPool { /// language. WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands); /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); - /// Asynchronously start a new worker process. Once the worker process has - /// registered with an external server, the process should create and - /// register num_workers_per_process_ workers, then add them to the pool. - /// Failure to start the worker process is a fatal error. If too many workers - /// are already being started, then this function will return without starting - /// any workers. - /// - /// \param language Which language this worker process should be. - void StartWorkerProcess(const Language &language); - /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). /// @@ -118,6 +109,15 @@ class WorkerPool { std::vector> GetWorkersRunningTasksForDriver( const DriverID &driver_id) const; + /// Whether there is a pending worker for the given task. + /// Note that, this is only used for actor creation task with dynamic options. + /// And if the worker registered but isn't assigned a task, + /// the worker also is in pending state, and this'll return true. + /// + /// \param language The required language. + /// \param task_id The task that we want to query. + bool HasPendingWorkerForTask(const Language &language, const TaskID &task_id); + /// Returns debug string for class. /// /// \return string. @@ -126,24 +126,37 @@ class WorkerPool { /// Record metrics. void RecordMetrics() const; - /// Generate a warning about the number of workers that have registered or - /// started if appropriate. + protected: + /// Asynchronously start a new worker process. Once the worker process has + /// registered with an external server, the process should create and + /// register num_workers_per_process_ workers, then add them to the pool. + /// Failure to start the worker process is a fatal error. If too many workers + /// are already being started, then this function will return without starting + /// any workers. /// - /// \return An empty string if no warning should be generated and otherwise a - /// string with a warning message. - std::string WarningAboutSize(); + /// \param language Which language this worker process should be. + /// \param dynamic_options The dynamic options that we should add for worker command. + /// \return The id of the process that we started if it's positive, + /// otherwise it means we didn't start a process. + int StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}); - protected: /// The implementation of how to start a new worker process with command arguments. /// /// \param worker_command_args The command arguments of new worker process. /// \return The process ID of started worker process. virtual pid_t StartProcess(const std::vector &worker_command_args); + /// Push an warning message to user if worker pool is getting to big. + virtual void WarnAboutSize(); + /// An internal data structure that maintains the pool state per language. struct State { /// The commands and arguments used to start the worker process std::vector worker_command; + /// The pool of dedicated workers for actor creation tasks + /// with prefix or suffix worker command. + std::unordered_map> idle_dedicated_workers; /// The pool of idle non-actor workers. std::unordered_set> idle; /// The pool of idle actor workers. @@ -156,6 +169,11 @@ class WorkerPool { /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; + /// A map for looking up the task with dynamic options by the pid of + /// worker. Note that this is used for the dedicated worker processes. + std::unordered_map dedicated_workers_to_tasks; + /// A map for speeding up looking up the pending worker for the given task. + std::unordered_map tasks_to_dedicated_workers; }; /// The number of workers per process. @@ -166,7 +184,7 @@ class WorkerPool { private: /// A helper function that returns the reference of the pool state /// for a given language. - inline State &GetStateForLanguage(const Language &language); + State &GetStateForLanguage(const Language &language); /// We'll push a warning to the user every time a multiple of this many /// workers has been started. @@ -176,6 +194,8 @@ class WorkerPool { /// The last size at which a warning about the number of registered workers /// was generated. int64_t last_warning_multiple_; + /// A client connection to the GCS. + std::shared_ptr gcs_client_; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 143ffd57dda6..15a5fb0471e0 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -1,6 +1,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "ray/common/constants.h" #include "ray/raylet/node_manager.h" #include "ray/raylet/worker_pool.h" @@ -14,21 +15,46 @@ int MAXIMUM_STARTUP_CONCURRENCY = 5; class WorkerPoolMock : public WorkerPool { public: WorkerPoolMock() - : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, - {{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, {"dummy_java_worker_command"}}}), + : WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, {"dummy_java_worker_command"}}}) {} + + explicit WorkerPoolMock( + const std::unordered_map> &worker_commands) + : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr, + worker_commands), last_worker_pid_(0) {} + ~WorkerPoolMock() { // Avoid killing real processes states_by_lang_.clear(); } + void StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}) { + WorkerPool::StartWorkerProcess(language, dynamic_options); + } + pid_t StartProcess(const std::vector &worker_command_args) override { - return ++last_worker_pid_; + last_worker_pid_ += 1; + std::vector local_worker_commands_args; + for (auto item : worker_command_args) { + if (item == nullptr) { + break; + } + local_worker_commands_args.push_back(std::string(item)); + } + worker_commands_by_pid[last_worker_pid_] = std::move(local_worker_commands_args); + return last_worker_pid_; } + void WarnAboutSize() override {} + pid_t LastStartedWorkerProcess() const { return last_worker_pid_; } + const std::vector &GetWorkerCommand(int pid) { + return worker_commands_by_pid[pid]; + } + int NumWorkerProcessesStarting() const { int total = 0; for (auto &entry : states_by_lang_) { @@ -39,6 +65,8 @@ class WorkerPoolMock : public WorkerPool { private: int last_worker_pid_; + // The worker commands by pid. + std::unordered_map> worker_commands_by_pid; }; class WorkerPoolTest : public ::testing::Test { @@ -61,6 +89,12 @@ class WorkerPoolTest : public ::testing::Test { return std::shared_ptr(new Worker(pid, language, client)); } + void SetWorkerCommands( + const std::unordered_map> &worker_commands) { + WorkerPoolMock worker_pool(worker_commands); + this->worker_pool_ = std::move(worker_pool); + } + protected: WorkerPoolMock worker_pool_; boost::asio::io_service io_service_; @@ -72,10 +106,10 @@ class WorkerPoolTest : public ::testing::Test { }; static inline TaskSpecification ExampleTaskSpec( - const ActorID actor_id = ActorID::Nil(), - const Language &language = Language::PYTHON) { + const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, + const ActorID actor_creation_id = ActorID::Nil()) { std::vector function_descriptor(3); - return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, ActorID::Nil(), + return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, actor_creation_id, ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, language, function_descriptor); } @@ -186,6 +220,23 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { ASSERT_NE(worker_pool_.PopWorker(java_task_spec), nullptr); } +TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { + const std::vector java_worker_command = { + "RAY_WORKER_OPTION_0", "dummy_java_worker_command", "RAY_WORKER_OPTION_1"}; + SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, java_worker_command}}); + + TaskSpecification task_spec(DriverID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), + ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0, + {}, {}, 0, {}, {}, Language::JAVA, {"", "", ""}, + {"test_op_0", "test_op_1"}); + worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); + const auto real_command = + worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); + ASSERT_EQ(real_command, std::vector( + {"test_op_0", "dummy_java_worker_command", "test_op_1"})); +} + } // namespace raylet } // namespace ray