Skip to content

Commit

Permalink
Merge pull request #108 from Xuyuanp/master
Browse files Browse the repository at this point in the history
Add new method WithTx for Database
  • Loading branch information
doug-martin authored Jul 24, 2019
2 parents c7d8e67 + 36be327 commit 8edff69
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
9 changes: 9 additions & 0 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ func (d *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*TxDatabas
return tx, nil
}

// WithTx starts a new transaction and executes it in Wrap method
func (d *Database) WithTx(fn func(*TxDatabase) error) error {
tx, err := d.Begin()
if err != nil {
return err
}
return tx.Wrap(func() error { return fn(tx) })
}

// Creates a new Dataset that uses the correct adapter and supports queries.
// var ids []uint32
// if err := db.From("items").Where(goqu.I("id").Gt(10)).Pluck("id", &ids); err != nil {
Expand Down
20 changes: 20 additions & 0 deletions database_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ func ExampleDatabase_BeginTx() {
// Updated users in transaction [ids:=[1 2 3]]
}

func ExampleDatabase_WithTx() {
db := getDb()
var ids []int64
if err := db.WithTx(func(tx *goqu.TxDatabase) error {
// use tx.From to get a dataset that will execute within this transaction
update := tx.From("goqu_user").
Where(goqu.Ex{"last_name": "Yukon"}).
Returning("id").
Update(goqu.Record{"last_name": "Ucon"})

return update.ScanVals(&ids)
}); err != nil {
fmt.Println("An error occurred in transaction\n\t", err.Error())
} else {
fmt.Printf("Updated users in transaction [ids:=%+v]", ids)
}
// Output:
// Updated users in transaction [ids:=[1 2 3]]
}

func ExampleDatabase_Dialect() {
db := getDb()

Expand Down
59 changes: 59 additions & 0 deletions database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,65 @@ func (dt *databaseTest) TestBeginTx() {
assert.EqualError(t, err, "goqu: transaction error")
}

func (dt *databaseTest) TestWithTx() {
t := dt.T()
mDb, mock, err := sqlmock.New()
assert.NoError(t, err)

db := newDatabase("mock", mDb)

cases := []struct {
expectf func(sqlmock.Sqlmock)
f func(*TxDatabase) error
wantErr bool
errStr string
}{
{
expectf: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectCommit()
},
f: func(_ *TxDatabase) error { return nil },
wantErr: false,
},
{
expectf: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin().WillReturnError(errors.New("transaction begin error"))
},
f: func(_ *TxDatabase) error { return nil },
wantErr: true,
errStr: "goqu: transaction begin error",
},
{
expectf: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
},
f: func(_ *TxDatabase) error { return errors.New("transaction error") },
wantErr: true,
errStr: "goqu: transaction error",
},
{
expectf: func(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback().WillReturnError(errors.New("transaction rollback error"))
},
f: func(_ *TxDatabase) error { return errors.New("something wrong") },
wantErr: true,
errStr: "goqu: transaction rollback error",
},
}
for _, c := range cases {
c.expectf(mock)
err := db.WithTx(c.f)
if c.wantErr {
assert.EqualError(t, err, c.errStr)
} else {
assert.NoError(t, err)
}
}
}

func TestDatabaseSuite(t *testing.T) {
suite.Run(t, new(databaseTest))
}
Expand Down

0 comments on commit 8edff69

Please sign in to comment.