diff --git a/connection.go b/connection.go index 018df6caa..3eff1fc26 100644 --- a/connection.go +++ b/connection.go @@ -90,15 +90,15 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac reconnects := v[0].(uint) err := v[1].(error) log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s", - reconnects, conn.opts.MaxReconnects, conn.addr, err) + reconnects, conn.opts.MaxReconnects, conn.Addr(), err) case LogLastReconnectFailed: err := v[0].(error) log.Printf("tarantool: last reconnect to %s failed: %s, giving it up", - conn.addr, err) + conn.Addr(), err) case LogUnexpectedResultId: resp := v[0].(*Response) log.Printf("tarantool: connection %s got unexpected resultId (%d) in response", - conn.addr, resp.RequestId) + conn.Addr(), resp.RequestId) case LogWatchEventReadFailed: err := v[0].(error) log.Printf("tarantool: unable to parse watch event: %s", err) @@ -156,10 +156,10 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac // More on graceful shutdown: // https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/graceful_shutdown/ type Connection struct { - addr string - c Conn - mutex sync.Mutex - cond *sync.Cond + dialer Dialer + c Conn + mutex sync.Mutex + cond *sync.Cond // Schema contains schema loaded on connection. Schema *Schema // requestId contains the last request ID for requests with nil context. @@ -260,11 +260,6 @@ const ( // Opts is a way to configure Connection type Opts struct { - // Auth is an authentication method. - Auth Auth - // Dialer is a Dialer object used to create a new connection to a - // Tarantool instance. TtDialer is a default one. - Dialer Dialer // Timeout for response to a particular request. The timeout is reset when // push messages are received. If Timeout is zero, any request can be // blocked infinitely. @@ -287,10 +282,6 @@ type Opts struct { // endlessly. // After MaxReconnects attempts Connection becomes closed. MaxReconnects uint - // Username for logging in to Tarantool. - User string - // User password for logging in to Tarantool. - Pass string // RateLimit limits number of 'in-fly' request, i.e. already put into // requests queue, but not yet answered by server or timeouted. // It is disabled by default. @@ -315,44 +306,6 @@ type Opts struct { Handle interface{} // Logger is user specified logger used for error messages. Logger Logger - // Transport is the connection type, by default the connection is unencrypted. - Transport string - // SslOpts is used only if the Transport == 'ssl' is set. - Ssl SslOpts - // RequiredProtocolInfo contains minimal protocol version and - // list of protocol features that should be supported by - // Tarantool server. By default there are no restrictions. - RequiredProtocolInfo ProtocolInfo -} - -// SslOpts is a way to configure ssl transport. -type SslOpts struct { - // KeyFile is a path to a private SSL key file. - KeyFile string - // CertFile is a path to an SSL certificate file. - CertFile string - // CaFile is a path to a trusted certificate authorities (CA) file. - CaFile string - // Ciphers is a colon-separated (:) list of SSL cipher suites the connection - // can use. - // - // We don't provide a list of supported ciphers. This is what OpenSSL - // does. The only limitation is usage of TLSv1.2 (because other protocol - // versions don't seem to support the GOST cipher). To add additional - // ciphers (GOST cipher), you must configure OpenSSL. - // - // See also - // - // * https://www.openssl.org/docs/man1.1.1/man1/ciphers.html - Ciphers string - // Password is a password for decrypting the private SSL key file. - // The priority is as follows: try to decrypt with Password, then - // try PasswordFile. - Password string - // PasswordFile is a path to the list of passwords for decrypting - // the private SSL key file. The connection tries every line from the - // file as a password. - PasswordFile string } // Clone returns a copy of the Opts object. @@ -360,24 +313,13 @@ type SslOpts struct { // RequiredProtocolInfo value. func (opts Opts) Clone() Opts { optsCopy := opts - optsCopy.RequiredProtocolInfo = opts.RequiredProtocolInfo.Clone() - return optsCopy } // Connect creates and configures a new Connection. -// -// Address could be specified in following ways: -// -// - TCP connections (tcp://192.168.1.1:3013, tcp://my.host:3013, -// tcp:192.168.1.1:3013, tcp:my.host:3013, 192.168.1.1:3013, my.host:3013) -// -// - Unix socket, first '/' or '.' indicates Unix socket -// (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock, -// ./rel/path/tnt.sock, unix/:path/tnt.sock) -func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err error) { +func Connect(ctx context.Context, dialer Dialer, opts Opts) (conn *Connection, err error) { conn = &Connection{ - addr: addr, + dialer: dialer, requestId: 0, contextRequestId: 1, Greeting: &Greeting{}, @@ -389,9 +331,6 @@ func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 { conn.opts.Concurrency = maxprocs * 4 } - if conn.opts.Dialer == nil { - conn.opts.Dialer = TtDialer{} - } if c := conn.opts.Concurrency; c&(c-1) != 0 { for i := uint(1); i < 32; i *= 2 { c |= c >> i @@ -473,27 +412,7 @@ func (conn *Connection) CloseGraceful() error { // Addr returns a configured address of Tarantool socket. func (conn *Connection) Addr() string { - return conn.addr -} - -// RemoteAddr returns an address of Tarantool socket. -func (conn *Connection) RemoteAddr() string { - conn.mutex.Lock() - defer conn.mutex.Unlock() - if conn.c == nil { - return "" - } - return conn.c.RemoteAddr().String() -} - -// LocalAddr returns an address of outgoing socket. -func (conn *Connection) LocalAddr() string { - conn.mutex.Lock() - defer conn.mutex.Unlock() - if conn.c == nil { - return "" - } - return conn.c.LocalAddr().String() + return conn.c.GetAddr() } // Handle returns a user-specified handle from Opts. @@ -512,14 +431,8 @@ func (conn *Connection) dial(ctx context.Context) error { opts := conn.opts var c Conn - c, err := conn.opts.Dialer.Dial(ctx, conn.addr, DialOpts{ - IoTimeout: opts.Timeout, - Transport: opts.Transport, - Ssl: opts.Ssl, - RequiredProtocol: opts.RequiredProtocolInfo, - Auth: opts.Auth, - User: opts.User, - Password: opts.Pass, + c, err := conn.dialer.Dial(ctx, DialOpts{ + IoTimeout: opts.Timeout, }) if err != nil { return err @@ -1474,7 +1387,7 @@ func (conn *Connection) NewWatcher(key string, callback WatchCallback) (Watcher, // That's why we can't just check the Tarantool response for an unsupported // request error. if !isFeatureInSlice(iproto.IPROTO_FEATURE_WATCHERS, - conn.opts.RequiredProtocolInfo.Features) { + conn.c.ProtocolInfo().Features) { err := fmt.Errorf("the feature %s must be required by connection "+ "options to create a watcher", iproto.IPROTO_FEATURE_WATCHERS) return nil, err @@ -1580,7 +1493,7 @@ func (conn *Connection) ServerProtocolInfo() ProtocolInfo { // Since 1.10.0 func (conn *Connection) ClientProtocolInfo() ProtocolInfo { info := clientProtocolInfo.Clone() - info.Auth = conn.opts.Auth + info.Auth = conn.serverProtocolInfo.Auth return info } diff --git a/dial.go b/dial.go index 5b17c0534..a6c93471c 100644 --- a/dial.go +++ b/dial.go @@ -15,10 +15,7 @@ import ( "github.com/vmihailenco/msgpack/v5" ) -const ( - dialTransportNone = "" - dialTransportSsl = "ssl" -) +const bufSize = 128 * 1024 // Greeting is a message sent by Tarantool on connect. type Greeting struct { @@ -45,34 +42,18 @@ type Conn interface { // Any blocked Read or Flush operations will be unblocked and return // errors. Close() error - // LocalAddr returns the local network address, if known. - LocalAddr() net.Addr - // RemoteAddr returns the remote network address, if known. - RemoteAddr() net.Addr // Greeting returns server greeting. Greeting() Greeting // ProtocolInfo returns server protocol info. ProtocolInfo() ProtocolInfo + // GetAddr returns the connection address. + GetAddr() string } // DialOpts is a way to configure a Dial method to create a new Conn. type DialOpts struct { // IoTimeout is a timeout per a network read/write. IoTimeout time.Duration - // Transport is a connect transport type. - Transport string - // Ssl configures "ssl" transport. - Ssl SslOpts - // RequiredProtocol contains minimal protocol version and - // list of protocol features that should be supported by - // Tarantool server. By default there are no restrictions. - RequiredProtocol ProtocolInfo - // Auth is an authentication method. - Auth Auth - // Username for logging in to Tarantool. - User string - // User password for logging in to Tarantool. - Password string } // Dialer is the interface that wraps a method to connect to a Tarantool @@ -85,10 +66,11 @@ type DialOpts struct { type Dialer interface { // Dial connects to a Tarantool instance to the address with specified // options. - Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) + Dial(ctx context.Context, opts DialOpts) (Conn, error) } type tntConn struct { + addr string net net.Conn reader io.Reader writer writeFlusher @@ -96,61 +78,186 @@ type tntConn struct { protocol ProtocolInfo } -// TtDialer is a default implementation of the Dialer interface which is -// used by the connector. +// rawDial does basic dial operations: +// reads greeting, identifies a protocol and validates it. +func rawDial(conn *tntConn, requiredProto ProtocolInfo) (string, error) { + version, salt, err := readGreeting(conn.reader) + if err != nil { + return "", fmt.Errorf("failed to read greeting: %w", err) + } + conn.greeting.Version = version + + if conn.protocol, err = identify(conn.writer, conn.reader); err != nil { + return "", fmt.Errorf("failed to read greeting: %w", err) + } + + if err = checkProtocolInfo(requiredProto, conn.protocol); err != nil { + return "", fmt.Errorf("invalid server protocol: %w", err) + } + return salt, err +} + type TtDialer struct { + // Address is an address to connect. + // It could be specified in following ways: + // + // - TCP connections (tcp://192.168.1.1:3013, tcp://my.host:3013, + // tcp:192.168.1.1:3013, tcp:my.host:3013, 192.168.1.1:3013, my.host:3013) + // + // - Unix socket, first '/' or '.' indicates Unix socket + // (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock, + // ./rel/path/tnt.sock, unix/:path/tnt.sock) + Address string + // Username for logging in to Tarantool. + User string + // User password for logging in to Tarantool. + Password string + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default, there are no restrictions. + RequiredProtocolInfo ProtocolInfo } -// Dial connects to a Tarantool instance to the address with specified -// options. -func (t TtDialer) Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) { +// Dial makes TtDialer satisfy the Dialer interface. +func (d TtDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { var err error conn := new(tntConn) + conn.addr = d.Address - if conn.net, err = dial(ctx, address, opts); err != nil { + network, address := parseAddress(d.Address) + dialer := net.Dialer{} + conn.net, err = dialer.DialContext(ctx, network, address) + if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} - conn.reader = bufio.NewReaderSize(dc, 128*1024) - conn.writer = bufio.NewWriterSize(dc, 128*1024) + conn.reader = bufio.NewReaderSize(dc, bufSize) + conn.writer = bufio.NewWriterSize(dc, bufSize) - var version, salt string - if version, salt, err = readGreeting(conn.reader); err != nil { + salt, err := rawDial(conn, d.RequiredProtocolInfo) + if err != nil { conn.net.Close() - return nil, fmt.Errorf("failed to read greeting: %w", err) + return nil, err } - conn.greeting.Version = version - if conn.protocol, err = identify(conn.writer, conn.reader); err != nil { + if d.User == "" { + return conn, nil + } + + if err = authenticate(conn, ChapSha1Auth, d.User, d.Password, salt); err != nil { conn.net.Close() - return nil, fmt.Errorf("failed to identify: %w", err) + return nil, fmt.Errorf("failed to authenticate: %w", err) + } + + return conn, nil +} + +type OpenSSLDialer struct { + // Address is an address to connect. + // It could be specified in following ways: + // + // - TCP connections (tcp://192.168.1.1:3013, tcp://my.host:3013, + // tcp:192.168.1.1:3013, tcp:my.host:3013, 192.168.1.1:3013, my.host:3013) + // + // - Unix socket, first '/' or '.' indicates Unix socket + // (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock, + // ./rel/path/tnt.sock, unix/:path/tnt.sock) + Address string + // Auth is an authentication method. + Auth Auth + // Username for logging in to Tarantool. + User string + // User password for logging in to Tarantool. + Password string + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default, there are no restrictions. + RequiredProtocolInfo ProtocolInfo + // KeyFile is a path to a private SSL key file. + KeyFile string + // CertFile is a path to an SSL certificate file. + CertFile string + // CaFile is a path to a trusted certificate authorities (CA) file. + CaFile string + // Ciphers is a colon-separated (:) list of SSL cipher suites the connection + // can use. + // + // We don't provide a list of supported ciphers. This is what OpenSSL + // does. The only limitation is usage of TLSv1.2 (because other protocol + // versions don't seem to support the GOST cipher). To add additional + // ciphers (GOST cipher), you must configure OpenSSL. + // + // See also + // + // * https://www.openssl.org/docs/man1.1.1/man1/ciphers.html + Ciphers string + // SSLPassword is a password for decrypting the private SSL key file. + // The priority is as follows: try to decrypt with SSLPassword, then + // try PasswordFile. + SSLPassword string + // PasswordFile is a path to the list of passwords for decrypting + // the private SSL key file. The connection tries every line from the + // file as a password. + PasswordFile string +} + +type TlsDialer = OpenSSLDialer + +// Dial makes OpenSSLDialer satisfy the Dialer interface. +func (d OpenSSLDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + var err error + conn := new(tntConn) + conn.addr = d.Address + + network, address := parseAddress(d.Address) + conn.net, err = sslDialContext(ctx, network, address, sslOpts{ + keyFile: d.KeyFile, + certFile: d.CertFile, + caFile: d.CaFile, + ciphers: d.Ciphers, + password: d.SSLPassword, + passwordFile: d.PasswordFile, + }) + if err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) } - if err = checkProtocolInfo(opts.RequiredProtocol, conn.protocol); err != nil { + dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} + conn.reader = bufio.NewReaderSize(dc, bufSize) + conn.writer = bufio.NewWriterSize(dc, bufSize) + + salt, err := rawDial(conn, d.RequiredProtocolInfo) + if err != nil { conn.net.Close() - return nil, fmt.Errorf("invalid server protocol: %w", err) + return nil, err } - if opts.User != "" { - if opts.Auth == AutoAuth { - if conn.protocol.Auth != AutoAuth { - opts.Auth = conn.protocol.Auth - } else { - opts.Auth = ChapSha1Auth - } - } + if d.User == "" { + return conn, nil + } - err := authenticate(conn, opts, salt) - if err != nil { - conn.net.Close() - return nil, fmt.Errorf("failed to authenticate: %w", err) + if d.Auth == AutoAuth { + if conn.protocol.Auth != AutoAuth { + d.Auth = conn.protocol.Auth + } else { + d.Auth = ChapSha1Auth } } + if err = authenticate(conn, d.Auth, d.User, d.Password, salt); err != nil { + conn.net.Close() + return nil, fmt.Errorf("failed to authenticate: %w", err) + } + return conn, nil } +// GetAddr makes tntConn satisfy the Conn interface. +func (c *tntConn) GetAddr() string { + return c.addr +} + // Read makes tntConn satisfy the Conn interface. func (c *tntConn) Read(p []byte) (int, error) { return c.reader.Read(p) @@ -177,16 +284,6 @@ func (c *tntConn) Close() error { return c.net.Close() } -// RemoteAddr makes tntConn satisfy the Conn interface. -func (c *tntConn) RemoteAddr() net.Addr { - return c.net.RemoteAddr() -} - -// LocalAddr makes tntConn satisfy the Conn interface. -func (c *tntConn) LocalAddr() net.Addr { - return c.net.LocalAddr() -} - // Greeting makes tntConn satisfy the Conn interface. func (c *tntConn) Greeting() Greeting { return c.greeting @@ -197,20 +294,6 @@ func (c *tntConn) ProtocolInfo() ProtocolInfo { return c.protocol } -// dial connects to a Tarantool instance. -func dial(ctx context.Context, address string, opts DialOpts) (net.Conn, error) { - network, address := parseAddress(address) - switch opts.Transport { - case dialTransportNone: - dialer := net.Dialer{} - return dialer.DialContext(ctx, network, address) - case dialTransportSsl: - return sslDialContext(ctx, network, address, opts.Ssl) - default: - return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport) - } -} - // parseAddress split address into network and address parts. func parseAddress(address string) (string, string) { network := "tcp" @@ -316,29 +399,21 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error { } } -// authenticate authenticate for a connection. -func authenticate(c Conn, opts DialOpts, salt string) error { - auth := opts.Auth - user := opts.User - pass := opts.Password - +// authenticate authenticates for a connection. +func authenticate(c Conn, auth Auth, user string, pass string, salt string) error { var req Request var err error - switch opts.Auth { + switch auth { case ChapSha1Auth: req, err = newChapSha1AuthRequest(user, pass, salt) if err != nil { return err } case PapSha256Auth: - if opts.Transport != dialTransportSsl { - return errors.New("forbidden to use " + auth.String() + - " unless SSL is enabled for the connection") - } req = newPapSha256AuthRequest(user, pass) default: - return errors.New("unsupported method " + opts.Auth.String()) + return errors.New("unsupported method " + auth.String()) } if err = writeRequest(c, req); err != nil { diff --git a/pool/connection_pool.go b/pool/connection_pool.go index aa84d8c24..ea07ae32f 100644 --- a/pool/connection_pool.go +++ b/pool/connection_pool.go @@ -23,7 +23,7 @@ import ( ) var ( - ErrEmptyAddrs = errors.New("addrs (first argument) should not be empty") + ErrEmptyDialers = errors.New("dialers (second argument) should not be empty") ErrWrongCheckTimeout = errors.New("wrong check timeout, must be greater than 0") ErrNoConnection = errors.New("no active connections") ErrTooManyArgs = errors.New("too many arguments") @@ -94,8 +94,8 @@ Main features: - Automatic master discovery by mode parameter. */ type ConnectionPool struct { - addrs map[string]*endpoint - addrsMutex sync.RWMutex + ends map[string]*endpoint + endsMutex sync.RWMutex connOpts tarantool.Opts opts Opts @@ -112,7 +112,8 @@ type ConnectionPool struct { var _ Pooler = (*ConnectionPool)(nil) type endpoint struct { - addr string + id string + dialer tarantool.Dialer notify chan tarantool.ConnEvent conn *tarantool.Connection role Role @@ -124,9 +125,10 @@ type endpoint struct { closeErr error } -func newEndpoint(addr string) *endpoint { +func newEndpoint(id string, dialer tarantool.Dialer) *endpoint { return &endpoint{ - addr: addr, + id: id, + dialer: dialer, notify: make(chan tarantool.ConnEvent, 100), conn: nil, role: UnknownRole, @@ -137,24 +139,24 @@ func newEndpoint(addr string) *endpoint { } } -// ConnectWithOpts creates pool for instances with addresses addrs -// with options opts. -func ConnectWithOpts(ctx context.Context, addrs []string, +// ConnectWithOpts creates pool for instances with specified dialers and options opts. +// Each dialer corresponds to a certain id by which they will be distinguished. +func ConnectWithOpts(ctx context.Context, dialers map[string]tarantool.Dialer, connOpts tarantool.Opts, opts Opts) (*ConnectionPool, error) { - if len(addrs) == 0 { - return nil, ErrEmptyAddrs + if len(dialers) == 0 { + return nil, ErrEmptyDialers } if opts.CheckTimeout <= 0 { return nil, ErrWrongCheckTimeout } - size := len(addrs) + size := len(dialers) rwPool := newRoundRobinStrategy(size) roPool := newRoundRobinStrategy(size) anyPool := newRoundRobinStrategy(size) connPool := &ConnectionPool{ - addrs: make(map[string]*endpoint), + ends: make(map[string]*endpoint), connOpts: connOpts.Clone(), opts: opts, state: unknownState, @@ -164,11 +166,7 @@ func ConnectWithOpts(ctx context.Context, addrs []string, anyPool: anyPool, } - for _, addr := range addrs { - connPool.addrs[addr] = nil - } - - somebodyAlive, ctxCanceled := connPool.fillPools(ctx) + somebodyAlive, ctxCanceled := connPool.fillPools(ctx, dialers) if !somebodyAlive { connPool.state.set(closedState) if ctxCanceled { @@ -179,7 +177,7 @@ func ConnectWithOpts(ctx context.Context, addrs []string, connPool.state.set(connectedState) - for _, s := range connPool.addrs { + for _, s := range connPool.ends { endpointCtx, cancel := context.WithCancel(context.Background()) s.cancel = cancel go connPool.controller(endpointCtx, s) @@ -188,17 +186,18 @@ func ConnectWithOpts(ctx context.Context, addrs []string, return connPool, nil } -// ConnectWithOpts creates pool for instances with addresses addrs. +// Connect creates pool for instances with specified dialers. +// Each dialer corresponds to a certain id by which they will be distinguished. // // It is useless to set up tarantool.Opts.Reconnect value for a connection. // The connection pool has its own reconnection logic. See // Opts.CheckTimeout description. -func Connect(ctx context.Context, addrs []string, +func Connect(ctx context.Context, dialers map[string]tarantool.Dialer, connOpts tarantool.Opts) (*ConnectionPool, error) { opts := Opts{ CheckTimeout: 1 * time.Second, } - return ConnectWithOpts(ctx, addrs, connOpts, opts) + return ConnectWithOpts(ctx, dialers, connOpts, opts) } // ConnectedNow gets connected status of pool. @@ -235,32 +234,32 @@ func (p *ConnectionPool) ConfiguredTimeout(mode Mode) (time.Duration, error) { return conn.ConfiguredTimeout(), nil } -// Add adds a new endpoint with the address into the pool. This function +// Add adds a new endpoint with the id into the pool. This function // adds the endpoint only after successful connection. -func (p *ConnectionPool) Add(ctx context.Context, addr string) error { - e := newEndpoint(addr) +func (p *ConnectionPool) Add(ctx context.Context, id string, dialer tarantool.Dialer) error { + e := newEndpoint(id, dialer) - p.addrsMutex.Lock() + p.endsMutex.Lock() // Ensure that Close()/CloseGraceful() not in progress/done. if p.state.get() != connectedState { - p.addrsMutex.Unlock() + p.endsMutex.Unlock() return ErrClosed } - if _, ok := p.addrs[addr]; ok { - p.addrsMutex.Unlock() + if _, ok := p.ends[id]; ok { + p.endsMutex.Unlock() return ErrExists } endpointCtx, cancel := context.WithCancel(context.Background()) e.cancel = cancel - p.addrs[addr] = e - p.addrsMutex.Unlock() + p.ends[id] = e + p.endsMutex.Unlock() if err := p.tryConnect(ctx, e); err != nil { - p.addrsMutex.Lock() - delete(p.addrs, addr) - p.addrsMutex.Unlock() + p.endsMutex.Lock() + delete(p.ends, id) + p.endsMutex.Unlock() e.cancel() close(e.closed) return err @@ -270,13 +269,13 @@ func (p *ConnectionPool) Add(ctx context.Context, addr string) error { return nil } -// Remove removes an endpoint with the address from the pool. The call +// Remove removes an endpoint with the id from the pool. The call // closes an active connection gracefully. -func (p *ConnectionPool) Remove(addr string) error { - p.addrsMutex.Lock() - endpoint, ok := p.addrs[addr] +func (p *ConnectionPool) Remove(id string) error { + p.endsMutex.Lock() + endpoint, ok := p.ends[id] if !ok { - p.addrsMutex.Unlock() + p.endsMutex.Unlock() return errors.New("endpoint not exist") } @@ -290,20 +289,20 @@ func (p *ConnectionPool) Remove(addr string) error { close(endpoint.shutdown) } - delete(p.addrs, addr) - p.addrsMutex.Unlock() + delete(p.ends, id) + p.endsMutex.Unlock() <-endpoint.closed return nil } func (p *ConnectionPool) waitClose() []error { - p.addrsMutex.RLock() - endpoints := make([]*endpoint, 0, len(p.addrs)) - for _, e := range p.addrs { + p.endsMutex.RLock() + endpoints := make([]*endpoint, 0, len(p.ends)) + for _, e := range p.ends { endpoints = append(endpoints, e) } - p.addrsMutex.RUnlock() + p.endsMutex.RUnlock() errs := make([]error, 0, len(endpoints)) for _, e := range endpoints { @@ -319,12 +318,12 @@ func (p *ConnectionPool) waitClose() []error { func (p *ConnectionPool) Close() []error { if p.state.cas(connectedState, closedState) || p.state.cas(shutdownState, closedState) { - p.addrsMutex.RLock() - for _, s := range p.addrs { + p.endsMutex.RLock() + for _, s := range p.ends { s.cancel() close(s.close) } - p.addrsMutex.RUnlock() + p.endsMutex.RUnlock() } return p.waitClose() @@ -334,39 +333,23 @@ func (p *ConnectionPool) Close() []error { // for all requests to complete. func (p *ConnectionPool) CloseGraceful() []error { if p.state.cas(connectedState, shutdownState) { - p.addrsMutex.RLock() - for _, s := range p.addrs { + p.endsMutex.RLock() + for _, s := range p.ends { s.cancel() close(s.shutdown) } - p.addrsMutex.RUnlock() + p.endsMutex.RUnlock() } return p.waitClose() } -// GetAddrs gets addresses of connections in pool. -func (p *ConnectionPool) GetAddrs() []string { - p.addrsMutex.RLock() - defer p.addrsMutex.RUnlock() - - cpy := make([]string, len(p.addrs)) - - i := 0 - for addr := range p.addrs { - cpy[i] = addr - i++ - } - - return cpy -} - -// GetPoolInfo gets information of connections (connected status, ro/rw role). -func (p *ConnectionPool) GetPoolInfo() map[string]*ConnectionInfo { +// GetInfo gets information of connections (connected status, ro/rw role). +func (p *ConnectionPool) GetInfo() map[string]*ConnectionInfo { info := make(map[string]*ConnectionInfo) - p.addrsMutex.RLock() - defer p.addrsMutex.RUnlock() + p.endsMutex.RLock() + defer p.endsMutex.RUnlock() p.poolsMutex.RLock() defer p.poolsMutex.RUnlock() @@ -374,7 +357,7 @@ func (p *ConnectionPool) GetPoolInfo() map[string]*ConnectionInfo { return info } - for addr := range p.addrs { + for addr := range p.ends { conn, role := p.getConnectionFromPool(addr) if conn != nil { info[addr] = &ConnectionInfo{ConnectedNow: conn.ConnectedNow(), ConnRole: role} @@ -932,16 +915,33 @@ func (p *ConnectionPool) NewPrepared(expr string, userMode Mode) (*tarantool.Pre // Since 1.10.0 func (p *ConnectionPool) NewWatcher(key string, callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { - watchersRequired := false - for _, feature := range p.connOpts.RequiredProtocolInfo.Features { - if iproto.IPROTO_FEATURE_WATCHERS == feature { - watchersRequired = true + + rr := p.anyPool + if mode == RW { + rr = p.rwPool + } else if mode == RO { + rr = p.roPool + } + + conns := rr.GetConnections() + + watchersRequired := true + for _, conn := range conns { + watchersRequired = false + for _, feature := range conn.ServerProtocolInfo().Features { + if iproto.IPROTO_FEATURE_WATCHERS == feature { + watchersRequired = true + break + } + } + if !watchersRequired { break } } + if !watchersRequired { return nil, errors.New("the feature IPROTO_FEATURE_WATCHERS must " + - "be required by connection options to create a watcher") + "be required by any connection to create a watcher") } watcher := &poolWatcher{ @@ -955,14 +955,6 @@ func (p *ConnectionPool) NewWatcher(key string, watcher.container.add(watcher) - rr := p.anyPool - if mode == RW { - rr = p.rwPool - } else if mode == RO { - rr = p.roPool - } - - conns := rr.GetConnections() for _, conn := range conns { if err := watcher.watch(conn); err != nil { conn.Close() @@ -1030,22 +1022,22 @@ func (p *ConnectionPool) getConnectionRole(conn *tarantool.Connection) (Role, er return UnknownRole, nil } -func (p *ConnectionPool) getConnectionFromPool(addr string) (*tarantool.Connection, Role) { - if conn := p.rwPool.GetConnByAddr(addr); conn != nil { +func (p *ConnectionPool) getConnectionFromPool(id string) (*tarantool.Connection, Role) { + if conn := p.rwPool.GetConnById(id); conn != nil { return conn, MasterRole } - if conn := p.roPool.GetConnByAddr(addr); conn != nil { + if conn := p.roPool.GetConnById(id); conn != nil { return conn, ReplicaRole } - return p.anyPool.GetConnByAddr(addr), UnknownRole + return p.anyPool.GetConnById(id), UnknownRole } -func (p *ConnectionPool) deleteConnection(addr string) { - if conn := p.anyPool.DeleteConnByAddr(addr); conn != nil { - if conn := p.rwPool.DeleteConnByAddr(addr); conn == nil { - p.roPool.DeleteConnByAddr(addr) +func (p *ConnectionPool) deleteConnection(id string) { + if conn := p.anyPool.DeleteConnById(id); conn != nil { + if conn := p.rwPool.DeleteConnById(id); conn == nil { + p.roPool.DeleteConnById(id) } // The internal connection deinitialization. p.watcherContainer.mutex.RLock() @@ -1058,7 +1050,7 @@ func (p *ConnectionPool) deleteConnection(addr string) { } } -func (p *ConnectionPool) addConnection(addr string, +func (p *ConnectionPool) addConnection(id string, conn *tarantool.Connection, role Role) error { // The internal connection initialization. p.watcherContainer.mutex.RLock() @@ -1087,17 +1079,17 @@ func (p *ConnectionPool) addConnection(addr string, for _, watcher := range watched { watcher.unwatch(conn) } - log.Printf("tarantool: failed initialize watchers for %s: %s", addr, err) + log.Printf("tarantool: failed initialize watchers for %s: %s", id, err) return err } - p.anyPool.AddConn(addr, conn) + p.anyPool.AddConn(id, conn) switch role { case MasterRole: - p.rwPool.AddConn(addr, conn) + p.rwPool.AddConn(id, conn) case ReplicaRole: - p.roPool.AddConn(addr, conn) + p.roPool.AddConn(id, conn) } return nil } @@ -1130,27 +1122,27 @@ func (p *ConnectionPool) handlerDeactivated(conn *tarantool.Connection, } } -func (p *ConnectionPool) deactivateConnection(addr string, +func (p *ConnectionPool) deactivateConnection(id string, conn *tarantool.Connection, role Role) { - p.deleteConnection(addr) + p.deleteConnection(id) conn.Close() p.handlerDeactivated(conn, role) } func (p *ConnectionPool) deactivateConnections() { - for address, endpoint := range p.addrs { + for id, endpoint := range p.ends { if endpoint != nil && endpoint.conn != nil { - p.deactivateConnection(address, endpoint.conn, endpoint.role) + p.deactivateConnection(id, endpoint.conn, endpoint.role) } } } func (p *ConnectionPool) processConnection(conn *tarantool.Connection, - addr string, end *endpoint) bool { + id string, end *endpoint) bool { role, err := p.getConnectionRole(conn) if err != nil { conn.Close() - log.Printf("tarantool: storing connection to %s failed: %s\n", addr, err) + log.Printf("tarantool: storing connection %s failed: %s\n", id, err) return false } @@ -1158,7 +1150,7 @@ func (p *ConnectionPool) processConnection(conn *tarantool.Connection, conn.Close() return false } - if p.addConnection(addr, conn, role) != nil { + if p.addConnection(id, conn, role) != nil { conn.Close() p.handlerDeactivated(conn, role) return false @@ -1169,26 +1161,27 @@ func (p *ConnectionPool) processConnection(conn *tarantool.Connection, return true } -func (p *ConnectionPool) fillPools(ctx context.Context) (bool, bool) { +func (p *ConnectionPool) fillPools( + ctx context.Context, + dialers map[string]tarantool.Dialer) (bool, bool) { somebodyAlive := false ctxCanceled := false - // It is called before controller() goroutines so we don't expect + // It is called before controller() goroutines, so we don't expect // concurrency issues here. - for addr := range p.addrs { - end := newEndpoint(addr) - p.addrs[addr] = end - + for id, dialer := range dialers { + end := newEndpoint(id, dialer) + p.ends[id] = end connOpts := p.connOpts connOpts.Notify = end.notify - conn, err := tarantool.Connect(ctx, addr, connOpts) + conn, err := tarantool.Connect(ctx, dialer, connOpts) if err != nil { - log.Printf("tarantool: connect to %s failed: %s\n", addr, err.Error()) + log.Printf("tarantool: connect to %s failed: %s\n", conn.Addr(), err.Error()) select { case <-ctx.Done(): ctxCanceled = true - p.addrs[addr] = nil + p.ends[id] = nil log.Printf("tarantool: operation was canceled") p.deactivateConnections() @@ -1196,7 +1189,7 @@ func (p *ConnectionPool) fillPools(ctx context.Context) (bool, bool) { return false, ctxCanceled default: } - } else if p.processConnection(conn, addr, end) { + } else if p.processConnection(conn, id, end) { somebodyAlive = true } } @@ -1214,7 +1207,7 @@ func (p *ConnectionPool) updateConnection(e *endpoint) { if role, err := p.getConnectionRole(e.conn); err == nil { if e.role != role { - p.deleteConnection(e.addr) + p.deleteConnection(e.id) p.poolsMutex.Unlock() p.handlerDeactivated(e.conn, e.role) @@ -1237,7 +1230,7 @@ func (p *ConnectionPool) updateConnection(e *endpoint) { return } - if p.addConnection(e.addr, e.conn, role) != nil { + if p.addConnection(e.id, e.conn, role) != nil { p.poolsMutex.Unlock() e.conn.Close() @@ -1251,7 +1244,7 @@ func (p *ConnectionPool) updateConnection(e *endpoint) { p.poolsMutex.Unlock() return } else { - p.deleteConnection(e.addr) + p.deleteConnection(e.id) p.poolsMutex.Unlock() e.conn.Close() @@ -1275,14 +1268,15 @@ func (p *ConnectionPool) tryConnect(ctx context.Context, e *endpoint) error { connOpts := p.connOpts connOpts.Notify = e.notify - conn, err := tarantool.Connect(ctx, e.addr, connOpts) + conn, err := tarantool.Connect(ctx, e.dialer, connOpts) if err == nil { role, err := p.getConnectionRole(conn) p.poolsMutex.Unlock() if err != nil { conn.Close() - log.Printf("tarantool: storing connection to %s failed: %s\n", e.addr, err) + log.Printf("tarantool: storing connection to %s failed: %s\n", + e.conn.Addr(), err) return err } @@ -1300,7 +1294,7 @@ func (p *ConnectionPool) tryConnect(ctx context.Context, e *endpoint) error { return ErrClosed } - if err = p.addConnection(e.addr, conn, role); err != nil { + if err = p.addConnection(e.id, conn, role); err != nil { p.poolsMutex.Unlock() conn.Close() p.handlerDeactivated(conn, role) @@ -1322,7 +1316,7 @@ func (p *ConnectionPool) reconnect(ctx context.Context, e *endpoint) { return } - p.deleteConnection(e.addr) + p.deleteConnection(e.id) p.poolsMutex.Unlock() p.handlerDeactivated(e.conn, e.role) @@ -1358,7 +1352,7 @@ func (p *ConnectionPool) controller(ctx context.Context, e *endpoint) { case <-e.close: if e.conn != nil { p.poolsMutex.Lock() - p.deleteConnection(e.addr) + p.deleteConnection(e.id) p.poolsMutex.Unlock() if !shutdown { @@ -1380,7 +1374,7 @@ func (p *ConnectionPool) controller(ctx context.Context, e *endpoint) { shutdown = true if e.conn != nil { p.poolsMutex.Lock() - p.deleteConnection(e.addr) + p.deleteConnection(e.id) p.poolsMutex.Unlock() // We need to catch s.close in the current goroutine, so @@ -1402,7 +1396,7 @@ func (p *ConnectionPool) controller(ctx context.Context, e *endpoint) { if e.conn != nil && e.conn.ClosedNow() { p.poolsMutex.Lock() if p.state.get() == connectedState { - p.deleteConnection(e.addr) + p.deleteConnection(e.id) p.poolsMutex.Unlock() p.handlerDeactivated(e.conn, e.role) e.conn = nil diff --git a/pool/round_robin.go b/pool/round_robin.go index c3400d371..68cd4e3f4 100644 --- a/pool/round_robin.go +++ b/pool/round_robin.go @@ -8,27 +8,27 @@ import ( ) type roundRobinStrategy struct { - conns []*tarantool.Connection - indexByAddr map[string]uint - mutex sync.RWMutex - size uint64 - current uint64 + conns []*tarantool.Connection + indexById map[string]uint + mutex sync.RWMutex + size uint64 + current uint64 } func newRoundRobinStrategy(size int) *roundRobinStrategy { return &roundRobinStrategy{ - conns: make([]*tarantool.Connection, 0, size), - indexByAddr: make(map[string]uint), - size: 0, - current: 0, + conns: make([]*tarantool.Connection, 0, size), + indexById: make(map[string]uint), + size: 0, + current: 0, } } -func (r *roundRobinStrategy) GetConnByAddr(addr string) *tarantool.Connection { +func (r *roundRobinStrategy) GetConnById(id string) *tarantool.Connection { r.mutex.RLock() defer r.mutex.RUnlock() - index, found := r.indexByAddr[addr] + index, found := r.indexById[id] if !found { return nil } @@ -36,7 +36,7 @@ func (r *roundRobinStrategy) GetConnByAddr(addr string) *tarantool.Connection { return r.conns[index] } -func (r *roundRobinStrategy) DeleteConnByAddr(addr string) *tarantool.Connection { +func (r *roundRobinStrategy) DeleteConnById(id string) *tarantool.Connection { r.mutex.Lock() defer r.mutex.Unlock() @@ -44,20 +44,20 @@ func (r *roundRobinStrategy) DeleteConnByAddr(addr string) *tarantool.Connection return nil } - index, found := r.indexByAddr[addr] + index, found := r.indexById[id] if !found { return nil } - delete(r.indexByAddr, addr) + delete(r.indexById, id) conn := r.conns[index] r.conns = append(r.conns[:index], r.conns[index+1:]...) r.size -= 1 - for k, v := range r.indexByAddr { + for k, v := range r.indexById { if v > index { - r.indexByAddr[k] = v - 1 + r.indexById[k] = v - 1 } } @@ -91,15 +91,15 @@ func (r *roundRobinStrategy) GetConnections() []*tarantool.Connection { return ret } -func (r *roundRobinStrategy) AddConn(addr string, conn *tarantool.Connection) { +func (r *roundRobinStrategy) AddConn(id string, conn *tarantool.Connection) { r.mutex.Lock() defer r.mutex.Unlock() - if idx, ok := r.indexByAddr[addr]; ok { + if idx, ok := r.indexById[id]; ok { r.conns[idx] = conn } else { r.conns = append(r.conns, conn) - r.indexByAddr[addr] = uint(r.size) + r.indexById[id] = uint(r.size) r.size += 1 } } diff --git a/ssl.go b/ssl.go index 8ca430559..67fd2e22c 100644 --- a/ssl.go +++ b/ssl.go @@ -15,8 +15,37 @@ import ( "github.com/tarantool/go-openssl" ) +type sslOpts struct { + // keyFile is a path to a private SSL key file. + keyFile string + // certFile is a path to an SSL certificate file. + certFile string + // caFile is a path to a trusted certificate authorities (CA) file. + caFile string + // ciphers is a colon-separated (:) list of SSL cipher suites the connection + // can use. + // + // We don't provide a list of supported ciphers. This is what OpenSSL + // does. The only limitation is usage of TLSv1.2 (because other protocol + // versions don't seem to support the GOST cipher). To add additional + // ciphers (GOST cipher), you must configure OpenSSL. + // + // See also + // + // * https://www.openssl.org/docs/man1.1.1/man1/ciphers.html + ciphers string + // password is a password for decrypting the private SSL key file. + // The priority is as follows: try to decrypt with Password, then + // try PasswordFile. + password string + // passwordFile is a path to the list of passwords for decrypting + // the private SSL key file. The connection tries every line from the + // file as a password. + passwordFile string +} + func sslDialContext(ctx context.Context, network, address string, - opts SslOpts) (connection net.Conn, err error) { + opts sslOpts) (connection net.Conn, err error) { var sslCtx interface{} if sslCtx, err = sslCreateContext(opts); err != nil { return @@ -27,7 +56,7 @@ func sslDialContext(ctx context.Context, network, address string, // interface{} is a hack. It helps to avoid dependency of go-openssl in build // of tests with the tag 'go_tarantool_ssl_disable'. -func sslCreateContext(opts SslOpts) (ctx interface{}, err error) { +func sslCreateContext(opts sslOpts) (ctx interface{}, err error) { var sslCtx *openssl.Ctx // Require TLSv1.2, because other protocol versions don't seem to @@ -39,28 +68,28 @@ func sslCreateContext(opts SslOpts) (ctx interface{}, err error) { sslCtx.SetMaxProtoVersion(openssl.TLS1_2_VERSION) sslCtx.SetMinProtoVersion(openssl.TLS1_2_VERSION) - if opts.CertFile != "" { - if err = sslLoadCert(sslCtx, opts.CertFile); err != nil { + if opts.certFile != "" { + if err = sslLoadCert(sslCtx, opts.certFile); err != nil { return } } - if opts.KeyFile != "" { - if err = sslLoadKey(sslCtx, opts.KeyFile, opts.Password, opts.PasswordFile); err != nil { + if opts.keyFile != "" { + if err = sslLoadKey(sslCtx, opts.keyFile, opts.password, opts.passwordFile); err != nil { return } } - if opts.CaFile != "" { - if err = sslCtx.LoadVerifyLocations(opts.CaFile, ""); err != nil { + if opts.caFile != "" { + if err = sslCtx.LoadVerifyLocations(opts.caFile, ""); err != nil { return } verifyFlags := openssl.VerifyPeer | openssl.VerifyFailIfNoPeerCert sslCtx.SetVerify(verifyFlags, nil) } - if opts.Ciphers != "" { - sslCtx.SetCipherList(opts.Ciphers) + if opts.ciphers != "" { + sslCtx.SetCipherList(opts.ciphers) } return