Skip to content

Commit

Permalink
feat: add TestOnBorrowContext
Browse files Browse the repository at this point in the history
Add TestOnBorrowContext to the Pool struct for checking the health of the idle connection with a given context.
  • Loading branch information
vasayxtx committed Feb 12, 2024
1 parent 9f0d2e9 commit fa1f1b6
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 22 deletions.
11 changes: 10 additions & 1 deletion redis/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ type Pool struct {
// DialContext is an application supplied function for creating and configuring a
// connection with the given context.
//
// The connection returned from Dial must not be in a special state
// The connection returned from DialContext must not be in a special state
// (subscribed to pubsub channel, transaction started, ...).
DialContext func(ctx context.Context) (Conn, error)

Expand All @@ -139,6 +139,14 @@ type Pool struct {
// closed.
TestOnBorrow func(c Conn, t time.Time) error

// TestOnBorrowContext is an optional application supplied function
// for checking the health of an idle connection with the given context
// before the connection is used again by the application.
// Argument t is the time that the connection was returned
// to the pool. If the function returns an error, then the connection is
// closed.
TestOnBorrowContext func(ctx context.Context, c Conn, t time.Time) error

// Maximum number of idle connections in the pool.
MaxIdle int

Expand Down Expand Up @@ -228,6 +236,7 @@ func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
p.idle.popFront()
p.mu.Unlock()
if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) &&
(p.TestOnBorrowContext == nil || p.TestOnBorrowContext(ctx, pc.c, pc.t) == nil) &&
(p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) {
return &activeConn{p: p, pc: pc}, nil
}
Expand Down
143 changes: 124 additions & 19 deletions redis/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,48 @@ func (c *poolTestConn) Close() error {
func (c *poolTestConn) Err() error { return c.err }

func (c *poolTestConn) Do(commandName string, args ...interface{}) (interface{}, error) {
return c.do(c.Conn.Do, commandName, args...)
}

func (c *poolTestConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (interface{}, error) {
cwc, ok := c.Conn.(redis.ConnWithContext)
if !ok {
return nil, errors.New("redis: connection does not support ConnWithContext")
}
return c.do(
func(c string, a ...interface{}) (interface{}, error) {
return cwc.DoContext(ctx, c, a...)
},
commandName, args)
}

func (c *poolTestConn) do(
fn func(commandName string, args ...interface{}) (interface{}, error),
commandName string, args ...interface{},
) (interface{}, error) {
if commandName == "ERR" {
c.err = args[0].(error)
commandName = "PING"
}
if commandName != "" {
c.d.commands = append(c.d.commands, commandName)
}
return c.Conn.Do(commandName, args...)
return fn(commandName, args...)
}

func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
c.d.commands = append(c.d.commands, commandName)
return c.Conn.Send(commandName, args...)
}

func (c *poolTestConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) {
cwc, ok := c.Conn.(redis.ConnWithContext)
if !ok {
return nil, errors.New("redis: connection does not support ConnWithContext")
}
return cwc.ReceiveContext(ctx)
}

type poolDialer struct {
mu sync.Mutex
t *testing.T
Expand All @@ -73,14 +100,18 @@ type poolDialer struct {
}

func (d *poolDialer) dial() (redis.Conn, error) {
return d.dialContext(context.Background())
}

func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) {
d.mu.Lock()
d.dialed += 1
dialErr := d.dialErr
d.mu.Unlock()
if dialErr != nil {
return nil, d.dialErr
}
c, err := redis.DialDefaultServer()
c, err := redis.DialDefaultServerContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -90,15 +121,14 @@ func (d *poolDialer) dial() (redis.Conn, error) {
return &poolTestConn{d: d, Conn: c}, nil
}

func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) {
return d.dial()
}

func (d *poolDialer) check(message string, p *redis.Pool, dialed, open, inuse int) {
d.t.Helper()
d.checkAll(message, p, dialed, open, inuse, 0, 0)
}

func (d *poolDialer) checkAll(message string, p *redis.Pool, dialed, open, inuse int, waitCountMax int64, waitDurationMax time.Duration) {
d.t.Helper()

d.mu.Lock()
defer d.mu.Unlock()

Expand Down Expand Up @@ -368,21 +398,96 @@ func TestPoolConcurrenSendReceive(t *testing.T) {
}

func TestPoolBorrowCheck(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
pingN := func(ctx context.Context, p *redis.Pool, n int) {
for i := 0; i < n; i++ {
func() {
c, err := p.GetContext(ctx)
require.NoError(t, err)
defer func() {
require.NoError(t, c.Close())
}()
_, err = redis.DoContext(c, ctx, "PING")
require.NoError(t, err)
}()
}
}
defer p.Close()

for i := 0; i < 10; i++ {
c := p.Get()
_, err := c.Do("PING")
require.NoError(t, err)
c.Close()
}
d.check("1", p, 10, 1, 0)
t.Run("TestOnBorrow-error", func(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
}
defer p.Close()
pingN(context.Background(), p, 10)
d.check("1", p, 10, 1, 0)
})

t.Run("TestOnBorrowContext-error", func(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrowContext: func(context.Context, redis.Conn, time.Time) error { return redis.Error("BLAH") },
}
defer p.Close()
pingN(context.Background(), p, 10)
d.check("1", p, 10, 1, 0)
})

t.Run("TestOnBorrowContext-nil-error", func(t *testing.T) {
d := poolDialer{t: t}
var borrowErrs []error
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrowContext: func(ctx context.Context, c redis.Conn, t time.Time) error {
_, err := redis.DoContext(c, ctx, "PING")
if err != nil {
borrowErrs = append(borrowErrs, err)
}
return err
},
}
defer p.Close()
pingN(context.Background(), p, 10)
require.Empty(t, borrowErrs)
d.check("1", p, 1, 1, 0)
})

t.Run("TestOnBorrowContext-context.Canceled", func(t *testing.T) {
d := poolDialer{t: t}
var borrowErrs []error
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrowContext: func(ctx context.Context, c redis.Conn, t time.Time) error {
_, err := redis.DoContext(c, ctx, "PING")
if err != nil {
borrowErrs = append(borrowErrs, err)
}
return err
},
}
defer p.Close()

ctx, ctxCancel := context.WithCancel(context.Background())
defer ctxCancel()

pingN(ctx, p, 2)
d.check("1", p, 1, 1, 0)
require.Empty(t, borrowErrs)

ctxCancel()

_, err := p.GetContext(ctx)
require.ErrorIs(t, err, context.Canceled)

d.check("1", p, 2, 0, 0)
require.Len(t, borrowErrs, 1)
require.ErrorIs(t, borrowErrs[0], context.Canceled)
})
}

