diff --git a/neo4j/error.go b/neo4j/error.go index cfc6bb0f..e0eeb620 100644 --- a/neo4j/error.go +++ b/neo4j/error.go @@ -20,6 +20,8 @@ package neo4j import ( + "context" + "errors" "fmt" "io" "net" @@ -114,7 +116,9 @@ func wrapError(err error) error { if err == nil { return nil } - if err == io.EOF { + if err == io.EOF || + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) { return &ConnectivityError{inner: err} } switch e := err.(type) { diff --git a/neo4j/internal/bolt/dechunker.go b/neo4j/internal/bolt/dechunker.go index 1a370b35..5fee7b76 100644 --- a/neo4j/internal/bolt/dechunker.go +++ b/neo4j/internal/bolt/dechunker.go @@ -20,9 +20,10 @@ package bolt import ( + "context" "encoding/binary" + rio "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/racingio" "github.com/neo4j/neo4j-go-driver/v4/neo4j/log" - "io" "net" "time" ) @@ -30,26 +31,37 @@ import ( // dechunkMessage takes a buffer to be reused and returns the reusable buffer // (might have been reallocated to handle growth), the message buffer and // error. -// If a non-default connection read timeout configuration hint is passed, the dechunker resets the connection read -// deadline as well after successfully reading a chunk (NOOP messages included) -func dechunkMessage(conn net.Conn, msgBuf []byte, readTimeout time.Duration, - logger log.Logger, logName, logId string) ([]byte, []byte, error) { +// Reads will race against the provided context ctx +// If the server provides the connection read timeout hint readTimeout, a new context will be created from that timeout +// and the user-provided context ctx before every read +func dechunkMessage( + conn net.Conn, + msgBuf []byte, + readTimeout time.Duration, + logger log.Logger, + logName string, + logId string) ([]byte, []byte, error) { + sizeBuf := []byte{0x00, 0x00} off := 0 + reader := rio.NewRacingReader(conn) + for { - _, err := io.ReadFull(conn, sizeBuf) + updatedCtx, cancelFunc := newContext(readTimeout, logger, logName, logId) + _, err := reader.ReadFull(updatedCtx, sizeBuf) if err != nil { return msgBuf, nil, err } + if cancelFunc != nil { // reading has been completed, time to release the context + cancelFunc() + } chunkSize := int(binary.BigEndian.Uint16(sizeBuf)) if chunkSize == 0 { if off > 0 { return msgBuf, msgBuf[:off], nil } // Got a nop chunk - resetConnectionReadDeadline(conn, readTimeout, logger, - logName, logId) continue } @@ -60,20 +72,42 @@ func dechunkMessage(conn net.Conn, msgBuf []byte, readTimeout time.Duration, msgBuf = newMsgBuf } // Read the chunk into buffer - _, err = io.ReadFull(conn, msgBuf[off:(off+chunkSize)]) + updatedCtx, cancelFunc = newContext(readTimeout, logger, logName, logId) + _, err = reader.ReadFull(updatedCtx, msgBuf[off:(off+chunkSize)]) if err != nil { return msgBuf, nil, err } + if cancelFunc != nil { // reading has been completed, time to release the context + cancelFunc() + } off += chunkSize - resetConnectionReadDeadline(conn, readTimeout, logger, logName, logId) } } -func resetConnectionReadDeadline(conn net.Conn, readTimeout time.Duration, logger log.Logger, logName, logId string) { - if readTimeout < 0 { - return +// newContext computes a new context and cancel function if a readTimeout is set +func newContext( + readTimeout time.Duration, + logger log.Logger, + logName string, + logId string) (context.Context, context.CancelFunc) { + + ctx := context.Background() + if readTimeout >= 0 { + newCtx, cancelFunc := context.WithTimeout(ctx, readTimeout) + logger.Debugf(logName, logId, + "read timeout of %s applied, chunk read deadline is now: %s", + readTimeout.String(), + deadlineOf(newCtx), + ) + return newCtx, cancelFunc } - if err := conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil { - logger.Error(logName, logId, err) + return ctx, nil +} + +func deadlineOf(ctx context.Context) string { + deadline, hasDeadline := ctx.Deadline() + if !hasDeadline { + return "N/A (no deadline set)" } + return deadline.String() } diff --git a/neo4j/internal/bolt/dechunker_test.go b/neo4j/internal/bolt/dechunker_test.go index 33bf1428..d76c9e09 100644 --- a/neo4j/internal/bolt/dechunker_test.go +++ b/neo4j/internal/bolt/dechunker_test.go @@ -22,6 +22,7 @@ package bolt import ( "bytes" "encoding/binary" + "github.com/neo4j/neo4j-go-driver/v4/neo4j/log" "net" "reflect" "testing" @@ -108,17 +109,10 @@ func TestDechunker(t *testing.T) { func TestDechunkerWithTimeout(ot *testing.T) { timeout := time.Millisecond * 600 - serv, cli := net.Pipe() - defer func() { - AssertNoError(ot, serv.Close()) - AssertNoError(ot, cli.Close()) - }() - AssertNoError(ot, serv.SetReadDeadline(time.Now().Add(timeout))) - logger := &noopLogger{} - logName := "dechunker" - logId := "dechunker-test" ot.Run("Resets connection deadline upon successful reads", func(t *testing.T) { + serv, cli := net.Pipe() + defer closePipe(ot, serv, cli) go func() { time.Sleep(timeout / 2) AssertWriteSucceeds(t, cli, []byte{0x00, 0x00}) @@ -128,32 +122,24 @@ func TestDechunkerWithTimeout(ot *testing.T) { AssertWriteSucceeds(t, cli, []byte{0x00, 0x00}) }() buffer := make([]byte, 2) - _, _, err := dechunkMessage(serv, buffer, timeout, logger, logName, - logId) + _, _, err := dechunkMessage(serv, buffer, timeout, log.Void{}, "", "") AssertNoError(t, err) AssertTrue(t, reflect.DeepEqual(buffer, []byte{0xCA, 0xFE})) }) ot.Run("Fails when connection deadline is reached", func(t *testing.T) { - _, _, err := dechunkMessage(serv, nil, timeout, logger, logName, - logId) - AssertError(t, err) - AssertStringContain(t, err.Error(), "read pipe") - }) - -} - -type noopLogger struct { -} + serv, cli := net.Pipe() + defer closePipe(ot, serv, cli) -func (*noopLogger) Error(string, string, error) { -} + _, _, err := dechunkMessage(serv, nil, timeout, log.Void{}, "", "") -func (*noopLogger) Warnf(string, string, string, ...interface{}) { -} + AssertError(t, err) + AssertStringContain(t, err.Error(), "context deadline exceeded") + }) -func (*noopLogger) Infof(string, string, string, ...interface{}) { } -func (*noopLogger) Debugf(string, string, string, ...interface{}) { +func closePipe(t *testing.T, srv, cli net.Conn) { + AssertNoError(t, srv.Close()) + AssertNoError(t, cli.Close()) } diff --git a/neo4j/internal/racingio/reader.go b/neo4j/internal/racingio/reader.go new file mode 100644 index 00000000..8a04c5b8 --- /dev/null +++ b/neo4j/internal/racingio/reader.go @@ -0,0 +1,99 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package racingio + +import ( + "context" + "errors" + "fmt" + "io" +) + +type RacingReader interface { + Read(ctx context.Context, bytes []byte) (int, error) + ReadFull(ctx context.Context, bytes []byte) (int, error) +} + +func NewRacingReader(reader io.Reader) RacingReader { + return &racingReader{reader: reader} +} + +type racingReader struct { + reader io.Reader +} + +func (rr *racingReader) Read(ctx context.Context, bytes []byte) (int, error) { + return rr.race(ctx, bytes, read) +} + +func (rr *racingReader) ReadFull(ctx context.Context, bytes []byte) (int, error) { + return rr.race(ctx, bytes, readFull) +} + +func (rr *racingReader) race(ctx context.Context, bytes []byte, readFn func(io.Reader, []byte) (int, error)) (int, error) { + if err := ctx.Err(); err != nil { + return 0, wrapRaceError(err) + } + resultChan := make(chan *ioResult, 1) + defer close(resultChan) + go func() { + n, err := readFn(rr.reader, bytes) + defer func() { + // When the read operation completes, the outer function may have returned already. + // In that situation, the channel will have been closed and the result emission will crash. + // Let's just swallow the panic that may happen and ignore it + _ = recover() + }() + resultChan <- &ioResult{ + n: n, + err: err, + } + }() + select { + case <-ctx.Done(): + return 0, wrapRaceError(ctx.Err()) + case result := <-resultChan: + return result.n, wrapRaceError(result.err) + } +} + +func wrapRaceError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + // temporary adjustment for 4.x + return fmt.Errorf("i/o timeout: %w", err) + } + return err +} + +type ioResult struct { + n int + err error +} + +func read(reader io.Reader, bytes []byte) (int, error) { + return reader.Read(bytes) +} + +func readFull(reader io.Reader, bytes []byte) (int, error) { + return io.ReadFull(reader, bytes) +} diff --git a/neo4j/internal/racingio/reader_test.go b/neo4j/internal/racingio/reader_test.go new file mode 100644 index 00000000..81af9185 --- /dev/null +++ b/neo4j/internal/racingio/reader_test.go @@ -0,0 +1,153 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package racingio_test + +import ( + "bytes" + "context" + "errors" + "fmt" + rio "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/racingio" + . "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/testutil" + "net" + "reflect" + "testing" + "time" +) + +func TestRacingReader(outer *testing.T) { + + type readFn func(context.Context, []byte) (int, error) + + type testCase struct { + qualifier string + readOperation func(rio.RacingReader) readFn + } + + testCases := []testCase{ + {qualifier: "Read", readOperation: func(reader rio.RacingReader) readFn { return reader.Read }}, + {qualifier: "ReadFull", readOperation: func(reader rio.RacingReader) readFn { return reader.ReadFull }}, + } + + for _, testCase := range testCases { + outer.Run(fmt.Sprintf(`[%s] reads fine with non-cancelling contexts`, testCase.qualifier), func(t *testing.T) { + source := []byte{1, 2, 3} + racingReader := rio.NewRacingReader(bytes.NewBuffer(source)) + result := make([]byte, 2) + + n, err := testCase.readOperation(racingReader)(context.Background(), result) + + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + expectedN := len(result) + if n != expectedN { + t.Errorf("expected %d bytes to be written, got %d", expectedN, n) + } + if !reflect.DeepEqual(result, source[:expectedN]) { + t.Errorf("expected %v bytes, got %v", source, result) + } + }) + + outer.Run(fmt.Sprintf(`[%s] fails reading when context is already canceled`, testCase.qualifier), func(t *testing.T) { + reader := &bytes.Buffer{} + racingReader := rio.NewRacingReader(reader) + result := make([]byte, 2) + + n, err := testCase.readOperation(racingReader)(canceledContext(), result) + + if !errors.Is(err, context.Canceled) { + t.Errorf("expected cancelation error, got %v", err) + } + if n > 0 { + t.Errorf("expected no bytes to be written, got %d", n) + } + if len(reader.Bytes()) > 0 { + t.Errorf("expected empty slice, got %v", reader.Bytes()) + } + }) + + outer.Run(fmt.Sprintf(`[%s] completes before read occurs`, testCase.qualifier), func(t *testing.T) { + reader := &slowFailingReader{sleep: 2 * time.Minute} + racingReader := rio.NewRacingReader(reader) + result := make([]byte, 2) + ctx, cancelFunc := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancelFunc() + + n, err := testCase.readOperation(racingReader)(ctx, result) + + if n > 0 { + t.Errorf("expected 0 written bytes, got %d", n) + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected deadline exceeded error, got %v", err) + } + }) + + outer.Run("connection read times out after 1 successful read", func(t *testing.T) { + timeout := 400 * time.Millisecond + server, client := net.Pipe() + defer closePipe(t, server, client) + go func() { + time.Sleep(timeout / 2) + AssertWriteSucceeds(t, server, []byte{0xca, 0xfe}) + time.Sleep(timeout * 2) + _, _ = server.Write([]byte{0xba, 0xba}) + }() + reader := rio.NewRacingReader(client) + ctx, cancelFunc := context.WithTimeout(context.Background(), timeout) + defer cancelFunc() + + response1 := make([]byte, 2) + n1, err1 := reader.Read(ctx, response1) + response2 := make([]byte, 2) + n2, err2 := reader.Read(ctx, response2) + + AssertIntEqual(t, n1, 2) + AssertNoError(t, err1) + AssertIntEqual(t, n2, 0) + if !errors.Is(err2, context.DeadlineExceeded) { + t.Fatalf("expected underlying connection's read to time out") + } + }) + + } + +} + +func canceledContext() context.Context { + ctx, cancelFunc := context.WithCancel(context.Background()) + cancelFunc() + return ctx +} + +type slowFailingReader struct { + sleep time.Duration +} + +func (hw *slowFailingReader) Read([]byte) (int, error) { + time.Sleep(hw.sleep) + return 0, fmt.Errorf("not gonna read") +} + +func closePipe(t *testing.T, srv, cli net.Conn) { + AssertNoError(t, srv.Close()) + AssertNoError(t, cli.Close()) +}