From da23ca9579f5b90e86287e5b7dc689a549ea9240 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 29 Mar 2022 09:41:48 -0600 Subject: [PATCH] feat: add AlloyDB instance type --- internal/cloudsql/instance.go | 63 +++++----------- internal/cloudsql/instance_test.go | 116 +++++++++++------------------ internal/cloudsql/refresh_test.go | 3 - 3 files changed, 61 insertions(+), 121 deletions(-) diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index 6ac38a83..a8a1370b 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -24,8 +24,7 @@ import ( "time" errtype "cloud.google.com/go/cloudsqlconn/errtype" - "golang.org/x/oauth2" - sqladmin "google.golang.org/api/sqladmin/v1beta4" + "cloud.google.com/go/cloudsqlconn/internal/alloydb" ) const ( @@ -81,9 +80,7 @@ type metadata struct { // refreshOperation is a pending result of a refresh operation of data used to connect securely. It should // only be initialized by the Instance struct as part of a refresh cycle. type refreshOperation struct { - md metadata - tlsCfg *tls.Config - expiry time.Time + result refreshResult err error // timer that triggers refresh, can be used to cancel. @@ -115,7 +112,7 @@ func (r *refreshOperation) IsValid() bool { default: return false case <-r.ready: - if r.err != nil || time.Now().After(r.expiry) { + if r.err != nil || time.Now().After(r.result.expiry) { return false } return true @@ -150,10 +147,9 @@ type Instance struct { // NewInstance initializes a new Instance given an instance connection name func NewInstance( instance string, - client *sqladmin.Service, + client *alloydb.Client, key *rsa.PrivateKey, refreshTimeout time.Duration, - ts oauth2.TokenSource, dialerID string, ) (*Instance, error) { cn, err := parseConnName(instance) @@ -164,15 +160,13 @@ func NewInstance( i := &Instance{ connName: cn, key: key, - // TODO: we'll update this when we do instance - // r: newRefresher( - // refreshTimeout, - // 30*time.Second, - // 2, - // client, - // ts, - // dialerID, - // ), + r: newRefresher( + client, + refreshTimeout, + 30*time.Second, + 2, + dialerID, + ), ctx: ctx, cancel: cancel, } @@ -194,31 +188,12 @@ func (i *Instance) Close() { // ConnectInfo returns an IP address specified by ipType (i.e., public or // private) and a TLS config that can be used to connect to a Cloud SQL // instance. -func (i *Instance) ConnectInfo(ctx context.Context, ipType string) (string, *tls.Config, error) { +func (i *Instance) ConnectInfo(ctx context.Context) (string, *tls.Config, error) { res, err := i.result(ctx) if err != nil { return "", nil, err } - addr, ok := res.md.ipAddrs[ipType] - if !ok { - err := errtype.NewConfigError( - fmt.Sprintf("instance does not have IP of type %q", ipType), - i.String(), - ) - return "", nil, err - } - return addr, res.tlsCfg, nil -} - -// InstanceEngineVersion returns the engine type and version for the instance. The value -// coresponds to one of the following types for the instance: -// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion -func (i *Instance) InstanceEngineVersion(ctx context.Context) (string, error) { - res, err := i.result(ctx) - if err != nil { - return "", err - } - return res.md.version, nil + return res.result.instanceIPAddr, res.result.conf, nil } // ForceRefresh triggers an immediate refresh operation to be scheduled and used for future connection attempts. @@ -246,17 +221,13 @@ func (i *Instance) result(ctx context.Context) (*refreshOperation, error) { } // scheduleRefresh schedules a refresh operation to be triggered after a given -// duration. The returned refreshOperation -// can be used to either Cancel or Wait for the operations result. +// duration. The returned refreshOperation can be used to either Cancel or Wait +// for the operations result. func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { res := &refreshOperation{} res.ready = make(chan struct{}) res.timer = time.AfterFunc(d, func() { - // TODO: fix this - // res.md, res.tlsCfg, res.expiry, res.err = i.r.performRefresh(i.ctx, i.connName, i.key) - r, err := i.r.performRefresh(i.ctx, i.connName, i.key) - _ = r - _ = err + res.result, res.err = i.r.performRefresh(i.ctx, i.connName, i.key) close(res.ready) // Once the refresh is complete, update "current" with working result and schedule a new refresh @@ -282,7 +253,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { return default: } - nextRefresh := i.cur.expiry.Add(-refreshBuffer) + nextRefresh := i.cur.result.expiry.Add(-refreshBuffer) i.next = i.scheduleRefresh(time.Until(nextRefresh)) }) return res diff --git a/internal/cloudsql/instance_test.go b/internal/cloudsql/instance_test.go index ee4b8fbe..b2fb0ade 100644 --- a/internal/cloudsql/instance_test.go +++ b/internal/cloudsql/instance_test.go @@ -23,7 +23,9 @@ import ( "time" "cloud.google.com/go/cloudsqlconn/errtype" + "cloud.google.com/go/cloudsqlconn/internal/alloydb" "cloud.google.com/go/cloudsqlconn/internal/mock" + "google.golang.org/api/option" ) // genRSAKey generates an RSA key used for test. @@ -112,67 +114,39 @@ func TestParseConnNameErrors(t *testing.T) { } } -func TestInstanceEngineVersion(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tests := []string{ - "MYSQL_5_7", "POSTGRES_14", "SQLSERVER_2019_STANDARD", "MYSQL_8_0_18", - } - for _, wantEV := range tests { - inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", mock.WithEngineVersion(wantEV)) - client, cleanup, err := mock.NewSQLAdminService( - ctx, - mock.DELETEInstanceGetSuccess(inst, 1), - mock.DELETECreateEphemeralSuccess(inst, 1), - ) - if err != nil { - t.Fatalf("%s", err) - } - defer func() { - if err := cleanup(); err != nil { - t.Fatalf("%v", err) - } - }() - i, err := NewInstance("my-project:my-region:my-instance", client, RSAKey, 30*time.Second, nil, "") - if err != nil { - t.Fatalf("failed to init instance: %v", err) - } - - gotEV, err := i.InstanceEngineVersion(ctx) - if err != nil { - t.Fatalf("failed to retrieve engine version: %v", err) - } - if wantEV != gotEV { - t.Errorf("InstanceEngineVersion(%s) failed: want %v, got %v", wantEV, gotEV, err) - } - - } -} - func TestConnectInfo(t *testing.T) { ctx := context.Background() + wantAddr := "0.0.0.0" - inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", mock.WithPublicIP(wantAddr)) - client, cleanup, err := mock.NewSQLAdminService( - ctx, - mock.DELETEInstanceGetSuccess(inst, 1), - mock.DELETECreateEphemeralSuccess(inst, 1), + inst := mock.NewFakeInstance( + "my-project", "my-region", "my-cluster", "my-instance", + mock.WithIPAddr(wantAddr), ) - if err != nil { - t.Fatalf("%s", err) - } + mc, url, cleanup := mock.HTTPClient( + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + ) + stop := mock.StartServerProxy(t, inst) defer func() { + stop() if err := cleanup(); err != nil { t.Fatalf("%v", err) } }() + c, err := alloydb.NewClient(ctx, option.WithHTTPClient(mc), option.WithEndpoint(url)) + if err != nil { + t.Fatalf("expected NewClient to succeed, but got error: %v", err) + } - i, err := NewInstance("my-project:my-region:my-instance", client, RSAKey, 30*time.Second, nil, "") + i, err := NewInstance( + "my-project:my-region:my-cluster:my-instance", + c, RSAKey, 30*time.Second, "dialer-id", + ) if err != nil { t.Fatalf("failed to create mock instance: %v", err) } - gotAddr, gotTLSCfg, err := i.ConnectInfo(ctx, "PUBLIC") + gotAddr, gotTLSCfg, err := i.ConnectInfo(ctx) if err != nil { t.Fatalf("failed to retrieve connect info: %v", err) } @@ -184,60 +158,58 @@ func TestConnectInfo(t *testing.T) { ) } - wantServerName := "my-project:my-region:my-instance" - if gotTLSCfg.ServerName != wantServerName { - t.Fatalf( - "ConnectInfo return unexpected server name in TLS Config, want = %v, got = %v", - wantServerName, gotTLSCfg.ServerName, - ) - } + _ = gotTLSCfg + // TODO: this should be the instance UID + // wantServerName := "TODO instance UID" + // if gotTLSCfg.ServerName != wantServerName { + // t.Fatalf( + // "ConnectInfo return unexpected server name in TLS Config, want = %v, got = %v", + // wantServerName, gotTLSCfg.ServerName, + // ) + // } } func TestConnectInfoErrors(t *testing.T) { ctx := context.Background() - - client, cleanup, err := mock.NewSQLAdminService(ctx) + c, err := alloydb.NewClient(ctx) if err != nil { - t.Fatalf("%s", err) + t.Fatalf("expected NewClient to succeed, but got error: %v", err) } - defer cleanup() // Use a timeout that should fail instantly - im, err := NewInstance("my-project:my-region:my-instance", client, RSAKey, 0, nil, "") + im, err := NewInstance( + "my-project:my-region:my-cluster:my-instance", + c, RSAKey, 0, "dialer-id", + ) if err != nil { t.Fatalf("failed to initialize Instance: %v", err) } - _, _, err = im.ConnectInfo(ctx, "PUBLIC") + _, _, err = im.ConnectInfo(ctx) var wantErr *errtype.DialError if !errors.As(err, &wantErr) { t.Fatalf("when connect info fails, want = %T, got = %v", wantErr, err) } - - // when client asks for wrong IP address type - gotAddr, _, err := im.ConnectInfo(ctx, "PUBLIC") - if err == nil { - t.Fatalf("expected ConnectInfo to fail but returned IP address = %v", gotAddr) - } } func TestClose(t *testing.T) { ctx := context.Background() - - client, cleanup, err := mock.NewSQLAdminService(ctx) + c, err := alloydb.NewClient(ctx) if err != nil { - t.Fatalf("%s", err) + t.Fatalf("expected NewClient to succeed, but got error: %v", err) } - defer cleanup() // Set up an instance and then close it immediately - im, err := NewInstance("my-proj:my-region:my-inst", client, RSAKey, 30, nil, "") + im, err := NewInstance( + "my-proj:my-region:my-cluster:my-inst", + c, RSAKey, 30, "dialer-ider", + ) if err != nil { t.Fatalf("failed to initialize Instance: %v", err) } im.Close() - _, _, err = im.ConnectInfo(ctx, "PUBLIC") + _, _, err = im.ConnectInfo(ctx) if !errors.Is(err, context.Canceled) { t.Fatalf("failed to retrieve connect info: %v", err) } diff --git a/internal/cloudsql/refresh_test.go b/internal/cloudsql/refresh_test.go index 843e5a87..09a43a91 100644 --- a/internal/cloudsql/refresh_test.go +++ b/internal/cloudsql/refresh_test.go @@ -69,9 +69,6 @@ func TestRefresh(t *testing.T) { if got := res.expiry; wantExpiry != got { t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, got) } - if got := res.conf.ServerName; "client.alloydb" != got { - t.Fatalf("server name mismatch, want = %v, got = %v", "client.alloydb", got) - } } func TestRefreshFailsFast(t *testing.T) {