Skip to content

Commit

Permalink
TC-1793 S3 connector (guacsec#125)
Browse files Browse the repository at this point in the history
Signed-off-by: mrizzi <mrizzi@redhat.com>
  • Loading branch information
mrizzi authored Sep 18, 2024
1 parent 0a5d272 commit 691474b
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 46 deletions.
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ require (
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.16 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 // indirect
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.9 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
Expand Down Expand Up @@ -238,6 +237,9 @@ require (
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/xanzy/go-gitlab v0.93.1 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e // indirect
github.com/zclconf/go-cty v1.10.0 // indirect
go.etcd.io/etcd/api/v3 v3.5.12 // indirect
Expand Down Expand Up @@ -279,6 +281,7 @@ require (
github.com/aws/aws-sdk-go v1.53.1
github.com/aws/aws-sdk-go-v2 v1.27.0
github.com/aws/aws-sdk-go-v2/config v1.27.16
github.com/aws/aws-sdk-go-v2/credentials v1.17.16
github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1
github.com/aws/aws-sdk-go-v2/service/sqs v1.31.4
github.com/cdevents/sdk-go v0.3.2
Expand Down
89 changes: 50 additions & 39 deletions pkg/handler/collector/s3/bucket/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/guacsec/guac/pkg/logging"
"github.com/spf13/viper"
)

type BuildBucket interface {
Expand Down Expand Up @@ -58,22 +60,11 @@ func GetDefaultBucket(url string, region string) Bucket {
}

func (d *s3Bucket) ListFiles(ctx context.Context, bucket string, prefix string, token *string, max int32) ([]string, *string, error) {
cfg, err := config.LoadDefaultConfig(ctx)
client, err := d.getS3Client(ctx)
if err != nil {
return nil, nil, fmt.Errorf("error loading AWS SDK config: %w", err)
return nil, nil, fmt.Errorf("error creating S3 client: %w", err)
}

client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
if d.url != "" {
o.BaseEndpoint = aws.String(d.url)
}

if d.region != "" {
o.Region = d.region
}
})

input := &s3.ListObjectsV2Input{
Bucket: &bucket,
Prefix: &prefix,
Expand All @@ -97,22 +88,11 @@ func (d *s3Bucket) ListFiles(ctx context.Context, bucket string, prefix string,
}

func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string) ([]byte, error) {
cfg, err := config.LoadDefaultConfig(ctx)
client, err := d.getS3Client(ctx)
if err != nil {
return nil, fmt.Errorf("error loading AWS SDK config: %w", err)
return nil, fmt.Errorf("error creating S3 client: %w", err)
}

client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
if d.url != "" {
o.BaseEndpoint = aws.String(d.url)
}

if d.region != "" {
o.Region = d.region
}
})

// Create a GetObjectInput with the bucket name and object key.
input := &s3.GetObjectInput{
Bucket: aws.String(bucket),
Expand All @@ -121,7 +101,7 @@ func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string)

resp, err := client.GetObject(ctx, input)
if err != nil {
return nil, fmt.Errorf("unable to download file: %s %w", item, err)
return nil, fmt.Errorf("unable to download file: %w", err)
}
defer resp.Body.Close()

Expand All @@ -136,21 +116,11 @@ func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string)

func (d *s3Bucket) GetEncoding(ctx context.Context, bucket string, item string) (string, error) {
logger := logging.FromContext(ctx)
cfg, err := config.LoadDefaultConfig(ctx)
client, err := d.getS3Client(ctx)
if err != nil {
return "", fmt.Errorf("error loading AWS SDK config: %w", err)
return "", fmt.Errorf("error creating S3 client: %w", err)
}

client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
if d.url != "" {
o.BaseEndpoint = aws.String(d.url)
}
if d.region != "" {
o.Region = d.region
}
})

logger.Infof("Downloading document %v from bucket %v", item, bucket)

