From 09e1433c51ff6f1d16e184fb0bd3ba4dda6e032d Mon Sep 17 00:00:00 2001 From: Oleg Jukovec Date: Mon, 21 Nov 2022 19:00:47 +0300 Subject: [PATCH] api: add events subscription support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A user can create watcher by the Connection.NewWatcher() call: watcher = conn.NewWatcker("key", func(event WatchEvent) { // The callback code. }) After that, the watcher callback is invoked for the first time. In this case, the callback is triggered whether or not the key has already been broadcast. All subsequent invocations are triggered with box.broadcast() called on the remote host. If a watcher is subscribed for a key that has not been broadcast yet, the callback is triggered only once, after the registration of the watcher. If the key is updated while the watcher callback is running, the callback will be invoked again with the latest value as soon as it returns. Multiple watchers can be created for one key. If you don’t need the watcher anymore, you can unregister it using the Unregister method: watcher.Unregister() The api is similar to net.box implementation [1]. It also adds a BroadcastRequest to make it easier to send broadcast messages. 1. https://www.tarantool.io/en/doc/latest/reference/reference_lua/net_box/#conn-watch Closes #119 --- CHANGELOG.md | 2 + connection.go | 303 ++++++++++++++++++++- connection_pool/connection_pool.go | 309 +++++++++++++++++++--- connection_pool/connection_pool_test.go | 270 +++++++++++++++++++ connection_pool/connector.go | 8 + connection_pool/connector_test.go | 39 +++ connection_pool/pooler.go | 2 + connection_pool/round_robin.go | 24 +- connection_pool/round_robin_test.go | 19 ++ connector.go | 1 + const.go | 5 + multi/multi.go | 10 + multi/multi_test.go | 24 ++ request.go | 8 + request_test.go | 65 +++++ tarantool_test.go | 333 ++++++++++++++++++++++++ test_helpers/request_mock.go | 4 + test_helpers/utils.go | 17 +- watch.go | 138 ++++++++++ 19 files changed, 1521 insertions(+), 60 deletions(-) create mode 100644 watch.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fb75664b..782dda025 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Added +- Event subscription support (#119) + ### Changed ### Fixed diff --git a/connection.go b/connection.go index a4ae8cc36..f21e0fc8b 100644 --- a/connection.go +++ b/connection.go @@ -53,6 +53,8 @@ const ( // LogUnexpectedResultId is logged when response with unknown id was received. // Most probably it is due to request timeout. LogUnexpectedResultId + // LogReadWatchEventFailed is logged when failed to read a watch event. + LogReadWatchEventFailed ) // ConnEvent is sent throw Notify channel specified in Opts. @@ -62,6 +64,12 @@ type ConnEvent struct { When time.Time } +// A raw watch event. +type connWatchEvent struct { + key string + value interface{} +} + var epoch = time.Now() // Logger is logger type expected to be passed in options. @@ -76,13 +84,16 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac case LogReconnectFailed: reconnects := v[0].(uint) err := v[1].(error) - log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s\n", reconnects, conn.opts.MaxReconnects, conn.addr, err.Error()) + log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s", reconnects, conn.opts.MaxReconnects, conn.addr, err.Error()) case LogLastReconnectFailed: err := v[0].(error) - log.Printf("tarantool: last reconnect to %s failed: %s, giving it up.\n", conn.addr, err.Error()) + log.Printf("tarantool: last reconnect to %s failed: %s, giving it up", conn.addr, err.Error()) case LogUnexpectedResultId: resp := v[0].(*Response) log.Printf("tarantool: connection %s got unexpected resultId (%d) in response", conn.addr, resp.RequestId) + case LogReadWatchEventFailed: + err := v[0].(error) + log.Printf("tarantool: unable to parse watch event: %s", err) default: args := append([]interface{}{"tarantool: unexpected event ", event, conn}, v...) log.Print(args...) @@ -146,6 +157,9 @@ type Connection struct { lenbuf [PacketLengthBytes]byte lastStreamId uint64 + + // watchMap is a map of key -> watchSharedData. + watchMap sync.Map } var _ = Connector(&Connection{}) // Check compatibility with connector interface. @@ -502,7 +516,7 @@ func (conn *Connection) dial() (err error) { conn.Greeting.Version = bytes.NewBuffer(greeting[:64]).String() conn.Greeting.auth = bytes.NewBuffer(greeting[64:108]).String() - // Auth + // Auth. if opts.User != "" { scr, err := scramble(conn.Greeting.auth, opts.Pass) if err != nil { @@ -520,7 +534,28 @@ func (conn *Connection) dial() (err error) { } } - // Only if connected and authenticated. + // Watchers. + conn.watchMap.Range(func(key, value interface{}) bool { + st := value.(chan watchState) + state := <-st + if state.cnt > 0 { + req := newWatchRequest(key.(string)) + if err = conn.writeRequest(w, req); err != nil { + st <- state + return false + } + state.init = true + state.ack = true + } + st <- state + return true + }) + + if err != nil { + return fmt.Errorf("unable to register watch: %w", err) + } + + // Only if connected and fully initialized. conn.lockShards() conn.c = connection atomic.StoreUint32(&conn.state, connConnected) @@ -581,23 +616,33 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32, return } -func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { +func (conn *Connection) writeRequest(w *bufio.Writer, req Request) (err error) { var packet smallWBuf - req := newAuthRequest(conn.opts.User, string(scramble)) err = pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema) if err != nil { - return errors.New("auth: pack error " + err.Error()) + return fmt.Errorf("pack error %w", err) } if err := write(w, packet.b); err != nil { - return errors.New("auth: write error " + err.Error()) + return fmt.Errorf("write error %w", err) } if err = w.Flush(); err != nil { - return errors.New("auth: flush error " + err.Error()) + return fmt.Errorf("flush error %w", err) } return } +func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { + req := newAuthRequest(conn.opts.User, string(scramble)) + + err = conn.writeRequest(w, req) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + + return nil +} + func (conn *Connection) readAuthResponse(r io.Reader) (err error) { respBytes, err := conn.read(r) if err != nil { @@ -774,7 +819,50 @@ func (conn *Connection) writer(w *bufio.Writer, c net.Conn) { } } +func readWatchEvent(reader io.Reader) (connWatchEvent, error) { + keyExist := false + event := connWatchEvent{} + d := newDecoder(reader) + + if l, err := d.DecodeMapLen(); err == nil { + for ; l > 0; l-- { + if cd, err := d.DecodeInt(); err == nil { + switch cd { + case KeyEvent: + if event.key, err = d.DecodeString(); err != nil { + return event, err + } + keyExist = true + case KeyEventData: + if event.value, err = d.DecodeInterface(); err != nil { + return event, err + } + default: + if err = d.Skip(); err != nil { + return event, err + } + } + } else { + return event, err + } + } + } else { + return event, err + } + + if !keyExist { + return event, errors.New("watch event does not have a key") + } + + return event, nil +} + func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { + events := make(chan connWatchEvent, 1024) + defer close(events) + + go conn.eventer(events) + for atomic.LoadUint32(&conn.state) != connClosed { respBytes, err := conn.read(r) if err != nil { @@ -789,7 +877,14 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { } var fut *Future = nil - if resp.Code == PushCode { + if resp.Code == EventCode { + if event, err := readWatchEvent(&resp.buf); err == nil { + events <- event + } else { + conn.opts.Logger.Report(LogReadWatchEventFailed, conn, err) + } + continue + } else if resp.Code == PushCode { if fut = conn.peekFuture(resp.RequestId); fut != nil { fut.AppendPush(resp) } @@ -799,12 +894,37 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { conn.markDone(fut) } } + if fut == nil { conn.opts.Logger.Report(LogUnexpectedResultId, conn, resp) } } } +// eventer goroutine gets watch events and updates values for watchers. +func (conn *Connection) eventer(events <-chan connWatchEvent) { + for { + event, ok := <-events + if !ok { + // The channel is closed. + break + } + + if value, ok := conn.watchMap.Load(event.key); ok { + st := value.(chan watchState) + state := <-st + if state.changed != nil { + close(state.changed) + state.changed = nil + } + state.value = event.value + state.init = false + state.ack = false + st <- state + } + } +} + func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { fut = NewFuture() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { @@ -960,6 +1080,18 @@ func (conn *Connection) putFuture(fut *Future, req Request, streamId uint64) { return } shard.bufmut.Unlock() + + if req.Async() { + if fut = conn.fetchFuture(reqid); fut != nil { + resp := &Response{ + RequestId: reqid, + Code: OkCode, + } + fut.SetResponse(resp) + conn.markDone(fut) + } + } + if firstWritten { conn.dirtyShard <- shardn } @@ -1163,3 +1295,154 @@ func (conn *Connection) NewStream() (*Stream, error) { Conn: conn, }, nil } + +// watchState is the current state of the watcher. See the idea at p. 70, 105: +// https://drive.google.com/file/d/1nPdvhB0PutEJzdCq5ms6UI58dp50fcAN/view +type watchState struct { + // value is a current value. + value interface{} + // init is true if it is an initial state (no events reveived). + init bool + // ack true if the acknowledge is already sended. + ack bool + // cnt is a count of active watchers for the key. + cnt int32 + // changed is a channel for broadcast the value changes. + changed chan struct{} +} + +// connWatcher is an internal implementation of the Watcher interface. +type connWatcher struct { + unregister sync.Once + done chan struct{} + finished chan struct{} +} + +// Unregister unregisters the connection watcher. +func (w *connWatcher) Unregister() { + w.unregister.Do(func() { + close(w.done) + }) + <-w.finished +} + +// NewWatcher creates a new Watcher object for the connection. +// +// After watcher creation, the watcher callback is invoked for the first time. +// In this case, the callback is triggered whether or not the key has already +// been broadcast. All subsequent invocations are triggered with +// box.broadcast() called on the remote host. If a watcher is subscribed for a +// key that has not been broadcast yet, the callback is triggered only once, +// after the registration of the watcher. +// +// The watcher callbacks are always invoked in a separate goroutine. A watcher +// callback is never executed in parallel with itself, but they can be executed +// in parallel to other watchers. +// +// If the key is updated while the watcher callback is running, the callback +// will be invoked again with the latest value as soon as it returns. +// +// Watchers survive reconnection. All registered watchers are automatically +// resubscribed when the connection is reestablished. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (conn *Connection) NewWatcher(key string, callback WatchCallback) (Watcher, error) { + // TODO: check required features after: + // + // https://github.com/tarantool/go-tarantool/issues/120 + var st chan watchState + // Get or create a shared data for the key. + if val, ok := conn.watchMap.Load(key); !ok { + st = make(chan watchState, 1) + st <- watchState{ + value: nil, + init: true, + ack: false, + cnt: 0, + changed: nil, + } + + if val, ok := conn.watchMap.LoadOrStore(key, st); ok { + close(st) + st = val.(chan watchState) + } + } else { + st = val.(chan watchState) + } + + state := <-st + // Send an initial watch request if needed. + if state.cnt == 0 { + if _, err := conn.Do(newWatchRequest(key)).Get(); err != nil { + st <- state + return nil, err + } + } + state.cnt += 1 + st <- state + + // Start the watcher goroutine. + done := make(chan struct{}) + finished := make(chan struct{}) + + go func() { + for { + state := <-st + if state.changed == nil { + state.changed = make(chan struct{}) + } + st <- state + + if !state.init { + callback(WatchEvent{ + Conn: conn, + Key: key, + Value: state.value, + }) + + // Acknowledge the notification. + state = <-st + ack := state.ack + state.ack = true + st <- state + + if !ack { + conn.Do(newWatchRequest(key)).Get() + // We expect a reconnect and re-subscribe if it fails to + // send the watch request. So it looks ok do not check a + // result. + } + } + + select { + case <-done: + state := <-st + state.cnt -= 1 + if state.cnt == 0 { + // The last one sends IPROTO_UNWATCH. + conn.Do(newUnwatchRequest(key)).Get() + } + st <- state + + close(finished) + return + case <-state.changed: + } + } + }() + + return &connWatcher{ + done: done, + finished: finished, + }, nil +} diff --git a/connection_pool/connection_pool.go b/connection_pool/connection_pool.go index 6597e2dd0..81dc3c4ae 100644 --- a/connection_pool/connection_pool.go +++ b/connection_pool/connection_pool.go @@ -91,12 +91,13 @@ type ConnectionPool struct { connOpts tarantool.Opts opts OptsPool - state state - done chan struct{} - roPool *RoundRobinStrategy - rwPool *RoundRobinStrategy - anyPool *RoundRobinStrategy - poolsMutex sync.RWMutex + state state + done chan struct{} + roPool *RoundRobinStrategy + rwPool *RoundRobinStrategy + anyPool *RoundRobinStrategy + poolsMutex sync.RWMutex + watcherContainer watcherContainer } var _ Pooler = (*ConnectionPool)(nil) @@ -640,25 +641,6 @@ func (connPool *ConnectionPool) ExecuteAsync(expr string, args interface{}, user return conn.ExecuteAsync(expr, args) } -// Do sends the request and returns a future. -// For requests that belong to an only one connection (e.g. Unprepare or ExecutePrepared) -// the argument of type Mode is unused. -func (connPool *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { - if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { - conn, _ := connPool.getConnectionFromPool(connectedReq.Conn().Addr()) - if conn == nil { - return newErrorFuture(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) - } - return connectedReq.Conn().Do(req) - } - conn, err := connPool.getNextConnection(userMode) - if err != nil { - return newErrorFuture(err) - } - - return conn.Do(req) -} - // NewStream creates new Stream object for connection selected // by userMode from connPool. // @@ -682,6 +664,200 @@ func (connPool *ConnectionPool) NewPrepared(expr string, userMode Mode) (*tarant return conn.NewPrepared(expr) } +// watcherContainer is a very simple implementation of a thread-safe container +// for watchers. It is not expected that there will be too many watchers and +// they will registered/unregistered too frequently. +// +// Otherwise, the implementation will need to be optimized. +type watcherContainer struct { + head *poolWatcher + mutex sync.RWMutex +} + +// add adds a watcher to the container. +func (c *watcherContainer) add(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + watcher.next = c.head + c.head = watcher +} + +// remove removes a watcher from the container. +func (c *watcherContainer) remove(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if watcher == c.head { + c.head = watcher.next + } else { + cur := c.head + for cur.next != nil { + if cur.next == watcher { + cur.next = watcher.next + break + } + cur = cur.next + } + } +} + +// foreach iterates over the container to the end or until the call returns +// false. +func (c *watcherContainer) foreach(call func(watcher *poolWatcher) error) error { + cur := c.head + for cur != nil { + if err := call(cur); err != nil { + return err + } + cur = cur.next + } + return nil +} + +// poolWatcher is an internal implementation of the tarantool.Watcher interface. +type poolWatcher struct { + // The watcher container data. We can split the structure into two parts + // in the future: a watcher data and a watcher container data, but it looks + // simple at now. + + // next item in the watcher container. + next *poolWatcher + // container is the container for all active poolWatcher objects. + container *watcherContainer + + // The watcher data. + // mode of the watcher. + mode Mode + key string + callback tarantool.WatchCallback + // watchers is a map connection -> connection watcher. + watchers map[string]tarantool.Watcher + // unregistered is true if the watcher already unregistered. + unregistered bool + // mutex for the pool watcher. + mutex sync.Mutex +} + +// Unregister unregisters the pool watcher. +func (w *poolWatcher) Unregister() { + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + w.container.remove(w) + w.unregistered = true + for _, watcher := range w.watchers { + watcher.Unregister() + } + } +} + +// watch adds a watcher for the connection. +func (w *poolWatcher) watch(conn *tarantool.Connection) error { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if _, ok := w.watchers[addr]; ok { + return nil + } + + if watcher, err := conn.NewWatcher(w.key, w.callback); err == nil { + w.watchers[addr] = watcher + return nil + } else { + return err + } + } + return nil +} + +// unwatch removes a watcher for the connection. +func (w *poolWatcher) unwatch(conn *tarantool.Connection) { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if watcher, ok := w.watchers[addr]; ok { + watcher.Unregister() + delete(w.watchers, addr) + } + } +} + +// NewWatcher creates a new Watcher object for the connection pool. +// +// The behavior is same as if Connection.NewWatcher() called for each +// connection with a suitable role. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (pool *ConnectionPool) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + watcher := &poolWatcher{ + container: &pool.watcherContainer, + mode: mode, + key: key, + callback: callback, + watchers: make(map[string]tarantool.Watcher), + unregistered: false, + } + + watcher.container.add(watcher) + + rr := pool.anyPool + if mode == RW { + rr = pool.rwPool + } else if mode == RO { + rr = pool.roPool + } + + conns := rr.GetConnections() + for _, conn := range conns { + // TODO: check required features after: + // + // https://github.com/tarantool/go-tarantool/issues/120 + if err := watcher.watch(conn); err != nil { + conn.Close() + } + } + + return watcher, nil +} + +// Do sends the request and returns a future. +// For requests that belong to the only one connection (e.g. Unprepare or ExecutePrepared) +// the argument of type Mode is unused. +func (connPool *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { + if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { + conn, _ := connPool.getConnectionFromPool(connectedReq.Conn().Addr()) + if conn == nil { + return newErrorFuture(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) + } + return connectedReq.Conn().Do(req) + } + conn, err := connPool.getNextConnection(userMode) + if err != nil { + return newErrorFuture(err) + } + + return conn.Do(req) +} + // // private // @@ -733,26 +909,63 @@ func (connPool *ConnectionPool) getConnectionFromPool(addr string) (*tarantool.C return connPool.anyPool.GetConnByAddr(addr), UnknownRole } -func (connPool *ConnectionPool) deleteConnection(addr string) { - if conn := connPool.anyPool.DeleteConnByAddr(addr); conn != nil { - if conn := connPool.rwPool.DeleteConnByAddr(addr); conn != nil { - return +func (pool *ConnectionPool) deleteConnection(addr string) { + if conn := pool.anyPool.DeleteConnByAddr(addr); conn != nil { + if conn := pool.rwPool.DeleteConnByAddr(addr); conn == nil { + pool.roPool.DeleteConnByAddr(addr) } - connPool.roPool.DeleteConnByAddr(addr) + // The internal connection deinitialization. + pool.watcherContainer.mutex.RLock() + defer pool.watcherContainer.mutex.RUnlock() + + pool.watcherContainer.foreach(func(watcher *poolWatcher) error { + watcher.unwatch(conn) + return nil + }) + } +} + +func (pool *ConnectionPool) addConnection(addr string, + conn *tarantool.Connection, role Role) error { + // The internal connection initialization. + pool.watcherContainer.mutex.RLock() + defer pool.watcherContainer.mutex.RUnlock() + + watched := []*poolWatcher{} + err := pool.watcherContainer.foreach(func(watcher *poolWatcher) error { + watch := false + if watcher.mode == RW { + watch = role == MasterRole + } else if watcher.mode == RO { + watch = role == ReplicaRole + } else { + watch = true + } + if watch { + if err := watcher.watch(conn); err != nil { + return err + } + watched = append(watched, watcher) + } + return nil + }) + if err != nil { + for _, watcher := range watched { + watcher.unwatch(conn) + } + log.Printf("tarantool: failed initialize watchers for %s: %s", addr, err) + return err } -} -func (connPool *ConnectionPool) addConnection(addr string, - conn *tarantool.Connection, role Role) { - - connPool.anyPool.AddConn(addr, conn) + pool.anyPool.AddConn(addr, conn) switch role { case MasterRole: - connPool.rwPool.AddConn(addr, conn) + pool.rwPool.AddConn(addr, conn) case ReplicaRole: - connPool.roPool.AddConn(addr, conn) + pool.roPool.AddConn(addr, conn) } + return nil } func (connPool *ConnectionPool) handlerDiscovered(conn *tarantool.Connection, @@ -811,7 +1024,10 @@ func (connPool *ConnectionPool) fillPools() ([]connState, bool) { } if connPool.handlerDiscovered(conn, role) { - connPool.addConnection(addr, conn, role) + if connPool.addConnection(addr, conn, role) != nil { + conn.Close() + connPool.handlerDeactivated(conn, role) + } if conn.ConnectedNow() { states[i].conn = conn @@ -864,7 +1080,15 @@ func (pool *ConnectionPool) updateConnection(s connState) connState { return s } - pool.addConnection(s.addr, s.conn, role) + if pool.addConnection(s.addr, s.conn, role) != nil { + pool.poolsMutex.Unlock() + + s.conn.Close() + pool.handlerDeactivated(s.conn, role) + s.conn = nil + s.role = UnknownRole + return s + } s.role = role } } @@ -911,7 +1135,12 @@ func (pool *ConnectionPool) tryConnect(s connState) connState { return s } - pool.addConnection(s.addr, conn, role) + if pool.addConnection(s.addr, conn, role) != nil { + pool.poolsMutex.Unlock() + conn.Close() + pool.handlerDeactivated(conn, role) + return s + } s.conn = conn s.role = role } diff --git a/connection_pool/connection_pool_test.go b/connection_pool/connection_pool_test.go index 60c9b4f91..3a650bfbc 100644 --- a/connection_pool/connection_pool_test.go +++ b/connection_pool/connection_pool_test.go @@ -2048,6 +2048,276 @@ func TestStream_TxnIsolationLevel(t *testing.T) { } } +func TestConnectionPool_NewWatcher_modes(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_modes" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + modes := []connection_pool.Mode{ + connection_pool.ANY, + connection_pool.RW, + connection_pool.RO, + connection_pool.PreferRW, + connection_pool.PreferRO, + } + for _, mode := range modes { + t.Run(fmt.Sprintf("%d", mode), func(t *testing.T) { + expectedServers := []string{} + for i, server := range servers { + if roles[i] && mode == connection_pool.RW { + continue + } else if !roles[i] && mode == connection_pool.RO { + continue + } + expectedServers = append(expectedServers, server) + } + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to register a watcher") + defer watcher.Unregister() + + testMap := make(map[string]int) + + for i := 0; i < len(expectedServers); i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch event.") + break + } + } + + for _, server := range expectedServers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else { + require.Equal(t, val, 1, fmt.Sprintf("for server %s", server)) + } + } + }) + } +} + +func TestConnectionPool_NewWatcher_update(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_update" + const mode = connection_pool.RW + const initCnt = 2 + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + defer watcher.Unregister() + + // Wait for all initial events. + testMap := make(map[string]int) + for i := 0; i < initCnt; i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch init event.") + break + } + } + + // Just invert roles for simplify the test. + for i, role := range roles { + roles[i] = !role + } + err = test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + // Wait for all updated events. + for i := 0; i < len(servers)-initCnt; i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch update event.") + break + } + } + + // Check that all an event happen for an each connection. + for _, server := range servers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else { + require.Equal(t, val, 1, fmt.Sprintf("for server %s", server)) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const mode = connection_pool.RW + const expectedCnt = 2 + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + for i := 0; i < expectedCnt; i++ { + select { + case <-events: + case <-time.After(time.Second): + t.Fatalf("Failed to skip initial events.") + } + } + watcher.Unregister() + + broadcast := tarantool.NewBroadcastRequest(key).Value("foo") + for i := 0; i < expectedCnt; i++ { + _, err := pool.Do(broadcast, mode).Get() + require.Nilf(t, err, "failed to send a broadcast request") + } + + select { + case event := <-events: + t.Fatalf("Get unexpected event: %v", event) + case <-time.After(time.Second): + } +} + +func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnectionPool_NewWatcher_concurrent" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + mode := connection_pool.ANY + callback := func(event tarantool.WatchEvent) {} + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + watcher, err := pool.NewWatcher(key, callback, mode) + if err != nil { + t.Errorf("Failed to create a watcher: %s", err) + } else { + watcher.Unregister() + } + }(i) + } + wg.Wait() +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + mode := connection_pool.ANY + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/connection_pool/connector.go b/connection_pool/connector.go index e52109d92..c108aba0b 100644 --- a/connection_pool/connector.go +++ b/connection_pool/connector.go @@ -299,6 +299,14 @@ func (c *ConnectorAdapter) NewStream() (*tarantool.Stream, error) { return c.pool.NewStream(c.mode) } +// NewWatcher creates new Watcher object for the pool +// +// Since 1.10.0 +func (c *ConnectorAdapter) NewWatcher(key string, + callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return c.pool.NewWatcher(key, callback, c.mode) +} + // Do performs a request asynchronously on the connection. func (c *ConnectorAdapter) Do(req tarantool.Request) *tarantool.Future { return c.pool.Do(req, c.mode) diff --git a/connection_pool/connector_test.go b/connection_pool/connector_test.go index fa7cf06ba..f53a05b22 100644 --- a/connection_pool/connector_test.go +++ b/connection_pool/connector_test.go @@ -1139,6 +1139,45 @@ func TestConnectorNewStream(t *testing.T) { require.Equalf(t, testMode, m.mode, "unexpected proxy mode") } +type watcherMock struct{} + +func (w *watcherMock) Unregister() {} + +const reqWatchKey = "foo" + +var reqWatcher tarantool.Watcher = &watcherMock{} + +type newWatcherMock struct { + Pooler + key string + callback tarantool.WatchCallback + called int + mode Mode +} + +func (m *newWatcherMock) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + m.called++ + m.key = key + m.callback = callback + m.mode = mode + return reqWatcher, reqErr +} + +func TestConnectorNewWatcher(t *testing.T) { + m := &newWatcherMock{} + c := NewConnectorAdapter(m, testMode) + + w, err := c.NewWatcher(reqWatchKey, func(event tarantool.WatchEvent) {}) + + require.Equalf(t, reqWatcher, w, "unexpected watcher") + require.Equalf(t, reqErr, err, "unexpected error") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, reqWatchKey, m.key, "unexpected key") + require.NotNilf(t, m.callback, "callback must be set") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + var reqRequest tarantool.Request = tarantool.NewPingRequest() type doMock struct { diff --git a/connection_pool/pooler.go b/connection_pool/pooler.go index a9dbe09f9..856f5d5be 100644 --- a/connection_pool/pooler.go +++ b/connection_pool/pooler.go @@ -84,6 +84,8 @@ type Pooler interface { NewPrepared(expr string, mode Mode) (*tarantool.Prepared, error) NewStream(mode Mode) (*tarantool.Stream, error) + NewWatcher(key string, callback tarantool.WatchCallback, + mode Mode) (tarantool.Watcher, error) Do(req tarantool.Request, mode Mode) (fut *tarantool.Future) } diff --git a/connection_pool/round_robin.go b/connection_pool/round_robin.go index b83d877d9..a7fb73e18 100644 --- a/connection_pool/round_robin.go +++ b/connection_pool/round_robin.go @@ -14,6 +14,15 @@ type RoundRobinStrategy struct { current uint } +func NewEmptyRoundRobin(size int) *RoundRobinStrategy { + return &RoundRobinStrategy{ + conns: make([]*tarantool.Connection, 0, size), + indexByAddr: make(map[string]uint), + size: 0, + current: 0, + } +} + func (r *RoundRobinStrategy) GetConnByAddr(addr string) *tarantool.Connection { r.mutex.RLock() defer r.mutex.RUnlock() @@ -71,13 +80,14 @@ func (r *RoundRobinStrategy) GetNextConnection() *tarantool.Connection { return r.conns[r.nextIndex()] } -func NewEmptyRoundRobin(size int) *RoundRobinStrategy { - return &RoundRobinStrategy{ - conns: make([]*tarantool.Connection, 0, size), - indexByAddr: make(map[string]uint), - size: 0, - current: 0, - } +func (r *RoundRobinStrategy) GetConnections() []*tarantool.Connection { + r.mutex.RLock() + defer r.mutex.RUnlock() + + ret := make([]*tarantool.Connection, len(r.conns)) + copy(ret, r.conns) + + return ret } func (r *RoundRobinStrategy) AddConn(addr string, conn *tarantool.Connection) { diff --git a/connection_pool/round_robin_test.go b/connection_pool/round_robin_test.go index 6b54ecfd8..03038eada 100644 --- a/connection_pool/round_robin_test.go +++ b/connection_pool/round_robin_test.go @@ -69,3 +69,22 @@ func TestRoundRobinGetNextConnection(t *testing.T) { } } } + +func TestRoundRobinStrategy_GetConnections(t *testing.T) { + rr := NewEmptyRoundRobin(10) + + addrs := []string{validAddr1, validAddr2} + conns := []*tarantool.Connection{&tarantool.Connection{}, &tarantool.Connection{}} + + for i, addr := range addrs { + rr.AddConn(addr, conns[i]) + } + + rr.GetConnections()[1] = conns[0] // GetConnections() returns a copy. + rrConns := rr.GetConnections() + for i, expected := range conns { + if expected != rrConns[i] { + t.Errorf("Unexpected connection on %d call", i) + } + } +} diff --git a/connector.go b/connector.go index d6c44c8dd..d93c69ec8 100644 --- a/connector.go +++ b/connector.go @@ -46,6 +46,7 @@ type Connector interface { NewPrepared(expr string) (*Prepared, error) NewStream() (*Stream, error) + NewWatcher(key string, callback WatchCallback) (Watcher, error) Do(req Request) (fut *Future) } diff --git a/const.go b/const.go index 4a3cb6833..95a0d366d 100644 --- a/const.go +++ b/const.go @@ -18,6 +18,8 @@ const ( RollbackRequestCode = 16 PingRequestCode = 64 SubscribeRequestCode = 66 + WatchRequestCode = 74 + UnwatchRequestCode = 75 KeyCode = 0x00 KeySync = 0x01 @@ -42,6 +44,8 @@ const ( KeySQLInfo = 0x42 KeyStmtID = 0x43 KeyTimeout = 0x56 + KeyEvent = 0x57 + KeyEventData = 0x58 KeyTxnIsolation = 0x59 KeyFieldName = 0x00 @@ -70,6 +74,7 @@ const ( RLimitWait = 2 OkCode = uint32(0) + EventCode = uint32(0x4c) PushCode = uint32(0x80) ErrorCodeBit = 0x8000 PacketLengthBytes = 5 diff --git a/multi/multi.go b/multi/multi.go index 67f450c5c..9d3828dd7 100644 --- a/multi/multi.go +++ b/multi/multi.go @@ -507,6 +507,16 @@ func (connMulti *ConnectionMulti) NewStream() (*tarantool.Stream, error) { return connMulti.getCurrentConnection().NewStream() } +// NewWatcher does not supported by the ConnectionMulti. The ConnectionMulti is +// deprecated: use ConnectionPool instead. +// +// Since 1.10.0 +func (connMulti *ConnectionMulti) NewWatcher(key string, + callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return nil, errors.New("ConnectionMulti is deprecated " + + "use ConnectionPool") +} + // Do sends the request and returns a future. func (connMulti *ConnectionMulti) Do(req tarantool.Request) *tarantool.Future { if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { diff --git a/multi/multi_test.go b/multi/multi_test.go index 2d43bb179..ef07d629b 100644 --- a/multi/multi_test.go +++ b/multi/multi_test.go @@ -548,6 +548,30 @@ func TestStream_Rollback(t *testing.T) { } } +func TestConnectionMulti_NewWatcher(t *testing.T) { + test_helpers.SkipIfStreamsUnsupported(t) + + multiConn, err := Connect([]string{server1, server2}, connOpts) + if err != nil { + t.Fatalf("Failed to connect: %s", err.Error()) + } + if multiConn == nil { + t.Fatalf("conn is nil after Connect") + } + defer multiConn.Close() + + watcher, err := multiConn.NewWatcher("foo", func(event tarantool.WatchEvent) {}) + if watcher != nil { + t.Errorf("Unexpected watcher") + } + if err == nil { + t.Fatalf("Unexpected success") + } + if err.Error() != "ConnectionMulti is deprecated use ConnectionPool" { + t.Fatalf("Unexpected error: %s", err) + } +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/request.go b/request.go index cfa40e522..66eb4be41 100644 --- a/request.go +++ b/request.go @@ -538,6 +538,8 @@ type Request interface { Body(resolver SchemaResolver, enc *encoder) error // Ctx returns a context of the request. Ctx() context.Context + // Async returns true if the request does not expect response. + Async() bool } // ConnectedRequest is an interface that provides the info about a Connection @@ -550,6 +552,7 @@ type ConnectedRequest interface { type baseRequest struct { requestCode int32 + async bool ctx context.Context } @@ -558,6 +561,11 @@ func (req *baseRequest) Code() int32 { return req.requestCode } +// Async returns true if the request does not require a response. +func (req *baseRequest) Async() bool { + return req.async +} + // Ctx returns a context of the request. func (req *baseRequest) Ctx() context.Context { return req.ctx diff --git a/request_test.go b/request_test.go index 89d1d8884..a680cdcbb 100644 --- a/request_test.go +++ b/request_test.go @@ -19,6 +19,7 @@ const invalidIndex = 2 const validSpace = 1 // Any valid value != default. const validIndex = 3 // Any valid value != default. const validExpr = "any string" // We don't check the value here. +const validKey = "foo" // Any string. const defaultSpace = 0 // And valid too. const defaultIndex = 0 // And valid too. @@ -183,6 +184,7 @@ func TestRequestsCodes(t *testing.T) { {req: NewBeginRequest(), code: BeginRequestCode}, {req: NewCommitRequest(), code: CommitRequestCode}, {req: NewRollbackRequest(), code: RollbackRequestCode}, + {req: NewBroadcastRequest(validKey), code: EvalRequestCode}, } for _, test := range tests { @@ -192,6 +194,38 @@ func TestRequestsCodes(t *testing.T) { } } +func TestRequestsAsync(t *testing.T) { + tests := []struct { + req Request + async bool + }{ + {req: NewSelectRequest(validSpace), async: false}, + {req: NewUpdateRequest(validSpace), async: false}, + {req: NewUpsertRequest(validSpace), async: false}, + {req: NewInsertRequest(validSpace), async: false}, + {req: NewReplaceRequest(validSpace), async: false}, + {req: NewDeleteRequest(validSpace), async: false}, + {req: NewCall16Request(validExpr), async: false}, + {req: NewCall17Request(validExpr), async: false}, + {req: NewEvalRequest(validExpr), async: false}, + {req: NewExecuteRequest(validExpr), async: false}, + {req: NewPingRequest(), async: false}, + {req: NewPrepareRequest(validExpr), async: false}, + {req: NewUnprepareRequest(validStmt), async: false}, + {req: NewExecutePreparedRequest(validStmt), async: false}, + {req: NewBeginRequest(), async: false}, + {req: NewCommitRequest(), async: false}, + {req: NewRollbackRequest(), async: false}, + {req: NewBroadcastRequest(validKey), async: false}, + } + + for _, test := range tests { + if async := test.req.Async(); async != test.async { + t.Errorf("An invalid async %t, expected %t", async, test.async) + } + } +} + func TestPingRequestDefaultValues(t *testing.T) { var refBuf bytes.Buffer @@ -649,3 +683,34 @@ func TestRollbackRequestDefaultValues(t *testing.T) { req := NewRollbackRequest() assertBodyEqual(t, refBuf.Bytes(), req) } + +func TestBroadcastRequestDefaultValues(t *testing.T) { + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey} + err := RefImplEvalBody(refEnc, "box.broadcast(...)", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplEvalBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey) + assertBodyEqual(t, refBuf.Bytes(), req) +} + +func TestBroadcastRequestSetters(t *testing.T) { + value := []interface{}{uint(34), int(12)} + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey, value} + err := RefImplEvalBody(refEnc, "box.broadcast(...)", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplEvalBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey).Value(value) + assertBodyEqual(t, refBuf.Bytes(), req) +} diff --git a/tarantool_test.go b/tarantool_test.go index 1350390f9..301e6359e 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -2830,6 +2830,339 @@ func TestStream_DoWithClosedConn(t *testing.T) { } } +func TestConnection_NewWatcher(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != nil { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestConnection_NewWatcher_reconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_reconnect" + const server = "127.0.0.1:3014" + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + InitScript: "config.lua", + Listen: server, + WorkDir: "work_dir", + User: opts.User, + Pass: opts.Pass, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 3, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 10 + conn := test_helpers.ConnectWithValidation(t, server, reconnectOpts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events + + test_helpers.StopTarantool(inst) + if err := test_helpers.RestartTarantool(&inst); err != nil { + t.Fatalf("Unable to restart Tarantool: %s", err) + } + + maxTime := reconnectOpts.Reconnect * time.Duration(reconnectOpts.MaxReconnects) + select { + case <-events: + case <-time.After(maxTime): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + resp, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + if resp.Code != OkCode { + t.Errorf("Got unexpected broadcast response code: %d", resp.Code) + } + if !reflect.DeepEqual(resp.Data, []interface{}{}) { + t.Errorf("Got unexpected broadcast response data: %v", resp.Data) + } + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest_multi(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest_multi" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events // Skip an initial event. + for i := 0; i < 10; i++ { + val := fmt.Sprintf("%d", i) + _, err := conn.Do(NewBroadcastRequest(key).Value(val)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != val { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event %d", i) + } + } +} + +func TestConnection_NewWatcher_multiOnKey(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_multiOnKey" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := []chan WatchEvent{ + make(chan WatchEvent), + make(chan WatchEvent), + } + for _, ch := range events { + defer close(ch) + } + + for _, ch := range events { + channel := ch + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + channel <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + } + + for i, ch := range events { + select { + case <-ch: // Skip an initial event. + case <-time.After(2 * time.Second): + t.Fatalf("Failed to skip watch event for %d callback", i) + } + } + + _, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + + for i, ch := range events { + select { + case event := <-ch: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event from callback %d", i) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + <-events + watcher.Unregister() + + _, err = conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + + select { + case event := <-events: + t.Fatalf("Get unexpected events: %v", event) + case <-time.After(time.Second): + } +} + +func TestConnection_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnection_NewWatcher_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + var ret error + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + events := make(chan struct{}) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + close(events) + }) + if err != nil { + ret = err + } else { + select { + case <-events: + case <-time.After(time.Second): + ret = fmt.Errorf("Unable to get an event %d.", i) + } + watcher.Unregister() + } + }(i) + } + wg.Wait() + + if ret != nil { + t.Fatalf("Unable to create a watcher: %s", ret) + } +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go index 93551e34a..19c18545e 100644 --- a/test_helpers/request_mock.go +++ b/test_helpers/request_mock.go @@ -17,6 +17,10 @@ func (sr *StrangerRequest) Code() int32 { return 0 } +func (sr *StrangerRequest) Async() bool { + return false +} + func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *encoder) error { return nil } diff --git a/test_helpers/utils.go b/test_helpers/utils.go index c936e90b3..f78c751f0 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -53,16 +53,27 @@ func SkipIfSQLUnsupported(t testing.TB) { } } -func SkipIfStreamsUnsupported(t *testing.T) { +func skipIfLess2_10(t *testing.T, feature string) { t.Helper() - // Tarantool supports streams and interactive transactions since version 2.10.0 isLess, err := IsTarantoolVersionLess(2, 10, 0) if err != nil { t.Fatalf("Could not check the Tarantool version") } if isLess { - t.Skip("Skipping test for Tarantool without streams support") + t.Skipf("Skipping test for Tarantool without %s support", feature) } } + +func SkipIfStreamsUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t, "streams") +} + +func SkipIfWatchersUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t, "watchers") +} diff --git a/watch.go b/watch.go new file mode 100644 index 000000000..2bd91b4bf --- /dev/null +++ b/watch.go @@ -0,0 +1,138 @@ +package tarantool + +import ( + "context" +) + +// BroadcastRequest helps to send broadcast messages. See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/broadcast/ +type BroadcastRequest struct { + eval *EvalRequest + key string +} + +// NewBroadcastRequest returns a new broadcast request for a specified key. +func NewBroadcastRequest(key string) *BroadcastRequest { + req := new(BroadcastRequest) + req.key = key + req.eval = NewEvalRequest("box.broadcast(...)").Args([]interface{}{key}) + return req +} + +// Value sets the value for the broadcast request. +// Note: default value is nil. +func (req *BroadcastRequest) Value(value interface{}) *BroadcastRequest { + req.eval = req.eval.Args([]interface{}{req.key, value}) + return req +} + +// Context sets a passed context to the broadcast request. +func (req *BroadcastRequest) Context(ctx context.Context) *BroadcastRequest { + req.eval = req.eval.Context(ctx) + return req +} + +// Code returns IPROTO code for the broadcast request. +func (req *BroadcastRequest) Code() int32 { + return req.eval.Code() +} + +// Body fills an encoder with the broadcast request body. +func (req *BroadcastRequest) Body(res SchemaResolver, enc *encoder) error { + return req.eval.Body(res, enc) +} + +// Ctx returns a context of the broadcast request. +func (req *BroadcastRequest) Ctx() context.Context { + return req.eval.Ctx() +} + +// Async returns is the broadcast request expects a response. +func (req *BroadcastRequest) Async() bool { + return req.eval.Async() +} + +// watchRequest subscribes to the updates of a specified key defined on the +// server. After receiving the notification, you should send a new +// watchRequest to acknowledge the notification. +type watchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newWatchRequest returns a new watchRequest. +func newWatchRequest(key string) *watchRequest { + req := new(watchRequest) + req.requestCode = WatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the watch request body. +func (req *watchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *watchRequest) Context(ctx context.Context) *watchRequest { + req.ctx = ctx + return req +} + +// unwatchRequest unregisters a watcher subscribed to the given notification +// key. +type unwatchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newUnwatchRequest returns a new unwatchRequest. +func newUnwatchRequest(key string) *unwatchRequest { + req := new(unwatchRequest) + req.requestCode = UnwatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the unwatch request body. +func (req *unwatchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *unwatchRequest) Context(ctx context.Context) *unwatchRequest { + req.ctx = ctx + return req +} + +// WatchEvent is a watch notification event received from a server. +type WatchEvent struct { + Conn *Connection // A source connection. + Key string // A key. + Value interface{} // A value. +} + +// Watcher is a subscription to broadcast events. +type Watcher interface { + // Unregister unregisters the watcher. + Unregister() +} + +// WatchCallback is a callback to invoke when the key value is updated. +type WatchCallback func(event WatchEvent)