From 691474bae6b6b335da7f41607b76694a795dda4a Mon Sep 17 00:00:00 2001 From: Marco Rizzi Date: Wed, 18 Sep 2024 16:36:01 +0200 Subject: [PATCH] TC-1793 S3 connector (#125) Signed-off-by: mrizzi --- go.mod | 5 +- pkg/handler/collector/s3/bucket/bucket.go | 89 ++++++++++--------- pkg/handler/collector/s3/messaging/kafka.go | 97 ++++++++++++++++++++- pkg/handler/collector/s3/messaging/sqs.go | 33 +++++-- pkg/handler/collector/s3/s3.go | 30 +++++++ 5 files changed, 208 insertions(+), 46 deletions(-) diff --git a/go.mod b/go.mod index fe1625f0eb..e398d2e6d3 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,6 @@ require ( github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.16 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 // indirect github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.9 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect @@ -238,6 +237,9 @@ require ( github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xanzy/go-gitlab v0.93.1 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e // indirect github.com/zclconf/go-cty v1.10.0 // indirect go.etcd.io/etcd/api/v3 v3.5.12 // indirect @@ -279,6 +281,7 @@ require ( github.com/aws/aws-sdk-go v1.53.1 github.com/aws/aws-sdk-go-v2 v1.27.0 github.com/aws/aws-sdk-go-v2/config v1.27.16 + github.com/aws/aws-sdk-go-v2/credentials v1.17.16 github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 github.com/aws/aws-sdk-go-v2/service/sqs v1.31.4 github.com/cdevents/sdk-go v0.3.2 diff --git a/pkg/handler/collector/s3/bucket/bucket.go b/pkg/handler/collector/s3/bucket/bucket.go index c58996058b..23f7c9e11e 100644 --- a/pkg/handler/collector/s3/bucket/bucket.go +++ b/pkg/handler/collector/s3/bucket/bucket.go @@ -24,8 +24,10 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/guacsec/guac/pkg/logging" + "github.com/spf13/viper" ) type BuildBucket interface { @@ -58,22 +60,11 @@ func GetDefaultBucket(url string, region string) Bucket { } func (d *s3Bucket) ListFiles(ctx context.Context, bucket string, prefix string, token *string, max int32) ([]string, *string, error) { - cfg, err := config.LoadDefaultConfig(ctx) + client, err := d.getS3Client(ctx) if err != nil { - return nil, nil, fmt.Errorf("error loading AWS SDK config: %w", err) + return nil, nil, fmt.Errorf("error creating S3 client: %w", err) } - client := s3.NewFromConfig(cfg, func(o *s3.Options) { - o.UsePathStyle = true - if d.url != "" { - o.BaseEndpoint = aws.String(d.url) - } - - if d.region != "" { - o.Region = d.region - } - }) - input := &s3.ListObjectsV2Input{ Bucket: &bucket, Prefix: &prefix, @@ -97,22 +88,11 @@ func (d *s3Bucket) ListFiles(ctx context.Context, bucket string, prefix string, } func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string) ([]byte, error) { - cfg, err := config.LoadDefaultConfig(ctx) + client, err := d.getS3Client(ctx) if err != nil { - return nil, fmt.Errorf("error loading AWS SDK config: %w", err) + return nil, fmt.Errorf("error creating S3 client: %w", err) } - client := s3.NewFromConfig(cfg, func(o *s3.Options) { - o.UsePathStyle = true - if d.url != "" { - o.BaseEndpoint = aws.String(d.url) - } - - if d.region != "" { - o.Region = d.region - } - }) - // Create a GetObjectInput with the bucket name and object key. input := &s3.GetObjectInput{ Bucket: aws.String(bucket), @@ -121,7 +101,7 @@ func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string) resp, err := client.GetObject(ctx, input) if err != nil { - return nil, fmt.Errorf("unable to download file: %s %w", item, err) + return nil, fmt.Errorf("unable to download file: %w", err) } defer resp.Body.Close() @@ -136,21 +116,11 @@ func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string) func (d *s3Bucket) GetEncoding(ctx context.Context, bucket string, item string) (string, error) { logger := logging.FromContext(ctx) - cfg, err := config.LoadDefaultConfig(ctx) + client, err := d.getS3Client(ctx) if err != nil { - return "", fmt.Errorf("error loading AWS SDK config: %w", err) + return "", fmt.Errorf("error creating S3 client: %w", err) } - client := s3.NewFromConfig(cfg, func(o *s3.Options) { - o.UsePathStyle = true - if d.url != "" { - o.BaseEndpoint = aws.String(d.url) - } - if d.region != "" { - o.Region = d.region - } - }) - logger.Infof("Downloading document %v from bucket %v", item, bucket) headObject, err := client.HeadObject(context.Background(), &s3.HeadObjectInput{Bucket: aws.String(bucket), Key: aws.String(item)}) @@ -164,3 +134,44 @@ func (d *s3Bucket) GetEncoding(ctx context.Context, bucket string, item string) return *headObject.ContentEncoding, nil } + +func (d *s3Bucket) getS3Client(ctx context.Context) (*s3.Client, error) { + s3Config := &viper.Viper{} + s3Config.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + s3Config.AutomaticEnv() + + accessKey := s3Config.GetString("storage-access-key") + secretKey := s3Config.GetString("storage-secret-key") + region := s3Config.GetString("storage-region") + if region == "" { + region = d.region + } + + cfg, err := config.LoadDefaultConfig(ctx) + + if err != nil { + return nil, fmt.Errorf("error loading AWS SDK config: %w", err) + } + + return s3.NewFromConfig(cfg, func(o *s3.Options) { + o.UsePathStyle = true + if d.url != "" { + o.BaseEndpoint = aws.String(d.url) + } + + if region != "" { + o.Region = region + } + + if accessKey != "" && secretKey != "" { + staticProvider := credentials.NewStaticCredentialsProvider( + accessKey, + secretKey, + "", + ) + o.Credentials = staticProvider + } + + }), nil + +} diff --git a/pkg/handler/collector/s3/messaging/kafka.go b/pkg/handler/collector/s3/messaging/kafka.go index d122b7ab69..a37c2675fc 100644 --- a/pkg/handler/collector/s3/messaging/kafka.go +++ b/pkg/handler/collector/s3/messaging/kafka.go @@ -17,12 +17,20 @@ package messaging import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" + "os" "strings" + "time" "github.com/guacsec/guac/pkg/logging" "github.com/segmentio/kafka-go" + "github.com/segmentio/kafka-go/sasl" + "github.com/segmentio/kafka-go/sasl/plain" + "github.com/segmentio/kafka-go/sasl/scram" + "github.com/spf13/viper" ) type KafkaProvider struct { @@ -63,12 +71,40 @@ func NewKafkaProvider(mpConfig MessageProviderConfig) (KafkaProvider, error) { kafkaTopic := mpConfig.Queue kafkaProvider := KafkaProvider{} + + kafkaConfig := &viper.Viper{} + + prefix := os.Getenv("KAFKA_PROPERTIES_ENV_PREFIX") + prefix = strings.TrimSuffix(prefix, "_") + + kafkaConfig.SetEnvPrefix(prefix) + kafkaConfig.SetEnvKeyReplacer(strings.NewReplacer("-", "__")) + kafkaConfig.AutomaticEnv() + + mechanism, err := SASLMechanism(*kafkaConfig) + if err != nil { + return KafkaProvider{}, err + } + + tlsConfig, err := TLSConfig(*kafkaConfig) + if err != nil { + return KafkaProvider{}, err + } + + dialer := &kafka.Dialer{ + Timeout: 10 * time.Second, + DualStack: true, + SASLMechanism: mechanism, + TLS: tlsConfig, + } + kafkaProvider.reader = kafka.NewReader(kafka.ReaderConfig{ Brokers: []string{mpConfig.Endpoint}, Topic: kafkaTopic, Partition: 0, + Dialer: dialer, }) - err := kafkaProvider.reader.SetOffset(kafka.LastOffset) + err = kafkaProvider.reader.SetOffset(kafka.LastOffset) if err != nil { return KafkaProvider{}, err } @@ -76,6 +112,65 @@ func NewKafkaProvider(mpConfig MessageProviderConfig) (KafkaProvider, error) { return kafkaProvider, nil } +func SASLMechanism(kafkaConfig viper.Viper) (sasl.Mechanism, error) { + protocol := kafkaConfig.GetString("security-protocol") + saslProtocols := make(map[string]struct{}) + saslProtocols["SASL_PLAINTEXT"] = struct{}{} + saslProtocols["SASL_SSL"] = struct{}{} + + _, isSasl := saslProtocols[protocol] + if !isSasl { + return nil, nil + } + mechanism := kafkaConfig.GetString("sasl-mechanism") + username := kafkaConfig.GetString("sasl-username") + password := kafkaConfig.GetString("sasl-password") + + switch mechanism { + case "SCRAM-SHA-256": + return scram.Mechanism(scram.SHA256, username, password) + case "SCRAM-SHA-512": + return scram.Mechanism(scram.SHA512, username, password) + case "PLAIN": + return plain.Mechanism{ + Username: username, + Password: password, + }, nil + default: + return nil, nil + } +} + +func TLSConfig(kafkaConfig viper.Viper) (*tls.Config, error) { + protocol := kafkaConfig.GetString("security-protocol") + tlsProtocols := make(map[string]struct{}) + tlsProtocols["SSL"] = struct{}{} + tlsProtocols["SASL_SSL"] = struct{}{} + + _, isTls := tlsProtocols[protocol] + if !isTls { + return nil, nil + } + sslCaLocation := kafkaConfig.GetString("ssl-ca-location") + verifyClientCert := kafkaConfig.GetBool("enable-ssl-certificate-verification") + + caFile, err := os.ReadFile(sslCaLocation) + if err != nil { + return nil, nil + } + rootCAs, _ := x509.SystemCertPool() + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + rootCAs.AppendCertsFromPEM(caFile) + + return &tls.Config{ + InsecureSkipVerify: !verifyClientCert, + RootCAs: rootCAs, + }, nil + +} + func (k *KafkaProvider) ReceiveMessage(ctx context.Context) (Message, error) { logger := logging.FromContext(ctx) diff --git a/pkg/handler/collector/s3/messaging/sqs.go b/pkg/handler/collector/s3/messaging/sqs.go index 07e716839e..745c88ce01 100644 --- a/pkg/handler/collector/s3/messaging/sqs.go +++ b/pkg/handler/collector/s3/messaging/sqs.go @@ -19,11 +19,14 @@ import ( "context" "encoding/json" "fmt" + "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/guacsec/guac/pkg/logging" + "github.com/spf13/viper" ) type SqsProvider struct { @@ -83,6 +86,14 @@ func NewSqsProvider(mpConfig MessageProviderConfig) (SqsProvider, error) { sqsQueue := mpConfig.Queue sqsProvider := SqsProvider{} + sqsConfig := &viper.Viper{} + sqsConfig.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + sqsConfig.AutomaticEnv() + + accessKey := sqsConfig.GetString("sqs-access-key") + secretKey := sqsConfig.GetString("sqs-secret-key") + region := sqsConfig.GetString("sqs-region") + cfg, err := config.LoadDefaultConfig(context.TODO()) if err != nil { return SqsProvider{}, fmt.Errorf("error loading AWS SDK config: %w", err) @@ -92,8 +103,17 @@ func NewSqsProvider(mpConfig MessageProviderConfig) (SqsProvider, error) { if mpConfig.Endpoint != "" { o.BaseEndpoint = aws.String(mpConfig.Endpoint) } - if mpConfig.Region != "" { - o.Region = mpConfig.Region + if region != "" { + o.Region = region + } + + if accessKey != "" && secretKey != "" { + staticProvider := credentials.NewStaticCredentialsProvider( + accessKey, + secretKey, + "", + ) + o.Credentials = staticProvider } }) @@ -113,7 +133,9 @@ func (s *SqsProvider) ReceiveMessage(ctx context.Context) (Message, error) { // Get URL of queue urlResult, err := s.client.GetQueueUrl(ctx, gQInput) if err != nil { - return nil, fmt.Errorf("Got an error getting the queue URL : %w", err) + fmt.Println("Got an error getting the queue URL:") + fmt.Println(err) + return nil, err } addr := urlResult.QueueUrl @@ -131,7 +153,8 @@ func (s *SqsProvider) ReceiveMessage(ctx context.Context) (Message, error) { default: receiveOutput, err := s.client.ReceiveMessage(ctx, receiveInput) if err != nil { - return &SqsMessage{}, fmt.Errorf("error receiving message, skipping: %w", err) + fmt.Printf("error receiving message, skipping: %v\n", err) + //continue } messages := receiveOutput.Messages @@ -153,7 +176,7 @@ func (s *SqsProvider) ReceiveMessage(ctx context.Context) (Message, error) { } _, err = s.client.DeleteMessage(context.TODO(), deleteInput) if err != nil { - return nil, fmt.Errorf("error deleting message: %w", err) + logger.Errorf("error deleting message: %v\n", err) } logger.Debugf("Message deleted from the queue") diff --git a/pkg/handler/collector/s3/s3.go b/pkg/handler/collector/s3/s3.go index b291a5f940..293cad7515 100644 --- a/pkg/handler/collector/s3/s3.go +++ b/pkg/handler/collector/s3/s3.go @@ -18,6 +18,7 @@ package s3 import ( "context" "fmt" + "os" "strings" "sync" @@ -42,11 +43,13 @@ type S3CollectorConfig struct { S3Url string // optional (uses aws sdk defaults) S3Bucket string // bucket name to collect from S3Path string // optional (only for non-polling) s3 folder path to collect from + Limit int // optional max number of files to download from the bucket S3Item string // optional (only for non-polling behaviour) S3Region string // optional (defaults to us-east-1, assumes same region for s3 and sqs) Queues string // optional (comma-separated list of queues/topics) MpBuilder messaging.MessageProviderBuilder // optional BucketBuilder bucket.BuildBucket // optional + SigChan chan os.Signal // optional Poll bool } @@ -58,6 +61,7 @@ func NewS3Collector(cfg S3CollectorConfig) *S3Collector { } func (s *S3Collector) RetrieveArtifacts(ctx context.Context, docChannel chan<- *processor.Document) error { + if s.config.Poll { retrieveWithPoll(*s, ctx, docChannel) } else { @@ -100,6 +104,7 @@ func retrieve(s S3Collector, ctx context.Context, docChannel chan<- *processor.D } else { var token *string const MaxKeys = 100 + var total = 0 for { files, t, err := downloader.ListFiles(ctx, s.config.S3Bucket, s.config.S3Path, token, MaxKeys) if err != nil { @@ -109,12 +114,25 @@ func retrieve(s S3Collector, ctx context.Context, docChannel chan<- *processor.D token = t for _, item := range files { + logger.Infof("Processing %v", item) + + if !strings.HasPrefix(item, "data/") { + logger.Infof("Skipping non-data file") + continue + } + blob, err := downloader.DownloadFile(ctx, s.config.S3Bucket, item) if err != nil { logger.Errorf("could not download item %v, skipping: %v", item, err) continue } + // TODO make this configurable + // if len(blob) > 6291456 { + // logger.Infof("Skipping %s due to its size %d", item, len(blob)) + // continue + // } + enc, err := downloader.GetEncoding(ctx, s.config.S3Bucket, item) if err != nil { logger.Errorf("could not get encoding for item %v, skipping: %v", item, err) @@ -132,7 +150,16 @@ func retrieve(s S3Collector, ctx context.Context, docChannel chan<- *processor.D DocumentRef: events.GetDocRef(blob), }, } + + logger.Infof("Ingesting item %s of size %d", item, len(blob)) + docChannel <- doc + + total += 1 + if s.config.Limit > 0 && total >= s.config.Limit { + logger.Infof("Configured limit of %d reached. Exiting.", s.config.Limit) + break + } } if len(files) < MaxKeys { @@ -245,6 +272,9 @@ func getMessageProvider(s S3Collector, queue string) (messaging.MessageProvider, mpBuilder = s.config.MpBuilder } else { mpBuilder = messaging.GetDefaultMessageProviderBuilder() + if err != nil { + return nil, fmt.Errorf("error getting message provider: %w", err) + } } mp, err := mpBuilder.GetMessageProvider(messaging.MessageProviderConfig{