Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Fix the race condition where grpc requests are handled while core worker not yet initialized #37117

Merged
merged 11 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,13 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
// Verify driver and worker are never mixed in the same process.
RAY_CHECK_EQ(options_.worker_type != WorkerType::DRIVER, niced);
#endif

// Notify that core worker is initialized.
{
scv119 marked this conversation as resolved.
Show resolved Hide resolved
absl::MutexLock lock(&initialize_mutex_);
initialized_ = true;
scv119 marked this conversation as resolved.
Show resolved Hide resolved
intialize_cv_.SignalAll();
}
}

CoreWorker::~CoreWorker() { RAY_LOG(INFO) << "Core worker is destructed"; }
Expand Down Expand Up @@ -2867,6 +2874,7 @@ void CoreWorker::HandleReportGeneratorItemReturns(
rpc::ReportGeneratorItemReturnsRequest request,
rpc::ReportGeneratorItemReturnsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
scv119 marked this conversation as resolved.
Show resolved Hide resolved
task_manager_->HandleReportGeneratorItemReturns(request);
send_reply_callback(Status::OK(), nullptr, nullptr);
}
Expand Down Expand Up @@ -3011,6 +3019,7 @@ void CoreWorker::HandlePushTask(rpc::PushTaskRequest request,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Received Handle Push Task "
<< TaskID::FromBinary(request.task_spec().task_id());
WaitUntilInitialized();
if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()),
send_reply_callback)) {
return;
Expand Down Expand Up @@ -3070,6 +3079,7 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete(
rpc::DirectActorCallArgWaitCompleteRequest request,
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()),
send_reply_callback)) {
return;
Expand All @@ -3091,13 +3101,15 @@ void CoreWorker::HandleRayletNotifyGCSRestart(
rpc::RayletNotifyGCSRestartRequest request,
rpc::RayletNotifyGCSRestartReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
gcs_client_->AsyncResubscribe();
send_reply_callback(Status::OK(), nullptr, nullptr);
}

void CoreWorker::HandleGetObjectStatus(rpc::GetObjectStatusRequest request,
rpc::GetObjectStatusReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
if (HandleWrongRecipient(WorkerID::FromBinary(request.owner_worker_id()),
send_reply_callback)) {
RAY_LOG(INFO) << "Handling GetObjectStatus for object produced by a previous worker "
Expand Down Expand Up @@ -3177,6 +3189,7 @@ void CoreWorker::HandleWaitForActorOutOfScope(
rpc::WaitForActorOutOfScopeRequest request,
rpc::WaitForActorOutOfScopeReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
// Currently WaitForActorOutOfScope is only used when GCS actor service is enabled.
if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()),
send_reply_callback)) {
Expand Down Expand Up @@ -3304,6 +3317,7 @@ void CoreWorker::ProcessPubsubCommands(const Commands &commands,
void CoreWorker::HandlePubsubLongPolling(rpc::PubsubLongPollingRequest request,
rpc::PubsubLongPollingReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
const auto subscriber_id = NodeID::FromBinary(request.subscriber_id());
RAY_LOG(DEBUG) << "Got a long polling request from a node " << subscriber_id;
object_info_publisher_->ConnectToSubscriber(
Expand All @@ -3313,6 +3327,7 @@ void CoreWorker::HandlePubsubLongPolling(rpc::PubsubLongPollingRequest request,
void CoreWorker::HandlePubsubCommandBatch(rpc::PubsubCommandBatchRequest request,
rpc::PubsubCommandBatchReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
const auto subscriber_id = NodeID::FromBinary(request.subscriber_id());
ProcessPubsubCommands(request.commands(), subscriber_id);
send_reply_callback(Status::OK(), nullptr, nullptr);
Expand All @@ -3322,6 +3337,7 @@ void CoreWorker::HandleUpdateObjectLocationBatch(
rpc::UpdateObjectLocationBatchRequest request,
rpc::UpdateObjectLocationBatchReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
const auto &worker_id = request.intended_worker_id();
if (HandleWrongRecipient(WorkerID::FromBinary(worker_id), send_reply_callback)) {
return;
Expand Down Expand Up @@ -3457,6 +3473,7 @@ void CoreWorker::HandleGetObjectLocationsOwner(
rpc::GetObjectLocationsOwnerRequest request,
rpc::GetObjectLocationsOwnerReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
auto &object_location_request = request.object_location_request();
if (HandleWrongRecipient(
WorkerID::FromBinary(object_location_request.intended_worker_id()),
Expand Down Expand Up @@ -3496,6 +3513,7 @@ void CoreWorker::ProcessSubscribeForRefRemoved(
void CoreWorker::HandleRemoteCancelTask(rpc::RemoteCancelTaskRequest request,
rpc::RemoteCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()),
request.force_kill(),
request.recursive());
Expand All @@ -3505,6 +3523,7 @@ void CoreWorker::HandleRemoteCancelTask(rpc::RemoteCancelTaskRequest request,
void CoreWorker::HandleCancelTask(rpc::CancelTaskRequest request,
rpc::CancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
TaskID task_id = TaskID::FromBinary(request.intended_task_id());
bool requested_task_running;
{
Expand Down Expand Up @@ -3559,6 +3578,7 @@ void CoreWorker::HandleCancelTask(rpc::CancelTaskRequest request,
void CoreWorker::HandleKillActor(rpc::KillActorRequest request,
rpc::KillActorReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
ActorID intended_actor_id = ActorID::FromBinary(request.intended_actor_id());
if (intended_actor_id != worker_context_.GetCurrentActorID()) {
std::ostringstream stream;
Expand Down Expand Up @@ -3590,6 +3610,7 @@ void CoreWorker::HandleKillActor(rpc::KillActorRequest request,
void CoreWorker::HandleGetCoreWorkerStats(rpc::GetCoreWorkerStatsRequest request,
rpc::GetCoreWorkerStatsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
absl::MutexLock lock(&mutex_);
auto limit = request.has_limit() ? request.limit() : -1;
auto stats = reply->mutable_core_worker_stats();
Expand Down Expand Up @@ -3649,6 +3670,7 @@ void CoreWorker::HandleGetCoreWorkerStats(rpc::GetCoreWorkerStatsRequest request
void CoreWorker::HandleLocalGC(rpc::LocalGCRequest request,
rpc::LocalGCReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
if (options_.gc_collect != nullptr) {
options_.gc_collect(request.triggered_by_global_gc());
send_reply_callback(Status::OK(), nullptr, nullptr);
Expand All @@ -3661,6 +3683,7 @@ void CoreWorker::HandleLocalGC(rpc::LocalGCRequest request,
void CoreWorker::HandleDeleteObjects(rpc::DeleteObjectsRequest request,
rpc::DeleteObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
std::vector<ObjectID> object_ids;
for (const auto &obj_id : request.object_ids()) {
object_ids.push_back(ObjectID::FromBinary(obj_id));
Expand Down Expand Up @@ -3692,6 +3715,7 @@ Status CoreWorker::DeleteImpl(const std::vector<ObjectID> &object_ids, bool loca
void CoreWorker::HandleSpillObjects(rpc::SpillObjectsRequest request,
rpc::SpillObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
if (options_.spill_objects != nullptr) {
auto object_refs =
VectorFromProtobuf<rpc::ObjectReference>(request.object_refs_to_spill());
Expand All @@ -3709,6 +3733,7 @@ void CoreWorker::HandleSpillObjects(rpc::SpillObjectsRequest request,
void CoreWorker::HandleRestoreSpilledObjects(rpc::RestoreSpilledObjectsRequest request,
rpc::RestoreSpilledObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
if (options_.restore_spilled_objects != nullptr) {
// Get a list of object ids.
std::vector<rpc::ObjectReference> object_refs_to_restore;
Expand Down Expand Up @@ -3739,6 +3764,7 @@ void CoreWorker::HandleRestoreSpilledObjects(rpc::RestoreSpilledObjectsRequest r
void CoreWorker::HandleDeleteSpilledObjects(rpc::DeleteSpilledObjectsRequest request,
rpc::DeleteSpilledObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
if (options_.delete_spilled_objects != nullptr) {
std::vector<std::string> spilled_objects_url;
spilled_objects_url.reserve(request.spilled_objects_url_size());
Expand All @@ -3758,6 +3784,7 @@ void CoreWorker::HandleDeleteSpilledObjects(rpc::DeleteSpilledObjectsRequest req
void CoreWorker::HandleExit(rpc::ExitRequest request,
rpc::ExitReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
bool own_objects = reference_counter_->OwnObjects();
int64_t pins_in_flight = local_raylet_client_->GetPinsInFlight();
// We consider the worker to be idle if it doesn't own any objects and it doesn't have
Expand Down Expand Up @@ -3794,6 +3821,7 @@ void CoreWorker::HandleExit(rpc::ExitRequest request,
void CoreWorker::HandleAssignObjectOwner(rpc::AssignObjectOwnerRequest request,
rpc::AssignObjectOwnerReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
ObjectID object_id = ObjectID::FromBinary(request.object_id());
const auto &borrower_address = request.borrower_address();
std::string call_site = request.call_site();
Expand Down Expand Up @@ -3821,6 +3849,7 @@ void CoreWorker::HandleAssignObjectOwner(rpc::AssignObjectOwnerRequest request,
void CoreWorker::HandleNumPendingTasks(rpc::NumPendingTasksRequest request,
rpc::NumPendingTasksReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
RAY_LOG(DEBUG) << "Received NumPendingTasks request.";
reply->set_num_pending_tasks(task_manager_->NumPendingTasks());
send_reply_callback(Status::OK(), nullptr, nullptr);
Expand Down Expand Up @@ -3895,6 +3924,7 @@ void CoreWorker::PlasmaCallback(SetResultCallback success,
void CoreWorker::HandlePlasmaObjectReady(rpc::PlasmaObjectReadyRequest request,
rpc::PlasmaObjectReadyReply *reply,
rpc::SendReplyCallback send_reply_callback) {
WaitUntilInitialized();
std::vector<std::function<void(void)>> callbacks;
{
absl::MutexLock lock(&plasma_mutex_);
Expand Down
13 changes: 13 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
}
}

/// Wait until the worker is initialized.
void WaitUntilInitialized() {
absl::MutexLock lock(&initialize_mutex_);
while (!initialized_) {
intialize_cv_.WaitWithTimeout(&initialize_mutex_, absl::Seconds(1));
}
}

const CoreWorkerOptions options_;

/// Callback to get the current language (e.g., Python) call site.
Expand Down Expand Up @@ -1548,6 +1556,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {

std::string main_thread_task_name_ GUARDED_BY(mutex_);

/// States that used for initialization.
absl::Mutex initialize_mutex_;
absl::CondVar intialize_cv_;
bool initialized_ GUARDED_BY(initialize_mutex_) = false;

/// Event loop where the IO events are handled. e.g. async GCS operations.
instrumented_io_context io_service_;

Expand Down