From 1a49b281a97e216b45a34645d114dc2764477bea Mon Sep 17 00:00:00 2001 From: DerekBum Date: Thu, 18 Jan 2024 15:51:57 +0300 Subject: [PATCH] api: create `AuthDialer` and `ProtocolDialer` To disable SSL by default we want to transfer `OpenSslDialer` to the go-openssl repository. In order to do so, we need to minimize the amount of copy-paste of the private functions. `AuthDialer` is created as a dialer-wrapper, that calls authentication methods. `ProtoDialer` is created to check the `ProtocolInfo` in the created connection. Part of #301 --- CHANGELOG.md | 2 + connection.go | 4 +- dial.go | 308 ++++++++++++++++++++++++++++++++-------------- dial_test.go | 206 ++++++++++++++++++++++++++++++- ssl_test.go | 8 +- tarantool_test.go | 14 +++ 6 files changed, 441 insertions(+), 101 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1b1fc170..cd16f0289 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. the response (#237) - Ability to mock connections for tests (#237). Added new types `MockDoer`, `MockRequest` to `test_helpers`. +- `AuthDialer` and `ProtocolDialer` types for creating a dialer with + authentication and `ProtocolInfo` check (#301) ### Changed diff --git a/connection.go b/connection.go index 8f8631a31..24edebece 100644 --- a/connection.go +++ b/connection.go @@ -440,7 +440,9 @@ func (conn *Connection) dial(ctx context.Context) error { } conn.addr = c.Addr() - conn.Greeting.Version = c.Greeting().Version + connGreeting := c.Greeting() + conn.Greeting.Version = connGreeting.Version + conn.Greeting.Salt = connGreeting.Salt conn.serverProtocolInfo = c.ProtocolInfo() spaceAndIndexNamesSupported := diff --git a/dial.go b/dial.go index ff5419760..5fcbd3bc4 100644 --- a/dial.go +++ b/dial.go @@ -20,7 +20,10 @@ const bufSize = 128 * 1024 // Greeting is a message sent by Tarantool on connect. type Greeting struct { + // Version is the supported protocol version. Version string + // Salt is used to authenticate a user. + Salt string } // writeFlusher is the interface that groups the basic Write and Flush methods. @@ -71,30 +74,43 @@ type Dialer interface { } type tntConn struct { - net net.Conn - reader io.Reader - writer writeFlusher + net net.Conn + reader io.Reader + writer writeFlusher +} + +// protocolConn is a wrapper for connections, so they contain the ProtocolInfo. +type protocolConn struct { + Conn + protocolInfo ProtocolInfo +} + +// greetingConn is a wrapper for connections, so they contain the Greeting. +type greetingConn struct { + Conn greeting Greeting - protocol ProtocolInfo } -// rawDial does basic dial operations: -// reads greeting, identifies a protocol and validates it. -func rawDial(conn *tntConn, requiredProto ProtocolInfo) (string, error) { - version, salt, err := readGreeting(conn.reader) +type netDialer struct { + address string +} + +func (d netDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + var err error + conn := new(tntConn) + + network, address := parseAddress(d.address) + dialer := net.Dialer{} + conn.net, err = dialer.DialContext(ctx, network, address) if err != nil { - return "", fmt.Errorf("failed to read greeting: %w", err) + return nil, fmt.Errorf("failed to dial: %w", err) } - conn.greeting.Version = version - if conn.protocol, err = identify(conn.writer, conn.reader); err != nil { - return "", fmt.Errorf("failed to identify: %w", err) - } + dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} + conn.reader = bufio.NewReaderSize(dc, bufSize) + conn.writer = bufio.NewWriterSize(dc, bufSize) - if err = checkProtocolInfo(requiredProto, conn.protocol); err != nil { - return "", fmt.Errorf("invalid server protocol: %w", err) - } - return salt, err + return conn, nil } // NetDialer is a basic Dialer implementation. @@ -121,12 +137,44 @@ type NetDialer struct { // Dial makes NetDialer satisfy the Dialer interface. func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + dialer := AuthDialer{ + Dialer: ProtocolDialer{ + Dialer: netDialer{ + address: d.Address, + }, + RequiredProtocolInfo: d.RequiredProtocolInfo, + }, + Auth: ChapSha1Auth, + Username: d.User, + Password: d.Password, + } + + return dialer.Dial(ctx, opts) +} + +type openSslDialer struct { + address string + sslKeyFile string + sslCertFile string + sslCaFile string + sslCiphers string + sslPassword string + sslPasswordFile string +} + +func (d openSslDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { var err error conn := new(tntConn) - network, address := parseAddress(d.Address) - dialer := net.Dialer{} - conn.net, err = dialer.DialContext(ctx, network, address) + network, address := parseAddress(d.address) + conn.net, err = sslDialContext(ctx, network, address, sslOpts{ + KeyFile: d.sslKeyFile, + CertFile: d.sslCertFile, + CaFile: d.sslCaFile, + Ciphers: d.sslCiphers, + Password: d.sslPassword, + PasswordFile: d.sslPasswordFile, + }) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } @@ -135,22 +183,6 @@ func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { conn.reader = bufio.NewReaderSize(dc, bufSize) conn.writer = bufio.NewWriterSize(dc, bufSize) - salt, err := rawDial(conn, d.RequiredProtocolInfo) - if err != nil { - conn.net.Close() - return nil, err - } - - if d.User == "" { - return conn, nil - } - - conn.protocol.Auth = ChapSha1Auth - if err = authenticate(conn, ChapSha1Auth, d.User, d.Password, salt); err != nil { - conn.net.Close() - return nil, fmt.Errorf("failed to authenticate: %w", err) - } - return conn, nil } @@ -206,51 +238,25 @@ type OpenSslDialer struct { // Dial makes OpenSslDialer satisfy the Dialer interface. func (d OpenSslDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { - var err error - conn := new(tntConn) - - network, address := parseAddress(d.Address) - conn.net, err = sslDialContext(ctx, network, address, sslOpts{ - KeyFile: d.SslKeyFile, - CertFile: d.SslCertFile, - CaFile: d.SslCaFile, - Ciphers: d.SslCiphers, - Password: d.SslPassword, - PasswordFile: d.SslPasswordFile, - }) - if err != nil { - return nil, fmt.Errorf("failed to dial: %w", err) - } - - dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} - conn.reader = bufio.NewReaderSize(dc, bufSize) - conn.writer = bufio.NewWriterSize(dc, bufSize) - - salt, err := rawDial(conn, d.RequiredProtocolInfo) - if err != nil { - conn.net.Close() - return nil, err - } - - if d.User == "" { - return conn, nil - } - - if d.Auth == AutoAuth { - if conn.protocol.Auth != AutoAuth { - d.Auth = conn.protocol.Auth - } else { - d.Auth = ChapSha1Auth - } - } - conn.protocol.Auth = d.Auth - - if err = authenticate(conn, d.Auth, d.User, d.Password, salt); err != nil { - conn.net.Close() - return nil, fmt.Errorf("failed to authenticate: %w", err) - } - - return conn, nil + dialer := AuthDialer{ + Dialer: ProtocolDialer{ + Dialer: openSslDialer{ + address: d.Address, + sslKeyFile: d.SslKeyFile, + sslCertFile: d.SslCertFile, + sslCaFile: d.SslCaFile, + sslCiphers: d.SslCiphers, + sslPassword: d.SslPassword, + sslPasswordFile: d.SslPasswordFile, + }, + RequiredProtocolInfo: d.RequiredProtocolInfo, + }, + Auth: d.Auth, + Username: d.User, + Password: d.Password, + } + + return dialer.Dial(ctx, opts) } // FdDialer allows to use an existing socket fd for connection. @@ -263,6 +269,10 @@ type FdDialer struct { RequiredProtocolInfo ProtocolInfo } +type fdDialer struct { + fd uintptr +} + type fdAddr struct { Fd uintptr } @@ -284,30 +294,132 @@ func (c *fdConn) RemoteAddr() net.Addr { return c.Addr } -// Dial makes FdDialer satisfy the Dialer interface. -func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { - file := os.NewFile(d.Fd, "") +func (d fdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + file := os.NewFile(d.fd, "") c, err := net.FileConn(file) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } conn := new(tntConn) - conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.Fd}} + conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.fd}} dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} conn.reader = bufio.NewReaderSize(dc, bufSize) conn.writer = bufio.NewWriterSize(dc, bufSize) - _, err = rawDial(conn, d.RequiredProtocolInfo) + return conn, nil +} + +// Dial makes FdDialer satisfy the Dialer interface. +func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + dialer := ProtocolDialer{ + Dialer: fdDialer{ + fd: d.Fd, + }, + RequiredProtocolInfo: d.RequiredProtocolInfo, + } + + return dialer.Dial(ctx, opts) +} + +// AuthDialer is a dialer-wrapper that does authentication of a user. +type AuthDialer struct { + // Dialer is a base dialer. + Dialer Dialer + // Authentication options. + Auth Auth + // Username is a name of a user for authentication. + Username string + // Password is a user password for authentication. + Password string +} + +// Dial makes AuthDialer satisfy the Dialer interface. +func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + conn, err := d.Dialer.Dial(ctx, opts) if err != nil { - conn.net.Close() - return nil, err + return conn, err + } + greeting := conn.Greeting() + if greeting.Salt == "" { + conn.Close() + return nil, fmt.Errorf("failed to authenticate: " + + "an invalid connection without salt") + } + + if d.Username == "" { + return conn, nil } + protocolAuth := conn.ProtocolInfo().Auth + if d.Auth == AutoAuth { + if protocolAuth != AutoAuth { + d.Auth = protocolAuth + } else { + d.Auth = ChapSha1Auth + } + } + + if err := authenticate(conn, d.Auth, d.Username, d.Password, + conn.Greeting().Salt); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to authenticate: %w", err) + } return conn, nil } +// ProtocolDialer is a dialer-wrapper that reads and fills the Greeting and +// ProtocolInfo of a connection. +type ProtocolDialer struct { + // Dialer is a base dialer. + Dialer Dialer + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default, there are no restrictions. + RequiredProtocolInfo ProtocolInfo +} + +// Dial makes ProtocolDialer satisfy the Dialer interface. +func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + conn, err := d.Dialer.Dial(ctx, opts) + if err != nil { + return conn, err + } + + greetingConn := greetingConn{ + Conn: conn, + } + version, salt, err := readGreeting(greetingConn) + if err != nil { + greetingConn.Close() + return nil, fmt.Errorf("failed to read greeting: %w", err) + } + greetingConn.greeting = Greeting{ + Version: version, + Salt: salt, + } + + protocolConn := protocolConn{ + Conn: &greetingConn, + protocolInfo: d.RequiredProtocolInfo, + } + + protocolConn.protocolInfo, err = identify(&protocolConn) + if err != nil { + protocolConn.Close() + return nil, fmt.Errorf("failed to identify: %w", err) + } + + err = checkProtocolInfo(d.RequiredProtocolInfo, protocolConn.protocolInfo) + if err != nil { + protocolConn.Close() + return nil, fmt.Errorf("invalid server protocol: %w", err) + } + + return &protocolConn, nil +} + // Addr makes tntConn satisfy the Conn interface. func (c *tntConn) Addr() net.Addr { return c.net.RemoteAddr() @@ -341,12 +453,22 @@ func (c *tntConn) Close() error { // Greeting makes tntConn satisfy the Conn interface. func (c *tntConn) Greeting() Greeting { - return c.greeting + return Greeting{} } // ProtocolInfo makes tntConn satisfy the Conn interface. func (c *tntConn) ProtocolInfo() ProtocolInfo { - return c.protocol + return ProtocolInfo{} +} + +// ProtocolInfo returns ProtocolInfo of a protocolConn. +func (c *protocolConn) ProtocolInfo() ProtocolInfo { + return c.protocolInfo +} + +// Greeting returns Greeting of a greetingConn. +func (c *greetingConn) Greeting() Greeting { + return c.greeting } // parseAddress split address into network and address parts. @@ -390,15 +512,15 @@ func readGreeting(reader io.Reader) (string, string, error) { // identify sends info about client protocol, receives info // about server protocol in response and stores it in the connection. -func identify(w writeFlusher, r io.Reader) (ProtocolInfo, error) { +func identify(conn Conn) (ProtocolInfo, error) { var info ProtocolInfo req := NewIdRequest(clientProtocolInfo) - if err := writeRequest(w, req); err != nil { + if err := writeRequest(conn, req); err != nil { return info, err } - resp, err := readResponse(r, req) + resp, err := readResponse(conn, req) if err != nil { if resp != nil && resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE { diff --git a/dial_test.go b/dial_test.go index 8e7ec8727..4fa4d4704 100644 --- a/dial_test.go +++ b/dial_test.go @@ -3,6 +3,7 @@ package tarantool_test import ( "bytes" "context" + "encoding/base64" "errors" "fmt" "net" @@ -235,6 +236,7 @@ func TestConn_Addr(t *testing.T) { func TestConn_Greeting(t *testing.T) { greeting := tarantool.Greeting{ Version: "any", + Salt: "salt", } conn, dialer := dialIo(t, func(conn *mockIoConn) { conn.greeting = greeting @@ -522,6 +524,7 @@ func testDialer(t *testing.T, l net.Listener, dialer tarantool.Dialer, require.NoError(t, err) require.Equal(t, opts.expectedProtocolInfo, conn.ProtocolInfo()) require.Equal(t, testDialVersion[:], []byte(conn.Greeting().Version)) + require.Equal(t, testDialSalt[:44], []byte(conn.Greeting().Salt)) actual := <-ch require.Equal(t, idRequestExpected, actual.IdRequest) @@ -551,9 +554,8 @@ func TestNetDialer_Dial(t *testing.T) { expectedProtocolInfo: idResponseTyped.Clone(), }, { - name: "id request unsupported", - // Dialer sets auth. - expectedProtocolInfo: tarantool.ProtocolInfo{Auth: tarantool.ChapSha1Auth}, + name: "id request unsupported", + expectedProtocolInfo: tarantool.ProtocolInfo{}, isIdUnsupported: true, }, { @@ -678,3 +680,201 @@ func TestFdDialer_Dial_requirements(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "invalid server protocol") } + +func TestAuthDialer_Dial_DialerError(t *testing.T) { + dialer := tarantool.AuthDialer{ + Dialer: mockErrorDialer{ + err: fmt.Errorf("some error"), + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NotNil(t, err) + assert.EqualError(t, err, "some error") +} + +func TestAuthDialer_Dial_NoSalt(t *testing.T) { + dialer := tarantool.AuthDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.greeting = tarantool.Greeting{ + Salt: "", + } + }, + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "an invalid connection without salt") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestAuthDialer_Dial(t *testing.T) { + salt := fmt.Sprintf("%s", testDialSalt) + salt = base64.StdEncoding.EncodeToString([]byte(salt)) + dialer := mockIoDialer{ + init: func(conn *mockIoConn) { + conn.greeting.Salt = salt + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(okResponse) + }, + } + defer func() { + dialer.conn.writeWg.Done() + }() + + authDialer := tarantool.AuthDialer{ + Dialer: &dialer, + Username: "test", + Password: "test", + } + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := authDialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.Nil(t, err) + assert.NotNil(t, conn) + assert.Equal(t, authRequestExpectedChapSha1[:41], dialer.conn.writebuf.Bytes()[:41]) +} + +func TestProtocolDialer_Dial_DialerError(t *testing.T) { + dialer := tarantool.ProtocolDialer{ + Dialer: mockErrorDialer{ + err: fmt.Errorf("some error"), + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NotNil(t, err) + assert.EqualError(t, err, "some error") +} + +func TestProtocolDialer_Dial_GreetingFailed(t *testing.T) { + dialer := tarantool.ProtocolDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(errResponse) + }, + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "failed to read greeting") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestProtocolDialer_Dial_IdentifyFailed(t *testing.T) { + dialer := tarantool.ProtocolDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 3 + conn.readbuf.Write(append(testDialVersion[:], testDialSalt[:]...)) + conn.readbuf.Write(errResponse) + }, + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "failed to identify") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestProtocolDialer_Dial_WrongInfo(t *testing.T) { + dialer := tarantool.ProtocolDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 3 + conn.readbuf.Write(append(testDialVersion[:], testDialSalt[:]...)) + conn.readbuf.Write(idResponse) + }, + }, + RequiredProtocolInfo: validProtocolInfo, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "invalid server protocol") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestProtocolDialer_Dial(t *testing.T) { + protoInfo := tarantool.ProtocolInfo{ + Auth: tarantool.ChapSha1Auth, + Version: 6, + Features: []iproto.Feature{0x01, 0x15}, + } + + dialer := tarantool.ProtocolDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 3 + conn.readbuf.Write(append(testDialVersion[:], testDialSalt[:]...)) + conn.readbuf.Write(idResponse) + }, + }, + RequiredProtocolInfo: protoInfo, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.Nil(t, err) + assert.NotNil(t, conn) + assert.Equal(t, protoInfo, conn.ProtocolInfo()) +} diff --git a/ssl_test.go b/ssl_test.go index 44b26eb05..f161b98f0 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -654,9 +654,8 @@ func TestOpenSslDialer_Dial_basic(t *testing.T) { expectedProtocolInfo: idResponseTyped.Clone(), }, { - name: "id request unsupported", - // Dialer sets auth. - expectedProtocolInfo: ProtocolInfo{Auth: ChapSha1Auth}, + name: "id request unsupported", + expectedProtocolInfo: ProtocolInfo{}, isIdUnsupported: true, }, { @@ -730,8 +729,9 @@ func TestOpenSslDialer_Dial_papSha256Auth(t *testing.T) { Auth: PapSha256Auth, } + // Response from the server. protocol := idResponseTyped.Clone() - protocol.Auth = PapSha256Auth + protocol.Auth = ChapSha1Auth testDialer(t, l, dialer, testDialOpts{ expectedProtocolInfo: protocol, diff --git a/tarantool_test.go b/tarantool_test.go index c3f6b4c0b..3cdef2857 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -772,6 +772,20 @@ func TestNetDialer(t *testing.T) { assert.Equal([]byte{0x83, 0x00, 0xce, 0x00, 0x00, 0x00, 0x00}, buf[:7]) } +func TestNetDialer_BadUser(t *testing.T) { + badDialer := NetDialer{ + Address: server, + User: "Cpt Smollett", + Password: "none", + } + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := Connect(ctx, badDialer, opts) + assert.NotNil(t, err) + assert.Nil(t, conn) +} + func TestFutureMultipleGetGetTyped(t *testing.T) { conn := test_helpers.ConnectWithValidation(t, dialer, opts) defer conn.Close()