Skip to content

Commit

Permalink
feat: transaction support powerful more
Browse files Browse the repository at this point in the history
  • Loading branch information
VarusHsu committed Feb 6, 2024
1 parent 13b7aba commit cf1908f
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 30 deletions.
23 changes: 23 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) {
}

func getCallerInfo(db database, retry bool) string {
if !db.enableCallerInfo {
return ""
}
extraInfo := ""
if retry {
extraInfo += " (retry)"
}
for i := 0; true; i++ {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
if file == "" || strings.Contains(file, "/sqlingo@v") {
continue
}
segs := strings.Split(file, "/")
name := segs[len(segs)-1]
return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo)
}
return ""
}

func getTxCallerInfo(db transaction, retry bool) string {
if !db.enableCallerInfo {
return ""
}
Expand Down
19 changes: 4 additions & 15 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,12 @@ type Database interface {
Update(table Table) updateWithSet
// Initiate a DELETE FROM statement
DeleteFrom(table Table) deleteWithTable
}

type txOrDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Begin() (Transaction, error)
}

type database struct {
db *sql.DB
tx *sql.Tx
logger func(sql string, durationNano int64)
dialect dialect
retryPolicy func(error) bool
Expand Down Expand Up @@ -101,13 +97,6 @@ func (d database) GetDB() *sql.DB {
return d.db
}

func (d database) getTxOrDB() txOrDB {
if d.tx != nil {
return d.tx
}
return d.db
}

func (d database) Query(sqlString string) (Cursor, error) {
return d.QueryContext(context.Background(), sqlString)
}
Expand All @@ -119,7 +108,7 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e

rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo)
if err != nil {
isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err)
isRetry = d.retryPolicy != nil && d.retryPolicy(err)
if isRetry {
continue
}
Expand All @@ -144,7 +133,7 @@ func (d database) queryContextOnce(ctx context.Context, sqlStringWithCallerInfo
interceptor := d.interceptor
var rows *sql.Rows
invoker := func(ctx context.Context, sql string) (err error) {
rows, err = d.getTxOrDB().QueryContext(ctx, sql)
rows, err = d.GetDB().QueryContext(ctx, sql)
return
}

Expand Down Expand Up @@ -180,7 +169,7 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res

var result sql.Result
invoker := func(ctx context.Context, sql string) (err error) {
result, err = d.getTxOrDB().ExecContext(ctx, sql)
result, err = d.GetDB().ExecContext(ctx, sql)
return
}
var err error
Expand Down
7 changes: 4 additions & 3 deletions expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ func (e expression) GetTable() Table {
}

type scope struct {
Database *database
Tables []Table
lastJoin *join
Transaction *transaction
Database *database
Tables []Table
lastJoin *join
}

func staticExpression(sql string, priority priority, isBool bool) expression {
Expand Down
4 changes: 3 additions & 1 deletion field.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ func newField(table Table, fieldName string) actualField {
expression: expression{
builder: func(scope scope) (string, error) {
dialect := dialectUnknown
if scope.Database != nil {
if scope.Transaction != nil {
dialect = scope.Transaction.dialect
} else if scope.Database != nil {
dialect = scope.Database.dialect
}
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
Expand Down
20 changes: 18 additions & 2 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type toSelectJoin interface {

type selectWithJoin interface {
On(condition BooleanExpression) selectWithJoinOn
Using(fields ...Field) selectWithJoinOn
}

type selectWithJoinOn interface {
Expand Down Expand Up @@ -143,6 +144,7 @@ type join struct {
prefix string
table Table
on BooleanExpression
using []Field
}

type selectBase struct {
Expand Down Expand Up @@ -219,6 +221,14 @@ func (s selectStatus) On(condition BooleanExpression) selectWithJoinOn {
return s
}

func (s selectStatus) Using(fields ...Field) selectWithJoinOn {
base := activeSelectBase(&s)
join := *base.scope.lastJoin
join.using = fields
base.scope.lastJoin = &join
return s
}

func getFields(fields []interface{}) (result []Field) {
fields = expandSliceValues(fields)
result = make([]Field, 0, len(fields))
Expand Down Expand Up @@ -579,11 +589,17 @@ func (s selectStatus) FetchCursor() (Cursor, error) {
return nil, err
}

cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString)
var c Cursor
if s.base.scope.Transaction != nil {
c, err = s.base.scope.Transaction.QueryContext(s.ctx, sqlString)
} else {
c, err = s.base.scope.Database.QueryContext(s.ctx, sqlString)
}

if err != nil {
return nil, err
}
return cursor, nil
return c, nil
}

func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) {
Expand Down
3 changes: 3 additions & 0 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ func (t table) GetName() string {
}

func (t table) GetSQL(scope scope) string {
if scope.Transaction != nil {
return t.sqlDialects[scope.Transaction.dialect]
}
return t.sqlDialects[scope.Database.dialect]
}

Expand Down
Loading

0 comments on commit cf1908f

Please sign in to comment.