diff --git a/src/ray/common/asio/instrumented_io_context.h b/src/ray/common/asio/instrumented_io_context.h index aff11c36b9b28..8a72bc5f7209d 100644 --- a/src/ray/common/asio/instrumented_io_context.h +++ b/src/ray/common/asio/instrumented_io_context.h @@ -28,7 +28,20 @@ class instrumented_io_context : public boost::asio::io_context { public: /// Initializes the global stats struct after calling the base contructor. /// TODO(ekl) allow taking an externally defined event tracker. - instrumented_io_context() : event_stats_(std::make_shared()) {} + instrumented_io_context() + : event_stats_(std::make_shared()), is_running_(false) {} + + bool running() { return is_running_.load(); } + + void run() { + is_running_.store(true); + boost::asio::io_context::run(); + } + + void stop() { + is_running_.store(false); + boost::asio::io_context::stop(); + } /// A proxy post function that collects count, queueing, and execution statistics for /// the given handler. @@ -58,4 +71,6 @@ class instrumented_io_context : public boost::asio::io_context { private: /// The event stats tracker to use to record asio handler stats to. std::shared_ptr event_stats_; + + std::atomic is_running_; }; diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index ebe6ffbb610c6..a4d586412db15 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -90,18 +90,35 @@ Status GcsClient::Connect(instrumented_io_context &io_service, if (cluster_id.IsNil()) { rpc::GetClusterIdReply reply; + std::promise temporary_start; + std::promise wait_sync; gcs_rpc_client_->GetClusterId( rpc::GetClusterIdRequest(), - [this, &io_service](const Status &status, const rpc::GetClusterIdReply &reply) { + [this, + &io_service, + &wait_sync, + do_stop = std::shared_future(temporary_start.get_future())]( + const Status &status, const rpc::GetClusterIdReply &reply) { RAY_CHECK(status.ok()) << "Failed to get Cluster ID! Status: " << status; auto cluster_id = ClusterID::FromBinary(reply.cluster_id()); RAY_LOG(DEBUG) << "Setting cluster ID to " << cluster_id; client_call_manager_->SetClusterId(cluster_id); - io_service.stop(); + if (do_stop.get()) { + io_service.stop(); + } + wait_sync.set_value(true); }); // Run the IO service here to make the above call synchronous. - io_service.run(); - io_service.restart(); + // If it is already running, then wait for our particular callback + // to be processed. + if (!io_service.running()) { + temporary_start.set_value(true); + io_service.run(); + io_service.restart(); + } else { + temporary_start.set_value(false); + wait_sync.get_future().get(); + } } else { client_call_manager_->SetClusterId(cluster_id); } diff --git a/src/ray/gcs/gcs_client/test/gcs_client_test.cc b/src/ray/gcs/gcs_client/test/gcs_client_test.cc index 9f581619273d6..fdc38a85a4266 100644 --- a/src/ray/gcs/gcs_client/test/gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/gcs_client_test.cc @@ -69,15 +69,6 @@ class GcsClientTest : public ::testing::TestWithParam { config_.node_ip_address = "127.0.0.1"; config_.enable_sharding_conn = false; - // Tests legacy code paths. The poller and broadcaster have their own dedicated unit - // test targets. - client_io_service_ = std::make_unique(); - client_io_service_thread_ = std::make_unique([this] { - std::unique_ptr work( - new boost::asio::io_service::work(*client_io_service_)); - client_io_service_->run(); - }); - server_io_service_ = std::make_unique(); gcs_server_ = std::make_unique(config_, *server_io_service_); gcs_server_->Start(); @@ -95,7 +86,15 @@ class GcsClientTest : public ::testing::TestWithParam { // Create GCS client. gcs::GcsClientOptions options("127.0.0.1:5397"); gcs_client_ = std::make_unique(options); + // Tests legacy code paths. The poller and broadcaster have their own dedicated unit + // test targets. + client_io_service_ = std::make_unique(); RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); + client_io_service_thread_ = std::make_unique([this] { + std::unique_ptr work( + new boost::asio::io_service::work(*client_io_service_)); + client_io_service_->run(); + }); } void TearDown() override { diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index a72f36de524e4..df97b1e0dd066 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -148,10 +148,10 @@ class ClientCallImpl : public ClientCall { /// The lifecycle of a `ClientCallTag` is as follows. /// /// When a client submits a new gRPC request, a new `ClientCallTag` object will be created -/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of +/// by `ClientCallManager::CreateCall`. Then the object will be used as the tag of /// `CompletionQueue`. /// -/// When the reply is received, `ClientCallMangager` will get the address of this object +/// When the reply is received, `ClientCallManager` will get the address of this object /// via `CompletionQueue`'s tag. And the manager should call /// `GetCall()->OnReplyReceived()` and then delete this object. class ClientCallTag { @@ -271,7 +271,7 @@ class ClientCallManager { } void SetClusterId(const ClusterID &cluster_id) { - auto old_id = cluster_id_.exchange(ClusterID::Nil()); + auto old_id = cluster_id_.exchange(cluster_id); if (!old_id.IsNil() && (old_id != cluster_id)) { RAY_LOG(FATAL) << "Expected cluster ID to be Nil or " << cluster_id << ", but got" << old_id;