func TestPoolMaxActive(t *testing.T) {
Expand Down
11 changes: 9 additions & 2 deletions redis/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package redis

import (
"bufio"
"context"
"errors"
"flag"
"fmt"
Expand Down Expand Up @@ -197,15 +198,21 @@ func DefaultServerAddr() (string, error) {
// DialDefaultServer starts the test server if not already started and dials a
// connection to the server.
func DialDefaultServer(options ...DialOption) (Conn, error) {
return DialDefaultServerContext(context.Background(), options...)
}

// DialDefaultServerContext starts the test server if not already started and
// dials a connection to the server with the given context.
func DialDefaultServerContext(ctx context.Context, options ...DialOption) (Conn, error) {
addr, err := DefaultServerAddr()
if err != nil {
return nil, err
}
c, err := Dial("tcp", addr, append([]DialOption{DialReadTimeout(1 * time.Second), DialWriteTimeout(1 * time.Second)}, options...)...)
c, err := DialContext(ctx, "tcp", addr, append([]DialOption{DialReadTimeout(1 * time.Second), DialWriteTimeout(1 * time.Second)}, options...)...)
if err != nil {
return nil, err
}
if _, err = c.Do("FLUSHDB"); err != nil {
if _, err = DoContext(c, ctx, "FLUSHDB"); err != nil {
return nil, err
}
return c, nil
Expand Down

0 comments on commit fa1f1b6

Please sign in to comment.