Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Add ClusterID token to GCS server [3/n] #36535

Merged
merged 5 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mock/ray/gcs/gcs_server/gcs_node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace gcs {

class MockGcsNodeManager : public GcsNodeManager {
public:
MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr) {}
MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr, ClusterID::Nil()) {}
MOCK_METHOD(void,
HandleRegisterNode,
(rpc::RegisterNodeRequest request,
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
std::make_unique<rpc::GrpcServer>(WorkerTypeString(options_.worker_type),
assigned_port,
options_.node_ip_address == "127.0.0.1");
core_worker_server_->RegisterService(grpc_service_);
core_worker_server_->RegisterService(grpc_service_, false /* token_auth */);
core_worker_server_->Run();

// Set our own address.
Expand Down
15 changes: 13 additions & 2 deletions src/ray/gcs/gcs_server/gcs_node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ namespace gcs {
GcsNodeManager::GcsNodeManager(
std::shared_ptr<GcsPublisher> gcs_publisher,
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage,
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool)
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool,
const ClusterID &cluster_id)
: gcs_publisher_(std::move(gcs_publisher)),
gcs_table_storage_(std::move(gcs_table_storage)),
raylet_client_pool_(std::move(raylet_client_pool)) {}
raylet_client_pool_(std::move(raylet_client_pool)),
cluster_id_(cluster_id) {}

// Note: ServerCall will populate the cluster_id.
void GcsNodeManager::HandleGetClusterId(rpc::GetClusterIdRequest request,
rpc::GetClusterIdReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Registering GCS client!";
reply->set_cluster_id(cluster_id_.Binary());
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
}

void GcsNodeManager::HandleRegisterNode(rpc::RegisterNodeRequest request,
rpc::RegisterNodeReply *reply,
Expand Down
10 changes: 9 additions & 1 deletion src/ray/gcs/gcs_server/gcs_node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ class GcsNodeManager : public rpc::NodeInfoHandler {
/// \param gcs_table_storage GCS table external storage accessor.
explicit GcsNodeManager(std::shared_ptr<GcsPublisher> gcs_publisher,
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage,
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool);
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool,
const ClusterID &cluster_id);

/// Handle register rpc request come from raylet.
void HandleGetClusterId(rpc::GetClusterIdRequest request,
rpc::GetClusterIdReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

/// Handle register rpc request come from raylet.
void HandleRegisterNode(rpc::RegisterNodeRequest request,
Expand Down Expand Up @@ -167,6 +173,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler {
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage_;
/// Raylet client pool.
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool_;
/// Cluster ID to be shared with clients when connecting.
const ClusterID cluster_id_;

// Debug info.
enum CountType {
Expand Down
57 changes: 51 additions & 6 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "ray/gcs/gcs_server/store_client_kv.h"
#include "ray/gcs/store_client/observable_store_client.h"
#include "ray/pubsub/publisher.h"
#include "ray/util/util.h"

namespace ray {
namespace gcs {
Expand Down Expand Up @@ -86,6 +87,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
RAY_CHECK(status.ok()) << "Failed to put internal config";
this->main_service_.stop();
};

ray::rpc::StoredConfig stored_config;
stored_config.set_config(config_.raylet_config_list);
RAY_CHECK_OK(gcs_table_storage_->InternalConfigTable().Put(
Expand Down Expand Up @@ -138,7 +140,45 @@ RedisClientOptions GcsServer::GetRedisClientOptions() const {
void GcsServer::Start() {
// Load gcs tables data asynchronously.
auto gcs_init_data = std::make_shared<GcsInitData>(gcs_table_storage_);
gcs_init_data->AsyncLoad([this, gcs_init_data] { DoStart(*gcs_init_data); });
// Init KV Manager. This needs to be initialized first here so that
// it can be used to retrieve the cluster ID.
InitKVManager();
gcs_init_data->AsyncLoad([this, gcs_init_data] {
GetOrGenerateClusterId([this, gcs_init_data](ClusterID cluster_id) {
rpc_server_.SetClusterId(cluster_id);
DoStart(*gcs_init_data);
});
});
}

void GcsServer::GetOrGenerateClusterId(
std::function<void(ClusterID cluster_id)> &&continuation) {
static std::string const kTokenNamespace = "cluster";
kv_manager_->GetInstance().Get(
kTokenNamespace,
kClusterIdKey,
[this, continuation = std::move(continuation)](
std::optional<std::string> provided_cluster_id) mutable {
if (!provided_cluster_id.has_value()) {
ClusterID cluster_id = ClusterID::FromRandom();
RAY_LOG(INFO) << "No existing server cluster ID found. Generating new ID: "
<< cluster_id.Hex();
kv_manager_->GetInstance().Put(
kTokenNamespace,
kClusterIdKey,
cluster_id.Binary(),
false,
[&cluster_id,
continuation = std::move(continuation)](bool added_entry) mutable {
RAY_CHECK(added_entry) << "Failed to persist new cluster ID!";
continuation(cluster_id);
});
} else {
ClusterID cluster_id = ClusterID::FromBinary(provided_cluster_id.value());
RAY_LOG(INFO) << "Found existing server token: " << cluster_id;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token -> cluster id

continuation(cluster_id);
}
});
}

void GcsServer::DoStart(const GcsInitData &gcs_init_data) {
Expand All @@ -160,8 +200,8 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) {
// Init gcs health check manager.
InitGcsHealthCheckManager(gcs_init_data);

// Init KV Manager
InitKVManager();
// Init KV service.
InitKVService();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this? It's already called in GcsServer::Start

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One is KV Manager and one is KV service


// Init function manager
InitFunctionManager();
Expand Down Expand Up @@ -208,7 +248,6 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) {
gcs_actor_manager_->SetUsageStatsClient(usage_stats_client_.get());
gcs_placement_group_manager_->SetUsageStatsClient(usage_stats_client_.get());
gcs_task_manager_->SetUsageStatsClient(usage_stats_client_.get());

RecordMetrics();

periodical_runner_.RunFnPeriodically(
Expand Down Expand Up @@ -265,8 +304,10 @@ void GcsServer::Stop() {

void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) {
RAY_CHECK(gcs_table_storage_ && gcs_publisher_);
gcs_node_manager_ = std::make_unique<GcsNodeManager>(
gcs_publisher_, gcs_table_storage_, raylet_client_pool_);
gcs_node_manager_ = std::make_unique<GcsNodeManager>(gcs_publisher_,
gcs_table_storage_,
raylet_client_pool_,
rpc_server_.GetClusterId());
// Initialize by gcs tables data.
gcs_node_manager_->Initialize(gcs_init_data);
// Register service.
Expand Down Expand Up @@ -547,6 +588,10 @@ void GcsServer::InitKVManager() {
}

kv_manager_ = std::make_unique<GcsInternalKVManager>(std::move(instance));
}

void GcsServer::InitKVService() {
RAY_CHECK(kv_manager_);
kv_service_ = std::make_unique<rpc::InternalKVGrpcService>(main_service_, *kv_manager_);
// Register service.
rpc_server_.RegisterService(*kv_service_, false /* token_auth */);
Expand Down
8 changes: 8 additions & 0 deletions src/ray/gcs/gcs_server/gcs_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ class GcsServer {
/// Initialize KV manager.
void InitKVManager();

/// Initialize KV service.
void InitKVService();

/// Initialize function manager.
void InitFunctionManager();

Expand Down Expand Up @@ -182,6 +185,11 @@ class GcsServer {
/// Collect stats from each module.
void RecordMetrics() const;

/// Get cluster id if persisted, otherwise generate
/// a new one and persist as necessary.
/// Expected to be idempotent while server is up.
void GetOrGenerateClusterId(std::function<void(ClusterID cluster_id)> &&continuation);

/// Print the asio event loop stats for debugging.
void PrintAsioStats();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class GcsActorSchedulerMockTest : public Test {
void SetUp() override {
store_client = std::make_shared<MockStoreClient>();
actor_table = std::make_unique<GcsActorTable>(store_client);
gcs_node_manager = std::make_unique<GcsNodeManager>(nullptr, nullptr, nullptr);
gcs_node_manager =
std::make_unique<GcsNodeManager>(nullptr, nullptr, nullptr, ClusterID::Nil());
raylet_client = std::make_shared<MockRayletClientInterface>();
core_worker_client = std::make_shared<rpc::MockCoreWorkerClientInterface>();
client_pool = std::make_shared<rpc::NodeManagerClientPool>(
Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GcsActorSchedulerTest : public ::testing::Test {
store_client_ = std::make_shared<gcs::InMemoryStoreClient>(io_service_);
gcs_table_storage_ = std::make_shared<gcs::InMemoryGcsTableStorage>(io_service_);
gcs_node_manager_ = std::make_shared<gcs::GcsNodeManager>(
gcs_publisher_, gcs_table_storage_, raylet_client_pool_);
gcs_publisher_, gcs_table_storage_, raylet_client_pool_, ClusterID::Nil());
gcs_actor_table_ =
std::make_shared<GcsServerMocker::MockedGcsActorTable>(store_client_);
local_node_id_ = NodeID::FromRandom();
Expand Down
6 changes: 4 additions & 2 deletions src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class GcsNodeManagerTest : public ::testing::Test {
};

TEST_F(GcsNodeManagerTest, TestManagement) {
gcs::GcsNodeManager node_manager(gcs_publisher_, gcs_table_storage_, client_pool_);
gcs::GcsNodeManager node_manager(
gcs_publisher_, gcs_table_storage_, client_pool_, ClusterID::Nil());
// Test Add/Get/Remove functionality.
auto node = Mocker::GenNodeInfo();
auto node_id = NodeID::FromBinary(node->node_id());
Expand All @@ -55,7 +56,8 @@ TEST_F(GcsNodeManagerTest, TestManagement) {
}

TEST_F(GcsNodeManagerTest, TestListener) {
gcs::GcsNodeManager node_manager(gcs_publisher_, gcs_table_storage_, client_pool_);
gcs::GcsNodeManager node_manager(
gcs_publisher_, gcs_table_storage_, client_pool_, ClusterID::Nil());
// Test AddNodeAddedListener.
int node_count = 1000;
std::vector<std::shared_ptr<rpc::GcsNodeInfo>> added_nodes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test {
raylet_client_pool_ = std::make_shared<rpc::NodeManagerClientPool>(
[this](const rpc::Address &addr) { return raylet_clients_[addr.port()]; });
gcs_node_manager_ = std::make_shared<gcs::GcsNodeManager>(
gcs_publisher_, gcs_table_storage_, raylet_client_pool_);
gcs_publisher_, gcs_table_storage_, raylet_client_pool_, ClusterID::Nil());
scheduler_ = std::make_shared<GcsServerMocker::MockedGcsPlacementGroupScheduler>(
io_service_,
gcs_table_storage_,
Expand Down
2 changes: 1 addition & 1 deletion src/ray/object_manager/object_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ void ObjectManager::StartRpcService() {
for (int i = 0; i < config_.rpc_service_threads_number; i++) {
rpc_threads_[i] = std::thread(&ObjectManager::RunRpcService, this, i);
}
object_manager_server_.RegisterService(object_manager_service_);
object_manager_server_.RegisterService(object_manager_service_, false /* token_auth */);
object_manager_server_.Run();
}

Expand Down
9 changes: 9 additions & 0 deletions src/ray/protobuf/gcs_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ service ActorInfoGcsService {
rpc KillActorViaGcs(KillActorViaGcsRequest) returns (KillActorViaGcsReply);
}

message GetClusterIdRequest {}

message GetClusterIdReply {
GcsStatus status = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but it seems to be convention in other RPC replies, even when it's not used.

bytes cluster_id = 2;
}

message RegisterNodeRequest {
// Info of node.
GcsNodeInfo node_info = 1;
Expand Down Expand Up @@ -618,6 +625,8 @@ message GcsStatus {

// Service for node info access.
service NodeInfoGcsService {
// Register a client to GCS Service. Must be called before any other RPC in GCSClient.
rpc GetClusterId(GetClusterIdRequest) returns (GetClusterIdReply);
// Register a node to GCS Service.
rpc RegisterNode(RegisterNodeRequest) returns (RegisterNodeReply);
// Drain a node from GCS Service.
Expand Down
4 changes: 2 additions & 2 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service,

RAY_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
// Run the node manger rpc server.
node_manager_server_.RegisterService(node_manager_service_);
node_manager_server_.RegisterService(agent_manager_service_);
node_manager_server_.RegisterService(node_manager_service_, false /* token_auth */);
node_manager_server_.RegisterService(agent_manager_service_, false /* token_auth */);
if (RayConfig::instance().use_ray_syncer()) {
node_manager_server_.RegisterService(ray_syncer_service_);
}
Expand Down
5 changes: 5 additions & 0 deletions src/ray/rpc/gcs_server/gcs_rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ class GcsRpcClient {
KillActorViaGcs,
actor_info_grpc_client_,
/*method_timeout_ms*/ -1, )
/// Register a client to GCS Service.
VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService,
GetClusterId,
node_info_grpc_client_,
/*method_timeout_ms*/ -1, )

/// Register a node to GCS Service.
VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService,
Expand Down
5 changes: 5 additions & 0 deletions src/ray/rpc/gcs_server/gcs_rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ class NodeInfoGcsServiceHandler {
public:
virtual ~NodeInfoGcsServiceHandler() = default;

virtual void HandleGetClusterId(rpc::GetClusterIdRequest request,
rpc::GetClusterIdReply *reply,
rpc::SendReplyCallback send_reply_callback) = 0;

virtual void HandleRegisterNode(RegisterNodeRequest request,
RegisterNodeReply *reply,
SendReplyCallback send_reply_callback) = 0;
Expand Down Expand Up @@ -314,6 +318,7 @@ class NodeInfoGrpcService : public GrpcService {
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
NODE_INFO_SERVICE_RPC_HANDLER(GetClusterId);
NODE_INFO_SERVICE_RPC_HANDLER(RegisterNode);
NODE_INFO_SERVICE_RPC_HANDLER(DrainNode);
NODE_INFO_SERVICE_RPC_HANDLER(GetAllNodeInfo);
Expand Down
5 changes: 4 additions & 1 deletion src/ray/rpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ void GrpcServer::RegisterService(GrpcService &service, bool token_auth) {
services_.emplace_back(service.GetGrpcService());

for (int i = 0; i < num_threads_; i++) {
service.InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_.load());
if (token_auth && cluster_id_.IsNil()) {
RAY_LOG(FATAL) << "Expected cluster ID for token auth!";
}
service.InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_);
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/ray/rpc/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ class GrpcServer {
grpc::Server &GetServer() { return *server_; }

const ClusterID GetClusterId() {
RAY_CHECK(!cluster_id_.load().IsNil()) << "Cannot fetch cluster ID before it is set.";
return cluster_id_.load();
RAY_CHECK(!cluster_id_.IsNil()) << "Cannot fetch cluster ID before it is set.";
return cluster_id_;
}

void SetClusterId(const ClusterID &cluster_id) {
RAY_CHECK(!cluster_id.IsNil()) << "Cannot set cluster ID back to Nil!";
auto old_id = cluster_id_.exchange(cluster_id);
if (!old_id.IsNil() && old_id != cluster_id) {
if (!cluster_id_.IsNil() && cluster_id_ != cluster_id) {
RAY_LOG(FATAL) << "Resetting non-nil cluster ID! Setting to " << cluster_id
<< ", but old value is " << old_id;
<< ", but old value is " << cluster_id_;
}
cluster_id_ = cluster_id;
}

protected:
Expand All @@ -148,7 +148,7 @@ class GrpcServer {
/// interfaces (0.0.0.0)
const bool listen_to_localhost_only_;
/// Token representing ID of this cluster.
SafeClusterID cluster_id_;
ClusterID cluster_id_;
/// Indicates whether this server has been closed.
bool is_closed_;
/// The `grpc::Service` objects which should be registered to `ServerBuilder`.
Expand Down