diff --git a/.gitignore b/.gitignore index fcd3c3236..c9f687eb1 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ work_dir* .rocks bench* +testdata/sidecar/main diff --git a/dial.go b/dial.go index 6b75adafa..eae8e1283 100644 --- a/dial.go +++ b/dial.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "os" "strings" "time" @@ -252,6 +253,61 @@ func (d OpenSslDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { return conn, nil } +// FdDialer allows to use an existing socket fd for connection. +type FdDialer struct { + // Fd is a socket file descrpitor. + Fd uintptr + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default, there are no restrictions. + RequiredProtocolInfo ProtocolInfo +} + +type fdAddr struct { + Fd uintptr +} + +func (a fdAddr) Network() string { + return "fd" +} + +func (a fdAddr) String() string { + return fmt.Sprintf("fd://%d", a.Fd) +} + +type fdConn struct { + net.Conn + Addr fdAddr +} + +func (c *fdConn) RemoteAddr() net.Addr { + return c.Addr +} + +// Dial makes FdDialer satisfy the Dialer interface. +func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + file := os.NewFile(d.Fd, "") + c, err := net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + conn := new(tntConn) + conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.Fd}} + + dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} + conn.reader = bufio.NewReaderSize(dc, bufSize) + conn.writer = bufio.NewWriterSize(dc, bufSize) + + _, err = rawDial(conn, d.RequiredProtocolInfo) + if err != nil { + conn.net.Close() + return nil, err + } + + return conn, nil +} + // Addr makes tntConn satisfy the Conn interface. func (c *tntConn) Addr() net.Addr { return c.net.RemoteAddr() diff --git a/dial_test.go b/dial_test.go index 750c100c9..a1b7f2317 100644 --- a/dial_test.go +++ b/dial_test.go @@ -442,6 +442,7 @@ type testDialOpts struct { isIdUnsupported bool isPapSha256Auth bool isErrAuth bool + isEmptyAuth bool } type dialServerActual struct { @@ -483,6 +484,8 @@ func testDialAccept(t *testing.T, opts testDialOpts, l net.Listener) chan dialSe authRequestExpected := authRequestExpectedChapSha1 if opts.isPapSha256Auth { authRequestExpected = authRequestExpectedPapSha256 + } else if opts.isEmptyAuth { + authRequestExpected = []byte{} } authRequestActual := make([]byte, len(authRequestExpected)) client.Read(authRequestActual) @@ -525,6 +528,8 @@ func testDialer(t *testing.T, l net.Listener, dialer tarantool.Dialer, authRequestExpected := authRequestExpectedChapSha1 if opts.isPapSha256Auth { authRequestExpected = authRequestExpectedPapSha256 + } else if opts.isEmptyAuth { + authRequestExpected = []byte{} } require.Equal(t, authRequestExpected, actual.AuthRequest) conn.Close() @@ -769,3 +774,48 @@ func TestOpenSslDialer_Dial_ctx_cancel(t *testing.T) { _, err := dialer.Dial(ctx, tarantool.DialOpts{}) require.Error(t, err) } + +func TestFdDialer_Dial(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + + cases := []testDialOpts{ + { + name: "all is ok", + expectedProtocolInfo: idResponseTyped.Clone(), + isEmptyAuth: true, + }, + { + name: "id request unsupported", + expectedProtocolInfo: tarantool.ProtocolInfo{}, + isIdUnsupported: true, + isEmptyAuth: true, + }, + { + name: "greeting response error", + wantErr: true, + expectedErr: "failed to read greeting", + isErrGreeting: true, + }, + { + name: "id response error", + wantErr: true, + expectedErr: "failed to identify", + isErrId: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sock, err := net.Dial("tcp", addr) + require.NoError(t, err) + f, err := sock.(*net.TCPConn).File() + require.NoError(t, err) + dialer := tarantool.FdDialer{ + Fd: f.Fd(), + } + testDialer(t, l, dialer, tc) + }) + } +} diff --git a/example_test.go b/example_test.go index d4da11853..7952b0a11 100644 --- a/example_test.go +++ b/example_test.go @@ -3,6 +3,7 @@ package tarantool_test import ( "context" "fmt" + "net" "time" "github.com/tarantool/go-iproto" @@ -1330,3 +1331,34 @@ func ExampleWatchOnceRequest() { fmt.Println(resp.Data) } } + +// This example demonstrates how to use an existing socket file descriptor +// to establish a connection with Tarantool. This can be useful if the socket fd +// was inherited from the Tarantool process itself. +// For details, please see TestFdDialer in tarantool_test.go. +func ExampleFdDialer() { + addr := dialer.Address + c, err := net.Dial("tcp", addr) + if err != nil { + fmt.Printf("can't establish connection: %v\n", err) + return + } + f, err := c.(*net.TCPConn).File() + if err != nil { + fmt.Printf("unexpected error: %v\n", err) + return + } + dialer := tarantool.FdDialer{ + Fd: f.Fd(), + } + // Use an existing socket fd to create connection with Tarantool. + conn, err := tarantool.Connect(context.Background(), dialer, opts) + if err != nil { + fmt.Printf("connect error: %v\n", err) + return + } + resp, err := conn.Do(tarantool.NewPingRequest()).Get() + fmt.Println(resp.Code, err) + // Output: + // 0 +} diff --git a/tarantool_test.go b/tarantool_test.go index 5d1d76aa6..d3976fde1 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -8,6 +8,8 @@ import ( "log" "math" "os" + "os/exec" + "path/filepath" "reflect" "runtime" "strings" @@ -77,6 +79,7 @@ func (m *Member) DecodeMsgpack(d *msgpack.Decoder) error { } var server = "127.0.0.1:3013" +var fdDialerTestServer = "127.0.0.1:3014" var spaceNo = uint32(617) var spaceName = "test" var indexNo = uint32(0) @@ -3927,6 +3930,87 @@ func TestConnect_context_cancel(t *testing.T) { } } +func buildSidecar(dir string) error { + goPath, err := exec.LookPath("go") + if err != nil { + return err + } + cmd := exec.Command(goPath, "build", "main.go") + cmd.Dir = filepath.Join(dir, "testdata", "sidecar") + return cmd.Run() +} + +func TestFdDialer(t *testing.T) { + isLess, err := test_helpers.IsTarantoolVersionLess(3, 0, 0) + if err != nil || isLess { + t.Skip("box.session.new present in Tarantool since version 3.0") + } + + wd, err := os.Getwd() + require.NoError(t, err) + + err = buildSidecar(wd) + require.NoErrorf(t, err, "failed to build sidecar: %v", err) + + instOpts := startOpts + instOpts.Listen = fdDialerTestServer + instOpts.Dialer = NetDialer{ + Address: fdDialerTestServer, + User: "test", + Password: "test", + } + + inst, err := test_helpers.StartTarantool(instOpts) + require.NoError(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + sidecarExe := filepath.Join(wd, "testdata", "sidecar", "main") + + evalBody := fmt.Sprintf(` + local socket = require('socket') + local popen = require('popen') + local os = require('os') + local s1, s2 = socket.socketpair('AF_UNIX', 'SOCK_STREAM', 0) + + --[[ Tell sidecar which fd use to connect. --]] + os.setenv('SOCKET_FD', tostring(s2:fd())) + + box.session.new({ + type = 'binary', + fd = s1:fd(), + user = 'test', + }) + s1:detach() + + local ph, err = popen.new({'%s'}, { + stdout = popen.opts.PIPE, + stderr = popen.opts.PIPE, + inherit_fds = {s2:fd()}, + }) + + if err ~= nil then + return 1, err + end + + ph:wait() + + local status_code = ph:info().status.exit_code + local stderr = ph:read({stderr=true}):rstrip() + local stdout = ph:read({stdout=true}):rstrip() + return status_code, stderr, stdout + `, sidecarExe) + + var resp []interface{} + err = conn.EvalTyped(evalBody, []interface{}{}, &resp) + require.NoError(t, err) + require.Equal(t, "", resp[1], resp[1]) + require.Equal(t, "", resp[2], resp[2]) + require.Equal(t, int8(0), resp[0]) +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/testdata/sidecar/main.go b/testdata/sidecar/main.go new file mode 100644 index 000000000..971b8694c --- /dev/null +++ b/testdata/sidecar/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + "os" + "strconv" + + "github.com/tarantool/go-tarantool/v2" +) + +func main() { + fd, err := strconv.Atoi(os.Getenv("SOCKET_FD")) + if err != nil { + panic(err) + } + dialer := tarantool.FdDialer{ + Fd: uintptr(fd), + } + conn, err := tarantool.Connect(context.Background(), dialer, tarantool.Opts{}) + if err != nil { + panic(err) + } + if _, err := conn.Do(tarantool.NewPingRequest()).Get(); err != nil { + panic(err) + } + // Insert new tuple. + if _, err := conn.Do(tarantool.NewInsertRequest("test"). + Tuple([]interface{}{239})).Get(); err != nil { + panic(err) + } + // Delete inserted tuple. + if _, err := conn.Do(tarantool.NewDeleteRequest("test"). + Index("primary"). + Key([]interface{}{239})).Get(); err != nil { + panic(err) + } +}