Skip to content

Commit

Permalink
[Kafka] Refactor the way we handle Kafka connections (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Jul 11, 2024
1 parent eea982d commit 4036e76
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 37 deletions.
119 changes: 119 additions & 0 deletions lib/kafkalib/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package kafkalib

import (
"context"
"crypto/tls"
"fmt"
"time"

awsCfg "github.com/aws/aws-sdk-go-v2/config"
"github.com/segmentio/kafka-go"
"github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2"
"github.com/segmentio/kafka-go/sasl/scram"
)

type Mechanism string

const (
Plain Mechanism = "PLAIN"
ScramSha512 Mechanism = "SCRAM-SHA-512"
AwsMskIam Mechanism = "AWS-MSK-IAM"
)

type Connection struct {
enableAWSMSKIAM bool
disableTLS bool
username string
password string
}

func NewConnection(enableAWSMSKIAM bool, disableTLS bool, username, password string) Connection {
return Connection{
enableAWSMSKIAM: enableAWSMSKIAM,
disableTLS: disableTLS,
username: username,
password: password,
}
}

func (c Connection) Mechanism() Mechanism {
if c.username != "" && c.password != "" {
return ScramSha512
}

if c.enableAWSMSKIAM {
return AwsMskIam
}

return Plain
}

func (c Connection) Dialer(ctx context.Context) (*kafka.Dialer, error) {
dialer := &kafka.Dialer{
Timeout: 10 * time.Second,
DualStack: true,
}

switch c.Mechanism() {
case ScramSha512:
mechanism, err := scram.Mechanism(scram.SHA512, c.username, c.password)
if err != nil {
return nil, fmt.Errorf("failed to create SCRAM mechanism: %w", err)
}

dialer.SASLMechanism = mechanism
if !c.disableTLS {
dialer.TLS = &tls.Config{}
}
case AwsMskIam:
_awsCfg, err := awsCfg.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load aws configuration: %w", err)
}

dialer.SASLMechanism = aws_msk_iam_v2.NewMechanism(_awsCfg)
// We don't need to disable TLS for AWS IAM since MSK will always enable TLS.
dialer.TLS = &tls.Config{}
case Plain:
// No mechanism
default:
return nil, fmt.Errorf("unsupported kafka mechanism: %s", c.Mechanism())
}

return dialer, nil
}

func (c Connection) Transport() (*kafka.Transport, error) {
transport := &kafka.Transport{
DialTimeout: 10 * time.Second,
}

switch c.Mechanism() {
case ScramSha512:
mechanism, err := scram.Mechanism(scram.SHA512, c.username, c.password)
if err != nil {
return nil, fmt.Errorf("failed to create SCRAM mechanism: %w", err)
}

transport.SASL = mechanism
if !c.disableTLS {
transport.TLS = &tls.Config{}
}
case AwsMskIam:
_awsCfg, err := awsCfg.LoadDefaultConfig(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to load AWS configuration: %w", err)
}

transport.SASL = aws_msk_iam_v2.NewMechanism(_awsCfg)
if !c.disableTLS {
transport.TLS = &tls.Config{}
}
case Plain:
// No mechanism
default:
return nil, fmt.Errorf("unsupported kafka mechanism: %s", c.Mechanism())
}

return transport, nil
}
69 changes: 69 additions & 0 deletions lib/kafkalib/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package kafkalib

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestConnection_Mechanism(t *testing.T) {
{
c := NewConnection(false, false, "", "")
assert.Equal(t, Plain, c.Mechanism())
}
{
c := NewConnection(false, false, "username", "password")
assert.Equal(t, ScramSha512, c.Mechanism())

// AWS MSK IAM is enabled, but SCRAM is preferred
c = NewConnection(true, false, "username", "password")
assert.Equal(t, ScramSha512, c.Mechanism())
}
{
c := NewConnection(true, false, "", "")
assert.Equal(t, AwsMskIam, c.Mechanism())
}
}

