Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner): fine-grained access control #9669

Merged
merged 2 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using ::testing::EndsWith;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Not;
using ::testing::StartsWith;

// Constants used to identify the encryption key.
auto constexpr kKeyRing = "spanner-cmek";
Expand Down Expand Up @@ -276,6 +277,26 @@ TEST_F(DatabaseAdminClientTest, DatabaseBasicCRUD) {
AllOf(HasSubstr("Key has type JSON"),
HasSubstr("part of the primary key"))));

// Verify that a new role can be created and returned.
statements.clear();
statements.emplace_back(R"""(
CREATE ROLE test_role
)""");
metadata = client_.UpdateDatabaseDdl(database_.FullName(), statements).get();
if (emulator_) {
EXPECT_THAT(metadata, StatusIs(StatusCode::kInvalidArgument));
} else {
EXPECT_THAT(metadata, IsOk());
std::vector<std::string> roles;
for (auto const& role : client_.ListDatabaseRoles(database_.FullName())) {
EXPECT_THAT(role->name(),
StartsWith(database_.FullName() + "/databaseRoles/"));
roles.push_back(role->name());
}
EXPECT_THAT(roles, AllOf(Contains(EndsWith("/public")),
Contains(EndsWith("/test_role"))));
}

EXPECT_TRUE(DatabaseExists()) << "Database " << database_;
auto drop_status = client_.DropDatabase(database_.FullName());
EXPECT_STATUS_OK(drop_status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct SessionPoolFriendForTest {
CompletionQueue& cq,
std::shared_ptr<SpannerStub> const& stub,
std::map<std::string, std::string> const& labels,
int num_sessions) {
return session_pool->AsyncBatchCreateSessions(cq, stub, labels,
std::string const& role, int num_sessions) {
return session_pool->AsyncBatchCreateSessions(cq, stub, labels, role,
num_sessions);
}

Expand Down Expand Up @@ -73,10 +73,11 @@ TEST_F(SessionPoolIntegrationTest, SessionAsyncCRUD) {

// Make an asynchronous request, but immediately block until the response
// arrives
auto constexpr kSessionCreatorRole = "public";
auto constexpr kNumTestSession = 4;
auto create_response =
spanner_internal::SessionPoolFriendForTest::AsyncBatchCreateSessions(
session_pool, cq, stub, {}, kNumTestSession)
session_pool, cq, stub, {}, kSessionCreatorRole, kNumTestSession)
.get();
ASSERT_STATUS_OK(create_response);
EXPECT_EQ(kNumTestSession, create_response->session_size());
Expand Down
19 changes: 13 additions & 6 deletions google/cloud/spanner/internal/connection_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,14 @@ MATCHER(HasBeginTransaction, "request has begin TransactionSelector set") {
return arg.transaction().has_begin();
}

MATCHER_P(HasDatabase, database, "Request has expected database") {
MATCHER_P(HasDatabase, database, "request has expected database") {
return arg.database() == database.FullName();
}

MATCHER_P(HasCreatorRole, role, "request has expected creator role") {
return arg.session_template().creator_role() == role;
}

// Matches a `spanner::Transaction` that is bound to a "bad" session.
MATCHER(HasBadSession, "bound to a session that's marked bad") {
return Visit(arg, [&](SessionHolder& session,
Expand Down Expand Up @@ -265,11 +269,11 @@ Options MakeLimitedRetryOptions() {

std::shared_ptr<ConnectionImpl> MakeConnectionImpl(
spanner::Database db,
std::shared_ptr<spanner_testing::MockSpannerStub> mock) {
auto opts = MakeLimitedRetryOptions();
std::shared_ptr<spanner_testing::MockSpannerStub> mock, Options opts = {}) {
// No actual credential needed for unit tests
opts.set<UnifiedCredentialsOption>(MakeInsecureCredentials());
opts = DefaultOptions(std::move(opts));
opts = DefaultOptions(
internal::MergeOptions(std::move(opts), MakeLimitedRetryOptions()));
auto background = internal::MakeBackgroundThreadsFactory(opts)();
std::vector<std::shared_ptr<SpannerStub>> stubs = {std::move(mock)};
return std::make_shared<ConnectionImpl>(std::move(db), std::move(background),
Expand Down Expand Up @@ -2705,8 +2709,10 @@ TEST(ConnectionImplTest, MultipleThreads) {
auto mock = std::make_shared<spanner_testing::MockSpannerStub>();
auto db = spanner::Database("project", "instance", "database");
std::string const session_prefix = "test-session-prefix-";
std::string const role = "TestRole";
std::atomic<int> session_counter(0);
EXPECT_CALL(*mock, BatchCreateSessions(_, HasDatabase(db)))
EXPECT_CALL(*mock, BatchCreateSessions(
_, AllOf(HasDatabase(db), HasCreatorRole(role))))
.WillRepeatedly(
[&session_prefix, &session_counter](
grpc::ClientContext&,
Expand Down Expand Up @@ -2745,7 +2751,8 @@ TEST(ConnectionImplTest, MultipleThreads) {
}
};

auto conn = MakeConnectionImpl(db, mock);
auto conn = MakeConnectionImpl(
db, mock, Options{}.set<spanner::SessionCreatorRoleOption>(role));
std::vector<std::future<void>> tasks;
for (unsigned i = 0; i != thread_count; ++i) {
tasks.push_back(std::async(std::launch::async, runner, i,
Expand Down
37 changes: 26 additions & 11 deletions google/cloud/spanner/internal/session_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,20 @@ Status SessionPool::CreateSessions(
std::vector<CreateCount> const& create_counts,
WaitForSessionAllocation wait) {
Status return_status;
auto const& labels = opts_.get<spanner::SessionPoolLabelsOption>();
auto const& role = opts_.get<spanner::SessionCreatorRoleOption>();
for (auto const& op : create_counts) {
auto const& labels = opts_.get<spanner::SessionPoolLabelsOption>();
switch (wait) {
case WaitForSessionAllocation::kWait: {
auto status = CreateSessionsSync(op.channel, labels, op.session_count);
auto status =
CreateSessionsSync(op.channel, labels, role, op.session_count);
if (!status.ok()) {
return_status = status;
}
break;
}
case WaitForSessionAllocation::kNoWait:
CreateSessionsAsync(op.channel, labels, op.session_count);
CreateSessionsAsync(op.channel, labels, role, op.session_count);
break;
}
}
Expand Down Expand Up @@ -355,11 +357,17 @@ void SessionPool::Release(std::unique_ptr<Session> session) {
// Creates `num_sessions` on `channel` and adds them to the pool.
Status SessionPool::CreateSessionsSync(
std::shared_ptr<Channel> const& channel,
std::map<std::string, std::string> const& labels, int num_sessions) {
std::map<std::string, std::string> const& labels, std::string const& role,
int num_sessions) {
google::spanner::v1::BatchCreateSessionsRequest request;
request.set_database(db_.FullName());
request.mutable_session_template()->mutable_labels()->insert(labels.begin(),
labels.end());
if (!labels.empty()) {
request.mutable_session_template()->mutable_labels()->insert(labels.begin(),
labels.end());
}
if (!role.empty()) {
request.mutable_session_template()->set_creator_role(role);
}
request.set_session_count(std::int32_t{num_sessions});
auto const& stub = channel->stub;
auto response = RetryLoop(
Expand All @@ -375,9 +383,10 @@ Status SessionPool::CreateSessionsSync(

void SessionPool::CreateSessionsAsync(
std::shared_ptr<Channel> const& channel,
std::map<std::string, std::string> const& labels, int num_sessions) {
std::map<std::string, std::string> const& labels, std::string const& role,
int num_sessions) {
std::weak_ptr<SessionPool> pool = shared_from_this();
AsyncBatchCreateSessions(cq_, channel->stub, labels, num_sessions)
AsyncBatchCreateSessions(cq_, channel->stub, labels, role, num_sessions)
.then(
[pool, channel](
future<StatusOr<google::spanner::v1::BatchCreateSessionsResponse>>
Expand Down Expand Up @@ -408,11 +417,17 @@ SessionHolder SessionPool::MakeSessionHolder(std::unique_ptr<Session> session,
future<StatusOr<google::spanner::v1::BatchCreateSessionsResponse>>
SessionPool::AsyncBatchCreateSessions(
CompletionQueue& cq, std::shared_ptr<SpannerStub> const& stub,
std::map<std::string, std::string> const& labels, int num_sessions) {
std::map<std::string, std::string> const& labels, std::string const& role,
int num_sessions) {
google::spanner::v1::BatchCreateSessionsRequest request;
request.set_database(db_.FullName());
request.mutable_session_template()->mutable_labels()->insert(labels.begin(),
labels.end());
if (!labels.empty()) {
request.mutable_session_template()->mutable_labels()->insert(labels.begin(),
labels.end());
}
if (!role.empty()) {
request.mutable_session_template()->set_creator_role(role);
}
request.set_session_count(std::int32_t{num_sessions});
return google::cloud::internal::AsyncRetryLoop(
retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(),
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/spanner/internal/session_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ class SessionPool : public std::enable_shared_from_this<SessionPool> {
WaitForSessionAllocation wait); // LOCKS_EXCLUDED(mu_)
Status CreateSessionsSync(std::shared_ptr<Channel> const& channel,
std::map<std::string, std::string> const& labels,
std::string const& role,
int num_sessions); // LOCKS_EXCLUDED(mu_)
void CreateSessionsAsync(std::shared_ptr<Channel> const& channel,
std::map<std::string, std::string> const& labels,
std::string const& role,
int num_sessions); // LOCKS_EXCLUDED(mu_)

SessionHolder MakeSessionHolder(std::unique_ptr<Session> session,
Expand All @@ -158,7 +160,7 @@ class SessionPool : public std::enable_shared_from_this<SessionPool> {
AsyncBatchCreateSessions(CompletionQueue& cq,
std::shared_ptr<SpannerStub> const& stub,
std::map<std::string, std::string> const& labels,
int num_sessions);
std::string const& role, int num_sessions);
future<Status> AsyncDeleteSession(CompletionQueue& cq,
std::shared_ptr<SpannerStub> const& stub,
std::string session_name);
Expand Down
Loading