From 15162afaf2a1cd1ee8c63ebc0dc14b8baa0613f7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Aug 2023 13:30:48 +0800 Subject: [PATCH] Support GetDBConnWithContext PreparedStmtDB --- prepare_stmt.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 10fefc317..9d98c86e0 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,19 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() - } - +func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } + if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(gormdb) + } + + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + return nil, ErrInvalidDB } @@ -54,15 +58,15 @@ func (db *PreparedStmtDB) Close() { } } -func (db *PreparedStmtDB) Reset() { - db.Mux.Lock() - defer db.Mux.Unlock() +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() - for _, stmt := range db.Stmts { + for _, stmt := range sdb.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = make(map[string]*Stmt) + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {