Skip to content

Commit

Permalink
Pass contexts to checkBadConn and HandleError
Browse files Browse the repository at this point in the history
  • Loading branch information
Craig Fitzpatrick committed Jul 9, 2021
1 parent 9cbf4b8 commit a47dcd0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (b *Bulk) Done() (rowcount int64, err error) {
reader := startReading(b.cn.sess, b.ctx, nil)
err = reader.iterateResponse()
if err != nil {
return 0, b.cn.checkBadConn(err)
return 0, b.cn.checkBadConn(b.ctx, err)
}

return reader.rowCount, nil
Expand Down
31 changes: 15 additions & 16 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (c *Conn) IsValid() bool {
return c.connectionGood
}

func (c *Conn) checkBadConn(err error) error {
func (c *Conn) checkBadConn(queryCtx context.Context, err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
// error indicates that connection is not usable
Expand All @@ -191,7 +191,7 @@ func (c *Conn) checkBadConn(err error) error {
case io.EOF:
c.connectionGood = false
// fixme: This is just a POC. The arguments aren't fully populated.
return c.errHandler.HandleError(context.TODO(), context.TODO(), "", c.connectionGood, driver.ErrBadConn)
return c.errHandler.HandleError(queryCtx, c.transactionCtx, "fixme", c.connectionGood, driver.ErrBadConn)
case driver.ErrBadConn:
// It is an internal programming error if driver.ErrBadConn
// is ever passed to this function. driver.ErrBadConn should
Expand All @@ -208,7 +208,7 @@ func (c *Conn) checkBadConn(err error) error {
}

// fixme: This is just a POC. The arguments aren't fully populated.
return c.errHandler.HandleError(context.TODO(), context.TODO(), "", c.connectionGood, err)
return c.errHandler.HandleError(queryCtx, c.transactionCtx, "fixme", c.connectionGood, err)
}

func (c *Conn) clearOuts() {
Expand All @@ -219,20 +219,19 @@ func (c *Conn) simpleProcessResp(ctx context.Context) error {
reader := startReading(c.sess, ctx, c.outs)
c.clearOuts()

var resultError error
err := reader.iterateResponse()
if err != nil {
return c.checkBadConn(err)
return c.checkBadConn(ctx, err)
}
return resultError
return nil
}

func (c *Conn) Commit() error {
if !c.connectionGood {
return driver.ErrBadConn
}
if err := c.sendCommitRequest(); err != nil {
return c.checkBadConn(err)
return c.checkBadConn(c.transactionCtx, err)
}
return c.simpleProcessResp(c.transactionCtx)
}
Expand All @@ -259,7 +258,7 @@ func (c *Conn) Rollback() error {
return driver.ErrBadConn
}
if err := c.sendRollbackRequest(); err != nil {
return c.checkBadConn(err)
return c.checkBadConn(c.transactionCtx, err)
}
return c.simpleProcessResp(c.transactionCtx)
}
Expand Down Expand Up @@ -291,7 +290,7 @@ func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx,
}
err = c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, c.checkBadConn(err)
return nil, c.checkBadConn(ctx, err)
}
tx, err = c.processBeginResponse(ctx)
if err != nil {
Expand Down Expand Up @@ -600,7 +599,7 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(ctx, err)
}
return s.processQueryResponse(ctx)
}
Expand Down Expand Up @@ -633,7 +632,7 @@ loop:
if token.isError() {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(token.getError())
return nil, s.c.checkBadConn(ctx, token.getError())
}
case ReturnStatus:
s.c.sess.setReturnStatus(token)
Expand All @@ -642,7 +641,7 @@ loop:
} else {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(ctx, err)
}
}
res = &Rows{stmt: s, reader: reader, cols: cols, cancel: cancel}
Expand All @@ -658,7 +657,7 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result,
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(ctx, err)
}
if res, err = s.processExec(ctx); err != nil {
return nil, err
Expand All @@ -671,7 +670,7 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
s.c.clearOuts()
err = reader.iterateResponse()
if err != nil {
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(ctx, err)
}
return &Result{s.c, reader.rowCount}, nil
}
Expand Down Expand Up @@ -741,15 +740,15 @@ func (rc *Rows) Next(dest []driver.Value) error {
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError())
return rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError())
}
case ReturnStatus:
rc.stmt.c.sess.setReturnStatus(tokdata)
}
}

} else {
return rc.stmt.c.checkBadConn(err)
return rc.stmt.c.checkBadConn(rc.reader.ctx, err)
}
}
}
Expand Down

0 comments on commit a47dcd0

Please sign in to comment.