Skip to content

Commit

Permalink
add ChainInterceptors
Browse files Browse the repository at this point in the history
  • Loading branch information
lqs committed Sep 6, 2024
1 parent 2512270 commit 65c4cf7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
23 changes: 22 additions & 1 deletion interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,25 @@ type InvokerFunc = func(ctx context.Context, sql string) error
// InterceptorFunc is the function type of an interceptor. An interceptor should implement this function to fulfill it's purpose.
type InterceptorFunc = func(ctx context.Context, sql string, invoker InvokerFunc) error

// TODO: add some common interceptors
func noopInterceptor(ctx context.Context, sql string, invoker InvokerFunc) error {
return invoker(ctx, sql)
}

// ChainInterceptors chains multiple interceptors into one interceptor.
func ChainInterceptors(interceptors ...InterceptorFunc) InterceptorFunc {
if len(interceptors) == 0 {
return noopInterceptor
}
return func(ctx context.Context, sql string, invoker InvokerFunc) error {
var chain func(int, context.Context, string) error
chain = func(i int, ctx context.Context, sql string) error {
if i == len(interceptors) {
return invoker(ctx, sql)
}
return interceptors[i](ctx, sql, func(ctx context.Context, sql string) error {
return chain(i+1, ctx, sql)
})
}
return chain(0, ctx, sql)
}
}
55 changes: 55 additions & 0 deletions interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package sqlingo

import (
"context"
"testing"
)

func TestChainInterceptors(t *testing.T) {
s := ""
i1 := func(ctx context.Context, sql string, invoker InvokerFunc) error {
s += "<i1>"
s += sql
defer func() {
s += "</i1>"
}()
return invoker(ctx, sql+"s1")
}
i2 := func(ctx context.Context, sql string, invoker InvokerFunc) error {
s += "<i2>"
s += sql
defer func() {
s += "</i2>"
}()
return invoker(ctx, sql+"s2")
}
chain := ChainInterceptors(i1, i2)
_ = chain(context.Background(), "sql", func(ctx context.Context, sql string) error {
s += "<invoker>"
s += sql
defer func() {
s += "</invoker>"
}()
return nil
})
if s != "<i1>sql<i2>sqls1<invoker>sqls1s2</invoker></i2></i1>" {
t.Error(s)
}
}

func TestEmptyChainInterceptors(t *testing.T) {
s := ""
chain := ChainInterceptors()
_ = chain(context.Background(), "sql", func(ctx context.Context, sql string) error {
s += "<invoker>"
defer func() {
s += "</invoker>"
}()
s += sql
return nil
})

if s != "<invoker>sql</invoker>" {
t.Error(s)
}
}

0 comments on commit 65c4cf7

Please sign in to comment.