Skip to content

Commit

Permalink
Migrate AWS STS client to AWS SDK v2 (#51573)
Browse files Browse the repository at this point in the history
  • Loading branch information
GavinFrazar authored Jan 29, 2025
1 parent 533202f commit a9be7db
Show file tree
Hide file tree
Showing 23 changed files with 185 additions and 157 deletions.
20 changes: 0 additions & 20 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ type GCPClients interface {
type AWSClients interface {
// GetAWSSession returns AWS session for the specified region and any role(s).
GetAWSSession(ctx context.Context, region string, opts ...AWSOptionsFn) (*awssession.Session, error)
// GetAWSSTSClient returns AWS STS client for the specified region.
GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error)
}

// AzureClients is an interface for Azure-specific API clients
Expand Down Expand Up @@ -468,15 +466,6 @@ func (c *cloudClients) GetAWSSession(ctx context.Context, region string, opts ..
return c.getAWSSessionForRole(ctx, region, options)
}

// GetAWSSTSClient returns AWS STS client for the specified region.
func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return sts.New(session), nil
}

// GetGCPIAMClient returns GCP IAM client.
func (c *cloudClients) GetGCPIAMClient(ctx context.Context) (*gcpcredentials.IamCredentialsClient, error) {
c.mtx.RLock()
Expand Down Expand Up @@ -964,15 +953,6 @@ func (c *TestCloudClients) getAWSSessionForRegion(region string) (*awssession.Se
})
}

