Skip to content

Commit

Permalink
Proof of concept for ErrorHandler interface
Browse files Browse the repository at this point in the history
This implementation is not complete. It just demonstrates the
ideas behind the ErrorHandler interface.
  • Loading branch information
Craig Fitzpatrick committed Jul 9, 2021
1 parent b597672 commit 9cbf4b8
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 6 deletions.
60 changes: 60 additions & 0 deletions errorhandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package mssql

import (
"context"
"database/sql/driver"
)

// ErrorHandler interface allows user to alter the error returned by
// the driver to database/sql
type ErrorHandler interface {
HandleError(
queryCtx context.Context,
transactionCtx context.Context,
stmt string,
connectionGood bool,
err error) error
}

// optionalErrorHandler implements the ErrorHandler interface with
// a default behavior (pass back the error unaltered) that can be
// overridden with an optional ErrorHandler supplied by the user.
type optionalErrorHandler struct {
errorHandler ErrorHandler
}

// HandleError returns the orignal error by default unless the user
// has specified an optional ErrorHandler to override the default.
func (o optionalErrorHandler) HandleError(
queryCtx context.Context,
transactionCtx context.Context,
stmt string,
connectionGood bool,
err error) error {
if nil != o.errorHandler {
return o.errorHandler.HandleError(queryCtx, transactionCtx, stmt, connectionGood, err)
} else {
return err
}
}

// AutoRetryErrorHandler implements the ErrorHandler interface by
// returning driver.ErrBadConn when the underlying connection is bad.
// This allows the auto-retry logic in database/sql to work as designed,
// at the expense of error detail (see issues #275 and #586)
type AutoRetryErrorHandler struct{}

// HandleError returns driver.ErrBadConn when the underlying connection
// is bad, allowing database/sql auto-retry logic to work as designed.
func (a AutoRetryErrorHandler) HandleError(
queryCtx context.Context,
transactionCtx context.Context,
stmt string,
connectionGood bool,
err error) error {
if !connectionGood {
return driver.ErrBadConn
} else {
return err
}
}
24 changes: 18 additions & 6 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ func (d netDialer) DialContext(ctx context.Context, network string, addr string)
}

type Driver struct {
log optionalLogger
log optionalLogger
errHandler optionalErrorHandler

processQueryText bool
}
Expand Down Expand Up @@ -78,6 +79,15 @@ func (d *Driver) SetLogger(logger Logger) {
d.log = optionalLogger{logger}
}

func SetErrorHandler(errorHandler ErrorHandler) {
driverInstance.SetErrorHandler(errorHandler)
driverInstanceNoProcess.SetErrorHandler(errorHandler)
}

func (d *Driver) SetErrorHandler(errorHandler ErrorHandler) {
d.errHandler = optionalErrorHandler{errorHandler}
}

// NewConnector creates a new connector from a DSN.
// The returned connector may be used with sql.OpenDB.
func NewConnector(dsn string) (*Connector, error) {
Expand Down Expand Up @@ -148,6 +158,7 @@ func (c *Connector) getDialer(p *connectParams) Dialer {

type Conn struct {
connector *Connector
errHandler optionalErrorHandler
sess *tdsSession
transactionCtx context.Context
resetSession bool
Expand Down Expand Up @@ -179,7 +190,8 @@ func (c *Conn) checkBadConn(err error) error {
return nil
case io.EOF:
c.connectionGood = false
return driver.ErrBadConn
// fixme: This is just a POC. The arguments aren't fully populated.
return c.errHandler.HandleError(context.TODO(), context.TODO(), "", c.connectionGood, driver.ErrBadConn)
case driver.ErrBadConn:
// It is an internal programming error if driver.ErrBadConn
// is ever passed to this function. driver.ErrBadConn should
Expand All @@ -191,13 +203,12 @@ func (c *Conn) checkBadConn(err error) error {
switch err.(type) {
case net.Error:
c.connectionGood = false
return err
case StreamError:
c.connectionGood = false
return err
default:
return err
}

// fixme: This is just a POC. The arguments aren't fully populated.
return c.errHandler.HandleError(context.TODO(), context.TODO(), "", c.connectionGood, err)
}

func (c *Conn) clearOuts() {
Expand Down Expand Up @@ -347,6 +358,7 @@ func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams

conn := &Conn{
connector: c,
errHandler: d.errHandler,
sess: sess,
transactionCtx: context.Background(),
processQueryText: d.processQueryText,
Expand Down

0 comments on commit 9cbf4b8

Please sign in to comment.