Skip to content

Commit

Permalink
PgWire auth with ApiKey (ydb-platform#8283)
Browse files Browse the repository at this point in the history
  • Loading branch information
shnikd committed Sep 2, 2024
1 parent 9ffc09d commit b493b71
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 101 deletions.
125 changes: 24 additions & 101 deletions ydb/core/local_pgwire/local_pgwire.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,8 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
using TBase = TActor<TPgYdbProxy>;

struct TSecurityState {
TString Ticket;
Ydb::Auth::LoginResult LoginResult;
TEvTicketParser::TError Error;
TIntrusiveConstPtr<NACLib::TUserToken> Token;
TString SerializedToken;
};

struct TTokenState {
std::unordered_set<TActorId> Senders;
};

struct TEvPrivate {
enum EEv {
EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE),
EvEnd
};

static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)");

struct TEvTokenReady : TEventLocal<TEvTokenReady, EvTokenReady> {
Ydb::Auth::LoginResult LoginResult;
TActorId Sender;
TString Database;
TString PeerName;

TEvTokenReady() = default;
};
TString Ticket;
};

struct TConnectionState {
Expand All @@ -54,7 +29,6 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {

std::unordered_map<TActorId, TConnectionState> ConnectionState;
std::unordered_map<TActorId, TSecurityState> SecurityState;
std::unordered_map<TString, TTokenState> TokenState;
uint32_t ConnectionNum = 0;

public:
Expand All @@ -63,85 +37,24 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
{
}

void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) {
auto token = ev->Get()->Ticket;
auto itTokenState = TokenState.find(token);
if (itTokenState == TokenState.end()) {
BLOG_W("Couldn't find token in reply from TicketParser");
return;
}
for (auto sender : itTokenState->second.Senders) {
auto& securityState(SecurityState[sender]);
securityState.Ticket = token;
securityState.Error = ev->Get()->Error;
securityState.Token = ev->Get()->Token;
securityState.SerializedToken = ev->Get()->SerializedToken;
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
if (ev->Get()->Error) {
authResponse->Error = ev->Get()->Error.Message;
}
Send(sender, authResponse.release());
}
TokenState.erase(itTokenState);
}

void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) {
auto token = ev->Get()->LoginResult.token();
auto itTokenState = TokenState.find(token);
if (itTokenState == TokenState.end()) {
itTokenState = TokenState.insert({token, {}}).first;
}
bool needSend = itTokenState->second.Senders.empty();
itTokenState->second.Senders.insert(ev->Get()->Sender);
if (needSend) {
Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({
.Database = ev->Get()->Database,
.Ticket = token,
.PeerName = ev->Get()->PeerName,
}));
}
SecurityState[ev->Get()->Sender].LoginResult = std::move(ev->Get()->LoginResult);
}

void Handle(NPG::TEvPGEvents::TEvAuth::TPtr& ev) {
std::unordered_map<TString, TString> clientParams = ev->Get()->InitialMessage->GetClientParams();
BLOG_D("TEvAuth " << ev->Get()->InitialMessage->Dump() << " cookie " << ev->Cookie);
Ydb::Auth::LoginRequest request;
request.set_user(clientParams["user"]);
std::unordered_map<TString, TString> clientParams = ev->Get()->InitialMessage->GetClientParams();
TPgWireAuthData pgWireAuthData;
pgWireAuthData.UserName = clientParams["user"];
if (ev->Get()->PasswordMessage) {
request.set_password(TString(ev->Get()->PasswordMessage->GetPassword()));
pgWireAuthData.Password = TString(ev->Get()->PasswordMessage->GetPassword());
}
TActorSystem* actorSystem = TActivationContext::ActorSystem();
TActorId sender = ev->Sender;
TString database = clientParams["database"];
if (database == "/postgres") {
pgWireAuthData.Sender = ev->Sender;
pgWireAuthData.DatabasePath = clientParams["database"];
if (pgWireAuthData.DatabasePath == "/postgres") {
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
authResponse->Error = Ydb::StatusIds_StatusCode_Name(Ydb::StatusIds_StatusCode::StatusIds_StatusCode_BAD_REQUEST);
actorSystem->Send(sender, authResponse.release());
Send(pgWireAuthData.Sender, authResponse.release());
}
TString peerName = TStringBuilder() << ev->Get()->Address;
pgWireAuthData.PeerName = TStringBuilder() << ev->Get()->Address;

using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth<NGRpcService::TRpcServices::EvLogin, Ydb::Auth::LoginRequest, Ydb::Auth::LoginResponse>;
auto rpcFuture = NRpcService::DoLocalRpc<TRpcEv>(std::move(request), database, {}, actorSystem);
rpcFuture.Subscribe([actorSystem, sender, database, peerName, selfId = SelfId()](const NThreading::TFuture<Ydb::Auth::LoginResponse>& future) {
auto& response = future.GetValueSync();
if (response.operation().status() == Ydb::StatusIds::SUCCESS) {
auto tokenReady = std::make_unique<TEvPrivate::TEvTokenReady>();
response.operation().result().UnpackTo(&(tokenReady->LoginResult));
tokenReady->Sender = sender;
tokenReady->Database = database;
tokenReady->PeerName = peerName;
actorSystem->Send(selfId, tokenReady.release());
} else {
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
if (response.operation().issues_size() > 0) {
authResponse->Error = response.operation().issues(0).message();
} else {
authResponse->Error = Ydb::StatusIds_StatusCode_Name(response.operation().status());
}
actorSystem->Send(sender, authResponse.release());
}
});
Register(CreateLocalPgWireAuthActor(pgWireAuthData, SelfId()));
}

void Handle(NPG::TEvPGEvents::TEvConnectionOpened::TPtr& ev) {
Expand Down Expand Up @@ -173,7 +86,6 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
}
SecurityState.erase(ev->Sender);
ConnectionState.erase(itConnection);
// TODO: cleanup TokenState too
}

