Skip to content

Commit

Permalink
feat: support rotating credentials in newsigv4
Browse files Browse the repository at this point in the history
This adds an additional feature set to a fork of the sigv4
extension. An explicit configuration option can define
a shared credential file location and profile name, this
file will be watched for changes and propogate the fresh
credentials to upstream consumers.

Where possible the AWS SDK machinery has been leveraged
to maintain best compatibility with other AWS toolchains.

The change should maintain backwards compatibility with
the existing sigv4 extension, and both should be
interchangable for uses requiring static credentials or
sts credentials.

Closes rocketsciencegg/aws-gamelift#174
  • Loading branch information
tanuck committed Jan 8, 2025
1 parent 3a5903a commit 9ec5466
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 16 deletions.
17 changes: 13 additions & 4 deletions extensions/newsigv4/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import (

// Config stores the configuration for the Sigv4 Authenticator
type Config struct {
Region string `mapstructure:"region,omitempty"`
Service string `mapstructure:"service,omitempty"`
AssumeRole AssumeRole `mapstructure:"assume_role"`
credsProvider *aws.CredentialsProvider
Region string `mapstructure:"region,omitempty"`
Service string `mapstructure:"service,omitempty"`
AssumeRole AssumeRole `mapstructure:"assume_role"`
SharedCredentialsWatcher SharedCredentialsWatcher `mapstructure:"shared_credentials_watcher"`
credsProvider *aws.CredentialsProvider
}

// AssumeRole holds the configuration needed to assume a role
Expand All @@ -25,6 +26,14 @@ type AssumeRole struct {
STSRegion string `mapstructure:"sts_region,omitempty"`
}

// SharedCredentialsWatcher holds the configuration to setup a file based
// watch for environments where the shared credentials file is updated
// periodically by an external process.
type SharedCredentialsWatcher struct {
FileLocation string `mapstructure:"file_location,omitempty"`
ProfileName string `mapstructure:"profile_name,omitempty"`
}

// compile time check that the Config struct satisfies the component.Config interface
var _ component.Config = (*Config)(nil)

Expand Down
4 changes: 4 additions & 0 deletions extensions/newsigv4/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func TestLoadConfig(t *testing.T) {
SessionName: "role_session_name",
STSRegion: "region",
},
SharedCredentialsWatcher: SharedCredentialsWatcher{
FileLocation: "/local/credentials/credentials",
ProfileName: "default",
},
// Ensure creds are the same for load config test; tested in extension_test.go
credsProvider: cfg.(*Config).credsProvider,
}, cfg)
Expand Down
26 changes: 26 additions & 0 deletions extensions/newsigv4/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package newsigv4

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)

type sharedCredentialsProvider struct {
sfile string
profile string
}

// Retrieve returns fresh credentials from the given shared
// credentials file.
func (s *sharedCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
sharedcfg, err := config.LoadSharedConfigProfile(ctx, s.profile, func(opts *config.LoadSharedConfigOptions) {
opts.CredentialsFiles = []string{s.sfile}
})
if err != nil {
return aws.Credentials{}, err
}

return sharedcfg.Credentials, nil
}
38 changes: 38 additions & 0 deletions extensions/newsigv4/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package newsigv4

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
)

