Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce RPC interceptor mechanism #389

Merged
merged 19 commits into from
Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions integration_tests/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2021 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tikv_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/tikv/client-go/v2/tikv"
"github.com/tikv/client-go/v2/tikvrpc"
)

func TestInterceptor(t *testing.T) {
store := NewTestStore(t)
defer func() {
assert.NoError(t, store.Close())
}()
store.SetTiKVClient(&mockRPCClient{store.GetTiKVClient()})
manager := tikvrpc.MockInterceptorManager{}

txn, err := store.Begin()
txn.SetInterceptor(manager.CreateMockInterceptor())
assert.NoError(t, err)
err = txn.Set([]byte("KEY-1"), []byte("VALUE-1"))
assert.NoError(t, err)
err = txn.Commit(context.Background())
assert.NoError(t, err)
assert.Equal(t, 2, manager.BeginCount())
assert.Equal(t, 2, manager.EndCount())
manager.Reset()

txn, err = store.Begin()
txn.SetInterceptor(manager.CreateMockInterceptor())
assert.NoError(t, err)
value, err := txn.Get(context.Background(), []byte("KEY-1"))
assert.NoError(t, err)
assert.Equal(t, []byte("VALUE-1"), value)
assert.Equal(t, 1, manager.BeginCount())
assert.Equal(t, 1, manager.EndCount())
manager.Reset()
}

type mockRPCClient struct {
tikv.Client
}

func (c *mockRPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) {
if interceptor := tikvrpc.GetInterceptorFromCtx(ctx); interceptor != nil {
return interceptor(func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) {
return c.Client.SendRequest(ctx, addr, req, timeout)
})(addr, req)
}
return c.Client.SendRequest(ctx, addr, req, timeout)
}
11 changes: 11 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,18 @@ func (c *RPCClient) updateTiKVSendReqHistogram(req *tikvrpc.Request, start time.
}

// SendRequest sends a Request to server and receives Response.
// If tikvrpc.Interceptor has been set in ctx, it will be used to wrap RPC action.
func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) {
if interceptor := tikvrpc.GetInterceptorFromCtx(ctx); interceptor != nil {
return interceptor(func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) {
return c.sendRequest(ctx, target, req, timeout)
})(addr, req)
}
return c.sendRequest(ctx, addr, req, timeout)
}

// sendRequest sends a Request to server and receives Response.
func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan(fmt.Sprintf("rpcClient.SendRequest, region ID: %d, type: %s", req.RegionId, req.Type), opentracing.ChildOf(span.Context()))
defer span1.Finish()
Expand Down
187 changes: 187 additions & 0 deletions tikvrpc/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright 2021 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tikvrpc
mornyx marked this conversation as resolved.
Show resolved Hide resolved

import (
"context"
"sync/atomic"
)

// Interceptor is used to decorate the RPC requests to TiKV.
//
// The definition of an interceptor is: Given an InterceptorFunc, we will
// get the decorated InterceptorFunc with additional logic before and after
// the execution of the given InterceptorFunc.
//
// The decorated InterceptorFunc will be executed before and after the real
// RPC request is initiated to TiKV.
//
// We can implement an Interceptor like this:
// ```
// func LogInterceptor(next InterceptorFunc) InterceptorFunc {
// return func(target string, req *Request) (*Response, error) {
// log.Println("before")
// resp, err := next(target, req)
// log.Println("after")
// return resp, err
// }
// }
// txn.SetInterceptor(LogInterceptor)
// ```
//
// Or you want to inject some dependent modules:
// ```
// func GetLogInterceptor(lg *log.Logger) Interceptor {
// return func(next InterceptorFunc) InterceptorFunc {
// return func(target string, req *Request) (*Response, error) {
// lg.Println("before")
// resp, err := next(target, req)
// lg.Println("after")
// return resp, err
// }
// }
// }
// txn.SetInterceptor(GetLogInterceptor())
// ```
type Interceptor func(next InterceptorFunc) InterceptorFunc
mornyx marked this conversation as resolved.
Show resolved Hide resolved

// InterceptorFunc is a callable function used to initiate a request to TiKV.
// It is mainly used as the parameter and return value of Interceptor.
type InterceptorFunc func(target string, req *Request) (*Response, error)