// GetAWSSTSClient returns AWS STS client for the specified region.
func (c *TestCloudClients) GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error) {
_, err := c.GetAWSSession(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return c.STS, nil
}

// GetGCPIAMClient returns GCP IAM client.
func (c *TestCloudClients) GetGCPIAMClient(ctx context.Context) (*gcpcredentials.IamCredentialsClient, error) {
return gcpcredentials.NewIamCredentialsClient(ctx,
Expand Down
2 changes: 1 addition & 1 deletion lib/cloud/mocks/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (m *STSClientV1) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleI
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
AccessKeyId: in.RoleArn,
AccessKeyId: aws.String("FAKEACCESSKEYID"),
SecretAccessKey: aws.String("secret"),
SessionToken: aws.String("token"),
Expiration: &expiry,
Expand Down
17 changes: 15 additions & 2 deletions lib/cloud/mocks/aws_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
type STSClient struct {
STSClientV1

Unauth bool
// credentialProvider is only set when a chain of assumed roles is used.
credentialProvider aws.CredentialsProvider
// recordFn records the role and external ID when a role is assumed.
Expand All @@ -55,17 +56,25 @@ type STSClient struct {
}

func (m *STSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}

return &sts.GetCallerIdentityOutput{
Arn: aws.String(m.ARN),
}, nil
}

func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}

m.record(aws.ToString(in.RoleArn), "")
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &ststypes.Credentials{
AccessKeyId: in.RoleArn,
AccessKeyId: aws.String("WEBIDENTITYFAKEACCESSKEYID"),
SecretAccessKey: aws.String("secret"),
SessionToken: aws.String("token"),
Expiration: &expiry,
Expand All @@ -74,6 +83,10 @@ func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.Assum
}

func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, _ ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}

// Retrieve credentials if we have a credential provider, so that all
// assume-role providers in a role chain are triggered to call AssumeRole.
if m.credentialProvider != nil {
Expand All @@ -87,7 +100,7 @@ func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, _ .
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &ststypes.Credentials{
AccessKeyId: in.RoleArn,
AccessKeyId: aws.String("FAKEACCESSKEYID"),
SecretAccessKey: aws.String("secret"),
SessionToken: aws.String("token"),
Expiration: &expiry,
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/app/aws/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -57,6 +58,11 @@ import (
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)

func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
}

type makeRequest func(url string, provider client.ConfigProvider, awsHost string) error

func s3Request(url string, provider client.ConfigProvider, awsHost string) error {
Expand Down
6 changes: 4 additions & 2 deletions lib/srv/db/cloud/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,13 @@ func (c *IAM) getAWSIdentity(ctx context.Context, database types.Database) (awsl
return c.agentIdentity, nil
}
c.mu.RUnlock()
sts, err := c.cfg.Clients.GetAWSSTSClient(ctx, meta.Region, cloud.WithAmbientCredentials())

awsCfg, err := c.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, awsconfig.WithAmbientCredentials())
if err != nil {
return nil, trace.Wrap(err)
}
awsIdentity, err := awslib.GetIdentityWithClient(ctx, sts)
clt := c.cfg.awsClients.getSTSClient(awsCfg)
awsIdentity, err := awslib.GetIdentityWithClientV2(ctx, clt)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
73 changes: 32 additions & 41 deletions lib/srv/db/cloud/iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,15 @@ func TestAWSIAM(t *testing.T) {
AWSConfigProvider: &mocks.AWSConfigProvider{
STSClient: stsClient,
},
Clients: &clients.TestCloudClients{
STS: &stsClient.STSClientV1,
},
HostID: "host-id",
Clients: &clients.TestCloudClients{},
HostID: "host-id",
onProcessedTask: func(iamTask, error) {
taskChan <- struct{}{}
},
awsClients: fakeAWSClients{
iamClient: iamClient,
rdsClient: clt,
stsClient: stsClient,
},
})
require.NoError(t, err)
Expand Down Expand Up @@ -291,8 +290,10 @@ func TestAWSIAMNoPermissions(t *testing.T) {
t.Cleanup(cancel)

// Create unauthorized mocks for AWS services.
stsClient := &mocks.STSClientV1{
ARN: "arn:aws:iam::123456789012:role/test-role",
stsClient := &mocks.STSClient{
STSClientV1: mocks.STSClientV1{
ARN: "arn:aws:iam::123456789012:role/test-role",
},
}
tests := []struct {
name string
Expand All @@ -301,70 +302,64 @@ func TestAWSIAMNoPermissions(t *testing.T) {
awsClients awsClientProvider
}{
{
name: "RDS database",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{InstanceID: "postgres-rds", ResourceID: "postgres-rds-resource-id"}},
clients: &clients.TestCloudClients{
STS: stsClient,
},
name: "RDS database",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{InstanceID: "postgres-rds", ResourceID: "postgres-rds-resource-id"}},
clients: &clients.TestCloudClients{},
awsClients: fakeAWSClients{
iamClient: &mocks.IAMMock{Unauth: true},
rdsClient: &mocks.RDSClient{Unauth: true},
stsClient: stsClient,
},
},
{
name: "Aurora cluster",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{ClusterID: "postgres-aurora", ResourceID: "postgres-aurora-resource-id"}},
clients: &clients.TestCloudClients{
STS: stsClient,
},
name: "Aurora cluster",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{ClusterID: "postgres-aurora", ResourceID: "postgres-aurora-resource-id"}},
clients: &clients.TestCloudClients{},
awsClients: fakeAWSClients{
iamClient: &mocks.IAMMock{Unauth: true},
rdsClient: &mocks.RDSClient{Unauth: true},
stsClient: stsClient,
},
},
{
name: "RDS database missing metadata",
meta: types.AWS{Region: "localhost", RDS: types.RDS{ClusterID: "postgres-aurora"}},
clients: &clients.TestCloudClients{
STS: stsClient,
},
name: "RDS database missing metadata",
meta: types.AWS{Region: "localhost", RDS: types.RDS{ClusterID: "postgres-aurora"}},
clients: &clients.TestCloudClients{},
awsClients: fakeAWSClients{
iamClient: &mocks.IAMMock{Unauth: true},
rdsClient: &mocks.RDSClient{Unauth: true},
stsClient: stsClient,
},
},
{
name: "Redshift cluster",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", Redshift: types.Redshift{ClusterID: "redshift-cluster-1"}},
clients: &clients.TestCloudClients{
STS: stsClient,
},
name: "Redshift cluster",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", Redshift: types.Redshift{ClusterID: "redshift-cluster-1"}},
clients: &clients.TestCloudClients{},
awsClients: fakeAWSClients{
iamClient: &mocks.IAMMock{Unauth: true},
stsClient: stsClient,
},
},
{
name: "ElastiCache",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", ElastiCache: types.ElastiCache{ReplicationGroupID: "some-group"}},
clients: &clients.TestCloudClients{
STS: stsClient,
},
name: "ElastiCache",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", ElastiCache: types.ElastiCache{ReplicationGroupID: "some-group"}},
clients: &clients.TestCloudClients{},
awsClients: fakeAWSClients{
iamClient: &mocks.IAMMock{Unauth: true},
stsClient: stsClient,
},
},
{
name: "IAM UnmodifiableEntityException",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", Redshift: types.Redshift{ClusterID: "redshift-cluster-1"}},
clients: &clients.TestCloudClients{
STS: stsClient,
},
name: "IAM UnmodifiableEntityException",
meta: types.AWS{Region: "localhost", AccountID: "123456789012", Redshift: types.Redshift{ClusterID: "redshift-cluster-1"}},
clients: &clients.TestCloudClients{},
awsClients: fakeAWSClients{
iamClient: &mocks.IAMMock{
Error: &iamtypes.UnmodifiableEntityException{
Message: aws.String("Cannot perform the operation on the protected role"),
},
},
stsClient: stsClient,
},
},
}
Expand All @@ -377,11 +372,7 @@ func TestAWSIAMNoPermissions(t *testing.T) {
Clients: test.clients,
HostID: "host-id",
AWSConfigProvider: &mocks.AWSConfigProvider{
STSClient: &mocks.STSClient{
STSClientV1: mocks.STSClientV1{
ARN: "arn:aws:iam::123456789012:role/test-role",
},
},
STSClient: stsClient,
},
awsClients: test.awsClients,
})
Expand Down
11 changes: 11 additions & 0 deletions lib/srv/db/cloud/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless"
rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -93,6 +94,11 @@ type rssClient interface {
GetWorkgroup(ctx context.Context, params *rss.GetWorkgroupInput, optFns ...func(*rss.Options)) (*rss.GetWorkgroupOutput, error)
}

