From 248372ec6ecb53a2fb7d3809628386bc7b36ee47 Mon Sep 17 00:00:00 2001 From: Masudur Rahman Date: Thu, 4 Apr 2024 02:26:20 +0600 Subject: [PATCH] Add UnitOfWork for all databases Signed-off-by: Masudur Rahman --- sql/postgres/postgres_test.go | 17 ++++++++++++ uow.go | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 uow.go diff --git a/sql/postgres/postgres_test.go b/sql/postgres/postgres_test.go index fc508d6..d9d7601 100644 --- a/sql/postgres/postgres_test.go +++ b/sql/postgres/postgres_test.go @@ -51,6 +51,13 @@ func TestPostgres_FindOne(t *testing.T) { db, closer := initializeDB(t) defer closer() + db, err := db.BeginTx() + assert.Nil(t, err) + defer func() { + err = db.Commit() + assert.Nil(t, err) + }() + user := TestUser{} db = db.Table("test_user") @@ -93,10 +100,13 @@ func TestPostgres_FindMany(t *testing.T) { func TestPostgres_InsertOne(t *testing.T) { db, closer := initializeDB(t) defer closer() + db, err := db.BeginTx() + assert.Nil(t, err) db = db.Table("test_user") t.Run("insert data", func(t *testing.T) { suffix := xid.New().String() + //suffix := "hello" user := TestUser{ Name: "test-" + suffix, FullName: "Test Name", @@ -105,6 +115,13 @@ func TestPostgres_InsertOne(t *testing.T) { id, err := db.InsertOne(&user) assert.Nil(t, err) assert.NotEqual(t, 0, id) + if err != nil { + err = db.Rollback() + assert.Nil(t, err) + } + + err = db.Commit() + assert.Nil(t, err) }) } diff --git a/uow.go b/uow.go new file mode 100644 index 0000000..679f886 --- /dev/null +++ b/uow.go @@ -0,0 +1,51 @@ +package database + +import ( + "github.com/masudur-rahman/database/nosql" + "github.com/masudur-rahman/database/sql" +) + +// UnitOfWork represents the unit of work for coordinating transactions +type UnitOfWork struct { + SQL sql.Database + NoSQL nosql.Database +} + +// Begin starts a new transaction +func (uow UnitOfWork) Begin() (UnitOfWork, error) { + cp := UnitOfWork{ + SQL: uow.SQL, + NoSQL: uow.NoSQL, + } + if uow.SQL != nil { + sqlTx, err := uow.SQL.BeginTx() + if err != nil { + return UnitOfWork{}, err + } + cp.SQL = sqlTx + } + // For NoSQL databases, no action needed for beginning a transaction + return cp, nil +} + +// Commit commits the transaction +func (uow UnitOfWork) Commit() error { + if uow.SQL != nil { + if err := uow.SQL.Commit(); err != nil { + return err + } + } + // For NoSQL databases, no action needed for committing a transaction + return nil +} + +// Rollback rolls back the transaction +func (uow UnitOfWork) Rollback() error { + if uow.SQL != nil { + if err := uow.SQL.Rollback(); err != nil { + return err + } + } + // For NoSQL databases, no action needed for rolling back a transaction + return nil +}