// InterceptorChain is used to combine multiple Interceptors into one.
// Multiple interceptors will be executed in the order of link time, but are more
// similar to the onion model: The earlier the interceptor is executed, the later
// it will return.
//
// We can use InterceptorChain like this:
// ```
// func Interceptor1(next InterceptorFunc) InterceptorFunc {
// return func(target string, req *Request) (*Response, error) {
// fmt.Println("begin-interceptor-1")
// defer fmt.Println("end-interceptor-1")
// return next(target, req)
// }
// }
// func Interceptor2(next InterceptorFunc) InterceptorFunc {
// return func(target string, req *Request) (*Response, error) {
// fmt.Println("begin-interceptor-2")
// defer fmt.Println("end-interceptor-2")
// return next(target, req)
// }
// }
// txn.SetInterceptor(NewInterceptorChain().Link(Interceptor1).Link(Interceptor2).Build())
// ```
//
// Then every time an RPC request is initiated, the following text will be printed:
// ```
// begin-interceptor-1
// begin-interceptor-2
// /* do request & respond here */
// end-interceptor-2
// end-interceptor-1
// ```
type InterceptorChain struct {
chain []Interceptor
}

// NewInterceptorChain creates an empty InterceptorChain.
func NewInterceptorChain() *InterceptorChain {
return &InterceptorChain{}
}

// Link is used to link the next Interceptor.
// Multiple interceptors will be executed in the order of link time.
func (c *InterceptorChain) Link(it Interceptor) *InterceptorChain {
c.chain = append(c.chain, it)
return c
}

// Build merges the previously linked interceptors into one.
func (c *InterceptorChain) Build() Interceptor {
return func(next InterceptorFunc) InterceptorFunc {
for n := len(c.chain) - 1; n >= 0; n-- {
next = c.chain[n](next)
}
return next
}
}
mornyx marked this conversation as resolved.
Show resolved Hide resolved

type interceptorCtxKeyType struct{}

var interceptorCtxKey = interceptorCtxKeyType{}

// SetInterceptorIntoCtx is a helper function used to bind Interceptor into ctx.
// Different from the behavior of calling context.WithValue() directly, calling
// SetInterceptorIntoCtx multiple times will not bind multiple Interceptors, but
// will replace the original value each time.
// Be careful not to forget to use the returned ctx.
func SetInterceptorIntoCtx(ctx context.Context, interceptor Interceptor) context.Context {
mornyx marked this conversation as resolved.
Show resolved Hide resolved
if v := ctx.Value(interceptorCtxKey); v != nil {
v.(*atomic.Value).Store(interceptor)
return ctx
}
v := new(atomic.Value)
v.Store(interceptor)
return context.WithValue(ctx, interceptorCtxKey, v)
}

// GetInterceptorFromCtx gets the Interceptor bound by the previous call to SetInterceptorIntoCtx,
// and returns nil if there is none.
func GetInterceptorFromCtx(ctx context.Context) Interceptor {
if v := ctx.Value(interceptorCtxKey); v != nil {
v := v.(*atomic.Value).Load()
if interceptor, ok := v.(Interceptor); ok && interceptor != nil {
return interceptor
}
}
return nil
}

/* Suite for testing */

// MockInterceptorManager can be used to create Interceptor and record the
// number of executions of the created Interceptor.
type MockInterceptorManager struct {
mornyx marked this conversation as resolved.
Show resolved Hide resolved
begin int32
end int32
}

// CreateMockInterceptor creates an Interceptor for testing.
func (m *MockInterceptorManager) CreateMockInterceptor() Interceptor {
return func(next InterceptorFunc) InterceptorFunc {
return func(target string, req *Request) (*Response, error) {
atomic.AddInt32(&m.begin, 1)
defer atomic.AddInt32(&m.end, 1)
return next(target, req)
}
}
}

// Reset clear all counters.
func (m *MockInterceptorManager) Reset() {
atomic.StoreInt32(&m.begin, 0)
atomic.StoreInt32(&m.end, 0)
}

// BeginCount gets how many times the previously created Interceptor has been executed.
func (m *MockInterceptorManager) BeginCount() int {
return int(atomic.LoadInt32(&m.begin))
}

// EndCount gets how many times the previously created Interceptor has been returned.
func (m *MockInterceptorManager) EndCount() int {
return int(atomic.LoadInt32(&m.end))
}
59 changes: 59 additions & 0 deletions tikvrpc/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2021 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tikvrpc

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestInterceptorChain(t *testing.T) {
chain := NewInterceptorChain()
manager := MockInterceptorManager{}
it := chain.
Link(manager.CreateMockInterceptor()).
Link(manager.CreateMockInterceptor()).
Build()
_, _ = it(func(target string, req *Request) (*Response, error) {
return nil, nil
})("", nil)
assert.Equal(t, 2, manager.BeginCount())
assert.Equal(t, 2, manager.EndCount())
}

