Skip to content

Commit

Permalink
feat: add support for Auto IAM AuthN (#358)
Browse files Browse the repository at this point in the history
AlloyDB uses a metadata exchange prior to handing over to the database
driver to begin the Postgres protocol. As part of this metadata
exchange, the dialer sends the IAM principal's OAuth2 token which the
proxy server uses to verify connection permissions and, if IAM AuthN is
requested, to authenticate to the database.
  • Loading branch information
enocom authored Nov 15, 2023
1 parent c7e192d commit e50dd25
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 19 deletions.
157 changes: 155 additions & 2 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"crypto/rsa"
"crypto/tls"
_ "embed"
"encoding/binary"
"errors"
"fmt"
"net"
"strings"
Expand All @@ -28,12 +30,16 @@ import (
"time"

alloydbadmin "cloud.google.com/go/alloydb/apiv1beta"
"cloud.google.com/go/alloydb/connectors/apiv1beta/connectorspb"
"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/alloydb"
"cloud.google.com/go/alloydbconn/internal/trace"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/protobuf/proto"
)

const (
Expand All @@ -42,6 +48,9 @@ const (
defaultTCPKeepAlive = 30 * time.Second
// serverProxyPort is the port the server-side proxy receives connections on.
serverProxyPort = "5433"
// ioTimeout is the maximum amount of time to wait before aborting a
// metadata exhange
ioTimeout = 30 * time.Second
)

var (
Expand Down Expand Up @@ -86,6 +95,12 @@ type Dialer struct {
// dialFunc is the function used to connect to the address on the named
// network. By default it is golang.org/x/net/proxy#Dial.
dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)

useIAMAuthN bool
iamTokenSource oauth2.TokenSource
userAgent string

buffer *buffer
}

// NewDialer creates a new Dialer.
Expand All @@ -97,16 +112,17 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: alloydb.RefreshTimeout,
dialFunc: proxy.Dial,
useragents: []string{userAgent},
userAgents: []string{userAgent},
}
for _, opt := range opts {
opt(cfg)
if cfg.err != nil {
return nil, cfg.err
}
}
userAgent := strings.Join(cfg.userAgents, " ")
// Add this to the end to make sure it's not overridden
cfg.adminOpts = append(cfg.adminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))
cfg.adminOpts = append(cfg.adminOpts, option.WithUserAgent(userAgent))

if cfg.rsaKey == nil {
key, err := getDefaultKeys()
Expand All @@ -116,6 +132,16 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg.rsaKey = key
}

// If no token source is configured, use ADC's token source.
ts := cfg.tokenSource
if ts == nil {
var err error
ts, err = google.DefaultTokenSource(ctx, CloudPlatformScope)
if err != nil {
return nil, err
}
}

client, err := alloydbadmin.NewAlloyDBAdminRESTClient(ctx, cfg.adminOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create AlloyDB Admin API client: %v", err)
Expand All @@ -139,6 +165,10 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
defaultDialCfg: dialCfg,
dialerID: uuid.New().String(),
dialFunc: cfg.dialFunc,
useIAMAuthN: cfg.useIAMAuthN,
iamTokenSource: ts,
userAgent: userAgent,
buffer: newBuffer(),
}
return d, nil
}
Expand Down Expand Up @@ -212,6 +242,14 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
return nil, errtype.NewDialError("handshake failed", i.String(), err)
}

// The metadata exchange must occur after the TLS connection is established
// to avoid leaking sensitive information.
err = d.metadataExchange(tlsConn)
if err != nil {
_ = tlsConn.Close() // best effort close attempt
return nil, err
}

latency := time.Since(startTime).Milliseconds()
go func() {
n := atomic.AddUint64(&i.OpenConns, 1)
Expand All @@ -225,6 +263,121 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
}), nil
}