headObject, err := client.HeadObject(context.Background(), &s3.HeadObjectInput{Bucket: aws.String(bucket), Key: aws.String(item)})
Expand All @@ -164,3 +134,44 @@ func (d *s3Bucket) GetEncoding(ctx context.Context, bucket string, item string)

return *headObject.ContentEncoding, nil
}

func (d *s3Bucket) getS3Client(ctx context.Context) (*s3.Client, error) {
s3Config := &viper.Viper{}
s3Config.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
s3Config.AutomaticEnv()

accessKey := s3Config.GetString("storage-access-key")
secretKey := s3Config.GetString("storage-secret-key")
region := s3Config.GetString("storage-region")
if region == "" {
region = d.region
}

cfg, err := config.LoadDefaultConfig(ctx)

if err != nil {
return nil, fmt.Errorf("error loading AWS SDK config: %w", err)
}

return s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
if d.url != "" {
o.BaseEndpoint = aws.String(d.url)
}

if region != "" {
o.Region = region
}

if accessKey != "" && secretKey != "" {
staticProvider := credentials.NewStaticCredentialsProvider(
accessKey,
secretKey,
"",
)
o.Credentials = staticProvider
}

}), nil

}
97 changes: 96 additions & 1 deletion pkg/handler/collector/s3/messaging/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ package messaging

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"os"
"strings"
"time"

"github.com/guacsec/guac/pkg/logging"
"github.com/segmentio/kafka-go"
"github.com/segmentio/kafka-go/sasl"
"github.com/segmentio/kafka-go/sasl/plain"
"github.com/segmentio/kafka-go/sasl/scram"
"github.com/spf13/viper"
)

