diff --git a/go/pools/smartconnpool/connection.go b/go/pools/smartconnpool/connection.go index cdb5720596e..dbc235a8218 100644 --- a/go/pools/smartconnpool/connection.go +++ b/go/pools/smartconnpool/connection.go @@ -19,7 +19,6 @@ package smartconnpool import ( "context" "sync/atomic" - "time" ) type Connection interface { @@ -33,8 +32,8 @@ type Connection interface { type Pooled[C Connection] struct { next atomic.Pointer[Pooled[C]] - timeCreated time.Time - timeUsed time.Time + timeCreated timestamp + timeUsed timestamp pool *ConnPool[C] Conn C diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go index 52346adb1a4..87a288bec3d 100644 --- a/go/pools/smartconnpool/pool.go +++ b/go/pools/smartconnpool/pool.go @@ -19,7 +19,6 @@ package smartconnpool import ( "context" "math/rand/v2" - "slices" "sync" "sync/atomic" "time" @@ -408,13 +407,13 @@ func (pool *ConnPool[C]) put(conn *Pooled[C]) { return } } else { - conn.timeUsed = time.Now() + conn.timeUsed.update() lifetime := pool.extendedMaxLifetime() - if lifetime > 0 && time.Until(conn.timeCreated.Add(lifetime)) < 0 { + if lifetime > 0 && conn.timeCreated.elapsed() > lifetime { pool.Metrics.maxLifetimeClosed.Add(1) conn.Close() - if err := pool.connReopen(context.Background(), conn, conn.timeUsed); err != nil { + if err := pool.connReopen(context.Background(), conn, conn.timeUsed.get()); err != nil { pool.closedConn() return } @@ -442,12 +441,30 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool { return false } +func (pool *ConnPool[C]) pop(stack *connStack[C]) *Pooled[C] { + // retry-loop: pop a connection from the stack and atomically check whether + // its timeout has elapsed. If the timeout has elapsed, the borrow will fail, + // which means that a background worker has already marked this connection + // as stale and is in the process of shutting it down. If we successfully mark + // the timeout as borrowed, we know that background workers will not be able + // to expire this connection (even if it's still visible to them), so it's + // safe to return it + for conn, ok := stack.Pop(); ok; conn, ok = stack.Pop() { + if conn.timeUsed.borrow() { + return conn + } + } + return nil +} + func (pool *ConnPool[C]) tryReturnAnyConn() bool { - if conn, ok := pool.clean.Pop(); ok { + if conn := pool.pop(&pool.clean); conn != nil { + conn.timeUsed.update() return pool.tryReturnConn(conn) } for u := 0; u <= stackMask; u++ { - if conn, ok := pool.settings[u].Pop(); ok { + if conn := pool.pop(&pool.settings[u]); conn != nil { + conn.timeUsed.update() return pool.tryReturnConn(conn) } } @@ -479,15 +496,15 @@ func (pool *ConnPool[D]) extendedMaxLifetime() time.Duration { return time.Duration(maxLifetime) + time.Duration(rand.Uint32N(uint32(maxLifetime))) } -func (pool *ConnPool[C]) connReopen(ctx context.Context, dbconn *Pooled[C], now time.Time) error { +func (pool *ConnPool[C]) connReopen(ctx context.Context, dbconn *Pooled[C], now time.Duration) error { var err error dbconn.Conn, err = pool.config.connect(ctx) if err != nil { return err } - dbconn.timeUsed = now - dbconn.timeCreated = now + dbconn.timeUsed.set(now) + dbconn.timeCreated.set(now) return nil } @@ -496,13 +513,14 @@ func (pool *ConnPool[C]) connNew(ctx context.Context) (*Pooled[C], error) { if err != nil { return nil, err } - now := time.Now() - return &Pooled[C]{ - timeCreated: now, - timeUsed: now, - pool: pool, - Conn: conn, - }, nil + pooled := &Pooled[C]{ + pool: pool, + Conn: conn, + } + now := monotonicNow() + pooled.timeUsed.set(now) + pooled.timeCreated.set(now) + return pooled, nil } func (pool *ConnPool[C]) getFromSettingsStack(setting *Setting) *Pooled[C] { @@ -515,7 +533,7 @@ func (pool *ConnPool[C]) getFromSettingsStack(setting *Setting) *Pooled[C] { for i := uint32(0); i <= stackMask; i++ { pos := (i + start) & stackMask - if conn, ok := pool.settings[pos].Pop(); ok { + if conn := pool.pop(&pool.settings[pos]); conn != nil { return conn } } @@ -549,7 +567,7 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { pool.Metrics.getCount.Add(1) // best case: if there's a connection in the clean stack, return it right away - if conn, ok := pool.clean.Pop(); ok { + if conn := pool.pop(&pool.clean); conn != nil { pool.borrowed.Add(1) return conn, nil } @@ -585,7 +603,7 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { err = conn.Conn.ResetSetting(ctx) if err != nil { conn.Close() - err = pool.connReopen(ctx, conn, time.Now()) + err = pool.connReopen(ctx, conn, monotonicNow()) if err != nil { pool.closedConn() return nil, err @@ -603,10 +621,10 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( var err error // best case: check if there's a connection in the setting stack where our Setting belongs - conn, _ := pool.settings[setting.bucket&stackMask].Pop() + conn := pool.pop(&pool.settings[setting.bucket&stackMask]) // if there's connection with our setting, try popping a clean connection if conn == nil { - conn, _ = pool.clean.Pop() + conn = pool.pop(&pool.clean) } // otherwise try opening a brand new connection and we'll apply the setting to it if conn == nil { @@ -645,7 +663,7 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( err = conn.Conn.ResetSetting(ctx) if err != nil { conn.Close() - err = pool.connReopen(ctx, conn, time.Now()) + err = pool.connReopen(ctx, conn, monotonicNow()) if err != nil { pool.closedConn() return nil, err @@ -710,7 +728,7 @@ func (pool *ConnPool[C]) setCapacity(ctx context.Context, newcap int64) error { // try closing from connections which are currently idle in the stacks conn := pool.getFromSettingsStack(nil) if conn == nil { - conn, _ = pool.clean.Pop() + conn = pool.pop(&pool.clean) } if conn == nil { time.Sleep(delay) @@ -732,21 +750,22 @@ func (pool *ConnPool[C]) closeIdleResources(now time.Time) { return } - var conns []*Pooled[C] + mono := monotonicFromTime(now) closeInStack := func(s *connStack[C]) { - conns = s.PopAll(conns[:0]) - slices.Reverse(conns) - - for _, conn := range conns { - if conn.timeUsed.Add(timeout).Sub(now) < 0 { + // Do a read-only best effort iteration of all the connection in this + // stack and atomically attempt to mark them as expired. + // Any connections that are marked as expired are _not_ removed from + // the stack; it's generally unsafe to remove nodes from the stack + // besides the head. When clients pop from the stack, they'll immediately + // notice the expired connection and ignore it. + // see: timestamp.expired + for conn := s.Peek(); conn != nil; conn = conn.next.Load() { + if conn.timeUsed.expired(mono, timeout) { pool.Metrics.idleClosed.Add(1) conn.Close() pool.closedConn() - continue } - - s.Push(conn) } } diff --git a/go/pools/smartconnpool/pool_test.go b/go/pools/smartconnpool/pool_test.go index 2ac3d1b00e3..202261e6f3b 100644 --- a/go/pools/smartconnpool/pool_test.go +++ b/go/pools/smartconnpool/pool_test.go @@ -619,8 +619,14 @@ func TestIdleTimeout(t *testing.T) { p.put(conn) } + time.Sleep(1 * time.Second) + for _, closed := range closers { - <-closed + select { + case <-closed: + default: + t.Fatalf("Connections remain open after 1 second") + } } // no need to assert anything: all the connections in the pool should are idle-closed diff --git a/go/pools/smartconnpool/stack.go b/go/pools/smartconnpool/stack.go index ea7ae50201e..8d656ee4e8d 100644 --- a/go/pools/smartconnpool/stack.go +++ b/go/pools/smartconnpool/stack.go @@ -25,6 +25,9 @@ import ( // connStack is a lock-free stack for Connection objects. It is safe to // use from several goroutines. type connStack[C Connection] struct { + // top is a pointer to the top node on the stack and to an increasing + // counter of pop operations, to prevent A-B-A races. + // See: https://en.wikipedia.org/wiki/ABA_problem top atomic2.PointerAndUint64[Pooled[C]] } @@ -54,24 +57,7 @@ func (s *connStack[C]) Pop() (*Pooled[C], bool) { } } -func (s *connStack[C]) PopAll(out []*Pooled[C]) []*Pooled[C] { - var oldHead *Pooled[C] - - for { - var popCount uint64 - oldHead, popCount = s.top.Load() - if oldHead == nil { - return out - } - if s.top.CompareAndSwap(oldHead, popCount, nil, popCount+1) { - break - } - runtime.Gosched() - } - - for oldHead != nil { - out = append(out, oldHead) - oldHead = oldHead.next.Load() - } - return out +func (s *connStack[C]) Peek() *Pooled[C] { + top, _ := s.top.Load() + return top } diff --git a/go/pools/smartconnpool/timestamp.go b/go/pools/smartconnpool/timestamp.go new file mode 100644 index 00000000000..961ff18a5c5 --- /dev/null +++ b/go/pools/smartconnpool/timestamp.go @@ -0,0 +1,94 @@ +package smartconnpool + +import ( + "math" + "sync/atomic" + "time" +) + +var monotonicRoot = time.Now() + +// timestamp is a monotonic point in time, stored as a number of +// nanoseconds since the monotonic root. This type is only 8 bytes +// and hence can always be accessed atomically +type timestamp struct { + nano atomic.Int64 +} + +// timestampExpired is a special value that means this timestamp is now past +// an arbitrary expiration point, and hence doesn't need to store +const timestampExpired = math.MaxInt64 + +// timestampBusy is a special value that means this timestamp no longer +// tracks an expiration point +const timestampBusy = math.MinInt64 + +// monotonicNow returns the current monotonic time as a time.Duration. +// This is a very efficient operation because time.Since performs direct +// substraction of monotonic times without considering the wall clock times. +func monotonicNow() time.Duration { + return time.Since(monotonicRoot) +} + +// monotonicFromTime converts a wall-clock time from time.Now into a +// monotonic timestamp. +// This is a very efficient operation because time.(*Time).Sub performs direct +// substraction of monotonic times without considering the wall clock times. +func monotonicFromTime(now time.Time) time.Duration { + return now.Sub(monotonicRoot) +} + +// set sets this timestamp to the given monotonic value +func (t *timestamp) set(mono time.Duration) { + t.nano.Store(int64(mono)) +} + +// get returns the monotonic time of this timestamp as the number of nanoseconds +// since the monotonic root. +func (t *timestamp) get() time.Duration { + return time.Duration(t.nano.Load()) +} + +// elapsed returns the number of nanoseconds that have passed since +// this timestamp was updated +func (t *timestamp) elapsed() time.Duration { + return monotonicNow() - t.get() +} + +// update sets this timestamp's value to the current monotonic time +func (t *timestamp) update() { + t.nano.Store(int64(monotonicNow())) +} + +// borrow attempts to borrow this timestamp atomically. +// It only succeeds if we can ensure that nobody else has marked +// this timestamp as expired. When succeeded, the timestamp +// is cleared as "busy" as it no longer tracks an expiration point. +func (t *timestamp) borrow() bool { + stamp := t.nano.Load() + switch stamp { + case timestampExpired: + return false + case timestampBusy: + panic("timestampBusy when borrowing a time") + default: + return t.nano.CompareAndSwap(stamp, timestampBusy) + } +} + +// expired attempts to atomically expire this timestamp. +// It only succeeds if we can ensure the timestamp hasn't been +// concurrently expired or borrowed. +func (t *timestamp) expired(now time.Duration, timeout time.Duration) bool { + stamp := t.nano.Load() + if stamp == timestampExpired { + return false + } + if stamp == timestampBusy { + return false + } + if now-time.Duration(stamp) > timeout { + return t.nano.CompareAndSwap(stamp, timestampExpired) + } + return false +}