diff --git a/neo4j/directrouter.go b/neo4j/directrouter.go index fd0cfb1a..234efa5b 100644 --- a/neo4j/directrouter.go +++ b/neo4j/directrouter.go @@ -30,38 +30,32 @@ type directRouter struct { address string } -func (r *directRouter) InvalidateWriter(context.Context, string, string) error { - return nil -} +func (r *directRouter) InvalidateWriter(string, string) {} -func (r *directRouter) InvalidateReader(context.Context, string, string) error { - return nil -} +func (r *directRouter) InvalidateReader(string, string) {} + +func (r *directRouter) InvalidateServer(string) {} func (r *directRouter) GetOrUpdateReaders(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) { return []string{r.address}, nil } -func (r *directRouter) Readers(context.Context, string) ([]string, error) { - return []string{r.address}, nil +func (r *directRouter) Readers(string) []string { + return []string{r.address} } func (r *directRouter) GetOrUpdateWriters(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) { return []string{r.address}, nil } -func (r *directRouter) Writers(context.Context, string) ([]string, error) { - return []string{r.address}, nil +func (r *directRouter) Writers(string) []string { + return []string{r.address} } func (r *directRouter) GetNameOfDefaultDatabase(context.Context, []string, string, *db.ReAuthToken, log.BoltLogger) (string, error) { return db.DefaultDatabase, nil } -func (r *directRouter) Invalidate(context.Context, string) error { - return nil -} +func (r *directRouter) Invalidate(string) {} -func (r *directRouter) CleanUp(context.Context) error { - return nil -} +func (r *directRouter) CleanUp() {} diff --git a/neo4j/driver_with_context.go b/neo4j/driver_with_context.go index c95974d5..0c941875 100644 --- a/neo4j/driver_with_context.go +++ b/neo4j/driver_with_context.go @@ -254,6 +254,8 @@ func NewDriverWithContext(target string, auth auth.TokenManager, configurers ... d.router = router.New(address, routersResolver, routingContext, d.pool, d.log, d.logId, &d.now) } + d.pool.SetRouter(d.router) + d.log.Infof(log.Driver, d.logId, "Created { target: %s }", address) return &d, nil } @@ -300,20 +302,21 @@ type sessionRouter interface { // they should not be called when it is not needed (e.g. when a routing table is cached) GetOrUpdateReaders(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error) // Readers returns the list of servers that can serve reads on the requested database. - Readers(ctx context.Context, database string) ([]string, error) + Readers(database string) []string // GetOrUpdateWriters returns the list of servers that can serve writes on the requested database. // note: bookmarks are lazily supplied, see Readers documentation to learn why GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error) // Writers returns the list of servers that can serve writes on the requested database. - Writers(ctx context.Context, database string) ([]string, error) + Writers(database string) []string // GetNameOfDefaultDatabase returns the name of the default database for the specified user. // The correct database name is needed when requesting readers or writers. // the bookmarks are eagerly provided since this method always fetches a new routing table GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) (string, error) - Invalidate(ctx context.Context, database string) error - CleanUp(ctx context.Context) error - InvalidateWriter(ctx context.Context, name string, server string) error - InvalidateReader(ctx context.Context, name string, server string) error + Invalidate(db string) + CleanUp() + InvalidateWriter(db string, server string) + InvalidateReader(db string, server string) + InvalidateServer(server string) } type driverWithContext struct { @@ -394,9 +397,7 @@ func (d *driverWithContext) Close(ctx context.Context) error { defer d.mut.Unlock() // Safeguard against closing more than once if d.pool != nil { - if err := d.pool.Close(ctx); err != nil { - return err - } + d.pool.Close(ctx) d.pool = nil d.log.Infof(log.Driver, d.logId, "Closed") } diff --git a/neo4j/driver_with_context_testkit.go b/neo4j/driver_with_context_testkit.go index f30c08bb..94fb1c98 100644 --- a/neo4j/driver_with_context_testkit.go +++ b/neo4j/driver_with_context_testkit.go @@ -1,5 +1,3 @@ -//go:build internal_testkit - /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] @@ -10,18 +8,30 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ +//go:build internal_testkit + package neo4j -import "time" +import ( + "context" + "fmt" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" + "time" +) + +type RoutingTable = idb.RoutingTable func SetTimer(d DriverWithContext, timer func() time.Time) { driver := d.(*driverWithContext) @@ -32,3 +42,33 @@ func ResetTime(d DriverWithContext) { driver := d.(*driverWithContext) driver.now = time.Now } + +func ForceRoutingTableUpdate(d DriverWithContext, database string, bookmarks []string, logger log.BoltLogger) error { + driver := d.(*driverWithContext) + ctx := context.Background() + driver.router.Invalidate(database) + getBookmarks := func(context.Context) ([]string, error) { + return bookmarks, nil + } + auth := &idb.ReAuthToken{ + Manager: driver.auth, + FromSession: false, + ForceReAuth: false, + } + _, err := driver.router.GetOrUpdateReaders(ctx, getBookmarks, database, auth, logger) + if err != nil { + return errorutil.WrapError(err) + } + _, err = driver.router.GetOrUpdateWriters(ctx, getBookmarks, database, auth, logger) + return errorutil.WrapError(err) +} + +func GetRoutingTable(d DriverWithContext, database string) (*RoutingTable, error) { + driver := d.(*driverWithContext) + router, ok := driver.router.(*router.Router) + if !ok { + return nil, fmt.Errorf("GetRoutingTable is only supported for direct drivers") + } + table := router.GetTable(database) + return table, nil +} diff --git a/neo4j/internal/bolt/bolt3.go b/neo4j/internal/bolt/bolt3.go index 41ca4101..33ae9558 100644 --- a/neo4j/internal/bolt/bolt3.go +++ b/neo4j/internal/bolt/bolt3.go @@ -95,14 +95,14 @@ type bolt3 struct { auth map[string]any authManager auth.TokenManager resetAuth bool - onNeo4jError Neo4jErrorCallback + errorListener ConnectionErrorListener now *func() time.Time } func NewBolt3( serverName string, conn net.Conn, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, @@ -120,16 +120,22 @@ func NewBolt3( }, connReadTimeout: -1, }, - birthDate: now, - idleDate: now, - log: logger, - onNeo4jError: callback, - now: timer, + birthDate: now, + idleDate: now, + log: logger, + errorListener: errorListener, + now: timer, } b.out = &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { + onPackErr: func(err error) { + if b.err == nil { + b.err = err + } + b.state = bolt3_dead + }, + onIoErr: func(ctx context.Context, err error) { if b.err == nil { b.err = err } @@ -181,7 +187,7 @@ func (b *bolt3) receiveSuccess(ctx context.Context) *success { } else { b.log.Error(log.Bolt3, b.logId, message) } - if err := b.onNeo4jError(ctx, b, message); err != nil { + if err := b.errorListener.OnNeo4jError(ctx, b, message); err != nil { b.err = errorutil.CombineErrors(message, b.err) } return nil @@ -662,7 +668,7 @@ func (b *bolt3) receiveNext(ctx context.Context) (*db.Record, *db.Summary, error } else { b.log.Error(log.Bolt3, b.logId, message) } - if err := b.onNeo4jError(ctx, b, message); err != nil { + if err := b.errorListener.OnNeo4jError(ctx, b, message); err != nil { return nil, nil, errorutil.CombineErrors(message, err) } return nil, nil, message diff --git a/neo4j/internal/bolt/bolt3_test.go b/neo4j/internal/bolt/bolt3_test.go index ff494892..82f39199 100644 --- a/neo4j/internal/bolt/bolt3_test.go +++ b/neo4j/internal/bolt/bolt3_test.go @@ -114,7 +114,7 @@ func TestBolt3(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -168,7 +168,7 @@ func TestBolt3(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, diff --git a/neo4j/internal/bolt/bolt4.go b/neo4j/internal/bolt/bolt4.go index 9532e647..e28d6168 100644 --- a/neo4j/internal/bolt/bolt4.go +++ b/neo4j/internal/bolt/bolt4.go @@ -111,30 +111,30 @@ type bolt4 struct { auth map[string]any authManager auth.TokenManager resetAuth bool - onNeo4jError Neo4jErrorCallback + errorListener ConnectionErrorListener now *func() time.Time } func NewBolt4( serverName string, conn net.Conn, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, ) *bolt4 { now := (*timer)() b := &bolt4{ - state: bolt4_unauthorized, - conn: conn, - serverName: serverName, - birthDate: now, - idleDate: now, - log: logger, - streams: openstreams{}, - lastQid: -1, - onNeo4jError: callback, - now: timer, + state: bolt4_unauthorized, + conn: conn, + serverName: serverName, + birthDate: now, + idleDate: now, + log: logger, + streams: openstreams{}, + lastQid: -1, + errorListener: errorListener, + now: timer, } b.queue = newMessageQueue( conn, @@ -149,11 +149,12 @@ func NewBolt4( &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { b.setError(err, true) }, + onPackErr: func(err error) { b.setError(err, true) }, + onIoErr: b.onIoError, boltLogger: boltLog, }, b.onNextMessage, - b.onNextMessageError, + b.onIoError, ) return b @@ -938,6 +939,10 @@ func (b *bolt4) SelectDatabase(database string) { b.databaseName = database } +func (b *bolt4) Database() string { + return b.databaseName +} + func (b *bolt4) SetBoltLogger(boltLogger log.BoltLogger) { b.queue.setBoltLogger(boltLogger) } @@ -1079,7 +1084,7 @@ func (b *bolt4) resetResponseHandler() responseHandler { b.state = bolt4_ready }, onFailure: func(ctx context.Context, failure *db.Neo4jError) { - _ = b.onNeo4jError(ctx, b, failure) + _ = b.errorListener.OnNeo4jError(ctx, b, failure) b.state = bolt4_dead }, } @@ -1126,19 +1131,24 @@ func (b *bolt4) onNextMessage() { b.idleDate = (*b.now)() } -func (b *bolt4) onNextMessageError(err error) { - b.setError(err, true) -} - func (b *bolt4) onFailure(ctx context.Context, failure *db.Neo4jError) { var err error err = failure - if callbackErr := b.onNeo4jError(ctx, b, failure); callbackErr != nil { + if callbackErr := b.errorListener.OnNeo4jError(ctx, b, failure); callbackErr != nil { err = errorutil.CombineErrors(failure, callbackErr) } b.setError(err, isFatalError(failure)) } +func (b *bolt4) onIoError(ctx context.Context, err error) { + if b.state != bolt4_failed && b.state != bolt4_dead { + // Don't call callback when connections break after sending RESET. + // The server chooses to close the connection on some errors. + b.errorListener.OnIoError(ctx, b, err) + } + b.setError(err, true) +} + const readTimeoutHintName = "connection.recv_timeout_seconds" func (b *bolt4) initializeReadTimeoutHint(hints map[string]any) { diff --git a/neo4j/internal/bolt/bolt4_test.go b/neo4j/internal/bolt/bolt4_test.go index 8c740154..1dc55426 100644 --- a/neo4j/internal/bolt/bolt4_test.go +++ b/neo4j/internal/bolt/bolt4_test.go @@ -117,7 +117,7 @@ func TestBolt4(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -322,7 +322,7 @@ func TestBolt4(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, diff --git a/neo4j/internal/bolt/bolt5.go b/neo4j/internal/bolt/bolt5.go index ad69aff8..a658caca 100644 --- a/neo4j/internal/bolt/bolt5.go +++ b/neo4j/internal/bolt/bolt5.go @@ -113,30 +113,30 @@ type bolt5 struct { auth map[string]any authManager auth.TokenManager resetAuth bool - onNeo4jError Neo4jErrorCallback + errorListener ConnectionErrorListener now *func() time.Time } func NewBolt5( serverName string, conn net.Conn, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, ) *bolt5 { now := (*timer)() b := &bolt5{ - state: bolt5Unauthorized, - conn: conn, - serverName: serverName, - birthDate: now, - idleDate: now, - log: logger, - streams: openstreams{}, - lastQid: -1, - onNeo4jError: callback, - now: timer, + state: bolt5Unauthorized, + conn: conn, + serverName: serverName, + birthDate: now, + idleDate: now, + log: logger, + streams: openstreams{}, + lastQid: -1, + errorListener: errorListener, + now: timer, } b.queue = newMessageQueue( conn, @@ -152,12 +152,13 @@ func NewBolt5( &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { b.setError(err, true) }, + onPackErr: func(err error) { b.setError(err, true) }, + onIoErr: b.onIoError, boltLogger: boltLog, useUtc: true, }, b.onNextMessage, - b.onNextMessageError, + b.onIoError, ) return b } @@ -921,19 +922,23 @@ func (b *bolt5) reAuth(ctx context.Context, auth *idb.ReAuthToken) error { func (b *bolt5) Close(ctx context.Context) { b.log.Infof(log.Bolt5, b.logId, "Close") if b.state != bolt5Dead { + b.state = bolt5Dead b.queue.appendGoodbye() b.queue.send(ctx) } if err := b.conn.Close(); err != nil { b.log.Warnf(log.Driver, b.serverName, "could not close underlying socket") } - b.state = bolt5Dead } func (b *bolt5) SelectDatabase(database string) { b.databaseName = database } +func (b *bolt5) Database() string { + return b.databaseName +} + func (b *bolt5) Version() db.ProtocolVersion { return db.ProtocolVersion{ Major: 5, @@ -1077,7 +1082,7 @@ func (b *bolt5) resetResponseHandler() responseHandler { b.state = bolt5Ready }, onFailure: func(ctx context.Context, failure *db.Neo4jError) { - _ = b.onNeo4jError(ctx, b, failure) + _ = b.errorListener.OnNeo4jError(ctx, b, failure) b.state = bolt5Dead }, } @@ -1111,19 +1116,24 @@ func (b *bolt5) onNextMessage() { b.idleDate = (*b.now)() } -func (b *bolt5) onNextMessageError(err error) { - b.setError(err, true) -} - func (b *bolt5) onFailure(ctx context.Context, failure *db.Neo4jError) { var err error err = failure - if callbackErr := b.onNeo4jError(ctx, b, failure); callbackErr != nil { + if callbackErr := b.errorListener.OnNeo4jError(ctx, b, failure); callbackErr != nil { err = errorutil.CombineErrors(callbackErr, failure) } b.setError(err, isFatalError(failure)) } +func (b *bolt5) onIoError(ctx context.Context, err error) { + if b.state != bolt5Failed && b.state != bolt5Dead { + // Don't call callback when connections break after sending RESET. + // The server chooses to close the connection on some errors. + b.errorListener.OnIoError(ctx, b, err) + } + b.setError(err, true) +} + func (b *bolt5) initializeReadTimeoutHint(hints map[string]any) { readTimeoutHint, ok := hints[readTimeoutHintName] if !ok { diff --git a/neo4j/internal/bolt/bolt5_test.go b/neo4j/internal/bolt/bolt5_test.go index f147e020..ce2a971a 100644 --- a/neo4j/internal/bolt/bolt5_test.go +++ b/neo4j/internal/bolt/bolt5_test.go @@ -118,7 +118,7 @@ func TestBolt5(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -414,7 +414,7 @@ func TestBolt5(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, @@ -451,7 +451,7 @@ func TestBolt5(outer *testing.T) { auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, nil, idb.NotificationConfig{}, diff --git a/neo4j/internal/bolt/bolt_test.go b/neo4j/internal/bolt/bolt_test.go index 56de639c..e8b79908 100644 --- a/neo4j/internal/bolt/bolt_test.go +++ b/neo4j/internal/bolt/bolt_test.go @@ -25,6 +25,12 @@ import ( idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" ) -func noopOnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { +type noopErrorListener struct{} + +func (n noopErrorListener) OnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { return nil } + +func (n noopErrorListener) OnIoError(context.Context, idb.Connection, error) {} + +func (n noopErrorListener) OnDialError(context.Context, string, error) {} diff --git a/neo4j/internal/bolt/connect.go b/neo4j/internal/bolt/connect.go index e1b08b5f..17bb6397 100644 --- a/neo4j/internal/bolt/connect.go +++ b/neo4j/internal/bolt/connect.go @@ -54,7 +54,7 @@ func Connect(ctx context.Context, auth *db.ReAuthToken, userAgent string, routingContext map[string]string, - callback Neo4jErrorCallback, + errorListener ConnectionErrorListener, logger log.Logger, boltLogger log.BoltLogger, notificationConfig db.NotificationConfig, @@ -74,6 +74,7 @@ func Connect(ctx context.Context, } _, err := racing.NewRacingWriter(conn).Write(ctx, handshake) if err != nil { + errorListener.OnDialError(ctx, serverName, err) return nil, err } @@ -81,6 +82,7 @@ func Connect(ctx context.Context, buf := make([]byte, 4) _, err = racing.NewRacingReader(conn).ReadFull(ctx, buf) if err != nil { + errorListener.OnDialError(ctx, serverName, err) return nil, err } @@ -93,11 +95,11 @@ func Connect(ctx context.Context, var boltConn db.Connection switch major { case 3: - boltConn = NewBolt3(serverName, conn, callback, timer, logger, boltLogger) + boltConn = NewBolt3(serverName, conn, errorListener, timer, logger, boltLogger) case 4: - boltConn = NewBolt4(serverName, conn, callback, timer, logger, boltLogger) + boltConn = NewBolt4(serverName, conn, errorListener, timer, logger, boltLogger) case 5: - boltConn = NewBolt5(serverName, conn, callback, timer, logger, boltLogger) + boltConn = NewBolt5(serverName, conn, errorListener, timer, logger, boltLogger) case 0: return nil, fmt.Errorf("server did not accept any of the requested Bolt versions (%#v)", versions) default: diff --git a/neo4j/internal/bolt/connections.go b/neo4j/internal/bolt/connections.go index 0d3caf07..16218409 100644 --- a/neo4j/internal/bolt/connections.go +++ b/neo4j/internal/bolt/connections.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -27,7 +27,11 @@ import ( "net" ) -type Neo4jErrorCallback func(context.Context, idb.Connection, *db.Neo4jError) error +type ConnectionErrorListener interface { + OnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error + OnIoError(context.Context, idb.Connection, error) + OnDialError(context.Context, string, error) +} func handleTerminatedContextError(err error, connection net.Conn) error { if !contextTerminatedErr(err) { diff --git a/neo4j/internal/bolt/hydratedehydrate_test.go b/neo4j/internal/bolt/hydratedehydrate_test.go index ad735a04..9db79ff5 100644 --- a/neo4j/internal/bolt/hydratedehydrate_test.go +++ b/neo4j/internal/bolt/hydratedehydrate_test.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -39,9 +39,12 @@ func TestDehydrateHydrate(ot *testing.T) { out := &outgoing{ chunker: newChunker(), packer: packstream.Packer{}, - onErr: func(err error) { + onPackErr: func(err error) { ot.Fatalf("Should be no dehydration errors in this test: %s", err) }, + onIoErr: func(_ context.Context, err error) { + ot.Fatalf("Should be no io errors in this test: %s", err) + }, } serv, cli := net.Pipe() defer func() { diff --git a/neo4j/internal/bolt/message_queue.go b/neo4j/internal/bolt/message_queue.go index 4e52f5f8..de8e4906 100644 --- a/neo4j/internal/bolt/message_queue.go +++ b/neo4j/internal/bolt/message_queue.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -36,23 +36,23 @@ type messageQueue struct { targetConnection net.Conn err error - onNextMessage func() - onNextMessageErr func(error) + onNextMessage func() + onIoErr func(context.Context, error) } func newMessageQueue( target net.Conn, in *incoming, out *outgoing, onNext func(), - onNextErr func(error)) messageQueue { - + onIoErr func(context.Context, error), +) messageQueue { return messageQueue{ in: in, out: out, handlers: list.List{}, targetConnection: target, onNextMessage: onNext, - onNextMessageErr: onNextErr, + onIoErr: onIoErr, } } @@ -205,7 +205,7 @@ func (q *messageQueue) receiveMsg(ctx context.Context) any { msg, err := q.in.next(ctx, q.targetConnection) q.err = err if err != nil { - q.onNextMessageErr(err) + q.onIoErr(ctx, err) } else { q.onNextMessage() } diff --git a/neo4j/internal/bolt/outgoing.go b/neo4j/internal/bolt/outgoing.go index eab6be0d..771467c7 100644 --- a/neo4j/internal/bolt/outgoing.go +++ b/neo4j/internal/bolt/outgoing.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -35,7 +35,8 @@ import ( type outgoing struct { chunker chunker packer packstream.Packer - onErr func(err error) + onPackErr func(error) + onIoErr func(context.Context, error) boltLogger log.BoltLogger logId string useUtc bool @@ -51,7 +52,7 @@ func (o *outgoing) end() { o.chunker.buf = buf o.chunker.endMessage() if err != nil { - o.onErr(err) + o.onPackErr(err) } } @@ -259,7 +260,7 @@ func (o *outgoing) appendX(tag byte, fields ...any) { func (o *outgoing) send(ctx context.Context, wr io.Writer) { err := o.chunker.send(ctx, wr) if err != nil { - o.onErr(err) + o.onIoErr(ctx, err) } } @@ -347,7 +348,7 @@ func (o *outgoing) packStruct(x any) { o.packer.Int64(v.Seconds) o.packer.Int(v.Nanos) default: - o.onErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) + o.onPackErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) } } @@ -416,7 +417,7 @@ func (o *outgoing) packX(x any) { default: t := reflect.TypeOf(x) if t.Key().Kind() != reflect.String { - o.onErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) + o.onPackErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) return } o.packer.MapHeader(v.Len()) @@ -427,7 +428,7 @@ func (o *outgoing) packX(x any) { } } default: - o.onErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) + o.onPackErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) } } diff --git a/neo4j/internal/bolt/outgoing_test.go b/neo4j/internal/bolt/outgoing_test.go index 2eb5bc6f..af6fe096 100644 --- a/neo4j/internal/bolt/outgoing_test.go +++ b/neo4j/internal/bolt/outgoing_test.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -91,9 +91,12 @@ func TestOutgoing(ot *testing.T) { // Utility to unpack through dechunking and a custom build func dechunkAndUnpack := func(t *testing.T, build func(*testing.T, *outgoing)) any { out := &outgoing{ - chunker: newChunker(), - packer: packstream.Packer{}, - onErr: func(e error) { err = e }, + chunker: newChunker(), + packer: packstream.Packer{}, + onPackErr: func(e error) { err = e }, + onIoErr: func(_ context.Context, err error) { + ot.Fatalf("Should be no io errors in this test: %s", err) + }, } serv, cli := net.Pipe() defer func() { @@ -584,9 +587,12 @@ func TestOutgoing(ot *testing.T) { for _, c := range paramErrorCases { var err error out := &outgoing{ - chunker: newChunker(), - packer: packstream.Packer{}, - onErr: func(e error) { err = e }, + chunker: newChunker(), + packer: packstream.Packer{}, + onPackErr: func(e error) { err = e }, + onIoErr: func(_ context.Context, err error) { + ot.Fatalf("Should be no io errors in this test: %s", err) + }, } ot.Run(c.name, func(t *testing.T) { out.begin() diff --git a/neo4j/internal/connector/connector.go b/neo4j/internal/connector/connector.go index e89186d1..becb01d3 100644 --- a/neo4j/internal/connector/connector.go +++ b/neo4j/internal/connector/connector.go @@ -50,7 +50,7 @@ func (c Connector) Connect( ctx context.Context, address string, auth *db.ReAuthToken, - callback bolt.Neo4jErrorCallback, + errorListener bolt.ConnectionErrorListener, boltLogger log.BoltLogger, ) (connection db.Connection, err error) { if c.SupplyConnection == nil { @@ -59,13 +59,14 @@ func (c Connector) Connect( conn, err := c.SupplyConnection(ctx, address) if err != nil { + errorListener.OnDialError(ctx, address, err) return nil, err } defer func() { if err != nil && connection == nil { if err := conn.Close(); err != nil { - c.Log.Warnf(log.Driver, address, "could not close socket after failed connection") + c.Log.Warnf(log.Driver, address, "could not close socket after failed connection %s", err) } } }() @@ -84,7 +85,7 @@ func (c Connector) Connect( auth, c.Config.UserAgent, c.RoutingContext, - callback, + errorListener, c.Log, boltLogger, notificationConfig, @@ -99,6 +100,7 @@ func (c Connector) Connect( // TLS requested, continue with handshake serverName, _, err := net.SplitHostPort(address) if err != nil { + errorListener.OnDialError(ctx, address, err) return nil, err } tlsConn := tls.Client(conn, c.tlsConfig(serverName)) @@ -108,7 +110,9 @@ func (c Connector) Connect( // Give a bit nicer error message err = errors.New("remote end closed the connection, check that TLS is enabled on the server") } - return nil, &errorutil.TlsError{Inner: err} + err = &errorutil.TlsError{Inner: err} + errorListener.OnDialError(ctx, address, err) + return nil, err } connection, err = bolt.Connect(ctx, address, @@ -116,7 +120,7 @@ func (c Connector) Connect( auth, c.Config.UserAgent, c.RoutingContext, - callback, + errorListener, c.Log, boltLogger, notificationConfig, diff --git a/neo4j/internal/connector/connector_test.go b/neo4j/internal/connector/connector_test.go index 6b8ff0b1..633a68e6 100644 --- a/neo4j/internal/connector/connector_test.go +++ b/neo4j/internal/connector/connector_test.go @@ -22,7 +22,9 @@ package connector_test import ( "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/connector" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" "io" "net" @@ -30,6 +32,16 @@ import ( "time" ) +type noopErrorListener struct{} + +func (n noopErrorListener) OnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { + return nil +} + +func (n noopErrorListener) OnIoError(context.Context, idb.Connection, error) {} + +func (n noopErrorListener) OnDialError(context.Context, string, error) {} + func TestConnect(outer *testing.T) { outer.Parallel() @@ -49,7 +61,7 @@ func TestConnect(outer *testing.T) { Now: &timer, } - connection, err := connector.Connect(ctx, "irrelevant", nil, nil, nil) + connection, err := connector.Connect(ctx, "irrelevant", nil, noopErrorListener{}, nil) AssertNil(t, connection) AssertErrorMessageContains(t, err, "unsupported version 1.0") @@ -70,7 +82,7 @@ func TestConnect(outer *testing.T) { Now: &timer, } - connection, err := connector.Connect(ctx, "irrelevant", nil, nil, nil) + connection, err := connector.Connect(ctx, "irrelevant", nil, noopErrorListener{}, nil) AssertNil(t, connection) AssertError(t, err) diff --git a/neo4j/internal/db/connection.go b/neo4j/internal/db/connection.go index 203ad962..6a6556b8 100644 --- a/neo4j/internal/db/connection.go +++ b/neo4j/internal/db/connection.go @@ -184,4 +184,5 @@ type DatabaseSelector interface { // SelectDatabase should be called immediately after Reset. Not allowed to call multiple times with different // databases without a reset in-between. SelectDatabase(database string) + Database() string } diff --git a/neo4j/internal/pool/no_test.go b/neo4j/internal/pool/no_test.go index 28e2072b..25741ff7 100644 --- a/neo4j/internal/pool/no_test.go +++ b/neo4j/internal/pool/no_test.go @@ -8,19 +8,18 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package pool import ( - "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "testing" ) @@ -44,24 +43,18 @@ func assertNoConnection(t *testing.T, c db.Connection, err error) { } } -func assertNumberOfServers(t *testing.T, ctx context.Context, p *Pool, expectedNum int) { +func assertNumberOfServers(t *testing.T, p *Pool, expectedNum int) { t.Helper() - servers, err := p.getServers(ctx) - if err != nil { - t.Fatalf("Expected nil error, got: %v", err) - } + servers := p.getServers() actualNum := len(servers) if actualNum != expectedNum { t.Fatalf("Expected number of servers to be %d but was %d", expectedNum, actualNum) } } -func assertNumberOfIdle(t *testing.T, ctx context.Context, p *Pool, serverName string, expectedNum int) { +func assertNumberOfIdle(t *testing.T, p *Pool, serverName string, expectedNum int) { t.Helper() - servers, err := p.getServers(ctx) - if err != nil { - t.Fatalf("Expected nil error, got: %v", err) - } + servers := p.getServers() server := servers[serverName] if server == nil { t.Fatalf("Server %s not found", serverName) diff --git a/neo4j/internal/pool/pool.go b/neo4j/internal/pool/pool.go index b2a8a069..163b220e 100644 --- a/neo4j/internal/pool/pool.go +++ b/neo4j/internal/pool/pool.go @@ -25,14 +25,12 @@ package pool import ( "container/list" "context" - "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "math" "sort" "sync" @@ -45,7 +43,13 @@ import ( // Liveness checks are performed before a connection is deemed idle enough to be reset const DefaultLivenessCheckThreshold = math.MaxInt64 -type Connect func(context.Context, string, *idb.ReAuthToken, bolt.Neo4jErrorCallback, log.BoltLogger) (idb.Connection, error) +type Connect func(context.Context, string, *idb.ReAuthToken, bolt.ConnectionErrorListener, log.BoltLogger) (idb.Connection, error) + +type poolRouter interface { + InvalidateWriter(db string, server string) + InvalidateReader(db string, server string) + InvalidateServer(server string) +} type qitem struct { wakeup chan bool @@ -54,9 +58,10 @@ type qitem struct { type Pool struct { config *config.Config connect Connect + router poolRouter servers map[string]*server - serversMut racing.Mutex - queueMut racing.Mutex + serversMut sync.Mutex + queueMut sync.Mutex queue list.List now *func() time.Time closed bool @@ -75,9 +80,10 @@ func New(config *config.Config, connect Connect, logger log.Logger, logId string p := &Pool{ config: config, connect: connect, + router: nil, servers: make(map[string]*server), - serversMut: racing.NewMutex(), - queueMut: racing.NewMutex(), + serversMut: sync.Mutex{}, + queueMut: sync.Mutex{}, now: now, logId: logId, log: logger, @@ -86,11 +92,13 @@ func New(config *config.Config, connect Connect, logger log.Logger, logId string return p } -func (p *Pool) Close(ctx context.Context) error { +func (p *Pool) SetRouter(router poolRouter) { + p.router = router +} + +func (p *Pool) Close(ctx context.Context) { p.closed = true - if !p.queueMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire queue lock in time when closing pool") - } + p.queueMut.Lock() for e := p.queue.Front(); e != nil; e = e.Next() { queuedRequest := e.Value.(*qitem) p.queue.Remove(e) @@ -98,48 +106,39 @@ func (p *Pool) Close(ctx context.Context) error { } p.queueMut.Unlock() // Go through each server and close all connections to it - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock in time when closing pool") - } + p.serversMut.Lock() for n, s := range p.servers { s.closeAll(ctx) delete(p.servers, n) } p.serversMut.Unlock() p.log.Infof(log.Pool, p.logId, "Closed") - return nil } // For testing -func (p *Pool) queueSize(ctx context.Context) (int, error) { - if !p.queueMut.TryLock(ctx) { - return -1, fmt.Errorf("could not acquire queue lock in time when checking queue size") - } +func (p *Pool) queueSize() int { + p.queueMut.Lock() defer p.queueMut.Unlock() - return p.queue.Len(), nil + return p.queue.Len() } // For testing -func (p *Pool) getServers(ctx context.Context) (map[string]*server, error) { - if !p.serversMut.TryLock(ctx) { - return nil, fmt.Errorf("could not acquire server lock in time when getting servers") - } +func (p *Pool) getServers() map[string]*server { + p.serversMut.Lock() defer p.serversMut.Unlock() servers := make(map[string]*server) for k, v := range p.servers { servers[k] = v } - return servers, nil + return servers } // CleanUp prunes all old connection on all the servers, this makes sure that servers // gets removed from the map at some point in time. If there is a noticed // failed connect still active we should wait a while with removal to get // prioritization right. -func (p *Pool) CleanUp(ctx context.Context) error { - if !p.serversMut.TryLock(ctx) { - return fmt.Errorf("could not acquire server lock in time when cleaning up pool") - } +func (p *Pool) CleanUp(ctx context.Context) { + p.serversMut.Lock() defer p.serversMut.Unlock() now := (*p.now)() for n, s := range p.servers { @@ -148,17 +147,14 @@ func (p *Pool) CleanUp(ctx context.Context) error { delete(p.servers, n) } } - return nil } func (p *Pool) Now() time.Time { return (*p.now)() } -func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) ([]serverPenalty, error) { - if !p.serversMut.TryLock(ctx) { - return nil, fmt.Errorf("could not acquire server lock in time when computing server penalties") - } +func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) []serverPenalty { + p.serversMut.Lock() defer p.serversMut.Unlock() // Retrieve penalty for each server @@ -175,13 +171,11 @@ func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) penalties[i].penalty = newConnectionPenalty } } - return penalties, nil + return penalties } func (p *Pool) tryAnyIdle(ctx context.Context, serverNames []string, idlenessThreshold time.Duration, auth *idb.ReAuthToken, logger log.BoltLogger) (idb.Connection, error) { - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire server lock in time when getting idle connection") - } + p.serversMut.Lock() var unlock = new(sync.Once) defer unlock.Do(p.serversMut.Unlock) serverLoop: @@ -198,16 +192,12 @@ serverLoop: if healthy { return conn, nil } - if err := p.unreg(context.Background(), serverName, conn, p.Now()); err != nil { - panic("lock with Background context should never time out") - } + p.unreg(ctx, serverName, conn, p.Now()) if err != nil { p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err) return nil, err } - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when borrowing a connection") - } + p.serversMut.Lock() *unlock = sync.Once{} } } @@ -215,29 +205,25 @@ serverLoop: return nil, nil } -func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { +func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { for { if p.closed { return nil, &errorutil.PoolClosed{} } - serverNames, err := getServerNames(ctx) - if err != nil { - return nil, err - } + serverNames := getServerNames() if len(serverNames) == 0 { return nil, &errorutil.PoolOutOfServers{} } p.log.Debugf(log.Pool, p.logId, "Trying to borrow connection from %s", serverNames) // Retrieve penalty for each server - penalties, err := p.getPenaltiesForServers(ctx, serverNames) - if err != nil { - return nil, err - } + penalties := p.getPenaltiesForServers(ctx, serverNames) // Sort server penalties by lowest penalty sort.Slice(penalties, func(i, j int) bool { return penalties[i].penalty < penalties[j].penalty }) + var err error + var conn idb.Connection for _, s := range penalties { conn, err = p.tryBorrow(ctx, s.name, boltLogger, idlenessThreshold, auth) @@ -265,9 +251,7 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) } // Wait for a matching connection to be returned from another thread. - if !p.queueMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when trying to get an idle connection") - } + p.queueMut.Lock() // Ok, now that we own the queue we can add the item there but between getting the lock // and above check for an existing connection another thread might have returned a connection // so check again to avoid potentially starving this thread. @@ -294,9 +278,7 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) case <-q.wakeup: continue case <-ctx.Done(): - if !p.queueMut.TryLock(context.Background()) { - return nil, racing.LockTimeoutError("could not acquire lock in time when removing server wait request") - } + p.queueMut.Lock() p.queue.Remove(e) p.queueMut.Unlock() p.log.Warnf(log.Pool, p.logId, "Borrow time-out") @@ -309,15 +291,14 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. // For now, lock complete servers map to avoid over connecting but with the downside // that long connect times will block connects to other servers as well. To fix this // we would need to add a pending connect to the server and lock per server. - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when borrowing a connection") - } + p.serversMut.Lock() var unlock = new(sync.Once) defer unlock.Do(p.serversMut.Unlock) srv := p.servers[serverName] for { if srv != nil { + srv.closing = false connection := srv.getIdle() if connection == nil { if srv.size() >= p.config.MaxConnectionPoolSize { @@ -330,16 +311,12 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. if healthy { return connection, nil } - if err := p.unreg(context.Background(), serverName, connection, p.Now()); err != nil { - panic("lock with Background context should never time out") - } + p.unreg(ctx, serverName, connection, p.Now()) if err != nil { p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err) return nil, err } - if !p.serversMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire lock in time when borrowing a connection") - } + p.serversMut.Lock() *unlock = sync.Once{} srv = p.servers[serverName] } else { @@ -355,18 +332,16 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. // No idle connection, try to connect p.log.Infof(log.Pool, p.logId, "Connecting to %s", serverName) - c, err := p.connect(ctx, serverName, auth, p.OnConnectionError, boltLogger) - if !p.serversMut.TryLock(context.Background()) { - panic("lock with Background context should never time out") - } + c, err := p.connect(ctx, serverName, auth, p, boltLogger) + p.serversMut.Lock() *unlock = sync.Once{} srv.reservations-- if err != nil { + p.log.Warnf(log.Pool, p.logId, "Failed to connect to %s: %s", serverName, err) // FeatureNotSupportedError is not the server fault, don't penalize it if _, ok := err.(*db.FeatureNotSupportedError); !ok { srv.notifyFailedConnect((*p.now)()) } - p.log.Warnf(log.Pool, p.logId, "Failed to connect to %s: %s", serverName, err) return nil, err } @@ -376,16 +351,13 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. return c, nil } -func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time) error { - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock in time when unregistering server") - } +func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time) { + p.serversMut.Lock() defer p.serversMut.Unlock() - - return p.unregUnlocked(ctx, serverName, c, now) + p.unregUnlocked(ctx, serverName, c, now) } -func (p *Pool) unregUnlocked(ctx context.Context, serverName string, c idb.Connection, now time.Time) error { +func (p *Pool) unregUnlocked(ctx context.Context, serverName string, c idb.Connection, now time.Time) { defer func() { // Close connection in another thread to avoid potential long blocking operation during close. go c.Close(ctx) @@ -395,33 +367,29 @@ func (p *Pool) unregUnlocked(ctx context.Context, serverName string, c idb.Conne // Check for strange condition of not finding the server. if server == nil { p.log.Warnf(log.Pool, p.logId, "Server %s not found", serverName) - return nil + return } server.unregisterBusy(c) if server.size() == 0 && !server.hasFailedConnect(now) { delete(p.servers, serverName) } - return nil } -func (p *Pool) removeIdleOlderThanOnServer(ctx context.Context, serverName string, now time.Time, maxAge time.Duration) error { - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock in time before removing old idle connections") - } +func (p *Pool) removeIdleOlderThanOnServer(ctx context.Context, serverName string, now time.Time, maxAge time.Duration) { + p.serversMut.Lock() defer p.serversMut.Unlock() server := p.servers[serverName] if server == nil { - return nil + return } server.removeIdleOlderThan(ctx, now, maxAge) - return nil } -func (p *Pool) Return(ctx context.Context, c idb.Connection) error { +func (p *Pool) Return(ctx context.Context, c idb.Connection) { if p.closed { p.log.Warnf(log.Pool, p.logId, "Trying to return connection to closed pool") - return nil + return } // Get the name of the server that the connection belongs to. @@ -441,9 +409,7 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error { maxAge = age } } - if err := p.removeIdleOlderThanOnServer(ctx, serverName, now, maxAge); err != nil { - return err - } + p.removeIdleOlderThanOnServer(ctx, serverName, now, maxAge) // Prepare connection for being used by someone else if is alive. // Since reset could find the connection to be in a bad state or non-recoverable state, @@ -457,20 +423,19 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error { // Shouldn't return a too old or dead connection back to the pool if !isAlive || age >= p.config.MaxConnectionLifetime { - if err := p.unreg(ctx, serverName, c, now); err != nil { - return err - } + p.unreg(ctx, serverName, c, now) p.log.Infof(log.Pool, p.logId, "Unregistering dead or too old connection to %s", serverName) } if isAlive { // Just put it back in the list of idle connections for this server - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire server lock when putting connection back to idle") - } + p.serversMut.Lock() server := p.servers[serverName] if server != nil { // Strange when server not found - server.returnBusy(c) + server.returnBusy(ctx, c) + if server.closing && server.size() == 0 { + delete(p.servers, serverName) + } } else { p.log.Warnf(log.Pool, p.logId, "Server %s not found", serverName) } @@ -478,44 +443,70 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error { } // Check if there is anyone in the queue waiting for a connection to this server. - if !p.queueMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire queue lock when checking connection requests") - } + p.queueMut.Lock() for e := p.queue.Front(); e != nil; e = e.Next() { queuedRequest := e.Value.(*qitem) p.queue.Remove(e) queuedRequest.wakeup <- true } p.queueMut.Unlock() - - return nil } -func (p *Pool) OnConnectionError(ctx context.Context, connection idb.Connection, error *db.Neo4jError) error { - if error.Code == "Neo.ClientError.Security.AuthorizationExpired" { +func (p *Pool) OnNeo4jError(ctx context.Context, connection idb.Connection, error *db.Neo4jError) error { + switch error.Code { + case "Neo.ClientError.Security.AuthorizationExpired": serverName := connection.ServerName() - if !p.serversMut.TryLock(ctx) { - return racing.LockTimeoutError(fmt.Sprintf( - "could not acquire server lock in time before marking all connection to %s for re-authentication", - serverName)) - } + p.serversMut.Lock() defer p.serversMut.Unlock() server := p.servers[serverName] server.executeForAllConnections(func(c idb.Connection) { c.ResetAuth() }) - } else if error.Code == "Neo.ClientError.Security.TokenExpired" { + case "Neo.ClientError.Security.TokenExpired": manager, token := connection.GetCurrentAuth() if manager != nil { if err := manager.OnTokenExpired(ctx, token); err != nil { return err } - if _, isStaticToken := manager.(auth.Token); isStaticToken { - return nil - } else { + if _, isStaticToken := manager.(auth.Token); !isStaticToken { error.MarkRetriable() } } + case "Neo.TransientError.General.DatabaseUnavailable": + p.deactivate(ctx, connection.ServerName()) + default: + if error.IsRetriableCluster() { + var database string + if dbSelector, ok := connection.(idb.DatabaseSelector); ok { + database = dbSelector.Database() + } + p.deactivateWriter(connection.ServerName(), database) + } } + return nil } + +func (p *Pool) OnIoError(ctx context.Context, connection idb.Connection, _ error) { + p.deactivate(ctx, connection.ServerName()) +} + +func (p *Pool) OnDialError(ctx context.Context, serverName string, _ error) { + p.deactivate(ctx, serverName) +} + +func (p *Pool) deactivate(ctx context.Context, serverName string) { + p.log.Debugf(log.Pool, p.logId, "Deactivating server %s", serverName) + p.router.InvalidateServer(serverName) + p.serversMut.Lock() + defer p.serversMut.Unlock() + server := p.servers[serverName] + if server != nil { + server.startClosing(ctx) + } +} + +func (p *Pool) deactivateWriter(serverName string, db string) { + p.log.Debugf(log.Pool, p.logId, "Deactivating writer %s for database %s", serverName, db) + p.router.InvalidateWriter(db, serverName) +} diff --git a/neo4j/internal/pool/pool_test.go b/neo4j/internal/pool/pool_test.go index c7596ea6..9bea4863 100644 --- a/neo4j/internal/pool/pool_test.go +++ b/neo4j/internal/pool/pool_test.go @@ -23,9 +23,10 @@ import ( "context" "errors" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + db "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "math/rand" "sync" @@ -38,18 +39,18 @@ import ( var logger = &log.Void{} var ctx = context.Background() -var reAuthToken = &db.ReAuthToken{FromSession: false, Manager: iauth.Token{Tokens: map[string]any{"scheme": "none"}}} +var reAuthToken = &idb.ReAuthToken{FromSession: false, Manager: iauth.Token{Tokens: map[string]any{"scheme": "none"}}} func TestPoolBorrowReturn(outer *testing.T) { maxAge := 1 * time.Second birthdate := time.Now() - succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil } failingError := errors.New("whatever") - failingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + failingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return nil, failingError } @@ -58,19 +59,15 @@ func TestPoolBorrowReturn(outer *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srv1"} conn, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, conn, err) - if err := p.Return(ctx, conn); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, conn) // Make sure that connection actually returned - servers, err := p.getServers(ctx) + servers := p.getServers() if err != nil { t.Errorf("Should not fail retrieving servers, got: %v", err) } @@ -84,9 +81,7 @@ func TestPoolBorrowReturn(outer *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srv1"} wg := sync.WaitGroup{} @@ -105,12 +100,10 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) // Give back the connection - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c1) wg.Wait() }) @@ -119,9 +112,7 @@ func TestPoolBorrowReturn(outer *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srv1"} @@ -151,9 +142,7 @@ func TestPoolBorrowReturn(outer *testing.T) { c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c, err) time.Sleep(time.Duration(rand.Int()%7) * time.Millisecond) - if err := p.Return(ctx, c); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c) } wg.Done() } @@ -164,10 +153,7 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Wait() // Everything should be freed up, it's ok if there isn't a server as well... - servers, err := p.getServers(ctx) - if err != nil { - t.Errorf("Should not fail retrieving server, but got: %v", err) - } + servers := p.getServers() for _, v := range servers { if v.numIdle() != maxConnections { t.Error("A connection is still in use in the server") @@ -179,6 +165,7 @@ func TestPoolBorrowReturn(outer *testing.T) { timer := func() time.Time { return birthdate } conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} p := New(&conf, failingConnect, logger, "pool id", &timer) + p.SetRouter(&RouterFake{}) serverNames := []string{"srv1"} c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertNoConnection(t, c, err) @@ -202,12 +189,10 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) cancel() wg.Wait() - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c1) if err == nil { t.Error("There should be an error due to cancelling") } @@ -226,7 +211,7 @@ func TestPoolBorrowReturn(outer *testing.T) { timer := time.Now conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} pool := New(&conf, nil, logger, "pool id", &timer) - setIdleConnections(pool, map[string][]db.Connection{"a server": { + setIdleConnections(pool, map[string][]idb.Connection{"a server": { deadAfterReset, stayingAlive, whatATimeToBeAlive, @@ -250,7 +235,7 @@ func TestPoolBorrowReturn(outer *testing.T) { timer := time.Now conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} pool := New(&conf, connectTo(healthyConnection), logger, "pool id", &timer) - setIdleConnections(pool, map[string][]db.Connection{serverName: {deadAfterReset1, deadAfterReset2}}) + setIdleConnections(pool, map[string][]idb.Connection{serverName: {deadAfterReset1, deadAfterReset2}}) result, err := pool.tryBorrow(ctx, serverName, nil, idlenessThreshold, reAuthToken) @@ -276,10 +261,10 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) // break the connection. then it shouldn't be picked up by the waiting borrow c1.(*ConnFake).Alive = false - err = p.Return(ctx, c1) + p.Return(ctx, c1) AssertNoError(t, err) wg.Wait() }) @@ -288,7 +273,7 @@ func TestPoolBorrowReturn(outer *testing.T) { token2 := iauth.Token{Tokens: map[string]any{"scheme": "foobar"}} // sanity check AssertNotDeepEquals(t, reAuthToken.Manager, token2) - reAuthToken2 := &db.ReAuthToken{FromSession: false, Manager: token2} + reAuthToken2 := &idb.ReAuthToken{FromSession: false, Manager: token2} timer := func() time.Time { return birthdate } conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) @@ -304,14 +289,14 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Done() }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) reAuthCalled := false - c1.(*ConnFake).ReAuthHook = func(_ context.Context, token *db.ReAuthToken) error { + c1.(*ConnFake).ReAuthHook = func(_ context.Context, token *idb.ReAuthToken) error { AssertDeepEquals(t, token.Manager, token2) reAuthCalled = true return nil } - err = p.Return(ctx, c1) + p.Return(ctx, c1) AssertNoError(t, err) wg.Wait() AssertTrue(t, reAuthCalled) @@ -323,7 +308,7 @@ func TestPoolResourceUsage(ot *testing.T) { maxAge := 1 * time.Second birthdate := time.Now() - succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil } @@ -332,9 +317,7 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA", "srvB", "srvC", "srvD"} c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) @@ -348,20 +331,13 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA"} c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) c.(*ConnFake).Alive = false - if err := p.Return(ctx, c); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - servers, err := p.getServers(ctx) - if err != nil { - t.Errorf("Should not fail retrieving server but got: %v", err) - } + p.Return(ctx, c) + servers := p.getServers() if len(servers) > 0 && servers[serverNames[0]].size() > 0 { t.Errorf("Should have either removed the server or kept it but emptied it") } @@ -372,19 +348,12 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA"} c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) - if err := p.Return(ctx, c); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - servers, err := p.getServers(ctx) - if err != nil { - t.Errorf("Should not fail retrieving server but got: %v", err) - } + p.Return(ctx, c) + servers := p.getServers() if len(servers) > 0 && servers[serverNames[0]].size() > 0 { t.Errorf("Should have either removed the server or kept it but emptied it") } @@ -407,20 +376,14 @@ func TestPoolResourceUsage(ot *testing.T) { c3.(*ConnFake).Birth = nowTime.Add(1 * time.Second) c3.(*ConnFake).Id = 3 // Return the old and young connections to make them idle - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - if err := p.Return(ctx, c3); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 1) - assertNumberOfIdle(t, ctx, p, "A", 2) + p.Return(ctx, c1) + p.Return(ctx, c3) + assertNumberOfServers(t, p, 1) + assertNumberOfIdle(t, p, "A", 2) // Kill the middle-aged connection and return it c2.(*ConnFake).Alive = false - if err := p.Return(ctx, c2); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfIdle(t, ctx, p, "A", 1) + p.Return(ctx, c2) + assertNumberOfIdle(t, p, "A", 1) }) ot.Run("Do not borrow too old connections", func(t *testing.T) { @@ -434,17 +397,13 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() serverNames := []string{"srvA"} c1, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) c1.(*ConnFake).Id = 123 // It's alive when returning it - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } + p.Return(ctx, c1) nowMut.Lock() now = now.Add(2 * maxAge) nowMut.Unlock() @@ -460,27 +419,25 @@ func TestPoolResourceUsage(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c1, err) c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c2, err) - assertNumberOfServers(t, ctx, p, 2) + assertNumberOfServers(t, p, 2) }) } func TestPoolCleanup(ot *testing.T) { birthdate := time.Now() maxLife := 1 * time.Second - succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil } // Borrows a connection in server A and another in server B - borrowConnections := func(t *testing.T, p *Pool) (db.Connection, db.Connection) { + borrowConnections := func(t *testing.T, p *Pool) (idb.Connection, idb.Connection) { c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertConnection(t, c1, err) c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) @@ -493,27 +450,19 @@ func TestPoolCleanup(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() c1, c2 := borrowConnections(t, p) - if err := p.Return(ctx, c1); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - if err := p.Return(ctx, c2); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 2) - assertNumberOfIdle(t, ctx, p, "A", 1) - assertNumberOfIdle(t, ctx, p, "B", 1) + p.Return(ctx, c1) + p.Return(ctx, c2) + assertNumberOfServers(t, p, 2) + assertNumberOfIdle(t, p, "A", 1) + assertNumberOfIdle(t, p, "B", 1) // Now go into the future and cleanup, should remove both servers and close the connections timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 0) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 0) }) ot.Run("Should not remove servers with busy connections", func(t *testing.T) { @@ -521,59 +470,48 @@ func TestPoolCleanup(ot *testing.T) { conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} p := New(&conf, succeedingConnect, logger, "pool id", &timer) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() _, c2 := borrowConnections(t, p) - if err := p.Return(ctx, c2); err != nil { - t.Errorf("Should not fail returning connection to pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 2) - assertNumberOfIdle(t, ctx, p, "A", 0) - assertNumberOfIdle(t, ctx, p, "B", 1) + p.Return(ctx, c2) + assertNumberOfServers(t, p, 2) + assertNumberOfIdle(t, p, "A", 0) + assertNumberOfIdle(t, p, "B", 1) // Now go into the future and cleanup, should only remove B timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 1) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 1) }) ot.Run("Should not remove servers with only idle connections but with recent connect failures ", func(t *testing.T) { - failingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { + failingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return nil, errors.New("an error") } timer := time.Now conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} p := New(&conf, failingConnect, logger, "pool id", &timer) + p.SetRouter(&RouterFake{}) defer func() { - if err := p.Close(ctx); err != nil { - t.Errorf("Should not fail closing the pool, but got: %v", err) - } + p.Close(ctx) }() c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) assertNoConnection(t, c1, err) - assertNumberOfServers(t, ctx, p, 1) - assertNumberOfIdle(t, ctx, p, "A", 0) + assertNumberOfServers(t, p, 1) + assertNumberOfIdle(t, p, "A", 0) // Now go into the future and cleanup, should not remove server A even if it has no connections since // we should remember the failure a bit longer timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 1) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 1) // Further in the future, the failure should have been forgotten timer = func() time.Time { return birthdate.Add(maxLife).Add(rememberFailedConnectDuration).Add(1 * time.Second) } - if err := p.CleanUp(ctx); err != nil { - t.Errorf("Should not fail cleaning up the pool, but got: %v", err) - } - assertNumberOfServers(t, ctx, p, 0) + p.CleanUp(ctx) + assertNumberOfServers(t, p, 0) }) ot.Run("wakes up borrowers when closing", func(t *testing.T) { @@ -592,9 +530,9 @@ func TestPoolCleanup(ot *testing.T) { _, err := p.Borrow(ctx, servers, true, nil, DefaultLivenessCheckThreshold, reAuthToken) borrowErrChan <- err }() - waitForBorrowers(t, p, 1) + waitForBorrowers(p, 1) - AssertNoError(t, p.Close(ctx)) + p.Close(ctx) select { case err := <-borrowErrChan: @@ -605,13 +543,115 @@ func TestPoolCleanup(ot *testing.T) { }) } -func connectTo(singleConnection *ConnFake) func(ctx context.Context, name string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { - return func(ctx context.Context, name string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) { +func TestPoolErrorHanding(ot *testing.T) { + const ServerName = "A" + const DbName = "some database" + + type TestCase struct { + name string + errorCall func(bolt.ConnectionErrorListener, idb.Connection) error + expectedInvalidateMode string + expectedInvalidatedServer string + expectedInvalidatedDb string + } + + dbUnavailableErr := db.Neo4jError{Code: "Neo.TransientError.General.DatabaseUnavailable"} + notALeaderErr := db.Neo4jError{Code: "Neo.ClientError.Cluster.NotALeader"} + forbiddenRoDbError := db.Neo4jError{Code: "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase"} + + cases := []TestCase{ + { + name: "should invalidate server on io error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + listener.OnIoError(ctx, conn, errors.New("an error")) + return nil + }, + expectedInvalidatedServer: ServerName, + }, + { + name: "should invalidate server on dial error", + errorCall: func(listener bolt.ConnectionErrorListener, _ idb.Connection) error { + listener.OnDialError(ctx, "what ever server", errors.New("an error")) + return nil + }, + expectedInvalidatedServer: "what ever server", + }, + { + name: "should invalidate server on dial error", + errorCall: func(listener bolt.ConnectionErrorListener, _ idb.Connection) error { + listener.OnDialError(ctx, ServerName, errors.New("an error")) + return nil + }, + expectedInvalidatedServer: ServerName, + }, + { + name: "should invalidate server on db unavailable error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + return listener.OnNeo4jError(ctx, conn, &dbUnavailableErr) + }, + expectedInvalidatedServer: ServerName, + }, + { + name: "should invalidate writer for db on not a leader error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + return listener.OnNeo4jError(ctx, conn, ¬ALeaderErr) + }, + expectedInvalidateMode: "writer", + expectedInvalidatedServer: ServerName, + expectedInvalidatedDb: DbName, + }, + { + name: "should invalidate writer for db on forbidden error", + errorCall: func(listener bolt.ConnectionErrorListener, conn idb.Connection) error { + return listener.OnNeo4jError(ctx, conn, &forbiddenRoDbError) + }, + expectedInvalidateMode: "writer", + expectedInvalidatedServer: ServerName, + expectedInvalidatedDb: DbName, + }, + } + + for _, testCase := range cases { + ot.Run(testCase.name, func(t *testing.T) { + errorListeners := make([]bolt.ConnectionErrorListener, 0) + connections := make([]*ConnFake, 0) + succeedingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, errorListener bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { + errorListeners = append(errorListeners, errorListener) + connection := ConnFake{Name: s, Alive: true, DatabaseName: DbName} + connections = append(connections, &connection) + return &connection, nil + } + + now := time.Now + router := RouterFake{} + p := New(&config.Config{}, succeedingConnect, logger, "pool id", &now) + p.SetRouter(&router) + defer p.Close(ctx) + conn, err := p.Borrow(ctx, getServers([]string{ServerName}), false, nil, DefaultLivenessCheckThreshold, reAuthToken) + assertConnection(t, conn, err) + AssertLen(t, errorListeners, 1) + AssertLen(t, connections, 1) + errorListener := errorListeners[0] + connection := connections[0] + AssertFalse(t, router.Invalidated) + + err = testCase.errorCall(errorListener, connection) + AssertNoError(t, err) + AssertTrue(t, router.Invalidated) + AssertStringEqual(t, router.InvalidateMode, testCase.expectedInvalidateMode) + AssertStringEqual(t, router.InvalidatedServer, testCase.expectedInvalidatedServer) + AssertStringEqual(t, router.InvalidatedDb, testCase.expectedInvalidatedDb) + }) + } +} + +func connectTo(singleConnection *ConnFake) func(ctx context.Context, name string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { + return func(ctx context.Context, name string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return singleConnection, nil } } -func setIdleConnections(pool *Pool, servers map[string][]db.Connection) { +func setIdleConnections(pool *Pool, servers map[string][]idb.Connection) { poolServers := make(map[string]*server, len(servers)) for serverName, connections := range servers { srv := NewServer() @@ -633,18 +673,14 @@ func deadConnectionAfterForceReset(name string, idleness time.Time) *ConnFake { return result } -func getServers(servers []string) func(context.Context) ([]string, error) { - return func(context.Context) ([]string, error) { - return servers, nil +func getServers(servers []string) func() []string { + return func() []string { + return servers } } -func waitForBorrowers(t *testing.T, p *Pool, minBorrowers int) { - for { - if size, err := p.queueSize(ctx); err != nil { - t.Errorf("should not fail computing queue size, got: %v", err) - } else if size >= minBorrowers { - break - } +func waitForBorrowers(p *Pool, minBorrowers int) { + for p.queueSize() < minBorrowers { + // still waiting } } diff --git a/neo4j/internal/pool/server.go b/neo4j/internal/pool/server.go index 2ac1b53c..8e5de873 100644 --- a/neo4j/internal/pool/server.go +++ b/neo4j/internal/pool/server.go @@ -37,6 +37,7 @@ type server struct { reservations int failedConnectAt time.Time roundRobin uint32 + closing bool } func NewServer() *server { @@ -137,9 +138,13 @@ func (s *server) calculatePenalty(now time.Time) uint32 { } // Returns a busy connection, makes it idle -func (s *server) returnBusy(c db.Connection) { +func (s *server) returnBusy(ctx context.Context, c db.Connection) { s.unregisterBusy(c) - s.idle.PushFront(c) + if s.closing { + c.Close(ctx) + } else { + s.idle.PushFront(c) + } } // Number of idle connections @@ -206,10 +211,15 @@ func (s *server) executeForAllConnections(callback func(c db.Connection)) { } } +func (s *server) startClosing(ctx context.Context) { + s.closing = true + closeAndEmptyConnections(ctx, s.idle) +} + func closeAndEmptyConnections(ctx context.Context, l list.List) { for e := l.Front(); e != nil; e = e.Next() { c := e.Value.(db.Connection) - c.Close(ctx) + go c.Close(ctx) } l.Init() } diff --git a/neo4j/internal/pool/server_test.go b/neo4j/internal/pool/server_test.go index f9a42757..46c26f24 100644 --- a/neo4j/internal/pool/server_test.go +++ b/neo4j/internal/pool/server_test.go @@ -80,7 +80,7 @@ func TestServer(ot *testing.T) { c3 := s.getIdle() assertNilConnection(t, c3) - s.returnBusy(c2) + s.returnBusy(context.Background(), c2) c3 = s.getIdle() assertConnection(t, c3) }) @@ -110,8 +110,8 @@ func TestServer(ot *testing.T) { assertNilConnection(t, b3) // Return the connections and let all of them be too old - s.returnBusy(b1) - s.returnBusy(b2) + s.returnBusy(context.Background(), b1) + s.returnBusy(context.Background(), b2) conns[0].Birth = now.Add(-20 * time.Second) conns[2].Birth = now.Add(-20 * time.Second) s.removeIdleOlderThan(context.Background(), now, 10*time.Second) @@ -146,7 +146,7 @@ func TestServerPenalty(t *testing.T) { // Return the busy connection to srv1 // Now srv2 should have higher penalty than srv1 since using srv2 would require a new // connection. - srv1.returnBusy(c11) + srv1.returnBusy(context.Background(), c11) assertPenaltiesGreaterThan(srv2, srv1, now) // Add an idle connection to srv2 to make both servers have one idle connection each. @@ -162,7 +162,7 @@ func TestServerPenalty(t *testing.T) { idle := srv1.getIdle() _, _ = srv1.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) testutil.AssertDeepEquals(t, idle, c11) - srv1.returnBusy(c11) + srv1.returnBusy(context.Background(), c11) assertPenaltiesGreaterThan(srv1, srv2, now) // Add one more connection each to the servers @@ -187,10 +187,10 @@ func TestServerPenalty(t *testing.T) { // Return the connections idle = srv2.getIdle() _, _ = srv2.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) - srv2.returnBusy(c21) - srv2.returnBusy(c22) - srv1.returnBusy(c11) - srv1.returnBusy(c12) + srv2.returnBusy(context.Background(), c21) + srv2.returnBusy(context.Background(), c22) + srv1.returnBusy(context.Background(), c11) + srv1.returnBusy(context.Background(), c12) // Everything returned, srv2 should have higher penalty since it was last used assertPenaltiesGreaterThan(srv2, srv1, now) @@ -290,5 +290,5 @@ func TestIdlenessThreshold(outer *testing.T) { func registerIdle(srv *server, connection db.Connection) { srv.registerBusy(connection) - srv.returnBusy(connection) + srv.returnBusy(context.Background(), connection) } diff --git a/neo4j/internal/retry/state.go b/neo4j/internal/retry/state.go index 29de6e46..c4b5329c 100644 --- a/neo4j/internal/retry/state.go +++ b/neo4j/internal/retry/state.go @@ -32,10 +32,6 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) -type Router interface { - Invalidate(ctx context.Context, database string) error -} - type State struct { Errs []error MaxTransactionRetryTime time.Duration @@ -46,26 +42,22 @@ type State struct { Sleep func(time.Duration) Throttle Throttler MaxDeadConnections int - Router Router DatabaseName string - start time.Time - cause string - deadErrors int - skipSleep bool - OnDeadConnection func(server string) error + start time.Time + cause string + deadErrors int + skipSleep bool } -func (s *State) OnFailure(ctx context.Context, err error, conn idb.Connection, isCommitting bool) { +func (s *State) OnFailure(_ context.Context, err error, conn idb.Connection, isCommitting bool) { if conn != nil && !conn.IsAlive() { if isCommitting { + // FIXME: CommitFailedDeadError should be returned even when not using transaction functions s.Errs = append(s.Errs, &errorutil.CommitFailedDeadError{Inner: err}) } else { s.Errs = append(s.Errs, err) } - if err := s.OnDeadConnection(conn.ServerName()); err != nil { - s.Errs = append(s.Errs, err) - } s.deadErrors += 1 s.skipSleep = true return @@ -73,13 +65,6 @@ func (s *State) OnFailure(ctx context.Context, err error, conn idb.Connection, i s.Errs = append(s.Errs, err) s.skipSleep = false - - if dbErr, isDbErr := err.(*db.Neo4jError); isDbErr && dbErr.IsRetriableCluster() { - if err := s.Router.Invalidate(ctx, s.DatabaseName); err != nil { - s.Errs = append(s.Errs, err) - } - } - } func (s *State) Continue() bool { diff --git a/neo4j/internal/retry/state_test.go b/neo4j/internal/retry/state_test.go index 93a08302..8b0cc4dc 100644 --- a/neo4j/internal/retry/state_test.go +++ b/neo4j/internal/retry/state_test.go @@ -35,16 +35,13 @@ import ( ) type TStateInvocation struct { - conn idb.Connection - err error - isCommitting bool - now time.Time - expectContinued bool - expectRouterInvalidated bool - expectRouterInvalidatedDb string - expectRouterInvalidatedServer string - expectLastErrWasRetryable bool - expectLastErrType error + conn idb.Connection + err error + isCommitting bool + now time.Time + expectContinued bool + expectLastErrWasRetryable bool + expectLastErrType error } func TestState(outer *testing.T) { @@ -78,39 +75,29 @@ func TestState(outer *testing.T) { "Retry dead connection": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error"), expectContinued: false, - expectLastErrWasRetryable: false, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: false}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, - expectContinued: true, expectLastErrWasRetryable: true, - expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectContinued: true, expectLastErrWasRetryable: true}, }, "Retry dead connection timeout": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, - expectContinued: true, now: baseTime, expectLastErrWasRetryable: true, - expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, - expectRouterInvalidatedServer: serverName}, + expectContinued: true, now: baseTime, expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 2"), expectContinued: false, now: overTime, - expectLastErrWasRetryable: false, - expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: false}, }, "Retry dead connection max": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: true, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: true, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, expectContinued: false, - expectLastErrWasRetryable: true, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: true}, }, "Cluster error": { {conn: &testutil.ConnFake{Alive: true}, err: clusterErr, expectContinued: true, - expectRouterInvalidated: true, expectRouterInvalidatedDb: dbName, expectLastErrWasRetryable: true}, + expectLastErrWasRetryable: true}, }, "Database transient error": { {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: true, @@ -128,17 +115,14 @@ func TestState(outer *testing.T) { }, "Fail during commit": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: io.EOF, isCommitting: true, expectContinued: false, - expectLastErrWasRetryable: false, expectLastErrType: &errorutil.CommitFailedDeadError{}, - expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrWasRetryable: false, expectLastErrType: &errorutil.CommitFailedDeadError{}}, }, "Fail during commit after retry": { {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: true, expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: io.EOF, isCommitting: true, expectContinued: false, expectLastErrWasRetryable: false, - expectLastErrType: &errorutil.CommitFailedDeadError{}, expectRouterInvalidated: true, - expectRouterInvalidatedDb: dbName, expectRouterInvalidatedServer: serverName}, + expectLastErrType: &errorutil.CommitFailedDeadError{}}, }, "Does not retry on auth errors": { {conn: nil, err: authErr, expectContinued: false, @@ -176,24 +160,12 @@ func TestState(outer *testing.T) { if !invocation.now.IsZero() { now = invocation.now } - router := &testutil.RouterFake{} - state.Router = router - state.OnDeadConnection = func(server string) error { - return router.InvalidateReader(ctx, dbName, server) - } state.OnFailure(ctx, invocation.err, invocation.conn, invocation.isCommitting) continued := state.Continue() if continued != invocation.expectContinued { t.Errorf("Expected continue to return %v but returned %v", invocation.expectContinued, continued) } - if invocation.expectRouterInvalidated != router.Invalidated || - invocation.expectRouterInvalidatedDb != router.InvalidatedDb || - invocation.expectRouterInvalidatedServer != router.InvalidatedServer { - t.Errorf("Expected router invalidated: expected (%v/%s/%s) vs. actual (%v/%s/%s)", - invocation.expectRouterInvalidated, invocation.expectRouterInvalidatedDb, invocation.expectRouterInvalidatedServer, - router.Invalidated, router.InvalidatedDb, router.InvalidatedServer) - } var lastError error if err, ok := state.Errs[0].(*errorutil.TransactionExecutionLimit); ok { errs := err.Errors diff --git a/neo4j/internal/router/no_test.go b/neo4j/internal/router/no_test.go index c64102ea..587957a6 100644 --- a/neo4j/internal/router/no_test.go +++ b/neo4j/internal/router/no_test.go @@ -32,15 +32,11 @@ type poolFake struct { cancel context.CancelFunc } -func (p *poolFake) Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), _ bool, logger log.BoltLogger, _ time.Duration, _ *db.ReAuthToken) (db.Connection, error) { - servers, err := getServers(ctx) - if err != nil { - return nil, err - } +func (p *poolFake) Borrow(_ context.Context, getServers func() []string, _ bool, logger log.BoltLogger, _ time.Duration, _ *db.ReAuthToken) (db.Connection, error) { + servers := getServers() return p.borrow(servers, p.cancel, logger) } -func (p *poolFake) Return(_ context.Context, c db.Connection) error { +func (p *poolFake) Return(_ context.Context, c db.Connection) { p.returned = append(p.returned, c) - return nil } diff --git a/neo4j/internal/router/readtable.go b/neo4j/internal/router/readtable.go index bffc8420..15f12907 100644 --- a/neo4j/internal/router/readtable.go +++ b/neo4j/internal/router/readtable.go @@ -75,8 +75,8 @@ func readTable( return nil, err } -func getStaticServer(server string) func(context.Context) ([]string, error) { - return func(context.Context) ([]string, error) { - return []string{server}, nil +func getStaticServer(server string) func() []string { + return func() []string { + return []string{server} } } diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index ac82dcb5..ec68d689 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -26,6 +26,7 @@ import ( idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" + "sync" "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" @@ -44,7 +45,8 @@ type Router struct { routerContext map[string]string pool Pool dbRouters map[string]*databaseRouter - dbRoutersMut racing.Mutex + updating map[string][]chan struct{} + dbRoutersMut sync.Mutex now *func() time.Time sleep func(time.Duration) rootRouter string @@ -58,8 +60,8 @@ type Pool interface { // If all connections are busy and the pool is full, calls to Borrow may wait for a connection to become idle // If a connection has been idle for longer than idlenessThreshold, it will be reset // to check if it's still alive. - Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) - Return(ctx context.Context, c idb.Connection) error + Borrow(ctx context.Context, getServers func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Return(ctx context.Context, c idb.Connection) } func New(rootRouter string, getRouters func() []string, routerContext map[string]string, pool Pool, logger log.Logger, logId string, timer *func() time.Time) *Router { @@ -69,7 +71,8 @@ func New(rootRouter string, getRouters func() []string, routerContext map[string routerContext: routerContext, pool: pool, dbRouters: make(map[string]*databaseRouter), - dbRoutersMut: racing.NewMutex(), + updating: make(map[string][]chan struct{}), + dbRoutersMut: sync.Mutex{}, now: timer, sleep: time.Sleep, log: logger, @@ -139,40 +142,63 @@ func (r *Router) readTable( return table, nil } -func (r *Router) getTable(ctx context.Context, database string) (*idb.RoutingTable, error) { - if !r.dbRoutersMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire router lock in time when getting routing table") - } +func (r *Router) getTable(database string) *idb.RoutingTable { + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() dbRouter := r.dbRouters[database] - return r.getTableLocked(dbRouter, (*r.now)()), nil + return r.getTableLocked(dbRouter) } func (r *Router) getOrUpdateTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) (*idb.RoutingTable, error) { - now := (*r.now)() - - if !r.dbRoutersMut.TryLock(ctx) { - return nil, racing.LockTimeoutError("could not acquire router lock in time when getting routing table") - } - defer r.dbRoutersMut.Unlock() - - dbRouter := r.dbRouters[database] - if table := r.getTableLocked(dbRouter, now); table != nil { - return table, nil + r.dbRoutersMut.Lock() + var unlock = new(sync.Once) + defer unlock.Do(r.dbRoutersMut.Unlock) + for { + dbRouter := r.dbRouters[database] + if table := r.getTableLocked(dbRouter); table != nil { + return table, nil + } + waiters, ok := r.updating[database] + if ok { + // Wait for the table to be updated by other goroutine + ch := make(chan struct{}) + r.updating[database] = append(waiters, ch) + unlock.Do(r.dbRoutersMut.Unlock) + select { + case <-ctx.Done(): + return nil, racing.LockTimeoutError("timed out waiting for other goroutine to update routing table") + case <-ch: + r.dbRoutersMut.Lock() + *unlock = sync.Once{} + continue + } + } + // this goroutine will update the table + r.updating[database] = make([]chan struct{}, 0) + unlock.Do(r.dbRoutersMut.Unlock) + + table, err := r.updateTable(ctx, bookmarksFn, database, auth, boltLogger, dbRouter) + r.dbRoutersMut.Lock() + *unlock = sync.Once{} + // notify all waiters + for _, waiter := range r.updating[database] { + waiter <- struct{}{} + } + delete(r.updating, database) + return table, err } - - return r.updateTable(ctx, bookmarksFn, database, auth, boltLogger, dbRouter, now) } -func (r *Router) getTableLocked(dbRouter *databaseRouter, now time.Time) *idb.RoutingTable { +func (r *Router) getTableLocked(dbRouter *databaseRouter) *idb.RoutingTable { + now := (*r.now)() if dbRouter != nil && now.Unix() < dbRouter.dueUnix { return dbRouter.table } return nil } -func (r *Router) updateTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger, dbRouter *databaseRouter, now time.Time) (*idb.RoutingTable, error) { +func (r *Router) updateTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger, dbRouter *databaseRouter) (*idb.RoutingTable, error) { bookmarks, err := bookmarksFn(ctx) if err != nil { return nil, err @@ -182,7 +208,10 @@ func (r *Router) updateTable(ctx context.Context, bookmarksFn func(context.Conte return nil, err } - r.storeRoutingTable(database, table, now) + err = r.storeRoutingTable(ctx, database, table, (*r.now)()) + if err != nil { + return nil, err + } return table, nil } @@ -201,9 +230,7 @@ func (r *Router) GetOrUpdateReaders(ctx context.Context, bookmarks func(context. break } r.log.Infof(log.Router, r.logId, "Invalidating routing table, no readers") - if err := r.Invalidate(ctx, table.DatabaseName); err != nil { - return nil, err - } + r.Invalidate(table.DatabaseName) r.sleep(100 * time.Millisecond) table, err = r.getOrUpdateTable(ctx, bookmarks, database, auth, boltLogger) if err != nil { @@ -217,15 +244,12 @@ func (r *Router) GetOrUpdateReaders(ctx context.Context, bookmarks func(context. return table.Readers, nil } -func (r *Router) Readers(ctx context.Context, database string) ([]string, error) { - table, err := r.getTable(ctx, database) - if err != nil { - return nil, err - } +func (r *Router) Readers(database string) []string { + table := r.getTable(database) if table == nil { - return nil, nil + return nil } - return table.Readers, nil + return table.Readers } func (r *Router) GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error) { @@ -242,9 +266,7 @@ func (r *Router) GetOrUpdateWriters(ctx context.Context, bookmarks func(context. break } r.log.Infof(log.Router, r.logId, "Invalidating routing table, no writers") - if err := r.Invalidate(ctx, database); err != nil { - return nil, err - } + r.Invalidate(database) r.sleep(100 * time.Millisecond) table, err = r.getOrUpdateTable(ctx, bookmarks, database, auth, boltLogger) if err != nil { @@ -258,29 +280,26 @@ func (r *Router) GetOrUpdateWriters(ctx context.Context, bookmarks func(context. return table.Writers, nil } -func (r *Router) Writers(ctx context.Context, database string) ([]string, error) { - table, err := r.getTable(ctx, database) - if err != nil { - return nil, err - } +func (r *Router) Writers(database string) []string { + table := r.getTable(database) if table == nil { - return nil, nil + return nil } - return table.Writers, nil + return table.Writers } func (r *Router) GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) (string, error) { + // FIXME: this seems to indirectly cache the home db for the routing table's TTL table, err := r.readTable(ctx, nil, bookmarks, idb.DefaultDatabase, user, auth, boltLogger) if err != nil { return "", err } // Store the fresh routing table as well to avoid another roundtrip to receive servers from session. now := (*r.now)() - if !r.dbRoutersMut.TryLock(ctx) { - return "", racing.LockTimeoutError("could not acquire router lock in time when resolving home database") + err = r.storeRoutingTable(ctx, table.DatabaseName, table, now) + if err != nil { + return "", err } - defer r.dbRoutersMut.Unlock() - r.storeRoutingTable(table.DatabaseName, table, now) return table.DatabaseName, err } @@ -288,11 +307,9 @@ func (r *Router) Context() map[string]string { return r.routerContext } -func (r *Router) Invalidate(ctx context.Context, database string) error { +func (r *Router) Invalidate(database string) { r.log.Infof(log.Router, r.logId, "Invalidating routing table for '%s'", database) - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when invalidating database router") - } + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() // Reset due time to the 70s, this will make next access refresh the routing table using // last set of routers instead of the original one. @@ -300,55 +317,53 @@ func (r *Router) Invalidate(ctx context.Context, database string) error { if dbRouter != nil { dbRouter.dueUnix = 0 } - return nil } -func (r *Router) InvalidateWriter(ctx context.Context, db string, server string) error { - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when getting routing table") - } +func (r *Router) InvalidateWriter(db string, server string) { + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() router := r.dbRouters[db] if router == nil { - return nil + return } - writers := router.table.Writers - for i, writer := range writers { - if writer == server { - router.table.Writers = append(writers[0:i], writers[i+1:]...) - return nil - } - } - return nil + router.table.Writers = removeServerFromList(router.table.Writers, server) } -func (r *Router) InvalidateReader(ctx context.Context, db string, server string) error { - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when invalidating reader") - } +func (r *Router) InvalidateReader(db string, server string) { + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() router := r.dbRouters[db] if router == nil { - return nil + return } - readers := router.table.Readers - for i, reader := range readers { - if reader == server { - router.table.Readers = append(readers[0:i], readers[i+1:]...) - return nil + router.table.Readers = removeServerFromList(router.table.Readers, server) +} + +func (r *Router) InvalidateServer(server string) { + r.dbRoutersMut.Lock() + defer r.dbRoutersMut.Unlock() + for _, routing := range r.dbRouters { + routing.table.Routers = removeServerFromList(routing.table.Routers, server) + routing.table.Readers = removeServerFromList(routing.table.Readers, server) + routing.table.Writers = removeServerFromList(routing.table.Writers, server) + } +} + +func removeServerFromList(list []string, server string) []string { + for i, s := range list { + if s == server { + return append(list[0:i], list[i+1:]...) } } - return nil + return list } -func (r *Router) CleanUp(ctx context.Context) error { +func (r *Router) CleanUp() { r.log.Debugf(log.Router, r.logId, "Cleaning up") now := (*r.now)().Unix() - if !r.dbRoutersMut.TryLock(ctx) { - return racing.LockTimeoutError("could not acquire router lock in time when invalidating reader") - } + r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() for dbName, dbRouter := range r.dbRouters { @@ -356,15 +371,17 @@ func (r *Router) CleanUp(ctx context.Context) error { delete(r.dbRouters, dbName) } } - return nil } -func (r *Router) storeRoutingTable(database string, table *idb.RoutingTable, now time.Time) { +func (r *Router) storeRoutingTable(ctx context.Context, database string, table *idb.RoutingTable, now time.Time) error { + r.dbRoutersMut.Lock() + defer r.dbRoutersMut.Unlock() r.dbRouters[database] = &databaseRouter{ table: table, dueUnix: now.Add(time.Duration(table.TimeToLive) * time.Second).Unix(), } r.log.Debugf(log.Router, r.logId, "New routing table for '%s', TTL %d", database, table.TimeToLive) + return nil } func wrapError(server string, err error) error { diff --git a/neo4j/internal/router/router_test.go b/neo4j/internal/router/router_test.go index 43323931..98eda90d 100644 --- a/neo4j/internal/router/router_test.go +++ b/neo4j/internal/router/router_test.go @@ -140,9 +140,7 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) { assertNum(t, numfetch, 2, "Should not have have fetched") // Invalidate should force fetching - if err := router.Invalidate(ctx, dbName); err != nil { - testutil.AssertNoError(t, err) - } + router.Invalidate(dbName) if _, err := router.GetOrUpdateReaders(ctx, nilBookmarks, dbName, nil, nil); err != nil { testutil.AssertNoError(t, err) } @@ -387,17 +385,13 @@ func TestCleanUp(t *testing.T) { } // Should not remove these since they still have time to live - if err := router.CleanUp(ctx); err != nil { - testutil.AssertNoError(t, err) - } + router.CleanUp() if len(router.dbRouters) != 2 { t.Fatal("Should not have removed routing tables") } timer = func() time.Time { return now.Add(1 * time.Minute) } - if err := router.CleanUp(ctx); err != nil { - testutil.AssertNoError(t, err) - } + router.CleanUp() if len(router.dbRouters) != 0 { t.Fatal("Should have cleaned up") } diff --git a/neo4j/internal/router/router_testkit.go b/neo4j/internal/router/router_testkit.go new file mode 100644 index 00000000..e4f0d2f2 --- /dev/null +++ b/neo4j/internal/router/router_testkit.go @@ -0,0 +1,28 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//go:build internal_testkit + +package router + +import idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + +func (r *Router) GetTable(database string) *idb.RoutingTable { + return r.getTable(database) +} diff --git a/neo4j/internal/testutil/connfake.go b/neo4j/internal/testutil/connfake.go index 906c64f2..577ec429 100644 --- a/neo4j/internal/testutil/connfake.go +++ b/neo4j/internal/testutil/connfake.go @@ -193,6 +193,10 @@ func (c *ConnFake) SelectDatabase(database string) { c.DatabaseName = database } +func (c *ConnFake) Database() string { + return c.DatabaseName +} + func (c *ConnFake) SetBoltLogger(log.BoltLogger) { } diff --git a/neo4j/internal/testutil/poolfake.go b/neo4j/internal/testutil/poolfake.go index 8e836ba7..0c569ca4 100644 --- a/neo4j/internal/testutil/poolfake.go +++ b/neo4j/internal/testutil/poolfake.go @@ -34,7 +34,7 @@ type PoolFake struct { BorrowHook func() (db.Connection, error) } -func (p *PoolFake) Borrow(context.Context, func(context.Context) ([]string, error), bool, log.BoltLogger, time.Duration, *db.ReAuthToken) (db.Connection, error) { +func (p *PoolFake) Borrow(context.Context, func() []string, bool, log.BoltLogger, time.Duration, *db.ReAuthToken) (db.Connection, error) { if p.BorrowHook != nil && (p.BorrowConn != nil || p.BorrowErr != nil) { panic("either use the hook or the desired return values, but not both") } @@ -44,18 +44,16 @@ func (p *PoolFake) Borrow(context.Context, func(context.Context) ([]string, erro return p.BorrowConn, p.BorrowErr } -func (p *PoolFake) Return(context.Context, db.Connection) error { +func (p *PoolFake) Return(context.Context, db.Connection) { if p.ReturnHook != nil { p.ReturnHook() } - return nil } -func (p *PoolFake) CleanUp(context.Context) error { +func (p *PoolFake) CleanUp(context.Context) { if p.CleanUpHook != nil { p.CleanUpHook() } - return nil } func (p *PoolFake) Now() time.Time { diff --git a/neo4j/internal/testutil/routerfake.go b/neo4j/internal/testutil/routerfake.go index 83a4ec3e..1c3422c6 100644 --- a/neo4j/internal/testutil/routerfake.go +++ b/neo4j/internal/testutil/routerfake.go @@ -28,6 +28,8 @@ import ( type RouterFake struct { Invalidated bool InvalidatedDb string + InvalidateMode string + InvalidatedServer string GetOrUpdateReadersRet []string GetOrUpdateReadersHook func(bookmarks func(context.Context) ([]string, error), database string) ([]string, error) GetOrUpdateWritersRet []string @@ -35,25 +37,28 @@ type RouterFake struct { Err error CleanUpHook func() GetNameOfDefaultDbHook func(user string) (string, error) - InvalidatedServer string } -func (r *RouterFake) InvalidateReader(ctx context.Context, database string, server string) error { - if err := r.Invalidate(ctx, database); err != nil { - return err - } +func (r *RouterFake) InvalidateReader(database string, server string) { + r.Invalidate(database) r.InvalidatedServer = server - return nil + r.InvalidateMode = "reader" } -func (r *RouterFake) InvalidateWriter(context.Context, string, string) error { - return nil +func (r *RouterFake) InvalidateWriter(database string, server string) { + r.Invalidate(database) + r.InvalidatedServer = server + r.InvalidateMode = "writer" +} + +func (r *RouterFake) InvalidateServer(server string) { + r.Invalidated = true + r.InvalidatedServer = server } -func (r *RouterFake) Invalidate(ctx context.Context, database string) error { +func (r *RouterFake) Invalidate(database string) { r.InvalidatedDb = database r.Invalidated = true - return nil } func (r *RouterFake) GetOrUpdateReaders(_ context.Context, bookmarksFn func(context.Context) ([]string, error), database string, _ *db.ReAuthToken, _ log.BoltLogger) ([]string, error) { @@ -63,8 +68,8 @@ func (r *RouterFake) GetOrUpdateReaders(_ context.Context, bookmarksFn func(cont return r.GetOrUpdateReadersRet, r.Err } -func (r *RouterFake) Readers(context.Context, string) ([]string, error) { - return nil, nil +func (r *RouterFake) Readers(string) []string { + return nil } func (r *RouterFake) GetOrUpdateWriters(_ context.Context, bookmarksFn func(context.Context) ([]string, error), database string, _ *db.ReAuthToken, _ log.BoltLogger) ([]string, error) { @@ -74,8 +79,8 @@ func (r *RouterFake) GetOrUpdateWriters(_ context.Context, bookmarksFn func(cont return r.GetOrUpdateWritersRet, r.Err } -func (r *RouterFake) Writers(context.Context, string) ([]string, error) { - return nil, nil +func (r *RouterFake) Writers(string) []string { + return nil } func (r *RouterFake) GetNameOfDefaultDatabase(_ context.Context, _ []string, user string, _ *db.ReAuthToken, _ log.BoltLogger) (string, error) { @@ -85,9 +90,8 @@ func (r *RouterFake) GetNameOfDefaultDatabase(_ context.Context, _ []string, use return "", nil } -func (r *RouterFake) CleanUp(ctx context.Context) error { +func (r *RouterFake) CleanUp() { if r.CleanUpHook != nil { r.CleanUpHook() } - return nil } diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index dd906b62..09433d43 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -189,9 +189,9 @@ const FetchDefault = 0 // Connection pool as seen by the session. type sessionPool interface { - Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, livenessCheckThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) - Return(ctx context.Context, c idb.Connection) error - CleanUp(ctx context.Context) error + Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, livenessCheckThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Return(ctx context.Context, c idb.Connection) + CleanUp(ctx context.Context) Now() time.Time } @@ -312,7 +312,7 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . beginBookmarks, err := s.getBookmarks(ctx) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } txHandle, err := conn.TxBegin(ctx, @@ -328,7 +328,7 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . }, }) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } @@ -343,8 +343,8 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . } // On run failure, transaction closed (rolled back or committed) bookmarkErr := s.retrieveBookmarks(ctx, tx.conn, beginBookmarks) - poolErr := s.pool.Return(ctx, tx.conn) - tx.err = errorutil.CombineAllErrors(tx.err, bookmarkErr, poolErr) + s.pool.Return(ctx, tx.conn) + tx.err = errorutil.CombineAllErrors(tx.err, bookmarkErr) tx.conn = nil s.explicitTx = nil }, @@ -397,20 +397,7 @@ func (s *sessionWithContext) runRetriable( Sleep: s.sleep, Throttle: retry.Throttler(s.throttleTime), MaxDeadConnections: s.driverConfig.MaxConnectionPoolSize, - Router: s.router, DatabaseName: s.config.DatabaseName, - OnDeadConnection: func(server string) error { - if mode == idb.WriteMode { - if err := s.router.InvalidateWriter(ctx, s.config.DatabaseName, server); err != nil { - return err - } - } else { - if err := s.router.InvalidateReader(ctx, s.config.DatabaseName, server); err != nil { - return err - } - } - return nil - }, } for state.Continue() { if hasCompleted, result := s.executeTransactionFunction(ctx, mode, config, &state, work); hasCompleted { @@ -438,7 +425,7 @@ func (s *sessionWithContext) executeTransactionFunction( // handle transaction function panic as well defer func() { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) }() beginBookmarks, err := s.getBookmarks(ctx) @@ -496,12 +483,12 @@ func (s *sessionWithContext) getOrUpdateServers(ctx context.Context, mode idb.Ac } } -func (s *sessionWithContext) getServers(mode idb.AccessMode) func(context.Context) ([]string, error) { - return func(ctx context.Context) ([]string, error) { +func (s *sessionWithContext) getServers(mode idb.AccessMode) func() []string { + return func() []string { if mode == idb.ReadMode { - return s.router.Readers(ctx, s.config.DatabaseName) + return s.router.Readers(s.config.DatabaseName) } else { - return s.router.Writers(ctx, s.config.DatabaseName) + return s.router.Writers(s.config.DatabaseName) } } } @@ -594,7 +581,7 @@ func (s *sessionWithContext) Run(ctx context.Context, runBookmarks, err := s.getBookmarks(ctx) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } stream, err := conn.Run( @@ -617,7 +604,7 @@ func (s *sessionWithContext) Run(ctx context.Context, }, ) if err != nil { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) return nil, errorutil.WrapError(err) } @@ -630,7 +617,7 @@ func (s *sessionWithContext) Run(ctx context.Context, } }), onClosed: func() { - _ = s.pool.Return(ctx, conn) + s.pool.Return(ctx, conn) s.autocommitTx = nil }, } @@ -649,15 +636,19 @@ func (s *sessionWithContext) Close(ctx context.Context) error { } defer s.log.Debugf(log.Session, s.logId, "Closed") - poolErrChan := make(chan error, 1) - routerErrChan := make(chan error, 1) + poolCleanUpChan := make(chan struct{}, 1) + routerCleanUpChan := make(chan struct{}, 1) go func() { - poolErrChan <- s.pool.CleanUp(ctx) + s.pool.CleanUp(ctx) + poolCleanUpChan <- struct{}{} }() go func() { - routerErrChan <- s.router.CleanUp(ctx) + s.router.CleanUp() + routerCleanUpChan <- struct{}{} }() - return errorutil.CombineAllErrors(txErr, <-poolErrChan, <-routerErrChan) + <-poolCleanUpChan + <-routerCleanUpChan + return txErr } func (s *sessionWithContext) legacy() Session { diff --git a/neo4j/test-integration/dbconn_test.go b/neo4j/test-integration/dbconn_test.go index c1f2e776..57691f71 100644 --- a/neo4j/test-integration/dbconn_test.go +++ b/neo4j/test-integration/dbconn_test.go @@ -40,10 +40,21 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/test-integration/dbserver" ) -func noopOnNeo4jError(context.Context, idb.Connection, *db.Neo4jError) error { +type noopErrorListener struct{} + +func (n noopErrorListener) OnNeo4jError(_ context.Context, _ idb.Connection, e *db.Neo4jError) error { + fmt.Println("OnNeo4jError", e) return nil } +func (n noopErrorListener) OnIoError(_ context.Context, _ idb.Connection, e error) { + fmt.Println("OnIoError", e) +} + +func (n noopErrorListener) OnDialError(_ context.Context, _ string, e error) { + fmt.Println("OnDialError", e) +} + func makeRawConnection(ctx context.Context, logger log.Logger, boltLogger log.BoltLogger) ( dbserver.DbServer, idb.Connection) { server := dbserver.GetDbServer(ctx) @@ -77,7 +88,7 @@ func makeRawConnection(ctx context.Context, logger log.Logger, boltLogger log.Bo auth, "007", nil, - noopOnNeo4jError, + noopErrorListener{}, logger, boltLogger, idb.NotificationConfig{}, diff --git a/neo4j/test-integration/driver_test.go b/neo4j/test-integration/driver_test.go index 7b53a60b..89f8e968 100644 --- a/neo4j/test-integration/driver_test.go +++ b/neo4j/test-integration/driver_test.go @@ -40,7 +40,7 @@ func TestDriver(outer *testing.T) { outer.Run("VerifyConnectivity", func(inner *testing.T) { inner.Run("should return nil upon good connection", func(t *testing.T) { driver := server.Driver() - defer driver.Close(ctx) + defer func() { _ = driver.Close(ctx) }() assertNil(t, driver.VerifyConnectivity(ctx)) }) @@ -48,7 +48,7 @@ func TestDriver(outer *testing.T) { auth := neo4j.BasicAuth("bad user", "bad pass", "bad area") driver, err := neo4j.NewDriverWithContext(server.BoltURI(), auth, server.ConfigFunc()) assertNil(t, err) - defer driver.Close(ctx) + defer func() { _ = driver.Close(ctx) }() err = driver.VerifyConnectivity(ctx) assertNotNil(t, err) }) @@ -64,11 +64,11 @@ func TestDriver(outer *testing.T) { tearDown := func(driver neo4j.DriverWithContext, session neo4j.SessionWithContext) { if session != nil { - session.Close(ctx) + _ = session.Close(ctx) } if driver != nil { - driver.Close(ctx) + _ = driver.Close(ctx) } } @@ -126,7 +126,7 @@ func TestDriver(outer *testing.T) { defer func() { if driver != nil { - driver.Close(ctx) + _ = driver.Close(ctx) } }() @@ -165,7 +165,7 @@ func TestDriver(outer *testing.T) { defer func() { if driver != nil { - driver.Close(ctx) + _ = driver.Close(ctx) } }() diff --git a/testkit-backend/backend.go b/testkit-backend/backend.go index f413d157..07a07db8 100644 --- a/testkit-backend/backend.go +++ b/testkit-backend/backend.go @@ -150,6 +150,7 @@ func (b *backend) writeLine(s string) error { func (b *backend) writeLineLocked(s string) error { b.wrLock.Lock() defer b.wrLock.Unlock() + fmt.Println(s) return b.writeLine(s) } @@ -608,7 +609,7 @@ func (b *backend) handleRequest(req map[string]any) { case "NewSession": driver := b.drivers[data["driverId"].(string)] sessionConfig := neo4j.SessionConfig{ - BoltLogger: neo4j.ConsoleBoltLogger(), + BoltLogger: &streamLog{writeLine: b.writeLineLocked}, } if data["accessMode"] != nil { switch data["accessMode"].(string) { @@ -844,6 +845,54 @@ func (b *backend) handleRequest(req map[string]any) { } b.writeResponse("Summary", serializeSummary(summary)) + case "ForcedRoutingTableUpdate": + databaseRaw := data["database"] + var database string + if databaseRaw != nil { + database = databaseRaw.(string) + } + var bookmarks []string + bookmarksRaw := data["bookmarks"] + if bookmarksRaw != nil { + bookmarksSlice := bookmarksRaw.([]any) + bookmarks = make([]string, len(bookmarksSlice)) + for i, bookmark := range bookmarksSlice { + bookmarks[i] = bookmark.(string) + } + } + driverId := data["driverId"].(string) + driver := b.drivers[driverId] + err := neo4j.ForceRoutingTableUpdate(driver, database, bookmarks, &streamLog{writeLine: b.writeLineLocked}) + if err != nil { + b.writeError(err) + return + } + b.writeResponse("Driver", map[string]any{"id": driverId}) + + case "GetRoutingTable": + driver := b.drivers[data["driverId"].(string)] + databaseRaw := data["database"] + var database string + if databaseRaw != nil { + database = databaseRaw.(string) + } + table, err := neo4j.GetRoutingTable(driver, database) + if err != nil { + b.writeError(err) + return + } + var databaseName any = table.DatabaseName + if databaseName == "" { + databaseName = nil + } + b.writeResponse("RoutingTable", map[string]any{ + "database": databaseName, + "ttl": table.TimeToLive, + "routers": table.Routers, + "readers": table.Readers, + "writers": table.Writers, + }) + case "CheckMultiDBSupport": driver := b.drivers[data["driverId"].(string)] session := driver.NewSession(ctx, neo4j.SessionConfig{ @@ -1041,10 +1090,7 @@ func (b *backend) handleRequest(req map[string]any) { case "GetFeatures": b.writeResponse("FeatureList", map[string]any{ "features": []string{ - "AuthorizationExpiredTreatment", - "Backend:MockTime", - "ConfHint:connection.recv_timeout_seconds", - "Detail:ClosedDriverIsEncrypted", + // === FUNCTIONAL FEATURES === "Feature:API:BookmarkManager", "Feature:API:ConnectionAcquisitionTimeout", "Feature:API:Driver.ExecuteQuery", @@ -1053,15 +1099,21 @@ func (b *backend) handleRequest(req map[string]any) { "Feature:API:Driver:NotificationsConfig", "Feature:API:Driver.VerifyAuthentication", "Feature:API:Driver.VerifyConnectivity", - "Feature:API:Liveness.Check", + //"Feature:API:Driver.SupportsSessionAuth", + // Go driver does not support LivenessCheckTimeout yet + //"Feature:API:Liveness.Check", "Feature:API:Result.List", "Feature:API:Result.Peek", + //"Feature:API:Result.Single", + //"Feature:API:Result.SingleOptional", "Feature:API:Session:AuthConfig", - "Feature:API:Session:NotificationsConfig", + //"Feature:API:Session:NotificationsConfig", + //"Feature:API:SSLConfig", + //"Feature:API:SSLSchemes", "Feature:API:Type.Spatial", "Feature:API:Type.Temporal", - "Feature:Auth:Custom", "Feature:Auth:Bearer", + "Feature:Auth:Custom", "Feature:Auth:Kerberos", "Feature:Auth:Managed", "Feature:Bolt:3.0", @@ -1075,15 +1127,33 @@ func (b *backend) handleRequest(req map[string]any) { "Feature:Bolt:5.3", "Feature:Bolt:Patch:UTC", "Feature:Impersonation", + //"Feature:TLS:1.1", "Feature:TLS:1.2", "Feature:TLS:1.3", - "Optimization:AuthPipelining", + + // === OPTIMIZATIONS === + "AuthorizationExpiredTreatment", "Optimization:ConnectionReuse", "Optimization:EagerTransactionBegin", "Optimization:ImplicitDefaultArguments", "Optimization:MinimalBookmarksSet", "Optimization:MinimalResets", + //"Optimization:MinimalVerifyAuthentication", + "Optimization:AuthPipelining", "Optimization:PullPipelining", + //"Optimization:ResultListFetchAll", + + // === IMPLEMENTATION DETAILS === + "Detail:ClosedDriverIsEncrypted", + "Detail:DefaultSecurityConfigValueEquality", + + // === CONFIGURATION HINTS (BOLT 4.3+) === + "ConfHint:connection.recv_timeout_seconds", + + // === BACKEND FEATURES FOR TESTING === + "Backend:MockTime", + "Backend:RTFetch", + "Backend:RTForceUpdate", }, }) @@ -1393,34 +1463,36 @@ func firstRecordInvalidValue(record *db.Record) *neo4j.InvalidValue { // you can use '*' as wildcards anywhere in the qualified test name (useful to exclude a whole class e.g.) func testSkips() map[string]string { return map[string]string{ - "stub.disconnects.test_disconnects.TestDisconnects.test_fail_on_reset": "It is not resetting driver when put back to pool", - "stub.routing.test_routing_v3.RoutingV3.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v5x0.RoutingV5x0.test_should_use_resolver_during_rediscovery_when_existing_routers_fail": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v3.RoutingV3.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.routing.test_routing_v5x0.RoutingV5x0.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "It needs investigation - custom resolver does not seem to be called", - "stub.configuration_hints.test_connection_recv_timeout_seconds.TestRoutingConnectionRecvTimeout.*": "No GetRoutingTable support - too tricky to implement in Go", - "stub.homedb.test_homedb.TestHomeDb.test_session_should_cache_home_db_despite_new_rt": "Driver does not remove servers from RT when connection breaks.", - "stub.iteration.test_result_scope.TestResultScope.*": "Results are always valid but don't return records when out of scope", - "stub.*.test_0_timeout": "Driver omits 0 as tx timeout value", - "stub.*.test_negative_timeout": "Driver omits negative tx timeout values", - "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_unknown_failure": "Add DNS resolver TestKit message and connection timeout support", - "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_authorization_expired": "Add DNS resolver TestKit message and connection timeout support", - "stub.summary.test_summary.TestSummary.test_server_info": "Needs some kind of server address DNS resolution", - "stub.summary.test_summary.TestSummary.test_invalid_query_type": "Driver does not verify query type returned from server.", - "stub.routing.*.test_should_drop_connections_failing_liveness_check": "Needs support for GetConnectionPoolMetrics", + // Won't fix - accepted/idiomatic behavioral differences + "stub.iteration.test_result_scope.TestResultScope.*": "Won't fix - Results are always valid but don't return records when out of scope", "stub.connectivity_check.test_get_server_info.TestGetServerInfo.test_routing_fail_when_no_reader_are_available": "Won't fix - Go driver retries routing table when no readers are available", "stub.connectivity_check.test_verify_connectivity.TestVerifyConnectivity.test_routing_fail_when_no_reader_are_available": "Won't fix - Go driver retries routing table when no readers are available", "stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_does_not_encompass_router_*": "Won't fix - ConnectionAcquisitionTimeout spans the whole process including db resolution, RT updates, connection acquisition from the pool, and creation of new connections.", "stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_router_handshake_has_own_timeout_*": "Won't fix - ConnectionAcquisitionTimeout spans the whole process including db resolution, RT updates, connection acquisition from the pool, and creation of new connections.", - "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_discard_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", - "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", - "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_commit_after_tx_termination": "fixme: commit is still sent when transaction is terminated", + "stub.routing.test_routing_v*.RoutingV*.test_should_successfully_check_if_support_for_multi_db_is_available": "Won't fix - driver.SupportsMultiDb() is not implemented", + "stub.routing.test_no_routing_v*.NoRoutingV*.test_should_check_multi_db_support": "Won't fix - driver.SupportsMultiDb() is not implemented", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_discovery_when_router_fails_with_procedure_not_found_code": "Won't fix - only Bolt 3 affected (not officially supported by this driver) + this is only a difference in how errors are surfaced", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_on_pull_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_on_run_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", + "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", + + // Missing message support in testkit backend + "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_unknown_failure": "Add DNS resolver TestKit message and connection timeout support", + "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_authorization_expired": "Add DNS resolver TestKit message and connection timeout support", + + // To fix/to decide whether to fix + "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_discard_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", + "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_run": "fixme: usage of failed transaction leads to unintelligible error that's treated as BackendError", + "stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_commit_after_tx_termination": "fixme: commit is still sent when transaction is terminated", + "stub.routing.test_routing_v*.RoutingV*.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "Driver always uses configured URL first and custom resolver only if that fails", + "stub.routing.test_routing_v*.RoutingV*.test_should_read_successfully_from_reachable_db_after_trying_unreachable_db": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_write_successfully_after_leader_switch_using_tx_run": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_fail_when_writing_without_writers_using_session_run": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_accept_routing_table_without_writers_and_then_rediscover": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_fail_on_routing_table_with_no_reader": "Driver retries to fetch a routing table up to 100 times if it's emtpy", + "stub.routing.test_routing_v*.RoutingV*.test_should_fail_discovery_when_router_fails_with_unknown_code": "Unify: other drivers have a list of fast failing errors during discover: on anything else, the driver will try the next router", + "stub.*.test_0_timeout": "Fixme: driver omits 0 as tx timeout value", + "stub.summary.test_summary.TestSummary.test_server_info": "pending unification: should the server address be pre or post DNS resolution?", } } diff --git a/testkit-backend/streamlogger.go b/testkit-backend/streamlogger.go index c51004be..fcdd1f9b 100644 --- a/testkit-backend/streamlogger.go +++ b/testkit-backend/streamlogger.go @@ -8,34 +8,56 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package main import ( "fmt" + "time" ) +const timeFormat = "2006-01-02 15:04:05.000" + type streamLog struct { writeLine func(string) error } func (l *streamLog) Error(name string, id string, err error) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, err)) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, err)) } func (l *streamLog) Warnf(name string, id string, msg string, args ...any) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) } func (l *streamLog) Infof(name string, id string, msg string, args ...any) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) } func (l *streamLog) Debugf(name string, id string, msg string, args ...any) { - l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) + _ = l.writeLine(fmt.Sprintf("[%s %s] %s", name, id, fmt.Sprintf(msg, args...))) +} + +func (l *streamLog) LogClientMessage(id, msg string, args ...any) { + l.logBoltMessage("C", id, msg, args) +} + +func (l *streamLog) LogServerMessage(id, msg string, args ...any) { + l.logBoltMessage("S", id, msg, args) +} + +func (l *streamLog) logBoltMessage(src, id string, msg string, args []any) { + _ = l.writeLine(fmt.Sprintf("%s BOLT %s%s: %s", time.Now().Format(timeFormat), formatId(id), src, fmt.Sprintf(msg, args...))) +} + +func formatId(id string) string { + if id == "" { + return "" + } + return fmt.Sprintf("[%s] ", id) }