Skip to content

Commit

Permalink
fix: ensure cert refresh recovers from computer sleep (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom authored Dec 13, 2023
1 parent fe4f6de commit 79fcbc8
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 6 deletions.
32 changes: 32 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,25 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
}
endInfo(err)

// If the client certificate has expired (as when the computer goes to
// sleep, and the refresh cycle cannot run), force a refresh immediately.
// The TLS handshake will not fail on an expired client certificate. It's
// not until the first read where the client cert error will be surfaced.
// So check that the certificate is valid before proceeding.
if invalidClientCert(tlsCfg) {
i.ForceRefresh()
// Block on refreshed connection info
addr, tlsCfg, err = i.ConnectInfo(ctx)
if err != nil {
d.lock.Lock()
defer d.lock.Unlock()
// Stop all background refreshes
i.Close()
delete(d.instances, inst)
return nil, err
}
}

var connectEnd trace.EndSpanFunc
ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.Connect")
defer func() { connectEnd(err) }()
Expand Down Expand Up @@ -275,6 +294,19 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
}), nil
}

func invalidClientCert(c *tls.Config) bool {
// The following conditions should be impossible (no certs, nil leaf), but
// just in case there's an unknown edge case, check assumptions before
// proceeding.
if len(c.Certificates) == 0 {
return true
}
if c.Certificates[0].Leaf == nil {
return true
}
return time.Now().After(c.Certificates[0].Leaf.NotAfter)
}

// metadataExchange sends metadata about the connection prior to the database
// protocol taking over. The exchange consists of four steps:
//
Expand Down
99 changes: 93 additions & 6 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package alloydbconn
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -238,7 +239,12 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
badInst, _ := alloydb.ParseInstURI(badInstanceName)

spy := &spyConnectionInfoCache{
connectInfoError: errors.New("connect info failed"),
connectInfoCalls: []struct {
tls *tls.Config
err error
}{{
err: errors.New("connect info failed"),
}},
}
d.instances[badInst] = spy

Expand All @@ -261,17 +267,92 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
}
}

type spyConnectionInfoCache struct {
connectInfoError error
func TestDialRefreshesExpiredCertificates(t *testing.T) {
d, err := NewDialer(
context.Background(),
WithTokenSource(stubTokenSource{}),
)
if err != nil {
t.Fatalf("expected NewDialer to succeed, but got error: %v", err)
}

sentinel := errors.New("connect info failed")
inst := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
cn, _ := alloydb.ParseInstURI(inst)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
tls *tls.Config
err error
}{
// First call returns expired certificate
{
tls: &tls.Config{
Certificates: []tls.Certificate{{
Leaf: &x509.Certificate{
// Certificate expired 10 hours ago.
NotAfter: time.Now().Add(-10 * time.Hour),
},
}},
},
},
// Second call errors to validate error path
{
err: sentinel,
},
},
}
d.instances[cn] = spy

_, err = d.Dial(context.Background(), inst)
if !errors.Is(err, sentinel) {
t.Fatalf("expected Dial to return sentinel error, instead got = %v", err)
}

// Verify that the cache was refreshed
if got, want := spy.ForceRefreshWasCalled(), true; got != want {
t.Fatal("ForceRefresh was not called")
}

// Verify that the connection info cache was closed (to prevent
// further failed refresh operations)
if got, want := spy.CloseWasCalled(), true; got != want {
t.Fatal("Close was not called")
}

mu sync.Mutex
closeWasCalled bool
// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.instances[cn]
d.lock.RUnlock()
if ok {
t.Fatal("bad instance was not removed from the cache")
}
}

type spyConnectionInfoCache struct {
mu sync.Mutex
connectInfoIndex int
connectInfoCalls []struct {
tls *tls.Config
err error
}
closeWasCalled bool
forceRefreshWasCalled bool
// embed interface to avoid having to implement irrelevant methods
connectionInfoCache
}

func (s *spyConnectionInfoCache) ConnectInfo(_ context.Context) (string, *tls.Config, error) {
return "", nil, s.connectInfoError
s.mu.Lock()
defer s.mu.Unlock()
res := s.connectInfoCalls[s.connectInfoIndex]
s.connectInfoIndex++
return "unused", res.tls, res.err
}

func (s *spyConnectionInfoCache) ForceRefresh() {
s.mu.Lock()
defer s.mu.Unlock()
s.forceRefreshWasCalled = true
}

func (s *spyConnectionInfoCache) Close() error {
Expand All @@ -287,6 +368,12 @@ func (s *spyConnectionInfoCache) CloseWasCalled() bool {
return s.closeWasCalled
}

func (s *spyConnectionInfoCache) ForceRefreshWasCalled() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.forceRefreshWasCalled
}

func TestDialerSupportsOneOffDialFunction(t *testing.T) {
ctx := context.Background()
inst := mock.NewFakeInstance(
Expand Down

0 comments on commit 79fcbc8

Please sign in to comment.