func TestConnection_Dialer(t *testing.T) {
ctx := context.Background()
{
// Plain
c := NewConnection(false, false, "", "")
dialer, err := c.Dialer(ctx)
assert.NoError(t, err)
assert.Nil(t, dialer.TLS)
assert.Nil(t, dialer.SASLMechanism)
}
{
// SCRAM enabled with TLS
c := NewConnection(false, false, "username", "password")
dialer, err := c.Dialer(ctx)
assert.NoError(t, err)
assert.NotNil(t, dialer.TLS)
assert.NotNil(t, dialer.SASLMechanism)

// w/o TLS
c = NewConnection(false, true, "username", "password")
dialer, err = c.Dialer(ctx)
assert.NoError(t, err)
assert.Nil(t, dialer.TLS)
assert.NotNil(t, dialer.SASLMechanism)
}
{
// AWS IAM w/ TLS
c := NewConnection(true, false, "", "")
dialer, err := c.Dialer(ctx)
assert.NoError(t, err)
assert.NotNil(t, dialer.TLS)
assert.NotNil(t, dialer.SASLMechanism)

// w/o TLS (still enabled because AWS doesn't support not having TLS)
c = NewConnection(true, true, "", "")
dialer, err = c.Dialer(ctx)
assert.NoError(t, err)
assert.NotNil(t, dialer.TLS)
assert.NotNil(t, dialer.SASLMechanism)
}
}
47 changes: 10 additions & 37 deletions processes/consumer/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@ package consumer

import (
"context"
"crypto/tls"
"log/slog"
"sync"
"time"

awsCfg "github.com/aws/aws-sdk-go-v2/config"
"github.com/segmentio/kafka-go"
"github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2"
"github.com/segmentio/kafka-go/sasl/scram"

"github.com/artie-labs/transfer/lib/artie"
"github.com/artie-labs/transfer/lib/cdc/format"
"github.com/artie-labs/transfer/lib/config"
Expand All @@ -21,6 +15,7 @@ import (
"github.com/artie-labs/transfer/lib/logger"
"github.com/artie-labs/transfer/lib/telemetry/metrics/base"
"github.com/artie-labs/transfer/models"
"github.com/segmentio/kafka-go"
)

var topicToConsumer *TopicToConsumer
Expand Down Expand Up @@ -49,37 +44,15 @@ func (t *TopicToConsumer) Get(topic string) kafkalib.Consumer {
}

func StartConsumer(ctx context.Context, cfg config.Config, inMemDB *models.DatabaseData, dest destination.Baseline, metricsClient base.Client) {
slog.Info("Starting Kafka consumer...", slog.Any("config", cfg.Kafka))
dialer := &kafka.Dialer{
Timeout: 10 * time.Second,
DualStack: true,
}

// If using AWS MSK IAM, we expect this to be set in the ENV VAR
// (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_REGION, or the AWS Profile should be called default.)
if cfg.Kafka.EnableAWSMSKIAM {
_awsCfg, err := awsCfg.LoadDefaultConfig(ctx)
if err != nil {
logger.Panic("Failed to load aws configuration", slog.Any("err", err))
}

dialer.SASLMechanism = aws_msk_iam_v2.NewMechanism(_awsCfg)
if !cfg.Kafka.DisableTLS {
dialer.TLS = &tls.Config{}
}
}

// If username and password are provided, we'll use SCRAM w/ SHA512.
if cfg.Kafka.Username != "" {
mechanism, err := scram.Mechanism(scram.SHA512, cfg.Kafka.Username, cfg.Kafka.Password)
if err != nil {
logger.Panic("Failed to create SCRAM mechanism", slog.Any("err", err))
}

dialer.SASLMechanism = mechanism
if !cfg.Kafka.DisableTLS {
dialer.TLS = &tls.Config{}
}
kafkaConn := kafkalib.NewConnection(cfg.Kafka.EnableAWSMSKIAM, cfg.Kafka.DisableTLS, cfg.Kafka.Username, cfg.Kafka.Password)
slog.Info("Starting Kafka consumer...",
slog.Any("config", cfg.Kafka),
slog.Any("authMechanism", kafkaConn.Mechanism()),
)

dialer, err := kafkaConn.Dialer(ctx)
if err != nil {
logger.Panic("Failed to create Kafka dialer", slog.Any("err", err))
}

tcFmtMap := NewTcFmtMap()
Expand Down

0 comments on commit 4036e76

Please sign in to comment.