Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PgWire auth with ApiKey #8283

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 24 additions & 101 deletions ydb/core/local_pgwire/local_pgwire.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,8 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
using TBase = TActor<TPgYdbProxy>;

struct TSecurityState {
TString Ticket;
Ydb::Auth::LoginResult LoginResult;
TEvTicketParser::TError Error;
TIntrusiveConstPtr<NACLib::TUserToken> Token;
TString SerializedToken;
};

struct TTokenState {
std::unordered_set<TActorId> Senders;
};

struct TEvPrivate {
enum EEv {
EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE),
EvEnd
};

static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)");

struct TEvTokenReady : TEventLocal<TEvTokenReady, EvTokenReady> {
Ydb::Auth::LoginResult LoginResult;
TActorId Sender;
TString Database;
TString PeerName;

TEvTokenReady() = default;
};
TString Ticket;
};

struct TConnectionState {
Expand All @@ -54,7 +29,6 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {

std::unordered_map<TActorId, TConnectionState> ConnectionState;
std::unordered_map<TActorId, TSecurityState> SecurityState;
std::unordered_map<TString, TTokenState> TokenState;
uint32_t ConnectionNum = 0;

public:
Expand All @@ -63,85 +37,24 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
{
}

void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) {
auto token = ev->Get()->Ticket;
auto itTokenState = TokenState.find(token);
if (itTokenState == TokenState.end()) {
BLOG_W("Couldn't find token in reply from TicketParser");
return;
}
for (auto sender : itTokenState->second.Senders) {
auto& securityState(SecurityState[sender]);
securityState.Ticket = token;
securityState.Error = ev->Get()->Error;
securityState.Token = ev->Get()->Token;
securityState.SerializedToken = ev->Get()->SerializedToken;
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
if (ev->Get()->Error) {
authResponse->Error = ev->Get()->Error.Message;
}
Send(sender, authResponse.release());
}
TokenState.erase(itTokenState);
}

void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) {
auto token = ev->Get()->LoginResult.token();
auto itTokenState = TokenState.find(token);
if (itTokenState == TokenState.end()) {
itTokenState = TokenState.insert({token, {}}).first;
}
bool needSend = itTokenState->second.Senders.empty();
itTokenState->second.Senders.insert(ev->Get()->Sender);
if (needSend) {
Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({
.Database = ev->Get()->Database,
.Ticket = token,
.PeerName = ev->Get()->PeerName,
}));
}
SecurityState[ev->Get()->Sender].LoginResult = std::move(ev->Get()->LoginResult);
}

void Handle(NPG::TEvPGEvents::TEvAuth::TPtr& ev) {
std::unordered_map<TString, TString> clientParams = ev->Get()->InitialMessage->GetClientParams();
BLOG_D("TEvAuth " << ev->Get()->InitialMessage->Dump() << " cookie " << ev->Cookie);
Ydb::Auth::LoginRequest request;
request.set_user(clientParams["user"]);
std::unordered_map<TString, TString> clientParams = ev->Get()->InitialMessage->GetClientParams();
TPgWireAuthData pgWireAuthData;
pgWireAuthData.UserName = clientParams["user"];
if (ev->Get()->PasswordMessage) {
request.set_password(TString(ev->Get()->PasswordMessage->GetPassword()));
pgWireAuthData.Password = TString(ev->Get()->PasswordMessage->GetPassword());
}
TActorSystem* actorSystem = TActivationContext::ActorSystem();
TActorId sender = ev->Sender;
TString database = clientParams["database"];
if (database == "/postgres") {
pgWireAuthData.Sender = ev->Sender;
pgWireAuthData.DatabasePath = clientParams["database"];
if (pgWireAuthData.DatabasePath == "/postgres") {
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
authResponse->Error = Ydb::StatusIds_StatusCode_Name(Ydb::StatusIds_StatusCode::StatusIds_StatusCode_BAD_REQUEST);
actorSystem->Send(sender, authResponse.release());
Send(pgWireAuthData.Sender, authResponse.release());
}
TString peerName = TStringBuilder() << ev->Get()->Address;
pgWireAuthData.PeerName = TStringBuilder() << ev->Get()->Address;

using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth<NGRpcService::TRpcServices::EvLogin, Ydb::Auth::LoginRequest, Ydb::Auth::LoginResponse>;
auto rpcFuture = NRpcService::DoLocalRpc<TRpcEv>(std::move(request), database, {}, actorSystem);
rpcFuture.Subscribe([actorSystem, sender, database, peerName, selfId = SelfId()](const NThreading::TFuture<Ydb::Auth::LoginResponse>& future) {
auto& response = future.GetValueSync();
if (response.operation().status() == Ydb::StatusIds::SUCCESS) {
auto tokenReady = std::make_unique<TEvPrivate::TEvTokenReady>();
response.operation().result().UnpackTo(&(tokenReady->LoginResult));
tokenReady->Sender = sender;
tokenReady->Database = database;
tokenReady->PeerName = peerName;
actorSystem->Send(selfId, tokenReady.release());
} else {
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
if (response.operation().issues_size() > 0) {
authResponse->Error = response.operation().issues(0).message();
} else {
authResponse->Error = Ydb::StatusIds_StatusCode_Name(response.operation().status());
}
actorSystem->Send(sender, authResponse.release());
}
});
Register(CreateLocalPgWireAuthActor(pgWireAuthData, SelfId()));
}

void Handle(NPG::TEvPGEvents::TEvConnectionOpened::TPtr& ev) {
Expand Down Expand Up @@ -173,7 +86,6 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
}
SecurityState.erase(ev->Sender);
ConnectionState.erase(itConnection);
// TODO: cleanup TokenState too
}

