From 10e06ceff97f4dd207f96542aee50326541aec68 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Fri, 17 Jan 2025 13:31:41 -0300 Subject: [PATCH] Migrate AWS Cassandra (Keyspaces) authenticator initialization to AWS SDK V2 (#51147) * refactor(cassandra): use aws sdk v2 for creating keyspaces signer * refactor(cassandra): only use external id for one assume role --- lib/srv/db/cassandra/engine.go | 4 ++-- lib/srv/db/cassandra/handshake.go | 29 +++++++++++++---------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/lib/srv/db/cassandra/engine.go b/lib/srv/db/cassandra/engine.go index 3ae48a1daa7ce..babc4a97c0abf 100644 --- a/lib/srv/db/cassandra/engine.go +++ b/lib/srv/db/cassandra/engine.go @@ -309,8 +309,8 @@ func (e *Engine) getAuth(sessionCtx *common.Session) (handshakeHandler, error) { switch { case sessionCtx.Database.IsAWSHosted(): return &authAWSSigV4Auth{ - cloudClients: e.CloudClients, - ses: sessionCtx, + ses: sessionCtx, + awsConfig: e.AWSConfigProvider, }, nil default: return &basicHandshake{ses: sessionCtx}, nil diff --git a/lib/srv/db/cassandra/handshake.go b/lib/srv/db/cassandra/handshake.go index da118559bb74a..adcccaded08aa 100644 --- a/lib/srv/db/cassandra/handshake.go +++ b/lib/srv/db/cassandra/handshake.go @@ -28,7 +28,7 @@ import ( "github.com/gocql/gocql" "github.com/gravitational/trace" - "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/db/cassandra/protocol" "github.com/gravitational/teleport/lib/srv/db/common" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -190,8 +190,8 @@ func sendAuthenticationErrorMessage(authErr error, clientConn *protocol.Conn, in // authHandler is a handler that performs the Cassandra authentication flow. type authAWSSigV4Auth struct { - ses *common.Session - cloudClients cloud.Clients + ses *common.Session + awsConfig awsconfig.Provider } func (a *authAWSSigV4Auth) getSigV4Authenticator(ctx context.Context) (gocql.Authenticator, error) { @@ -200,25 +200,22 @@ func (a *authAWSSigV4Auth) getSigV4Authenticator(ctx context.Context) (gocql.Aut if err != nil { return nil, trace.Wrap(err) } - baseSession, err := a.cloudClients.GetAWSSession(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), - ) - if err != nil { - return nil, trace.Wrap(err) - } - var externalID string + // ExternalID should only be used in one of the assumed roles. If the + // configuration doesn't specify the AssumeRoleARN, it should be used for + // the database role. + var dbRoleExternalID string if meta.AssumeRoleARN == "" { - externalID = meta.ExternalID + dbRoleExternalID = meta.ExternalID } - session, err := a.cloudClients.GetAWSSession(ctx, meta.Region, - cloud.WithChainedAssumeRole(baseSession, roleARN, externalID), - cloud.WithAmbientCredentials(), + awsCfg, err := a.awsConfig.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAssumeRole(roleARN, dbRoleExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return nil, trace.Wrap(err) } - cred, err := session.Config.Credentials.GetWithContext(ctx) + cred, err := awsCfg.Credentials.Retrieve(ctx) if err != nil { return nil, trace.Wrap(err) }