Skip to content

Commit

Permalink
test: add unit test for transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
VarusHsu committed Feb 18, 2024
1 parent 5d8d7fc commit ed3a4af
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 2 deletions.
3 changes: 3 additions & 0 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,19 @@ func (t transaction) Rollback() error {
return t.GetTx().Rollback()
}

// Savepoint todo defend sql injection
func (t transaction) Savepoint(name string) error {
_, err := t.GetTx().Exec("SAVEPOINT " + name)
return err
}

// RollbackTo todo defend sql injection
func (t transaction) RollbackTo(name string) error {
_, err := t.GetTx().Exec("ROLLBACK TO " + name)
return err
}

// ReleaseSavepoint todo defend sql injection
func (t transaction) ReleaseSavepoint(name string) error {
_, err := t.GetTx().Exec("RELEASE SAVEPOINT " + name)
return err
Expand Down
138 changes: 136 additions & 2 deletions transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package sqlingo

import (
"context"
"database/sql"
"errors"
"github.com/go-playground/assert/v2"
"testing"
)

Expand All @@ -27,9 +29,7 @@ func (m *mockTx) Rollback() error {

func TestTransaction(t *testing.T) {
db := newMockDatabase()

err := db.BeginTx(nil, nil, func(tx Transaction) error {

if tx.GetTx() == nil {
t.Error()
}
Expand Down Expand Up @@ -80,3 +80,137 @@ func TestTransaction(t *testing.T) {
t.Error("should get error here")
}
}

func TestTransaction_Commit(t *testing.T) {
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}

if err = tx.Commit(); err != nil {
t.Error(err)
}

if !sharedMockConn.mockTx.isCommitted {
t.Error()
}
}

func TestTransaction_Rollback(t *testing.T) {
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}

if err = tx.Rollback(); err != nil {
t.Error(err)
}
if !sharedMockConn.mockTx.isRolledBack {
t.Error()
}
}

func TestTransaction_Done(t *testing.T) {
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}

if err = tx.Commit(); err != nil {
t.Error(err)
}

if err = tx.Rollback(); !errors.Is(err, sql.ErrTxDone) {
t.Error(err)
}

if err = tx.Commit(); !errors.Is(err, sql.ErrTxDone) {
t.Error(err)
}

if _, err = tx.Select(1).FetchAll(); !errors.Is(err, sql.ErrTxDone) {
t.Error(err)
}
}

func TestTransaction_Execute(t *testing.T) {
var sqlCount = make(map[string]int)
db := newMockDatabase()

tx, err := db.Begin()
if err != nil {
t.Error(err)
}
db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error {
sqlCount[sql]++
return invoker(ctx, sql)
})

if _, err = tx.Execute("SQL 1 NOT SET INTERCEPTOR"); err != nil {
t.Error(err)
}
assert.Equal(t, sqlCount["SQL 1 NOT SET INTERCEPTOR"], 0)

if err = tx.Rollback(); err != nil {
t.Error(err)
}

tx, err = db.Begin()
if err != nil {
t.Error(err)
}
if _, err = tx.Execute("SQL 2 SET INTERCEPTOR"); err != nil {
t.Error(err)
}
assert.Equal(t, sqlCount["SQL 2 SET INTERCEPTOR"], 1)

if err = tx.Commit(); err != nil {
t.Error(err)
}
}

// TestTransaction_CRUD tests the CRUD operations in a transaction, cause sql build is tested on database,
// so we only insure there is no panic here.
func TestTransaction_CRUD(t *testing.T) {
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
_, err = tx.Select().From(table1).FetchAll()
if err != nil {
t.Error(err)
}

if _, err = tx.SelectFrom(table1).FetchAll(); err != nil {
t.Error(err)
}

if _, err = tx.SelectDistinct(field2).From(table1).FetchAll(); err != nil {
t.Error(err)
}

if _, err = tx.InsertInto(Test).Values(1, 2).Execute(); err != nil {
t.Error(err)
}

if _, err = tx.ReplaceInto(Test).Values(1, 2).Execute(); err != nil {
t.Error(err)
}

if _, err = tx.DeleteFrom(table1).Where().Execute(); err != nil {
t.Error(err)
}

if _, err = tx.Update(table1).Set(field1, 1).Where().Execute(); err != nil {
t.Error(err)
}

if err = tx.Rollback(); err != nil {
t.Error(err)
}

}

0 comments on commit ed3a4af

Please sign in to comment.