diff --git a/dialer.go b/dialer.go index c491f34e..f044d1f8 100644 --- a/dialer.go +++ b/dialer.go @@ -20,6 +20,8 @@ import ( "crypto/rsa" "crypto/tls" _ "embed" + "encoding/binary" + "errors" "fmt" "net" "strings" @@ -28,12 +30,16 @@ import ( "time" alloydbadmin "cloud.google.com/go/alloydb/apiv1beta" + "cloud.google.com/go/alloydb/connectors/apiv1beta/connectorspb" "cloud.google.com/go/alloydbconn/errtype" "cloud.google.com/go/alloydbconn/internal/alloydb" "cloud.google.com/go/alloydbconn/internal/trace" "github.com/google/uuid" "golang.org/x/net/proxy" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" "google.golang.org/api/option" + "google.golang.org/protobuf/proto" ) const ( @@ -42,6 +48,9 @@ const ( defaultTCPKeepAlive = 30 * time.Second // serverProxyPort is the port the server-side proxy receives connections on. serverProxyPort = "5433" + // ioTimeout is the maximum amount of time to wait before aborting a + // metadata exhange + ioTimeout = 30 * time.Second ) var ( @@ -86,6 +95,12 @@ type Dialer struct { // dialFunc is the function used to connect to the address on the named // network. By default it is golang.org/x/net/proxy#Dial. dialFunc func(cxt context.Context, network, addr string) (net.Conn, error) + + useIAMAuthN bool + iamTokenSource oauth2.TokenSource + userAgent string + + buffer *buffer } // NewDialer creates a new Dialer. @@ -97,7 +112,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { cfg := &dialerConfig{ refreshTimeout: alloydb.RefreshTimeout, dialFunc: proxy.Dial, - useragents: []string{userAgent}, + userAgents: []string{userAgent}, } for _, opt := range opts { opt(cfg) @@ -105,8 +120,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { return nil, cfg.err } } + userAgent := strings.Join(cfg.userAgents, " ") // Add this to the end to make sure it's not overridden - cfg.adminOpts = append(cfg.adminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " "))) + cfg.adminOpts = append(cfg.adminOpts, option.WithUserAgent(userAgent)) if cfg.rsaKey == nil { key, err := getDefaultKeys() @@ -116,6 +132,16 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { cfg.rsaKey = key } + // If no token source is configured, use ADC's token source. + ts := cfg.tokenSource + if ts == nil { + var err error + ts, err = google.DefaultTokenSource(ctx, CloudPlatformScope) + if err != nil { + return nil, err + } + } + client, err := alloydbadmin.NewAlloyDBAdminRESTClient(ctx, cfg.adminOpts...) if err != nil { return nil, fmt.Errorf("failed to create AlloyDB Admin API client: %v", err) @@ -139,6 +165,10 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { defaultDialCfg: dialCfg, dialerID: uuid.New().String(), dialFunc: cfg.dialFunc, + useIAMAuthN: cfg.useIAMAuthN, + iamTokenSource: ts, + userAgent: userAgent, + buffer: newBuffer(), } return d, nil } @@ -212,6 +242,14 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) return nil, errtype.NewDialError("handshake failed", i.String(), err) } + // The metadata exchange must occur after the TLS connection is established + // to avoid leaking sensitive information. + err = d.metadataExchange(tlsConn) + if err != nil { + _ = tlsConn.Close() // best effort close attempt + return nil, err + } + latency := time.Since(startTime).Milliseconds() go func() { n := atomic.AddUint64(&i.OpenConns, 1) @@ -225,6 +263,121 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) }), nil } +// metadataExchange sends metadata about the connection prior to the database +// protocol taking over. The exchange consists of four steps: +// +// 1. Prepare a MetadataExchangeRequest including the IAM Principal's OAuth2 +// token, the user agent, and the requested authentication type. +// +// 2. Write the size of the message as a big endian uint32 (4 bytes) to the +// server followed by the marshaled message. The length does not include the +// initial four bytes. +// +// 3. Read a big endian uint32 (4 bytes) from the server. This is the +// MetadataExchangeResponse message length and does not include the initial +// four bytes. +// +// 4. Unmarshal the response using the message length in step 3. If the +// response is not OK, return the response's error. If there is no error, the +// metadata exchange has succeeded and the connection is complete. +// +// Subsequent interactions with the server use the database protocol. +func (d *Dialer) metadataExchange(conn net.Conn) error { + tok, err := d.iamTokenSource.Token() + if err != nil { + return err + } + authType := connectorspb.MetadataExchangeRequest_DB_NATIVE + if d.useIAMAuthN { + authType = connectorspb.MetadataExchangeRequest_AUTO_IAM + } + req := &connectorspb.MetadataExchangeRequest{ + UserAgent: d.userAgent, + AuthType: authType, + Oauth2Token: tok.AccessToken, + } + m, err := proto.Marshal(req) + if err != nil { + return err + } + b := d.buffer.get() + defer d.buffer.put(b) + + buf := *b + reqSize := proto.Size(req) + binary.BigEndian.PutUint32(buf, uint32(reqSize)) + buf = append(buf[:4], m...) + + // Set IO deadline before write + err = conn.SetDeadline(time.Now().Add(ioTimeout)) + if err != nil { + return err + } + defer conn.SetDeadline(time.Time{}) + + _, err = conn.Write(buf) + if err != nil { + return err + } + + // Reset IO deadline before read + err = conn.SetDeadline(time.Now().Add(ioTimeout)) + if err != nil { + return err + } + defer conn.SetDeadline(time.Time{}) + + buf = buf[:4] + _, err = conn.Read(buf) + if err != nil { + return err + } + + respSize := binary.BigEndian.Uint32(buf) + resp := buf[:respSize] + _, err = conn.Read(resp) + if err != nil { + return err + } + + var mdxResp connectorspb.MetadataExchangeResponse + err = proto.Unmarshal(resp, &mdxResp) + if err != nil { + return err + } + + if mdxResp.GetResponseCode() != connectorspb.MetadataExchangeResponse_OK { + return errors.New(mdxResp.GetError()) + } + + return nil +} + +const maxMessageSize = 16 * 1024 // 16 kb + +type buffer struct { + pool sync.Pool +} + +func newBuffer() *buffer { + return &buffer{ + pool: sync.Pool{ + New: func() any { + buf := make([]byte, maxMessageSize) + return &buf + }, + }, + } +} + +func (b *buffer) get() *[]byte { + return b.pool.Get().(*[]byte) +} + +func (b *buffer) put(buf *[]byte) { + b.pool.Put(buf) +} + // newInstrumentedConn initializes an instrumentedConn that on closing will // decrement the number of open connects and record the result. func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn { diff --git a/dialer_test.go b/dialer_test.go index ce22417c..15aeb15c 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -17,6 +17,7 @@ package alloydbconn import ( "context" "errors" + "fmt" "io" "net" "os" @@ -35,7 +36,7 @@ import ( type stubTokenSource struct{} func (stubTokenSource) Token() (*oauth2.Token, error) { - return nil, nil + return &oauth2.Token{}, nil } func TestDialerCanConnectToInstance(t *testing.T) { @@ -54,7 +55,8 @@ func TestDialerCanConnectToInstance(t *testing.T) { t.Fatalf("%v", err) } }() - c, err := alloydbadmin.NewAlloyDBAdminRESTClient(ctx, option.WithHTTPClient(mc), option.WithEndpoint(url)) + c, err := alloydbadmin.NewAlloyDBAdminRESTClient( + ctx, option.WithHTTPClient(mc), option.WithEndpoint(url)) if err != nil { t.Fatalf("expected NewClient to succeed, but got error: %v", err) } @@ -65,19 +67,25 @@ func TestDialerCanConnectToInstance(t *testing.T) { } d.client = c - conn, err := d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance") - if err != nil { - t.Fatalf("expected Dial to succeed, but got error: %v", err) + // Run several tests to ensure the underlying shared buffer is properly + // reset between connections. + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + conn, err := d.Dial(ctx, "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance") + if err != nil { + t.Fatalf("expected Dial to succeed, but got error: %v", err) + } + defer conn.Close() + data, err := io.ReadAll(conn) + if err != nil { + t.Fatalf("expected ReadAll to succeed, got error %v", err) + } + if string(data) != "my-instance" { + t.Fatalf("expected known response from the server, but got %v", string(data)) + } + }) } - defer conn.Close() - data, err := io.ReadAll(conn) - if err != nil { - t.Fatalf("expected ReadAll to succeed, got error %v", err) - } - if string(data) != "my-instance" { - t.Fatalf("expected known response from the server, but got %v", string(data)) - } } func TestDialWithAdminAPIErrors(t *testing.T) { diff --git a/e2e_test.go b/e2e_test.go index 77ce3f88..c5931273 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -18,11 +18,14 @@ import ( "context" "database/sql" "fmt" + "net" "os" "testing" "time" + "cloud.google.com/go/alloydbconn" "cloud.google.com/go/alloydbconn/driver/pgxv4" + "github.com/jackc/pgx/v4" ) var ( @@ -48,6 +51,8 @@ func requireAlloyDBVars(t *testing.T) { t.Fatal("'ALLOYDB_INSTANCE_NAME' env var not set") case alloydbUser: t.Fatal("'ALLOYDB_USER' env var not set") + case alloydbIAMUser: + t.Fatal("'ALLOYDB_IAM_USER' env var not set") case alloydbPass: t.Fatal("'ALLOYDB_PASS' env var not set") case alloydbDB: @@ -75,6 +80,13 @@ func TestPgxConnect(t *testing.T) { // best effort _ = cleanup() }() + + var now time.Time + err = pool.QueryRow(context.Background(), "SELECT NOW()").Scan(&now) + if err != nil { + t.Fatalf("QueryRow failed: %s", err) + } + t.Log(now) } // TestDatabaseSQLConnect uses the latest pgx driver under the hood @@ -178,3 +190,40 @@ func TestDirectPGXAutoIAMAuthN(t *testing.T) { } t.Log(tt) } + +func TestAutoIAMAuthN(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + ctx := context.Background() + + d, err := alloydbconn.NewDialer(ctx, alloydbconn.WithIAMAuthN()) + if err != nil { + t.Fatalf("failed to init Dialer: %v", err) + } + + dsn := fmt.Sprintf( + "user=%s dbname=%s sslmode=disable", + alloydbIAMUser, alloydbDB, + ) + config, err := pgx.ParseConfig(dsn) + if err != nil { + t.Fatalf("failed to parse pgx config: %v", err) + } + + config.DialFunc = func(ctx context.Context, network string, instance string) (net.Conn, error) { + return d.Dial(ctx, alloydbInstanceName) + } + + conn, connErr := pgx.ConnectConfig(ctx, config) + if connErr != nil { + t.Fatalf("failed to connect: %s", connErr) + } + defer conn.Close(ctx) + + var tt time.Time + if err := conn.QueryRow(context.Background(), "SELECT NOW()").Scan(&tt); err != nil { + t.Fatal(err) + } + t.Log(tt) +} diff --git a/internal/alloydb/refresh.go b/internal/alloydb/refresh.go index 1625cde9..1c131ac8 100644 --- a/internal/alloydb/refresh.go +++ b/internal/alloydb/refresh.go @@ -92,8 +92,9 @@ func fetchEphemeralCert( Parent: fmt.Sprintf( "projects/%s/locations/%s/clusters/%s", inst.project, inst.region, inst.cluster, ), - PublicKey: buf.String(), - CertDuration: durationpb.New(time.Second * 3600), + PublicKey: buf.String(), + CertDuration: durationpb.New(time.Second * 3600), + UseMetadataExchange: true, } resp, err := cl.GenerateClientCertificate(ctx, req) if err != nil { diff --git a/internal/mock/alloydb.go b/internal/mock/alloydb.go index c56d5112..b3295d94 100644 --- a/internal/mock/alloydb.go +++ b/internal/mock/alloydb.go @@ -21,10 +21,15 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/binary" + "fmt" "math/big" "net" "testing" "time" + + "cloud.google.com/go/alloydb/connectors/apiv1beta/connectorspb" + "google.golang.org/protobuf/proto" ) // Option configures a FakeAlloyDBInstance @@ -231,6 +236,12 @@ func StartServerProxy(t *testing.T, inst FakeAlloyDBInstance) func() { if err != nil { return } + if err := metadataExchange(conn); err != nil { + conn.Close() + return + } + + // Database protocol takes over from here. conn.Write([]byte(inst.name)) conn.Close() } @@ -241,3 +252,70 @@ func StartServerProxy(t *testing.T, inst FakeAlloyDBInstance) func() { ln.Close() } } + +// metadataExchange mimics server side behavior in four steps: +// +// 1. Read a big endian uint32 (4 bytes) from the client. This is the number of +// bytes the message consumes. The length does not include the initial four +// bytes. +// +// 2. Read the message from the client using the message length and unmarshal +// it into a MetadataExchangeResponse message. +// +// The real server implementation will then validate the client has connection +// permissions using the provided OAuth2 token based on the auth type. Here in +// the test implementation, the server does nothing. +// +// 3. Prepare a response and write the size of the response as a uint32 (4 +// bytes) +// +// 4. Marshal the response to bytes and write those to the client as well. +// +// Subsequent interactions with the test server use the database protocol. +func metadataExchange(conn net.Conn) error { + msgSize := make([]byte, 4) + n, err := conn.Read(msgSize) + if err != nil { + return err + } + if n != 4 { + return fmt.Errorf("read %d bytes, want = 4", n) + } + + size := binary.BigEndian.Uint32(msgSize) + buf := make([]byte, size) + n, err = conn.Read(buf) + if err != nil { + return err + } + if n != int(size) { + return fmt.Errorf("read %d bytes, want = %d", n, size) + } + + m := &connectorspb.MetadataExchangeRequest{} + err = proto.Unmarshal(buf, m) + if err != nil { + return err + } + + resp := &connectorspb.MetadataExchangeResponse{ + ResponseCode: connectorspb.MetadataExchangeResponse_OK, + } + data, err := proto.Marshal(resp) + if err != nil { + return err + } + respSize := proto.Size(resp) + buf = make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(respSize)) + + buf = append(buf, data...) + n, err = conn.Write(buf) + if err != nil { + return err + } + if n != len(buf) { + return fmt.Errorf("write %d bytes, want = %d", n, len(buf)) + } + return nil +} diff --git a/options.go b/options.go index c803d10b..fc653370 100644 --- a/options.go +++ b/options.go @@ -41,7 +41,8 @@ type dialerConfig struct { dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) refreshTimeout time.Duration tokenSource oauth2.TokenSource - useragents []string + userAgents []string + useIAMAuthN bool // err tracks any dialer options that may have failed. err error } @@ -88,7 +89,7 @@ func WithCredentialsJSON(b []byte) Option { // WithUserAgent returns an Option that sets the User-Agent. func WithUserAgent(ua string) Option { return func(d *dialerConfig) { - d.useragents = append(d.useragents, ua) + d.userAgents = append(d.userAgents, ua) } } @@ -151,6 +152,16 @@ func WithDialFunc(dial func(ctx context.Context, network, addr string) (net.Conn } } +// WithIAMAuthN enables automatic IAM Authentication. If no token source has +// been configured (such as with WithTokenSource, WithCredentialsFile, etc), the +// dialer will use the default token source as defined by +// https://pkg.go.dev/golang.org/x/oauth2/google#FindDefaultCredentialsWithParams. +func WithIAMAuthN() Option { + return func(d *dialerConfig) { + d.useIAMAuthN = true + } +} + // A DialOption is an option for configuring how a Dialer's Dial call is executed. type DialOption func(d *dialCfg)