Skip to content

Commit

Permalink
[#88] fix default ContextOperator: change B to comparable
Browse files Browse the repository at this point in the history
  • Loading branch information
kozmod committed Jan 23, 2024
1 parent d903aba commit 50b9ec4
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 43 deletions.
14 changes: 8 additions & 6 deletions operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,26 @@ import (
)

// ContextOperator inject and extract Tx from context.Context.
type ContextOperator[B any, T Tx] struct {
beginner *B
//
// Default ContextOperator uses comparable key for context.Context value operation.
type ContextOperator[B comparable, T Tx] struct {
key B
}

// NewContextOperator returns new ContextOperator.
func NewContextOperator[B any, T Tx](b *B) *ContextOperator[B, T] {
func NewContextOperator[B comparable, T Tx](key B) *ContextOperator[B, T] {
return &ContextOperator[B, T]{
beginner: b,
key: key,
}
}

// Inject returns new context.Context contains Tx as value.
func (p *ContextOperator[B, T]) Inject(ctx context.Context, tx T) context.Context {
return context.WithValue(ctx, p.beginner, tx)
return context.WithValue(ctx, p.key, tx)
}

// Extract returns Tx extracted from context.Context.
func (p *ContextOperator[B, T]) Extract(ctx context.Context) (T, bool) {
c, ok := ctx.Value(p.beginner).(T)
c, ok := ctx.Value(p.key).(T)
return c, ok
}
6 changes: 3 additions & 3 deletions stdlib/transactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type dbWrapper struct {
}

// BeginTx starts a transaction.
func (db *dbWrapper) BeginTx(ctx context.Context, opts ...oniontx.Option[*sql.TxOptions]) (*txWrapper, error) {
func (db dbWrapper) BeginTx(ctx context.Context, opts ...oniontx.Option[*sql.TxOptions]) (*txWrapper, error) {
var txOptions sql.TxOptions
for _, opt := range opts {
opt.Apply(&txOptions)
Expand Down Expand Up @@ -58,12 +58,12 @@ type Transactor struct {
// NewTransactor returns new Transactor.
func NewTransactor(db *sql.DB) *Transactor {
var (
base = &dbWrapper{DB: db}
base = dbWrapper{DB: db}
operator = oniontx.NewContextOperator[*dbWrapper, *txWrapper](&base)
transactor = Transactor{
operator: operator,
transactor: oniontx.NewTransactor[*dbWrapper, *txWrapper, *sql.TxOptions](
base,
&base,
operator,
),
}
Expand Down
125 changes: 91 additions & 34 deletions transactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,51 @@ import (

func Test_CtxOperator(t *testing.T) {
t.Run("success_extract_committer", func(t *testing.T) {
var (
ctx = context.Background()
c = committerMock{}
b = &beginnerMock[*committerMock, any]{}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
)
ctx = o.Inject(ctx, &c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == &c)
t.Run("extract_pointer", func(t *testing.T) {
var (
ctx = context.Background()
c = committerMock{}
b = beginnerMock[*committerMock, any]{}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
)
ctx = o.Inject(ctx, &c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == &c)
})
t.Run("extract_value", func(t *testing.T) {
var (
ctx = context.Background()
c = committerValueMock{
committer: &committerMock{},
}
b = beginnerValueMock[committerValueMock, any]{
beginner: &beginnerMock[committerValueMock, any]{},
}
o = NewContextOperator[beginnerValueMock[committerValueMock, any], committerValueMock](b)
)
ctx = o.Inject(ctx, c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == c)
})
t.Run("extract_nil_value", func(t *testing.T) {
var (
ctx = context.Background()
c = committerValueMock{
committer: nil,
}
b = beginnerValueMock[committerValueMock, any]{
beginner: nil,
}
o = NewContextOperator[beginnerValueMock[committerValueMock, any], committerValueMock](b)
)
ctx = o.Inject(ctx, c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == c)
})

})
}

Expand All @@ -36,15 +71,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := tr.TryGetTx(ctx)
Expand All @@ -67,20 +102,20 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
beginner := tr.TxBeginner()
assertTrue(t, beginner != nil)
assertTrue(t, b == beginner)
assertTrue(t, &b == beginner)
return nil
})
assertTrue(t, err == nil)
Expand All @@ -99,15 +134,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -129,9 +164,9 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{}
b = beginnerMock[*committerMock, any]{}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
ctx = o.Inject(ctx, &c)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand All @@ -155,14 +190,14 @@ func Test_Transactor(t *testing.T) {
return expError
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -187,15 +222,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -222,15 +257,15 @@ func Test_Transactor(t *testing.T) {
return rollbackErr
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -256,15 +291,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand Down Expand Up @@ -300,7 +335,7 @@ func Test_Transactor(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand All @@ -327,7 +362,7 @@ func Test_Transactor(t *testing.T) {
return nil, expError
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand Down Expand Up @@ -435,7 +470,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand Down Expand Up @@ -472,7 +507,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -535,7 +570,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -600,7 +635,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -667,7 +702,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -735,6 +770,28 @@ func (c *committerMock) Rollback(ctx context.Context) error {
return c.rollbackFn(ctx)
}

// beginnerValueMock was added to avoid to use external dependencies for mocking
type beginnerValueMock[T Tx, O any] struct {
beginner *beginnerMock[T, O]
}

func (b beginnerValueMock[T, O]) BeginTx(ctx context.Context, opts ...Option[O]) (T, error) {
return b.beginner.beginFn(ctx, opts...)
}

// committerValueMock was added to avoid to use external dependencies for mocking
type committerValueMock struct {
committer *committerMock
}

func (c committerValueMock) Commit(ctx context.Context) error {
return c.committer.commitFn(ctx)
}

func (c committerValueMock) Rollback(ctx context.Context) error {
return c.committer.commitFn(ctx)
}

// assertTrue was added to avoid to use external dependencies for mocking
func assertTrue(t *testing.T, val bool) {
t.Helper()
Expand Down

0 comments on commit 50b9ec4

Please sign in to comment.