From c25cc31961ca099fcb4a2a24014b4a6fa84e1e83 Mon Sep 17 00:00:00 2001 From: Florent Biville <445792+fbiville@users.noreply.github.com> Date: Thu, 13 Apr 2023 17:10:35 +0200 Subject: [PATCH] Revisit retry logic (#465) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restrict what's retried, invalidate server more often Remove unneeded Sleep in tests Simplify retry loop 💄 Co-authored-by: Rouven Bauer --- neo4j/error.go | 112 +----------- neo4j/error_test.go | 8 +- neo4j/internal/bolt/bolt3.go | 3 +- neo4j/internal/bolt/bolt4.go | 3 +- neo4j/internal/bolt/bolt5.go | 3 +- neo4j/internal/bolt/chunker.go | 13 +- neo4j/internal/bolt/connections.go | 8 +- neo4j/internal/bolt/dechunker.go | 15 +- neo4j/internal/connector/connector.go | 15 +- .../{bolt/errors.go => errorutil/bolt.go} | 30 +-- neo4j/internal/errorutil/errors.go | 71 ++++++++ neo4j/internal/errorutil/pool.go | 27 +++ neo4j/internal/errorutil/retry.go | 29 +++ neo4j/internal/errorutil/router.go | 17 ++ neo4j/internal/errorutil/tls.go | 13 ++ neo4j/internal/pool/errors.go | 48 ----- neo4j/internal/pool/pool.go | 14 +- neo4j/internal/pool/pool_test.go | 5 +- neo4j/internal/retry/state.go | 171 +++++++----------- neo4j/internal/retry/state_test.go | 85 +++++---- neo4j/internal/router/errors.go | 47 ----- neo4j/internal/router/readtable.go | 3 +- neo4j/internal/router/readtable_test.go | 5 +- neo4j/internal/router/router.go | 29 ++- neo4j/result_with_context.go | 13 +- neo4j/session_with_context.go | 65 +++---- neo4j/session_with_context_test.go | 8 +- neo4j/test-integration/transaction_test.go | 2 - neo4j/transaction_with_context.go | 9 +- 29 files changed, 403 insertions(+), 468 deletions(-) rename neo4j/internal/{bolt/errors.go => errorutil/bolt.go} (74%) create mode 100644 neo4j/internal/errorutil/pool.go create mode 100644 neo4j/internal/errorutil/retry.go create mode 100644 neo4j/internal/errorutil/router.go create mode 100644 neo4j/internal/errorutil/tls.go delete mode 100644 neo4j/internal/pool/errors.go delete mode 100644 neo4j/internal/router/errors.go diff --git a/neo4j/error.go b/neo4j/error.go index 0449f4c5..51d0fc96 100644 --- a/neo4j/error.go +++ b/neo4j/error.go @@ -21,17 +21,9 @@ package neo4j import ( "context" - "errors" - "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/connector" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/retry" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router" - "io" - "net" ) // IsRetryable determines whether an operation can be retried based on the error @@ -42,12 +34,6 @@ func IsRetryable(err error) bool { if err == nil { return false } - var connectivityErr *ConnectivityError - var commitFailedError *retry.CommitFailedDeadError - if errors.As(err, &connectivityErr) && !errors.As(connectivityErr.inner, &commitFailedError) { - // all connectivity errors are safe to retry except during transaction commit - return true - } return retry.IsRetryable(err) } @@ -56,55 +42,11 @@ func IsRetryable(err error) bool { // used internally. type Neo4jError = db.Neo4jError -// UsageError represents errors caused by incorrect usage of the driver API. -// This does not include Cypher syntax (those errors will be Neo4jError). -type UsageError struct { - Message string -} - -func (e *UsageError) Error() string { - return e.Message -} - -// TransactionExecutionLimit error indicates that a retryable transaction has -// failed due to reaching a limit like a timeout or maximum number of attempts. -type TransactionExecutionLimit struct { - Errors []error - Causes []string -} - -func newTransactionExecutionLimit(errors []error, causes []string) *TransactionExecutionLimit { - tel := &TransactionExecutionLimit{Errors: make([]error, len(errors)), Causes: causes} - for i, err := range errors { - tel.Errors[i] = wrapError(err) - } - - return tel -} - -func (e *TransactionExecutionLimit) Error() string { - cause := "Unknown cause" - l := len(e.Causes) - if l > 0 { - cause = e.Causes[l-1] - } - var err error - l = len(e.Errors) - if l > 0 { - err = e.Errors[l-1] - } - return fmt.Sprintf("TransactionExecutionLimit: %s after %d attempts, last error: %s", cause, len(e.Errors), err) -} +type UsageError = errorutil.UsageError -// ConnectivityError represent errors caused by the driver not being able to connect to Neo4j services, -// or lost connections. -type ConnectivityError struct { - inner error -} +type ConnectivityError = errorutil.ConnectivityError -func (e *ConnectivityError) Error() string { - return fmt.Sprintf("ConnectivityError: %s", e.inner.Error()) -} +type TransactionExecutionLimit = errorutil.TransactionExecutionLimit // IsNeo4jError returns true if the provided error is an instance of Neo4jError. func IsNeo4jError(err error) bool { @@ -130,53 +72,7 @@ func IsTransactionExecutionLimit(err error) bool { return is } -// TokenExpiredError represent errors caused by the driver not being able to connect to Neo4j services, -// or lost connections. -type TokenExpiredError struct { - Code string - Message string -} - -func (e *TokenExpiredError) Error() string { - return fmt.Sprintf("TokenExpiredError: %s (%s)", e.Code, e.Message) -} - -func wrapError(err error) error { - if err == nil { - return nil - } - if err == io.EOF { - return &ConnectivityError{inner: err} - } - switch e := err.(type) { - case *db.UnsupportedTypeError, *db.FeatureNotSupportedError: - // Usage of a type not supported by database network protocol or feature - // not supported by current version or edition. - return &UsageError{Message: err.Error()} - case *pool.PoolClosed: - return &UsageError{Message: err.Error()} - case *connector.TlsError, net.Error: - return &ConnectivityError{inner: err} - case *pool.PoolTimeout, *pool.PoolFull: - return &ConnectivityError{inner: err} - case *router.ReadRoutingTableError: - return &ConnectivityError{inner: err} - case *retry.CommitFailedDeadError: - return &ConnectivityError{inner: err} - case *bolt.ConnectionReadTimeout: - return &ConnectivityError{inner: err} - case *bolt.ConnectionWriteTimeout: - return &ConnectivityError{inner: err} - case *db.Neo4jError: - if e.Code == "Neo.ClientError.Security.TokenExpired" { - return &TokenExpiredError{Code: e.Code, Message: e.Msg} - } - } - if err != nil && err.Error() == bolt.InvalidTransactionError { - return &UsageError{Message: bolt.InvalidTransactionError} - } - return err -} +type TokenExpiredError = errorutil.TokenExpiredError type ctxCloser interface { Close(ctx context.Context) error diff --git a/neo4j/error_test.go b/neo4j/error_test.go index 0f595d17..ec2b1e9c 100644 --- a/neo4j/error_test.go +++ b/neo4j/error_test.go @@ -22,7 +22,7 @@ package neo4j import ( "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/retry" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "testing" ) @@ -34,7 +34,7 @@ func TestIsRetryable(outer *testing.T) { testCases := []retryableTestCase{ {true, &ConnectivityError{ - inner: fmt.Errorf("hello, is it me you are looking for"), + Inner: fmt.Errorf("hello, is it me you are looking for"), }}, {true, &db.Neo4jError{ Code: "Neo.TransientError.No.Stress", @@ -50,7 +50,7 @@ func TestIsRetryable(outer *testing.T) { }}, {false, nil}, {false, &ConnectivityError{ - inner: &retry.CommitFailedDeadError{}, + Inner: &errorutil.CommitFailedDeadError{}, }}, {false, &db.Neo4jError{ Code: "Neo.TransientError.Transaction.Terminated", @@ -68,7 +68,7 @@ func TestIsRetryable(outer *testing.T) { } for _, testCase := range testCases { - outer.Run(fmt.Sprintf("is error %s retryable?", testCase.err), func(t *testing.T) { + outer.Run(fmt.Sprintf("is error %v retryable?", testCase.err), func(t *testing.T) { expected := testCase.isRetryable actual := IsRetryable(testCase.err) diff --git a/neo4j/internal/bolt/bolt3.go b/neo4j/internal/bolt/bolt3.go index 9b2727b4..dd54753d 100644 --- a/neo4j/internal/bolt/bolt3.go +++ b/neo4j/internal/bolt/bolt3.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "net" "time" @@ -271,7 +272,7 @@ func (b *bolt3) TxBegin( // misuse from clients that stick to their connections when they shouldn't. func (b *bolt3) assertTxHandle(h1, h2 idb.TxHandle) error { if h1 != h2 { - err := errors.New(InvalidTransactionError) + err := errors.New(errorutil.InvalidTransactionError) b.log.Error(log.Bolt3, b.logId, err) return err } diff --git a/neo4j/internal/bolt/bolt4.go b/neo4j/internal/bolt/bolt4.go index 493e7bb5..83a641dd 100644 --- a/neo4j/internal/bolt/bolt4.go +++ b/neo4j/internal/bolt/bolt4.go @@ -25,6 +25,7 @@ import ( "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "net" "time" @@ -311,7 +312,7 @@ func (b *bolt4) TxBegin( // misuse from clients that stick to their connections when they shouldn't. func (b *bolt4) assertTxHandle(h1, h2 idb.TxHandle) error { if h1 != h2 { - err := errors.New(InvalidTransactionError) + err := errors.New(errorutil.InvalidTransactionError) b.log.Error(log.Bolt4, b.logId, err) return err } diff --git a/neo4j/internal/bolt/bolt5.go b/neo4j/internal/bolt/bolt5.go index d2bf9623..e5cde0fc 100644 --- a/neo4j/internal/bolt/bolt5.go +++ b/neo4j/internal/bolt/bolt5.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "net" "time" @@ -300,7 +301,7 @@ func (b *bolt5) TxBegin( // misuse from clients that stick to their connections when they shouldn't. func (b *bolt5) assertTxHandle(h1, h2 idb.TxHandle) error { if h1 != h2 { - err := errors.New(InvalidTransactionError) + err := errors.New(errorutil.InvalidTransactionError) b.log.Error(log.Bolt5, b.logId, err) return err } diff --git a/neo4j/internal/bolt/chunker.go b/neo4j/internal/bolt/chunker.go index 16b189b6..b4d42f9b 100644 --- a/neo4j/internal/bolt/chunker.go +++ b/neo4j/internal/bolt/chunker.go @@ -22,6 +22,7 @@ package bolt import ( "context" "encoding/binary" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" rio "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "io" ) @@ -111,15 +112,15 @@ func (c *chunker) send(ctx context.Context, wr io.Writer) error { } func processWriteError(err error, ctx context.Context) error { - if IsTimeoutError(err) { - return &ConnectionWriteTimeout{ - userContext: ctx, - err: err, + if errorutil.IsTimeoutError(err) { + return &errorutil.ConnectionWriteTimeout{ + UserContext: ctx, + Err: err, } } if err == context.Canceled { - return &ConnectionWriteCanceled{ - err: err, + return &errorutil.ConnectionWriteCanceled{ + Err: err, } } return err diff --git a/neo4j/internal/bolt/connections.go b/neo4j/internal/bolt/connections.go index 3f74c696..dcbc9204 100644 --- a/neo4j/internal/bolt/connections.go +++ b/neo4j/internal/bolt/connections.go @@ -37,13 +37,13 @@ func handleTerminatedContextError(err error, connection net.Conn) error { func contextTerminatedErr(err error) bool { switch err.(type) { - case *ConnectionWriteTimeout: + case *errorutil.ConnectionWriteTimeout: return true - case *ConnectionReadTimeout: + case *errorutil.ConnectionReadTimeout: return true - case *ConnectionWriteCanceled: + case *errorutil.ConnectionWriteCanceled: return true - case *ConnectionReadCanceled: + case *errorutil.ConnectionReadCanceled: return true } return false diff --git a/neo4j/internal/bolt/dechunker.go b/neo4j/internal/bolt/dechunker.go index 015b885c..86589d0f 100644 --- a/neo4j/internal/bolt/dechunker.go +++ b/neo4j/internal/bolt/dechunker.go @@ -22,6 +22,7 @@ package bolt import ( "context" "encoding/binary" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" rio "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "net" "time" @@ -86,16 +87,16 @@ func newContext(ctx context.Context, readTimeout time.Duration) (context.Context } func processReadError(err error, ctx context.Context, readTimeout time.Duration) error { - if IsTimeoutError(err) { - return &ConnectionReadTimeout{ - userContext: ctx, - readTimeout: readTimeout, - err: err, + if errorutil.IsTimeoutError(err) { + return &errorutil.ConnectionReadTimeout{ + UserContext: ctx, + ReadTimeout: readTimeout, + Err: err, } } if err == context.Canceled { - return &ConnectionReadCanceled{ - err: err, + return &errorutil.ConnectionReadCanceled{ + Err: err, } } return err diff --git a/neo4j/internal/connector/connector.go b/neo4j/internal/connector/connector.go index cd2a8b3c..853dea99 100644 --- a/neo4j/internal/connector/connector.go +++ b/neo4j/internal/connector/connector.go @@ -26,6 +26,7 @@ import ( "errors" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "io" "net" "time" @@ -96,7 +97,7 @@ func (c Connector) Connect(ctx context.Context, address string, boltLogger log.B err = errors.New("remote end closed the connection, check that TLS is enabled on the server") } conn.Close() - return nil, &TlsError{inner: err} + return nil, &errorutil.TlsError{Inner: err} } connection, err := bolt.Connect(ctx, address, @@ -141,15 +142,3 @@ func (c Connector) tlsConfig(serverName string) *tls.Config { config.ServerName = serverName return config } - -// TlsError encapsulates all errors related to TLS connection creation -// This is needed since the tls package does not provide a common error type -// à la net.Error, and a common type is needed to properly classify the error -// for Testkit -type TlsError struct { - inner error -} - -func (e *TlsError) Error() string { - return e.inner.Error() -} diff --git a/neo4j/internal/bolt/errors.go b/neo4j/internal/errorutil/bolt.go similarity index 74% rename from neo4j/internal/bolt/errors.go rename to neo4j/internal/errorutil/bolt.go index a93d4d9e..0866d2e2 100644 --- a/neo4j/internal/bolt/errors.go +++ b/neo4j/internal/errorutil/bolt.go @@ -1,4 +1,4 @@ -package bolt +package errorutil import ( "context" @@ -9,50 +9,50 @@ import ( const InvalidTransactionError = "invalid transaction handle" type ConnectionReadTimeout struct { - userContext context.Context - readTimeout time.Duration - err error + UserContext context.Context + ReadTimeout time.Duration + Err error } func (crt *ConnectionReadTimeout) Error() string { userDeadline := "N/A" - if deadline, ok := crt.userContext.Deadline(); ok { + if deadline, ok := crt.UserContext.Deadline(); ok { userDeadline = deadline.String() } return fmt.Sprintf( "Timeout while reading from connection [server-side timeout hint: %s, user-provided context deadline: %s]: %s", - crt.readTimeout.String(), + crt.ReadTimeout.String(), userDeadline, - crt.err) + crt.Err) } type ConnectionWriteTimeout struct { - userContext context.Context - err error + UserContext context.Context + Err error } func (cwt *ConnectionWriteTimeout) Error() string { userDeadline := "N/A" - if deadline, ok := cwt.userContext.Deadline(); ok { + if deadline, ok := cwt.UserContext.Deadline(); ok { userDeadline = deadline.String() } - return fmt.Sprintf("Timeout while writing to connection [user-provided context deadline: %s]: %s", userDeadline, cwt.err) + return fmt.Sprintf("Timeout while writing to connection [user-provided context deadline: %s]: %s", userDeadline, cwt.Err) } type ConnectionReadCanceled struct { - err error + Err error } func (crc *ConnectionReadCanceled) Error() string { - return fmt.Sprintf("Reading from connection has been canceled: %s", crc.err) + return fmt.Sprintf("Reading from connection has been canceled: %s", crc.Err) } type ConnectionWriteCanceled struct { - err error + Err error } func (cwc *ConnectionWriteCanceled) Error() string { - return fmt.Sprintf("Writing to connection has been canceled: %s", cwc.err) + return fmt.Sprintf("Writing to connection has been canceled: %s", cwc.Err) } type timeout interface { diff --git a/neo4j/internal/errorutil/errors.go b/neo4j/internal/errorutil/errors.go index 1959cc04..54c93e81 100644 --- a/neo4j/internal/errorutil/errors.go +++ b/neo4j/internal/errorutil/errors.go @@ -21,6 +21,9 @@ package errorutil import ( "fmt" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + "io" + "net" ) func CombineAllErrors(errs ...error) error { @@ -43,3 +46,71 @@ func CombineErrors(err1, err2 error) error { } return fmt.Errorf("error %v occurred after previous error %w", err2, err1) } + +func WrapError(err error) error { + if err == nil { + return nil + } + if err == io.EOF { + return &ConnectivityError{Inner: err} + } + switch e := err.(type) { + case *db.UnsupportedTypeError, *db.FeatureNotSupportedError: + // Usage of a type not supported by database network protocol or feature + // not supported by current version or edition. + return &UsageError{Message: err.Error()} + case *PoolClosed: + return &UsageError{Message: err.Error()} + case *TlsError, net.Error: + return &ConnectivityError{Inner: err} + case *PoolTimeout, *PoolFull: + return &ConnectivityError{Inner: err} + case *ReadRoutingTableError: + return &ConnectivityError{Inner: err} + case *CommitFailedDeadError: + return &ConnectivityError{Inner: err} + case *ConnectionReadTimeout: + return &ConnectivityError{Inner: err} + case *ConnectionWriteTimeout: + return &ConnectivityError{Inner: err} + case *db.Neo4jError: + if e.Code == "Neo.ClientError.Security.TokenExpired" { + return &TokenExpiredError{Code: e.Code, Message: e.Msg} + } + } + if err != nil && err.Error() == InvalidTransactionError { + return &UsageError{Message: InvalidTransactionError} + } + return err +} + +// UsageError represents errors caused by incorrect usage of the driver API. +// This does not include Cypher syntax (those errors will be Neo4jError). +type UsageError struct { + Message string +} + +func (e *UsageError) Error() string { + return e.Message +} + +// ConnectivityError represent errors caused by the driver not being able to connect to Neo4j services, +// or lost connections. +type ConnectivityError struct { + Inner error +} + +func (e *ConnectivityError) Error() string { + return fmt.Sprintf("ConnectivityError: %s", e.Inner.Error()) +} + +// TokenExpiredError represent errors caused by the driver not being able to connect to Neo4j services, +// or lost connections. +type TokenExpiredError struct { + Code string + Message string +} + +func (e *TokenExpiredError) Error() string { + return fmt.Sprintf("TokenExpiredError: %s (%s)", e.Code, e.Message) +} diff --git a/neo4j/internal/errorutil/pool.go b/neo4j/internal/errorutil/pool.go new file mode 100644 index 00000000..6c07c72f --- /dev/null +++ b/neo4j/internal/errorutil/pool.go @@ -0,0 +1,27 @@ +package errorutil + +import "fmt" + +type PoolTimeout struct { + Err error + Servers []string +} + +func (e *PoolTimeout) Error() string { + return fmt.Sprintf("Timeout while waiting for connection to any of [%s]: %s", e.Servers, e.Err) +} + +type PoolFull struct { + Servers []string +} + +func (e *PoolFull) Error() string { + return fmt.Sprintf("No idle connections on any of [%s]", e.Servers) +} + +type PoolClosed struct { +} + +func (e *PoolClosed) Error() string { + return "Pool closed" +} diff --git a/neo4j/internal/errorutil/retry.go b/neo4j/internal/errorutil/retry.go new file mode 100644 index 00000000..af2593ae --- /dev/null +++ b/neo4j/internal/errorutil/retry.go @@ -0,0 +1,29 @@ +package errorutil + +import "fmt" + +type CommitFailedDeadError struct { + Inner error +} + +func (e *CommitFailedDeadError) Error() string { + return fmt.Sprintf("Connection lost during commit: %s", e.Inner) +} + +// TransactionExecutionLimit error indicates that a retryable transaction has +// failed due to reaching a limit like a timeout or maximum number of attempts. +type TransactionExecutionLimit struct { + Cause string + Errors []error +} + +func (e *TransactionExecutionLimit) Error() string { + cause := e.Cause + var err error + l := len(e.Errors) + + if l > 0 { + err = e.Errors[l-1] + } + return fmt.Sprintf("TransactionExecutionLimit: %s after %d attempts, last error: %s", cause, len(e.Errors), err) +} diff --git a/neo4j/internal/errorutil/router.go b/neo4j/internal/errorutil/router.go new file mode 100644 index 00000000..ed967f39 --- /dev/null +++ b/neo4j/internal/errorutil/router.go @@ -0,0 +1,17 @@ +package errorutil + +import ( + "fmt" +) + +type ReadRoutingTableError struct { + Err error + Server string +} + +func (e *ReadRoutingTableError) Error() string { + if e.Err != nil || len(e.Server) > 0 { + return fmt.Sprintf("Unable to retrieve routing table from %s: %s", e.Server, e.Err) + } + return "Unable to retrieve routing table, no router provided" +} diff --git a/neo4j/internal/errorutil/tls.go b/neo4j/internal/errorutil/tls.go new file mode 100644 index 00000000..1149ed1e --- /dev/null +++ b/neo4j/internal/errorutil/tls.go @@ -0,0 +1,13 @@ +package errorutil + +// TlsError encapsulates all errors related to TLS connection creation +// This is needed since the tls package does not provide a common error type +// à la net.Error, and a common type is needed to properly classify the error +// for Testkit +type TlsError struct { + Inner error +} + +func (e *TlsError) Error() string { + return e.Inner.Error() +} diff --git a/neo4j/internal/pool/errors.go b/neo4j/internal/pool/errors.go deleted file mode 100644 index 9d1d1007..00000000 --- a/neo4j/internal/pool/errors.go +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [https://neo4j.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package pool - -import ( - "fmt" -) - -type PoolTimeout struct { - err error - servers []string -} - -func (e *PoolTimeout) Error() string { - return fmt.Sprintf("Timeout while waiting for connection to any of [%s]: %s", e.servers, e.err) -} - -type PoolFull struct { - servers []string -} - -func (e *PoolFull) Error() string { - return fmt.Sprintf("No idle connections on any of [%s]", e.servers) -} - -type PoolClosed struct { -} - -func (e *PoolClosed) Error() string { - return "Pool closed" -} diff --git a/neo4j/internal/pool/pool.go b/neo4j/internal/pool/pool.go index 1c9bf55e..3a094378 100644 --- a/neo4j/internal/pool/pool.go +++ b/neo4j/internal/pool/pool.go @@ -27,8 +27,8 @@ import ( "context" "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "math" "sort" @@ -205,7 +205,7 @@ func (p *Pool) tryAnyIdle(ctx context.Context, serverNames []string, idlenessThr func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration) (db.Connection, error) { if p.closed { - return nil, &PoolClosed{} + return nil, &errorutil.PoolClosed{} } p.log.Debugf(log.Pool, p.logId, "Trying to borrow connection from %s", serverNames) @@ -226,9 +226,9 @@ func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, bolt return conn, nil } - if bolt.IsTimeoutError(err) { + if errorutil.IsTimeoutError(err) { p.log.Warnf(log.Pool, p.logId, "Borrow time-out") - return nil, &PoolTimeout{servers: serverNames, err: err} + return nil, &errorutil.PoolTimeout{Servers: serverNames, Err: err} } } @@ -249,7 +249,7 @@ func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, bolt } if !wait { - return nil, &PoolFull{servers: serverNames} + return nil, &errorutil.PoolFull{Servers: serverNames} } // Wait for a matching connection to be returned from another thread. @@ -293,7 +293,7 @@ func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, bolt return q.conn, nil } p.log.Warnf(log.Pool, p.logId, "Borrow time-out") - return nil, &PoolTimeout{err: ctx.Err(), servers: serverNames} + return nil, &errorutil.PoolTimeout{Err: ctx.Err(), Servers: serverNames} } } @@ -318,7 +318,7 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. return connection, nil } if srv.size() >= p.config.MaxConnectionPoolSize { - return nil, &PoolFull{servers: []string{serverName}} + return nil, &errorutil.PoolFull{Servers: []string{serverName}} } break } diff --git a/neo4j/internal/pool/pool_test.go b/neo4j/internal/pool/pool_test.go index 43b06a31..3bec06d9 100644 --- a/neo4j/internal/pool/pool_test.go +++ b/neo4j/internal/pool/pool_test.go @@ -24,6 +24,7 @@ import ( "errors" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "math/rand" "sync" "testing" @@ -133,7 +134,7 @@ func TestPoolBorrowReturn(outer *testing.T) { c2, err2 := p.Borrow(ctx, serverNames, false, nil, DefaultLivenessCheckThreshold) assertNoConnection(t, c2, err2) // Error should be pool full - _ = err2.(*PoolFull) + _ = err2.(*errorutil.PoolFull) }) outer.Run("Multiple threads borrows and returns randomly", func(t *testing.T) { @@ -216,7 +217,7 @@ func TestPoolBorrowReturn(outer *testing.T) { t.Error("There should be an error due to cancelling") } // Should be a pool error with the cancellation error in it - _ = err.(*PoolTimeout) + _ = err.(*errorutil.PoolTimeout) }) outer.Run("Borrows the first successfully reset long-idle connection", func(t *testing.T) { diff --git a/neo4j/internal/retry/state.go b/neo4j/internal/retry/state.go index b369da4d..e75c9d96 100644 --- a/neo4j/internal/retry/state.go +++ b/neo4j/internal/retry/state.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ // Package retry handles retry operations. @@ -25,6 +25,7 @@ import ( "errors" "fmt" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" @@ -35,20 +36,8 @@ type Router interface { Invalidate(ctx context.Context, database string) error } -type CommitFailedDeadError struct { - inner error -} - -func (e *CommitFailedDeadError) Error() string { - return fmt.Sprintf("Connection lost during commit: %s", e.inner) -} - type State struct { - LastErrWasRetryable bool - LastErr error - stop bool Errs []error - Causes []string MaxTransactionRetryTime time.Duration Log log.Logger LogName string @@ -67,112 +56,92 @@ type State struct { OnDeadConnection func(server string) error } -func (s *State) OnFailure(ctx context.Context, conn idb.Connection, err error, isCommitting bool) { - s.LastErr = err - s.cause = "" - s.skipSleep = false - - // Check timeout - if s.start.IsZero() { - s.start = s.Now() - } - if s.Now().Sub(s.start) > s.MaxTransactionRetryTime { - s.stop = true - s.cause = "Timeout" - return - } - - // Reset after determined to evaluate this error - s.LastErrWasRetryable = false - - if neo4jErr, ok := err.(*db.Neo4jError); ok && neo4jErr.IsAuthenticationFailed() { - s.cause = "Authentication failed" - s.stop = true - return - } - - if _, ok := err.(*db.ProtocolError); ok { - s.cause = "Protocol error detected" - s.stop = true - return - } - - // Failed to connect - if conn == nil { - s.LastErrWasRetryable = true - s.cause = "No available connection" - return - } - - // Check if the connection died, if it died during commit it is not safe to retry. - if !conn.IsAlive() { +func (s *State) OnFailure(ctx context.Context, err error, conn idb.Connection, isCommitting bool) { + if conn != nil && !conn.IsAlive() { if isCommitting { - s.stop = true - // The error is most probably io.EOF so enrich the error - // to make this error more recognizable. - s.LastErr = &CommitFailedDeadError{inner: s.LastErr} - return + s.Errs = append(s.Errs, &errorutil.CommitFailedDeadError{Inner: err}) + } else { + s.Errs = append(s.Errs, err) + } + if err := s.OnDeadConnection(conn.ServerName()); err != nil { + s.Errs = append(s.Errs, err) } - - s.OnDeadConnection(conn.ServerName()) s.deadErrors += 1 - s.stop = s.deadErrors > s.MaxDeadConnections - s.LastErrWasRetryable = true - s.cause = "Connection lost" s.skipSleep = true return } - s.LastErrWasRetryable = IsRetryable(err) - if dbErr, isDbErr := err.(*db.Neo4jError); isDbErr { - if dbErr.IsRetriableCluster() { - // Force routing tables to be updated before trying again - if err := s.Router.Invalidate(ctx, s.DatabaseName); err != nil { - s.stop = true - s.LastErr = err - } - s.cause = "Cluster error" - return - } - if dbErr.IsRetriableTransient() { - s.cause = "Transient error" - return + s.Errs = append(s.Errs, err) + s.skipSleep = false + + if dbErr, isDbErr := err.(*db.Neo4jError); isDbErr && dbErr.IsRetriableCluster() { + if err := s.Router.Invalidate(ctx, s.DatabaseName); err != nil { + s.Errs = append(s.Errs, err) } } - s.stop = true } func (s *State) Continue() bool { - // No error happened yet - if !s.stop && s.LastErr == nil { + if s.start.IsZero() { + s.start = s.Now() + } + + if len(s.Errs) == 0 { return true } - // Track the error and the cause - s.Errs = append(s.Errs, s.LastErr) - if s.cause != "" { - s.Causes = append(s.Causes, s.cause) + lastErr := s.Errs[len(s.Errs)-1] + if !IsRetryable(errorutil.WrapError(lastErr)) { + return false } - // Retry after optional sleep - if !s.stop { - if s.skipSleep { - s.Log.Debugf(s.LogName, s.LogId, "Retrying transaction (%s): %s", s.cause, s.LastErr) - } else { - s.Throttle = s.Throttle.next() - sleepTime := s.Throttle.delay() - s.Log.Debugf(s.LogName, s.LogId, - "Retrying transaction (%s): %s [after %s]", s.cause, s.LastErr, sleepTime) - s.Sleep(sleepTime) - } - return true + if s.Now().Sub(s.start) > s.MaxTransactionRetryTime { + s.Errs = []error{&errorutil.TransactionExecutionLimit{ + Cause: fmt.Sprintf("timeout (exceeded max retry time: %s)", s.MaxTransactionRetryTime.String()), + Errors: s.Errs, + }} + return false } - return false + if s.deadErrors > s.MaxDeadConnections { + s.Errs = []error{&errorutil.TransactionExecutionLimit{ + Cause: fmt.Sprintf("too many failed connection attempts (allowed max %d)", s.MaxDeadConnections), + Errors: s.Errs, + }} + return false + } + + if s.skipSleep { + s.Log.Debugf(s.LogName, s.LogId, "Retrying transaction (%s): %s", s.cause, lastErr) + } else { + s.Throttle = s.Throttle.next() + sleepTime := s.Throttle.delay() + s.Log.Debugf(s.LogName, s.LogId, + "Retrying transaction (%s): %s [after %s]", s.cause, lastErr, sleepTime) + s.Sleep(sleepTime) + } + return true +} + +func (s *State) ProduceError() error { + lastErr := s.Errs[len(s.Errs)-1] + if limitReachedErr, ok := lastErr.(*errorutil.TransactionExecutionLimit); ok { + return limitReachedErr + } + return errorutil.WrapError(lastErr) } func IsRetryable(err error) bool { + if connectivityErr, ok := err.(*errorutil.ConnectivityError); ok { + if _, ok := connectivityErr.Inner.(*errorutil.CommitFailedDeadError); ok { + return false + } + return true + } + if _, ok := err.(*errorutil.PoolTimeout); ok { + return true + } var dbError *db.Neo4jError if !errors.As(err, &dbError) { return false diff --git a/neo4j/internal/retry/state_test.go b/neo4j/internal/retry/state_test.go index 654d4fcc..d462339e 100644 --- a/neo4j/internal/retry/state_test.go +++ b/neo4j/internal/retry/state_test.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package retry @@ -23,13 +23,13 @@ import ( "context" "errors" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "io" "reflect" "testing" "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -64,37 +64,47 @@ func TestState(outer *testing.T) { testCases := map[string][]TStateInvocation{ "Retry connect": { - {conn: nil, err: &pool.PoolTimeout{}, expectContinued: true, - expectLastErrWasRetryable: true, expectLastErrType: &pool.PoolTimeout{}}, + {conn: nil, err: &errorutil.PoolTimeout{}, expectContinued: true, + expectLastErrWasRetryable: true, expectLastErrType: &errorutil.PoolTimeout{}}, }, "Retry connect timeout": { - {conn: nil, err: errors.New("connect error 1"), expectContinued: true, now: baseTime, + {conn: nil, err: dbTransientErr, expectContinued: true, now: baseTime, expectLastErrWasRetryable: true}, - {conn: nil, err: errors.New("connect error 2"), expectContinued: true, now: halfTime, + {conn: nil, err: dbTransientErr, expectContinued: true, now: halfTime, expectLastErrWasRetryable: true}, - {conn: nil, err: errors.New("connect error 3"), expectContinued: false, now: overTime, + {conn: nil, err: dbTransientErr, expectContinued: false, now: overTime, expectLastErrWasRetryable: true}, }, "Retry dead connection": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, - err: errors.New("some error"), expectContinued: true, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, + err: errors.New("some error"), expectContinued: false, + expectLastErrWasRetryable: false, expectRouterInvalidated: true, + expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, + err: dbTransientErr, + expectContinued: true, expectLastErrWasRetryable: true, + expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, }, "Retry dead connection timeout": { - {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 1"), expectContinued: true, now: baseTime, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, - {conn: &testutil.ConnFake{Alive: false}, err: errors.New("some error 2"), expectContinued: false, now: overTime, - expectLastErrWasRetryable: true}, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, + expectContinued: true, now: baseTime, expectLastErrWasRetryable: true, + expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, + expectRouterInvalidatedServer: serverName}, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 2"), + expectContinued: false, now: overTime, + expectLastErrWasRetryable: false, + expectRouterInvalidated: true, + expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, }, "Retry dead connection max": { - {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 1"), expectContinued: true, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: true, expectLastErrWasRetryable: true, expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, - {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 2"), expectContinued: true, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: true, expectLastErrWasRetryable: true, expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, - {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 3"), expectContinued: false, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: false, expectLastErrWasRetryable: true, expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, }, @@ -102,12 +112,6 @@ func TestState(outer *testing.T) { {conn: &testutil.ConnFake{Alive: true}, err: clusterErr, expectContinued: true, expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectLastErrWasRetryable: true}, }, - "Cluster error timeout": { - {conn: &testutil.ConnFake{Alive: true}, err: clusterErr, expectContinued: true, - expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectLastErrWasRetryable: true}, - {conn: &testutil.ConnFake{Alive: true}, err: clusterErr, expectContinued: false, now: overTime, - expectLastErrWasRetryable: true}, - }, "Database transient error": { {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: true, expectLastErrWasRetryable: true}, @@ -123,14 +127,18 @@ func TestState(outer *testing.T) { expectLastErrWasRetryable: false}, }, "Fail during commit": { - {conn: &testutil.ConnFake{Alive: false}, err: io.EOF, isCommitting: true, expectContinued: false, - expectLastErrWasRetryable: false, expectLastErrType: &CommitFailedDeadError{}}, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: io.EOF, isCommitting: true, expectContinued: false, + expectLastErrWasRetryable: false, expectLastErrType: &errorutil.CommitFailedDeadError{}, + expectRouterInvalidated: true, + expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, }, "Fail during commit after retry": { {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: true, expectLastErrWasRetryable: true}, - {conn: &testutil.ConnFake{Alive: false}, err: io.EOF, isCommitting: true, expectContinued: false, - expectLastErrWasRetryable: false, expectLastErrType: &CommitFailedDeadError{}}, + {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: io.EOF, isCommitting: true, + expectContinued: false, expectLastErrWasRetryable: false, + expectLastErrType: &errorutil.CommitFailedDeadError{}, expectRouterInvalidated: true, + expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, }, "Does not retry on auth errors": { {conn: nil, err: authErr, expectContinued: false, @@ -173,7 +181,7 @@ func TestState(outer *testing.T) { return router.InvalidateReader(ctx, dbName, server) } - state.OnFailure(ctx, invocation.conn, invocation.err, invocation.isCommitting) + state.OnFailure(ctx, invocation.err, invocation.conn, invocation.isCommitting) continued := state.Continue() if continued != invocation.expectContinued { t.Errorf("Expected continue to return %v but returned %v", invocation.expectContinued, continued) @@ -185,14 +193,19 @@ func TestState(outer *testing.T) { invocation.expectRouterInvalidated, invocation.expectRouterInvalidatedDb, invocation.expectRouterInvalidatedServer, router.Invalidated, router.InvalidatedDb, router.InvalidatedServer) } - if state.LastErr == nil { - t.Errorf("LastErr should be set") + var lastError error + if err, ok := state.Errs[0].(*errorutil.TransactionExecutionLimit); ok { + errs := err.Errors + lastError = errs[len(errs)-1] + } else { + lastError = state.Errs[len(state.Errs)-1] } - if state.LastErrWasRetryable != invocation.expectLastErrWasRetryable { + + if IsRetryable(lastError) != invocation.expectLastErrWasRetryable { t.Errorf("LastErrWasRetryable mismatch") } if invocation.expectLastErrType != nil { - t1 := reflect.TypeOf(state.LastErr) + t1 := reflect.TypeOf(lastError) t2 := reflect.TypeOf(invocation.expectLastErrType) if t1 != t2 { t.Errorf("LastErr type mismatch: %s vs %s", t1, t2) diff --git a/neo4j/internal/router/errors.go b/neo4j/internal/router/errors.go deleted file mode 100644 index 7b58fab2..00000000 --- a/neo4j/internal/router/errors.go +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [https://neo4j.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package router - -import ( - "fmt" - - "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" -) - -type ReadRoutingTableError struct { - err error - server string -} - -func (e *ReadRoutingTableError) Error() string { - if e.err != nil || len(e.server) > 0 { - return fmt.Sprintf("Unable to retrieve routing table from %s: %s", e.server, e.err) - } - return "Unable to retrieve routing table, no router provided" -} - -func wrapError(server string, err error) error { - // Preserve error originating from the database, wrap other errors - _, isNeo4jErr := err.(*db.Neo4jError) - if isNeo4jErr { - return err - } - return &ReadRoutingTableError{server: server, err: err} -} diff --git a/neo4j/internal/router/readtable.go b/neo4j/internal/router/readtable.go index ff150065..eb6942d7 100644 --- a/neo4j/internal/router/readtable.go +++ b/neo4j/internal/router/readtable.go @@ -22,6 +22,7 @@ package router import ( "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -31,7 +32,7 @@ import ( func readTable(ctx context.Context, connectionPool Pool, routers []string, routerContext map[string]string, bookmarks []string, database, impersonatedUser string, boltLogger log.BoltLogger) (*db.RoutingTable, error) { // Preserve last error to be returned, set a default for case of no routers - var err error = &ReadRoutingTableError{} + var err error = &errorutil.ReadRoutingTableError{} // Try the routers one at the time since some of them might no longer support routing and we // can't force the pool to not re-use these when putting them back in the pool and retrieving diff --git a/neo4j/internal/router/readtable_test.go b/neo4j/internal/router/readtable_test.go index 01d961d9..e8faaa20 100644 --- a/neo4j/internal/router/readtable_test.go +++ b/neo4j/internal/router/readtable_test.go @@ -23,6 +23,7 @@ import ( "context" "errors" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" "testing" @@ -54,9 +55,9 @@ func TestReadTableTable(ot *testing.T) { } assertRoutingTableError := func(t *testing.T, err error) { - _, is := err.(*ReadRoutingTableError) + _, is := err.(*errorutil.ReadRoutingTableError) if !is { - r := &ReadRoutingTableError{} + r := &errorutil.ReadRoutingTableError{} t.Errorf("Error should be %T but was %T", r, err) } } diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index fd207221..31aa1079 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -22,7 +22,9 @@ package router import ( "context" "errors" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "time" @@ -34,7 +36,7 @@ const missingReaderRetries = 100 type databaseRouter struct { dueUnix int64 - table *db.RoutingTable + table *idb.RoutingTable } // Router is thread safe @@ -56,8 +58,8 @@ type Pool interface { // If all connections are busy and the pool is full, calls to Borrow may wait for a connection to become idle // If a connection has been idle for longer than idlenessThreshold, it will be reset // to check if it's still alive. - Borrow(ctx context.Context, servers []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration) (db.Connection, error) - Return(ctx context.Context, c db.Connection) error + Borrow(ctx context.Context, servers []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration) (idb.Connection, error) + Return(ctx context.Context, c idb.Connection) error } func New(rootRouter string, getRouters func() []string, routerContext map[string]string, pool Pool, logger log.Logger, logId string) *Router { @@ -77,9 +79,9 @@ func New(rootRouter string, getRouters func() []string, routerContext map[string return r } -func (r *Router) readTable(ctx context.Context, dbRouter *databaseRouter, bookmarks []string, database, impersonatedUser string, boltLogger log.BoltLogger) (*db.RoutingTable, error) { +func (r *Router) readTable(ctx context.Context, dbRouter *databaseRouter, bookmarks []string, database, impersonatedUser string, boltLogger log.BoltLogger) (*idb.RoutingTable, error) { var ( - table *db.RoutingTable + table *idb.RoutingTable err error ) @@ -117,7 +119,7 @@ func (r *Router) readTable(ctx context.Context, dbRouter *databaseRouter, bookma return table, nil } -func (r *Router) getOrReadTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, boltLogger log.BoltLogger) (*db.RoutingTable, error) { +func (r *Router) getOrReadTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, boltLogger log.BoltLogger) (*idb.RoutingTable, error) { now := r.now() if !r.dbRoutersMut.TryLock(ctx) { @@ -205,7 +207,7 @@ func (r *Router) Writers(ctx context.Context, bookmarks func(context.Context) ([ } func (r *Router) GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, boltLogger log.BoltLogger) (string, error) { - table, err := r.readTable(ctx, nil, bookmarks, db.DefaultDatabase, user, boltLogger) + table, err := r.readTable(ctx, nil, bookmarks, idb.DefaultDatabase, user, boltLogger) if err != nil { return "", err } @@ -294,10 +296,19 @@ func (r *Router) CleanUp(ctx context.Context) error { return nil } -func (r *Router) storeRoutingTable(database string, table *db.RoutingTable, now time.Time) { +func (r *Router) storeRoutingTable(database string, table *idb.RoutingTable, now time.Time) { r.dbRouters[database] = &databaseRouter{ table: table, dueUnix: now.Add(time.Duration(table.TimeToLive) * time.Second).Unix(), } r.log.Debugf(log.Router, r.logId, "New routing table for '%s', TTL %d", database, table.TimeToLive) } + +func wrapError(server string, err error) error { + // Preserve error originating from the database, wrap other errors + _, isNeo4jErr := err.(*db.Neo4jError) + if isNeo4jErr { + return err + } + return &errorutil.ReadRoutingTableError{Server: server, Err: err} +} diff --git a/neo4j/result_with_context.go b/neo4j/result_with_context.go index 83b76bcf..9a080b7f 100644 --- a/neo4j/result_with_context.go +++ b/neo4j/result_with_context.go @@ -23,6 +23,7 @@ import ( "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" ) type ResultWithContext interface { @@ -126,7 +127,7 @@ func (r *resultWithContext) Peek(ctx context.Context) bool { } func (r *resultWithContext) Err() error { - return wrapError(r.err) + return errorutil.WrapError(r.err) } func (r *resultWithContext) Record() *Record { @@ -145,7 +146,7 @@ func (r *resultWithContext) Collect(ctx context.Context) ([]*Record, error) { } } if r.err != nil { - return nil, wrapError(r.err) + return nil, errorutil.WrapError(r.err) } r.callAfterConsumptionHook() return recs, nil @@ -155,7 +156,7 @@ func (r *resultWithContext) Single(ctx context.Context) (*Record, error) { // Try retrieving the single record r.advance(ctx) if r.err != nil { - return nil, wrapError(r.err) + return nil, errorutil.WrapError(r.err) } if r.summary != nil { r.err = &UsageError{Message: "Result contains no more records"} @@ -178,7 +179,7 @@ func (r *resultWithContext) Single(ctx context.Context) (*Record, error) { if r.err != nil { // Might be more records or not, anyway something is bad. // Both r.record and r.summary are nil at this point which is good. - return nil, wrapError(r.err) + return nil, errorutil.WrapError(r.err) } // We got the expected summary // r.record contains the single record and r.summary the summary. @@ -192,13 +193,13 @@ func (r *resultWithContext) Consume(ctx context.Context) (ResultSummary, error) // set by Single to indicate some kind of usage error that "destroyed" // the result. if r.err != nil { - return nil, wrapError(r.err) + return nil, errorutil.WrapError(r.err) } r.record = nil r.summary, r.err = r.conn.Consume(ctx, r.streamHandle) if r.err != nil { - return nil, wrapError(r.err) + return nil, errorutil.WrapError(r.err) } r.callAfterConsumptionHook() return r.toResultSummary(), nil diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index b72fed91..77ec09ce 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -280,14 +280,14 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . // Get a connection from the pool. This could fail in clustered environment. conn, err := s.getConnection(ctx, s.defaultMode, pool.DefaultLivenessCheckThreshold) if err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } // Begin transaction beginBookmarks, err := s.getBookmarks(ctx) if err != nil { _ = s.pool.Return(ctx, conn) - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } txHandle, err := conn.TxBegin(ctx, idb.TxConfig{ @@ -303,7 +303,7 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . }) if err != nil { _ = s.pool.Return(ctx, conn) - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } // Create transaction wrapper @@ -383,26 +383,13 @@ func (s *sessionWithContext) runRetriable( }, } for state.Continue() { - if tryAgain, result := s.executeTransactionFunction(ctx, mode, config, &state, work); tryAgain { - continue - } else { + if hasCompleted, result := s.executeTransactionFunction(ctx, mode, config, &state, work); hasCompleted { return result, nil } } - // When retries has occurred wrap the error, the last error is always added but - // cause is only set when the retry logic could detect something strange. - if state.LastErrWasRetryable { - err := newTransactionExecutionLimit(state.Errs, state.Causes) - s.log.Error(log.Session, s.logId, err) - return nil, err - } - // Wrap and log the error if it belongs to the driver - err := wrapError(state.LastErr) - switch err.(type) { - case *UsageError, *ConnectivityError: - s.log.Error(log.Session, s.logId, err) - } + err := state.ProduceError() + s.log.Error(log.Session, s.logId, err) return nil, err } @@ -415,8 +402,8 @@ func (s *sessionWithContext) executeTransactionFunction( conn, err := s.getConnection(ctx, mode, pool.DefaultLivenessCheckThreshold) if err != nil { - state.OnFailure(ctx, conn, err, false) - return true, nil + state.OnFailure(ctx, err, conn, false) + return false, nil } // handle transaction function panic as well @@ -426,8 +413,8 @@ func (s *sessionWithContext) executeTransactionFunction( beginBookmarks, err := s.getBookmarks(ctx) if err != nil { - state.OnFailure(ctx, conn, err, false) - return true, nil + state.OnFailure(ctx, err, conn, false) + return false, nil } txHandle, err := conn.TxBegin(ctx, idb.TxConfig{ @@ -442,8 +429,8 @@ func (s *sessionWithContext) executeTransactionFunction( }, }) if err != nil { - state.OnFailure(ctx, conn, err, false) - return true, nil + state.OnFailure(ctx, err, conn, false) + return false, nil } tx := managedTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle} @@ -453,14 +440,14 @@ func (s *sessionWithContext) executeTransactionFunction( // client wants to rollback. We don't do an explicit rollback here // but instead rely on the pool invoking reset on the connection, // that will do an implicit rollback. - state.OnFailure(ctx, conn, err, false) - return true, nil + state.OnFailure(ctx, err, conn, false) + return false, nil } err = conn.TxCommit(ctx, txHandle) if err != nil { - state.OnFailure(ctx, conn, err, true) - return true, nil + state.OnFailure(ctx, err, conn, true) + return false, nil } // transaction has been committed so let's ignore (ie just log) the error @@ -468,7 +455,7 @@ func (s *sessionWithContext) executeTransactionFunction( s.log.Warnf(log.Session, s.logId, "could not retrieve bookmarks after successful commit: %s\n"+ "the results of this transaction may not be visible to subsequent operations", err.Error()) } - return false, x + return true, x } func (s *sessionWithContext) getServers(ctx context.Context, mode idb.AccessMode) ([]string, error) { @@ -495,16 +482,16 @@ func (s *sessionWithContext) getConnection(ctx context.Context, mode idb.AccessM } if err := s.resolveHomeDatabase(ctx); err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } servers, err := s.getServers(ctx, mode) if err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } conn, err := s.pool.Borrow(ctx, servers, s.driverConfig.ConnectionAcquisitionTimeout != 0, s.config.BoltLogger, livenessCheckThreshold) if err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } // Select database on server @@ -557,13 +544,13 @@ func (s *sessionWithContext) Run(ctx context.Context, conn, err := s.getConnection(ctx, s.defaultMode, pool.DefaultLivenessCheckThreshold) if err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } runBookmarks, err := s.getBookmarks(ctx) if err != nil { _ = s.pool.Return(ctx, conn) - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } stream, err := conn.Run( ctx, @@ -586,7 +573,7 @@ func (s *sessionWithContext) Run(ctx context.Context, ) if err != nil { _ = s.pool.Return(ctx, conn) - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } s.autocommitTx = &autocommitTransaction{ @@ -634,15 +621,15 @@ func (s *sessionWithContext) legacy() Session { func (s *sessionWithContext) getServerInfo(ctx context.Context) (ServerInfo, error) { if err := s.resolveHomeDatabase(ctx); err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } servers, err := s.getServers(ctx, idb.ReadMode) if err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } conn, err := s.pool.Borrow(ctx, servers, s.driverConfig.ConnectionAcquisitionTimeout != 0, s.config.BoltLogger, 0) if err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } defer s.pool.Return(ctx, conn) return &simpleServerInfo{ diff --git a/neo4j/session_with_context_test.go b/neo4j/session_with_context_test.go index 7fe36941..388dfb28 100644 --- a/neo4j/session_with_context_test.go +++ b/neo4j/session_with_context_test.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "io" "reflect" "sync" @@ -31,7 +32,6 @@ import ( "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/retry" . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -50,7 +50,7 @@ func TestSession(outer *testing.T) { } createSession := func() (*RouterFake, *PoolFake, *sessionWithContext) { - conf := Config{MaxTransactionRetryTime: 3 * time.Millisecond} + conf := Config{MaxTransactionRetryTime: 3 * time.Millisecond, MaxConnectionPoolSize: 100} router := RouterFake{} pool := PoolFake{} sessConfig := SessionConfig{AccessMode: AccessModeRead, BoltLogger: boltLogger} @@ -125,7 +125,7 @@ func TestSession(outer *testing.T) { assertCleanSessionState(t, sess) }) - // Check that sesssion is in clean state after connection fails to commit. + // Check that session is in clean state after connection fails to commit. inner.Run("Failed commit", func(t *testing.T) { _, pool, sess := createSession() pool.BorrowConn = &ConnFake{Alive: false, TxCommitErr: io.EOF} @@ -139,7 +139,7 @@ func TestSession(outer *testing.T) { } // Should not be a TransactionExecutionLimitError here AssertTrue(t, IsConnectivityError(err)) - AssertSameType(t, err.(*ConnectivityError).inner, &retry.CommitFailedDeadError{}) + AssertSameType(t, err.(*ConnectivityError).Inner, &errorutil.CommitFailedDeadError{}) assertCleanSessionState(t, sess) }) diff --git a/neo4j/test-integration/transaction_test.go b/neo4j/test-integration/transaction_test.go index 9f8231e2..6588cc3e 100644 --- a/neo4j/test-integration/transaction_test.go +++ b/neo4j/test-integration/transaction_test.go @@ -59,7 +59,6 @@ func TestTransaction(outer *testing.T) { times := 0 _, err = session.ExecuteWrite(ctx, func(transaction neo4j.ManagedTransaction) (any, error) { times++ - time.Sleep(1 * time.Second) return nil, transientError }) @@ -71,7 +70,6 @@ func TestTransaction(outer *testing.T) { times := 0 _, err = session.ExecuteRead(ctx, func(transaction neo4j.ManagedTransaction) (any, error) { times++ - time.Sleep(1 * time.Second) return nil, transientError }) diff --git a/neo4j/transaction_with_context.go b/neo4j/transaction_with_context.go index ded10668..bdf93d0d 100644 --- a/neo4j/transaction_with_context.go +++ b/neo4j/transaction_with_context.go @@ -22,6 +22,7 @@ package neo4j import ( "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" ) // ManagedTransaction represents a transaction managed by the driver and operated on by the user, via transaction functions @@ -71,7 +72,7 @@ func (tx *explicitTransaction) Run(ctx context.Context, cypher string, tx.err = err tx.runFailed = true tx.onClosed(tx) - return nil, wrapError(tx.err) + return nil, errorutil.WrapError(tx.err) } // no result consumption hook here since bookmarks are sent after commit, not after pulling results return newResultWithContext(tx.conn, stream, cypher, params, nil), nil @@ -88,7 +89,7 @@ func (tx *explicitTransaction) Commit(ctx context.Context) error { tx.err = tx.conn.TxCommit(ctx, tx.txHandle) tx.done = true tx.onClosed(tx) - return wrapError(tx.err) + return errorutil.WrapError(tx.err) } func (tx *explicitTransaction) Close(ctx context.Context) error { @@ -115,7 +116,7 @@ func (tx *explicitTransaction) Rollback(ctx context.Context) error { } tx.done = true tx.onClosed(tx) - return wrapError(tx.err) + return errorutil.WrapError(tx.err) } func (tx *explicitTransaction) legacy() Transaction { @@ -134,7 +135,7 @@ type managedTransaction struct { func (tx *managedTransaction) Run(ctx context.Context, cypher string, params map[string]any) (ResultWithContext, error) { stream, err := tx.conn.RunTx(ctx, tx.txHandle, db.Command{Cypher: cypher, Params: params, FetchSize: tx.fetchSize}) if err != nil { - return nil, wrapError(err) + return nil, errorutil.WrapError(err) } // no result consumption hook here since bookmarks are sent after commit, not after pulling results return newResultWithContext(tx.conn, stream, cypher, params, nil), nil