Skip to content

Commit

Permalink
add WithContext for insert / update / delete
Browse files Browse the repository at this point in the history
  • Loading branch information
lqs committed Nov 15, 2024
1 parent e01ed0b commit 57bcb50
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 3 deletions.
16 changes: 15 additions & 1 deletion delete.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlingo

import (
"context"
"database/sql"
"strconv"
"strings"
Expand All @@ -11,27 +12,35 @@ type deleteStatus struct {
where BooleanExpression
orderBys []OrderBy
limit *int
ctx context.Context
}

type deleteWithTable interface {
Where(conditions ...BooleanExpression) deleteWithWhere
}

type deleteWithWhere interface {
toDeleteWithContext
toDeleteFinal
OrderBy(orderBys ...OrderBy) deleteWithOrder
Limit(limit int) deleteWithLimit
}

type deleteWithOrder interface {
toDeleteWithContext
toDeleteFinal
Limit(limit int) deleteWithLimit
}

type deleteWithLimit interface {
toDeleteWithContext
toDeleteFinal
}

type toDeleteWithContext interface {
WithContext(ctx context.Context) toDeleteFinal
}

type toDeleteFinal interface {
GetSQL() (string, error)
Execute() (result sql.Result, err error)
Expand Down Expand Up @@ -86,10 +95,15 @@ func (s deleteStatus) GetSQL() (string, error) {
return sb.String(), nil
}

func (s deleteStatus) WithContext(ctx context.Context) toDeleteFinal {
s.ctx = ctx
return s
}

func (s deleteStatus) Execute() (sql.Result, error) {
sqlString, err := s.GetSQL()
if err != nil {
return nil, err
}
return s.scope.Database.Execute(sqlString)
return s.scope.Database.ExecuteContext(s.ctx, sqlString)
}
6 changes: 6 additions & 0 deletions delete_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlingo

import (
"context"
"errors"
"testing"
)
Expand Down Expand Up @@ -30,4 +31,9 @@ func TestDelete(t *testing.T) {
t.Error(err)
}
assertLastSql(t, "DELETE FROM `table1` WHERE #1# ORDER BY #2# LIMIT 3")

if _, err := db.DeleteFrom(Table1).Where(Raw("#1#")).WithContext(context.Background()).Execute(); err != nil {
t.Error(err)
}
assertLastSql(t, "DELETE FROM `table1` WHERE #1#")
}
16 changes: 15 additions & 1 deletion insert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlingo

import (
"context"
"database/sql"
"errors"
"fmt"
Expand All @@ -14,6 +15,7 @@ type insertStatus struct {
values []interface{}
models []interface{}
onDuplicateKeyUpdateAssignments []assignment
ctx context.Context
}

type insertWithTable interface {
Expand All @@ -23,13 +25,15 @@ type insertWithTable interface {
}

type insertWithValues interface {
toInsertWithContext
toInsertFinal
Values(values ...interface{}) insertWithValues
OnDuplicateKeyIgnore() toInsertFinal
OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin
}

type insertWithModels interface {
toInsertWithContext
toInsertFinal
Models(models ...interface{}) insertWithModels
OnDuplicateKeyIgnore() toInsertFinal
Expand All @@ -42,11 +46,16 @@ type insertWithOnDuplicateKeyUpdateBegin interface {
}

type insertWithOnDuplicateKeyUpdate interface {
toInsertWithContext
toInsertFinal
Set(Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
SetIf(condition bool, Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
}

type toInsertWithContext interface {
WithContext(ctx context.Context) toInsertFinal
}

type toInsertFinal interface {
GetSQL() (string, error)
Execute() (result sql.Result, err error)
Expand Down Expand Up @@ -182,10 +191,15 @@ func (s insertStatus) GetSQL() (string, error) {
return sqlString, nil
}

func (s insertStatus) WithContext(ctx context.Context) toInsertFinal {
s.ctx = ctx
return s
}

func (s insertStatus) Execute() (result sql.Result, err error) {
sqlString, err := s.GetSQL()
if err != nil {
return nil, err
}
return s.scope.Database.Execute(sqlString)
return s.scope.Database.ExecuteContext(s.ctx, sqlString)
}
5 changes: 5 additions & 0 deletions insert_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlingo

import (
"context"
"errors"
"testing"
)
Expand Down Expand Up @@ -142,4 +143,8 @@ func TestInsert(t *testing.T) {
OnDuplicateKeyUpdate().Set(Test.F1, errExpr).Execute(); err == nil {
t.Error("should get error here")
}

if _, err := db.InsertInto(Test).Fields(Test.F1).Values(1).WithContext(context.Background()).Execute(); err != nil {
t.Error(err)
}
}
16 changes: 15 additions & 1 deletion update.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlingo

import (
"context"
"database/sql"
"strconv"
"strings"
Expand All @@ -12,6 +13,7 @@ type updateStatus struct {
where BooleanExpression
orderBys []OrderBy
limit *int
ctx context.Context
}

func (d *database) Update(table Table) updateWithSet {
Expand All @@ -27,20 +29,27 @@ type updateWithSet interface {
}

type updateWithWhere interface {
toUpdateWithContext
toUpdateFinal
OrderBy(orderBys ...OrderBy) updateWithOrder
Limit(limit int) updateWithLimit
}

type updateWithOrder interface {
toUpdateWithContext
toUpdateFinal
Limit(limit int) updateWithLimit
}

type updateWithLimit interface {
toUpdateWithContext
toUpdateFinal
}

type toUpdateWithContext interface {
WithContext(ctx context.Context) toUpdateFinal
}

type toUpdateFinal interface {
GetSQL() (string, error)
Execute() (sql.Result, error)
Expand Down Expand Up @@ -120,10 +129,15 @@ func (s updateStatus) GetSQL() (string, error) {
return sb.String(), nil
}

func (s updateStatus) WithContext(ctx context.Context) toUpdateFinal {
s.ctx = ctx
return s
}

func (s updateStatus) Execute() (sql.Result, error) {
sqlString, err := s.GetSQL()
if err != nil {
return nil, err
}
return s.scope.Database.Execute(sqlString)
return s.scope.Database.ExecuteContext(s.ctx, sqlString)
}
8 changes: 8 additions & 0 deletions update_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlingo

import (
"context"
"errors"
"testing"
)
Expand Down Expand Up @@ -62,4 +63,11 @@ func TestUpdate(t *testing.T) {
t.Error("should get error here")
}

if _, err := db.Update(Table1).
Set(field1, 10).
Where(True()).
WithContext(context.Background()).
Execute(); err != nil {
t.Error(err)
}
}

0 comments on commit 57bcb50

Please sign in to comment.