diff --git a/internal/alloydb/instance.go b/internal/alloydb/instance.go index d86b5640..c242b2d7 100644 --- a/internal/alloydb/instance.go +++ b/internal/alloydb/instance.go @@ -189,7 +189,11 @@ func (i *Instance) OpenConns() *uint64 { // Close closes the instance; it stops the refresh cycle and prevents it from // making additional calls to the AlloyDB Admin API. func (i *Instance) Close() error { + i.resultGuard.Lock() + defer i.resultGuard.Unlock() i.cancel() + i.cur.cancel() + i.next.cancel() return nil } @@ -230,6 +234,8 @@ func (i *Instance) result(ctx context.Context) (*refreshOperation, error) { err = res.err case <-ctx.Done(): err = ctx.Err() + case <-i.ctx.Done(): + err = i.ctx.Err() } if err != nil { return nil, err @@ -260,6 +266,13 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { r := &refreshOperation{} r.ready = make(chan struct{}) r.timer = time.AfterFunc(d, func() { + // instance has been closed, don't schedule anything + if err := i.ctx.Err(); err != nil { + r.err = err + close(r.ready) + return + } + ctx, cancel := context.WithTimeout(i.ctx, i.refreshTimeout) defer cancel() @@ -280,6 +293,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { // result and schedule a new refresh i.resultGuard.Lock() defer i.resultGuard.Unlock() + // if failed, scheduled the next refresh immediately if r.err != nil { i.next = i.scheduleRefresh(0) @@ -297,12 +311,6 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { // Update the current results, and schedule the next refresh in // the future i.cur = r - select { - case <-i.ctx.Done(): - // instance has been closed, don't schedule anything - return - default: - } t := refreshDuration(time.Now(), i.cur.result.expiry) i.next = i.scheduleRefresh(t) }) diff --git a/internal/alloydb/instance_test.go b/internal/alloydb/instance_test.go index f3a9c149..2db130b3 100644 --- a/internal/alloydb/instance_test.go +++ b/internal/alloydb/instance_test.go @@ -19,7 +19,6 @@ import ( "crypto/rand" "crypto/rsa" "errors" - "strings" "testing" "time" @@ -216,7 +215,7 @@ func TestClose(t *testing.T) { i.Close() _, _, err = i.ConnectInfo(ctx) - if !strings.Contains(err.Error(), "context was canceled or expired") { + if !errors.Is(err, context.Canceled) { t.Fatalf("failed to retrieve connect info: %v", err) } }