From 39d4411dee95f43db6ce9a5b782699314076cc04 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 18 Aug 2023 09:48:03 +0200 Subject: [PATCH] Catching up with TestKit (#520) 1. Add support for routing table specific TestKit messaged (retrieving routing tables and forcing updates) 2. Improve logging in the TestKit backend (all logs-driver and bolt-got both to stdout and TestKit) for easier debugging 3. Overhaul how/when the driver drops servers from the cached routing table: * Move logic of deactivating servers on failures into the pool. This functionality should not only be present when using transaction functions. * Drop writers on certain error codes: `Neo.ClientError.Cluster.NotALeader` and `Neo.ClientError.General.ForbiddenOnReadOnlyDatabase` 4. To simplify the code and avoid deadlocks or inconsistent driver states, both pool and router use blocking locks now. However, they will never perform IO while holding the lock removing the need for lock acquisition timeouts. * Reformat feature list to reflect TestKit's state * Clean up test skips * TestKit backend: improve logging Log everything to stdout + TestKit socket. * Drop servers from routing table on IO failure * Remove writer from routing table on certain errors * Refactor responsibilities and locking The retry state is no longer responsible for invalidating servers on broken connections. The pool and/or routing logic should take care of that. This is necessary and logical since invalidation should happen regardless of which API (session.Run, transaction.Run, ...) is used. Pool and router no longer do any IO while holding locks. Hence, the locks can be turned into blocking locks without risking blocking for too long. * Add support for RT related TestKit messages * Remove unused context parameters --------- Signed-off-by: Rouven Bauer Co-authored-by: Florent Biville --- neo4j/directrouter.go | 26 +- neo4j/driver_with_context.go | 19 +- neo4j/driver_with_context_testkit.go | 58 +++- neo4j/internal/bolt/bolt3.go | 26 +- neo4j/internal/bolt/bolt3_test.go | 4 +- neo4j/internal/bolt/bolt4.go | 50 +-- neo4j/internal/bolt/bolt4_test.go | 4 +- neo4j/internal/bolt/bolt5.go | 52 +-- neo4j/internal/bolt/bolt5_test.go | 6 +- neo4j/internal/bolt/bolt_test.go | 8 +- neo4j/internal/bolt/connect.go | 10 +- neo4j/internal/bolt/connections.go | 18 +- neo4j/internal/bolt/hydratedehydrate_test.go | 17 +- neo4j/internal/bolt/message_queue.go | 24 +- neo4j/internal/bolt/outgoing.go | 25 +- neo4j/internal/bolt/outgoing_test.go | 30 +- neo4j/internal/connector/connector.go | 14 +- neo4j/internal/connector/connector_test.go | 16 +- neo4j/internal/db/connection.go | 1 + neo4j/internal/pool/no_test.go | 27 +- neo4j/internal/pool/pool.go | 223 ++++++------- neo4j/internal/pool/pool_test.go | 328 ++++++++++--------- neo4j/internal/pool/server.go | 16 +- neo4j/internal/pool/server_test.go | 20 +- neo4j/internal/retry/state.go | 27 +- neo4j/internal/retry/state_test.go | 62 +--- neo4j/internal/router/no_test.go | 10 +- neo4j/internal/router/readtable.go | 6 +- neo4j/internal/router/router.go | 183 ++++++----- neo4j/internal/router/router_test.go | 12 +- neo4j/internal/router/router_testkit.go | 28 ++ neo4j/internal/testutil/connfake.go | 4 + neo4j/internal/testutil/poolfake.go | 8 +- neo4j/internal/testutil/routerfake.go | 36 +- neo4j/session_with_context.go | 57 ++-- neo4j/test-integration/dbconn_test.go | 15 +- neo4j/test-integration/driver_test.go | 12 +- testkit-backend/backend.go | 138 ++++++-- testkit-backend/streamlogger.go | 42 ++- 39 files changed, 943 insertions(+), 719 deletions(-) create mode 100644 neo4j/internal/router/router_testkit.go diff --git a/neo4j/directrouter.go b/neo4j/directrouter.go index fd0cfb1a..234efa5b 100644 --- a/neo4j/directrouter.go +++ b/neo4j/directrouter.go @@ -30,38 +30,32 @@ type directRouter struct { address string } -func (r *directRouter) InvalidateWriter(context.Context, string, string) error { - return nil -} +func (r *directRouter) InvalidateWriter(string, string) {} -func (r *directRouter) InvalidateReader(context.Context, string, string) error { - return nil -} +func (r *directRouter) InvalidateReader(string, string) {} + +func (r *directRouter) InvalidateServer(string) {} func (r *directRouter) GetOrUpdateReaders(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) { return []string{r.address}, nil } -func (r *directRouter) Readers(context.Context, string) ([]string, error) { - return []string{r.address}, nil +func (r *directRouter) Readers(string) []string { + return []string{r.address} } func (r *directRouter) GetOrUpdateWriters(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) { return []string{r.address}, nil } -func (r *directRouter) Writers(context.Context, string) ([]string, error) { - return []string{r.address}, nil +func (r *directRouter) Writers(string) []string { + return []string{r.address} } func (r *directRouter) GetNameOfDefaultDatabase(context.Context, []string, string, *db.ReAuthToken, log.BoltLogger) (string, error) { return db.DefaultDatabase, nil } -func (r *directRouter) Invalidate(context.Context, string) error { - return nil -} +func (r *directRouter) Invalidate(string) {} -func (r *directRouter) CleanUp(context.Context) error { - return nil -} +func (r *directRouter) CleanUp() {} diff --git a/neo4j/driver_with_context.go b/neo4j/driver_with_context.go index c95974d5..0c941875 100644 --- a/neo4j/driver_with_context.go +++ b/neo4j/driver_with_context.go @@ -254,6 +254,8 @@ func NewDriverWithContext(target string, auth auth.TokenManager, configurers ... d.router = router.New(address, routersResolver, routingContext, d.pool, d.log, d.logId, &d.now) } + d.pool.SetRouter(d.router) + d.log.Infof(log.Driver, d.logId, "Created { target: %s }", address) return &d, nil } @@ -300,20 +302,21 @@ type sessionRouter interface { // they should not be called when it is not needed (e.g. when a routing table is cached) GetOrUpdateReaders(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error) // Readers returns the list of servers that can serve reads on the requested database. - Readers(ctx context.Context, database string) ([]string, error) + Readers(database string) []string // GetOrUpdateWriters returns the list of servers that can serve writes on the requested database. // note: bookmarks are lazily supplied, see Readers documentation to learn why GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error) // Writers returns the list of servers that can serve writes on the requested database. - Writers(ctx context.Context, database string) ([]string, error) + Writers(database string) []string // GetNameOfDefaultDatabase returns the name of the default database for the specified user. // The correct database name is needed when requesting readers or writers. // the bookmarks are eagerly provided since this method always fetches a new routing table GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) (string, error) - Invalidate(ctx context.Context, database string) error - CleanUp(ctx context.Context) error - InvalidateWriter(ctx context.Context, name string, server string) error - InvalidateReader(ctx context.Context, name string, server string) error + Invalidate(db string) + CleanUp() + InvalidateWriter(db string, server string) + InvalidateReader(db string, server string) + InvalidateServer(server string) } type driverWithContext struct { @@ -394,9 +397,7 @@ func (d *driverWithContext) Close(ctx context.Context) error { defer d.mut.Unlock() // Safeguard against closing more than once if d.pool != nil { - if err := d.pool.Close(ctx); err != nil { - return err - } + d.pool.Close(ctx) d.pool = nil d.log.Infof(log.Driver, d.logId, "Closed") } diff --git a/neo4j/driver_with_context_testkit.go b/neo4j/driver_with_context_testkit.go index f30c08bb..94fb1c98 100644 --- a/neo4j/driver_with_context_testkit.go +++ b/neo4j/driver_with_context_testkit.go @@ -1,5 +1,3 @@ -//go:build internal_testkit - /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] @@ -10,18 +8,30 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ +//go:build internal_testkit + package neo4j -import "time" +import ( + "context" + "fmt" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" + "time" +) + +type RoutingTable = idb.RoutingTable func SetTimer(d DriverWithContext, timer func() time.Time) { driver := d.(*driverWithContext) @@ -32,3 +42,33 @@ func ResetTime(d DriverWithContext) { driver := d.(*driverWithContext) driver.now = time.Now } + +func ForceRoutingTableUpdate(d DriverWithContext, database string, bookmarks []string, logger log.BoltLogger) error { + driver := d.(*driverWithContext) + ctx := context.Background() + driver.router.Invalidate(database) + getBookmarks := func(context.Context) ([]string, error) { + return bookmarks, nil + } + auth := &idb.ReAuthToken{ + Manager: driver.auth, + FromSession: false, + ForceReAuth: false, + } + _, err := driver.router.GetOrUpdateReaders(ctx, getBookmarks, database, auth, logger) + if err != nil { + return errorutil.WrapError(err) + } + _, err = driver.router.GetOrUpdateWriters(ctx, getBookmarks, database, auth, logger) + return errorutil.WrapError(err) +} + +func GetRoutingTable(d DriverWithContext, database string) (*RoutingTable, error) { + driver := d.(*driverWithContext) + router, ok := driver.router.(*router.Router) + if !ok { + return nil, fmt.Errorf("GetRoutingTable is only supported for direct drivers") + } + table := router.GetTable(database) + return table, nil +} diff --git a/neo4j/internal/bolt/bolt3.go b/neo4j/internal/bolt/bolt3.go index 41ca4101..33ae9558 100644 --- a/neo4j/internal/bolt/bolt3.go +++ b/neo4j/internal/bolt/bolt3.go @@ -95,14 +95,14 @@ type bolt3 struct { auth map[string]any authManager auth.TokenManager resetAuth bool - onNeo4jError Neo4jErrorCallback + errorListener ConnectionErrorListener now *func() time.Time } func NewBolt3( serverName string, conn net.Conn, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, @@ -120,16 +120,22 @@ func NewBolt3( }, connReadTimeout: -1, }, - birthDate: now, - idleDate: now, - log: logger, - onNeo4jError: callback, - now: timer, + birthDate: now, + idleDate: now, + log: logger, + errorListener: errorListener, + now: timer, } b.out = &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { + onPackErr: func(err error) { + if b.err == nil { + b.err = err + } + b.state = bolt3_dead + }, + onIoErr: func(ctx context.Context, err error) { if b.err == nil { b.err = err } @@ -181,7 +187,7 @@ func (b *bolt3) receiveSuccess(ctx context.Context) *success { } else { b.log.Error(log.Bolt3, b.logId, message) } - if err := b.onNeo4jError(ctx, b, message); err != nil { + if err := b.errorListener.OnNeo4jError(ctx, b, message); err != nil { b.err = errorutil.CombineErrors(message, b.err) } return nil @@ -662,7 +668,7 @@ func (b *bolt3) receiveNext(ctx context.Context) (*db.Record, *db.Summary, error } else { b.log.Error(log.Bolt3, b.logId, message) } - if err := b.onNeo4jError(ctx, b, message); err != nil { + if err := b.errorListener.OnNeo4jError(ctx, b, message); err != nil { return nil, nil, errorutil.CombineErrors(message, err) } return nil, nil, message diff --git a/neo4j/internal/bolt/bolt3_test.go b/neo4j/internal/bolt/bolt3_test.go index ff494892..82f39199 100644 --- a/neo4j/internal/bolt/bolt3_test.go +++ b/neo4j/internal/bolt/bolt3_test.go @@ -114,7 +114,7 @@ func TestBolt3(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -168,7 +168,7 @@ func TestBolt3(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, diff --git a/neo4j/internal/bolt/bolt4.go b/neo4j/internal/bolt/bolt4.go index 9532e647..e28d6168 100644 --- a/neo4j/internal/bolt/bolt4.go +++ b/neo4j/internal/bolt/bolt4.go @@ -111,30 +111,30 @@ type bolt4 struct { auth map[string]any authManager auth.TokenManager resetAuth bool - onNeo4jError Neo4jErrorCallback + errorListener ConnectionErrorListener now *func() time.Time } func NewBolt4( serverName string, conn net.Conn, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, ) *bolt4 { now := (*timer)() b := &bolt4{ - state: bolt4_unauthorized, - conn: conn, - serverName: serverName, - birthDate: now, - idleDate: now, - log: logger, - streams: openstreams{}, - lastQid: -1, - onNeo4jError: callback, - now: timer, + state: bolt4_unauthorized, + conn: conn, + serverName: serverName, + birthDate: now, + idleDate: now, + log: logger, + streams: openstreams{}, + lastQid: -1, + errorListener: errorListener, + now: timer, } b.queue = newMessageQueue( conn, @@ -149,11 +149,12 @@ func NewBolt4( &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { b.setError(err, true) }, + onPackErr: func(err error) { b.setError(err, true) }, + onIoErr: b.onIoError, boltLogger: boltLog, }, b.onNextMessage, - b.onNextMessageError, + b.onIoError, ) return b @@ -938,6 +939,10 @@ func (b *bolt4) SelectDatabase(database string) { b.databaseName = database } +func (b *bolt4) Database() string { + return b.databaseName +} + func (b *bolt4) SetBoltLogger(boltLogger log.BoltLogger) { b.queue.setBoltLogger(boltLogger) } @@ -1079,7 +1084,7 @@ func (b *bolt4) resetResponseHandler() responseHandler { b.state = bolt4_ready }, onFailure: func(ctx context.Context, failure *db.Neo4jError) { - _ = b.onNeo4jError(ctx, b, failure) + _ = b.errorListener.OnNeo4jError(ctx, b, failure) b.state = bolt4_dead }, } @@ -1126,19 +1131,24 @@ func (b *bolt4) onNextMessage() { b.idleDate = (*b.now)() } -func (b *bolt4) onNextMessageError(err error) { - b.setError(err, true) -} - func (b *bolt4) onFailure(ctx context.Context, failure *db.Neo4jError) { var err error err = failure - if callbackErr := b.onNeo4jError(ctx, b, failure); callbackErr != nil { + if callbackErr := b.errorListener.OnNeo4jError(ctx, b, failure); callbackErr != nil { err = errorutil.CombineErrors(failure, callbackErr) } b.setError(err, isFatalError(failure)) } +func (b *bolt4) onIoError(ctx context.Context, err error) { + if b.state != bolt4_failed && b.state != bolt4_dead { + // Don't call callback when connections break after sending RESET. + // The server chooses to close the connection on some errors. + b.errorListener.OnIoError(ctx, b, err) + } + b.setError(err, true) +} + const readTimeoutHintName = "connection.recv_timeout_seconds" func (b *bolt4) initializeReadTimeoutHint(hints map[string]any) { diff --git a/neo4j/internal/bolt/bolt4_test.go b/neo4j/internal/bolt/bolt4_test.go index 8c740154..1dc55426 100644 --- a/neo4j/internal/bolt/bolt4_test.go +++ b/neo4j/internal/bolt/bolt4_test.go @@ -117,7 +117,7 @@ func TestBolt4(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -322,7 +322,7 @@ func TestBolt4(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, diff --git a/neo4j/internal/bolt/bolt5.go b/neo4j/internal/bolt/bolt5.go index ad69aff8..a658caca 100644 --- a/neo4j/internal/bolt/bolt5.go +++ b/neo4j/internal/bolt/bolt5.go @@ -113,30 +113,30 @@ type bolt5 struct { auth map[string]any authManager auth.TokenManager resetAuth bool - onNeo4jError Neo4jErrorCallback + errorListener ConnectionErrorListener now *func() time.Time } func NewBolt5( serverName string, conn net.Conn, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, ) *bolt5 { now := (*timer)() b := &bolt5{ - state: bolt5Unauthorized, - conn: conn, - serverName: serverName, - birthDate: now, - idleDate: now, - log: logger, - streams: openstreams{}, - lastQid: -1, - onNeo4jError: callback, - now: timer, + state: bolt5Unauthorized, + conn: conn, + serverName: serverName, + birthDate: now, + idleDate: now, + log: logger, + streams: openstreams{}, + lastQid: -1, + errorListener: errorListener, + now: timer, } b.queue = newMessageQueue( conn, @@ -152,12 +152,13 @@ func NewBolt5( &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { b.setError(err, true) }, + onPackErr: func(err error) { b.setError(err, true) }, + onIoErr: b.onIoError, boltLogger: boltLog, useUtc: true, }, b.onNextMessage, - b.onNextMessageError, + b.onIoError, ) return b } @@ -921,19 +922,23 @@ func (b *bolt5) reAuth(ctx context.Context, auth *idb.ReAuthToken) error { func (b *bolt5) Close(ctx context.Context) { b.log.Infof(log.Bolt5, b.logId, "Close") if b.state != bolt5Dead { + b.state = bolt5Dead b.queue.appendGoodbye() b.queue.send(ctx) } if err := b.conn.Close(); err != nil { b.log.Warnf(log.Driver, b.serverName, "could not close underlying socket") } - b.state = bolt5Dead } func (b *bolt5) SelectDatabase(database string) { b.databaseName = database } +func (b *bolt5) Database() string { + return b.databaseName +} + func (b *bolt5) Version() db.ProtocolVersion { return db.ProtocolVersion{ Major: 5, @@ -1077,7 +1082,7 @@ func (b *bolt5) resetResponseHandler() responseHandler { b.state = bolt5Ready }, onFailure: func(ctx context.Context, failure *db.Neo4jError) { - _ = b.onNeo4jError(ctx, b, failure) + _ = b.errorListener.OnNeo4jError(ctx, b, failure) b.state = bolt5Dead }, } @@ -1111,19 +1116,24 @@ func (b *bolt5) onNextMessage() { b.idleDate = (*b.now)() } -func (b *bolt5) onNextMessageError(err error) { - b.setError(err, true) -} - func (b *bolt5) onFailure(ctx context.Context, failure *db.Neo4jError) { var err error err = failure - if callbackErr := b.onNeo4jError(ctx, b, failure); callbackErr != nil { + if callbackErr := b.errorListener.OnNeo4jError(ctx, b, failure); callbackErr != nil { err = errorutil.CombineErrors(callbackErr, failure) } b.setError(err, isFatalError(failure)) } +func (b *bolt5) onIoError(ctx context.Context, err error) { + if b.state != bolt5Failed && b.state != bolt5Dead { + // Don't call callback when connections break after sending RESET. + // The server chooses to close the connection on some errors. + b.errorListener.OnIoError(ctx, b, err) + } + b.setError(err, true) +} + func (b *bolt5) initializeReadTimeoutHint(hints map[string]any) { readTimeoutHint, ok := hints[readTimeoutHintName] if !ok { diff --git a/neo4j/internal/bolt/bolt5_test.go b/neo4j/internal/bolt/bolt5_test.go index f147e020..ce2a971a 100644 --- a/neo4j/internal/bolt/bolt5_test.go +++ b/neo4j/internal/bolt/bolt5_test.go @@ -118,7 +118,7 @@ func TestBolt5(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -414,7 +414,7 @@ func TestBolt5(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -451,7 +451,7 @@ func TestBolt5(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, diff --git a/neo4j/internal/bolt/bolt_test.go b/neo4j/internal/bolt/bolt_test.go index 56de639c..e8b79908 100644 --- a/neo4j/internal/bolt/bolt_test.go +++ b/neo4j/internal/bolt/bolt_test.go @@ -25,6 +25,12 @@ import ( idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" ) -func noopOnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { +type noopErrorListener struct{} + +func (n noopErrorListener) OnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { return nil } + +func (n noopErrorListener) OnIoError(context.Context, idb.Connection, error) {} + +func (n noopErrorListener) OnDialError(context.Context, string, error) {} diff --git a/neo4j/internal/bolt/connect.go b/neo4j/internal/bolt/connect.go index e1b08b5f..17bb6397 100644 --- a/neo4j/internal/bolt/connect.go +++ b/neo4j/internal/bolt/connect.go @@ -54,7 +54,7 @@ func Connect(ctx context.Context, auth *db.ReAuthToken, userAgent string, routingContext map[string]string, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, logger log.Logger, boltLogger log.BoltLogger, notificationConfig db.NotificationConfig, @@ -74,6 +74,7 @@ func Connect(ctx context.Context, } _, err := racing.NewRacingWriter(conn).Write(ctx, handshake) if err != nil { + errorListener.OnDialError(ctx, serverName, err) return nil, err } @@ -81,6 +82,7 @@ func Connect(ctx context.Context, buf := make([]byte, 4) _, err = racing.NewRacingReader(conn).ReadFull(ctx, buf) if err != nil { + errorListener.OnDialError(ctx, serverName, err) return nil, err } @@ -93,11 +95,11 @@ func Connect(ctx context.Context, var boltConn db.Connection switch major { case 3: - boltConn = NewBolt3(serverName, conn, callback, timer, logger, boltLogger) + boltConn = NewBolt3(serverName, conn, errorListener, timer, logger, boltLogger) case 4: - boltConn = NewBolt4(serverName, conn, callback, timer, logger, boltLogger) + boltConn = NewBolt4(serverName, conn, errorListener, timer, logger, boltLogger) case 5: - boltConn = NewBolt5(serverName, conn, callback, timer, logger, boltLogger) + boltConn = NewBolt5(serverName, conn, errorListener, timer, logger, boltLogger) case 0: return nil, fmt.Errorf("server did not accept any of the requested Bolt versions (%#v)", versions) default: diff --git a/neo4j/internal/bolt/connections.go b/neo4j/internal/bolt/connections.go index 0d3caf07..16218409 100644 --- a/neo4j/internal/bolt/connections.go +++ b/neo4j/internal/bolt/connections.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -27,7 +27,11 @@ import ( "net" ) -type Neo4jErrorCallback func(context.Context, idb.Connection, *db.Neo4jError) error +type ConnectionErrorListener interface { + OnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error + OnIoError(context.Context, idb.Connection, error) + OnDialError(context.Context, string, error) +} func handleTerminatedContextError(err error, connection net.Conn) error { if !contextTerminatedErr(err) { diff --git a/neo4j/internal/bolt/hydratedehydrate_test.go b/neo4j/internal/bolt/hydratedehydrate_test.go index ad735a04..9db79ff5 100644 --- a/neo4j/internal/bolt/hydratedehydrate_test.go +++ b/neo4j/internal/bolt/hydratedehydrate_test.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -39,9 +39,12 @@ func TestDehydrateHydrate(ot *testing.T) { out := &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { + onPackErr: func(err error) { ot.Fatalf("Should be no dehydration errors in this test: %s", err) }, + onIoErr: func(_ context.Context, err error) { + ot.Fatalf("Should be no io errors in this test: %s", err) + }, } serv, cli := net.Pipe() defer func() { diff --git a/neo4j/internal/bolt/message_queue.go b/neo4j/internal/bolt/message_queue.go index 4e52f5f8..de8e4906 100644 --- a/neo4j/internal/bolt/message_queue.go +++ b/neo4j/internal/bolt/message_queue.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -36,23 +36,23 @@ type messageQueue struct { targetConnection net.Conn err error - onNextMessage func() - onNextMessageErr func(error) + onNextMessage func() + onIoErr func(context.Context, error) } func newMessageQueue( target net.Conn, in *incoming, out *outgoing, onNext func(), - onNextErr func(error)) messageQueue { - + onIoErr func(context.Context, error), +) messageQueue { return messageQueue{ in: in, out: out, handlers: list.List{}, targetConnection: target, onNextMessage: onNext, - onNextMessageErr: onNextErr, + onIoErr: onIoErr, } } @@ -205,7 +205,7 @@ func (q *messageQueue) receiveMsg(ctx context.Context) any { msg, err := q.in.next(ctx, q.targetConnection) q.err = err if err != nil { - q.onNextMessageErr(err) + q.onIoErr(ctx, err) } else { q.onNextMessage() } diff --git a/neo4j/internal/bolt/outgoing.go b/neo4j/internal/bolt/outgoing.go index eab6be0d..771467c7 100644 --- a/neo4j/internal/bolt/outgoing.go +++ b/neo4j/internal/bolt/outgoing.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -35,7 +35,8 @@ import ( type outgoing struct { chunker chunker packer packstream.Packer - onErr func(err error) + onPackErr func(error) + onIoErr func(context.Context, error) boltLogger log.BoltLogger logId string useUtc bool @@ -51,7 +52,7 @@ func (o *outgoing) end() { o.chunker.buf = buf o.chunker.endMessage() if err != nil { - o.onErr(err) + o.onPackErr(err) } } @@ -259,7 +260,7 @@ func (o *outgoing) appendX(tag byte, fields ...any) { func (o *outgoing) send(ctx context.Context, wr io.Writer) { err := o.chunker.send(ctx, wr) if err != nil { - o.onErr(err) + o.onIoErr(ctx, err) } } @@ -347,7 +348,7 @@ func (o *outgoing) packStruct(x any) { o.packer.Int64(v.Seconds) o.packer.Int(v.Nanos) default: - o.onErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) + o.onPackErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) } } @@ -416,7 +417,7 @@ func (o *outgoing) packX(x any) { default: t := reflect.TypeOf(x) if t.Key().Kind() != reflect.String { - o.onErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) + o.onPackErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) return } o.packer.MapHeader(v.Len()) @@ -427,7 +428,7 @@ func (o *outgoing) packX(x any) { } } default: - o.onErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) + o.onPackErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) } } diff --git a/neo4j/internal/bolt/outgoing_test.go b/neo4j/internal/bolt/outgoing_test.go index 2eb5bc6f..af6fe096 100644 --- a/neo4j/internal/bolt/outgoing_test.go +++ b/neo4j/internal/bolt/outgoing_test.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -91,9 +91,12 @@ func TestOutgoing(ot *testing.T) { // Utility to unpack through dechunking and a custom build func dechunkAndUnpack := func(t *testing.T, build func(*testing.T, *outgoing)) any { out := &outgoing{ - chunker: newChunker(), - packer: packstream.Packer{}, - onErr: func(e error) { err = e }, + chunker: newChunker(), + packer: packstream.Packer{}, + onPackErr: func(e error) { err = e }, + onIoErr: func(_ context.Context, err error) { + ot.Fatalf("Should be no io errors in this test: %s", err) + }, } serv, cli := net.Pipe() defer func() { @@ -584,9 +587,12 @@ func TestOutgoing(ot *testing.T) { for _, c := range paramErrorCases { var err error out := &outgoing{ - chunker: newChunker(), - packer: packstream.Packer{}, - onErr: func(e error) { err = e }, + chunker: newChunker(), + packer: packstream.Packer{}, + onPackErr: func(e error) { err = e }, + onIoErr: func(_ context.Context, err error) { + ot.Fatalf("Should be no io errors in this test: %s", err) + }, } ot.Run(c.name, func(t *testing.T) { out.begin() diff --git a/neo4j/internal/connector/connector.go b/neo4j/internal/connector/connector.go index e89186d1..becb01d3 100644 --- a/neo4j/internal/connector/connector.go +++ b/neo4j/internal/connector/connector.go @@ -50,7 +50,7 @@ func (c Connector) Connect( ctx context.Context, address string, auth *db.ReAuthToken, - callback bolt.Neo4jErrorCallback, + errorListener bolt.ConnectionErrorListener, boltLogger log.BoltLogger, ) (connection db.Connection, err error) { if c.SupplyConnection == nil { @@ -59,13 +59,14 @@ func (c Connector) Connect( conn, err := c.SupplyConnection(ctx, address) if err != nil { + errorListener.OnDialError(ctx, address, err) return nil, err } defer func() { if err != nil && connection == nil { if err := conn.Close(); err != nil { - c.Log.Warnf(log.Driver, address, "could not close socket after failed connection") + c.Log.Warnf(log.Driver, address, "could not close socket after failed connection %s", err) } } }() @@ -84,7 +85,7 @@ func (c Connector) Connect( auth, c.Config.UserAgent, c.RoutingContext, - callback, + errorListener, c.Log, boltLogger, notificationConfig, @@ -99,6 +100,7 @@ func (c Connector) Connect( // TLS requested, continue with handshake serverName, _, err := net.SplitHostPort(address) if err != nil { + errorListener.OnDialError(ctx, address, err) return nil, err } tlsConn := tls.Client(conn, c.tlsConfig(serverName)) @@ -108,7 +110,9 @@ func (c Connector) Connect( // Give a bit nicer error message err = errors.New("remote end closed the connection, check that TLS is enabled on the server") } - return nil, &errorutil.TlsError{Inner: err} + err = &errorutil.TlsError{Inner: err} + errorListener.OnDialError(ctx, address, err) + return nil, err } connection, err = bolt.Connect(ctx, address, @@ -116,7 +120,7 @@ func (c Connector) Connect( auth, c.Config.UserAgent, c.RoutingContext, - callback, + errorListener, c.Log, boltLogger, notificationConfig, diff --git a/neo4j/internal/connector/connector_test.go b/neo4j/internal/connector/connector_test.go index 6b8ff0b1..633a68e6 100644 --- a/neo4j/internal/connector/connector_test.go +++ b/neo4j/internal/connector/connector_test.go @@ -22,7 +22,9 @@ package connector_test import ( "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/connector" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" "io" "net" @@ -30,6 +32,16 @@ import ( "time" ) +type noopErrorListener struct{} + +func (n noopErrorListener) OnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { + return nil +} + +func (n noopErrorListener) OnIoError(context.Context, idb.Connection, error) {} + +func (n noopErrorListener) OnDialError(context.Context, string, error) {} + func TestConnect(outer *testing.T) { outer.Parallel() @@ -49,7 +61,7 @@ func TestConnect(outer *testing.T) { Now: &timer, } - connection, err := connector.Connect(ctx, "irrelevant", nil, nil, nil) + connection, err := connector.Connect(ctx, "irrelevant", nil, noopErrorListener{}, nil) AssertNil(t, connection) AssertErrorMessageContains(t, err, "unsupported version 1.0") @@ -70,7 +82,7 @@ func TestConnect(outer *testing.T) { Now: &timer, } - connection, err := connector.Connect(ctx, "irrelevant", nil, nil, nil) + connection, err := connector.Connect(ctx, "irrelevant", nil, noopErrorListener{}, nil) AssertNil(t, connection) AssertError(t, err) diff --git a/neo4j/internal/db/connection.go b/neo4j/internal/db/connection.go index 203ad962..6a6556b8 100644 --- a/neo4j/internal/db/connection.go +++ b/neo4j/internal/db/connection.go @@ -184,4 +184,5 @@ type DatabaseSelector interface { // SelectDatabase should be called immediately after Reset. Not allowed to call multiple times with different // databases without a reset in-between. SelectDatabase(database string) + Database() string } diff --git a/neo4j/internal/pool/no_test.go b/neo4j/internal/pool/no_test.go index 28e2072b..25741ff7 100644 --- a/neo4j/internal/pool/no_test.go +++ b/neo4j/internal/pool/no_test.go @@ -8,19 +8,18 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package pool import ( - "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "testing" ) @@ -44,24 +43,18 @@ func assertNoConnection(t *testing.T, c db.Connection, err error) { } } -func assertNumberOfServers(t *testing.T, ctx context.Context, p *Pool, expectedNum int) { +func assertNumberOfServers(t *testing.T, p *Pool, expectedNum int) { t.Helper() - servers, err := p.getServers(ctx) - if err != nil { - t.Fatalf("Expected nil error, got: %v", err) - } + servers := p.getServers() actualNum := len(servers) if actualNum != expectedNum { t.Fatalf("Expected number of servers to be %d but was %d", expectedNum, actualNum) } } -func assertNumberOfIdle(t *testing.T, ctx context.Context, p *Pool, serverName string, expectedNum int) { +func assertNumberOfIdle(t *testing.T, p *Pool, serverName string, expectedNum int) { t.Helper() - servers, err := p.getServers(ctx) - if err != nil { - t.Fatalf("Expected nil error, got: %v", err) - } + servers := p.getServers() server := servers[serverName] if server == nil { t.Fatalf("Server %s not found", serverName) diff --git a/neo4j/internal/pool/pool.go b/neo4j/internal/pool/pool.go index b2a8a069..163b220e 100644 --- a/neo4j/internal/pool/pool.go +++ b/neo4j/internal/pool/pool.go @@ -25,14 +25,12 @@ package pool import ( "container/list" "context" - "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "math" "sort" "sync" @@ -45,7 +43,13 @@ import ( // Liveness checks are performed before a connection is deemed idle enough to be reset const DefaultLivenessCheckThreshold = math.MaxInt64 -type Connect func(context.Context, string, *idb.ReAuthToken, bolt.Neo4jErrorCallback, log.BoltLogger) (idb.Connection, error) +type Connect func(context.Context, string, *idb.ReAuthToken, bolt.ConnectionErrorListener, log.BoltLogger) (idb.Connection, error) + +type poolRouter interface { + InvalidateWriter(db string, server string) + InvalidateReader(db string, server string) + InvalidateServer(server string) +} type qitem struct { wakeup chan bool @@ -54,9 +58,10 @@ type qitem struct { type Pool struct { config *config.Config connect Connect + router poolRouter servers map[string]*server - serversMut racing.Mutex - queueMut racing.Mutex + serversMut sync.Mutex + queueMut sync.Mutex queue list.List now *func() time.Time closed bool @@ -75,9 +80,10 @@ func New(config *config.Config, connect Connect, logger log.Logger, logId string p := &Pool{ config: config, connect: connect, + router: nil, servers: make(map[string]*server), - serversMut: racing.NewMutex(), - queueMut: racing.NewMutex(), + serversMut: sync.Mutex{}, + queueMut: sync.Mutex{}, now: now, logId: logId, log: logger, @@ -86,11 +92,13 @@ func New(config *config.Config, connect Connect, logger log.Logger, logId string return p } -func (p *Pool) Close(ctx context.Context) error { +func (p *Pool) SetRouter(router poolRouter) { + p.router = router +} + +func (p *Pool) Close(ctx context.Context) { p.closed = true - if !p.queueMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire queue lock in time when closing pool") - } + p.queueMut.Lock() for e := p.queue.Front(); e != nil; e = e.Next() { queuedRequest := e.Value.(*qitem) p.queue.Remove(e) @@ -98,48 +106,39 @@ func (p *Pool) Close(ctx context.Context) error { } p.queueMut.Unlock() // Go through each server and close all connections to it - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock in time when closing pool") - } + p.serversMut.Lock() for n, s := range p.servers { s.closeAll(ctx) delete(p.servers, n) } p.serversMut.Unlock() p.log.Infof(log.Pool, p.logId, "Closed") - return nil } // For testing -func (p *Pool) queueSize(ctx context.Context) (int, error) { - if !p.queueMut.TryLock(ctx) { - return -1, fmt.Errorf("could not acquire queue lock in time when checking queue size") - } +func (p *Pool) queueSize() int { + p.queueMut.Lock() defer p.queueMut.Unlock() - return p.queue.Len(), nil + return p.queue.Len() } // For testing -func (p *Pool) getServers(ctx context.Context) (map[string]*server, error) { - if !p.serversMut.TryLock(ctx) { - return nil, fmt.Errorf("could not acquire server lock in time when getting servers") - } +func (p *Pool) getServers() map[string]*server { + p.serversMut.Lock() defer p.serversMut.Unlock() servers := make(map[string]*server) for k, v := range p.servers { servers[k] = v } - return servers, nil + return servers } // CleanUp prunes all old connection on all the servers, this makes sure that servers // gets removed from the map at some point in time. If there is a noticed // failed connect still active we should wait a while with removal to get // prioritization right. -func (p *Pool) CleanUp(ctx context.Context) error { - if !p.serversMut.TryLock(ctx) { - return fmt.Errorf("could not acquire server lock in time when cleaning up pool") - } +func (p *Pool) CleanUp(ctx context.Context) { + p.serversMut.Lock() defer p.serversMut.Unlock() now := (*p.now)() for n, s := range p.servers { @@ -148,17 +147,14 @@ func (p *Pool) CleanUp(ctx context.Context) error { delete(p.servers, n) } } - return nil } func (p *Pool) Now() time.Time { return (*p.now)() } -func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) ([]serverPenalty, error) { - if !p.serversMut.TryLock(ctx) { - return nil, fmt.Errorf("could not acquire server lock in time when computing server penalties") - } +func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) []serverPenalty { + p.serversMut.Lock() defer p.serversMut.Unlock() // Retrieve penalty for each server @@ -175,13 +171,11 @@ func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) penalties[i].penalty = newConnectionPenalty } } - return penalties, nil + return penalties } func (p *Pool) tryAnyIdle(ctx context.Context, serverNames []string, idlenessThreshold time.Duration, auth *idb.ReAuthToken, logger log.BoltLogger) (idb.Connection, error) { - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire server lock in time when getting idle connection") - } + p.serversMut.Lock() var unlock = new(sync.Once) defer unlock.Do(p.serversMut.Unlock) serverLoop: @@ -198,16 +192,12 @@ serverLoop: if healthy { return conn, nil } - if err := p.unreg(context.Background(), serverName, conn, p.Now()); err != nil { - panic("lock with Background context should never time out") - } + p.unreg(ctx, serverName, conn, p.Now()) if err != nil { p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err) return nil, err } - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when borrowing a connection") - } + p.serversMut.Lock() *unlock = sync.Once{} } } @@ -215,29 +205,25 @@ serverLoop: return nil, nil } -func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { +func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { for { if p.closed { return nil, &errorutil.PoolClosed{} } - serverNames, err := getServerNames(ctx) - if err != nil { - return nil, err - } + serverNames := getServerNames() if len(serverNames) == 0 { return nil, &errorutil.PoolOutOfServers{} } p.log.Debugf(log.Pool, p.logId, "Trying to borrow connection from %s", serverNames) // Retrieve penalty for each server - penalties, err := p.getPenaltiesForServers(ctx, serverNames) - if err != nil { - return nil, err - } + penalties := p.getPenaltiesForServers(ctx, serverNames) // Sort server penalties by lowest penalty sort.Slice(penalties, func(i, j int) bool { return penalties[i].penalty < penalties[j].penalty }) + var err error + var conn idb.Connection for _, s := range penalties { conn, err = p.tryBorrow(ctx, s.name, boltLogger, idlenessThreshold, auth) @@ -265,9 +251,7 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) } // Wait for a matching connection to be returned from another thread. - if !p.queueMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when trying to get an idle connection") - } + p.queueMut.Lock() // Ok, now that we own the queue we can add the item there but between getting the lock // and above check for an existing connection another thread might have returned a connection // so check again to avoid potentially starving this thread. @@ -294,9 +278,7 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) case <-q.wakeup: continue case <-ctx.Done(): - if !p.queueMut.TryLock(context.Background()) { - return nil, racing.LockTimeoutError("could not acquire lock in time when removing server wait request") - } + p.queueMut.Lock() p.queue.Remove(e) p.queueMut.Unlock() p.log.Warnf(log.Pool, p.logId, "Borrow time-out") @@ -309,15 +291,14 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. // For now, lock complete servers map to avoid over connecting but with the downside // that long connect times will block connects to other servers as well. To fix this // we would need to add a pending connect to the server and lock per server. - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when borrowing a connection") - } + p.serversMut.Lock() var unlock = new(sync.Once) defer unlock.Do(p.serversMut.Unlock) srv := p.servers[serverName] for { if srv != nil { + srv.closing = false connection := srv.getIdle() if connection == nil { if srv.size() >= p.config.MaxConnectionPoolSize { @@ -330,16 +311,12 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. if healthy { return connection, nil } - if err := p.unreg(context.Background(), serverName, connection, p.Now()); err != nil { - panic("lock with Background context should never time out") - } + p.unreg(ctx, serverName, connection, p.Now()) if err != nil { p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err) return nil, err } - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when borrowing a connection") - } + p.serversMut.Lock() *unlock = sync.Once{} srv = p.servers[serverName] } else { @@ -355,18 +332,16 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. // No idle connection, try to connect p.log.Infof(log.Pool, p.logId, "Connecting to %s", serverName) - c, err := p.connect(ctx, serverName, auth, p.OnConnectionError, boltLogger) - if !p.serversMut.TryLock(context.Background()) { - panic("lock with Background context should never time out") - } + c, err := p.connect(ctx, serverName, auth, p, boltLogger) + p.serversMut.Lock() *unlock = sync.Once{} srv.reservations-- if err != nil { + p.log.Warnf(log.Pool, p.logId, "Failed to connect to %s: %s", serverName, err) // FeatureNotSupportedError is not the server fault, don't penalize it if _, ok := err.(*db.FeatureNotSupportedError); !ok { srv.notifyFailedConnect((*p.now)()) } - p.log.Warnf(log.Pool, p.logId, "Failed to connect to %s: %s", serverName, err) return nil, err } @@ -376,16 +351,13 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. return c, nil } -func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time) error { - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock in time when unregistering server") - } +func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time) { + p.serversMut.Lock() defer p.serversMut.Unlock() - - return p.unregUnlocked(ctx, serverName, c, now) + p.unregUnlocked(ctx, serverName, c, now) } -func (p *Pool) unregUnlocked(ctx context.Context, serverName string, c idb.Connection, now time.Time) error { +func (p *Pool) unregUnlocked(ctx context.Context, serverName string, c idb.Connection, now time.Time) { defer func() { // Close connection in another thread to avoid potential long blocking operation during close. go c.Close(ctx) @@ -395,33 +367,29 @@ func (p *Pool) unregUnlocked(ctx context.Context, serverName string, c idb.Conne // Check for strange condition of not finding the server. if server == nil { p.log.Warnf(log.Pool, p.logId, "Server %s not found", serverName) - return nil + return } server.unregisterBusy(c) if server.size() == 0 && !server.hasFailedConnect(now) { delete(p.servers, serverName) } - return nil } -func (p *Pool) removeIdleOlderThanOnServer(ctx context.Context, serverName string, now time.Time, maxAge time.Duration) error { - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock in time before removing old idle connections") - } +func (p *Pool) removeIdleOlderThanOnServer(ctx context.Context, serverName string, now time.Time, maxAge time.Duration) { + p.serversMut.Lock() defer p.serversMut.Unlock() server := p.servers[serverName] if server == nil { - return nil + return } server.removeIdleOlderThan(ctx, now, maxAge) - return nil } -func (p *Pool) Return(ctx context.Context, c idb.Connection) error { +func (p *Pool) Return(ctx context.Context, c idb.Connection) { if p.closed { p.log.Warnf(log.Pool, p.logId, "Trying to return connection to closed pool") - return nil + return } // Get the name of the server that the connection belongs to. @@ -441,9 +409,7 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error { maxAge = age } } - if err := p.removeIdleOlderThanOnServer(ctx, serverName, now, maxAge); err != nil { - return err - } + p.removeIdleOlderThanOnServer(ctx, serverName, now, maxAge) // Prepare connection for being used by someone else if is alive. // Since reset could find the connection to be in a bad state or non-recoverable state, @@ -457,20 +423,19 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error { // Shouldn't return a too old or dead connection back to the pool if !isAlive || age >= p.config.MaxConnectionLifetime { - if err := p.unreg(ctx, serverName, c, now); err != nil { - return err - } + p.unreg(ctx, serverName, c, now) p.log.Infof(log.Pool, p.logId, "Unregistering dead or too old connection to %s", serverName) } if isAlive { // Just put it back in the list of idle connections for this server - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock when putting connection back to idle") - } + p.serversMut.Lock() server := p.servers[serverName] if server != nil { // Strange when server not found - server.returnBusy(c) + server.returnBusy(ctx, c) + if server.closing && server.size() == 0 { + delete(p.servers, serverName) + } } else { p.log.Warnf(log.Pool, p.logId, "Server %s not found", serverName) } @@ -478,44 +443,70 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error { } // Check if there is anyone in the queue waiting for a connection to this server. - if !p.queueMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire queue lock when checking connection requests") - } + p.queueMut.Lock() for e := p.queue.Front(); e != nil; e = e.Next() { queuedRequest := e.Value.(*qitem) p.queue.Remove(e) queuedRequest.wakeup <- true } p.queueMut.Unlock() - - return nil } -func (p *Pool) OnConnectionError(ctx context.Context, connection idb.Connection, error *db.Neo4jError) error { - if error.Code == "Neo.ClientError.Security.AuthorizationExpired" { +func (p *Pool) OnNeo4jError(ctx context.Context, connection idb.Connection, error *db.Neo4jError) error { + switch error.Code { + case "Neo.ClientError.Security.AuthorizationExpired": serverName := connection.ServerName() - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError(fmt.Sprintf( - "could not acquire server lock in time before marking all connection to %s for re-authentication", - serverName)) - } + p.serversMut.Lock() defer p.serversMut.Unlock() server := p.servers[serverName] server.executeForAllConnections(func(c idb.Connection) { c.ResetAuth() }) - } else if error.Code == "Neo.ClientError.Security.TokenExpired" { + case "Neo.ClientError.Security.TokenExpired": manager, token := connection.GetCurrentAuth() if manager != nil { if err := manager.OnTokenExpired(ctx, token); err != nil { return err } - if _, isStaticToken := manager.(auth.Token); isStaticToken { - return nil - } else { + if _, isStaticToken := manager.(auth.Token); !isStaticToken { error.MarkRetriable() } } + case "Neo.TransientError.General.DatabaseUnavailable": + p.deactivate(ctx, connection.ServerName()) + default: + if error.IsRetriableCluster() { + var database string + if dbSelector, ok := connection.(idb.DatabaseSelector); ok { + database = dbSelector.Database() + } + p.deactivateWriter(connection.ServerName(), database) + } } + return nil } + +func (p *Pool) OnIoError(ctx context.Context, connection idb.Connection, _ error) { + p.deactivate(ctx, connection.ServerName()) +} + +func (p *Pool) OnDialError(ctx context.Context, serverName string, _ error) { + p.deactivate(ctx, serverName) +} + +func (p *Pool) deactivate(ctx context.Context, serverName string) { + p.log.Debugf(log.Pool, p.logId, "Deactivating server %s", serverName) + p.router.InvalidateServer(serverName) + p.serversMut.Lock() + defer p.serversMut.Unlock() + server := p.servers[serverName] + if server != nil { + server.startClosing(ctx) + } +} + +func (p *Pool) deactivateWriter(serverName string, db string) { + p.log.Debugf(log.Pool, p.logId, "Deactivating writer %s for database %s", serverName, db) + p.router.InvalidateWriter(db, serverName) +} diff --git a/neo4j/internal/pool/pool_test.go b/neo4j/internal/pool/pool_test.go index c7596ea6..9bea4863 100644 --- a/neo4j/internal/pool/pool_test.go +++ b/neo4j/internal/pool/pool_test.go @@ -23,9 +23,10 @@ import ( "context" "errors" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + db "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "math/rand" "sync" @@ -38,18 +39,18 @@ import ( var logger = &log.Void{} var ctx = context.Background() -var reAuthToken = &db.ReAuthToken{FromSession: false, Manager: iauth.Token{Tokens: map[string]any{"scheme": "none"}}} +var reAuthToken = &idb.ReAuthToken{FromSession: false, Manager: iauth.Token{Tokens: map[string]any{"scheme": "none"}}} func TestPoolBorrowReturn(outer *testing.T) { maxAge := 1 * time.Second birthdate := time.Now() - succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil } failingError := errors.New("whatever") - failingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + failingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return nil, failingError } @@ -58,19 +59,15 @@ func TestPoolBorrowReturn(outer *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srv1"} conn, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, conn, err) - if err := p.Return(ctx, conn); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, conn) // Make sure that connection actually returned - servers, err := p.getServers(ctx) + servers := p.getServers() if err != nil { t.Errorf("Should not fail retrieving servers, got: %v", err) } @@ -84,9 +81,7 @@ func TestPoolBorrowReturn(outer *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srv1"} wg := sync.WaitGroup{} @@ -105,12 +100,10 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) // Give back the connection - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c1) wg.Wait() }) @@ -119,9 +112,7 @@ func TestPoolBorrowReturn(outer *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srv1"} @@ -151,9 +142,7 @@ func TestPoolBorrowReturn(outer *testing.T) { c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c, err) time.Sleep(time.Duration(rand.Int()%7) * time.Millisecond) - if err := p.Return(ctx, c); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c) } wg.Done() } @@ -164,10 +153,7 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Wait() // Everything should be freed up, it's ok if there isn't a server as well... - servers, err := p.getServers(ctx) - if err != nil { - t.Errorf("Should not fail retrieving server, but got: %v", err) - } + servers := p.getServers() for _, v := range servers { if v.numIdle() != maxConnections { t.Error("A connection is still in use in the server") @@ -179,6 +165,7 @@ func TestPoolBorrowReturn(outer *testing.T) { timer := func() time.Time { return birthdate } conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} p := New(&conf, failingConnect, logger, "pool id", &timer) + p.SetRouter(&RouterFake{}) serverNames := []string{"srv1"} c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertNoConnection(t, c, err) @@ -202,12 +189,10 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) cancel() wg.Wait() - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c1) if err == nil { t.Error("There should be an error due to cancelling") } @@ -226,7 +211,7 @@ func TestPoolBorrowReturn(outer *testing.T) { timer := time.Now conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} pool := New(&conf, nil, logger, "pool id", &timer) - setIdleConnections(pool, map[string][]db.Connection{"a server": { + setIdleConnections(pool, map[string][]idb.Connection{"a server": { deadAfterReset, stayingAlive, whatATimeToBeAlive, @@ -250,7 +235,7 @@ func TestPoolBorrowReturn(outer *testing.T) { timer := time.Now conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} pool := New(&conf, connectTo(healthyConnection), logger, "pool id", &timer) - setIdleConnections(pool, map[string][]db.Connection{serverName: {deadAfterReset1, deadAfterReset2}}) + setIdleConnections(pool, map[string][]idb.Connection{serverName: {deadAfterReset1, deadAfterReset2}}) result, err := pool.tryBorrow(ctx, serverName, nil, idlenessThreshold, reAuthToken) @@ -276,10 +261,10 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) // break the connection. then it shouldn't be picked up by the waiting borrow c1.(*ConnFake).Alive = false - err = p.Return(ctx, c1) + p.Return(ctx, c1) AssertNoError(t, err) wg.Wait() }) @@ -288,7 +273,7 @@ func TestPoolBorrowReturn(outer *testing.T) { token2 := iauth.Token{Tokens: map[string]any{"scheme": "foobar"}} // sanity check AssertNotDeepEquals(t, reAuthToken.Manager, token2) - reAuthToken2 := &db.ReAuthToken{FromSession: false, Manager: token2} + reAuthToken2 := &idb.ReAuthToken{FromSession: false, Manager: token2} timer := func() time.Time { return birthdate } conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) @@ -304,14 +289,14 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) reAuthCalled := false - c1.(*ConnFake).ReAuthHook = func(_ context.Context, token *db.ReAuthToken) error { + c1.(*ConnFake).ReAuthHook = func(_ context.Context, token *idb.ReAuthToken) error { AssertDeepEquals(t, token.Manager, token2) reAuthCalled = true return nil } - err = p.Return(ctx, c1) + p.Return(ctx, c1) AssertNoError(t, err) wg.Wait() AssertTrue(t, reAuthCalled) @@ -323,7 +308,7 @@ func TestPoolResourceUsage(ot *testing.T) { maxAge := 1 * time.Second birthdate := time.Now() - succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil } @@ -332,9 +317,7 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA", "srvB", "srvC", "srvD"} c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) @@ -348,20 +331,13 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA"} c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) c.(*ConnFake).Alive = false - if err := p.Return(ctx, c); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - servers, err := p.getServers(ctx) - if err != nil { - t.Errorf("Should not fail retrieving server but got: %v", err) - } + p.Return(ctx, c) + servers := p.getServers() if len(servers) > 0 && servers[serverNames[0]].size() > 0 { t.Errorf("Should have either removed the server or kept it but emptied it") } @@ -372,19 +348,12 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA"} c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) - if err := p.Return(ctx, c); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - servers, err := p.getServers(ctx) - if err != nil { - t.Errorf("Should not fail retrieving server but got: %v", err) - } + p.Return(ctx, c) + servers := p.getServers() if len(servers) > 0 && servers[serverNames[0]].size() > 0 { t.Errorf("Should have either removed the server or kept it but emptied it") } @@ -407,20 +376,14 @@ func TestPoolResourceUsage(ot *testing.T) { c3.(*ConnFake).Birth = nowTime.Add(1 * time.Second) c3.(*ConnFake).Id = 3 // Return the old and young connections to make them idle - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - if err := p.Return(ctx, c3); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 1) - assertNumberOfIdle(t, ctx, p, "A", 2) + p.Return(ctx, c1) + p.Return(ctx, c3) + assertNumberOfServers(t, p, 1) + assertNumberOfIdle(t, p, "A", 2) // Kill the middle-aged connection and return it c2.(*ConnFake).Alive = false - if err := p.Return(ctx, c2); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfIdle(t, ctx, p, "A", 1) + p.Return(ctx, c2) + assertNumberOfIdle(t, p, "A", 1) }) ot.Run("Do not borrow too old connections", func(t *testing.T) { @@ -434,17 +397,13 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA"} c1, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) c1.(*ConnFake).Id = 123 // It's alive when returning it - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c1) nowMut.Lock() now = now.Add(2 * maxAge) nowMut.Unlock() @@ -460,27 +419,25 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c1, err) c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c2, err) - assertNumberOfServers(t, ctx, p, 2) + assertNumberOfServers(t, p, 2) }) } func TestPoolCleanup(ot *testing.T) { birthdate := time.Now() maxLife := 1 * time.Second - succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil } // Borrows a connection in server A and another in server B - borrowConnections := func(t *testing.T, p *Pool) (db.Connection, db.Connection) { + borrowConnections := func(t *testing.T, p *Pool) (idb.Connection, idb.Connection) { c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c1, err) c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) @@ -493,27 +450,19 @@ func TestPoolCleanup(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() c1, c2 := borrowConnections(t, p) - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - if err := p.Return(ctx, c2); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 2) - assertNumberOfIdle(t, ctx, p, "A", 1) - assertNumberOfIdle(t, ctx, p, "B", 1) + p.Return(ctx, c1) + p.Return(ctx, c2) + assertNumberOfServers(t, p, 2) + assertNumberOfIdle(t, p, "A", 1) + assertNumberOfIdle(t, p, "B", 1) // Now go into the future and cleanup, should remove both servers and close the connections timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 0) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 0) }) ot.Run("Should not remove servers with busy connections", func(t *testing.T) { @@ -521,59 +470,48 @@ func TestPoolCleanup(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() _, c2 := borrowConnections(t, p) - if err := p.Return(ctx, c2); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 2) - assertNumberOfIdle(t, ctx, p, "A", 0) - assertNumberOfIdle(t, ctx, p, "B", 1) + p.Return(ctx, c2) + assertNumberOfServers(t, p, 2) + assertNumberOfIdle(t, p, "A", 0) + assertNumberOfIdle(t, p, "B", 1) // Now go into the future and cleanup, should only remove B timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 1) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 1) }) ot.Run("Should not remove servers with only idle connections but with recent connect failures ", func(t *testing.T) { - failingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + failingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return nil, errors.New("an error") } timer := time.Now conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} p := New(&conf, failingConnect, logger, "pool id", &timer) + p.SetRouter(&RouterFake{}) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertNoConnection(t, c1, err) - assertNumberOfServers(t, ctx, p, 1) - assertNumberOfIdle(t, ctx, p, "A", 0) + assertNumberOfServers(t, p, 1) + assertNumberOfIdle(t, p, "A", 0) // Now go into the future and cleanup, should not remove server A even if it has no connections since // we should remember the failure a bit longer timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 1) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 1) // Further in the future, the failure should have been forgotten timer = func() time.Time { return birthdate.Add(maxLife).Add(rememberFailedConnectDuration).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 0) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 0) }) ot.Run("wakes up borrowers when closing", func(t *testing.T) { @@ -592,9 +530,9 @@ func TestPoolCleanup(ot *testing.T) { _, err := p.Borrow(ctx, servers, true, nil, DefaultLivenessCheckThreshold, reAuthToken) borrowErrChan <- err }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) - AssertNoError(t, p.Close(ctx)) + p.Close(ctx) select { case err := <-borrowErrChan: @@ -605,13 +543,115 @@ func TestPoolCleanup(ot *testing.T) { }) } -func connectTo(singleConnection *ConnFake) func(ctx context.Context, name string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { - return func(ctx context.Context, name string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { +func TestPoolErrorHanding(ot *testing.T) { + const ServerName = "A" + const DbName = "some database" + + type TestCase struct { + name string + errorCall func(bolt.ConnectionErrorListener, idb.Connection) error + expectedInvalidateMode string + expectedInvalidatedServer string + expectedInvalidatedDb string + } + + dbUnavailableErr := db.Neo4jError{Code: "Neo.TransientError.General.DatabaseUnavailable"} + notALeaderErr := db.Neo4jError{Code: "Neo.ClientError.Cluster.NotALeader"} + forbiddenRoDbError := db.Neo4jError{Code: "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase"} + + cases := []TestCase{ + { + name: "should invalidate server on io error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + listener.OnIoError(ctx, conn, errors.New("an error")) + return nil + }, + expectedInvalidatedServer: ServerName, + }, + { + name: "should invalidate server on dial error", + errorCall: func(listener bolt.ConnectionErrorListener, _ idb.Connection) error { + listener.OnDialError(ctx, "what ever server", errors.New("an error")) + return nil + }, + expectedInvalidatedServer: "what ever server", + }, + { + name: "should invalidate server on dial error", + errorCall: func(listener bolt.ConnectionErrorListener, _ idb.Connection) error { + listener.OnDialError(ctx, ServerName, errors.New("an error")) + return nil + }, + expectedInvalidatedServer: ServerName, + }, + { + name: "should invalidate server on db unavailable error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + return listener.OnNeo4jError(ctx, conn, &dbUnavailableErr) + }, + expectedInvalidatedServer: ServerName, + }, + { + name: "should invalidate writer for db on not a leader error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + return listener.OnNeo4jError(ctx, conn, ¬ALeaderErr) + }, + expectedInvalidateMode: "writer", + expectedInvalidatedServer: ServerName, + expectedInvalidatedDb: DbName, + }, + { + name: "should invalidate writer for db on forbidden error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + return listener.OnNeo4jError(ctx, conn, &forbiddenRoDbError) + }, + expectedInvalidateMode: "writer", + expectedInvalidatedServer: ServerName, + expectedInvalidatedDb: DbName, + }, + } + + for _, testCase := range cases { + ot.Run(testCase.name, func(t *testing.T) { + errorListeners := make([]bolt.ConnectionErrorListener, 0) + connections := make([]*ConnFake, 0) + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, errorListener bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { + errorListeners = append(errorListeners, errorListener) + connection := ConnFake{Name: s, Alive: true, DatabaseName: DbName} + connections = append(connections, &connection) + return &connection, nil + } + + now := time.Now + router := RouterFake{} + p := New(&config.Config{}, succeedingConnect, logger, "pool id", &now) + p.SetRouter(&router) + defer p.Close(ctx) + conn, err := p.Borrow(ctx, getServers([]string{ServerName}), false, nil, DefaultLivenessCheckThreshold, reAuthToken) + assertConnection(t, conn, err) + AssertLen(t, errorListeners, 1) + AssertLen(t, connections, 1) + errorListener := errorListeners[0] + connection := connections[0] + AssertFalse(t, router.Invalidated) + + err = testCase.errorCall(errorListener, connection) + AssertNoError(t, err) + AssertTrue(t, router.Invalidated) + AssertStringEqual(t, router.InvalidateMode, testCase.expectedInvalidateMode) + AssertStringEqual(t, router.InvalidatedServer, testCase.expectedInvalidatedServer) + AssertStringEqual(t, router.InvalidatedDb, testCase.expectedInvalidatedDb) + }) + } +} + +func connectTo(singleConnection *ConnFake) func(ctx context.Context, name string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { + return func(ctx context.Context, name string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return singleConnection, nil } } -func setIdleConnections(pool *Pool, servers map[string][]db.Connection) { +func setIdleConnections(pool *Pool, servers map[string][]idb.Connection) { poolServers := make(map[string]*server, len(servers)) for serverName, connections := range servers { srv := NewServer() @@ -633,18 +673,14 @@ func deadConnectionAfterForceReset(name string, idleness time.Time) *ConnFake { return result } -func getServers(servers []string) func(context.Context) ([]string, error) { - return func(context.Context) ([]string, error) { - return servers, nil +func getServers(servers []string) func() []string { + return func() []string { + return servers } } -func waitForBorrowers(t *testing.T, p *Pool, minBorrowers int) { - for { - if size, err := p.queueSize(ctx); err != nil { - t.Errorf("should not fail computing queue size, got: %v", err) - } else if size >= minBorrowers { - break - } +func waitForBorrowers(p *Pool, minBorrowers int) { + for p.queueSize() < minBorrowers { + // still waiting } } diff --git a/neo4j/internal/pool/server.go b/neo4j/internal/pool/server.go index 2ac1b53c..8e5de873 100644 --- a/neo4j/internal/pool/server.go +++ b/neo4j/internal/pool/server.go @@ -37,6 +37,7 @@ type server struct { reservations int failedConnectAt time.Time roundRobin uint32 + closing bool } func NewServer() *server { @@ -137,9 +138,13 @@ func (s *server) calculatePenalty(now time.Time) uint32 { } // Returns a busy connection, makes it idle -func (s *server) returnBusy(c db.Connection) { +func (s *server) returnBusy(ctx context.Context, c db.Connection) { s.unregisterBusy(c) - s.idle.PushFront(c) + if s.closing { + c.Close(ctx) + } else { + s.idle.PushFront(c) + } } // Number of idle connections @@ -206,10 +211,15 @@ func (s *server) executeForAllConnections(callback func(c db.Connection)) { } } +func (s *server) startClosing(ctx context.Context) { + s.closing = true + closeAndEmptyConnections(ctx, s.idle) +} + func closeAndEmptyConnections(ctx context.Context, l list.List) { for e := l.Front(); e != nil; e = e.Next() { c := e.Value.(db.Connection) - c.Close(ctx) + go c.Close(ctx) } l.Init() } diff --git a/neo4j/internal/pool/server_test.go b/neo4j/internal/pool/server_test.go index f9a42757..46c26f24 100644 --- a/neo4j/internal/pool/server_test.go +++ b/neo4j/internal/pool/server_test.go @@ -80,7 +80,7 @@ func TestServer(ot *testing.T) { c3 := s.getIdle() assertNilConnection(t, c3) - s.returnBusy(c2) + s.returnBusy(context.Background(), c2) c3 = s.getIdle() assertConnection(t, c3) }) @@ -110,8 +110,8 @@ func TestServer(ot *testing.T) { assertNilConnection(t, b3) // Return the connections and let all of them be too old - s.returnBusy(b1) - s.returnBusy(b2) + s.returnBusy(context.Background(), b1) + s.returnBusy(context.Background(), b2) conns[0].Birth = now.Add(-20 * time.Second) conns[2].Birth = now.Add(-20 * time.Second) s.removeIdleOlderThan(context.Background(), now, 10*time.Second) @@ -146,7 +146,7 @@ func TestServerPenalty(t *testing.T) { // Return the busy connection to srv1 // Now srv2 should have higher penalty than srv1 since using srv2 would require a new // connection. - srv1.returnBusy(c11) + srv1.returnBusy(context.Background(), c11) assertPenaltiesGreaterThan(srv2, srv1, now) // Add an idle connection to srv2 to make both servers have one idle connection each. @@ -162,7 +162,7 @@ func TestServerPenalty(t *testing.T) { idle := srv1.getIdle() _, _ = srv1.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) testutil.AssertDeepEquals(t, idle, c11) - srv1.returnBusy(c11) + srv1.returnBusy(context.Background(), c11) assertPenaltiesGreaterThan(srv1, srv2, now) // Add one more connection each to the servers @@ -187,10 +187,10 @@ func TestServerPenalty(t *testing.T) { // Return the connections idle = srv2.getIdle() _, _ = srv2.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) - srv2.returnBusy(c21) - srv2.returnBusy(c22) - srv1.returnBusy(c11) - srv1.returnBusy(c12) + srv2.returnBusy(context.Background(), c21) + srv2.returnBusy(context.Background(), c22) + srv1.returnBusy(context.Background(), c11) + srv1.returnBusy(context.Background(), c12) // Everything returned, srv2 should have higher penalty since it was last used assertPenaltiesGreaterThan(srv2, srv1, now) @@ -290,5 +290,5 @@ func TestIdlenessThreshold(outer *testing.T) { func registerIdle(srv *server, connection db.Connection) { srv.registerBusy(connection) - srv.returnBusy(connection) + srv.returnBusy(context.Background(), connection) } diff --git a/neo4j/internal/retry/state.go b/neo4j/internal/retry/state.go index 29de6e46..c4b5329c 100644 --- a/neo4j/internal/retry/state.go +++ b/neo4j/internal/retry/state.go @@ -32,10 +32,6 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) -type Router interface { - Invalidate(ctx context.Context, database string) error -} - type State struct { Errs []error MaxTransactionRetryTime time.Duration @@ -46,26 +42,22 @@ type State struct { Sleep func(time.Duration) Throttle Throttler MaxDeadConnections int - Router Router DatabaseName string - start time.Time - cause string - deadErrors int - skipSleep bool - OnDeadConnection func(server string) error + start time.Time + cause string + deadErrors int + skipSleep bool } -func (s *State) OnFailure(ctx context.Context, err error, conn idb.Connection, isCommitting bool) { +func (s *State) OnFailure(_ context.Context, err error, conn idb.Connection, isCommitting bool) { if conn != nil && !conn.IsAlive() { if isCommitting { + // FIXME: CommitFailedDeadError should be returned even when not using transaction functions s.Errs = append(s.Errs, &errorutil.CommitFailedDeadError{Inner: err}) } else { s.Errs = append(s.Errs, err) } - if err := s.OnDeadConnection(conn.ServerName()); err != nil { - s.Errs = append(s.Errs, err) - } s.deadErrors += 1 s.skipSleep = true return @@ -73,13 +65,6 @@ func (s *State) OnFailure(ctx context.Context, err error, conn idb.Connection, i s.Errs = append(s.Errs, err) s.skipSleep = false - - if dbErr, isDbErr := err.(*db.Neo4jError); isDbErr && dbErr.IsRetriableCluster() { - if err := s.Router.Invalidate(ctx, s.DatabaseName); err != nil { - s.Errs = append(s.Errs, err) - } - } - } func (s *State) Continue() bool { diff --git a/neo4j/internal/retry/state_test.go b/neo4j/internal/retry/state_test.go index 93a08302..8b0cc4dc 100644 --- a/neo4j/internal/retry/state_test.go +++ b/neo4j/internal/retry/state_test.go @@ -35,16 +35,13 @@ import ( ) type TStateInvocation struct { - conn idb.Connection - err error - isCommitting bool - now time.Time - expectContinued bool - expectRouterInvalidated bool - expectRouterInvalidatedDb string - expectRouterInvalidatedServer string - expectLastErrWasRetryable bool - expectLastErrType error + conn idb.Connection + err error + isCommitting bool + now time.Time + expectContinued bool + expectLastErrWasRetryable bool + expectLastErrType error } func TestState(outer *testing.T) { @@ -78,39 +75,29 @@ func TestState(outer *testing.T) { "Retry dead connection": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error"), expectContinued: false, - expectLastErrWasRetryable: false, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: false}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, - expectContinued: true, expectLastErrWasRetryable: true, - expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectContinued: true, expectLastErrWasRetryable: true}, }, "Retry dead connection timeout": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, - expectContinued: true, now: baseTime, expectLastErrWasRetryable: true, - expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, - expectRouterInvalidatedServer: serverName}, + expectContinued: true, now: baseTime, expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 2"), expectContinued: false, now: overTime, - expectLastErrWasRetryable: false, - expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: false}, }, "Retry dead connection max": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: true, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: true, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: false, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: true}, }, "Cluster error": { {conn: &testutil.ConnFake{Alive: true}, err: clusterErr, expectContinued: true, - expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectLastErrWasRetryable: true}, + expectLastErrWasRetryable: true}, }, "Database transient error": { {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: true, @@ -128,17 +115,14 @@ func TestState(outer *testing.T) { }, "Fail during commit": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: io.EOF, isCommitting: true, expectContinued: false, - expectLastErrWasRetryable: false, expectLastErrType: &errorutil.CommitFailedDeadError{}, - expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: false, expectLastErrType: &errorutil.CommitFailedDeadError{}}, }, "Fail during commit after retry": { {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: true, expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: io.EOF, isCommitting: true, expectContinued: false, expectLastErrWasRetryable: false, - expectLastErrType: &errorutil.CommitFailedDeadError{}, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrType: &errorutil.CommitFailedDeadError{}}, }, "Does not retry on auth errors": { {conn: nil, err: authErr, expectContinued: false, @@ -176,24 +160,12 @@ func TestState(outer *testing.T) { if !invocation.now.IsZero() { now = invocation.now } - router := &testutil.RouterFake{} - state.Router = router - state.OnDeadConnection = func(server string) error { - return router.InvalidateReader(ctx, dbName, server) - } state.OnFailure(ctx, invocation.err, invocation.conn, invocation.isCommitting) continued := state.Continue() if continued != invocation.expectContinued { t.Errorf("Expected continue to return %v but returned %v", invocation.expectContinued, continued) } - if invocation.expectRouterInvalidated != router.Invalidated || - invocation.expectRouterInvalidatedDb != router.InvalidatedDb || - invocation.expectRouterInvalidatedServer != router.InvalidatedServer { - t.Errorf("Expected router invalidated: expected (%v/%s/%s) vs. actual (%v/%s/%s)", - invocation.expectRouterInvalidated, invocation.expectRouterInvalidatedDb, invocation.expectRouterInvalidatedServer, - router.Invalidated, router.InvalidatedDb, router.InvalidatedServer) - } var lastError error if err, ok := state.Errs[0].(*errorutil.TransactionExecutionLimit); ok { errs := err.Errors diff --git a/neo4j/internal/router/no_test.go b/neo4j/internal/router/no_test.go index c64102ea..587957a6 100644 --- a/neo4j/internal/router/no_test.go +++ b/neo4j/internal/router/no_test.go @@ -32,15 +32,11 @@ type poolFake struct { cancel context.CancelFunc } -func (p *poolFake) Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), _ bool, logger log.BoltLogger, _ time.Duration, _ *db.ReAuthToken) (db.Connection, error) { - servers, err := getServers(ctx) - if err != nil { - return nil, err - } +func (p *poolFake) Borrow(_ context.Context, getServers func() []string, _ bool, logger log.BoltLogger, _ time.Duration, _ *db.ReAuthToken) (db.Connection, error) { + servers := getServers() return p.borrow(servers, p.cancel, logger) } -func (p *poolFake) Return(_ context.Context, c db.Connection) error { +func (p *poolFake) Return(_ context.Context, c db.Connection) { p.returned = append(p.returned, c) - return nil } diff --git a/neo4j/internal/router/readtable.go b/neo4j/internal/router/readtable.go index bffc8420..15f12907 100644 --- a/neo4j/internal/router/readtable.go +++ b/neo4j/internal/router/readtable.go @@ -75,8 +75,8 @@ func readTable( return nil, err } -func getStaticServer(server string) func(context.Context) ([]string, error) { - return func(context.Context) ([]string, error) { - return []string{server}, nil +func getStaticServer(server string) func() []string { + return func() []string { + return []string{server} } } diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index ac82dcb5..ec68d689 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -26,6 +26,7 @@ import ( idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" + "sync" "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" @@ -44,7 +45,8 @@ type Router struct { routerContext map[string]string pool Pool dbRouters map[string]*databaseRouter - dbRoutersMut racing.Mutex + updating map[string][]chan struct{} + dbRoutersMut sync.Mutex now *func() time.Time sleep func(time.Duration) rootRouter string @@ -58,8 +60,8 @@ type Pool interface { // If all connections are busy and the pool is full, calls to Borrow may wait for a connection to become idle // If a connection has been idle for longer than idlenessThreshold, it will be reset // to check if it's still alive. - Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) - Return(ctx context.Context, c idb.Connection) error + Borrow(ctx context.Context, getServers func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Return(ctx context.Context, c idb.Connection) } func New(rootRouter string, getRouters func() []string, routerContext map[string]string, pool Pool, logger log.Logger, logId string, timer *func() time.Time) *Router { @@ -69,7 +71,8 @@ func New(rootRouter string, getRouters func() []string, routerContext map[string routerContext: routerContext, pool: pool, dbRouters: make(map[string]*databaseRouter), - dbRoutersMut: racing.NewMutex(), + updating: make(map[string][]chan struct{}), + dbRoutersMut: sync.Mutex{}, now: timer, sleep: time.Sleep, log: logger, @@ -139,40 +142,63 @@ func (r *Router) readTable( return table, nil } -func (r *Router) getTable(ctx context.Context, database string) (*idb.RoutingTable, error) { - if !r.dbRoutersMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire router lock in time when getting routing table") - } +func (r *Router) getTable(database string) *idb.RoutingTable { + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() dbRouter := r.dbRouters[database] - return r.getTableLocked(dbRouter, (*r.now)()), nil + return r.getTableLocked(dbRouter) } func (r *Router) getOrUpdateTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) (*idb.RoutingTable, error) { - now := (*r.now)() - - if !r.dbRoutersMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire router lock in time when getting routing table") - } - defer r.dbRoutersMut.Unlock() - - dbRouter := r.dbRouters[database] - if table := r.getTableLocked(dbRouter, now); table != nil { - return table, nil + r.dbRoutersMut.Lock() + var unlock = new(sync.Once) + defer unlock.Do(r.dbRoutersMut.Unlock) + for { + dbRouter := r.dbRouters[database] + if table := r.getTableLocked(dbRouter); table != nil { + return table, nil + } + waiters, ok := r.updating[database] + if ok { + // Wait for the table to be updated by other goroutine + ch := make(chan struct{}) + r.updating[database] = append(waiters, ch) + unlock.Do(r.dbRoutersMut.Unlock) + select { + case <-ctx.Done(): + return nil, racing.LockTimeoutError("timed out waiting for other goroutine to update routing table") + case <-ch: + r.dbRoutersMut.Lock() + *unlock = sync.Once{} + continue + } + } + // this goroutine will update the table + r.updating[database] = make([]chan struct{}, 0) + unlock.Do(r.dbRoutersMut.Unlock) + + table, err := r.updateTable(ctx, bookmarksFn, database, auth, boltLogger, dbRouter) + r.dbRoutersMut.Lock() + *unlock = sync.Once{} + // notify all waiters + for _, waiter := range r.updating[database] { + waiter <- struct{}{} + } + delete(r.updating, database) + return table, err } - - return r.updateTable(ctx, bookmarksFn, database, auth, boltLogger, dbRouter, now) } -func (r *Router) getTableLocked(dbRouter *databaseRouter, now time.Time) *idb.RoutingTable { +func (r *Router) getTableLocked(dbRouter *databaseRouter) *idb.RoutingTable { + now := (*r.now)() if dbRouter != nil && now.Unix() < dbRouter.dueUnix { return dbRouter.table } return nil } -func (r *Router) updateTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger, dbRouter *databaseRouter, now time.Time) (*idb.RoutingTable, error) { +func (r *Router) updateTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger, dbRouter *databaseRouter) (*idb.RoutingTable, error) { bookmarks, err := bookmarksFn(ctx) if err != nil { return nil, err @@ -182,7 +208,10 @@ func (r *Router) updateTable(ctx context.Context, bookmarksFn func(context.Conte return nil, err } - r.storeRoutingTable(database, table, now) + err = r.storeRoutingTable(ctx, database, table, (*r.now)()) + if err != nil { + return nil, err + } return table, nil } @@ -201,9 +230,7 @@ func (r *Router) GetOrUpdateReaders(ctx context.Context, bookmarks func(context. break } r.log.Infof(log.Router, r.logId, "Invalidating routing table, no readers") - if err := r.Invalidate(ctx, table.DatabaseName); err != nil { - return nil, err - } + r.Invalidate(table.DatabaseName) r.sleep(100 * time.Millisecond) table, err = r.getOrUpdateTable(ctx, bookmarks, database, auth, boltLogger) if err != nil { @@ -217,15 +244,12 @@ func (r *Router) GetOrUpdateReaders(ctx context.Context, bookmarks func(context. return table.Readers, nil } -func (r *Router) Readers(ctx context.Context, database string) ([]string, error) { - table, err := r.getTable(ctx, database) - if err != nil { - return nil, err - } +func (r *Router) Readers(database string) []string { + table := r.getTable(database) if table == nil { - return nil, nil + return nil } - return table.Readers, nil + return table.Readers } func (r *Router) GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error) { @@ -242,9 +266,7 @@ func (r *Router) GetOrUpdateWriters(ctx context.Context, bookmarks func(context. break } r.log.Infof(log.Router, r.logId, "Invalidating routing table, no writers") - if err := r.Invalidate(ctx, database); err != nil { - return nil, err - } + r.Invalidate(database) r.sleep(100 * time.Millisecond) table, err = r.getOrUpdateTable(ctx, bookmarks, database, auth, boltLogger) if err != nil { @@ -258,29 +280,26 @@ func (r *Router) GetOrUpdateWriters(ctx context.Context, bookmarks func(context. return table.Writers, nil } -func (r *Router) Writers(ctx context.Context, database string) ([]string, error) { - table, err := r.getTable(ctx, database) - if err != nil { - return nil, err - } +func (r *Router) Writers(database string) []string { + table := r.getTable(database) if table == nil { - return nil, nil + return nil } - return table.Writers, nil + return table.Writers } func (r *Router) GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) (string, error) { + // FIXME: this seems to indirectly cache the home db for the routing table's TTL table, err := r.readTable(ctx, nil, bookmarks, idb.DefaultDatabase, user, auth, boltLogger) if err != nil { return "", err } // Store the fresh routing table as well to avoid another roundtrip to receive servers from session. now := (*r.now)() - if !r.dbRoutersMut.TryLock(ctx) { - return "", racing.LockTimeoutError("could not acquire router lock in time when resolving home database") + err = r.storeRoutingTable(ctx, table.DatabaseName, table, now) + if err != nil { + return "", err } - defer r.dbRoutersMut.Unlock() - r.storeRoutingTable(table.DatabaseName, table, now) return table.DatabaseName, err } @@ -288,11 +307,9 @@ func (r *Router) Context() map[string]string { return r.routerContext } -func (r *Router) Invalidate(ctx context.Context, database string) error { +func (r *Router) Invalidate(database string) { r.log.Infof(log.Router, r.logId, "Invalidating routing table for '%s'", database) - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when invalidating database router") - } + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() // Reset due time to the 70s, this will make next access refresh the routing table using // last set of routers instead of the original one. @@ -300,55 +317,53 @@ func (r *Router) Invalidate(ctx context.Context, database string) error { if dbRouter != nil { dbRouter.dueUnix = 0 } - return nil } -func (r *Router) InvalidateWriter(ctx context.Context, db string, server string) error { - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when getting routing table") - } +func (r *Router) InvalidateWriter(db string, server string) { + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() router := r.dbRouters[db] if router == nil { - return nil + return } - writers := router.table.Writers - for i, writer := range writers { - if writer == server { - router.table.Writers = append(writers[0:i], writers[i+1:]...) - return nil - } - } - return nil + router.table.Writers = removeServerFromList(router.table.Writers, server) } -func (r *Router) InvalidateReader(ctx context.Context, db string, server string) error { - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when invalidating reader") - } +func (r *Router) InvalidateReader(db string, server string) { + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() router := r.dbRouters[db] if router == nil { - return nil + return } - readers := router.table.Readers - for i, reader := range readers { - if reader == server { - router.table.Readers = append(readers[0:i], readers[i+1:]...) - return nil + router.table.Readers = removeServerFromList(router.table.Readers, server) +} + +func (r *Router) InvalidateServer(server string) { + r.dbRoutersMut.Lock() + defer r.dbRoutersMut.Unlock() + for _, routing := range r.dbRouters { + routing.table.Routers = removeServerFromList(routing.table.Routers, server) + routing.table.Readers = removeServerFromList(routing.table.Readers, server) + routing.table.Writers = removeServerFromList(routing.table.Writers, server) + } +} + +func removeServerFromList(list []string, server string) []string { + for i, s := range list { + if s == server { + return append(list[0:i], list[i+1:]...) } } - return nil + return list } -func (r *Router) CleanUp(ctx context.Context) error { +func (r *Router) CleanUp() { r.log.Debugf(log.Router, r.logId, "Cleaning up") now := (*r.now)().Unix() - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when invalidating reader") - } + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() for dbName, dbRouter := range r.dbRouters { @@ -356,15 +371,17 @@ func (r *Router) CleanUp(ctx context.Context) error { delete(r.dbRouters, dbName) } } - return nil } -func (r *Router) storeRoutingTable(database string, table *idb.RoutingTable, now time.Time) { +func (r *Router) storeRoutingTable(ctx context.Context, database string, table *idb.RoutingTable, now time.Time) error { + r.dbRoutersMut.Lock() + defer r.dbRoutersMut.Unlock() r.dbRouters[database] = &databaseRouter{ table: table, dueUnix: now.Add(time.Duration(table.TimeToLive) * time.Second).Unix(), } r.log.Debugf(log.Router, r.logId, "New routing table for '%s', TTL %d", database, table.TimeToLive) + return nil } func wrapError(server string, err error) error { diff --git a/neo4j/internal/router/router_test.go b/neo4j/internal/router/router_test.go index 43323931..98eda90d 100644 --- a/neo4j/internal/router/router_test.go +++ b/neo4j/internal/router/router_test.go @@ -140,9 +140,7 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) { assertNum(t, numfetch, 2, "Should not have have fetched") // Invalidate should force fetching - if err := router.Invalidate(ctx, dbName); err != nil { - testutil.AssertNoError(t, err) - } + router.Invalidate(dbName) if _, err := router.GetOrUpdateReaders(ctx, nilBookmarks, dbName, nil, nil); err != nil { testutil.AssertNoError(t, err) } @@ -387,17 +385,13 @@ func TestCleanUp(t *testing.T) { } // Should not remove these since they still have time to live - if err := router.CleanUp(ctx); err != nil { - testutil.AssertNoError(t, err) - } + router.CleanUp() if len(router.dbRouters) != 2 { t.Fatal("Should not have removed routing tables") } timer = func() time.Time { return now.Add(1 * time.Minute) } - if err := router.CleanUp(ctx); err != nil { - testutil.AssertNoError(t, err) - } + router.CleanUp() if len(router.dbRouters) != 0 { t.Fatal("Should have cleaned up") } diff --git a/neo4j/internal/router/router_testkit.go b/neo4j/internal/router/router_testkit.go new file mode 100644 index 00000000..e4f0d2f2 --- /dev/null +++ b/neo4j/internal/router/router_testkit.go @@ -0,0 +1,28 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//go:build internal_testkit + +package router + +import idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + +func (r *Router) GetTable(database string) *idb.RoutingTable { + return r.getTable(database) +} diff --git a/neo4j/internal/testutil/connfake.go b/neo4j/internal/testutil/connfake.go index 906c64f2..577ec429 100644 --- a/neo4j/internal/testutil/connfake.go +++ b/neo4j/internal/testutil/connfake.go @@ -193,6 +193,10 @@ func (c *ConnFake) SelectDatabase(database string) { c.DatabaseName = database } +func (c *ConnFake) Database() string { + return c.DatabaseName +} + func (c *ConnFake) SetBoltLogger(log.BoltLogger) { } diff --git a/neo4j/internal/testutil/poolfake.go b/neo4j/internal/testutil/poolfake.go index 8e836ba7..0c569ca4 100644 --- a/neo4j/internal/testutil/poolfake.go +++ b/neo4j/internal/testutil/poolfake.go @@ -34,7 +34,7 @@ type PoolFake struct { BorrowHook func() (db.Connection, error) } -func (p *PoolFake) Borrow(context.Context, func(context.Context) ([]string, error), bool, log.BoltLogger, time.Duration, *db.ReAuthToken) (db.Connection, error) { +func (p *PoolFake) Borrow(context.Context, func() []string, bool, log.BoltLogger, time.Duration, *db.ReAuthToken) (db.Connection, error) { if p.BorrowHook != nil && (p.BorrowConn != nil || p.BorrowErr != nil) { panic("either use the hook or the desired return values, but not both") } @@ -44,18 +44,16 @@ func (p *PoolFake) Borrow(context.Context, func(context.Context) ([]string, erro return p.BorrowConn, p.BorrowErr } -func (p *PoolFake) Return(context.Context, db.Connection) error { +func (p *PoolFake) Return(context.Context, db.Connection) { if p.ReturnHook != nil { p.ReturnHook() } - return nil } -func (p *PoolFake) CleanUp(context.Context) error { +func (p *PoolFake) CleanUp(context.Context) { if p.CleanUpHook != nil { p.CleanUpHook() } - return nil } func (p *PoolFake) Now() time.Time { diff --git a/neo4j/internal/testutil/routerfake.go b/neo4j/internal/testutil/routerfake.go index 83a4ec3e..1c3422c6 100644 --- a/neo4j/internal/testutil/routerfake.go +++ b/neo4j/internal/testutil/routerfake.go @@ -28,6 +28,8 @@ import ( type RouterFake struct { Invalidated bool InvalidatedDb string + InvalidateMode string + InvalidatedServer string GetOrUpdateReadersRet []string GetOrUpdateReadersHook func(bookmarks func(context.Context) ([]string, error), database string) ([]string, error) GetOrUpdateWritersRet []string @@ -35,25 +37,28 @@ type RouterFake struct { Err error CleanUpHook func() GetNameOfDefaultDbHook func(user string) (string, error) - InvalidatedServer string } -func (r *RouterFake) InvalidateReader(ctx context.Context, database string, server string) error { - if err := r.Invalidate(ctx, database); err != nil { - return err - } +func (r *RouterFake) InvalidateReader(database string, server string) { + r.Invalidate(database) r.InvalidatedServer = server - return nil + r.InvalidateMode = "reader" } -func (r *RouterFake) InvalidateWriter(context.Context, string, string) error { - return nil +func (r *RouterFake) InvalidateWriter(database string, server string) { + r.Invalidate(database) + r.InvalidatedServer = server + r.InvalidateMode = "writer" +} + +func (r *RouterFake) InvalidateServer(server string) { + r.Invalidated = true + r.InvalidatedServer = server } -func (r *RouterFake) Invalidate(ctx context.Context, database string) error { +func (r *RouterFake) Invalidate(database string) { r.InvalidatedDb = database r.Invalidated = true - return nil } func (r *RouterFake) GetOrUpdateReaders(_ context.Context, bookmarksFn func(context.Context) ([]string, error), database string, _ *db.ReAuthToken, _ log.BoltLogger) ([]string, error) { @@ -63,8 +68,8 @@ func (r *RouterFake) GetOrUpdateReaders(_ context.Context, bookmarksFn func(cont return r.GetOrUpdateReadersRet, r.Err } -func (r *RouterFake) Readers(context.Context, string) ([]string, error) { - return nil, nil +func (r *RouterFake) Readers(string) []string { + return nil } func (r *RouterFake) GetOrUpdateWriters(_ context.Context, bookmarksFn func(context.Context) ([]string, error), database string, _ *db.ReAuthToken, _ log.BoltLogger) ([]string, error) { @@ -74,8 +79,8 @@ func (r *RouterFake) GetOrUpdateWriters(_ context.Context, bookmarksFn func(cont return r.GetOrUpdateWritersRet, r.Err } -func (r *RouterFake) Writers(context.Context, string) ([]string, error) { - return nil, nil +func (r *RouterFake) Writers(string) []string { + return nil } func (r *RouterFake) GetNameOfDefaultDatabase(_ context.Context, _ []string, user string, _ *db.ReAuthToken, _ log.BoltLogger) (string, error) { @@ -85,9 +90,8 @@ func (r *RouterFake) GetNameOfDefaultDatabase(_ context.Context, _ []string, use return "", nil } -func (r *RouterFake) CleanUp(ctx context.Context) error { +func (r *RouterFake) CleanUp() { if r.CleanUpHook != nil { r.CleanUpHook() } - return nil } diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index dd906b62..09433d43 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -189,9 +189,9 @@ const FetchDefault = 0 // Connection pool as seen by the session. type sessionPool interface { - Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, livenessCheckThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) - Return(ctx context.Context, c idb.Connection) error - CleanUp(ctx context.Context) error + Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, livenessCheckThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Return(ctx context.Context, c idb.Connection) + CleanUp(ctx context.Context) Now() time.Time } @@ -312,7 +312,7 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . beginBookmarks, err := s.getBookmarks(ctx) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } txHandle, err := conn.TxBegin(ctx, @@ -328,7 +328,7 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . }, }) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } @@ -343,8 +343,8 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . } // On run failure, transaction closed (rolled back or committed) bookmarkErr := s.retrieveBookmarks(ctx, tx.conn, beginBookmarks) - poolErr := s.pool.Return(ctx, tx.conn) - tx.err = errorutil.CombineAllErrors(tx.err, bookmarkErr, poolErr) + s.pool.Return(ctx, tx.conn) + tx.err = errorutil.CombineAllErrors(tx.err, bookmarkErr) tx.conn = nil s.explicitTx = nil }, @@ -397,20 +397,7 @@ func (s *sessionWithContext) runRetriable( Sleep: s.sleep, Throttle: retry.Throttler(s.throttleTime), MaxDeadConnections: s.driverConfig.MaxConnectionPoolSize, - Router: s.router, DatabaseName: s.config.DatabaseName, - OnDeadConnection: func(server string) error { - if mode == idb.WriteMode { - if err := s.router.InvalidateWriter(ctx, s.config.DatabaseName, server); err != nil { - return err - } - } else { - if err := s.router.InvalidateReader(ctx, s.config.DatabaseName, server); err != nil { - return err - } - } - return nil - }, } for state.Continue() { if hasCompleted, result := s.executeTransactionFunction(ctx, mode, config, &state, work); hasCompleted { @@ -438,7 +425,7 @@ func (s *sessionWithContext) executeTransactionFunction( // handle transaction function panic as well defer func() { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) }() beginBookmarks, err := s.getBookmarks(ctx) @@ -496,12 +483,12 @@ func (s *sessionWithContext) getOrUpdateServers(ctx context.Context, mode idb.Ac } } -func (s *sessionWithContext) getServers(mode idb.AccessMode) func(context.Context) ([]string, error) { - return func(ctx context.Context) ([]string, error) { +func (s *sessionWithContext) getServers(mode idb.AccessMode) func() []string { + return func() []string { if mode == idb.ReadMode { - return s.router.Readers(ctx, s.config.DatabaseName) + return s.router.Readers(s.config.DatabaseName) } else { - return s.router.Writers(ctx, s.config.DatabaseName) + return s.router.Writers(s.config.DatabaseName) } } } @@ -594,7 +581,7 @@ func (s *sessionWithContext) Run(ctx context.Context, runBookmarks, err := s.getBookmarks(ctx) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } stream, err := conn.Run( @@ -617,7 +604,7 @@ func (s *sessionWithContext) Run(ctx context.Context, }, ) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } @@ -630,7 +617,7 @@ func (s *sessionWithContext) Run(ctx context.Context, } }), onClosed: func() { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) s.autocommitTx = nil }, } @@ -649,15 +636,19 @@ func (s *sessionWithContext) Close(ctx context.Context) error { } defer s.log.Debugf(log.Session, s.logId, "Closed") - poolErrChan := make(chan error, 1) - routerErrChan := make(chan error, 1) + poolCleanUpChan := make(chan struct{}, 1) + routerCleanUpChan := make(chan struct{}, 1) go func() { - poolErrChan <- s.pool.CleanUp(ctx) + s.pool.CleanUp(ctx) + poolCleanUpChan <- struct{}{} }() go func() { - routerErrChan <- s.router.CleanUp(ctx) + s.router.CleanUp() + routerCleanUpChan <- struct{}{} }() - return errorutil.CombineAllErrors(txErr, <-poolErrChan, <-routerErrChan) + <-poolCleanUpChan + <-routerCleanUpChan + return txErr } func (s *sessionWithContext) legacy() Session { diff --git a/neo4j/test-integration/dbconn_test.go b/neo4j/test-integration/dbconn_test.go index c1f2e776..57691f71 100644 --- a/neo4j/test-integration/dbconn_test.go +++ b/neo4j/test-integration/dbconn_test.go @@ -40,10 +40,21 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/test-integration/dbserver" ) -func noopOnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { +type noopErrorListener struct{} + +func (n noopErrorListener) OnNeo4jError(_ context.Context, _ idb.Connection, e *db.Neo4jError) error { + fmt.Println("OnNeo4jError", e) return nil } +func (n noopErrorListener) OnIoError(_ context.Context, _ idb.Connection, e error) { + fmt.Println("OnIoError", e) +} + +func (n noopErrorListener) OnDialError(_ context.Context, _ string, e error) { + fmt.Println("OnDialError", e) +} + func makeRawConnection(ctx context.Context, logger log.Logger, boltLogger log.BoltLogger) ( dbserver.DbServer, idb.Connection) { server := dbserver.GetDbServer(ctx) @@ -77,7 +88,7 @@ func makeRawConnection(ctx context.Context, logger log.Logger, boltLogger log.Bo auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, boltLogger, idb.NotificationConfig{}, diff --git a/neo4j/test-integration/driver_test.go b/neo4j/test-integration/driver_test.go index 7b53a60b..89f8e968 100644 --- a/neo4j/test-integration/driver_test.go +++ b/neo4j/test-integration/driver_test.go @@ -40,7 +40,7 @@ func TestDriver(outer *testing.T) { outer.Run("VerifyConnectivity", func(inner *testing.T) { inner.Run("should return nil upon good connection", func(t *testing.T) { driver := server.Driver() - defer driver.Close(ctx) + defer func() { _ = driver.Close(ctx) }() assertNil(t, driver.VerifyConnectivity(ctx)) }) @@ -48,7 +48,7 @@ func TestDriver(outer *testing.T) { auth := neo4j.BasicAuth("bad user", "bad pass", "bad area") driver, err := neo4j.NewDriverWithContext(server.BoltURI(), auth, server.ConfigFunc()) assertNil(t, err) - defer driver.Close(ctx) + defer func() { _ = driver.Close(ctx) }() err = driver.VerifyConnectivity(ctx) assertNotNil(t, err) }) @@ -64,11 +64,11 @@ func TestDriver(outer *testing.T) { tearDown := func(driver neo4j.DriverWithContext, session neo4j.SessionWithContext) { if session != nil { - session.Close(ctx) + _ = session.Close(ctx) } if driver != nil { - driver.Close(ctx) + _ = driver.Close(ctx) } } @@ -126,7 +126,7 @@ func TestDriver(outer *testing.T) { defer func() { if driver != nil { - driver.Close(ctx) + _ = driver.Close(ctx) } }() @@ -165,7 +165,7 @@ func TestDriver(outer *testing.T) { defer func() { if driver != nil { - driver.Close(ctx) + _ = driver.Close(ctx) } }() diff --git a/testkit-backend/backend.go b/testkit-backend/backend.go index f413d157..07a07db8 100644 --- a/testkit-backend/backend.go +++ b/testkit-backend/backend.go @@ -150,6 +150,7 @@ func (b *backend) writeLine(s string) error { func (b *backend) writeLineLocked(s string) error { b.wrLock.Lock() defer b.wrLock.Unlock() + fmt.Println(s) return b.writeLine(s) } @@ -608,7 +609,7 @@ func (b *backend) handleRequest(req map[string]any) { case "NewSession": driver := b.drivers[data["driverId"].(string)] sessionConfig := neo4j.SessionConfig{ - BoltLogger: neo4j.ConsoleBoltLogger(), + BoltLogger: &streamLog{writeLine: b.writeLineLocked}, } if data["accessMode"] != nil { switch data["accessMode"].(string) { @@ -844,6 +845,54 @@ func (b *backend) handleRequest(req map[string]any) { } b.writeResponse("Summary", serializeSummary(summary)) + case "ForcedRoutingTableUpdate": + databaseRaw := data["database"] + var database string + if databaseRaw != nil { + database = databaseRaw.(string) + } + var bookmarks []string + bookmarksRaw := data["bookmarks"] + if bookmarksRaw != nil { + bookmarksSlice := bookmarksRaw.([]any) + bookmarks = make([]string, len(bookmarksSlice)) + for i, bookmark := range bookmarksSlice { + bookmarks[i] = bookmark.(string) + } + } + driverId := data["driverId"].(string) + driver := b.drivers[driverId] + err := neo4j.ForceRoutingTableUpdate(driver, database, bookmarks, &streamLog{writeLine: b.writeLineLocked}) + if err != nil { + b.writeError(err) + return + } + b.writeResponse("Driver", map[string]any{"id": driverId}) + + case "GetRoutingTable": + driver := b.drivers[data["driverId"].(string)] + databaseRaw := data["database"] + var database string + if databaseRaw != nil { + database = databaseRaw.(string) + } + table, err := neo4j.GetRoutingTable(driver, database) + if err != nil { + b.writeError(err) + return + } + var databaseName any = table.DatabaseName + if databaseName == "" { + databaseName = nil + } + b.writeResponse("RoutingTable", map[string]any{ + "database": databaseName, + "ttl": table.TimeToLive, + "routers": table.Routers, + "readers": table.Readers, + "writers": table.Writers, + }) + case "CheckMultiDBSupport": driver := b.drivers[data["driverId"].(string)] session := driver.NewSession(ctx, neo4j.SessionConfig{ @@ -1041,10 +1090,7 @@ func (b *backend) handleRequest(req map[string]any) { case "GetFeatures": b.writeResponse("FeatureList", map[string]any{ "features": []string{ - "AuthorizationExpiredTreatment", - "Backend:MockTime", - "ConfHint:connection.recv_timeout_seconds", - "Detail:ClosedDriverIsEncrypted", + // === FUNCTIONAL FEATURES === "Feature:API:BookmarkManager", "Feature:API:ConnectionAcquisitionTimeout", "Feature:API:Driver.ExecuteQuery", @@ -1053,15 +1099,21 @@ func (b *backend) handleRequest(req map[string]any) { "Feature:API:Driver:NotificationsConfig", "Feature:API:Driver.VerifyAuthentication", "Feature:API:Driver.VerifyConnectivity", - "Feature:API:Liveness.Check", + //"Feature:API:Driver.SupportsSessionAuth", + // Go driver does not support LivenessCheckTimeout yet + //"Feature:API:Liveness.Check", "Feature:API:Result.List", "Feature:API:Result.Peek", + //"Feature:API:Result.Single", + //"Feature:API:Result.SingleOptional", "Feature:API:Session:AuthConfig", - "Feature:API:Session:NotificationsConfig", + //"Feature:API:Session:NotificationsConfig", + //"Feature:API:SSLConfig", + //"Feature:API:SSLSchemes", "Feature:API:Type.Spatial", "Feature:API:Type.Temporal", - "Feature:Auth:Custom", "Feature:Auth:Bearer", + "Feature:Auth:Custom", "Feature:Auth:Kerberos", "Feature:Auth:Managed", "Feature:Bolt:3.0", @@ -1075,15 +1127,33 @@ func (b *backend) handleRequest(req map[string]any) { "Feature:Bolt:5.3", "Feature:Bolt:Patch:UTC", "Feature:Impersonation", + //"Feature:TLS:1.1", "Feature:TLS:1.2", "Feature:TLS:1.3", - "Optimization:AuthPipelining", + + // === OPTIMIZATIONS === + "AuthorizationExpiredTreatment", "Optimization:ConnectionReuse", "Optimization:EagerTransactionBegin", "Optimization:ImplicitDefaultArguments", "Optimization:MinimalBookmarksSet", "Optimization:MinimalResets", + //"Optimization:MinimalVerifyAuthentication", + "Optimization:AuthPipelining", "Optimization:PullPipelining", + //"Optimization:ResultListFetchAll", + + // === IMPLEMENTATION DETAILS === + "Detail:ClosedDriverIsEncrypted", + "Detail:DefaultSecurityConfigValueEquality", + + // === CONFIGURATION HINTS (BOLT 4.3+) === + "ConfHint:connection.recv_timeout_seconds", + + // === BACKEND FEATURES FOR TESTING === + "Backend:MockTime", + "Backend:RTFetch", + "Backend:RTForceUpdate", }, }) @@ -1393,34 +1463,36 @@ func firstRecordInvalidValue(record *db.Record) *neo4j.InvalidValue { // you can use '*' as wildcards anywhere in the qualified test name (useful to exclude a whole class e.g.) func testSkips() map[string]string { return map[string]string{ - "stub.disconnects.test_disconnects.TestDisconnects.test_fail_on_reset": "It is not resetting driver when put back to pool", - "stub.routing.test_routing_v3.RoutingV3.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v5x0.RoutingV5x0.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v3.RoutingV3.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v5x0.RoutingV5x0.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.configuration_hints.test_connection_recv_timeout_seconds.TestRoutingConnectionRecvTimeout.*": "No GetRoutingTable support - too tricky to implement in Go", - "stub.homedb.test_homedb.TestHomeDb.test_session_should_cache_home_db_despite_new_rt": "Driver does not remove servers from RT when connection breaks.", - "stub.iteration.test_result_scope.TestResultScope.*": "Results are always valid but don't return records when out of scope", - "stub.*.test_0_timeout": "Driver omits 0 as tx timeout value", - "stub.*.test_negative_timeout": "Driver omits negative tx timeout values", - "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_unknown_failure": "Add DNS resolver TestKit message and connection timeout support", - "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_authorization_expired": "Add DNS resolver TestKit message and connection timeout support", - "stub.summary.test_summary.TestSummary.test_server_info": "Needs some kind of server address DNS resolution", - "stub.summary.test_summary.TestSummary.test_invalid_query_type": "Driver does not verify query type returned from server.", - "stub.routing.*.test_should_drop_connections_failing_liveness_check": "Needs support for GetConnectionPoolMetrics", + // Won't fix - accepted/idiomatic behavioral differences + "stub.iteration.test_result_scope.TestResultScope.*": "Won't fix - Results are always valid but don't return records when out of scope", "stub.connectivity_check.test_get_server_info.TestGetServerInfo.test_routing_fail_when_no_reader_are_available": "Won't fix - Go driver retries routing table when no readers are available", "stub.connectivity_check.test_verify_connectivity.TestVerifyConnectivity.test_routing_fail_when_no_reader_are_available": "Won't fix - Go driver retries routing table when no readers are available", "stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_does_not_encompass_router_*": "Won't fix - ConnectionAcquisitionTimeout spans the whole process including db resolution, RT updates, connection acquisition from the pool, and creation of new connections.", "stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_router_handshake_has_own_timeout_*": "Won't fix - ConnectionAcquisitionTimeout spans the whole process including db resolution, RT updates, connection acquisition from the pool, and creation of new connections.", - "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_discard_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", - "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", - "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_commit_after_tx_termination": "fixme: commit is still sent when transaction is terminated", + "stub.routing.test_routing_v*.RoutingV*.test_should_successfully_check_if_support_for_multi_db_is_available": "Won't fix - driver.SupportsMultiDb() is not implemented", + "stub.routing.test_no_routing_v*.NoRoutingV*.test_should_check_multi_db_support": "Won't fix - driver.SupportsMultiDb() is not implemented", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_discovery_when_router_fails_with_procedure_not_found_code": "Won't fix - only Bolt 3 affected (not officially supported by this driver) + this is only a difference in how errors are surfaced", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_on_pull_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_on_run_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", + + // Missing message support in testkit backend + "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_unknown_failure": "Add DNS resolver TestKit message and connection timeout support", + "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_authorization_expired": "Add DNS resolver TestKit message and connection timeout support", + + // To fix/to decide whether to fix + "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_discard_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", + "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", + "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_commit_after_tx_termination": "fixme: commit is still sent when transaction is terminated", + "stub.routing.test_routing_v*.RoutingV*.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "Driver always uses configured URL first and custom resolver only if that fails", + "stub.routing.test_routing_v*.RoutingV*.test_should_read_successfully_from_reachable_db_after_trying_unreachable_db": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_write_successfully_after_leader_switch_using_tx_run": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_fail_when_writing_without_writers_using_session_run": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_accept_routing_table_without_writers_and_then_rediscover": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_fail_on_routing_table_with_no_reader": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_fail_discovery_when_router_fails_with_unknown_code": "Unify: other drivers have a list of fast failing errors during discover: on anything else, the driver will try the next router", + "stub.*.test_0_timeout": "Fixme: driver omits 0 as tx timeout value", + "stub.summary.test_summary.TestSummary.test_server_info": "pending unification: should the server address be pre or post DNS resolution?", } } diff --git a/testkit-backend/streamlogger.go b/testkit-backend/streamlogger.go index c51004be..fcdd1f9b 100644 --- a/testkit-backend/streamlogger.go +++ b/testkit-backend/streamlogger.go @@ -8,34 +8,56 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package main import ( "fmt" + "time" ) +const timeFormat = "2006-01-02 15:04:05.000" + type streamLog struct { writeLine func(string) error } func (l *streamLog) Error(name string, id string, err error) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, err)) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, err)) } func (l *streamLog) Warnf(name string, id string, msg string, args ...any) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) } func (l *streamLog) Infof(name string, id string, msg string, args ...any) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) } func (l *streamLog) Debugf(name string, id string, msg string, args ...any) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) +} + +func (l *streamLog) LogClientMessage(id, msg string, args ...any) { + l.logBoltMessage("C", id, msg, args) +} + +func (l *streamLog) LogServerMessage(id, msg string, args ...any) { + l.logBoltMessage("S", id, msg, args) +} + +func (l *streamLog) logBoltMessage(src, id string, msg string, args []any) { + _ = l.writeLine(fmt.Sprintf("%s BOLT %s%s: %s", time.Now().Format(timeFormat), formatId(id), src, fmt.Sprintf(msg, args...))) +} + +func formatId(id string) string { + if id == "" { + return "" + } + return fmt.Sprintf("[%s] ", id) }