func TestGetAndSetInterceptorCtx(t *testing.T) {
ctx := context.Background()
assert.Nil(t, GetInterceptorFromCtx(ctx))
var it1 Interceptor = func(next InterceptorFunc) InterceptorFunc {
return next
}
ctx = SetInterceptorIntoCtx(ctx, it1)
it2 := GetInterceptorFromCtx(ctx)
assert.Equal(t, funcKey(it1), funcKey(it2))
var it3 Interceptor = func(next InterceptorFunc) InterceptorFunc {
return next
}
assert.NotEqual(t, funcKey(it1), funcKey(it3))
ctx = SetInterceptorIntoCtx(ctx, it3)
it4 := GetInterceptorFromCtx(ctx)
assert.Equal(t, funcKey(it3), funcKey(it4))
}

func funcKey(v interface{}) string {
return fmt.Sprintf("%v", v)
}
28 changes: 28 additions & 0 deletions txnkv/transaction/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ type KVTxn struct {
resourceGroupTag []byte
resourceGroupTagger tikvrpc.ResourceGroupTagger // use this when resourceGroupTag is nil
diskFullOpt kvrpcpb.DiskFullOpt
// interceptor is used to decorate the RPC request logic related to the txn.
interceptor tikvrpc.Interceptor
}

// NewTiKVTxn creates a new KVTxn.
Expand Down Expand Up @@ -242,6 +244,13 @@ func (txn *KVTxn) SetResourceGroupTagger(tagger tikvrpc.ResourceGroupTagger) {
txn.GetSnapshot().SetResourceGroupTagger(tagger)
}

// SetInterceptor sets tikvrpc.Interceptor for the transaction and its related snapshot.
// tikvrpc.Interceptor will be executed before each RPC request is initiated.
func (txn *KVTxn) SetInterceptor(interceptor tikvrpc.Interceptor) {
txn.interceptor = interceptor
txn.GetSnapshot().SetInterceptor(interceptor)
}

// SetSchemaAmender sets an amender to update mutations after schema change.
func (txn *KVTxn) SetSchemaAmender(sa SchemaAmender) {
txn.schemaAmender = sa
Expand Down Expand Up @@ -343,6 +352,13 @@ func (txn *KVTxn) Commit(ctx context.Context) error {
sessionID = val.(uint64)
}

if txn.interceptor != nil {
// User has called txn.SetInterceptor() to explicitly set an interceptor, we
// need to bind it to ctx so that the internal client can perceive and execute
// it before initiating an RPC request.
ctx = tikvrpc.SetInterceptorIntoCtx(ctx, txn.interceptor)
mornyx marked this conversation as resolved.
Show resolved Hide resolved
}

var err error
// If the txn use pessimistic lock, committer is initialized.
committer := txn.committer
Expand Down Expand Up @@ -450,6 +466,12 @@ func (txn *KVTxn) rollbackPessimisticLocks() error {
return nil
}
bo := retry.NewBackofferWithVars(context.Background(), cleanupMaxBackoff, txn.vars)
if txn.interceptor != nil {
// User has called txn.SetInterceptor() to explicitly set an interceptor, we
// need to bind it to ctx so that the internal client can perceive and execute
// it before initiating an RPC request.
bo.SetCtx(tikvrpc.SetInterceptorIntoCtx(bo.GetCtx(), txn.interceptor))
}
keys := txn.collectLockedKeys()
return txn.committer.pessimisticRollbackMutations(bo, &PlainMutations{keys: keys})
}
Expand Down Expand Up @@ -526,6 +548,12 @@ func (txn *KVTxn) LockKeysWithWaitTime(ctx context.Context, lockWaitTime int64,
// LockKeys tries to lock the entries with the keys in KV store.
// lockCtx is the context for lock, lockCtx.lockWaitTime in ms
func (txn *KVTxn) LockKeys(ctx context.Context, lockCtx *tikv.LockCtx, keysInput ...[]byte) error {
if txn.interceptor != nil {
// User has called txn.SetInterceptor() to explicitly set an interceptor, we
// need to bind it to ctx so that the internal client can perceive and execute
// it before initiating an RPC request.
ctx = tikvrpc.SetInterceptorIntoCtx(ctx, txn.interceptor)
}
mornyx marked this conversation as resolved.
Show resolved Hide resolved
// Exclude keys that are already locked.
var err error
keys := make([][]byte, 0, len(keysInput))
Expand Down
Loading