func TestSharedCredentialsProvider_Retrieve(t *testing.T) {
t.Run("Retrieve valid credentials from a temp file", func(t *testing.T) {
tmpFile, err := os.CreateTemp("", "shared-credentials")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())

sampleProfile := `[default]
aws_access_key_id = TEST_ACCESS_KEY
aws_secret_access_key = TEST_SECRET_KEY
`
_, err = tmpFile.WriteString(sampleProfile)
require.NoError(t, err)

// Close the file so the provider can read it properly.
err = tmpFile.Close()
require.NoError(t, err)

provider := &sharedCredentialsProvider{
profile: "default",
sfile: tmpFile.Name(),
}
creds, err := provider.Retrieve(context.Background())
require.NoError(t, err)

require.Equal(t, creds.AccessKeyID, "TEST_ACCESS_KEY")
require.Equal(t, creds.SecretAccessKey, "TEST_SECRET_KEY")
})
}
106 changes: 100 additions & 6 deletions extensions/newsigv4/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/fsnotify/fsnotify"
"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/component/componentstatus"
"go.opentelemetry.io/collector/extension/auth"
"go.uber.org/zap"
grpcCredentials "google.golang.org/grpc/credentials"
Expand All @@ -22,11 +24,10 @@ import (
// sigv4Auth is a struct that implements the auth.Client interface.
// It provides the implementation for providing Sigv4 authentication for HTTP requests only.
type sigv4Auth struct {
cfg *Config
logger *zap.Logger
awsSDKInfo string
component.StartFunc // embedded default behavior to do nothing with Start()
component.ShutdownFunc // embedded default behavior to do nothing with Shutdown()
cfg *Config
logger *zap.Logger
awsSDKInfo string
watcher *fsnotify.Watcher
}

// compile time check that the sigv4Auth struct satisfies the auth.Client interface
Expand Down Expand Up @@ -57,6 +58,85 @@ func (sa *sigv4Auth) PerRPCCredentials() (grpcCredentials.PerRPCCredentials, err
return nil, errors.New("not implemented")
}

// Start is implemented to satisfy the component.Component interface. Start
// is called on extension inialization and will setup the fsnotify
// file watcher when credentials are provided by a shared credentials file
// that requires refreshing over time.
func (sa *sigv4Auth) Start(_ context.Context, host component.Host) error {
if sa.cfg.SharedCredentialsWatcher.FileLocation != "" {
watcher, err := fsnotify.NewWatcher()
if err != nil {
componentstatus.ReportStatus(host, componentstatus.NewFatalErrorEvent(err))
return nil
}
sa.watcher = watcher

if err := sa.startWatcher(); err != nil {
componentstatus.ReportStatus(host, componentstatus.NewFatalErrorEvent(err))
}
sa.logger.Info("Started credentials file watcher")
}

return nil
}

// Shutdown is implemented to satisfy the component.Component interface. Shutdown
// closes any open fsnotify watches. Any goroutines active from startWatcher will
// subsequently exit safely.
func (sa *sigv4Auth) Shutdown(_ context.Context) error {
if sa.watcher != nil {
if err := sa.watcher.Close(); err != nil {
return err
}
}

return nil
}

func (sa *sigv4Auth) startWatcher() error {
location := sa.cfg.SharedCredentialsWatcher.FileLocation

// invalidator is a local copy of the internal interface for cache invalidators
// from the AWS Go SDK.
// https://github.com/aws/aws-sdk-go-v2/blob/main/internal/sdk/interfaces.go
type invalidator interface {
Invalidate()
}

cache, ok := (*sa.cfg.credsProvider).(invalidator)
if !ok {
return nil
}

go func() {
for {
select {
case event, ok := <-sa.watcher.Events:
if !ok {
return
}

if event.Has(fsnotify.Create | fsnotify.Write | fsnotify.Rename) {
sa.logger.Info("Detected changes within shared credentials file")
cache.Invalidate()
}
case err, ok := <-sa.watcher.Errors:
if !ok {
return
}

sa.logger.Error("Error event from file watcher", zap.Error(err))
}
}
}()

if err := sa.watcher.Add(location); err != nil {
return err
}

return nil
}

// newSigv4Extension() is called by createExtension() in factory.go and
// returns a new sigv4Auth struct.
func newSigv4Extension(cfg *Config, awsSDKInfo string, logger *zap.Logger) *sigv4Auth {
Expand All @@ -76,10 +156,24 @@ func getCredsProviderFromConfig(cfg *Config) (*aws.CredentialsProvider, error) {
if err != nil {
return nil, err
}

var provider aws.CredentialsProvider

// Create new wrapped CredentialProvider from awscfg
if cfg.SharedCredentialsWatcher.FileLocation != "" {
provider = &sharedCredentialsProvider{
sfile: cfg.SharedCredentialsWatcher.FileLocation,
profile: cfg.SharedCredentialsWatcher.ProfileName,
}
}

if cfg.AssumeRole.ARN != "" {
stsSvc := sts.NewFromConfig(awscfg)

provider := stscreds.NewAssumeRoleProvider(stsSvc, cfg.AssumeRole.ARN)
provider = stscreds.NewAssumeRoleProvider(stsSvc, cfg.AssumeRole.ARN)
}

if provider != nil {
awscfg.Credentials = aws.NewCredentialsCache(provider)
}

Expand Down
2 changes: 1 addition & 1 deletion extensions/newsigv4/generated_component_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions extensions/newsigv4/testdata/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
sigv4auth:
newsigv4:
region: "region"
service: "service"
assume_role:
session_name: "role_session_name"
sigv4auth/missing_credentials:
shared_credentials_watcher:
file_location: "/local/credentials/credentials"
profile_name: "default"
newsigv4/missing_credentials:
region: "region"
service: "service"
6 changes: 4 additions & 2 deletions pkg/defaultcomponents/defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import (

const (
exportersCount = 16
receiversCount = 10
extensionsCount = 8
receiversCount = 11
extensionsCount = 9
processorCount = 15
)

Expand Down Expand Up @@ -69,13 +69,15 @@ func TestComponents(t *testing.T) {
assert.NotNil(t, receivers[component.MustNewType("jaeger")])
assert.NotNil(t, receivers[component.MustNewType("kafka")])
assert.NotNil(t, receivers[component.MustNewType("filelog")])
assert.NotNil(t, receivers[component.MustNewType("hostmetrics")])

extensions := factories.Extensions
assert.Len(t, extensions, extensionsCount)
// aws extensions
assert.NotNil(t, extensions[component.MustNewType("awsproxy")])
assert.NotNil(t, extensions[component.MustNewType("ecs_observer")])
assert.NotNil(t, extensions[component.MustNewType("sigv4auth")])
assert.NotNil(t, extensions[component.MustNewType("newsigv4")])
// core extensions
assert.NotNil(t, extensions[component.MustNewType("zpages")])
assert.NotNil(t, extensions[component.MustNewType("memory_ballast")])
Expand Down
2 changes: 1 addition & 1 deletion tools/packaging/linux/create_rpm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ mv "${RPM_NAME}-${VERSION}.tar.gz" "${BUILD_ROOT}/SOURCES/"
rm -rf "${WORK_DIR}"

echo "Creating the rpm package"
rpmbuild --define "VERSION $VERSION" --define "RPM_NAME $RPM_NAME" --define "_topdir ${BUILD_ROOT}" --define "_source_filedigest_algorithm 8" --define "_binary_filedigest_algorithm 8" -bb -v --clean ${SPEC_FILE} --target "${ARCH}-linux"
rpmbuild --define "VERSION $VERSION" --define "RPM_NAME $RPM_NAME" --define "_topdir ${BUILD_ROOT}" --define "_source_filedigest_algorithm 8" --define "_binary_filedigest_algorithm 8" --define "_use_weak_usergroup_deps 1" -bb -v --clean ${SPEC_FILE} --target "${ARCH}-linux"

echo "Copy rpm file to ${DEST}"
mkdir -p "${DEST}"
Expand Down

0 comments on commit 9ec5466

Please sign in to comment.