Skip to content

Commit

Permalink
squashed changes before rebase
Browse files Browse the repository at this point in the history
[core] Add ClusterID token to GRPC server [1/n] (#36517)

First of a stack of changes to plumb through token exchange between GCS client and server. This adds a ClusterID token that can be passed to a GRPC server, which then initializes each component GRPC service with the token by passing to the ServerCallFactory objects when they are set up. When the factories create ServerCall objects for the GRPC service completion queue, this token is also passed to the ServerCall to check against inbound request metadata. The actual authentication check does not take place in this PR.

Note: This change also minorly cleans up some code in GCS server (changes a string check to use an enum).

Next change (client-side analogue): #36526

[core] Generate GCS server token

Signed-off-by: vitsai <victoria@anyscale.com>

Add client-side logic for setting cluster ID.

Signed-off-by: vitsai <victoria@anyscale.com>

bug fixes

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

bug workaround

Signed-off-by: vitsai <victoria@anyscale.com>

Fix windows build

Signed-off-by: vitsai <victoria@anyscale.com>

fix bug

Signed-off-by: vitsai <victoria@anyscale.com>

remove auth stuff from this pr

Signed-off-by: vitsai <victoria@anyscale.com>

fix mock build

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

remove future

Signed-off-by: vitsai <victoria@anyscale.com>

Remove top-level changes

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

Peel back everything that's not grpc-layer changes

Signed-off-by: vitsai <victoria@anyscale.com>

Change atomic to mutex

Signed-off-by: vitsai <victoria@anyscale.com>

Fix alignment of SafeClusterID

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

Add back everything in GCS server except RPC definition

Signed-off-by: vitsai <victoria@anyscale.com>

fix bug

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>

comments

Signed-off-by: vitsai <victoria@anyscale.com>
  • Loading branch information
vitsai committed Jun 28, 2023
1 parent c464026 commit 246cbc5
Show file tree
Hide file tree
Showing 16 changed files with 117 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/mock/ray/gcs/gcs_server/gcs_node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace gcs {

class MockGcsNodeManager : public GcsNodeManager {
public:
MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr) {}
MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr, ClusterID::Nil()) {}
MOCK_METHOD(void,
HandleRegisterNode,
(rpc::RegisterNodeRequest request,
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
std::make_unique<rpc::GrpcServer>(WorkerTypeString(options_.worker_type),
assigned_port,
options_.node_ip_address == "127.0.0.1");
core_worker_server_->RegisterService(grpc_service_);
core_worker_server_->RegisterService(grpc_service_, false /* token_auth */);
core_worker_server_->Run();

// Set our own address.
Expand Down
15 changes: 13 additions & 2 deletions src/ray/gcs/gcs_server/gcs_node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ namespace gcs {
GcsNodeManager::GcsNodeManager(
std::shared_ptr<GcsPublisher> gcs_publisher,
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage,
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool)
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool,
const ClusterID &cluster_id)
: gcs_publisher_(std::move(gcs_publisher)),
gcs_table_storage_(std::move(gcs_table_storage)),
raylet_client_pool_(std::move(raylet_client_pool)) {}
raylet_client_pool_(std::move(raylet_client_pool)),
cluster_id_(cluster_id) {}

// Note: ServerCall will populate the cluster_id.
void GcsNodeManager::HandleGetClusterId(rpc::GetClusterIdRequest request,
rpc::GetClusterIdReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Registering GCS client!";
reply->set_cluster_id(cluster_id_.Binary());
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
}

void GcsNodeManager::HandleRegisterNode(rpc::RegisterNodeRequest request,
rpc::RegisterNodeReply *reply,
Expand Down
10 changes: 9 additions & 1 deletion src/ray/gcs/gcs_server/gcs_node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ class GcsNodeManager : public rpc::NodeInfoHandler {
/// \param gcs_table_storage GCS table external storage accessor.
explicit GcsNodeManager(std::shared_ptr<GcsPublisher> gcs_publisher,
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage,
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool);
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool,
const ClusterID &cluster_id);

/// Handle register rpc request come from raylet.
void HandleGetClusterId(rpc::GetClusterIdRequest request,
rpc::GetClusterIdReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

/// Handle register rpc request come from raylet.
void HandleRegisterNode(rpc::RegisterNodeRequest request,
Expand Down Expand Up @@ -167,6 +173,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler {
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage_;
/// Raylet client pool.
std::shared_ptr<rpc::NodeManagerClientPool> raylet_client_pool_;
/// Cluster ID.
const ClusterID &cluster_id_;

// Debug info.
enum CountType {
Expand Down
55 changes: 50 additions & 5 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "ray/gcs/gcs_server/store_client_kv.h"
#include "ray/gcs/store_client/observable_store_client.h"
#include "ray/pubsub/publisher.h"
#include "ray/util/util.h"

namespace ray {
namespace gcs {
Expand Down Expand Up @@ -82,10 +83,16 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
RAY_LOG(FATAL) << "Unexpected storage type: " << storage_type_;
}

// Init KV Manager. This needs to be initialized first here so that
// it can be used to retrieve the cluster ID.
InitKVManager();
CacheAndSetClusterId();

auto on_done = [this](const ray::Status &status) {
RAY_CHECK(status.ok()) << "Failed to put internal config";
this->main_service_.stop();
};

ray::rpc::StoredConfig stored_config;
stored_config.set_config(config_.raylet_config_list);
RAY_CHECK_OK(gcs_table_storage_->InternalConfigTable().Put(
Expand Down Expand Up @@ -141,6 +148,39 @@ void GcsServer::Start() {
gcs_init_data->AsyncLoad([this, gcs_init_data] { DoStart(*gcs_init_data); });
}

void GcsServer::CacheAndSetClusterId() {
static std::string const kTokenNamespace = "cluster";
kv_manager_->GetInstance().Get(
kTokenNamespace, kClusterIdKey, [this](std::optional<std::string> token) mutable {
if (!token.has_value()) {
ClusterID cluster_id = ClusterID::FromRandom();
RAY_LOG(INFO) << "No existing server token found. Generating new token: "
<< cluster_id.Hex();
kv_manager_->GetInstance().Put(kTokenNamespace,
kClusterIdKey,
cluster_id.Binary(),
false,
[this, &cluster_id](bool added_entry) mutable {
RAY_CHECK(added_entry)
<< "Failed to persist new token!";
rpc_server_.SetClusterId(cluster_id);
main_service_.stop();
});
} else {
ClusterID cluster_id = ClusterID::FromBinary(token.value());
RAY_LOG(INFO) << "Found existing server token: " << cluster_id;
rpc_server_.SetClusterId(cluster_id);
main_service_.stop();
}
});
// This will run the async Get and Put inline.
main_service_.run();
main_service_.restart();

// Check the cluster ID exists. There is a RAY_CHECK in here.
RAY_UNUSED(rpc_server_.GetClusterId());
}

void GcsServer::DoStart(const GcsInitData &gcs_init_data) {
// Init cluster resource scheduler.
InitClusterResourceScheduler();
Expand All @@ -160,8 +200,8 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) {
// Init gcs health check manager.
InitGcsHealthCheckManager(gcs_init_data);

// Init KV Manager
InitKVManager();
// Init KV service.
InitKVService();

// Init function manager
InitFunctionManager();
Expand Down Expand Up @@ -208,7 +248,6 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) {
gcs_actor_manager_->SetUsageStatsClient(usage_stats_client_.get());
gcs_placement_group_manager_->SetUsageStatsClient(usage_stats_client_.get());
gcs_task_manager_->SetUsageStatsClient(usage_stats_client_.get());

RecordMetrics();

periodical_runner_.RunFnPeriodically(
Expand Down Expand Up @@ -265,8 +304,10 @@ void GcsServer::Stop() {

void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) {
RAY_CHECK(gcs_table_storage_ && gcs_publisher_);
gcs_node_manager_ = std::make_unique<GcsNodeManager>(
gcs_publisher_, gcs_table_storage_, raylet_client_pool_);
gcs_node_manager_ = std::make_unique<GcsNodeManager>(gcs_publisher_,
gcs_table_storage_,
raylet_client_pool_,
rpc_server_.GetClusterId());
// Initialize by gcs tables data.
gcs_node_manager_->Initialize(gcs_init_data);
// Register service.
Expand Down Expand Up @@ -547,6 +588,10 @@ void GcsServer::InitKVManager() {
}

kv_manager_ = std::make_unique<GcsInternalKVManager>(std::move(instance));
}

void GcsServer::InitKVService() {
RAY_CHECK(kv_manager_);
kv_service_ = std::make_unique<rpc::InternalKVGrpcService>(main_service_, *kv_manager_);
// Register service.
rpc_server_.RegisterService(*kv_service_, false /* token_auth */);
Expand Down
10 changes: 10 additions & 0 deletions src/ray/gcs/gcs_server/gcs_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <atomic>

#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/ray_syncer/ray_syncer.h"
#include "ray/common/runtime_env_manager.h"
Expand Down Expand Up @@ -154,6 +156,9 @@ class GcsServer {
/// Initialize KV manager.
void InitKVManager();

/// Initialize KV service.
void InitKVService();

/// Initialize function manager.
void InitFunctionManager();

Expand Down Expand Up @@ -182,6 +187,11 @@ class GcsServer {
/// Collect stats from each module.
void RecordMetrics() const;

/// Get server token if persisted, otherwise generate
/// a new one and persist as necessary.
/// Expected to be idempotent while server is up.
void CacheAndSetClusterId();

/// Print the asio event loop stats for debugging.
void PrintAsioStats();

Expand Down
3 changes: 2 additions & 1 deletion src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class GcsActorSchedulerMockTest : public Test {
void SetUp() override {
store_client = std::make_shared<MockStoreClient>();
actor_table = std::make_unique<GcsActorTable>(store_client);
gcs_node_manager = std::make_unique<GcsNodeManager>(nullptr, nullptr, nullptr);
gcs_node_manager =
std::make_unique<GcsNodeManager>(nullptr, nullptr, nullptr, ClusterID::Nil());
raylet_client = std::make_shared<MockRayletClientInterface>();
core_worker_client = std::make_shared<rpc::MockCoreWorkerClientInterface>();
client_pool = std::make_shared<rpc::NodeManagerClientPool>(
Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GcsActorSchedulerTest : public ::testing::Test {
store_client_ = std::make_shared<gcs::InMemoryStoreClient>(io_service_);
gcs_table_storage_ = std::make_shared<gcs::InMemoryGcsTableStorage>(io_service_);
gcs_node_manager_ = std::make_shared<gcs::GcsNodeManager>(
gcs_publisher_, gcs_table_storage_, raylet_client_pool_);
gcs_publisher_, gcs_table_storage_, raylet_client_pool_, ClusterID::Nil());
gcs_actor_table_ =
std::make_shared<GcsServerMocker::MockedGcsActorTable>(store_client_);
local_node_id_ = NodeID::FromRandom();
Expand Down
6 changes: 4 additions & 2 deletions src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class GcsNodeManagerTest : public ::testing::Test {
};

TEST_F(GcsNodeManagerTest, TestManagement) {
gcs::GcsNodeManager node_manager(gcs_publisher_, gcs_table_storage_, client_pool_);
gcs::GcsNodeManager node_manager(
gcs_publisher_, gcs_table_storage_, client_pool_, ClusterID::Nil());
// Test Add/Get/Remove functionality.
auto node = Mocker::GenNodeInfo();
auto node_id = NodeID::FromBinary(node->node_id());
Expand All @@ -55,7 +56,8 @@ TEST_F(GcsNodeManagerTest, TestManagement) {
}

TEST_F(GcsNodeManagerTest, TestListener) {
gcs::GcsNodeManager node_manager(gcs_publisher_, gcs_table_storage_, client_pool_);
gcs::GcsNodeManager node_manager(
gcs_publisher_, gcs_table_storage_, client_pool_, ClusterID::Nil());
// Test AddNodeAddedListener.
int node_count = 1000;
std::vector<std::shared_ptr<rpc::GcsNodeInfo>> added_nodes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test {
raylet_client_pool_ = std::make_shared<rpc::NodeManagerClientPool>(
[this](const rpc::Address &addr) { return raylet_clients_[addr.port()]; });
gcs_node_manager_ = std::make_shared<gcs::GcsNodeManager>(
gcs_publisher_, gcs_table_storage_, raylet_client_pool_);
gcs_publisher_, gcs_table_storage_, raylet_client_pool_, ClusterID::Nil());
scheduler_ = std::make_shared<GcsServerMocker::MockedGcsPlacementGroupScheduler>(
io_service_,
gcs_table_storage_,
Expand Down
2 changes: 1 addition & 1 deletion src/ray/object_manager/object_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ void ObjectManager::StartRpcService() {
for (int i = 0; i < config_.rpc_service_threads_number; i++) {
rpc_threads_[i] = std::thread(&ObjectManager::RunRpcService, this, i);
}
object_manager_server_.RegisterService(object_manager_service_);
object_manager_server_.RegisterService(object_manager_service_, false /* token_auth */);
object_manager_server_.Run();
}

Expand Down
9 changes: 9 additions & 0 deletions src/ray/protobuf/gcs_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ service ActorInfoGcsService {
rpc KillActorViaGcs(KillActorViaGcsRequest) returns (KillActorViaGcsReply);
}

message GetClusterIdRequest {}

message GetClusterIdReply {
GcsStatus status = 1;
bytes cluster_id = 2;
}

message RegisterNodeRequest {
// Info of node.
GcsNodeInfo node_info = 1;
Expand Down Expand Up @@ -618,6 +625,8 @@ message GcsStatus {

// Service for node info access.
service NodeInfoGcsService {
// Register a client to GCS Service. Must be called before any other RPC in GCSClient.
rpc GetClusterId(GetClusterIdRequest) returns (GetClusterIdReply);
// Register a node to GCS Service.
rpc RegisterNode(RegisterNodeRequest) returns (RegisterNodeReply);
// Drain a node from GCS Service.
Expand Down
4 changes: 2 additions & 2 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service,

RAY_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
// Run the node manger rpc server.
node_manager_server_.RegisterService(node_manager_service_);
node_manager_server_.RegisterService(agent_manager_service_);
node_manager_server_.RegisterService(node_manager_service_, false /* token_auth */);
node_manager_server_.RegisterService(agent_manager_service_, false /* token_auth */);
if (RayConfig::instance().use_ray_syncer()) {
node_manager_server_.RegisterService(ray_syncer_service_);
}
Expand Down
5 changes: 5 additions & 0 deletions src/ray/rpc/gcs_server/gcs_rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ class GcsRpcClient {
KillActorViaGcs,
actor_info_grpc_client_,
/*method_timeout_ms*/ -1, )
/// Register a client to GCS Service.
VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService,
GetClusterId,
node_info_grpc_client_,
/*method_timeout_ms*/ -1, )

/// Register a node to GCS Service.
VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService,
Expand Down
5 changes: 5 additions & 0 deletions src/ray/rpc/gcs_server/gcs_rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ class NodeInfoGcsServiceHandler {
public:
virtual ~NodeInfoGcsServiceHandler() = default;

virtual void HandleGetClusterId(rpc::GetClusterIdRequest request,
rpc::GetClusterIdReply *reply,
rpc::SendReplyCallback send_reply_callback) = 0;

virtual void HandleRegisterNode(RegisterNodeRequest request,
RegisterNodeReply *reply,
SendReplyCallback send_reply_callback) = 0;
Expand Down Expand Up @@ -314,6 +318,7 @@ class NodeInfoGrpcService : public GrpcService {
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
NODE_INFO_SERVICE_RPC_HANDLER(GetClusterId);
NODE_INFO_SERVICE_RPC_HANDLER(RegisterNode);
NODE_INFO_SERVICE_RPC_HANDLER(DrainNode);
NODE_INFO_SERVICE_RPC_HANDLER(GetAllNodeInfo);
Expand Down
3 changes: 3 additions & 0 deletions src/ray/rpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ void GrpcServer::RegisterService(GrpcService &service, bool token_auth) {
services_.emplace_back(service.GetGrpcService());

for (int i = 0; i < num_threads_; i++) {
if (token_auth && cluster_id_.load().IsNil()) {
RAY_LOG(FATAL) << "Expected cluster ID for token auth!";
}
service.InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_.load());
}
}
Expand Down

0 comments on commit 246cbc5

Please sign in to comment.