void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) {
Expand Down Expand Up @@ -236,6 +148,18 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
}
}

void Handle(TEvEvents::TEvAuthResponse::TPtr& ev) {
auto& securityState = SecurityState[ev->Get()->Sender];
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
if (!ev->Get()->ErrorMessage.empty()) {
authResponse->Error = ev->Get()->ErrorMessage;
} else {
securityState.SerializedToken = ev->Get()->SerializedToken;
securityState.Ticket = ev->Get()->Ticket;
}
Send(ev->Get()->Sender, authResponse.release());
}

STATEFN(StateWork) {
switch (ev->GetTypeRewrite()) {
hFunc(NPG::TEvPGEvents::TEvAuth, Handle);
Expand All @@ -248,8 +172,7 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
hFunc(NPG::TEvPGEvents::TEvExecute, Handle);
hFunc(NPG::TEvPGEvents::TEvClose, Handle);
hFunc(NPG::TEvPGEvents::TEvCancelRequest, Handle);
hFunc(TEvPrivate::TEvTokenReady, Handle);
hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle);
hFunc(TEvEvents::TEvAuthResponse, Handle);
}
}
};
Expand Down
5 changes: 5 additions & 0 deletions ydb/core/local_pgwire/local_pgwire.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#pragma once

#include "local_pgwire_util.h"
#include <ydb/library/actors/core/actor.h>

namespace NLocalPgWire {

inline NActors::TActorId CreateLocalPgWireProxyId(uint32_t nodeId = 0) { return NActors::TActorId(nodeId, "localpgwire"); }
NActors::IActor* CreateLocalPgWireProxy();

NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const NActors::TActorId& pgYdbProxy);

}
192 changes: 192 additions & 0 deletions ydb/core/local_pgwire/local_pgwire_auth_actor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#include "log_impl.h"
#include "local_pgwire.h"
#include "local_pgwire_util.h"

#include <ydb/core/base/path.h>
#include <ydb/core/base/ticket_parser.h>
#include <ydb/core/grpc_services/local_rpc/local_rpc.h>
#include <ydb/core/tx/scheme_cache/scheme_cache.h>

#include <ydb/library/actors/core/actor.h>
#include <ydb/library/actors/core/actor_bootstrapped.h>

#include <ydb/public/api/grpc/ydb_auth_v1.grpc.pb.h>

#include <ydb/services/persqueue_v1/actors/persqueue_utils.h>

namespace NLocalPgWire {

using namespace NActors;
using namespace NKikimr;

class TPgYdbAuthActor : public NActors::TActorBootstrapped<TPgYdbAuthActor> {
using TBase = TActor<TPgYdbAuthActor>;

struct TEvPrivate {
enum EEv {
EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE),
EvAuthFailed,
EvEnd
};

static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)");

struct TEvTokenReady : TEventLocal<TEvTokenReady, EvTokenReady> {
Ydb::Auth::LoginResult LoginResult;

TEvTokenReady() = default;
};

struct TEvAuthFailed : NActors::TEventLocal<TEvAuthFailed, EvAuthFailed> {
TString ErrorMessage;
};
};

TPgWireAuthData PgWireAuthData;
TActorId PgYdbProxy;

TString DatabaseId;
TString FolderId;
TString SerializedToken;
TString Ticket;

public:
TPgYdbAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy)
: PgWireAuthData(pgWireAuthData)
, PgYdbProxy(pgYdbProxy) {
}

void Bootstrap() {
if (PgWireAuthData.UserName == "__ydb_apikey") {
if (PgWireAuthData.Password.empty()) {
SendResponseAndDie("Invalid password");
}
SendDescribeRequest();
} else {
SendLoginRequest();
}

Become(&TPgYdbAuthActor::StateWork);
}

void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) {
if (ev->Get()->Error) {
SendResponseAndDie(ev->Get()->Error.Message);
return;
}

SerializedToken = ev->Get()->SerializedToken;
Ticket = ev->Get()->Ticket;

SendResponseAndDie();
}

