diff --git a/dialer.go b/dialer.go index c93dbade..fc6301a3 100644 --- a/dialer.go +++ b/dialer.go @@ -275,12 +275,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn c := d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN) ci, err := c.ConnectionInfo(ctx) if err != nil { - d.lock.Lock() - defer d.lock.Unlock() - d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String()) - // Stop all background refreshes - c.Close() - delete(d.cache, cn) + d.removeCached(ctx, cn, c, err) endInfo(err) return nil, err } @@ -297,12 +292,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn // Block on refreshed connection info ci, err = c.ConnectionInfo(ctx) if err != nil { - d.lock.Lock() - defer d.lock.Unlock() - d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String()) - // Stop all background refreshes - c.Close() - delete(d.cache, cn) + d.removeCached(ctx, cn, c, err) return nil, err } } @@ -312,6 +302,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn defer func() { connectEnd(err) }() addr, err := ci.Addr(cfg.ipType) if err != nil { + d.removeCached(ctx, cn, c, err) return nil, err } addr = net.JoinHostPort(addr, serverProxyPort) @@ -359,10 +350,31 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn }), nil } +// removeCached stops all background refreshes and deletes the connection +// info cache from the map of caches. +func (d *Dialer) removeCached( + ctx context.Context, + i instance.ConnName, c connectionInfoCache, err error, +) { + d.logger.Debugf( + ctx, + "[%v] Removing connection info from cache: %v", + i.String(), + err, + ) + d.lock.Lock() + defer d.lock.Unlock() + c.Close() + delete(d.cache, i) +} + // validClientCert checks that the ephemeral client certificate retrieved from // the cache is unexpired. The time comparisons strip the monotonic clock value // to ensure an accurate result, even after laptop sleep. -func validClientCert(ctx context.Context, cn instance.ConnName, l debug.ContextLogger, expiration time.Time) bool { +func validClientCert( + ctx context.Context, cn instance.ConnName, + l debug.ContextLogger, expiration time.Time, +) bool { // Use UTC() to strip monotonic clock value to guard against inaccurate // comparisons, especially after laptop sleep. // See the comments on the monotonic clock in the Go documentation for diff --git a/dialer_test.go b/dialer_test.go index 051e9b92..76850489 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -16,6 +16,8 @@ package cloudsqlconn import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -622,35 +624,71 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) { // Populate instance map with connection info cache that will always fail // This allows the test to verify the error case path invoking close. badInstanceConnectionName := "doesntexist:us-central1:doesntexist" - badCN, _ := instance.ParseConnName(badInstanceConnectionName) - spy := &spyConnectionInfoCache{ - connectInfoCalls: []struct { - info cloudsql.ConnectionInfo - err error - }{{ - err: errors.New("connect info failed"), - }}, + tcs := []struct { + desc string + icn string + resp connectionInfoResp + opts []DialOption + }{ + { + desc: "dialing a bad instance URI", + icn: badInstanceConnectionName, + resp: connectionInfoResp{ + err: errors.New("connect info failed"), + }, + }, + { + desc: "specifying an invalid IP type", + icn: "myproject:myregion:myinstance", + resp: connectionInfoResp{ + info: cloudsql.NewConnectionInfo( + instance.ConnName{}, + "", + map[string]string{ + // no public IP + cloudsql.PrivateIP: "10.0.0.1", + }, + nil, + tls.Certificate{Leaf: &x509.Certificate{ + NotAfter: time.Now().Add(time.Hour), + }}, + ), + }, + opts: []DialOption{WithPublicIP()}, + }, } - d.cache[badCN] = monitoredCache{connectionInfoCache: spy} - _, err = d.Dial(context.Background(), badInstanceConnectionName) - if err == nil { - t.Fatal("expected Dial to return error") - } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Manually populate the internal cache with a spy + inst, _ := instance.ParseConnName(tc.icn) + spy := &spyConnectionInfoCache{ + connectInfoCalls: []connectionInfoResp{tc.resp}, + } + d.cache[inst] = monitoredCache{ + connectionInfoCache: spy, + } - // Verify that the connection info cache was closed (to prevent - // further failed refresh operations) - if got, want := spy.CloseWasCalled(), true; got != want { - t.Fatal("Close was not called") - } + _, err = d.Dial(context.Background(), tc.icn, tc.opts...) + if err == nil { + t.Fatal("expected Dial to return error") + } + // Verify that the connection info cache was closed (to prevent + // further failed refresh operations) + if got, want := spy.closeWasCalled(), true; got != want { + t.Fatal("Close was not called") + } - // Now verify that bad connection name has been deleted from map. - d.lock.RLock() - _, ok := d.cache[badCN] - d.lock.RUnlock() - if ok { - t.Fatal("bad instance was not removed from the cache") + // Now verify that bad connection name has been deleted from map. + d.lock.RLock() + _, ok := d.cache[inst] + d.lock.RUnlock() + if ok { + t.Fatal("connection info was not removed from cache") + } + }) } + } func TestDialRefreshesExpiredCertificates(t *testing.T) { @@ -665,10 +703,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { icn := "project:region:instance" cn, _ := instance.ParseConnName(icn) spy := &spyConnectionInfoCache{ - connectInfoCalls: []struct { - info cloudsql.ConnectionInfo - err error - }{ + connectInfoCalls: []connectionInfoResp{ // First call returns expired certificate { // Certificate expired 10 hours ago. @@ -690,13 +725,13 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { } // Verify that the cache was refreshed - if got, want := spy.ForceRefreshWasCalled(), true; got != want { + if got, want := spy.forceRefreshWasCalled(), true; got != want { t.Fatal("ForceRefresh was not called") } // Verify that the connection info cache was closed (to prevent // further failed refresh operations) - if got, want := spy.CloseWasCalled(), true; got != want { + if got, want := spy.closeWasCalled(), true; got != want { t.Fatal("Close was not called") } @@ -710,15 +745,18 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { } +type connectionInfoResp struct { + info cloudsql.ConnectionInfo + err error +} + type spyConnectionInfoCache struct { mu sync.Mutex connectInfoIndex int - connectInfoCalls []struct { - info cloudsql.ConnectionInfo - err error - } - closeWasCalled bool - forceRefreshWasCalled bool + connectInfoCalls []connectionInfoResp + + closed bool + forceRefreshed bool // embed interface to avoid having to implement irrelevant methods connectionInfoCache } @@ -736,7 +774,7 @@ func (s *spyConnectionInfoCache) ConnectionInfo( func (s *spyConnectionInfoCache) ForceRefresh() { s.mu.Lock() defer s.mu.Unlock() - s.forceRefreshWasCalled = true + s.forceRefreshed = true } func (s *spyConnectionInfoCache) UpdateRefresh(*bool) {} @@ -744,20 +782,20 @@ func (s *spyConnectionInfoCache) UpdateRefresh(*bool) {} func (s *spyConnectionInfoCache) Close() error { s.mu.Lock() defer s.mu.Unlock() - s.closeWasCalled = true + s.closed = true return nil } -func (s *spyConnectionInfoCache) CloseWasCalled() bool { +func (s *spyConnectionInfoCache) closeWasCalled() bool { s.mu.Lock() defer s.mu.Unlock() - return s.closeWasCalled + return s.closed } -func (s *spyConnectionInfoCache) ForceRefreshWasCalled() bool { +func (s *spyConnectionInfoCache) forceRefreshWasCalled() bool { s.mu.Lock() defer s.mu.Unlock() - return s.forceRefreshWasCalled + return s.forceRefreshed } func TestDialerSupportsOneOffDialFunction(t *testing.T) { diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index bebce43a..2e0b5d02 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -178,6 +178,24 @@ type ConnectionInfo struct { addrs map[string]string } +// NewConnectionInfo initializes a ConnectionInfo struct. +func NewConnectionInfo( + cn instance.ConnName, + version string, + ipAddrs map[string]string, + serverCaCert *x509.Certificate, + clientCert tls.Certificate, +) ConnectionInfo { + return ConnectionInfo{ + addrs: ipAddrs, + ServerCaCert: serverCaCert, + ClientCertificate: clientCert, + Expiration: clientCert.Leaf.NotAfter, + DBVersion: version, + ConnectionName: cn, + } +} + // Addr returns the IP address or DNS name for the given IP type. func (c ConnectionInfo) Addr(ipType string) (string, error) { var ( diff --git a/internal/cloudsql/refresh.go b/internal/cloudsql/refresh.go index f13874f2..2448c060 100644 --- a/internal/cloudsql/refresh.go +++ b/internal/cloudsql/refresh.go @@ -340,14 +340,9 @@ func (r refresher) ConnectionInfo( return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err()) } - return ConnectionInfo{ - addrs: md.ipAddrs, - ServerCaCert: md.serverCaCert, - ClientCertificate: ec, - Expiration: ec.Leaf.NotAfter, - DBVersion: md.version, - ConnectionName: cn, - }, nil + return NewConnectionInfo( + cn, md.version, md.ipAddrs, md.serverCaCert, ec, + ), nil } // supportsAutoIAMAuthN checks that the engine support automatic IAM authn. If