diff --git a/kms/keysource.go b/kms/keysource.go index e1441b492..0ed044ddb 100644 --- a/kms/keysource.go +++ b/kms/keysource.go @@ -181,6 +181,41 @@ func ParseKMSContext(in interface{}) map[string]*string { return out } +// GetArnByAlias takes a AWS KMS Client, key. +// When argument comes with alias, convert it to arn +func GetArnByAlias(client *kms.Client, key *MasterKey) error { + input := &kms.ListAliasesInput{} + + paginator := kms.NewListAliasesPaginator(client, input) + found := false + for paginator.HasMorePages() && !found { + output, err := paginator.NextPage(context.Background()) + if err != nil { + return fmt.Errorf("failed to get kms key: %w", err) + } + + for _, alias := range output.Aliases { + if strings.HasSuffix(*alias.AliasArn, key.Arn) { + + describeInput := &kms.DescribeKeyInput{ + KeyId: aws.String(*alias.TargetKeyId), + } + + describeOutput, err := client.DescribeKey(context.Background(), describeInput) + + if err != nil { + return fmt.Errorf("failed to describe key: %w", err) + } + + key.Arn = *describeOutput.KeyMetadata.Arn + found = true + } + } + } + + return nil +} + // CredentialsProvider is a wrapper around aws.CredentialsProvider used for // authentication towards AWS KMS. type CredentialsProvider struct { @@ -209,6 +244,31 @@ func (key *MasterKey) Encrypt(dataKey []byte) error { return err } client := key.createClient(cfg) + + // condition that input is an alias + if !strings.HasPrefix(string(key.Arn), "arn:aws:kms") { + + err = GetArnByAlias(client, key) + + if err != nil { + return err + } + + encryptInput := &kms.EncryptInput{ + KeyId: &key.Arn, + Plaintext: dataKey, + EncryptionContext: stringPointerToStringMap(key.EncryptionContext), + } + out, err := client.Encrypt(context.TODO(), encryptInput) + if err != nil { + log.WithField("arn", key.Arn).Info("Encryption failed") + return fmt.Errorf("failed to encrypt sops data key with AWS KMS 222: %s", key.Arn) + } + key.EncryptedKey = base64.StdEncoding.EncodeToString(out.CiphertextBlob) + log.WithField("arn", key.Arn).Info("Encryption succeeded") + return nil + } + input := &kms.EncryptInput{ KeyId: &key.Arn, Plaintext: dataKey, @@ -257,6 +317,52 @@ func (key *MasterKey) Decrypt() ([]byte, error) { return nil, err } client := key.createClient(cfg) + + if !strings.HasPrefix(string(key.Arn), "arn:aws:kms") { + input := &kms.ListAliasesInput{} + + paginator := kms.NewListAliasesPaginator(client, input) + found := false + for paginator.HasMorePages() && !found { + output, err := paginator.NextPage(context.Background()) + if err != nil { + log.WithField("arn", key.Arn).Info("Error listing aliases") + break + } + + for _, alias := range output.Aliases { + if strings.HasSuffix(*alias.AliasArn, key.Arn) { + + describeInput := &kms.DescribeKeyInput{ + KeyId: aws.String(*alias.TargetKeyId), + } + + describeOutput, err := client.DescribeKey(context.Background(), describeInput) + + if err != nil { + return nil, fmt.Errorf("failed to describe key: %w", err) + } + + key.Arn = *describeOutput.KeyMetadata.Arn + found = true + } + } + } + + decryptedInput := &kms.DecryptInput{ + KeyId: &key.Arn, + CiphertextBlob: k, + EncryptionContext: stringPointerToStringMap(key.EncryptionContext), + } + decrypted, err := client.Decrypt(context.TODO(), decryptedInput) + if err != nil { + log.WithField("arn", key.Arn).Info("Decryption failed") + return nil, fmt.Errorf("failed to decrypt sops data key with AWS KMS: %w", err) + } + log.WithField("arn", key.Arn).Info("Decryption succeeded") + return decrypted.Plaintext, nil + } + input := &kms.DecryptInput{ KeyId: &key.Arn, CiphertextBlob: k, @@ -307,13 +413,7 @@ func (key *MasterKey) TypeToIdentifier() string { // createKMSConfig returns an AWS config with the credentialsProvider of the // MasterKey, or the default configuration sources. -func (key MasterKey) createKMSConfig() (*aws.Config, error) { - re := regexp.MustCompile(arnRegex) - matches := re.FindStringSubmatch(key.Arn) - if matches == nil { - return nil, fmt.Errorf("no valid ARN found in '%s'", key.Arn) - } - region := matches[1] +func (key *MasterKey) createKMSConfig() (*aws.Config, error) { cfg, err := config.LoadDefaultConfig(context.TODO(), func(lo *config.LoadOptions) error { // Use the credentialsProvider if present, otherwise default to reading credentials @@ -324,13 +424,28 @@ func (key MasterKey) createKMSConfig() (*aws.Config, error) { if key.AwsProfile != "" { lo.SharedConfigProfile = key.AwsProfile } - lo.Region = region return nil }) + if err != nil { return nil, fmt.Errorf("could not load AWS config: %w", err) } + re := regexp.MustCompile(arnRegex) + matches := re.FindStringSubmatch(key.Arn) + + if matches == nil { + client := key.createClient(&cfg) + err = GetArnByAlias(client, key) + + if err != nil { + return nil, err + } + + matches = re.FindStringSubmatch(key.Arn) + cfg.Region = matches[1] + } + if key.Role != "" { return key.createSTSConfig(&cfg) }