diff --git a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h index 417b966c65a4f..de5b471255335 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h @@ -17,7 +17,7 @@ namespace gcs { class MockGcsNodeManager : public GcsNodeManager { public: - MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr) {} + MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr, ClusterID::Nil()) {} MOCK_METHOD(void, HandleRegisterNode, (rpc::RegisterNodeRequest request, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index b1c14ef87f996..1f58a50e21baf 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -206,7 +206,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ std::make_unique(WorkerTypeString(options_.worker_type), assigned_port, options_.node_ip_address == "127.0.0.1"); - core_worker_server_->RegisterService(grpc_service_); + core_worker_server_->RegisterService(grpc_service_, false /* token_auth */); core_worker_server_->Run(); // Set our own address. diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 3da4a49c3f5e3..6ed990e10c884 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -30,10 +30,21 @@ namespace gcs { GcsNodeManager::GcsNodeManager( std::shared_ptr gcs_publisher, std::shared_ptr gcs_table_storage, - std::shared_ptr raylet_client_pool) + std::shared_ptr raylet_client_pool, + const ClusterID &cluster_id) : gcs_publisher_(std::move(gcs_publisher)), gcs_table_storage_(std::move(gcs_table_storage)), - raylet_client_pool_(std::move(raylet_client_pool)) {} + raylet_client_pool_(std::move(raylet_client_pool)), + cluster_id_(cluster_id) {} + +// Note: ServerCall will populate the cluster_id. +void GcsNodeManager::HandleGetClusterId(rpc::GetClusterIdRequest request, + rpc::GetClusterIdReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Registering GCS client!"; + reply->set_cluster_id(cluster_id_.Binary()); + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); +} void GcsNodeManager::HandleRegisterNode(rpc::RegisterNodeRequest request, rpc::RegisterNodeReply *reply, diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.h b/src/ray/gcs/gcs_server/gcs_node_manager.h index 6767f1bd6ef33..21db16c5da236 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/ray/gcs/gcs_server/gcs_node_manager.h @@ -48,7 +48,13 @@ class GcsNodeManager : public rpc::NodeInfoHandler { /// \param gcs_table_storage GCS table external storage accessor. explicit GcsNodeManager(std::shared_ptr gcs_publisher, std::shared_ptr gcs_table_storage, - std::shared_ptr raylet_client_pool); + std::shared_ptr raylet_client_pool, + const ClusterID &cluster_id); + + /// Handle register rpc request come from raylet. + void HandleGetClusterId(rpc::GetClusterIdRequest request, + rpc::GetClusterIdReply *reply, + rpc::SendReplyCallback send_reply_callback) override; /// Handle register rpc request come from raylet. void HandleRegisterNode(rpc::RegisterNodeRequest request, @@ -167,6 +173,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler { std::shared_ptr gcs_table_storage_; /// Raylet client pool. std::shared_ptr raylet_client_pool_; + /// Cluster ID to be shared with clients when connecting. + const ClusterID cluster_id_; // Debug info. enum CountType { diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 33c74faa70bac..aabeeb335b21c 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -33,6 +33,7 @@ #include "ray/gcs/gcs_server/store_client_kv.h" #include "ray/gcs/store_client/observable_store_client.h" #include "ray/pubsub/publisher.h" +#include "ray/util/util.h" namespace ray { namespace gcs { @@ -86,6 +87,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, RAY_CHECK(status.ok()) << "Failed to put internal config"; this->main_service_.stop(); }; + ray::rpc::StoredConfig stored_config; stored_config.set_config(config_.raylet_config_list); RAY_CHECK_OK(gcs_table_storage_->InternalConfigTable().Put( @@ -138,7 +140,45 @@ RedisClientOptions GcsServer::GetRedisClientOptions() const { void GcsServer::Start() { // Load gcs tables data asynchronously. auto gcs_init_data = std::make_shared(gcs_table_storage_); - gcs_init_data->AsyncLoad([this, gcs_init_data] { DoStart(*gcs_init_data); }); + // Init KV Manager. This needs to be initialized first here so that + // it can be used to retrieve the cluster ID. + InitKVManager(); + gcs_init_data->AsyncLoad([this, gcs_init_data] { + GetOrGenerateClusterId([this, gcs_init_data](ClusterID cluster_id) { + rpc_server_.SetClusterId(cluster_id); + DoStart(*gcs_init_data); + }); + }); +} + +void GcsServer::GetOrGenerateClusterId( + std::function &&continuation) { + static std::string const kTokenNamespace = "cluster"; + kv_manager_->GetInstance().Get( + kTokenNamespace, + kClusterIdKey, + [this, continuation = std::move(continuation)]( + std::optional provided_cluster_id) mutable { + if (!provided_cluster_id.has_value()) { + ClusterID cluster_id = ClusterID::FromRandom(); + RAY_LOG(INFO) << "No existing server cluster ID found. Generating new ID: " + << cluster_id.Hex(); + kv_manager_->GetInstance().Put( + kTokenNamespace, + kClusterIdKey, + cluster_id.Binary(), + false, + [&cluster_id, + continuation = std::move(continuation)](bool added_entry) mutable { + RAY_CHECK(added_entry) << "Failed to persist new cluster ID!"; + continuation(cluster_id); + }); + } else { + ClusterID cluster_id = ClusterID::FromBinary(provided_cluster_id.value()); + RAY_LOG(INFO) << "Found existing server token: " << cluster_id; + continuation(cluster_id); + } + }); } void GcsServer::DoStart(const GcsInitData &gcs_init_data) { @@ -160,8 +200,8 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) { // Init gcs health check manager. InitGcsHealthCheckManager(gcs_init_data); - // Init KV Manager - InitKVManager(); + // Init KV service. + InitKVService(); // Init function manager InitFunctionManager(); @@ -208,7 +248,6 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) { gcs_actor_manager_->SetUsageStatsClient(usage_stats_client_.get()); gcs_placement_group_manager_->SetUsageStatsClient(usage_stats_client_.get()); gcs_task_manager_->SetUsageStatsClient(usage_stats_client_.get()); - RecordMetrics(); periodical_runner_.RunFnPeriodically( @@ -265,8 +304,10 @@ void GcsServer::Stop() { void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) { RAY_CHECK(gcs_table_storage_ && gcs_publisher_); - gcs_node_manager_ = std::make_unique( - gcs_publisher_, gcs_table_storage_, raylet_client_pool_); + gcs_node_manager_ = std::make_unique(gcs_publisher_, + gcs_table_storage_, + raylet_client_pool_, + rpc_server_.GetClusterId()); // Initialize by gcs tables data. gcs_node_manager_->Initialize(gcs_init_data); // Register service. @@ -547,6 +588,10 @@ void GcsServer::InitKVManager() { } kv_manager_ = std::make_unique(std::move(instance)); +} + +void GcsServer::InitKVService() { + RAY_CHECK(kv_manager_); kv_service_ = std::make_unique(main_service_, *kv_manager_); // Register service. rpc_server_.RegisterService(*kv_service_, false /* token_auth */); diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 88fe312a2b6fd..b80f1f906f6d8 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -154,6 +154,9 @@ class GcsServer { /// Initialize KV manager. void InitKVManager(); + /// Initialize KV service. + void InitKVService(); + /// Initialize function manager. void InitFunctionManager(); @@ -182,6 +185,11 @@ class GcsServer { /// Collect stats from each module. void RecordMetrics() const; + /// Get cluster id if persisted, otherwise generate + /// a new one and persist as necessary. + /// Expected to be idempotent while server is up. + void GetOrGenerateClusterId(std::function &&continuation); + /// Print the asio event loop stats for debugging. void PrintAsioStats(); diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc index c58311f869276..19a7f2d044915 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc @@ -38,7 +38,8 @@ class GcsActorSchedulerMockTest : public Test { void SetUp() override { store_client = std::make_shared(); actor_table = std::make_unique(store_client); - gcs_node_manager = std::make_unique(nullptr, nullptr, nullptr); + gcs_node_manager = + std::make_unique(nullptr, nullptr, nullptr, ClusterID::Nil()); raylet_client = std::make_shared(); core_worker_client = std::make_shared(); client_pool = std::make_shared( diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc index 681d73bae0105..617a9083f2171 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc @@ -39,7 +39,7 @@ class GcsActorSchedulerTest : public ::testing::Test { store_client_ = std::make_shared(io_service_); gcs_table_storage_ = std::make_shared(io_service_); gcs_node_manager_ = std::make_shared( - gcs_publisher_, gcs_table_storage_, raylet_client_pool_); + gcs_publisher_, gcs_table_storage_, raylet_client_pool_, ClusterID::Nil()); gcs_actor_table_ = std::make_shared(store_client_); local_node_id_ = NodeID::FromRandom(); diff --git a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc index 1adab463d1218..0424b209faa25 100644 --- a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc @@ -42,7 +42,8 @@ class GcsNodeManagerTest : public ::testing::Test { }; TEST_F(GcsNodeManagerTest, TestManagement) { - gcs::GcsNodeManager node_manager(gcs_publisher_, gcs_table_storage_, client_pool_); + gcs::GcsNodeManager node_manager( + gcs_publisher_, gcs_table_storage_, client_pool_, ClusterID::Nil()); // Test Add/Get/Remove functionality. auto node = Mocker::GenNodeInfo(); auto node_id = NodeID::FromBinary(node->node_id()); @@ -55,7 +56,8 @@ TEST_F(GcsNodeManagerTest, TestManagement) { } TEST_F(GcsNodeManagerTest, TestListener) { - gcs::GcsNodeManager node_manager(gcs_publisher_, gcs_table_storage_, client_pool_); + gcs::GcsNodeManager node_manager( + gcs_publisher_, gcs_table_storage_, client_pool_, ClusterID::Nil()); // Test AddNodeAddedListener. int node_count = 1000; std::vector> added_nodes; diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc index bdb514c65f122..30b2cf306b6cb 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc @@ -65,7 +65,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { raylet_client_pool_ = std::make_shared( [this](const rpc::Address &addr) { return raylet_clients_[addr.port()]; }); gcs_node_manager_ = std::make_shared( - gcs_publisher_, gcs_table_storage_, raylet_client_pool_); + gcs_publisher_, gcs_table_storage_, raylet_client_pool_, ClusterID::Nil()); scheduler_ = std::make_shared( io_service_, gcs_table_storage_, diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index cf054aecfe422..44e6ad7c19144 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -174,7 +174,7 @@ void ObjectManager::StartRpcService() { for (int i = 0; i < config_.rpc_service_threads_number; i++) { rpc_threads_[i] = std::thread(&ObjectManager::RunRpcService, this, i); } - object_manager_server_.RegisterService(object_manager_service_); + object_manager_server_.RegisterService(object_manager_service_, false /* token_auth */); object_manager_server_.Run(); } diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 7bc382bc08425..8eff5fbf98756 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -177,6 +177,13 @@ service ActorInfoGcsService { rpc KillActorViaGcs(KillActorViaGcsRequest) returns (KillActorViaGcsReply); } +message GetClusterIdRequest {} + +message GetClusterIdReply { + GcsStatus status = 1; + bytes cluster_id = 2; +} + message RegisterNodeRequest { // Info of node. GcsNodeInfo node_info = 1; @@ -618,6 +625,8 @@ message GcsStatus { // Service for node info access. service NodeInfoGcsService { + // Register a client to GCS Service. Must be called before any other RPC in GCSClient. + rpc GetClusterId(GetClusterIdRequest) returns (GetClusterIdReply); // Register a node to GCS Service. rpc RegisterNode(RegisterNodeRequest) returns (RegisterNodeReply); // Drain a node from GCS Service. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 63c48bed5d4b4..9341d44240b90 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -389,8 +389,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, RAY_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); // Run the node manger rpc server. - node_manager_server_.RegisterService(node_manager_service_); - node_manager_server_.RegisterService(agent_manager_service_); + node_manager_server_.RegisterService(node_manager_service_, false /* token_auth */); + node_manager_server_.RegisterService(agent_manager_service_, false /* token_auth */); if (RayConfig::instance().use_ray_syncer()) { node_manager_server_.RegisterService(ray_syncer_service_); } diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 9d20611778159..e021d4287c558 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -314,6 +314,11 @@ class GcsRpcClient { KillActorViaGcs, actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) + /// Register a client to GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, + GetClusterId, + node_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Register a node to GCS Service. VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 0d3a526d8c915..4a46da66a38e6 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -276,6 +276,10 @@ class NodeInfoGcsServiceHandler { public: virtual ~NodeInfoGcsServiceHandler() = default; + virtual void HandleGetClusterId(rpc::GetClusterIdRequest request, + rpc::GetClusterIdReply *reply, + rpc::SendReplyCallback send_reply_callback) = 0; + virtual void HandleRegisterNode(RegisterNodeRequest request, RegisterNodeReply *reply, SendReplyCallback send_reply_callback) = 0; @@ -314,6 +318,7 @@ class NodeInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id) override { + NODE_INFO_SERVICE_RPC_HANDLER(GetClusterId); NODE_INFO_SERVICE_RPC_HANDLER(RegisterNode); NODE_INFO_SERVICE_RPC_HANDLER(DrainNode); NODE_INFO_SERVICE_RPC_HANDLER(GetAllNodeInfo); diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index d4d95574b833c..0143bc39ee942 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -159,7 +159,10 @@ void GrpcServer::RegisterService(GrpcService &service, bool token_auth) { services_.emplace_back(service.GetGrpcService()); for (int i = 0; i < num_threads_; i++) { - service.InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_.load()); + if (token_auth && cluster_id_.IsNil()) { + RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; + } + service.InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_); } } diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 7e7cfa7dbdbaf..89ce79db734ee 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -118,17 +118,17 @@ class GrpcServer { grpc::Server &GetServer() { return *server_; } const ClusterID GetClusterId() { - RAY_CHECK(!cluster_id_.load().IsNil()) << "Cannot fetch cluster ID before it is set."; - return cluster_id_.load(); + RAY_CHECK(!cluster_id_.IsNil()) << "Cannot fetch cluster ID before it is set."; + return cluster_id_; } void SetClusterId(const ClusterID &cluster_id) { RAY_CHECK(!cluster_id.IsNil()) << "Cannot set cluster ID back to Nil!"; - auto old_id = cluster_id_.exchange(cluster_id); - if (!old_id.IsNil() && old_id != cluster_id) { + if (!cluster_id_.IsNil() && cluster_id_ != cluster_id) { RAY_LOG(FATAL) << "Resetting non-nil cluster ID! Setting to " << cluster_id - << ", but old value is " << old_id; + << ", but old value is " << cluster_id_; } + cluster_id_ = cluster_id; } protected: @@ -148,7 +148,7 @@ class GrpcServer { /// interfaces (0.0.0.0) const bool listen_to_localhost_only_; /// Token representing ID of this cluster. - SafeClusterID cluster_id_; + ClusterID cluster_id_; /// Indicates whether this server has been closed. bool is_closed_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`.