void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) {
Expand Down Expand Up @@ -236,6 +148,18 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
}
}

void Handle(TEvEvents::TEvAuthResponse::TPtr& ev) {
auto& securityState = SecurityState[ev->Get()->Sender];
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
if (!ev->Get()->ErrorMessage.empty()) {
authResponse->Error = ev->Get()->ErrorMessage;
} else {
securityState.SerializedToken = ev->Get()->SerializedToken;
securityState.Ticket = ev->Get()->Ticket;
}
Send(ev->Get()->Sender, authResponse.release());
}

STATEFN(StateWork) {
switch (ev->GetTypeRewrite()) {
hFunc(NPG::TEvPGEvents::TEvAuth, Handle);
Expand All @@ -248,8 +172,7 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
hFunc(NPG::TEvPGEvents::TEvExecute, Handle);
hFunc(NPG::TEvPGEvents::TEvClose, Handle);
hFunc(NPG::TEvPGEvents::TEvCancelRequest, Handle);
hFunc(TEvPrivate::TEvTokenReady, Handle);
hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle);
hFunc(TEvEvents::TEvAuthResponse, Handle);
}
}
};
Expand Down
5 changes: 5 additions & 0 deletions ydb/core/local_pgwire/local_pgwire.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#pragma once

#include "local_pgwire_util.h"
#include <ydb/library/actors/core/actor.h>

namespace NLocalPgWire {

inline NActors::TActorId CreateLocalPgWireProxyId(uint32_t nodeId = 0) { return NActors::TActorId(nodeId, "localpgwire"); }
NActors::IActor* CreateLocalPgWireProxy();

NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const NActors::TActorId& pgYdbProxy);

}
192 changes: 192 additions & 0 deletions ydb/core/local_pgwire/local_pgwire_auth_actor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#include "log_impl.h"
#include "local_pgwire.h"
#include "local_pgwire_util.h"

#include <ydb/core/base/path.h>
#include <ydb/core/base/ticket_parser.h>
#include <ydb/core/grpc_services/local_rpc/local_rpc.h>
#include <ydb/core/tx/scheme_cache/scheme_cache.h>

#include <ydb/library/actors/core/actor.h>
#include <ydb/library/actors/core/actor_bootstrapped.h>

#include <ydb/public/api/grpc/ydb_auth_v1.grpc.pb.h>

#include <ydb/services/persqueue_v1/actors/persqueue_utils.h>

namespace NLocalPgWire {

using namespace NActors;
using namespace NKikimr;

class TPgYdbAuthActor : public NActors::TActorBootstrapped<TPgYdbAuthActor> {
using TBase = TActor<TPgYdbAuthActor>;

struct TEvPrivate {
enum EEv {
EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE),
EvAuthFailed,
EvEnd
};

static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)");

struct TEvTokenReady : TEventLocal<TEvTokenReady, EvTokenReady> {
Ydb::Auth::LoginResult LoginResult;

TEvTokenReady() = default;
};

struct TEvAuthFailed : NActors::TEventLocal<TEvAuthFailed, EvAuthFailed> {
TString ErrorMessage;
};
};

TPgWireAuthData PgWireAuthData;
TActorId PgYdbProxy;

TString DatabaseId;
TString FolderId;
TString SerializedToken;
TString Ticket;

public:
TPgYdbAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy)
: PgWireAuthData(pgWireAuthData)
, PgYdbProxy(pgYdbProxy) {
}

void Bootstrap() {
if (PgWireAuthData.UserName == "__ydb_apikey") {
if (PgWireAuthData.Password.empty()) {
SendResponseAndDie("Invalid password");
}
SendDescribeRequest();
} else {
SendLoginRequest();
}

Become(&TPgYdbAuthActor::StateWork);
}

void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) {
if (ev->Get()->Error) {
SendResponseAndDie(ev->Get()->Error.Message);
return;
}

SerializedToken = ev->Get()->SerializedToken;
Ticket = ev->Get()->Ticket;

SendResponseAndDie();
}

