diff --git a/ydb/core/local_pgwire/local_pgwire.cpp b/ydb/core/local_pgwire/local_pgwire.cpp index f71cd38d07d5..7dc6f855ce00 100644 --- a/ydb/core/local_pgwire/local_pgwire.cpp +++ b/ydb/core/local_pgwire/local_pgwire.cpp @@ -18,33 +18,8 @@ class TPgYdbProxy : public TActor { using TBase = TActor; struct TSecurityState { - TString Ticket; - Ydb::Auth::LoginResult LoginResult; - TEvTicketParser::TError Error; - TIntrusiveConstPtr Token; TString SerializedToken; - }; - - struct TTokenState { - std::unordered_set Senders; - }; - - struct TEvPrivate { - enum EEv { - EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE), - EvEnd - }; - - static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)"); - - struct TEvTokenReady : TEventLocal { - Ydb::Auth::LoginResult LoginResult; - TActorId Sender; - TString Database; - TString PeerName; - - TEvTokenReady() = default; - }; + TString Ticket; }; struct TConnectionState { @@ -54,7 +29,6 @@ class TPgYdbProxy : public TActor { std::unordered_map ConnectionState; std::unordered_map SecurityState; - std::unordered_map TokenState; uint32_t ConnectionNum = 0; public: @@ -63,85 +37,24 @@ class TPgYdbProxy : public TActor { { } - void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) { - auto token = ev->Get()->Ticket; - auto itTokenState = TokenState.find(token); - if (itTokenState == TokenState.end()) { - BLOG_W("Couldn't find token in reply from TicketParser"); - return; - } - for (auto sender : itTokenState->second.Senders) { - auto& securityState(SecurityState[sender]); - securityState.Ticket = token; - securityState.Error = ev->Get()->Error; - securityState.Token = ev->Get()->Token; - securityState.SerializedToken = ev->Get()->SerializedToken; - auto authResponse = std::make_unique(); - if (ev->Get()->Error) { - authResponse->Error = ev->Get()->Error.Message; - } - Send(sender, authResponse.release()); - } - TokenState.erase(itTokenState); - } - - void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) { - auto token = ev->Get()->LoginResult.token(); - auto itTokenState = TokenState.find(token); - if (itTokenState == TokenState.end()) { - itTokenState = TokenState.insert({token, {}}).first; - } - bool needSend = itTokenState->second.Senders.empty(); - itTokenState->second.Senders.insert(ev->Get()->Sender); - if (needSend) { - Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({ - .Database = ev->Get()->Database, - .Ticket = token, - .PeerName = ev->Get()->PeerName, - })); - } - SecurityState[ev->Get()->Sender].LoginResult = std::move(ev->Get()->LoginResult); - } - void Handle(NPG::TEvPGEvents::TEvAuth::TPtr& ev) { - std::unordered_map clientParams = ev->Get()->InitialMessage->GetClientParams(); BLOG_D("TEvAuth " << ev->Get()->InitialMessage->Dump() << " cookie " << ev->Cookie); - Ydb::Auth::LoginRequest request; - request.set_user(clientParams["user"]); + std::unordered_map clientParams = ev->Get()->InitialMessage->GetClientParams(); + TPgWireAuthData pgWireAuthData; + pgWireAuthData.UserName = clientParams["user"]; if (ev->Get()->PasswordMessage) { - request.set_password(TString(ev->Get()->PasswordMessage->GetPassword())); + pgWireAuthData.Password = TString(ev->Get()->PasswordMessage->GetPassword()); } - TActorSystem* actorSystem = TActivationContext::ActorSystem(); - TActorId sender = ev->Sender; - TString database = clientParams["database"]; - if (database == "/postgres") { + pgWireAuthData.Sender = ev->Sender; + pgWireAuthData.DatabasePath = clientParams["database"]; + if (pgWireAuthData.DatabasePath == "/postgres") { auto authResponse = std::make_unique(); authResponse->Error = Ydb::StatusIds_StatusCode_Name(Ydb::StatusIds_StatusCode::StatusIds_StatusCode_BAD_REQUEST); - actorSystem->Send(sender, authResponse.release()); + Send(pgWireAuthData.Sender, authResponse.release()); } - TString peerName = TStringBuilder() << ev->Get()->Address; + pgWireAuthData.PeerName = TStringBuilder() << ev->Get()->Address; - using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth; - auto rpcFuture = NRpcService::DoLocalRpc(std::move(request), database, {}, actorSystem); - rpcFuture.Subscribe([actorSystem, sender, database, peerName, selfId = SelfId()](const NThreading::TFuture& future) { - auto& response = future.GetValueSync(); - if (response.operation().status() == Ydb::StatusIds::SUCCESS) { - auto tokenReady = std::make_unique(); - response.operation().result().UnpackTo(&(tokenReady->LoginResult)); - tokenReady->Sender = sender; - tokenReady->Database = database; - tokenReady->PeerName = peerName; - actorSystem->Send(selfId, tokenReady.release()); - } else { - auto authResponse = std::make_unique(); - if (response.operation().issues_size() > 0) { - authResponse->Error = response.operation().issues(0).message(); - } else { - authResponse->Error = Ydb::StatusIds_StatusCode_Name(response.operation().status()); - } - actorSystem->Send(sender, authResponse.release()); - } - }); + Register(CreateLocalPgWireAuthActor(pgWireAuthData, SelfId())); } void Handle(NPG::TEvPGEvents::TEvConnectionOpened::TPtr& ev) { @@ -173,7 +86,6 @@ class TPgYdbProxy : public TActor { } SecurityState.erase(ev->Sender); ConnectionState.erase(itConnection); - // TODO: cleanup TokenState too } void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) { @@ -236,6 +148,18 @@ class TPgYdbProxy : public TActor { } } + void Handle(TEvEvents::TEvAuthResponse::TPtr& ev) { + auto& securityState = SecurityState[ev->Get()->Sender]; + auto authResponse = std::make_unique(); + if (!ev->Get()->ErrorMessage.empty()) { + authResponse->Error = ev->Get()->ErrorMessage; + } else { + securityState.SerializedToken = ev->Get()->SerializedToken; + securityState.Ticket = ev->Get()->Ticket; + } + Send(ev->Get()->Sender, authResponse.release()); + } + STATEFN(StateWork) { switch (ev->GetTypeRewrite()) { hFunc(NPG::TEvPGEvents::TEvAuth, Handle); @@ -248,8 +172,7 @@ class TPgYdbProxy : public TActor { hFunc(NPG::TEvPGEvents::TEvExecute, Handle); hFunc(NPG::TEvPGEvents::TEvClose, Handle); hFunc(NPG::TEvPGEvents::TEvCancelRequest, Handle); - hFunc(TEvPrivate::TEvTokenReady, Handle); - hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle); + hFunc(TEvEvents::TEvAuthResponse, Handle); } } }; diff --git a/ydb/core/local_pgwire/local_pgwire.h b/ydb/core/local_pgwire/local_pgwire.h index a5c9cc395793..b9d6588981b3 100644 --- a/ydb/core/local_pgwire/local_pgwire.h +++ b/ydb/core/local_pgwire/local_pgwire.h @@ -1,3 +1,6 @@ +#pragma once + +#include "local_pgwire_util.h" #include namespace NLocalPgWire { @@ -5,4 +8,6 @@ namespace NLocalPgWire { inline NActors::TActorId CreateLocalPgWireProxyId(uint32_t nodeId = 0) { return NActors::TActorId(nodeId, "localpgwire"); } NActors::IActor* CreateLocalPgWireProxy(); +NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const NActors::TActorId& pgYdbProxy); + } diff --git a/ydb/core/local_pgwire/local_pgwire_auth_actor.cpp b/ydb/core/local_pgwire/local_pgwire_auth_actor.cpp new file mode 100644 index 000000000000..fd9d91c60e1b --- /dev/null +++ b/ydb/core/local_pgwire/local_pgwire_auth_actor.cpp @@ -0,0 +1,192 @@ +#include "log_impl.h" +#include "local_pgwire.h" +#include "local_pgwire_util.h" + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace NLocalPgWire { + +using namespace NActors; +using namespace NKikimr; + +class TPgYdbAuthActor : public NActors::TActorBootstrapped { + using TBase = TActor; + + struct TEvPrivate { + enum EEv { + EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE), + EvAuthFailed, + EvEnd + }; + + static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)"); + + struct TEvTokenReady : TEventLocal { + Ydb::Auth::LoginResult LoginResult; + + TEvTokenReady() = default; + }; + + struct TEvAuthFailed : NActors::TEventLocal { + TString ErrorMessage; + }; + }; + + TPgWireAuthData PgWireAuthData; + TActorId PgYdbProxy; + + TString DatabaseId; + TString FolderId; + TString SerializedToken; + TString Ticket; + +public: + TPgYdbAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy) + : PgWireAuthData(pgWireAuthData) + , PgYdbProxy(pgYdbProxy) { + } + + void Bootstrap() { + if (PgWireAuthData.UserName == "__ydb_apikey") { + if (PgWireAuthData.Password.empty()) { + SendResponseAndDie("Invalid password"); + } + SendDescribeRequest(); + } else { + SendLoginRequest(); + } + + Become(&TPgYdbAuthActor::StateWork); + } + + void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) { + if (ev->Get()->Error) { + SendResponseAndDie(ev->Get()->Error.Message); + return; + } + + SerializedToken = ev->Get()->SerializedToken; + Ticket = ev->Get()->Ticket; + + SendResponseAndDie(); + } + + void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) { + Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({ + .Database = PgWireAuthData.DatabasePath, + .Ticket = ev->Get()->LoginResult.token(), + .PeerName = PgWireAuthData.PeerName, + })); + } + + void Handle(TEvPrivate::TEvAuthFailed::TPtr& ev) { + SendResponseAndDie(ev->Get()->ErrorMessage); + } + + void Handle(NKikimr::TEvTxProxySchemeCache::TEvNavigateKeySetResult::TPtr& ev) { + const NKikimr::NSchemeCache::TSchemeCacheNavigate* navigate = ev->Get()->Request.Get(); + if (navigate->ErrorCount) { + SendResponseAndDie(TStringBuilder() << "Database with path '" << PgWireAuthData.DatabasePath << "' doesn't exists"); + return; + } + Y_ABORT_UNLESS(navigate->ResultSet.size() == 1); + + const auto& entry = navigate->ResultSet.front(); + + for (const auto& attr : entry.Attributes) { + if (attr.first == "folderId") FolderId = attr.second; + else if (attr.first == "database_id") DatabaseId = attr.second; + } + + SendApiKeyRequest(); + } + + STATEFN(StateWork) { + switch (ev->GetTypeRewrite()) { + hFunc(TEvPrivate::TEvTokenReady, Handle); + hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle); + hFunc(TEvTxProxySchemeCache::TEvNavigateKeySetResult, Handle); + hFunc(TEvPrivate::TEvAuthFailed, Handle); + } + } +private: + void SendLoginRequest() { + Ydb::Auth::LoginRequest request; + request.set_user(PgWireAuthData.UserName); + if (!PgWireAuthData.Password.empty()) { + request.set_password(PgWireAuthData.Password); + } + + auto* actorSystem = TActivationContext::ActorSystem();; + + using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth; + auto rpcFuture = NRpcService::DoLocalRpc(std::move(request), PgWireAuthData.DatabasePath, {}, actorSystem); + rpcFuture.Subscribe([actorSystem, selfId = SelfId()](const NThreading::TFuture& future) { + auto& response = future.GetValueSync(); + if (response.operation().status() == Ydb::StatusIds::SUCCESS) { + auto tokenReady = std::make_unique(); + response.operation().result().UnpackTo(&(tokenReady->LoginResult)); + actorSystem->Send(selfId, tokenReady.release()); + } else { + auto authFailedEvent = std::make_unique(); + if (response.operation().issues_size() > 0) { + authFailedEvent->ErrorMessage = response.operation().issues(0).message(); + } else { + authFailedEvent->ErrorMessage = Ydb::StatusIds_StatusCode_Name(response.operation().status()); + } + actorSystem->Send(selfId, authFailedEvent.release()); + } + }); + } + + void SendApiKeyRequest() { + auto entries = NKikimr::NGRpcProxy::V1::GetTicketParserEntries(DatabaseId, FolderId); + + Send(NKikimr::MakeTicketParserID(), new NKikimr::TEvTicketParser::TEvAuthorizeTicket({ + .Database = PgWireAuthData.DatabasePath, + .Ticket = "ApiKey " + PgWireAuthData.Password, + .PeerName = PgWireAuthData.PeerName, + .Entries = entries + })); + } + + void SendDescribeRequest() { + auto schemeCacheRequest = std::make_unique(); + NKikimr::NSchemeCache::TSchemeCacheNavigate::TEntry entry; + entry.Path = NKikimr::SplitPath(PgWireAuthData.DatabasePath); + entry.Operation = NKikimr::NSchemeCache::TSchemeCacheNavigate::OpPath; + entry.SyncVersion = false; + schemeCacheRequest->ResultSet.emplace_back(entry); + Send(NKikimr::MakeSchemeCacheID(), MakeHolder(schemeCacheRequest.release())); + } + + void SendResponseAndDie(const TString& errorMessage = "") { + std::unique_ptr authResponse; + if (!errorMessage.empty()) { + authResponse = std::make_unique(errorMessage, PgWireAuthData.Sender); + } else { + authResponse = std::make_unique(SerializedToken, Ticket, PgWireAuthData.Sender); + } + + Send(PgYdbProxy, authResponse.release()); + + PassAway(); + } +}; + + +NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy) { + return new TPgYdbAuthActor(pgWireAuthData, pgYdbProxy); +} + +} diff --git a/ydb/core/local_pgwire/local_pgwire_util.h b/ydb/core/local_pgwire/local_pgwire_util.h index 21ecf6dd88db..0ef16d84b27e 100644 --- a/ydb/core/local_pgwire/local_pgwire_util.h +++ b/ydb/core/local_pgwire/local_pgwire_util.h @@ -30,6 +30,14 @@ struct TConnectionState { uint32_t ConnectionNum = 0; }; +struct TPgWireAuthData { + TActorId Sender; + TString UserName; + TString DatabasePath; + TString Password; + TString PeerName; +}; + struct TParsedStatement { NPG::TPGParse::TQueryData QueryData; std::vector ParameterTypes; @@ -56,6 +64,7 @@ struct TEvEvents { EvUpdateStatement, EvSingleQuery, EvCancelRequest, + EvAuthResponse, EvEnd }; @@ -98,6 +107,24 @@ struct TEvEvents { struct TEvCancelRequest : NActors::TEventLocal { TEvCancelRequest() = default; }; + + struct TEvAuthResponse : NActors::TEventLocal { + TString SerializedToken; + TString Ticket; + TString ErrorMessage; + TActorId Sender; + + TEvAuthResponse(const TString& serializedToken, const TString& ticket, const TActorId& sender) + : SerializedToken(serializedToken) + , Ticket(ticket) + , Sender(sender) + {} + + TEvAuthResponse(const TString& errorMessage, const TActorId& sender) + : ErrorMessage(errorMessage) + , Sender(sender) + {} + }; }; TString ColumnPrimitiveValueToString(NYdb::TValueParser& valueParser); diff --git a/ydb/core/local_pgwire/ya.make b/ydb/core/local_pgwire/ya.make index d63b67b22f5e..71b533976682 100644 --- a/ydb/core/local_pgwire/ya.make +++ b/ydb/core/local_pgwire/ya.make @@ -1,6 +1,7 @@ LIBRARY() SRCS( + local_pgwire_auth_actor.cpp local_pgwire_connection.cpp local_pgwire.cpp local_pgwire.h @@ -18,6 +19,7 @@ PEERDIR( ydb/core/kqp/common/events ydb/core/kqp/common/simple ydb/core/kqp/executer_actor + ydb/core/base ydb/core/grpc_services ydb/core/grpc_services/local_rpc ydb/core/protos @@ -25,6 +27,7 @@ PEERDIR( ydb/core/ydb_convert ydb/public/api/grpc ydb/public/lib/operation_id/protos + ydb/services/persqueue_v1/actors ) YQL_LAST_ABI_VERSION()