// metadataExchange sends metadata about the connection prior to the database
// protocol taking over. The exchange consists of four steps:
//
// 1. Prepare a MetadataExchangeRequest including the IAM Principal's OAuth2
// token, the user agent, and the requested authentication type.
//
// 2. Write the size of the message as a big endian uint32 (4 bytes) to the
// server followed by the marshaled message. The length does not include the
// initial four bytes.
//
// 3. Read a big endian uint32 (4 bytes) from the server. This is the
// MetadataExchangeResponse message length and does not include the initial
// four bytes.
//
// 4. Unmarshal the response using the message length in step 3. If the
// response is not OK, return the response's error. If there is no error, the
// metadata exchange has succeeded and the connection is complete.
//
// Subsequent interactions with the server use the database protocol.
func (d *Dialer) metadataExchange(conn net.Conn) error {
tok, err := d.iamTokenSource.Token()
if err != nil {
return err
}
authType := connectorspb.MetadataExchangeRequest_DB_NATIVE
if d.useIAMAuthN {
authType = connectorspb.MetadataExchangeRequest_AUTO_IAM
}
req := &connectorspb.MetadataExchangeRequest{
UserAgent: d.userAgent,
AuthType: authType,
Oauth2Token: tok.AccessToken,
}
m, err := proto.Marshal(req)
if err != nil {
return err
}
b := d.buffer.get()
defer d.buffer.put(b)

buf := *b
reqSize := proto.Size(req)
binary.BigEndian.PutUint32(buf, uint32(reqSize))
buf = append(buf[:4], m...)

// Set IO deadline before write
err = conn.SetDeadline(time.Now().Add(ioTimeout))
if err != nil {
return err
}
defer conn.SetDeadline(time.Time{})

_, err = conn.Write(buf)
if err != nil {
return err
}

// Reset IO deadline before read
err = conn.SetDeadline(time.Now().Add(ioTimeout))
if err != nil {
return err
}
defer conn.SetDeadline(time.Time{})

buf = buf[:4]
_, err = conn.Read(buf)
if err != nil {
return err
}

respSize := binary.BigEndian.Uint32(buf)
resp := buf[:respSize]
_, err = conn.Read(resp)
if err != nil {
return err
}

var mdxResp connectorspb.MetadataExchangeResponse
err = proto.Unmarshal(resp, &mdxResp)
if err != nil {
return err
}

if mdxResp.GetResponseCode() != connectorspb.MetadataExchangeResponse_OK {
return errors.New(mdxResp.GetError())
}

return nil
}

const maxMessageSize = 16 * 1024 // 16 kb

type buffer struct {
pool sync.Pool
}

func newBuffer() *buffer {
return &buffer{
pool: sync.Pool{
New: func() any {
buf := make([]byte, maxMessageSize)
return &buf
},
},
}
}

func (b *buffer) get() *[]byte {
return b.pool.Get().(*[]byte)
}

func (b *buffer) put(buf *[]byte) {
b.pool.Put(buf)
}

// newInstrumentedConn initializes an instrumentedConn that on closing will
// decrement the number of open connects and record the result.
func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn {
Expand Down
34 changes: 21 additions & 13 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package alloydbconn
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
Expand All @@ -35,7 +36,7 @@ import (
type stubTokenSource struct{}

func (stubTokenSource) Token() (*oauth2.Token, error) {
return nil, nil
return &oauth2.Token{}, nil
}

func TestDialerCanConnectToInstance(t *testing.T) {
Expand All @@ -54,7 +55,8 @@ func TestDialerCanConnectToInstance(t *testing.T) {
t.Fatalf("%v", err)
}
}()
c, err := alloydbadmin.NewAlloyDBAdminRESTClient(ctx, option.WithHTTPClient(mc), option.WithEndpoint(url))
c, err := alloydbadmin.NewAlloyDBAdminRESTClient(
ctx, option.WithHTTPClient(mc), option.WithEndpoint(url))
if err != nil {
t.Fatalf("expected NewClient to succeed, but got error: %v", err)
}
Expand All @@ -65,19 +67,25 @@ func TestDialerCanConnectToInstance(t *testing.T) {
}
d.client = c

conn, err := d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
if err != nil {
t.Fatalf("expected Dial to succeed, but got error: %v", err)
// Run several tests to ensure the underlying shared buffer is properly
// reset between connections.
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
conn, err := d.Dial(ctx, "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
if err != nil {
t.Fatalf("expected Dial to succeed, but got error: %v", err)
}
defer conn.Close()
data, err := io.ReadAll(conn)
if err != nil {
t.Fatalf("expected ReadAll to succeed, got error %v", err)
}
if string(data) != "my-instance" {
t.Fatalf("expected known response from the server, but got %v", string(data))
}
})
}
defer conn.Close()

data, err := io.ReadAll(conn)
if err != nil {
t.Fatalf("expected ReadAll to succeed, got error %v", err)
}
if string(data) != "my-instance" {
t.Fatalf("expected known response from the server, but got %v", string(data))
}
}

