diff --git a/lib/kafkalib/connection.go b/lib/kafkalib/connection.go new file mode 100644 index 000000000..386207d28 --- /dev/null +++ b/lib/kafkalib/connection.go @@ -0,0 +1,119 @@ +package kafkalib + +import ( + "context" + "crypto/tls" + "fmt" + "time" + + awsCfg "github.com/aws/aws-sdk-go-v2/config" + "github.com/segmentio/kafka-go" + "github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2" + "github.com/segmentio/kafka-go/sasl/scram" +) + +type Mechanism string + +const ( + Plain Mechanism = "PLAIN" + ScramSha512 Mechanism = "SCRAM-SHA-512" + AwsMskIam Mechanism = "AWS-MSK-IAM" +) + +type Connection struct { + enableAWSMSKIAM bool + disableTLS bool + username string + password string +} + +func NewConnection(enableAWSMSKIAM bool, disableTLS bool, username, password string) Connection { + return Connection{ + enableAWSMSKIAM: enableAWSMSKIAM, + disableTLS: disableTLS, + username: username, + password: password, + } +} + +func (c Connection) Mechanism() Mechanism { + if c.username != "" && c.password != "" { + return ScramSha512 + } + + if c.enableAWSMSKIAM { + return AwsMskIam + } + + return Plain +} + +func (c Connection) Dialer(ctx context.Context) (*kafka.Dialer, error) { + dialer := &kafka.Dialer{ + Timeout: 10 * time.Second, + DualStack: true, + } + + switch c.Mechanism() { + case ScramSha512: + mechanism, err := scram.Mechanism(scram.SHA512, c.username, c.password) + if err != nil { + return nil, fmt.Errorf("failed to create SCRAM mechanism: %w", err) + } + + dialer.SASLMechanism = mechanism + if !c.disableTLS { + dialer.TLS = &tls.Config{} + } + case AwsMskIam: + _awsCfg, err := awsCfg.LoadDefaultConfig(ctx) + if err != nil { + return nil, fmt.Errorf("failed to load aws configuration: %w", err) + } + + dialer.SASLMechanism = aws_msk_iam_v2.NewMechanism(_awsCfg) + // We don't need to disable TLS for AWS IAM since MSK will always enable TLS. + dialer.TLS = &tls.Config{} + case Plain: + // No mechanism + default: + return nil, fmt.Errorf("unsupported kafka mechanism: %s", c.Mechanism()) + } + + return dialer, nil +} + +func (c Connection) Transport() (*kafka.Transport, error) { + transport := &kafka.Transport{ + DialTimeout: 10 * time.Second, + } + + switch c.Mechanism() { + case ScramSha512: + mechanism, err := scram.Mechanism(scram.SHA512, c.username, c.password) + if err != nil { + return nil, fmt.Errorf("failed to create SCRAM mechanism: %w", err) + } + + transport.SASL = mechanism + if !c.disableTLS { + transport.TLS = &tls.Config{} + } + case AwsMskIam: + _awsCfg, err := awsCfg.LoadDefaultConfig(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to load AWS configuration: %w", err) + } + + transport.SASL = aws_msk_iam_v2.NewMechanism(_awsCfg) + if !c.disableTLS { + transport.TLS = &tls.Config{} + } + case Plain: + // No mechanism + default: + return nil, fmt.Errorf("unsupported kafka mechanism: %s", c.Mechanism()) + } + + return transport, nil +} diff --git a/lib/kafkalib/connection_test.go b/lib/kafkalib/connection_test.go new file mode 100644 index 000000000..f64dabe3f --- /dev/null +++ b/lib/kafkalib/connection_test.go @@ -0,0 +1,69 @@ +package kafkalib + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnection_Mechanism(t *testing.T) { + { + c := NewConnection(false, false, "", "") + assert.Equal(t, Plain, c.Mechanism()) + } + { + c := NewConnection(false, false, "username", "password") + assert.Equal(t, ScramSha512, c.Mechanism()) + + // AWS MSK IAM is enabled, but SCRAM is preferred + c = NewConnection(true, false, "username", "password") + assert.Equal(t, ScramSha512, c.Mechanism()) + } + { + c := NewConnection(true, false, "", "") + assert.Equal(t, AwsMskIam, c.Mechanism()) + } +} + +func TestConnection_Dialer(t *testing.T) { + ctx := context.Background() + { + // Plain + c := NewConnection(false, false, "", "") + dialer, err := c.Dialer(ctx) + assert.NoError(t, err) + assert.Nil(t, dialer.TLS) + assert.Nil(t, dialer.SASLMechanism) + } + { + // SCRAM enabled with TLS + c := NewConnection(false, false, "username", "password") + dialer, err := c.Dialer(ctx) + assert.NoError(t, err) + assert.NotNil(t, dialer.TLS) + assert.NotNil(t, dialer.SASLMechanism) + + // w/o TLS + c = NewConnection(false, true, "username", "password") + dialer, err = c.Dialer(ctx) + assert.NoError(t, err) + assert.Nil(t, dialer.TLS) + assert.NotNil(t, dialer.SASLMechanism) + } + { + // AWS IAM w/ TLS + c := NewConnection(true, false, "", "") + dialer, err := c.Dialer(ctx) + assert.NoError(t, err) + assert.NotNil(t, dialer.TLS) + assert.NotNil(t, dialer.SASLMechanism) + + // w/o TLS (still enabled because AWS doesn't support not having TLS) + c = NewConnection(true, true, "", "") + dialer, err = c.Dialer(ctx) + assert.NoError(t, err) + assert.NotNil(t, dialer.TLS) + assert.NotNil(t, dialer.SASLMechanism) + } +} diff --git a/processes/consumer/kafka.go b/processes/consumer/kafka.go index 1c7eecf63..2b3a810e1 100644 --- a/processes/consumer/kafka.go +++ b/processes/consumer/kafka.go @@ -2,16 +2,10 @@ package consumer import ( "context" - "crypto/tls" "log/slog" "sync" "time" - awsCfg "github.com/aws/aws-sdk-go-v2/config" - "github.com/segmentio/kafka-go" - "github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2" - "github.com/segmentio/kafka-go/sasl/scram" - "github.com/artie-labs/transfer/lib/artie" "github.com/artie-labs/transfer/lib/cdc/format" "github.com/artie-labs/transfer/lib/config" @@ -21,6 +15,7 @@ import ( "github.com/artie-labs/transfer/lib/logger" "github.com/artie-labs/transfer/lib/telemetry/metrics/base" "github.com/artie-labs/transfer/models" + "github.com/segmentio/kafka-go" ) var topicToConsumer *TopicToConsumer @@ -49,37 +44,15 @@ func (t *TopicToConsumer) Get(topic string) kafkalib.Consumer { } func StartConsumer(ctx context.Context, cfg config.Config, inMemDB *models.DatabaseData, dest destination.Baseline, metricsClient base.Client) { - slog.Info("Starting Kafka consumer...", slog.Any("config", cfg.Kafka)) - dialer := &kafka.Dialer{ - Timeout: 10 * time.Second, - DualStack: true, - } - - // If using AWS MSK IAM, we expect this to be set in the ENV VAR - // (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_REGION, or the AWS Profile should be called default.) - if cfg.Kafka.EnableAWSMSKIAM { - _awsCfg, err := awsCfg.LoadDefaultConfig(ctx) - if err != nil { - logger.Panic("Failed to load aws configuration", slog.Any("err", err)) - } - - dialer.SASLMechanism = aws_msk_iam_v2.NewMechanism(_awsCfg) - if !cfg.Kafka.DisableTLS { - dialer.TLS = &tls.Config{} - } - } - - // If username and password are provided, we'll use SCRAM w/ SHA512. - if cfg.Kafka.Username != "" { - mechanism, err := scram.Mechanism(scram.SHA512, cfg.Kafka.Username, cfg.Kafka.Password) - if err != nil { - logger.Panic("Failed to create SCRAM mechanism", slog.Any("err", err)) - } - - dialer.SASLMechanism = mechanism - if !cfg.Kafka.DisableTLS { - dialer.TLS = &tls.Config{} - } + kafkaConn := kafkalib.NewConnection(cfg.Kafka.EnableAWSMSKIAM, cfg.Kafka.DisableTLS, cfg.Kafka.Username, cfg.Kafka.Password) + slog.Info("Starting Kafka consumer...", + slog.Any("config", cfg.Kafka), + slog.Any("authMechanism", kafkaConn.Mechanism()), + ) + + dialer, err := kafkaConn.Dialer(ctx) + if err != nil { + logger.Panic("Failed to create Kafka dialer", slog.Any("err", err)) } tcFmtMap := NewTcFmtMap()