void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) {
Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({
.Database = PgWireAuthData.DatabasePath,
.Ticket = ev->Get()->LoginResult.token(),
.PeerName = PgWireAuthData.PeerName,
}));
}

void Handle(TEvPrivate::TEvAuthFailed::TPtr& ev) {
SendResponseAndDie(ev->Get()->ErrorMessage);
}

void Handle(NKikimr::TEvTxProxySchemeCache::TEvNavigateKeySetResult::TPtr& ev) {
const NKikimr::NSchemeCache::TSchemeCacheNavigate* navigate = ev->Get()->Request.Get();
if (navigate->ErrorCount) {
SendResponseAndDie(TStringBuilder() << "Database with path '" << PgWireAuthData.DatabasePath << "' doesn't exists");
return;
}
Y_ABORT_UNLESS(navigate->ResultSet.size() == 1);

const auto& entry = navigate->ResultSet.front();

for (const auto& attr : entry.Attributes) {
if (attr.first == "folderId") FolderId = attr.second;
else if (attr.first == "database_id") DatabaseId = attr.second;
}

SendApiKeyRequest();
}

STATEFN(StateWork) {
switch (ev->GetTypeRewrite()) {
hFunc(TEvPrivate::TEvTokenReady, Handle);
hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle);
hFunc(TEvTxProxySchemeCache::TEvNavigateKeySetResult, Handle);
hFunc(TEvPrivate::TEvAuthFailed, Handle);
}
}
private:
void SendLoginRequest() {
Ydb::Auth::LoginRequest request;
request.set_user(PgWireAuthData.UserName);
if (!PgWireAuthData.Password.empty()) {
request.set_password(PgWireAuthData.Password);
}

auto* actorSystem = TActivationContext::ActorSystem();;

using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth<NGRpcService::TRpcServices::EvLogin, Ydb::Auth::LoginRequest, Ydb::Auth::LoginResponse>;
auto rpcFuture = NRpcService::DoLocalRpc<TRpcEv>(std::move(request), PgWireAuthData.DatabasePath, {}, actorSystem);
rpcFuture.Subscribe([actorSystem, selfId = SelfId()](const NThreading::TFuture<Ydb::Auth::LoginResponse>& future) {
auto& response = future.GetValueSync();
if (response.operation().status() == Ydb::StatusIds::SUCCESS) {
auto tokenReady = std::make_unique<TEvPrivate::TEvTokenReady>();
response.operation().result().UnpackTo(&(tokenReady->LoginResult));
actorSystem->Send(selfId, tokenReady.release());
} else {
auto authFailedEvent = std::make_unique<TEvPrivate::TEvAuthFailed>();
if (response.operation().issues_size() > 0) {
authFailedEvent->ErrorMessage = response.operation().issues(0).message();
} else {
authFailedEvent->ErrorMessage = Ydb::StatusIds_StatusCode_Name(response.operation().status());
}
actorSystem->Send(selfId, authFailedEvent.release());
}
});
}

void SendApiKeyRequest() {
auto entries = NKikimr::NGRpcProxy::V1::GetTicketParserEntries(DatabaseId, FolderId);

Send(NKikimr::MakeTicketParserID(), new NKikimr::TEvTicketParser::TEvAuthorizeTicket({
.Database = PgWireAuthData.DatabasePath,
.Ticket = "ApiKey " + PgWireAuthData.Password,
.PeerName = PgWireAuthData.PeerName,
.Entries = entries
}));
}

void SendDescribeRequest() {
auto schemeCacheRequest = std::make_unique<NKikimr::NSchemeCache::TSchemeCacheNavigate>();
NKikimr::NSchemeCache::TSchemeCacheNavigate::TEntry entry;
entry.Path = NKikimr::SplitPath(PgWireAuthData.DatabasePath);
entry.Operation = NKikimr::NSchemeCache::TSchemeCacheNavigate::OpPath;
entry.SyncVersion = false;
schemeCacheRequest->ResultSet.emplace_back(entry);
Send(NKikimr::MakeSchemeCacheID(), MakeHolder<NKikimr::TEvTxProxySchemeCache::TEvNavigateKeySet>(schemeCacheRequest.release()));
}

void SendResponseAndDie(const TString& errorMessage = "") {
std::unique_ptr<TEvEvents::TEvAuthResponse> authResponse;
if (!errorMessage.empty()) {
authResponse = std::make_unique<TEvEvents::TEvAuthResponse>(errorMessage, PgWireAuthData.Sender);
} else {
authResponse = std::make_unique<TEvEvents::TEvAuthResponse>(SerializedToken, Ticket, PgWireAuthData.Sender);
}

Send(PgYdbProxy, authResponse.release());

PassAway();
}
};


NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy) {
return new TPgYdbAuthActor(pgWireAuthData, pgYdbProxy);
}

}
Loading

0 comments on commit b493b71

Please sign in to comment.