From f2f162338194378f060d44a12cf938e4f4d459bb Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Wed, 3 Apr 2024 12:04:45 -0600 Subject: [PATCH] fix: return a friendly error if the dialer is closed If the dialer has already been closed, return a clear error. Fixes #522 --- README.md | 2 ++ dialer.go | 22 ++++++++++++++++++++++ dialer_test.go | 41 ++++++++++++++++++++++++++++++++++------- 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 125e7bf5..b694b0e4 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,8 @@ d, err := alloydbconn.NewDialer(ctx) if err != nil { log.Fatalf("failed to initialize dialer: %v", err) } +// Don't close the dialer until you're done with the database connection +// e.g. at the end of your main function defer d.Close() // Tell the driver to use the AlloyDB Go Connector to create connections diff --git a/dialer.go b/dialer.go index 8dc431d2..e0fec7f1 100644 --- a/dialer.go +++ b/dialer.go @@ -56,6 +56,12 @@ const ( ) var ( + // ErrDialerClosed is used when a caller invokes Dial after closing the + // Dialer. + ErrDialerClosed = errors.New( + "Dialer has been closed. Close should be " + + "called only when a database connection is no longer needed.", + ) // versionString indicates the version of this library. //go:embed version.txt versionString string @@ -90,6 +96,8 @@ type Dialer struct { instances map[alloydb.InstanceURI]connectionInfoCache key *rsa.PrivateKey refreshTimeout time.Duration + // closed reports if the dialer has been closed. + closed chan struct{} client *alloydbadmin.AlloyDBAdminClient logger debug.Logger @@ -174,6 +182,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { return nil, err } d := &Dialer{ + closed: make(chan struct{}), instances: make(map[alloydb.InstanceURI]connectionInfoCache), key: cfg.rsaKey, refreshTimeout: cfg.refreshTimeout, @@ -194,6 +203,11 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { // instance argument must be the instance's URI, which is in the format // projects//locations//clusters//instances/ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) (conn net.Conn, err error) { + select { + case <-d.closed: + return nil, ErrDialerClosed + default: + } startTime := time.Now() var endDial trace.EndSpanFunc ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn.Dial", @@ -494,6 +508,14 @@ func (i *instrumentedConn) Close() error { // needed to connect. Additional dial operations may succeed until the information // expires. func (d *Dialer) Close() error { + // Check if Close has already been called. + select { + case <-d.closed: + return nil + default: + } + close(d.closed) + d.lock.Lock() defer d.lock.Unlock() for _, i := range d.instances { diff --git a/dialer_test.go b/dialer_test.go index ce088af2..07aa9fe9 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -34,6 +34,9 @@ import ( "google.golang.org/api/option" ) +const testInstanceURI = "projects/my-project/locations/my-region/" + + "clusters/my-cluster/instances/my-instance" + type stubTokenSource struct{} func (stubTokenSource) Token() (*oauth2.Token, error) { @@ -72,7 +75,7 @@ func TestDialerCanConnectToInstance(t *testing.T) { // reset between connections. for i := 0; i < 10; i++ { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - conn, err := d.Dial(ctx, "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance") + conn, err := d.Dial(ctx, testInstanceURI) if err != nil { t.Fatalf("expected Dial to succeed, but got error: %v", err) } @@ -116,12 +119,12 @@ func TestDialWithAdminAPIErrors(t *testing.T) { ctx, cancel := context.WithCancel(ctx) cancel() - _, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance") + _, err = d.Dial(ctx, testInstanceURI) if !errors.Is(err, context.Canceled) { t.Fatalf("when context is canceled, want = %T, got = %v", context.Canceled, err) } - _, err = d.Dial(context.Background(), "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance") + _, err = d.Dial(context.Background(), testInstanceURI) var wantErr2 *errtype.RefreshError if !errors.As(err, &wantErr2) { t.Fatalf("when API call fails, want = %T, got = %v", wantErr2, err) @@ -152,7 +155,7 @@ func TestDialWithUnavailableServerErrors(t *testing.T) { } d.client = c - _, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance") + _, err = d.Dial(ctx, testInstanceURI) var wantErr2 *errtype.DialError if !errors.As(err, &wantErr2) { t.Fatalf("when server proxy socket is unavailable, want = %T, got = %v", wantErr2, err) @@ -191,7 +194,7 @@ func TestDialerWithCustomDialFunc(t *testing.T) { } d.client = c - _, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance") + _, err = d.Dial(ctx, testInstanceURI) if !strings.Contains(err.Error(), "sentinel error") { t.Fatalf("want = sentinel error, got = %v", err) } @@ -275,7 +278,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { } sentinel := errors.New("connect info failed") - inst := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance" + inst := testInstanceURI cn, _ := alloydb.ParseInstURI(inst) spy := &spyConnectionInfoCache{ connectInfoCalls: []struct { @@ -410,8 +413,32 @@ func TestDialerSupportsOneOffDialFunction(t *testing.T) { return nil, sentinelErr } - _, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance", WithOneOffDialFunc(f)) + _, err = d.Dial(ctx, testInstanceURI, WithOneOffDialFunc(f)) if !errors.Is(err, sentinelErr) { t.Fatal("one-off dial func was not called") } } + +func TestDialerCloseReportsFriendlyError(t *testing.T) { + d, err := NewDialer( + context.Background(), + WithTokenSource(stubTokenSource{}), + ) + if err != nil { + t.Fatal(err) + } + _ = d.Close() + + _, err = d.Dial(context.Background(), testInstanceURI) + if !errors.Is(err, ErrDialerClosed) { + t.Fatalf("want = %v, got = %v", ErrDialerClosed, err) + } + + // Ensure multiple calls to close don't panic + _ = d.Close() + + _, err = d.Dial(context.Background(), testInstanceURI) + if !errors.Is(err, ErrDialerClosed) { + t.Fatalf("want = %v, got = %v", ErrDialerClosed, err) + } +}