void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) {
Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({
.Database = PgWireAuthData.DatabasePath,
.Ticket = ev->Get()->LoginResult.token(),
.PeerName = PgWireAuthData.PeerName,
}));
}

void Handle(TEvPrivate::TEvAuthFailed::TPtr& ev) {
SendResponseAndDie(ev->Get()->ErrorMessage);
}

void Handle(NKikimr::TEvTxProxySchemeCache::TEvNavigateKeySetResult::TPtr& ev) {
const NKikimr::NSchemeCache::TSchemeCacheNavigate* navigate = ev->Get()->Request.Get();
if (navigate->ErrorCount) {
SendResponseAndDie(TStringBuilder() << "Database with path '" << PgWireAuthData.DatabasePath << "' doesn't exists");
return;
}
Y_ABORT_UNLESS(navigate->ResultSet.size() == 1);

const auto& entry = navigate->ResultSet.front();

for (const auto& attr : entry.Attributes) {
if (attr.first == "folderId") FolderId = attr.second;
else if (attr.first == "database_id") DatabaseId = attr.second;
}

SendApiKeyRequest();
}

STATEFN(StateWork) {
switch (ev->GetTypeRewrite()) {
hFunc(TEvPrivate::TEvTokenReady, Handle);
hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle);
hFunc(TEvTxProxySchemeCache::TEvNavigateKeySetResult, Handle);
hFunc(TEvPrivate::TEvAuthFailed, Handle);
}
}
private:
void SendLoginRequest() {
Ydb::Auth::LoginRequest request;
request.set_user(PgWireAuthData.UserName);
if (!PgWireAuthData.Password.empty()) {
request.set_password(PgWireAuthData.Password);
}

auto* actorSystem = TActivationContext::ActorSystem();;

using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth<NGRpcService::TRpcServices::EvLogin, Ydb::Auth::LoginRequest, Ydb::Auth::LoginResponse>;
auto rpcFuture = NRpcService::DoLocalRpc<TRpcEv>(std::move(request), PgWireAuthData.DatabasePath, {}, actorSystem);
rpcFuture.Subscribe([actorSystem, selfId = SelfId()](const NThreading::TFuture<Ydb::Auth::LoginResponse>& future) {
auto& response = future.GetValueSync();
if (response.operation().status() == Ydb::StatusIds::SUCCESS) {
auto tokenReady = std::make_unique<TEvPrivate::TEvTokenReady>();
response.operation().result().UnpackTo(&(tokenReady->LoginResult));
actorSystem->Send(selfId, tokenReady.release());
} else {
auto authFailedEvent = std::make_unique<TEvPrivate::TEvAuthFailed>();
if (response.operation().issues_size() > 0) {
authFailedEvent->ErrorMessage = response.operation().issues(0).message();
} else {
authFailedEvent->ErrorMessage = Ydb::StatusIds_StatusCode_Name(response.operation().status());
}
actorSystem->Send(selfId, authFailedEvent.release());
}
});
}

void SendApiKeyRequest() {
auto entries = NKikimr::NGRpcProxy::V1::GetTicketParserEntries(DatabaseId, FolderId);

Send(NKikimr::MakeTicketParserID(), new NKikimr::TEvTicketParser::TEvAuthorizeTicket({
.Database = PgWireAuthData.DatabasePath,
.Ticket = "ApiKey " + PgWireAuthData.Password,
.PeerName = PgWireAuthData.PeerName,
.Entries = entries
}));
}

void SendDescribeRequest() {
auto schemeCacheRequest = std::make_unique<NKikimr::NSchemeCache::TSchemeCacheNavigate>();
NKikimr::NSchemeCache::TSchemeCacheNavigate::TEntry entry;
entry.Path = NKikimr::SplitPath(PgWireAuthData.DatabasePath);
entry.Operation = NKikimr::NSchemeCache::TSchemeCacheNavigate::OpPath;
entry.SyncVersion = false;
schemeCacheRequest->ResultSet.emplace_back(entry);
Send(NKikimr::MakeSchemeCacheID(), MakeHolder<NKikimr::TEvTxProxySchemeCache::TEvNavigateKeySet>(schemeCacheRequest.release()));
}

void SendResponseAndDie(const TString& errorMessage = "") {
std::unique_ptr<TEvEvents::TEvAuthResponse> authResponse;
if (!errorMessage.empty()) {
authResponse = std::make_unique<TEvEvents::TEvAuthResponse>(errorMessage, PgWireAuthData.Sender);
} else {
authResponse = std::make_unique<TEvEvents::TEvAuthResponse>(SerializedToken, Ticket, PgWireAuthData.Sender);
}

Send(PgYdbProxy, authResponse.release());

PassAway();
}
};


NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy) {
return new TPgYdbAuthActor(pgWireAuthData, pgYdbProxy);
}

}
Loading
Loading