From c509c9c3e5c348d1c7bf40d069b16069e04b743b Mon Sep 17 00:00:00 2001 From: "Jeremy L. Morris" Date: Thu, 31 Oct 2019 21:41:00 -0400 Subject: [PATCH] [WIP] Create function like TestOnBorrow that includes context --- redis/pool.go | 7 ++++- redis/pool_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/redis/pool.go b/redis/pool.go index 568374ab..2f417a97 100644 --- a/redis/pool.go +++ b/redis/pool.go @@ -141,6 +141,10 @@ type Pool struct { // closed. TestOnBorrow func(c Conn, t time.Time) error + // TestOnBorrowWithContext is the same as TestOnBorrow, but includes + // the context. + TestOnBorrowWithContext func(c Conn, t time.Time, ctx context.Context) error + // Maximum number of idle connections in the pool. MaxIdle int @@ -354,7 +358,8 @@ func (p *Pool) get(ctx context.Context) (*poolConn, error) { pc := p.idle.front p.idle.popFront() p.mu.Unlock() - if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) && + if (p.TestOnBorrowWithContext == nil || p.TestOnBorrowWithContext(pc.c, pc.t, p.GetContext(ctx))) && + (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) && (p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) { return pc, nil } diff --git a/redis/pool_test.go b/redis/pool_test.go index 3864d2a9..37350c3d 100644 --- a/redis/pool_test.go +++ b/redis/pool_test.go @@ -368,6 +368,23 @@ func TestPoolBorrowCheck(t *testing.T) { d.check("1", p, 10, 1, 0) } +func TestPoolBorrowWithContextCheck(t *testing.T) { + d := poolDialer{t: t} + p := &redis.Pool{ + MaxIdle: 2, + Dial: d.dial, + TestOnBorrowWithContext: func(redis.Conn, time.Time, context.Context) error { return redis.Error("BLAH") }, + } + defer p.Close() + + for i := 0; i < 10; i++ { + c := p.Get() + c.Do("PING") + c.Close() + } + d.check("1", p, 10, 1, 0) +} + func TestPoolMaxActive(t *testing.T) { d := poolDialer{t: t} p := &redis.Pool{ @@ -754,6 +771,54 @@ func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) { } } +// Borrowing requires us to iterate over the idle connections, unlock the pool, +// and perform a blocking operation to check the connection still works. If +// TestOnBorrow fails, we must reacquire the lock and continue iteration. This +// test ensures that iteration will work correctly if multiple threads are +// iterating simultaneously. +func TestLocking_TestOnBorrowWithContextFails_PoolDoesntCrash(t *testing.T) { + const count = 100 + + // First we'll Create a pool where the pilfering of idle connections fails. + d := poolDialer{t: t} + p := &redis.Pool{ + MaxIdle: count, + MaxActive: count, + Dial: d.dial, + TestOnBorrowWithContext: func(c redis.Conn, t time.Time, ctx context.Context) error { + return errors.New("No way back into the real world.") + }, + } + defer p.Close() + + // Fill the pool with idle connections. + conns := make([]redis.Conn, count) + for i := range conns { + conns[i] = p.Get() + } + for i := range conns { + conns[i].Close() + } + + // Spawn a bunch of goroutines to thrash the pool. + var wg sync.WaitGroup + wg.Add(count) + for i := 0; i < count; i++ { + go func() { + c := p.Get() + if c.Err() != nil { + t.Errorf("pool get failed: %v", c.Err()) + } + c.Close() + wg.Done() + }() + } + wg.Wait() + if d.dialed != count*2 { + t.Errorf("Expected %d dials, got %d", count*2, d.dialed) + } +} + func BenchmarkPoolGet(b *testing.B) { b.StopTimer() p := redis.Pool{Dial: func() (redis.Conn, error) { return redis.DialDefaultServer() }, MaxIdle: 2}