// stsClient defines a subset of the AWS STS client API.
type stsClient interface {
GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
}

// awsClientProvider is an AWS SDK client provider.
type awsClientProvider interface {
getElastiCacheClient(cfg aws.Config, optFns ...func(*elasticache.Options)) elasticacheClient
Expand All @@ -102,6 +108,7 @@ type awsClientProvider interface {
getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient
getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient
getRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) rssClient
getSTSClient(cfg aws.Config, optFns ...func(*sts.Options)) stsClient
}

type defaultAWSClients struct{}
Expand Down Expand Up @@ -134,6 +141,10 @@ func (defaultAWSClients) getRedshiftServerlessClient(cfg aws.Config, optFns ...f
return rss.NewFromConfig(cfg, optFns...)
}

func (defaultAWSClients) getSTSClient(cfg aws.Config, optFns ...func(*sts.Options)) stsClient {
return sts.NewFromConfig(cfg, optFns...)
}

// MetadataConfig is the cloud metadata service config.
type MetadataConfig struct {
// Clients is an interface for retrieving cloud clients.
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/db/cloud/meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless"
rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -511,6 +512,7 @@ type fakeAWSClients struct {
rdsClient rdsClient
redshiftClient redshiftClient
rssClient rssClient
stsClient stsClient
}

func (f fakeAWSClients) getElastiCacheClient(cfg aws.Config, optFns ...func(*elasticache.Options)) elasticacheClient {
Expand Down Expand Up @@ -540,3 +542,7 @@ func (f fakeAWSClients) getRedshiftClient(aws.Config, ...func(*redshift.Options)
func (f fakeAWSClients) getRedshiftServerlessClient(aws.Config, ...func(*rss.Options)) rssClient {
return f.rssClient
}

func (f fakeAWSClients) getSTSClient(cfg aws.Config, optFns ...func(*sts.Options)) stsClient {
return f.stsClient
}
2 changes: 1 addition & 1 deletion lib/srv/db/cloud/resource_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewDiscoveryResourceChecker(cfg DiscoveryResourceCheckerConfig) (DiscoveryR
return nil, trace.Wrap(err)
}

credentialsChecker, err := newCrednentialsChecker(cfg)
credentialsChecker, err := newCredentialsChecker(cfg)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
Loading

0 comments on commit a9be7db

Please sign in to comment.