diff --git a/common.go b/common.go index 6dc593e..ff26294 100644 --- a/common.go +++ b/common.go @@ -140,6 +140,29 @@ func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) { } func getCallerInfo(db database, retry bool) string { + if !db.enableCallerInfo { + return "" + } + extraInfo := "" + if retry { + extraInfo += " (retry)" + } + for i := 0; true; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + if file == "" || strings.Contains(file, "/sqlingo@v") { + continue + } + segs := strings.Split(file, "/") + name := segs[len(segs)-1] + return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo) + } + return "" +} + +func getTxCallerInfo(db transaction, retry bool) string { if !db.enableCallerInfo { return "" } diff --git a/database.go b/database.go index 499e95a..c16b2f9 100644 --- a/database.go +++ b/database.go @@ -43,16 +43,12 @@ type Database interface { Update(table Table) updateWithSet // Initiate a DELETE FROM statement DeleteFrom(table Table) deleteWithTable -} -type txOrDB interface { - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Begin() (Transaction, error) } type database struct { db *sql.DB - tx *sql.Tx logger func(sql string, durationNano int64) dialect dialect retryPolicy func(error) bool @@ -101,13 +97,6 @@ func (d database) GetDB() *sql.DB { return d.db } -func (d database) getTxOrDB() txOrDB { - if d.tx != nil { - return d.tx - } - return d.db -} - func (d database) Query(sqlString string) (Cursor, error) { return d.QueryContext(context.Background(), sqlString) } @@ -119,7 +108,7 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo) if err != nil { - isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err) + isRetry = d.retryPolicy != nil && d.retryPolicy(err) if isRetry { continue } @@ -144,7 +133,7 @@ func (d database) queryContextOnce(ctx context.Context, sqlStringWithCallerInfo interceptor := d.interceptor var rows *sql.Rows invoker := func(ctx context.Context, sql string) (err error) { - rows, err = d.getTxOrDB().QueryContext(ctx, sql) + rows, err = d.GetDB().QueryContext(ctx, sql) return } @@ -180,7 +169,7 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res var result sql.Result invoker := func(ctx context.Context, sql string) (err error) { - result, err = d.getTxOrDB().ExecContext(ctx, sql) + result, err = d.GetDB().ExecContext(ctx, sql) return } var err error diff --git a/expression.go b/expression.go index a4a1170..e2c52fc 100644 --- a/expression.go +++ b/expression.go @@ -147,9 +147,10 @@ func (e expression) GetTable() Table { } type scope struct { - Database *database - Tables []Table - lastJoin *join + Transaction *transaction + Database *database + Tables []Table + lastJoin *join } func staticExpression(sql string, priority priority, isBool bool) expression { diff --git a/field.go b/field.go index cbe5674..a20551f 100644 --- a/field.go +++ b/field.go @@ -61,7 +61,9 @@ func newField(table Table, fieldName string) actualField { expression: expression{ builder: func(scope scope) (string, error) { dialect := dialectUnknown - if scope.Database != nil { + if scope.Transaction != nil { + dialect = scope.Transaction.dialect + } else if scope.Database != nil { dialect = scope.Database.dialect } if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName { diff --git a/select.go b/select.go index 4797b01..61233e7 100644 --- a/select.go +++ b/select.go @@ -35,6 +35,7 @@ type toSelectJoin interface { type selectWithJoin interface { On(condition BooleanExpression) selectWithJoinOn + Using(fields ...Field) selectWithJoinOn } type selectWithJoinOn interface { @@ -143,6 +144,7 @@ type join struct { prefix string table Table on BooleanExpression + using []Field } type selectBase struct { @@ -219,6 +221,14 @@ func (s selectStatus) On(condition BooleanExpression) selectWithJoinOn { return s } +func (s selectStatus) Using(fields ...Field) selectWithJoinOn { + base := activeSelectBase(&s) + join := *base.scope.lastJoin + join.using = fields + base.scope.lastJoin = &join + return s +} + func getFields(fields []interface{}) (result []Field) { fields = expandSliceValues(fields) result = make([]Field, 0, len(fields)) @@ -579,11 +589,17 @@ func (s selectStatus) FetchCursor() (Cursor, error) { return nil, err } - cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString) + var c Cursor + if s.base.scope.Transaction != nil { + c, err = s.base.scope.Transaction.QueryContext(s.ctx, sqlString) + } else { + c, err = s.base.scope.Database.QueryContext(s.ctx, sqlString) + } + if err != nil { return nil, err } - return cursor, nil + return c, nil } func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) { diff --git a/table.go b/table.go index dc363f3..09afdf9 100644 --- a/table.go +++ b/table.go @@ -24,6 +24,9 @@ func (t table) GetName() string { } func (t table) GetSQL(scope scope) string { + if scope.Transaction != nil { + return t.sqlDialects[scope.Transaction.dialect] + } return t.sqlDialects[scope.Database.dialect] } diff --git a/transaction.go b/transaction.go index 00a6778..e419cfe 100644 --- a/transaction.go +++ b/transaction.go @@ -3,11 +3,11 @@ package sqlingo import ( "context" "database/sql" + "time" ) // Transaction is the interface of a transaction with underlying sql.Tx object. type Transaction interface { - GetDB() *sql.DB GetTx() *sql.Tx Query(sql string) (Cursor, error) Execute(sql string) (sql.Result, error) @@ -18,10 +18,13 @@ type Transaction interface { InsertInto(table Table) insertWithTable Update(table Table) updateWithSet DeleteFrom(table Table) deleteWithTable -} - -func (d *database) GetTx() *sql.Tx { - return d.tx + ReplaceInto(table Table) insertWithTable + // ReplaceInto(table Table) insertWithTable + Commit() error + Rollback() error + Savepoint(name string) error + RollbackTo(name string) error + ReleaseSavepoint(name string) error } func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error { @@ -40,8 +43,16 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T }() if f != nil { - db := *d + db := transaction{ + tx: tx, + logger: d.logger, + dialect: d.dialect, + retryPolicy: d.retryPolicy, + enableCallerInfo: d.enableCallerInfo, + interceptor: d.interceptor, + } db.tx = tx + err = f(&db) if err != nil { return err @@ -55,3 +66,213 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T isCommitted = true return nil } + +func (d *database) Begin() (Transaction, error) { + var err error + tx, err := d.db.Begin() + if err != nil { + return nil, err + } + t := &transaction{ + tx: tx, + logger: d.logger, + dialect: d.dialect, + retryPolicy: d.retryPolicy, + enableCallerInfo: d.enableCallerInfo, + interceptor: d.interceptor, + } + return t, nil +} + +type transaction struct { + tx *sql.Tx + logger func(sql string, durationNano int64) + dialect dialect + retryPolicy func(error) bool + enableCallerInfo bool + interceptor InterceptorFunc +} + +func (t transaction) GetTx() *sql.Tx { + return t.tx +} + +func (t transaction) Query(sql string) (Cursor, error) { + return t.QueryContext(context.Background(), sql) +} + +func (t transaction) QueryContext(ctx context.Context, sqlString string) (Cursor, error) { + isRetry := false + for { + sqlStringWithCallerInfo := getTxCallerInfo(t, isRetry) + sqlString + + rows, err := t.queryContextOnce(ctx, sqlStringWithCallerInfo) + if err != nil { + isRetry = t.tx == nil && t.retryPolicy != nil && t.retryPolicy(err) + if isRetry { + continue + } + return nil, err + } + return cursor{rows: rows}, nil + } +} + +func (t transaction) queryContextOnce(ctx context.Context, sqlStringWithCallerInfo string) (*sql.Rows, error) { + if ctx == nil { + ctx = context.Background() + } + startTime := time.Now().UnixNano() + defer func() { + endTime := time.Now().UnixNano() + if t.logger != nil { + t.logger(sqlStringWithCallerInfo, endTime-startTime) + } + }() + + interceptor := t.interceptor + var rows *sql.Rows + invoker := func(ctx context.Context, sql string) (err error) { + rows, err = t.GetTx().QueryContext(ctx, sql) + return + } + + var err error + if interceptor == nil { + err = invoker(ctx, sqlStringWithCallerInfo) + } else { + err = interceptor(ctx, sqlStringWithCallerInfo, invoker) + } + if err != nil { + return nil, err + } + return rows, nil +} + +func (t transaction) Execute(sql string) (sql.Result, error) { + return t.ExecuteContext(context.Background(), sql) +} + +func (t transaction) ExecuteContext(ctx context.Context, sqlString string) (sql.Result, error) { + if ctx == nil { + ctx = context.Background() + } + sqlStringWithCallerInfo := getTxCallerInfo(t, false) + sqlString + startTime := time.Now().UnixNano() + defer func() { + endTime := time.Now().UnixNano() + if t.logger != nil { + t.logger(sqlStringWithCallerInfo, endTime-startTime) + } + }() + + var result sql.Result + invoker := func(ctx context.Context, sql string) (err error) { + result, err = t.GetTx().ExecContext(ctx, sql) + return + } + var err error + if t.interceptor == nil { + err = invoker(ctx, sqlStringWithCallerInfo) + } else { + err = t.interceptor(ctx, sqlStringWithCallerInfo, invoker) + } + if err != nil { + return nil, err + } + + return result, err +} + +func (t transaction) Select(fields ...interface{}) selectWithFields { + return selectStatus{ + base: selectBase{ + scope: scope{ + Transaction: &t, + }, + fields: getFields(fields), + }, + } +} + +func (t transaction) SelectDistinct(fields ...interface{}) selectWithFields { + return selectStatus{ + base: selectBase{ + scope: scope{ + Transaction: &t, + }, + fields: getFields(fields), + distinct: true, + }, + } +} + +func (t transaction) SelectFrom(tables ...Table) selectWithTables { + return selectStatus{ + base: selectBase{ + scope: scope{ + Transaction: &t, + Tables: tables, + }, + }, + } +} + +func (t transaction) InsertInto(table Table) insertWithTable { + return insertStatus{ + scope: scope{ + Transaction: &t, + Tables: []Table{table}, + }, + } +} + +func (t transaction) Update(table Table) updateWithSet { + return updateStatus{ + scope: scope{ + Transaction: &t, + Tables: []Table{table}}, + } +} + +func (t transaction) DeleteFrom(table Table) deleteWithTable { + return deleteStatus{ + scope: scope{ + Transaction: &t, + Tables: []Table{table}, + }, + } +} + +func (t transaction) ReplaceInto(table Table) insertWithTable { + return insertStatus{ + method: "REPLACE", + scope: scope{ + Transaction: &t, + Tables: []Table{table}, + }, + } +} + +func (t transaction) Commit() error { + return t.GetTx().Commit() +} + +func (t transaction) Rollback() error { + return t.GetTx().Rollback() +} + +func (t transaction) Savepoint(name string) error { + _, err := t.GetTx().Exec("SAVEPOINT " + name) + return err +} + +func (t transaction) RollbackTo(name string) error { + _, err := t.GetTx().Exec("ROLLBACK TO " + name) + return err +} + +func (t transaction) ReleaseSavepoint(name string) error { + _, err := t.GetTx().Exec("RELEASE SAVEPOINT " + name) + return err +} diff --git a/transaction_test.go b/transaction_test.go index 619d5b6..7f82473 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -29,9 +29,7 @@ func TestTransaction(t *testing.T) { db := newMockDatabase() err := db.BeginTx(nil, nil, func(tx Transaction) error { - if tx.GetDB() != db.GetDB() { - t.Error() - } + if tx.GetTx() == nil { t.Error() }