diff --git a/database.go b/database.go index 7e457b47..e58d6e0f 100644 --- a/database.go +++ b/database.go @@ -90,6 +90,15 @@ func (d *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*TxDatabas return tx, nil } +// WithTx starts a new transaction and executes it in Wrap method +func (d *Database) WithTx(fn func(*TxDatabase) error) error { + tx, err := d.Begin() + if err != nil { + return err + } + return tx.Wrap(func() error { return fn(tx) }) +} + // Creates a new Dataset that uses the correct adapter and supports queries. // var ids []uint32 // if err := db.From("items").Where(goqu.I("id").Gt(10)).Pluck("id", &ids); err != nil { diff --git a/database_example_test.go b/database_example_test.go index 7878dcf5..627b7437 100644 --- a/database_example_test.go +++ b/database_example_test.go @@ -74,6 +74,26 @@ func ExampleDatabase_BeginTx() { // Updated users in transaction [ids:=[1 2 3]] } +func ExampleDatabase_WithTx() { + db := getDb() + var ids []int64 + if err := db.WithTx(func(tx *goqu.TxDatabase) error { + // use tx.From to get a dataset that will execute within this transaction + update := tx.From("goqu_user"). + Where(goqu.Ex{"last_name": "Yukon"}). + Returning("id"). + Update(goqu.Record{"last_name": "Ucon"}) + + return update.ScanVals(&ids) + }); err != nil { + fmt.Println("An error occurred in transaction\n\t", err.Error()) + } else { + fmt.Printf("Updated users in transaction [ids:=%+v]", ids) + } + // Output: + // Updated users in transaction [ids:=[1 2 3]] +} + func ExampleDatabase_Dialect() { db := getDb() diff --git a/database_test.go b/database_test.go index 31152363..036e387e 100644 --- a/database_test.go +++ b/database_test.go @@ -267,6 +267,65 @@ func (dt *databaseTest) TestBeginTx() { assert.EqualError(t, err, "goqu: transaction error") } +func (dt *databaseTest) TestWithTx() { + t := dt.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + db := newDatabase("mock", mDb) + + cases := []struct { + expectf func(sqlmock.Sqlmock) + f func(*TxDatabase) error + wantErr bool + errStr string + }{ + { + expectf: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectCommit() + }, + f: func(_ *TxDatabase) error { return nil }, + wantErr: false, + }, + { + expectf: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin().WillReturnError(errors.New("transaction begin error")) + }, + f: func(_ *TxDatabase) error { return nil }, + wantErr: true, + errStr: "goqu: transaction begin error", + }, + { + expectf: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectRollback() + }, + f: func(_ *TxDatabase) error { return errors.New("transaction error") }, + wantErr: true, + errStr: "goqu: transaction error", + }, + { + expectf: func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectRollback().WillReturnError(errors.New("transaction rollback error")) + }, + f: func(_ *TxDatabase) error { return errors.New("something wrong") }, + wantErr: true, + errStr: "goqu: transaction rollback error", + }, + } + for _, c := range cases { + c.expectf(mock) + err := db.WithTx(c.f) + if c.wantErr { + assert.EqualError(t, err, c.errStr) + } else { + assert.NoError(t, err) + } + } +} + func TestDatabaseSuite(t *testing.T) { suite.Run(t, new(databaseTest)) }