type KafkaProvider struct {
Expand Down Expand Up @@ -63,19 +71,106 @@ func NewKafkaProvider(mpConfig MessageProviderConfig) (KafkaProvider, error) {
kafkaTopic := mpConfig.Queue

kafkaProvider := KafkaProvider{}

kafkaConfig := &viper.Viper{}

prefix := os.Getenv("KAFKA_PROPERTIES_ENV_PREFIX")
prefix = strings.TrimSuffix(prefix, "_")

kafkaConfig.SetEnvPrefix(prefix)
kafkaConfig.SetEnvKeyReplacer(strings.NewReplacer("-", "__"))
kafkaConfig.AutomaticEnv()

mechanism, err := SASLMechanism(*kafkaConfig)
if err != nil {
return KafkaProvider{}, err
}

tlsConfig, err := TLSConfig(*kafkaConfig)
if err != nil {
return KafkaProvider{}, err
}

dialer := &kafka.Dialer{
Timeout: 10 * time.Second,
DualStack: true,
SASLMechanism: mechanism,
TLS: tlsConfig,
}

kafkaProvider.reader = kafka.NewReader(kafka.ReaderConfig{
Brokers: []string{mpConfig.Endpoint},
Topic: kafkaTopic,
Partition: 0,
Dialer: dialer,
})
err := kafkaProvider.reader.SetOffset(kafka.LastOffset)
err = kafkaProvider.reader.SetOffset(kafka.LastOffset)
if err != nil {
return KafkaProvider{}, err
}

return kafkaProvider, nil
}

func SASLMechanism(kafkaConfig viper.Viper) (sasl.Mechanism, error) {
protocol := kafkaConfig.GetString("security-protocol")
saslProtocols := make(map[string]struct{})
saslProtocols["SASL_PLAINTEXT"] = struct{}{}
saslProtocols["SASL_SSL"] = struct{}{}

_, isSasl := saslProtocols[protocol]
if !isSasl {
return nil, nil
}
mechanism := kafkaConfig.GetString("sasl-mechanism")
username := kafkaConfig.GetString("sasl-username")
password := kafkaConfig.GetString("sasl-password")

switch mechanism {
case "SCRAM-SHA-256":
return scram.Mechanism(scram.SHA256, username, password)
case "SCRAM-SHA-512":
return scram.Mechanism(scram.SHA512, username, password)
case "PLAIN":
return plain.Mechanism{
Username: username,
Password: password,
}, nil
default:
return nil, nil
}
}

func TLSConfig(kafkaConfig viper.Viper) (*tls.Config, error) {
protocol := kafkaConfig.GetString("security-protocol")
tlsProtocols := make(map[string]struct{})
tlsProtocols["SSL"] = struct{}{}
tlsProtocols["SASL_SSL"] = struct{}{}

_, isTls := tlsProtocols[protocol]
if !isTls {
return nil, nil
}
sslCaLocation := kafkaConfig.GetString("ssl-ca-location")
verifyClientCert := kafkaConfig.GetBool("enable-ssl-certificate-verification")

caFile, err := os.ReadFile(sslCaLocation)
if err != nil {
return nil, nil
}
rootCAs, _ := x509.SystemCertPool()
if rootCAs == nil {
rootCAs = x509.NewCertPool()
}
rootCAs.AppendCertsFromPEM(caFile)

return &tls.Config{
InsecureSkipVerify: !verifyClientCert,
RootCAs: rootCAs,
}, nil

}

func (k *KafkaProvider) ReceiveMessage(ctx context.Context) (Message, error) {
logger := logging.FromContext(ctx)

Expand Down
33 changes: 28 additions & 5 deletions pkg/handler/collector/s3/messaging/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/guacsec/guac/pkg/logging"
"github.com/spf13/viper"
)

type SqsProvider struct {
Expand Down Expand Up @@ -83,6 +86,14 @@ func NewSqsProvider(mpConfig MessageProviderConfig) (SqsProvider, error) {
sqsQueue := mpConfig.Queue
sqsProvider := SqsProvider{}

sqsConfig := &viper.Viper{}
sqsConfig.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
sqsConfig.AutomaticEnv()

accessKey := sqsConfig.GetString("sqs-access-key")
secretKey := sqsConfig.GetString("sqs-secret-key")
region := sqsConfig.GetString("sqs-region")

cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
return SqsProvider{}, fmt.Errorf("error loading AWS SDK config: %w", err)
Expand All @@ -92,8 +103,17 @@ func NewSqsProvider(mpConfig MessageProviderConfig) (SqsProvider, error) {
if mpConfig.Endpoint != "" {
o.BaseEndpoint = aws.String(mpConfig.Endpoint)
}
if mpConfig.Region != "" {
o.Region = mpConfig.Region
if region != "" {
o.Region = region
}

if accessKey != "" && secretKey != "" {
staticProvider := credentials.NewStaticCredentialsProvider(
accessKey,
secretKey,
"",
)
o.Credentials = staticProvider
}
})

Expand All @@ -113,7 +133,9 @@ func (s *SqsProvider) ReceiveMessage(ctx context.Context) (Message, error) {
// Get URL of queue
urlResult, err := s.client.GetQueueUrl(ctx, gQInput)
if err != nil {
return nil, fmt.Errorf("Got an error getting the queue URL : %w", err)
fmt.Println("Got an error getting the queue URL:")
fmt.Println(err)
return nil, err
}

addr := urlResult.QueueUrl
Expand All @@ -131,7 +153,8 @@ func (s *SqsProvider) ReceiveMessage(ctx context.Context) (Message, error) {
default:
receiveOutput, err := s.client.ReceiveMessage(ctx, receiveInput)
if err != nil {
return &SqsMessage{}, fmt.Errorf("error receiving message, skipping: %w", err)
fmt.Printf("error receiving message, skipping: %v\n", err)
//continue
}

messages := receiveOutput.Messages
Expand All @@ -153,7 +176,7 @@ func (s *SqsProvider) ReceiveMessage(ctx context.Context) (Message, error) {
}
_, err = s.client.DeleteMessage(context.TODO(), deleteInput)
if err != nil {
return nil, fmt.Errorf("error deleting message: %w", err)
logger.Errorf("error deleting message: %v\n", err)
}
logger.Debugf("Message deleted from the queue")

Expand Down
Loading

0 comments on commit 691474b

Please sign in to comment.