From 2a5eede84570a91de094372aafcafd1e2fb13cea Mon Sep 17 00:00:00 2001 From: Rianov Viacheslav Date: Fri, 15 Jul 2022 10:32:13 +0300 Subject: [PATCH] api: proposal to add the context support This patch adds the support of using context in API. The proposed API is based on using request objects. Added tests that cover almost all cases of using the context in a query. Added benchamrk tests are equivalent to other, that use the same query but without any context. Closes #48 --- connection.go | 174 ++++++++++++++++++++++++++--------- future_canceler.go | 5 + prepared.go | 19 ++++ request.go | 69 ++++++++++++++ tarantool_test.go | 144 +++++++++++++++++++++++++++++ test_helpers/request_mock.go | 6 ++ 6 files changed, 374 insertions(+), 43 deletions(-) create mode 100644 future_canceler.go diff --git a/connection.go b/connection.go index 6de1e9d01..468505028 100644 --- a/connection.go +++ b/connection.go @@ -5,6 +5,7 @@ package tarantool import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -139,20 +140,63 @@ type Connection struct { state uint32 dec *msgpack.Decoder lenbuf [PacketLengthBytes]byte + + futCanceler futureCanceler } var _ = Connector(&Connection{}) // Check compatibility with connector interface. +type futureList struct { + first *Future + last **Future +} + +func (list *futureList) findFuture(reqid uint32, fetch bool) *Future { + root := &list.first + for { + fut := *root + if fut == nil { + return nil + } + if fut.requestId == reqid { + if fetch { + *root = fut.next + if fut.next == nil { + list.last = root + } else { + fut.next = nil + } + } + return fut + } + root = &fut.next + } +} + +func (list *futureList) addFuture(fut *Future) { + *list.last = fut + list.last = &fut.next +} + +func (list *futureList) clear(err error, conn *Connection) { + fut := list.first + list.first = nil + list.last = &list.first + for fut != nil { + fut.SetError(err) + conn.markDone(fut) + fut, fut.next = fut.next, nil + } +} + type connShard struct { - rmut sync.Mutex - requests [requestsMap]struct { - first *Future - last **Future - } - bufmut sync.Mutex - buf smallWBuf - enc *msgpack.Encoder - _pad [16]uint64 //nolint: unused,structcheck + rmut sync.Mutex + requests [requestsMap]futureList + requestsWithCtx [requestsMap]futureList + bufmut sync.Mutex + buf smallWBuf + enc *msgpack.Encoder + _pad [16]uint64 //nolint: unused,structcheck } // Greeting is a message sent by Tarantool on connect. @@ -286,6 +330,9 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { for j := range shard.requests { shard.requests[j].last = &shard.requests[j].first } + for j := range shard.requests { + shard.requestsWithCtx[j].last = &shard.requestsWithCtx[j].first + } } if opts.RateLimit > 0 { @@ -334,6 +381,7 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { return nil, err } } + conn.futCanceler = &futCanceler{conn: conn} return conn, err } @@ -387,6 +435,20 @@ func (conn *Connection) Handle() interface{} { return conn.opts.Handle } +type futCanceler struct { + conn *Connection +} + +func (canceler *futCanceler) Cancel(fut *Future, err error) error { + if fut == nil { + return fmt.Errorf("passed nil future") + } + fut.SetError(err) + canceler.conn.fetchFuture(fut.requestId) + canceler.conn.markDone(fut) + return nil +} + func (conn *Connection) dial() (err error) { var connection net.Conn network := "tcp" @@ -582,14 +644,11 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error) conn.shard[i].buf.Reset() requests := &conn.shard[i].requests for pos := range requests { - fut := requests[pos].first - requests[pos].first = nil - requests[pos].last = &requests[pos].first - for fut != nil { - fut.SetError(neterr) - conn.markDone(fut) - fut, fut.next = fut.next, nil - } + requests[pos].clear(neterr, conn) + } + requestsWithCtx := &conn.shard[i].requestsWithCtx + for pos := range requestsWithCtx { + requestsWithCtx[pos].clear(neterr, conn) } } return @@ -721,7 +780,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { } } -func (conn *Connection) newFuture() (fut *Future) { +func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { fut = NewFuture() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { select { @@ -761,11 +820,20 @@ func (conn *Connection) newFuture() (fut *Future) { return } pos := (fut.requestId / conn.opts.Concurrency) & (requestsMap - 1) - pair := &shard.requests[pos] - *pair.last = fut - pair.last = &fut.next - if conn.opts.Timeout > 0 { - fut.timeout = time.Since(epoch) + conn.opts.Timeout + if ctx != nil { + select { + case <-ctx.Done(): + fut.SetError(fmt.Errorf("context is done")) + shard.rmut.Unlock() + return + default: + } + shard.requestsWithCtx[pos].addFuture(fut) + } else { + shard.requests[pos].addFuture(fut) + if conn.opts.Timeout > 0 { + fut.timeout = time.Since(epoch) + conn.opts.Timeout + } } shard.rmut.Unlock() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitWait { @@ -786,11 +854,37 @@ func (conn *Connection) newFuture() (fut *Future) { } func (conn *Connection) send(req Request) *Future { - fut := conn.newFuture() + fut := conn.newFuture(req.Context()) if fut.ready == nil { return fut } + if req.Context() != nil { + select { + case <-req.Context().Done(): + conn.futCanceler.Cancel(fut, fmt.Errorf("context is done")) + return fut + default: + } + } conn.putFuture(fut, req) + if req.Context() != nil { + go func() { + select { + case <-fut.done: + default: + select { + case <-req.Context().Done(): + conn.futCanceler.Cancel(fut, fmt.Errorf("context is done")) + default: + select { + case <-fut.done: + case <-req.Context().Done(): + conn.futCanceler.Cancel(fut, fmt.Errorf("context is done")) + } + } + } + }() + } return fut } @@ -877,26 +971,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) { func (conn *Connection) getFutureImp(reqid uint32, fetch bool) *Future { shard := &conn.shard[reqid&(conn.opts.Concurrency-1)] pos := (reqid / conn.opts.Concurrency) & (requestsMap - 1) - pair := &shard.requests[pos] - root := &pair.first - for { - fut := *root - if fut == nil { - return nil - } - if fut.requestId == reqid { - if fetch { - *root = fut.next - if fut.next == nil { - pair.last = root - } else { - fut.next = nil - } - } - return fut - } - root = &fut.next + fut := shard.requests[pos].findFuture(reqid, fetch) + if fut == nil { + fut = shard.requestsWithCtx[pos].findFuture(reqid, fetch) } + return fut } func (conn *Connection) timeouts() { @@ -1000,6 +1079,15 @@ func (conn *Connection) Do(req Request) *Future { return fut } } + if req.Context() != nil { + select { + case <-req.Context().Done(): + fut := NewFuture() + fut.SetError(fmt.Errorf("context is done")) + return fut + default: + } + } return conn.send(req) } diff --git a/future_canceler.go b/future_canceler.go new file mode 100644 index 000000000..fd5424bd5 --- /dev/null +++ b/future_canceler.go @@ -0,0 +1,5 @@ +package tarantool + +type futureCanceler interface { + Cancel(fut *Future, err error) error +} diff --git a/prepared.go b/prepared.go index 9508f0546..dd65272dc 100644 --- a/prepared.go +++ b/prepared.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "fmt" "gopkg.in/vmihailenco/msgpack.v2" @@ -58,6 +59,12 @@ func (req *PrepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error return fillPrepare(enc, req.expr) } +// WithContext sets a passed context to the request. +func (req *PrepareRequest) WithContext(ctx context.Context) *PrepareRequest { + req.ctx = ctx + return req +} + // UnprepareRequest helps you to create an unprepare request object for // execution by a Connection. type UnprepareRequest struct { @@ -83,6 +90,12 @@ func (req *UnprepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) erro return fillUnprepare(enc, *req.stmt) } +// WithContext sets a passed context to the request. +func (req *UnprepareRequest) WithContext(ctx context.Context) *UnprepareRequest { + req.ctx = ctx + return req +} + // ExecutePreparedRequest helps you to create an execute prepared request // object for execution by a Connection. type ExecutePreparedRequest struct { @@ -117,6 +130,12 @@ func (req *ExecutePreparedRequest) Body(res SchemaResolver, enc *msgpack.Encoder return fillExecutePrepared(enc, *req.stmt, req.args) } +// WithContext sets a passed context to the request. +func (req *ExecutePreparedRequest) WithContext(ctx context.Context) *ExecutePreparedRequest { + req.ctx = ctx + return req +} + func fillPrepare(enc *msgpack.Encoder, expr string) error { enc.EncodeMapLen(1) enc.EncodeUint64(KeySQLText) diff --git a/request.go b/request.go index a83094145..0987bc2c7 100644 --- a/request.go +++ b/request.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "errors" "reflect" "strings" @@ -537,6 +538,8 @@ type Request interface { Code() int32 // Body fills an encoder with a request body. Body(resolver SchemaResolver, enc *msgpack.Encoder) error + // Context returns a context of the request. + Context() context.Context } // ConnectedRequest is an interface that provides the info about a Connection @@ -549,6 +552,7 @@ type ConnectedRequest interface { type baseRequest struct { requestCode int32 + ctx context.Context } // Code returns a IPROTO code for the request. @@ -556,6 +560,11 @@ func (req *baseRequest) Code() int32 { return req.requestCode } +// Context returns a context of the request. +func (req *baseRequest) Context() context.Context { + return req.ctx +} + type spaceRequest struct { baseRequest space interface{} @@ -613,6 +622,12 @@ func (req *PingRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillPing(enc) } +// WithContext sets a passed context to the request. +func (req *PingRequest) WithContext(ctx context.Context) *PingRequest { + req.ctx = ctx + return req +} + // SelectRequest allows you to create a select request object for execution // by a Connection. type SelectRequest struct { @@ -683,6 +698,12 @@ func (req *SelectRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillSelect(enc, spaceNo, indexNo, req.offset, req.limit, req.iterator, req.key) } +// WithContext sets a passed context to the request. +func (req *SelectRequest) WithContext(ctx context.Context) *SelectRequest { + req.ctx = ctx + return req +} + // InsertRequest helps you to create an insert request object for execution // by a Connection. type InsertRequest struct { @@ -716,6 +737,12 @@ func (req *InsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillInsert(enc, spaceNo, req.tuple) } +// WithContext sets a passed context to the request. +func (req *InsertRequest) WithContext(ctx context.Context) *InsertRequest { + req.ctx = ctx + return req +} + // ReplaceRequest helps you to create a replace request object for execution // by a Connection. type ReplaceRequest struct { @@ -749,6 +776,12 @@ func (req *ReplaceRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error return fillInsert(enc, spaceNo, req.tuple) } +// WithContext sets a passed context to the request. +func (req *ReplaceRequest) WithContext(ctx context.Context) *ReplaceRequest { + req.ctx = ctx + return req +} + // DeleteRequest helps you to create a delete request object for execution // by a Connection. type DeleteRequest struct { @@ -789,6 +822,12 @@ func (req *DeleteRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillDelete(enc, spaceNo, indexNo, req.key) } +// WithContext sets a passed context to the request. +func (req *DeleteRequest) WithContext(ctx context.Context) *DeleteRequest { + req.ctx = ctx + return req +} + // UpdateRequest helps you to create an update request object for execution // by a Connection. type UpdateRequest struct { @@ -840,6 +879,12 @@ func (req *UpdateRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillUpdate(enc, spaceNo, indexNo, req.key, req.ops) } +// WithContext sets a passed context to the request. +func (req *UpdateRequest) WithContext(ctx context.Context) *UpdateRequest { + req.ctx = ctx + return req +} + // UpsertRequest helps you to create an upsert request object for execution // by a Connection. type UpsertRequest struct { @@ -884,6 +929,12 @@ func (req *UpsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillUpsert(enc, spaceNo, req.tuple, req.ops) } +// WithContext sets a passed context to the request. +func (req *UpsertRequest) WithContext(ctx context.Context) *UpsertRequest { + req.ctx = ctx + return req +} + // CallRequest helps you to create a call request object for execution // by a Connection. type CallRequest struct { @@ -915,6 +966,12 @@ func (req *CallRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillCall(enc, req.function, req.args) } +// WithContext sets a passed context to the request. +func (req *CallRequest) WithContext(ctx context.Context) *CallRequest { + req.ctx = ctx + return req +} + // NewCall16Request returns a new empty Call16Request. It uses request code for // Tarantool 1.6. // Deprecated since Tarantool 1.7.2. @@ -961,6 +1018,12 @@ func (req *EvalRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillEval(enc, req.expr, req.args) } +// WithContext sets a passed context to the request. +func (req *EvalRequest) WithContext(ctx context.Context) *EvalRequest { + req.ctx = ctx + return req +} + // ExecuteRequest helps you to create an execute request object for execution // by a Connection. type ExecuteRequest struct { @@ -989,3 +1052,9 @@ func (req *ExecuteRequest) Args(args interface{}) *ExecuteRequest { func (req *ExecuteRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillExecute(enc, req.expr, req.args) } + +// WithContext sets a passed context to the request. +func (req *ExecuteRequest) WithContext(ctx context.Context) *ExecuteRequest { + req.ctx = ctx + return req +} diff --git a/tarantool_test.go b/tarantool_test.go index 06771338c..0f63cf5b5 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -1,10 +1,12 @@ package tarantool_test import ( + "context" "fmt" "log" "os" "reflect" + "runtime" "strings" "sync" "testing" @@ -117,6 +119,36 @@ func BenchmarkClientSerialRequestObject(b *testing.B) { } } +func BenchmarkClientSerialRequestObjectWithContext(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err = conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Error(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + req := NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}). + WithContext(ctx) + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } +} + func BenchmarkClientSerialTyped(b *testing.B) { var err error @@ -342,6 +374,88 @@ func BenchmarkClientParallel(b *testing.B) { }) } +func benchmarkClientParallelRequestObject(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) + + b.SetParallelism(multiplier) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func benchmarkClientParallelRequestObjectWithContext(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) + + b.SetParallelism(multiplier) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + req.WithContext(ctx) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func BenchmarkClientParallelRequestObject(b *testing.B) { + multipliers := []int{10, 50, 500, 1000} + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + for _, m := range multipliers { + goroutinesNum := runtime.GOMAXPROCS(0) * m + + b.Run(fmt.Sprintf("%d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObject(m, b) + }) + + b.Run(fmt.Sprintf("With Context %d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObjectWithContext(m, b) + }) + } +} + func BenchmarkClientParallelMassive(b *testing.B) { conn := test_helpers.ConnectWithValidation(b, server, opts) defer conn.Close() @@ -2081,6 +2195,36 @@ func TestClientRequestObjects(t *testing.T) { } } +func TestClientRequestObjectsWithContext(t *testing.T) { + var err error + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewPingRequest().WithContext(ctx) + cancel() + resp, err := conn.Do(req).Get() + if err.Error() != "context is done" { + t.Fatalf("Failed to catch an error from done context") + } + if resp != nil { + t.Fatalf("Response is not nil after the occured error") + } + + req = NewPingRequest().WithContext(nil) //nolint + resp, err = conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Ping: %s", err.Error()) + } + if resp == nil { + t.Fatalf("Response is nil after Ping") + } + if len(resp.Data) != 0 { + t.Errorf("Response Body len != 0") + } +} + func TestComplexStructs(t *testing.T) { var err error diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go index 00674a3a7..99b55bb4e 100644 --- a/test_helpers/request_mock.go +++ b/test_helpers/request_mock.go @@ -1,6 +1,8 @@ package test_helpers import ( + "context" + "github.com/tarantool/go-tarantool" "gopkg.in/vmihailenco/msgpack.v2" ) @@ -23,3 +25,7 @@ func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *msgpack. func (sr *StrangerRequest) Conn() *tarantool.Connection { return &tarantool.Connection{} } + +func (sr *StrangerRequest) Context() context.Context { + return nil +}