func TestDialWithAdminAPIErrors(t *testing.T) {
Expand Down
49 changes: 49 additions & 0 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ import (
"context"
"database/sql"
"fmt"
"net"
"os"
"testing"
"time"

"cloud.google.com/go/alloydbconn"
"cloud.google.com/go/alloydbconn/driver/pgxv4"
"github.com/jackc/pgx/v4"
)

var (
Expand All @@ -48,6 +51,8 @@ func requireAlloyDBVars(t *testing.T) {
t.Fatal("'ALLOYDB_INSTANCE_NAME' env var not set")
case alloydbUser:
t.Fatal("'ALLOYDB_USER' env var not set")
case alloydbIAMUser:
t.Fatal("'ALLOYDB_IAM_USER' env var not set")
case alloydbPass:
t.Fatal("'ALLOYDB_PASS' env var not set")
case alloydbDB:
Expand Down Expand Up @@ -75,6 +80,13 @@ func TestPgxConnect(t *testing.T) {
// best effort
_ = cleanup()
}()

var now time.Time
err = pool.QueryRow(context.Background(), "SELECT NOW()").Scan(&now)
if err != nil {
t.Fatalf("QueryRow failed: %s", err)
}
t.Log(now)
}

// TestDatabaseSQLConnect uses the latest pgx driver under the hood
Expand Down Expand Up @@ -178,3 +190,40 @@ func TestDirectPGXAutoIAMAuthN(t *testing.T) {
}
t.Log(tt)
}

func TestAutoIAMAuthN(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
ctx := context.Background()

d, err := alloydbconn.NewDialer(ctx, alloydbconn.WithIAMAuthN())
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}

dsn := fmt.Sprintf(
"user=%s dbname=%s sslmode=disable",
alloydbIAMUser, alloydbDB,
)
config, err := pgx.ParseConfig(dsn)
if err != nil {
t.Fatalf("failed to parse pgx config: %v", err)
}

config.DialFunc = func(ctx context.Context, network string, instance string) (net.Conn, error) {
return d.Dial(ctx, alloydbInstanceName)
}

conn, connErr := pgx.ConnectConfig(ctx, config)
if connErr != nil {
t.Fatalf("failed to connect: %s", connErr)
}
defer conn.Close(ctx)

var tt time.Time
if err := conn.QueryRow(context.Background(), "SELECT NOW()").Scan(&tt); err != nil {
t.Fatal(err)
}
t.Log(tt)
}
5 changes: 3 additions & 2 deletions internal/alloydb/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ func fetchEphemeralCert(
Parent: fmt.Sprintf(
"projects/%s/locations/%s/clusters/%s", inst.project, inst.region, inst.cluster,
),
PublicKey: buf.String(),
CertDuration: durationpb.New(time.Second * 3600),
PublicKey: buf.String(),
CertDuration: durationpb.New(time.Second * 3600),
UseMetadataExchange: true,
}
resp, err := cl.GenerateClientCertificate(ctx, req)
if err != nil {
Expand Down
Loading

0 comments on commit e50dd25

Please sign in to comment.