diff --git a/admin/README.md b/admin/README.md index 1694be1f765..2d003798b0f 100644 --- a/admin/README.md +++ b/admin/README.md @@ -96,6 +96,11 @@ curl localhost:9002/admin/run_command -H 'Content-Type: application/json' -d '{" curl localhost:9002/admin/run_command -H 'Content-Type: application/json' -d '{"commandName": "stop-at-height", "data": { "height": 1111, "crash": false }}' ``` +### Trigger checkpoint creation on execution +``` +curl localhost:9002/admin/run_command -H 'Content-Type: application/json' -d '{"commandName": "trigger-checkpoint"}' +``` + ### Add/Remove/Get address to rate limit a payer from adding transactions to collection nodes' mempool ``` curl localhost:9002/admin/run_command -H 'Content-Type: application/json' -d '{"commandName": "ingest-tx-rate-limit", "data": { "command": "add", "addresses": "a08d349e8037d6e5,e6765c6113547fb7" }}' diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 299334f96b2..4f7270b1d32 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -1505,12 +1505,12 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { cacheSize := int(backendConfig.ConnectionPoolSize) var connBackendCache *rpcConnection.Cache + var err error if cacheSize > 0 { - backendCache, err := backend.NewCache(node.Logger, accessMetrics, cacheSize) + connBackendCache, err = rpcConnection.NewCache(node.Logger, accessMetrics, cacheSize) if err != nil { - return nil, fmt.Errorf("could not initialize backend cache: %w", err) + return nil, fmt.Errorf("could not initialize connection cache: %w", err) } - connBackendCache = rpcConnection.NewCache(backendCache, cacheSize) } connFactory := &rpcConnection.ConnectionFactoryImpl{ @@ -1521,9 +1521,9 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { AccessMetrics: accessMetrics, Log: node.Logger, Manager: rpcConnection.NewManager( - connBackendCache, node.Logger, accessMetrics, + connBackendCache, config.MaxMsgSize, backendConfig.CircuitBreakerConfig, config.CompressorName, diff --git a/cmd/bootstrap/run/execution_state.go b/cmd/bootstrap/run/execution_state.go index 38bd1d8de10..c1896668c38 100644 --- a/cmd/bootstrap/run/execution_state.go +++ b/cmd/bootstrap/run/execution_state.go @@ -43,7 +43,7 @@ func GenerateExecutionState( return flow.DummyStateCommitment, err } - compactor, err := complete.NewCompactor(ledgerStorage, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(ledgerStorage, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metricsCollector) if err != nil { return flow.DummyStateCommitment, err } diff --git a/cmd/execution_builder.go b/cmd/execution_builder.go index f80295a6af2..285c61c2d29 100644 --- a/cmd/execution_builder.go +++ b/cmd/execution_builder.go @@ -881,6 +881,7 @@ func (exeNode *ExecutionNode) LoadExecutionStateLedgerWALCompactor( exeNode.exeConf.checkpointDistance, exeNode.exeConf.checkpointsToKeep, exeNode.toTriggerCheckpoint, // compactor will listen to the signal from admin tool for force triggering checkpointing + exeNode.collector, ) } diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 91d2b50e4f1..dcda8127563 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -1210,12 +1210,12 @@ func (builder *ObserverServiceBuilder) enqueueRPCServer() { cacheSize := int(backendConfig.ConnectionPoolSize) var connBackendCache *rpcConnection.Cache + var err error if cacheSize > 0 { - backendCache, err := backend.NewCache(node.Logger, accessMetrics, cacheSize) + connBackendCache, err = rpcConnection.NewCache(node.Logger, accessMetrics, cacheSize) if err != nil { - return nil, fmt.Errorf("could not initialize backend cache: %w", err) + return nil, fmt.Errorf("could not initialize connection cache: %w", err) } - connBackendCache = rpcConnection.NewCache(backendCache, cacheSize) } connFactory := &rpcConnection.ConnectionFactoryImpl{ @@ -1226,9 +1226,9 @@ func (builder *ObserverServiceBuilder) enqueueRPCServer() { AccessMetrics: accessMetrics, Log: node.Logger, Manager: rpcConnection.NewManager( - connBackendCache, node.Logger, accessMetrics, + connBackendCache, config.MaxMsgSize, backendConfig.CircuitBreakerConfig, config.CompressorName, diff --git a/cmd/util/cmd/checkpoint-collect-stats/cmd.go b/cmd/util/cmd/checkpoint-collect-stats/cmd.go index cf74b467758..29c7bd1c5ef 100644 --- a/cmd/util/cmd/checkpoint-collect-stats/cmd.go +++ b/cmd/util/cmd/checkpoint-collect-stats/cmd.go @@ -93,7 +93,7 @@ func run(*cobra.Command, []string) { if err != nil { log.Fatal().Err(err).Msg("cannot create ledger from write-a-head logs and checkpoints") } - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), complete.DefaultCacheSize, math.MaxInt, 1, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), complete.DefaultCacheSize, math.MaxInt, 1, atomic.NewBool(false), &metrics.NoopCollector{}) if err != nil { log.Fatal().Err(err).Msg("cannot create compactor") } diff --git a/cmd/util/cmd/exec-data-json-export/ledger_exporter.go b/cmd/util/cmd/exec-data-json-export/ledger_exporter.go index ee8573d8963..a9d75734d9b 100644 --- a/cmd/util/cmd/exec-data-json-export/ledger_exporter.go +++ b/cmd/util/cmd/exec-data-json-export/ledger_exporter.go @@ -35,7 +35,7 @@ func ExportLedger(ledgerPath string, targetstate string, outputPath string) erro return fmt.Errorf("cannot create ledger from write-a-head logs and checkpoints: %w", err) } - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), complete.DefaultCacheSize, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), complete.DefaultCacheSize, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), &metrics.NoopCollector{}) if err != nil { return fmt.Errorf("cannot create compactor: %w", err) } diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract.go b/cmd/util/cmd/execution-state-extract/execution_state_extract.go index 90bcd70533d..b2146878898 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract.go @@ -70,7 +70,7 @@ func extractExecutionState( log.Info().Msg("init compactor") - compactor, err := complete.NewCompactor(led, diskWal, log, complete.DefaultCacheSize, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, log, complete.DefaultCacheSize, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), &metrics.NoopCollector{}) if err != nil { return fmt.Errorf("cannot create compactor: %w", err) } diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go index 2f91ea7d603..70f8ca6bc89 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go @@ -90,7 +90,7 @@ func TestExtractExecutionState(t *testing.T) { require.NoError(t, err) f, err := complete.NewLedger(diskWal, size*10, metr, zerolog.Nop(), complete.DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := complete.NewCompactor(f, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(f, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false), &metrics.NoopCollector{}) require.NoError(t, err) <-compactor.Ready() @@ -166,7 +166,7 @@ func TestExtractExecutionState(t *testing.T) { checkpointDistance = math.MaxInt // A large number to prevent checkpoint creation. checkpointsToKeep = 1 ) - compactor, err := complete.NewCompactor(storage, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(storage, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false), &metrics.NoopCollector{}) require.NoError(t, err) <-compactor.Ready() diff --git a/consensus/hotstuff/votecollector/combined_vote_processor_v2_test.go b/consensus/hotstuff/votecollector/combined_vote_processor_v2_test.go index deb5cce438a..fe0f48c0f29 100644 --- a/consensus/hotstuff/votecollector/combined_vote_processor_v2_test.go +++ b/consensus/hotstuff/votecollector/combined_vote_processor_v2_test.go @@ -118,7 +118,7 @@ func (s *CombinedVoteProcessorV2TestSuite) TestProcess_InvalidSignatureFormat() rapid.Check(s.T(), func(t *rapid.T) { // create a signature with invalid length vote := unittest.VoteForBlockFixture(s.proposal.Block, func(vote *model.Vote) { - vote.SigData = unittest.RandomBytes(generator.Draw(t, "sig-size").(int)) + vote.SigData = unittest.RandomBytes(generator.Draw(t, "sig-size")) }) err := s.processor.Process(vote) require.Error(s.T(), err) @@ -434,8 +434,8 @@ func TestCombinedVoteProcessorV2_PropertyCreatingQCCorrectness(testifyT *testing rapid.Check(testifyT, func(t *rapid.T) { // draw participants in range 1 <= participants <= maxParticipants - participants := rapid.Uint64Range(1, maxParticipants).Draw(t, "participants").(uint64) - beaconSignersCount := rapid.Uint64Range(participants/2+1, participants).Draw(t, "beaconSigners").(uint64) + participants := rapid.Uint64Range(1, maxParticipants).Draw(t, "participants") + beaconSignersCount := rapid.Uint64Range(participants/2+1, participants).Draw(t, "beaconSigners") stakingSignersCount := participants - beaconSignersCount require.Equal(t, participants, stakingSignersCount+beaconSignersCount) @@ -638,20 +638,20 @@ func TestCombinedVoteProcessorV2_PropertyCreatingQCCorrectness(testifyT *testing func TestCombinedVoteProcessorV2_PropertyCreatingQCLiveness(testifyT *testing.T) { rapid.Check(testifyT, func(t *rapid.T) { // draw beacon signers in range 1 <= beaconSignersCount <= 53 - beaconSignersCount := rapid.Uint64Range(1, 53).Draw(t, "beaconSigners").(uint64) + beaconSignersCount := rapid.Uint64Range(1, 53).Draw(t, "beaconSigners") // draw staking signers in range 0 <= stakingSignersCount <= 10 - stakingSignersCount := rapid.Uint64Range(0, 10).Draw(t, "stakingSigners").(uint64) + stakingSignersCount := rapid.Uint64Range(0, 10).Draw(t, "stakingSigners") stakingWeightRange, beaconWeightRange := rapid.Uint64Range(1, 10), rapid.Uint64Range(1, 10) minRequiredWeight := uint64(0) // draw weight for each signer randomly stakingSigners := unittest.IdentityListFixture(int(stakingSignersCount), func(identity *flow.Identity) { - identity.InitialWeight = stakingWeightRange.Draw(t, identity.String()).(uint64) + identity.InitialWeight = stakingWeightRange.Draw(t, identity.String()) minRequiredWeight += identity.InitialWeight }) beaconSigners := unittest.IdentityListFixture(int(beaconSignersCount), func(identity *flow.Identity) { - identity.InitialWeight = beaconWeightRange.Draw(t, identity.String()).(uint64) + identity.InitialWeight = beaconWeightRange.Draw(t, identity.String()) minRequiredWeight += identity.InitialWeight }) diff --git a/consensus/hotstuff/votecollector/combined_vote_processor_v3_test.go b/consensus/hotstuff/votecollector/combined_vote_processor_v3_test.go index 50d435f5a50..1f632428c05 100644 --- a/consensus/hotstuff/votecollector/combined_vote_processor_v3_test.go +++ b/consensus/hotstuff/votecollector/combined_vote_processor_v3_test.go @@ -434,8 +434,8 @@ func TestCombinedVoteProcessorV3_PropertyCreatingQCCorrectness(testifyT *testing rapid.Check(testifyT, func(t *rapid.T) { // draw participants in range 1 <= participants <= maxParticipants - participants := rapid.Uint64Range(1, maxParticipants).Draw(t, "participants").(uint64) - beaconSignersCount := rapid.Uint64Range(participants/2+1, participants).Draw(t, "beaconSigners").(uint64) + participants := rapid.Uint64Range(1, maxParticipants).Draw(t, "participants") + beaconSignersCount := rapid.Uint64Range(participants/2+1, participants).Draw(t, "beaconSigners") stakingSignersCount := participants - beaconSignersCount require.Equal(t, participants, stakingSignersCount+beaconSignersCount) @@ -749,20 +749,20 @@ func TestCombinedVoteProcessorV3_OnlyRandomBeaconSigners(testifyT *testing.T) { func TestCombinedVoteProcessorV3_PropertyCreatingQCLiveness(testifyT *testing.T) { rapid.Check(testifyT, func(t *rapid.T) { // draw beacon signers in range 1 <= beaconSignersCount <= 53 - beaconSignersCount := rapid.Uint64Range(1, 53).Draw(t, "beaconSigners").(uint64) + beaconSignersCount := rapid.Uint64Range(1, 53).Draw(t, "beaconSigners") // draw staking signers in range 0 <= stakingSignersCount <= 10 - stakingSignersCount := rapid.Uint64Range(0, 10).Draw(t, "stakingSigners").(uint64) + stakingSignersCount := rapid.Uint64Range(0, 10).Draw(t, "stakingSigners") stakingWeightRange, beaconWeightRange := rapid.Uint64Range(1, 10), rapid.Uint64Range(1, 10) minRequiredWeight := uint64(0) // draw weight for each signer randomly stakingSigners := unittest.IdentityListFixture(int(stakingSignersCount), func(identity *flow.Identity) { - identity.InitialWeight = stakingWeightRange.Draw(t, identity.String()).(uint64) + identity.InitialWeight = stakingWeightRange.Draw(t, identity.String()) minRequiredWeight += identity.InitialWeight }) beaconSigners := unittest.IdentityListFixture(int(beaconSignersCount), func(identity *flow.Identity) { - identity.InitialWeight = beaconWeightRange.Draw(t, identity.String()).(uint64) + identity.InitialWeight = beaconWeightRange.Draw(t, identity.String()) minRequiredWeight += identity.InitialWeight }) diff --git a/engine/access/apiproxy/access_api_proxy_test.go b/engine/access/apiproxy/access_api_proxy_test.go index 27f96413c52..a4c27896f08 100644 --- a/engine/access/apiproxy/access_api_proxy_test.go +++ b/engine/access/apiproxy/access_api_proxy_test.go @@ -152,9 +152,9 @@ func TestNewFlowCachedAccessAPIProxy(t *testing.T) { AccessMetrics: metrics, CollectionNodeGRPCTimeout: time.Second, Manager: connection.NewManager( - nil, unittest.Logger(), metrics, + nil, grpcutils.DefaultMaxMsgSize, connection.CircuitBreakerConfig{}, grpcutils.NoCompressor, diff --git a/engine/access/rpc/backend/backend.go b/engine/access/rpc/backend/backend.go index 00707e01633..76519974de5 100644 --- a/engine/access/rpc/backend/backend.go +++ b/engine/access/rpc/backend/backend.go @@ -236,27 +236,6 @@ func New(params Params) (*Backend, error) { return b, nil } -// NewCache constructs cache for storing connections to other nodes. -// No errors are expected during normal operations. -func NewCache( - log zerolog.Logger, - metrics module.AccessMetrics, - connectionPoolSize int, -) (*lru.Cache[string, *connection.CachedClient], error) { - cache, err := lru.NewWithEvict(connectionPoolSize, func(_ string, client *connection.CachedClient) { - go client.Close() // close is blocking, so run in a goroutine - - log.Debug().Str("grpc_conn_evicted", client.Address).Msg("closing grpc connection evicted from pool") - metrics.ConnectionFromPoolEvicted() - }) - - if err != nil { - return nil, fmt.Errorf("could not initialize connection pool cache: %w", err) - } - - return cache, nil -} - func identifierList(ids []string) (flow.IdentifierList, error) { idList := make(flow.IdentifierList, len(ids)) for i, idStr := range ids { diff --git a/engine/access/rpc/connection/cache.go b/engine/access/rpc/connection/cache.go index 1b12deb6f17..ba0231fe452 100644 --- a/engine/access/rpc/connection/cache.go +++ b/engine/access/rpc/connection/cache.go @@ -1,22 +1,62 @@ package connection import ( + "fmt" "sync" "time" lru "github.com/hashicorp/golang-lru/v2" + "github.com/onflow/crypto" + "github.com/rs/zerolog" "go.uber.org/atomic" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + + "github.com/onflow/flow-go/module" ) // CachedClient represents a gRPC client connection that is cached for reuse. type CachedClient struct { - ClientConn *grpc.ClientConn - Address string - timeout time.Duration + conn *grpc.ClientConn + address string + timeout time.Duration + + cache *Cache closeRequested *atomic.Bool wg sync.WaitGroup - mu sync.Mutex + mu sync.RWMutex +} + +// ClientConn returns the underlying gRPC client connection. +func (cc *CachedClient) ClientConn() *grpc.ClientConn { + cc.mu.RLock() + defer cc.mu.RUnlock() + return cc.conn +} + +// Address returns the address of the remote server. +func (cc *CachedClient) Address() string { + return cc.address +} + +// CloseRequested returns true if the CachedClient has been marked for closure. +func (cc *CachedClient) CloseRequested() bool { + return cc.closeRequested.Load() +} + +// AddRequest increments the in-flight request counter for the CachedClient. +// It returns a function that should be called when the request completes to decrement the counter +func (cc *CachedClient) AddRequest() func() { + cc.wg.Add(1) + return cc.wg.Done +} + +// Invalidate removes the CachedClient from the cache and closes the connection. +func (cc *CachedClient) Invalidate() { + cc.cache.invalidate(cc.address) + + // Close the connection asynchronously to avoid blocking requests + go cc.Close() } // Close closes the CachedClient connection. It marks the connection for closure and waits asynchronously for ongoing @@ -28,16 +68,17 @@ func (cc *CachedClient) Close() { } // Obtain the lock to ensure that any connection attempts have completed - cc.mu.Lock() - conn := cc.ClientConn - cc.mu.Unlock() + cc.mu.RLock() + conn := cc.conn + cc.mu.RUnlock() - // If the initial connection attempt failed, ClientConn will be nil + // If the initial connection attempt failed, conn will be nil if conn == nil { return } // If there are ongoing requests, wait for them to complete asynchronously + // this avoids tearing down the connection while requests are in-flight resulting in errors cc.wg.Wait() // Close the connection @@ -46,59 +87,95 @@ func (cc *CachedClient) Close() { // Cache represents a cache of CachedClient instances with a given maximum size. type Cache struct { - cache *lru.Cache[string, *CachedClient] - size int + cache *lru.Cache[string, *CachedClient] + maxSize int + + logger zerolog.Logger + metrics module.GRPCConnectionPoolMetrics } // NewCache creates a new Cache with the specified maximum size and the underlying LRU cache. -func NewCache(cache *lru.Cache[string, *CachedClient], size int) *Cache { - return &Cache{ - cache: cache, - size: size, +func NewCache( + log zerolog.Logger, + metrics module.GRPCConnectionPoolMetrics, + maxSize int, +) (*Cache, error) { + cache, err := lru.NewWithEvict(maxSize, func(_ string, client *CachedClient) { + go client.Close() // close is blocking, so run in a goroutine + + log.Debug().Str("grpc_conn_evicted", client.address).Msg("closing grpc connection evicted from pool") + metrics.ConnectionFromPoolEvicted() + }) + + if err != nil { + return nil, fmt.Errorf("could not initialize connection pool cache: %w", err) } -} -// Get retrieves the CachedClient for the given address from the cache. -// It returns the CachedClient and a boolean indicating whether the entry exists in the cache. -func (c *Cache) Get(address string) (*CachedClient, bool) { - val, ok := c.cache.Get(address) - if !ok { - return nil, false - } - return val, true + return &Cache{ + cache: cache, + maxSize: maxSize, + logger: log, + metrics: metrics, + }, nil } -// GetOrAdd atomically gets the CachedClient for the given address from the cache, or adds a new one -// if none existed. -// New entries are added to the cache with their mutex locked. This ensures that the caller gets -// priority when working with the new client, allowing it to create the underlying connection. -// Clients retrieved from the cache are returned without modifying their lock. -func (c *Cache) GetOrAdd(address string, timeout time.Duration) (*CachedClient, bool) { - client := &CachedClient{} - client.mu.Lock() +// GetConnected returns a CachedClient for the given address that has an active connection. +// If the address is not in the cache, it creates a new entry and connects. +func (c *Cache) GetConnected( + address string, + timeout time.Duration, + networkPubKey crypto.PublicKey, + connectFn func(string, time.Duration, crypto.PublicKey, *CachedClient) (*grpc.ClientConn, error), +) (*CachedClient, error) { + client := &CachedClient{ + address: address, + timeout: timeout, + closeRequested: atomic.NewBool(false), + cache: c, + } + // Note: PeekOrAdd does not "visit" the existing entry, so we need to call Get explicitly + // to mark the entry as "visited" and update the LRU order. Unfortunately, the lru library + // doesn't have a GetOrAdd method, so this is the simplest way to achieve atomic get-or-add val, existed, _ := c.cache.PeekOrAdd(address, client) if existed { - return val, true + client = val + _, _ = c.cache.Get(address) + c.metrics.ConnectionFromPoolReused() + } else { + c.metrics.ConnectionAddedToPool() } - client.Address = address - client.timeout = timeout - client.closeRequested = atomic.NewBool(false) + client.mu.Lock() + defer client.mu.Unlock() - return client, false -} + // after getting the lock, check if the connection is still active + if client.conn != nil && client.conn.GetState() != connectivity.Shutdown { + return client, nil + } -// Add adds a CachedClient to the cache with the given address. -// It returns a boolean indicating whether an existing entry was evicted. -func (c *Cache) Add(address string, client *CachedClient) (evicted bool) { - return c.cache.Add(address, client) + // if the connection is not setup yet or closed, create a new connection and cache it + conn, err := connectFn(client.address, client.timeout, networkPubKey, client) + if err != nil { + return nil, err + } + + c.metrics.NewConnectionEstablished() + c.metrics.TotalConnectionsInPool(uint(c.Len()), uint(c.MaxSize())) + + client.conn = conn + return client, nil } -// Remove removes the CachedClient entry from the cache with the given address. -// It returns a boolean indicating whether the entry was present and removed. -func (c *Cache) Remove(address string) (present bool) { - return c.cache.Remove(address) +// invalidate removes the CachedClient entry from the cache with the given address, and shuts +// down the connection. +func (c *Cache) invalidate(address string) { + if !c.cache.Remove(address) { + return + } + + c.logger.Debug().Str("cached_client_invalidated", address).Msg("invalidating cached client") + c.metrics.ConnectionFromPoolInvalidated() } // Len returns the number of CachedClient entries in the cache. @@ -108,11 +185,5 @@ func (c *Cache) Len() int { // MaxSize returns the maximum size of the cache. func (c *Cache) MaxSize() int { - return c.size -} - -// Contains checks if the cache contains an entry with the given address. -// It returns a boolean indicating whether the address is present in the cache. -func (c *Cache) Contains(address string) (containKey bool) { - return c.cache.Contains(address) + return c.maxSize } diff --git a/engine/access/rpc/connection/cache_test.go b/engine/access/rpc/connection/cache_test.go new file mode 100644 index 00000000000..5dd07c3fe7f --- /dev/null +++ b/engine/access/rpc/connection/cache_test.go @@ -0,0 +1,212 @@ +package connection + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/onflow/crypto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/utils/unittest" +) + +func TestCachedClientShutdown(t *testing.T) { + // Test that a completely uninitialized client can be closed without panics + t.Run("uninitialized client", func(t *testing.T) { + client := &CachedClient{ + closeRequested: atomic.NewBool(false), + } + client.Close() + assert.True(t, client.closeRequested.Load()) + }) + + // Test closing a client with no outstanding requests + // Close() should return quickly + t.Run("with no outstanding requests", func(t *testing.T) { + client := &CachedClient{ + closeRequested: atomic.NewBool(false), + conn: setupGRPCServer(t), + } + + unittest.RequireReturnsBefore(t, func() { + client.Close() + }, 100*time.Millisecond, "client timed out closing connection") + + assert.True(t, client.closeRequested.Load()) + }) + + // Test closing a client with outstanding requests waits for requests to complete + // Close() should block until the request completes + t.Run("with some outstanding requests", func(t *testing.T) { + client := &CachedClient{ + closeRequested: atomic.NewBool(false), + conn: setupGRPCServer(t), + } + done := client.AddRequest() + + doneCalled := atomic.NewBool(false) + go func() { + defer done() + time.Sleep(50 * time.Millisecond) + doneCalled.Store(true) + }() + + unittest.RequireReturnsBefore(t, func() { + client.Close() + }, 100*time.Millisecond, "client timed out closing connection") + + assert.True(t, client.closeRequested.Load()) + assert.True(t, doneCalled.Load()) + }) + + // Test closing a client that is already closing does not block + // Close() should return immediately + t.Run("already closing", func(t *testing.T) { + client := &CachedClient{ + closeRequested: atomic.NewBool(true), // close already requested + conn: setupGRPCServer(t), + } + done := client.AddRequest() + + doneCalled := atomic.NewBool(false) + go func() { + defer done() + + // use a long delay and require Close() to complete faster + time.Sleep(5 * time.Second) + doneCalled.Store(true) + }() + + // should return immediately + unittest.RequireReturnsBefore(t, func() { + client.Close() + }, 10*time.Millisecond, "client timed out closing connection") + + assert.True(t, client.closeRequested.Load()) + assert.False(t, doneCalled.Load()) + }) + + // Test closing a client that is locked during connection setup + // Close() should wait for the lock before shutting down + t.Run("connection setting up", func(t *testing.T) { + client := &CachedClient{ + closeRequested: atomic.NewBool(false), + } + + // simulate an in-progress connection setup + client.mu.Lock() + + go func() { + // unlock after setting up the connection + defer client.mu.Unlock() + + // pause before setting the connection to cause client.Close() to block + time.Sleep(100 * time.Millisecond) + client.conn = setupGRPCServer(t) + }() + + // should wait at least 100 milliseconds before returning + unittest.RequireReturnsBefore(t, func() { + client.Close() + }, 500*time.Millisecond, "client timed out closing connection") + + assert.True(t, client.closeRequested.Load()) + assert.NotNil(t, client.conn) + }) +} + +// Test that rapid connections and disconnects do not cause a panic. +func TestConcurrentConnectionsAndDisconnects(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + + cache, err := NewCache(logger, metrics, 1) + require.NoError(t, err) + + connectionCount := 100_000 + conn := setupGRPCServer(t) + + t.Run("test concurrent connections", func(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(connectionCount) + callCount := atomic.NewInt32(0) + for i := 0; i < connectionCount; i++ { + go func() { + defer wg.Done() + cachedConn, err := cache.GetConnected("foo", DefaultClientTimeout, nil, func(string, time.Duration, crypto.PublicKey, *CachedClient) (*grpc.ClientConn, error) { + callCount.Inc() + return conn, nil + }) + require.NoError(t, err) + + done := cachedConn.AddRequest() + time.Sleep(1 * time.Millisecond) + done() + }() + } + unittest.RequireReturnsBefore(t, wg.Wait, time.Second, "timed out waiting for connections to finish") + + // the client should be cached, so only a single connection is created + assert.Equal(t, int32(1), callCount.Load()) + }) + + t.Run("test rapid connections and invalidations", func(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(connectionCount) + callCount := atomic.NewInt32(0) + for i := 0; i < connectionCount; i++ { + go func() { + defer wg.Done() + cachedConn, err := cache.GetConnected("foo", DefaultClientTimeout, nil, func(string, time.Duration, crypto.PublicKey, *CachedClient) (*grpc.ClientConn, error) { + callCount.Inc() + return conn, nil + }) + require.NoError(t, err) + + done := cachedConn.AddRequest() + time.Sleep(1 * time.Millisecond) + cachedConn.Invalidate() + done() + }() + } + wg.Wait() + + // since all connections are invalidated, the cache should be empty at the end + require.Eventually(t, func() bool { + return cache.Len() == 0 + }, time.Second, 20*time.Millisecond, "cache should be empty") + + // Many connections should be created, but some will be shared + assert.Greater(t, callCount.Load(), int32(1)) + assert.LessOrEqual(t, callCount.Load(), int32(connectionCount)) + }) +} + +// setupGRPCServer starts a dummy grpc server for connection tests +func setupGRPCServer(t *testing.T) *grpc.ClientConn { + l, err := net.Listen("tcp", net.JoinHostPort("localhost", "0")) + require.NoError(t, err) + + server := grpc.NewServer() + + t.Cleanup(func() { + server.Stop() + }) + + go func() { + err = server.Serve(l) + require.NoError(t, err) + }() + + conn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + return conn +} diff --git a/engine/access/rpc/connection/connection.go b/engine/access/rpc/connection/connection.go index 161aa2949d2..c9533f945bc 100644 --- a/engine/access/rpc/connection/connection.go +++ b/engine/access/rpc/connection/connection.go @@ -75,7 +75,7 @@ func (cf *ConnectionFactoryImpl) GetAccessAPIClient(address string, networkPubKe // The networkPubKey is the public key used for secure gRPC connection. Can be nil for an unsecured connection. // The returned io.Closer should close the connection after the call if no error occurred during client creation. func (cf *ConnectionFactoryImpl) GetAccessAPIClientWithPort(address string, networkPubKey crypto.PublicKey) (access.AccessAPIClient, io.Closer, error) { - conn, closer, err := cf.Manager.GetConnection(address, cf.CollectionNodeGRPCTimeout, AccessClient, networkPubKey) + conn, closer, err := cf.Manager.GetConnection(address, cf.CollectionNodeGRPCTimeout, networkPubKey) if err != nil { return nil, nil, err } @@ -91,7 +91,7 @@ func (cf *ConnectionFactoryImpl) GetExecutionAPIClient(address string) (executio return nil, nil, err } - conn, closer, err := cf.Manager.GetConnection(grpcAddress, cf.ExecutionNodeGRPCTimeout, ExecutionClient, nil) + conn, closer, err := cf.Manager.GetConnection(grpcAddress, cf.ExecutionNodeGRPCTimeout, nil) if err != nil { return nil, nil, err } diff --git a/engine/access/rpc/connection/connection_test.go b/engine/access/rpc/connection/connection_test.go index 4f024105a95..4ef7d9a978b 100644 --- a/engine/access/rpc/connection/connection_test.go +++ b/engine/access/rpc/connection/connection_test.go @@ -2,7 +2,9 @@ package connection import ( "context" + "crypto/rand" "fmt" + "math/big" "net" "sync" "testing" @@ -19,7 +21,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "pgregory.net/rapid" @@ -29,6 +30,9 @@ import ( ) func TestProxyAccessAPI(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // create a collection node cn := new(collectionNode) cn.start(t) @@ -43,11 +47,11 @@ func TestProxyAccessAPI(t *testing.T) { // set the collection grpc port connectionFactory.CollectionGRPCPort = cn.port // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - nil, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + nil, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -70,15 +74,10 @@ func TestProxyAccessAPI(t *testing.T) { assert.Equal(t, resp, expected) } -func getCache(t *testing.T, cacheSize int) *lru.Cache[string, *CachedClient] { - cache, err := lru.NewWithEvict[string, *CachedClient](cacheSize, func(_ string, client *CachedClient) { - client.Close() - }) - require.NoError(t, err) - return cache -} - func TestProxyExecutionAPI(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // create an execution node en := new(executionNode) en.start(t) @@ -94,11 +93,11 @@ func TestProxyExecutionAPI(t *testing.T) { connectionFactory.ExecutionGRPCPort = en.port // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - nil, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + nil, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -121,6 +120,9 @@ func TestProxyExecutionAPI(t *testing.T) { } func TestProxyAccessAPIConnectionReuse(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // create a collection node cn := new(collectionNode) cn.start(t) @@ -134,16 +136,18 @@ func TestProxyAccessAPIConnectionReuse(t *testing.T) { connectionFactory := new(ConnectionFactoryImpl) // set the collection grpc port connectionFactory.CollectionGRPCPort = cn.port + // set the connection pool cache size cacheSize := 1 - connectionCache := NewCache(getCache(t, cacheSize), cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -161,9 +165,9 @@ func TestProxyAccessAPIConnectionReuse(t *testing.T) { assert.Nil(t, closer.Close()) var conn *grpc.ClientConn - res, ok := connectionCache.Get(proxyConnectionFactory.targetAddress) + res, ok := connectionCache.cache.Get(proxyConnectionFactory.targetAddress) assert.True(t, ok) - conn = res.ClientConn + conn = res.ClientConn() // check if api client can be rebuilt with retrieved connection accessAPIClient := access.NewAccessAPIClient(conn) @@ -174,6 +178,9 @@ func TestProxyAccessAPIConnectionReuse(t *testing.T) { } func TestProxyExecutionAPIConnectionReuse(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // create an execution node en := new(executionNode) en.start(t) @@ -187,15 +194,18 @@ func TestProxyExecutionAPIConnectionReuse(t *testing.T) { connectionFactory := new(ConnectionFactoryImpl) // set the execution grpc port connectionFactory.ExecutionGRPCPort = en.port + // set the connection pool cache size cacheSize := 5 - connectionCache := NewCache(getCache(t, cacheSize), cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) + // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -213,9 +223,9 @@ func TestProxyExecutionAPIConnectionReuse(t *testing.T) { assert.Nil(t, closer.Close()) var conn *grpc.ClientConn - res, ok := connectionCache.Get(proxyConnectionFactory.targetAddress) + res, ok := connectionCache.cache.Get(proxyConnectionFactory.targetAddress) assert.True(t, ok) - conn = res.ClientConn + conn = res.ClientConn() // check if api client can be rebuilt with retrieved connection executionAPIClient := execution.NewExecutionAPIClient(conn) @@ -227,6 +237,8 @@ func TestProxyExecutionAPIConnectionReuse(t *testing.T) { // TestExecutionNodeClientTimeout tests that the execution API client times out after the timeout duration func TestExecutionNodeClientTimeout(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() timeout := 10 * time.Millisecond @@ -246,15 +258,18 @@ func TestExecutionNodeClientTimeout(t *testing.T) { connectionFactory.ExecutionGRPCPort = en.port // set the execution grpc client timeout connectionFactory.ExecutionNodeGRPCTimeout = timeout + // set the connection pool cache size cacheSize := 5 - connectionCache := NewCache(getCache(t, cacheSize), cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) + // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -274,6 +289,8 @@ func TestExecutionNodeClientTimeout(t *testing.T) { // TestCollectionNodeClientTimeout tests that the collection API client times out after the timeout duration func TestCollectionNodeClientTimeout(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() timeout := 10 * time.Millisecond @@ -293,15 +310,18 @@ func TestCollectionNodeClientTimeout(t *testing.T) { connectionFactory.CollectionGRPCPort = cn.port // set the collection grpc client timeout connectionFactory.CollectionNodeGRPCTimeout = timeout + // set the connection pool cache size cacheSize := 5 - connectionCache := NewCache(getCache(t, cacheSize), cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) + // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -321,6 +341,9 @@ func TestCollectionNodeClientTimeout(t *testing.T) { // TestConnectionPoolFull tests that the LRU cache replaces connections when full func TestConnectionPoolFull(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // create a collection node cn1, cn2, cn3 := new(collectionNode), new(collectionNode), new(collectionNode) cn1.start(t) @@ -340,16 +363,18 @@ func TestConnectionPoolFull(t *testing.T) { connectionFactory := new(ConnectionFactoryImpl) // set the collection grpc port connectionFactory.CollectionGRPCPort = cn1.port + // set the connection pool cache size cacheSize := 2 - connectionCache := NewCache(getCache(t, cacheSize), cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -361,7 +386,7 @@ func TestConnectionPoolFull(t *testing.T) { // get a collection API client // Create and add first client to cache - _, _, err := connectionFactory.GetAccessAPIClient(cn1Address, nil) + _, _, err = connectionFactory.GetAccessAPIClient(cn1Address, nil) assert.Equal(t, connectionCache.Len(), 1) assert.NoError(t, err) @@ -370,38 +395,40 @@ func TestConnectionPoolFull(t *testing.T) { assert.Equal(t, connectionCache.Len(), 2) assert.NoError(t, err) - // Peek first client from cache. "recently used"-ness will not be updated, so it will be wiped out first. + // Get the first client from cache. _, _, err = connectionFactory.GetAccessAPIClient(cn1Address, nil) assert.Equal(t, connectionCache.Len(), 2) assert.NoError(t, err) - // Create and add third client to cache, firs client will be removed from cache + // Create and add third client to cache, second client will be removed from cache _, _, err = connectionFactory.GetAccessAPIClient(cn3Address, nil) assert.Equal(t, connectionCache.Len(), 2) assert.NoError(t, err) var hostnameOrIP string + hostnameOrIP, _, err = net.SplitHostPort(cn1Address) - assert.NoError(t, err) + require.NoError(t, err) grpcAddress1 := fmt.Sprintf("%s:%d", hostnameOrIP, connectionFactory.CollectionGRPCPort) + hostnameOrIP, _, err = net.SplitHostPort(cn2Address) - assert.NoError(t, err) + require.NoError(t, err) grpcAddress2 := fmt.Sprintf("%s:%d", hostnameOrIP, connectionFactory.CollectionGRPCPort) + hostnameOrIP, _, err = net.SplitHostPort(cn3Address) - assert.NoError(t, err) + require.NoError(t, err) grpcAddress3 := fmt.Sprintf("%s:%d", hostnameOrIP, connectionFactory.CollectionGRPCPort) - contains1 := connectionCache.Contains(grpcAddress1) - contains2 := connectionCache.Contains(grpcAddress2) - contains3 := connectionCache.Contains(grpcAddress3) - - assert.False(t, contains1) - assert.True(t, contains2) - assert.True(t, contains3) + assert.True(t, connectionCache.cache.Contains(grpcAddress1)) + assert.False(t, connectionCache.cache.Contains(grpcAddress2)) + assert.True(t, connectionCache.cache.Contains(grpcAddress3)) } // TestConnectionPoolStale tests that a new connection will be established if the old one cached is stale func TestConnectionPoolStale(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // create a collection node cn := new(collectionNode) cn.start(t) @@ -415,16 +442,18 @@ func TestConnectionPoolStale(t *testing.T) { connectionFactory := new(ConnectionFactoryImpl) // set the collection grpc port connectionFactory.CollectionGRPCPort = cn.port + // set the connection pool cache size cacheSize := 5 - connectionCache := NewCache(getCache(t, cacheSize), cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -440,10 +469,10 @@ func TestConnectionPoolStale(t *testing.T) { assert.Equal(t, connectionCache.Len(), 1) assert.NoError(t, err) // close connection to simulate something "going wrong" with our stored connection - res, _ := connectionCache.Get(proxyConnectionFactory.targetAddress) + cachedClient, _ := connectionCache.cache.Get(proxyConnectionFactory.targetAddress) - connectionCache.Remove(proxyConnectionFactory.targetAddress) - res.Close() + cachedClient.Invalidate() + cachedClient.Close() ctx := context.Background() // make the call to the collection node (should fail, connection closed) @@ -455,9 +484,9 @@ func TestConnectionPoolStale(t *testing.T) { assert.Equal(t, connectionCache.Len(), 1) var conn *grpc.ClientConn - res, ok := connectionCache.Get(proxyConnectionFactory.targetAddress) + res, ok := connectionCache.cache.Get(proxyConnectionFactory.targetAddress) assert.True(t, ok) - conn = res.ClientConn + conn = res.ClientConn() // check if api client can be rebuilt with retrieved connection accessAPIClient := access.NewAccessAPIClient(conn) @@ -475,6 +504,9 @@ func TestConnectionPoolStale(t *testing.T) { // - Wait for all goroutines to finish. // - Verify that the number of completed requests matches the number of sent responses. func TestExecutionNodeClientClosedGracefully(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // Add createExecNode function to recreate it each time for rapid test createExecNode := func() (*executionNode, func()) { en := new(executionNode) @@ -503,16 +535,18 @@ func TestExecutionNodeClientClosedGracefully(t *testing.T) { connectionFactory.ExecutionGRPCPort = en.port // set the execution grpc client timeout connectionFactory.ExecutionNodeGRPCTimeout = time.Second + // set the connection pool cache size cacheSize := 1 - connectionCache := NewCache(getCache(t, cacheSize), cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -526,7 +560,7 @@ func TestExecutionNodeClientClosedGracefully(t *testing.T) { ctx := context.Background() // Generate random number of requests - nofRequests := rapid.IntRange(10, 100).Draw(tt, "nofRequests").(int) + nofRequests := rapid.IntRange(10, 100).Draw(tt, "nofRequests") reqCompleted := atomic.NewUint64(0) var waitGroup sync.WaitGroup @@ -548,7 +582,7 @@ func TestExecutionNodeClientClosedGracefully(t *testing.T) { } // Close connection - connectionFactory.Manager.Remove(clientAddress) + // connectionFactory.Manager.Remove(clientAddress) waitGroup.Wait() @@ -566,6 +600,9 @@ func TestExecutionNodeClientClosedGracefully(t *testing.T) { // error response. // - Wait for the client state to change from "Ready" to "Shutdown", indicating that the client connection was closed. func TestEvictingCacheClients(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + // Create a new collection node for testing cn := new(collectionNode) cn.start(t) @@ -600,19 +637,21 @@ func TestEvictingCacheClients(t *testing.T) { // Set the connection pool cache size cacheSize := 1 + connectionCache, err := NewCache(logger, metrics, cacheSize) + require.NoError(t, err) + // create a non-blocking cache - cache, err := lru.NewWithEvict[string, *CachedClient](cacheSize, func(_ string, client *CachedClient) { + connectionCache.cache, err = lru.NewWithEvict[string, *CachedClient](cacheSize, func(_ string, client *CachedClient) { go client.Close() }) require.NoError(t, err) - connectionCache := NewCache(cache, cacheSize) // set metrics reporting - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics connectionFactory.Manager = NewManager( - connectionCache, - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{}, grpcutils.NoCompressor, @@ -626,12 +665,12 @@ func TestEvictingCacheClients(t *testing.T) { ctx := context.Background() // Retrieve the cached client from the cache - cachedClient, ok := connectionCache.Get(clientAddress) + cachedClient, ok := connectionCache.cache.Get(clientAddress) require.True(t, ok) // wait until the client connection is ready require.Eventually(t, func() bool { - return cachedClient.ClientConn.GetState() == connectivity.Ready + return cachedClient.ClientConn().GetState() == connectivity.Ready }, 100*time.Millisecond, 10*time.Millisecond, "client timed out before ready") // Schedule the invalidation of the access API client while the Ping call is in progress @@ -643,9 +682,9 @@ func TestEvictingCacheClients(t *testing.T) { <-startPing // wait until Ping is called // Invalidate the access API client - connectionFactory.Manager.Remove(clientAddress) + cachedClient.Invalidate() - // Remove marks the connection for closure asynchronously, so give it some time to run + // Invalidate marks the connection for closure asynchronously, so give it some time to run require.Eventually(t, func() bool { return cachedClient.closeRequested.Load() }, 100*time.Millisecond, 10*time.Millisecond, "client timed out closing connection") @@ -666,140 +705,116 @@ func TestEvictingCacheClients(t *testing.T) { // Wait for the client connection to change state from "Ready" to "Shutdown" as connection was closed. require.Eventually(t, func() bool { - return cachedClient.ClientConn.WaitForStateChange(ctx, connectivity.Ready) + return cachedClient.ClientConn().WaitForStateChange(ctx, connectivity.Ready) }, 100*time.Millisecond, 10*time.Millisecond, "client timed out transitioning state") - assert.Equal(t, connectivity.Shutdown, cachedClient.ClientConn.GetState()) + assert.Equal(t, connectivity.Shutdown, cachedClient.ClientConn().GetState()) assert.Equal(t, 0, connectionCache.Len()) wg.Wait() // wait until the move test routine is done } -func TestCachedClientShutdown(t *testing.T) { - // Test that a completely uninitialized client can be closed without panics - t.Run("uninitialized client", func(t *testing.T) { - client := &CachedClient{ - closeRequested: atomic.NewBool(false), - } - client.Close() - assert.True(t, client.closeRequested.Load()) - }) +func TestConcurrentConnections(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() - // Test closing a client with no outstanding requests - // Close() should return quickly - t.Run("with no outstanding requests", func(t *testing.T) { - client := &CachedClient{ - closeRequested: atomic.NewBool(false), - ClientConn: setupGRPCServer(t), - } - - unittest.RequireReturnsBefore(t, func() { - client.Close() - }, 100*time.Millisecond, "client timed out closing connection") - - assert.True(t, client.closeRequested.Load()) - }) - - // Test closing a client with outstanding requests waits for requests to complete - // Close() should block until the request completes - t.Run("with some outstanding requests", func(t *testing.T) { - client := &CachedClient{ - closeRequested: atomic.NewBool(false), - ClientConn: setupGRPCServer(t), + // Add createExecNode function to recreate it each time for rapid test + createExecNode := func() (*executionNode, func()) { + en := new(executionNode) + en.start(t) + return en, func() { + en.stop(t) } - client.wg.Add(1) - - done := atomic.NewBool(false) - go func() { - defer client.wg.Done() - time.Sleep(50 * time.Millisecond) - done.Store(true) - }() + } - unittest.RequireReturnsBefore(t, func() { - client.Close() - }, 100*time.Millisecond, "client timed out closing connection") + // setup the handler mock + req := &execution.PingRequest{} + resp := &execution.PingResponse{} - assert.True(t, client.closeRequested.Load()) - assert.True(t, done.Load()) - }) + // Note: rapid will randomly fail with an error: "group did not use any data from bitstream" + // See https://github.com/flyingmutant/rapid/issues/65 + rapid.Check(t, func(tt *rapid.T) { + en, closer := createExecNode() + defer closer() - // Test closing a client that is already closing does not block - // Close() should return immediately - t.Run("already closing", func(t *testing.T) { - client := &CachedClient{ - closeRequested: atomic.NewBool(true), // close already requested - ClientConn: setupGRPCServer(t), + // Note: rapid does not support concurrent calls to Draw for a given T, so they must be serialized + mu := sync.Mutex{} + getSleep := func() time.Duration { + mu.Lock() + defer mu.Unlock() + return time.Duration(rapid.Int64Range(100, 10_000).Draw(tt, "s")) } - client.wg.Add(1) - done := atomic.NewBool(false) - go func() { - defer client.wg.Done() + requestCount := rapid.IntRange(50, 1000).Draw(tt, "r") + responsesSent := atomic.NewInt32(0) + en.handler. + On("Ping", testifymock.Anything, req). + Return(func(_ context.Context, _ *execution.PingRequest) (*execution.PingResponse, error) { + time.Sleep(getSleep() * time.Microsecond) - // use a long delay and require Close() to complete faster - time.Sleep(5 * time.Second) - done.Store(true) - }() + // randomly fail ~25% of the time to test that client connection and reuse logic + // handles concurrent connect/disconnects + fail, err := rand.Int(rand.Reader, big.NewInt(4)) + require.NoError(tt, err) - // should return immediately - unittest.RequireReturnsBefore(t, func() { - client.Close() - }, 10*time.Millisecond, "client timed out closing connection") - - assert.True(t, client.closeRequested.Load()) - assert.False(t, done.Load()) - }) + if fail.Uint64()%4 == 0 { + err = status.Errorf(codes.Unavailable, "random error") + } - // Test closing a client that is locked during connection setup - // Close() should wait for the lock before shutting down - t.Run("connection setting up", func(t *testing.T) { - client := &CachedClient{ - closeRequested: atomic.NewBool(false), + responsesSent.Inc() + return resp, err + }) + + connectionCache, err := NewCache(logger, metrics, 1) + require.NoError(tt, err) + + connectionFactory := &ConnectionFactoryImpl{ + ExecutionGRPCPort: en.port, + ExecutionNodeGRPCTimeout: time.Second, + AccessMetrics: metrics, + Manager: NewManager( + logger, + metrics, + connectionCache, + 0, + CircuitBreakerConfig{}, + grpcutils.NoCompressor, + ), } - // simulate an in-progress connection setup - client.mu.Lock() + clientAddress := en.listener.Addr().String() - go func() { - // unlock after setting up the connection - defer client.mu.Unlock() + ctx := context.Background() - // pause before setting the connection to cause client.Close() to block - time.Sleep(100 * time.Millisecond) - client.ClientConn = setupGRPCServer(t) - }() + // Generate random number of requests + var wg sync.WaitGroup + wg.Add(requestCount) - // should wait at least 100 milliseconds before returning - unittest.RequireReturnsBefore(t, func() { - client.Close() - }, 500*time.Millisecond, "client timed out closing connection") + for i := 0; i < requestCount; i++ { + go func() { + defer wg.Done() - assert.True(t, client.closeRequested.Load()) - assert.NotNil(t, client.ClientConn) - }) -} + client, _, err := connectionFactory.GetExecutionAPIClient(clientAddress) + require.NoError(tt, err) -// setupGRPCServer starts a dummy grpc server for connection tests -func setupGRPCServer(t *testing.T) *grpc.ClientConn { - l, err := net.Listen("tcp", net.JoinHostPort("localhost", "0")) - require.NoError(t, err) + _, err = client.Ping(ctx, req) - server := grpc.NewServer() + if err != nil { + // Note: for some reason, when Unavailable is returned, the error message is + // changed to "the connection to 127.0.0.1:57753 was closed". Other error codes + // preserve the message. + require.Equalf(tt, codes.Unavailable, status.Code(err), "unexpected error: %v", err) + } + }() + } + wg.Wait() - t.Cleanup(func() { - server.Stop() + // the grpc client seems to throttle requests to servers that return Unavailable, so not + // all of the requests make it through to the backend every test. Requiring that at least 1 + // request is handled for these cases, but all should be handled in most runs. + assert.LessOrEqual(tt, responsesSent.Load(), int32(requestCount)) + assert.Greater(tt, responsesSent.Load(), int32(0)) }) - - go func() { - err = server.Serve(l) - require.NoError(t, err) - }() - - conn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) - require.NoError(t, err) - - return conn } var successCodes = []codes.Code{ @@ -812,6 +827,9 @@ var successCodes = []codes.Code{ // TestCircuitBreakerExecutionNode tests the circuit breaker for execution nodes. func TestCircuitBreakerExecutionNode(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + requestTimeout := 500 * time.Millisecond circuitBreakerRestoreTimeout := 1500 * time.Millisecond @@ -831,13 +849,13 @@ func TestCircuitBreakerExecutionNode(t *testing.T) { // Set the connection pool cache size. cacheSize := 1 - connectionCache, err := lru.New[string, *CachedClient](cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) require.NoError(t, err) connectionFactory.Manager = NewManager( - NewCache(connectionCache, cacheSize), - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{ Enabled: true, @@ -849,7 +867,7 @@ func TestCircuitBreakerExecutionNode(t *testing.T) { ) // Set metrics reporting. - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics // Create the execution API client. client, _, err := connectionFactory.GetExecutionAPIClient(en.listener.Addr().String()) @@ -915,6 +933,9 @@ func TestCircuitBreakerExecutionNode(t *testing.T) { // TestCircuitBreakerCollectionNode tests the circuit breaker for collection nodes. func TestCircuitBreakerCollectionNode(t *testing.T) { + logger := unittest.Logger() + metrics := metrics.NewNoopCollector() + requestTimeout := 500 * time.Millisecond circuitBreakerRestoreTimeout := 1500 * time.Millisecond @@ -934,13 +955,13 @@ func TestCircuitBreakerCollectionNode(t *testing.T) { // Set the connection pool cache size. cacheSize := 1 - connectionCache, err := lru.New[string, *CachedClient](cacheSize) + connectionCache, err := NewCache(logger, metrics, cacheSize) require.NoError(t, err) connectionFactory.Manager = NewManager( - NewCache(connectionCache, cacheSize), - unittest.Logger(), + logger, connectionFactory.AccessMetrics, + connectionCache, 0, CircuitBreakerConfig{ Enabled: true, @@ -952,7 +973,7 @@ func TestCircuitBreakerCollectionNode(t *testing.T) { ) // Set metrics reporting. - connectionFactory.AccessMetrics = metrics.NewNoopCollector() + connectionFactory.AccessMetrics = metrics // Create the collection API client. client, _, err := connectionFactory.GetAccessAPIClient(cn.listener.Addr().String(), nil) diff --git a/engine/access/rpc/connection/grpc_compression_benchmark_test.go b/engine/access/rpc/connection/grpc_compression_benchmark_test.go index 1854d845d72..6ab86fa39a4 100644 --- a/engine/access/rpc/connection/grpc_compression_benchmark_test.go +++ b/engine/access/rpc/connection/grpc_compression_benchmark_test.go @@ -75,9 +75,9 @@ func runBenchmark(b *testing.B, compressorName string) { // set metrics reporting connectionFactory.AccessMetrics = metrics.NewNoopCollector() connectionFactory.Manager = NewManager( - nil, unittest.Logger(), connectionFactory.AccessMetrics, + nil, grpcutils.DefaultMaxMsgSize, CircuitBreakerConfig{}, compressorName, diff --git a/engine/access/rpc/connection/manager.go b/engine/access/rpc/connection/manager.go index add02afb4ca..356fbef1b0c 100644 --- a/engine/access/rpc/connection/manager.go +++ b/engine/access/rpc/connection/manager.go @@ -11,7 +11,6 @@ import ( "github.com/sony/gobreaker" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" _ "google.golang.org/grpc/encoding/gzip" //required for gRPC compression @@ -24,17 +23,9 @@ import ( "github.com/onflow/flow-go/utils/grpcutils" ) -// DefaultClientTimeout is used when making a GRPC request to a collection node or an execution node. +// DefaultClientTimeout is used when making a GRPC request to a collection or execution node. const DefaultClientTimeout = 3 * time.Second -// clientType is an enumeration type used to differentiate between different types of gRPC clients. -type clientType int - -const ( - AccessClient clientType = iota - ExecutionClient -) - type noopCloser struct{} func (c *noopCloser) Close() error { @@ -43,9 +34,9 @@ func (c *noopCloser) Close() error { // Manager provides methods for getting and managing gRPC client connections. type Manager struct { - cache *Cache logger zerolog.Logger metrics module.AccessMetrics + cache *Cache maxMsgSize uint circuitBreakerConfig CircuitBreakerConfig compressorName string @@ -67,9 +58,9 @@ type CircuitBreakerConfig struct { // NewManager creates a new Manager with the specified parameters. func NewManager( - cache *Cache, logger zerolog.Logger, metrics module.AccessMetrics, + cache *Cache, maxMsgSize uint, circuitBreakerConfig CircuitBreakerConfig, compressorName string, @@ -91,18 +82,18 @@ func NewManager( func (m *Manager) GetConnection( grpcAddress string, timeout time.Duration, - clientType clientType, networkPubKey crypto.PublicKey, ) (*grpc.ClientConn, io.Closer, error) { if m.cache != nil { - conn, err := m.retrieveConnection(grpcAddress, timeout, clientType, networkPubKey) + client, err := m.cache.GetConnected(grpcAddress, timeout, networkPubKey, m.createConnection) if err != nil { return nil, nil, err } - return conn, &noopCloser{}, nil + + return client.ClientConn(), &noopCloser{}, nil } - conn, err := m.createConnection(grpcAddress, timeout, nil, clientType, networkPubKey) + conn, err := m.createConnection(grpcAddress, timeout, networkPubKey, nil) if err != nil { return nil, nil, err } @@ -110,80 +101,6 @@ func (m *Manager) GetConnection( return conn, io.Closer(conn), nil } -// Remove removes the gRPC client connection associated with the given grpcAddress from the cache. -// It returns true if the connection was removed successfully, false otherwise. -func (m *Manager) Remove(grpcAddress string) bool { - if m.cache == nil { - return false - } - - client, ok := m.cache.Get(grpcAddress) - if !ok { - return false - } - - // First, remove the client from the cache to ensure other callers create a new entry - // Remove is done atomically, so only the first caller will succeed - if !m.cache.Remove(grpcAddress) { - return false - } - - // Close the connection asynchronously to avoid blocking requests - go client.Close() - - return true -} - -// HasCache returns true if the Manager has a cache, false otherwise. -func (m *Manager) HasCache() bool { - return m.cache != nil -} - -// retrieveConnection retrieves the CachedClient for the given grpcAddress from the cache or adds a new one if not present. -// If the connection is already cached, it waits for the lock and returns the connection from the cache. -// Otherwise, it creates a new connection and caches it. -// The networkPubKey is the public key used for retrieving secure gRPC connection. Can be nil for an unsecured connection. -func (m *Manager) retrieveConnection( - grpcAddress string, - timeout time.Duration, - clientType clientType, - networkPubKey crypto.PublicKey, -) (*grpc.ClientConn, error) { - client, ok := m.cache.GetOrAdd(grpcAddress, timeout) - if ok { - // The client was retrieved from the cache, wait for the lock - client.mu.Lock() - if m.metrics != nil { - m.metrics.ConnectionFromPoolReused() - } - } else { - // The client is new, lock is already held - if m.metrics != nil { - m.metrics.ConnectionAddedToPool() - } - } - defer client.mu.Unlock() - - if client.ClientConn != nil && client.ClientConn.GetState() != connectivity.Shutdown { - // Return the client connection from the cache - return client.ClientConn, nil - } - - // The connection is not cached or is closed, create a new connection and cache it - conn, err := m.createConnection(grpcAddress, timeout, client, clientType, networkPubKey) - if err != nil { - return nil, err - } - - client.ClientConn = conn - if m.metrics != nil { - m.metrics.NewConnectionEstablished() - m.metrics.TotalConnectionsInPool(uint(m.cache.Len()), uint(m.cache.MaxSize())) - } - - return client.ClientConn, nil -} - // createConnection creates a new gRPC connection to the remote node at the given address with the specified timeout. // If the cachedClient is not nil, it means a new entry in the cache is being created, so it's locked to give priority // to the caller working with the new client, allowing it to create the underlying connection. @@ -192,9 +109,8 @@ func (m *Manager) retrieveConnection( func (m *Manager) createConnection( address string, timeout time.Duration, - cachedClient *CachedClient, - clientType clientType, networkPubKey crypto.PublicKey, + cachedClient *CachedClient, ) (*grpc.ClientConn, error) { if timeout == 0 { timeout = DefaultClientTimeout @@ -210,8 +126,8 @@ func (m *Manager) createConnection( // https://grpc.io/blog/grpc-web-interceptor/#binding-interceptors var connInterceptors []grpc.UnaryClientInterceptor - if !m.circuitBreakerConfig.Enabled { - connInterceptors = append(connInterceptors, m.createClientInvalidationInterceptor(address, clientType)) + if !m.circuitBreakerConfig.Enabled && cachedClient != nil { + connInterceptors = append(connInterceptors, m.createClientInvalidationInterceptor(cachedClient)) } connInterceptors = append(connInterceptors, createClientTimeoutInterceptor(timeout)) @@ -272,13 +188,13 @@ func createRequestWatcherInterceptor(cachedClient *CachedClient) grpc.UnaryClien opts ...grpc.CallOption, ) error { // Prevent new requests from being sent if the connection is marked for closure. - if cachedClient.closeRequested.Load() { - return status.Errorf(codes.Unavailable, "the connection to %s was closed", cachedClient.Address) + if cachedClient.CloseRequested() { + return status.Errorf(codes.Unavailable, "the connection to %s was closed", cachedClient.Address()) } // Increment the request counter to track ongoing requests, then decrement the request counter before returning. - cachedClient.wg.Add(1) - defer cachedClient.wg.Done() + done := cachedClient.AddRequest() + defer done() // Invoke the actual RPC method. return invoker(ctx, method, req, reply, cc, opts...) @@ -320,49 +236,23 @@ func createClientTimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInter // createClientInvalidationInterceptor creates a client interceptor for client invalidation. It should only be created // if the circuit breaker is disabled. If the response from the server indicates an unavailable status, it invalidates // the corresponding client. -func (m *Manager) createClientInvalidationInterceptor( - address string, - clientType clientType, -) grpc.UnaryClientInterceptor { - if !m.circuitBreakerConfig.Enabled { - clientInvalidationInterceptor := func( - ctx context.Context, - method string, - req interface{}, - reply interface{}, - cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, - opts ...grpc.CallOption, - ) error { - err := invoker(ctx, method, req, reply, cc, opts...) - if status.Code(err) == codes.Unavailable { - switch clientType { - case AccessClient: - if m.Remove(address) { - m.logger.Debug().Str("cached_access_client_invalidated", address).Msg("invalidating cached access client") - if m.metrics != nil { - m.metrics.ConnectionFromPoolInvalidated() - } - } - case ExecutionClient: - if m.Remove(address) { - m.logger.Debug().Str("cached_execution_client_invalidated", address).Msg("invalidating cached execution client") - if m.metrics != nil { - m.metrics.ConnectionFromPoolInvalidated() - } - } - default: - m.logger.Info().Str("client_invalidation_interceptor", address).Msg(fmt.Sprintf("unexpected client type: %d", clientType)) - } - } - - return err +func (m *Manager) createClientInvalidationInterceptor(cachedClient *CachedClient) grpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req interface{}, + reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + err := invoker(ctx, method, req, reply, cc, opts...) + if status.Code(err) == codes.Unavailable { + cachedClient.Invalidate() } - return clientInvalidationInterceptor + return err } - - return nil } // The simplified representation and description of circuit breaker pattern, that used to handle node connectivity: diff --git a/engine/execution/ingestion/engine.go b/engine/execution/ingestion/engine.go index 778fbf880b8..560e695f8d8 100644 --- a/engine/execution/ingestion/engine.go +++ b/engine/execution/ingestion/engine.go @@ -499,8 +499,6 @@ func (e *Engine) onBlockExecuted( e.metrics.ExecutionStorageStateCommitment(int64(len(finalState))) e.metrics.ExecutionLastExecutedBlockHeight(executed.Block.Header.Height) - // e.checkStateSyncStop(executed.Block.Header.Height) - missingCollections := make(map[*entity.ExecutableBlock][]*flow.CollectionGuarantee) err := e.mempool.Run( func( diff --git a/engine/testutil/nodes.go b/engine/testutil/nodes.go index c61472d0d4e..3815fe9220d 100644 --- a/engine/testutil/nodes.go +++ b/engine/testutil/nodes.go @@ -573,7 +573,7 @@ func ExecutionNode(t *testing.T, hub *stub.Hub, identity bootstrap.NodeInfo, ide ls, err := completeLedger.NewLedger(diskWal, capacity, metricsCollector, node.Log.With().Str("compontent", "ledger").Logger(), completeLedger.DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := completeLedger.NewCompactor(ls, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := completeLedger.NewCompactor(ls, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metricsCollector) require.NoError(t, err) <-compactor.Ready() // Need to start compactor here because BootstrapLedger() updates ledger state. diff --git a/go.mod b/go.mod index 6364bb237de..5244ac5c976 100644 --- a/go.mod +++ b/go.mod @@ -94,12 +94,13 @@ require ( google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0 google.golang.org/protobuf v1.31.0 gotest.tools v2.2.0+incompatible - pgregory.net/rapid v0.4.7 + pgregory.net/rapid v1.1.0 ) require ( github.com/cockroachdb/pebble v0.0.0-20230928194634-aa077af62593 github.com/coreos/go-semver v0.3.0 + github.com/docker/go-units v0.5.0 github.com/go-playground/validator/v10 v10.14.1 github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb github.com/gorilla/websocket v1.5.0 @@ -160,7 +161,6 @@ require ( github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect github.com/dgraph-io/ristretto v0.1.0 // indirect github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect - github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/elastic/gosigar v0.14.2 // indirect github.com/ethereum/c-kzg-4844 v0.4.0 // indirect diff --git a/go.sum b/go.sum index 2e58165b70b..758282755e7 100644 --- a/go.sum +++ b/go.sum @@ -2478,8 +2478,9 @@ lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1 nhooyr.io/websocket v1.8.6/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= -pgregory.net/rapid v0.4.7 h1:MTNRktPuv5FNqOO151TM9mDTa+XHcX6ypYeISDVD14g= pgregory.net/rapid v0.4.7/go.mod h1:UYpPVyjFHzYBGHIxLFoupi8vwk6rXNzRY9OMvVxFIOU= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= +pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/insecure/go.sum b/insecure/go.sum index 02d7808d8aa..d25308d39dd 100644 --- a/insecure/go.sum +++ b/insecure/go.sum @@ -2398,8 +2398,8 @@ lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1 nhooyr.io/websocket v1.8.6/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= -pgregory.net/rapid v0.4.7 h1:MTNRktPuv5FNqOO151TM9mDTa+XHcX6ypYeISDVD14g= pgregory.net/rapid v0.4.7/go.mod h1:UYpPVyjFHzYBGHIxLFoupi8vwk6rXNzRY9OMvVxFIOU= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/integration/go.sum b/integration/go.sum index 4404781edfa..5f27761ccc5 100644 --- a/integration/go.sum +++ b/integration/go.sum @@ -2579,8 +2579,8 @@ modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= modernc.org/sqlite v1.21.1 h1:GyDFqNnESLOhwwDRaHGdp2jKLDzpyT/rNLglX3ZkMSU= modernc.org/sqlite v1.21.1/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI= -pgregory.net/rapid v0.4.7 h1:MTNRktPuv5FNqOO151TM9mDTa+XHcX6ypYeISDVD14g= pgregory.net/rapid v0.4.7/go.mod h1:UYpPVyjFHzYBGHIxLFoupi8vwk6rXNzRY9OMvVxFIOU= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/ledger/complete/compactor.go b/ledger/complete/compactor.go index ef603900af1..a08a36d2232 100644 --- a/ledger/complete/compactor.go +++ b/ledger/complete/compactor.go @@ -13,6 +13,7 @@ import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/complete/mtrie/trie" realWAL "github.com/onflow/flow-go/ledger/complete/wal" + "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/lifecycle" "github.com/onflow/flow-go/module/observable" ) @@ -57,6 +58,7 @@ type Compactor struct { stopCh chan chan struct{} trieUpdateCh <-chan *WALTrieUpdate triggerCheckpointOnNextSegmentFinish *atomic.Bool // to trigger checkpoint manually + metrics module.WALMetrics } // NewCompactor creates new Compactor which writes WAL record and triggers @@ -76,6 +78,7 @@ func NewCompactor( checkpointDistance uint, checkpointsToKeep uint, triggerCheckpointOnNextSegmentFinish *atomic.Bool, + metrics module.WALMetrics, ) (*Compactor, error) { if checkpointDistance < 1 { checkpointDistance = 1 @@ -114,6 +117,7 @@ func NewCompactor( checkpointDistance: checkpointDistance, checkpointsToKeep: checkpointsToKeep, triggerCheckpointOnNextSegmentFinish: triggerCheckpointOnNextSegmentFinish, + metrics: metrics, }, nil } @@ -288,7 +292,7 @@ Loop: // Since this function is only for checkpointing, Compactor isn't affected by returned error. func (c *Compactor) checkpoint(ctx context.Context, tries []*trie.MTrie, checkpointNum int) error { - err := createCheckpoint(c.checkpointer, c.logger, tries, checkpointNum) + err := createCheckpoint(c.checkpointer, c.logger, tries, checkpointNum, c.metrics) if err != nil { return &createCheckpointError{num: checkpointNum, err: err} } @@ -325,7 +329,7 @@ func (c *Compactor) checkpoint(ctx context.Context, tries []*trie.MTrie, checkpo // createCheckpoint creates checkpoint with given checkpointNum and tries. // Errors indicate that checkpoint file can't be created. // Caller should handle returned errors by retrying checkpointing when appropriate. -func createCheckpoint(checkpointer *realWAL.Checkpointer, logger zerolog.Logger, tries []*trie.MTrie, checkpointNum int) error { +func createCheckpoint(checkpointer *realWAL.Checkpointer, logger zerolog.Logger, tries []*trie.MTrie, checkpointNum int, metrics module.WALMetrics) error { logger.Info().Msgf("serializing checkpoint %d with %v tries", checkpointNum, len(tries)) @@ -337,6 +341,13 @@ func createCheckpoint(checkpointer *realWAL.Checkpointer, logger zerolog.Logger, return fmt.Errorf("error serializing checkpoint (%d): %w", checkpointNum, err) } + size, err := realWAL.ReadCheckpointFileSize(checkpointer.Dir(), fileName) + if err != nil { + return fmt.Errorf("error reading checkpoint file size (%d): %w", checkpointNum, err) + } + + metrics.ExecutionCheckpointSize(size) + duration := time.Since(startTime) logger.Info().Float64("total_time_s", duration.Seconds()).Msgf("created checkpoint %d", checkpointNum) diff --git a/ledger/complete/compactor_test.go b/ledger/complete/compactor_test.go index 3258361cb04..15cf89a446f 100644 --- a/ledger/complete/compactor_test.go +++ b/ledger/complete/compactor_test.go @@ -90,7 +90,7 @@ func TestCompactorCreation(t *testing.T) { // WAL segments are 32kB, so here we generate 2 keys 64kB each, times `size` // so we should get at least `size` segments - compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) co := CompactorObserver{fromBound: 8, done: make(chan struct{})} @@ -316,7 +316,7 @@ func TestCompactorSkipCheckpointing(t *testing.T) { // WAL segments are 32kB, so here we generate 2 keys 64kB each, times `size` // so we should get at least `size` segments - compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) co := CompactorObserver{fromBound: 8, done: make(chan struct{})} @@ -442,7 +442,7 @@ func TestCompactorAccuracy(t *testing.T) { l, err := NewLedger(wal, forestCapacity, metricsCollector, zerolog.Logger{}, DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) fromBound := lastCheckpointNum + (size / 2) @@ -552,7 +552,7 @@ func TestCompactorTriggeredByAdminTool(t *testing.T) { l, err := NewLedger(wal, forestCapacity, metricsCollector, unittest.LoggerWithName("ledger"), DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := NewCompactor(l, wal, unittest.LoggerWithName("compactor"), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(true)) + compactor, err := NewCompactor(l, wal, unittest.LoggerWithName("compactor"), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(true), metrics.NewNoopCollector()) require.NoError(t, err) fmt.Println("should stop as soon as segment 5 is generated, which should trigger checkpoint 5 to be created") @@ -656,7 +656,7 @@ func TestCompactorConcurrency(t *testing.T) { l, err := NewLedger(wal, forestCapacity, metricsCollector, zerolog.Logger{}, DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := NewCompactor(l, wal, unittest.Logger(), forestCapacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) fromBound := lastCheckpointNum + (size / 2 * numGoroutine) diff --git a/ledger/complete/ledger_benchmark_test.go b/ledger/complete/ledger_benchmark_test.go index 6c0855be914..a97257ac2a6 100644 --- a/ledger/complete/ledger_benchmark_test.go +++ b/ledger/complete/ledger_benchmark_test.go @@ -47,7 +47,7 @@ func benchmarkStorage(steps int, b *testing.B) { led, err := complete.NewLedger(diskWal, steps+1, &metrics.NoopCollector{}, zerolog.Logger{}, complete.DefaultPathFinderVersion) require.NoError(b, err) - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), uint(steps+1), checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), uint(steps+1), checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(b, err) <-compactor.Ready() @@ -160,7 +160,7 @@ func BenchmarkTrieUpdate(b *testing.B) { led, err := complete.NewLedger(diskWal, capacity, &metrics.NoopCollector{}, zerolog.Logger{}, complete.DefaultPathFinderVersion) require.NoError(b, err) - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(b, err) <-compactor.Ready() @@ -212,7 +212,7 @@ func BenchmarkTrieRead(b *testing.B) { led, err := complete.NewLedger(diskWal, capacity, &metrics.NoopCollector{}, zerolog.Logger{}, complete.DefaultPathFinderVersion) require.NoError(b, err) - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(b, err) <-compactor.Ready() @@ -273,7 +273,7 @@ func BenchmarkLedgerGetOneValue(b *testing.B) { led, err := complete.NewLedger(diskWal, capacity, &metrics.NoopCollector{}, zerolog.Logger{}, complete.DefaultPathFinderVersion) require.NoError(b, err) - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(b, err) <-compactor.Ready() @@ -351,7 +351,7 @@ func BenchmarkTrieProve(b *testing.B) { led, err := complete.NewLedger(diskWal, capacity, &metrics.NoopCollector{}, zerolog.Logger{}, complete.DefaultPathFinderVersion) require.NoError(b, err) - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(b, err) <-compactor.Ready() diff --git a/ledger/complete/ledger_test.go b/ledger/complete/ledger_test.go index b0685fb7ef4..f429aa851f4 100644 --- a/ledger/complete/ledger_test.go +++ b/ledger/complete/ledger_test.go @@ -514,7 +514,7 @@ func Test_WAL(t *testing.T) { led, err := complete.NewLedger(diskWal, size, metricsCollector, logger, complete.DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), size, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), size, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) <-compactor.Ready() @@ -551,7 +551,7 @@ func Test_WAL(t *testing.T) { led2, err := complete.NewLedger(diskWal2, size+10, metricsCollector, logger, complete.DefaultPathFinderVersion) require.NoError(t, err) - compactor2, err := complete.NewCompactor(led2, diskWal2, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor2, err := complete.NewCompactor(led2, diskWal2, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) <-compactor2.Ready() @@ -613,7 +613,7 @@ func TestLedgerFunctionality(t *testing.T) { require.NoError(t, err) led, err := complete.NewLedger(diskWal, activeTries, metricsCollector, logger, complete.DefaultPathFinderVersion) assert.NoError(t, err) - compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), uint(activeTries), checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, zerolog.Nop(), uint(activeTries), checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) <-compactor.Ready() @@ -730,7 +730,7 @@ func TestWALUpdateFailuresBubbleUp(t *testing.T) { led, err := complete.NewLedger(w, capacity, &metrics.NoopCollector{}, zerolog.Logger{}, complete.DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := complete.NewCompactor(led, w, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, w, zerolog.Nop(), capacity, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) <-compactor.Ready() diff --git a/ledger/complete/mtrie/trie/trie.go b/ledger/complete/mtrie/trie/trie.go index 7f03e3558bd..064e7f157e3 100644 --- a/ledger/complete/mtrie/trie/trie.go +++ b/ledger/complete/mtrie/trie/trie.go @@ -78,7 +78,7 @@ func (mt *MTrie) AllocatedRegCount() uint64 { return mt.regCount } -// AllocatedRegSize returns the size of allocated registers in the trie. +// AllocatedRegSize returns the size (number of bytes) of allocated registers in the trie. // Concurrency safe (as Tries are immutable structures by convention) func (mt *MTrie) AllocatedRegSize() uint64 { return mt.regSize diff --git a/ledger/complete/wal/checkpoint_v6_reader.go b/ledger/complete/wal/checkpoint_v6_reader.go index 2b8f626d80c..c2703261d24 100644 --- a/ledger/complete/wal/checkpoint_v6_reader.go +++ b/ledger/complete/wal/checkpoint_v6_reader.go @@ -105,6 +105,34 @@ func OpenAndReadCheckpointV6(dir string, fileName string, logger zerolog.Logger) return triesToReturn, errToReturn } +// ReadCheckpointFileSize returns the total size of the checkpoint file +func ReadCheckpointFileSize(dir string, fileName string) (uint64, error) { + paths := allFilePaths(dir, fileName) + totalSize := uint64(0) + for _, path := range paths { + fileInfo, err := os.Stat(path) + if err != nil { + return 0, fmt.Errorf("could not get file info for %v: %w", path, err) + } + + totalSize += uint64(fileInfo.Size()) + } + + return totalSize, nil +} + +func allFilePaths(dir string, fileName string) []string { + paths := make([]string, 0, 1+subtrieCount+1) + paths = append(paths, filePathCheckpointHeader(dir, fileName)) + for i := 0; i < subtrieCount; i++ { + subTriePath, _, _ := filePathSubTries(dir, fileName, i) + paths = append(paths, subTriePath) + } + topTriePath, _ := filePathTopTries(dir, fileName) + paths = append(paths, topTriePath) + return paths +} + func filePathCheckpointHeader(dir string, fileName string) string { return path.Join(dir, fileName) } diff --git a/ledger/complete/wal/checkpoint_v6_writer.go b/ledger/complete/wal/checkpoint_v6_writer.go index 93f97151b0e..5c420b8842d 100644 --- a/ledger/complete/wal/checkpoint_v6_writer.go +++ b/ledger/complete/wal/checkpoint_v6_writer.go @@ -10,6 +10,7 @@ import ( "path" "path/filepath" + "github.com/docker/go-units" "github.com/hashicorp/go-multierror" "github.com/rs/zerolog" @@ -79,8 +80,10 @@ func storeCheckpointV6( lg.Info(). Str("first_hash", first.RootHash().String()). Uint64("first_reg_count", first.AllocatedRegCount()). + Str("first_reg_size", units.BytesSize(float64(first.AllocatedRegSize()))). Str("last_hash", last.RootHash().String()). Uint64("last_reg_count", last.AllocatedRegCount()). + Str("last_reg_size", units.BytesSize(float64(last.AllocatedRegSize()))). Msg("storing checkpoint") // make sure a checkpoint file with same name doesn't exist diff --git a/ledger/complete/wal/checkpointer.go b/ledger/complete/wal/checkpointer.go index 1c6aaa0aef3..937d82e79a7 100644 --- a/ledger/complete/wal/checkpointer.go +++ b/ledger/complete/wal/checkpointer.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" + "github.com/docker/go-units" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" @@ -252,7 +253,14 @@ func (c *Checkpointer) Checkpoint(to int) (err error) { return fmt.Errorf("could not create checkpoint for %v: %w", to, err) } - c.wal.log.Info().Msgf("created checkpoint %d with %d tries", to, len(tries)) + checkpointFileSize, err := ReadCheckpointFileSize(c.wal.dir, fileName) + if err != nil { + return fmt.Errorf("could not read checkpoint file size: %w", err) + } + + c.wal.log.Info(). + Str("checkpoint_file_size", units.BytesSize(float64(checkpointFileSize))). + Msgf("created checkpoint %d with %d tries", to, len(tries)) return nil } diff --git a/ledger/complete/wal/checkpointer_test.go b/ledger/complete/wal/checkpointer_test.go index a0a828748d3..dd46ffdb85e 100644 --- a/ledger/complete/wal/checkpointer_test.go +++ b/ledger/complete/wal/checkpointer_test.go @@ -59,7 +59,7 @@ func Test_WAL(t *testing.T) { led, err := complete.NewLedger(diskWal, size*10, metricsCollector, logger, complete.DefaultPathFinderVersion) require.NoError(t, err) - compactor, err := complete.NewCompactor(led, diskWal, unittest.Logger(), size, checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + compactor, err := complete.NewCompactor(led, diskWal, unittest.Logger(), size, checkpointDistance, checkpointsToKeep, atomic.NewBool(false), metrics.NewNoopCollector()) require.NoError(t, err) <-compactor.Ready() diff --git a/module/chainsync/core_rapid_test.go b/module/chainsync/core_rapid_test.go index 649fce871d8..2554577caa3 100644 --- a/module/chainsync/core_rapid_test.go +++ b/module/chainsync/core_rapid_test.go @@ -25,8 +25,8 @@ func populatedBlockStore(t *rapid.T) []*flow.Header { store := []*flow.Header{unittest.BlockHeaderFixture()} for i := 1; i < NUM_BLOCKS; i++ { // we sample from the store 2/3 times to get deeper trees - b := rapid.OneOf(rapid.Just(unittest.BlockHeaderFixture()), rapid.SampledFrom(store), rapid.SampledFrom(store)).Draw(t, "parent").(flow.Header) - store = append(store, unittest.BlockHeaderWithParentFixture(&b)) + b := rapid.OneOf(rapid.Just(unittest.BlockHeaderFixture()), rapid.SampledFrom(store), rapid.SampledFrom(store)).Draw(t, "parent") + store = append(store, unittest.BlockHeaderWithParentFixture(b)) } return store } @@ -38,8 +38,8 @@ type rapidSync struct { heightRequests map[uint64]bool // depth 1 pushdown automaton to track height requests } -// Init is an action for initializing a rapidSync instance. -func (r *rapidSync) Init(t *rapid.T) { +// init is an action for initializing a rapidSync instance. +func (r *rapidSync) init(t *rapid.T) { var err error r.core, err = New(zerolog.New(io.Discard), DefaultConfig(), metrics.NewNoopCollector(), flow.Localnet) @@ -52,7 +52,7 @@ func (r *rapidSync) Init(t *rapid.T) { // RequestByID is an action that requests a block by its ID. func (r *rapidSync) RequestByID(t *rapid.T) { - b := rapid.SampledFrom(r.store).Draw(t, "id_request").(*flow.Header) + b := rapid.SampledFrom(r.store).Draw(t, "id_request") r.core.RequestBlock(b.ID(), b.Height) // Re-queueing by ID should always succeed r.idRequests[b.ID()] = true @@ -62,7 +62,7 @@ func (r *rapidSync) RequestByID(t *rapid.T) { // RequestByHeight is an action that requests a specific height func (r *rapidSync) RequestByHeight(t *rapid.T) { - b := rapid.SampledFrom(r.store).Draw(t, "height_request").(*flow.Header) + b := rapid.SampledFrom(r.store).Draw(t, "height_request") r.core.RequestHeight(b.Height) // Re-queueing by height should always succeed r.heightRequests[b.Height] = true @@ -71,8 +71,8 @@ func (r *rapidSync) RequestByHeight(t *rapid.T) { // HandleHeight is an action that requests a heights // upon receiving an argument beyond a certain tolerance func (r *rapidSync) HandleHeight(t *rapid.T) { - b := rapid.SampledFrom(r.store).Draw(t, "height_hint_request").(*flow.Header) - incr := rapid.IntRange(0, (int)(DefaultConfig().Tolerance)+1).Draw(t, "height increment").(int) + b := rapid.SampledFrom(r.store).Draw(t, "height_hint_request") + incr := rapid.IntRange(0, (int)(DefaultConfig().Tolerance)+1).Draw(t, "height increment") requestHeight := b.Height + (uint64)(incr) r.core.HandleHeight(b, requestHeight) // Re-queueing by height should always succeed if beyond tolerance @@ -85,7 +85,7 @@ func (r *rapidSync) HandleHeight(t *rapid.T) { // HandleByID is an action that provides a block header to the sync engine func (r *rapidSync) HandleByID(t *rapid.T) { - b := rapid.SampledFrom(r.store).Draw(t, "id_handling").(*flow.Header) + b := rapid.SampledFrom(r.store).Draw(t, "id_handling") success := r.core.HandleBlock(b) assert.True(t, success || r.idRequests[b.ID()] == false) @@ -174,7 +174,11 @@ func (r *rapidSync) Check(t *rapid.T) { func TestRapidSync(t *testing.T) { unittest.SkipUnless(t, unittest.TEST_FLAKY, "flaky test") - rapid.Check(t, rapid.Run(&rapidSync{})) + rapid.Check(t, func(t *rapid.T) { + sm := new(rapidSync) + sm.init(t) + t.Repeat(rapid.StateMachineActions(sm)) + }) } // utility functions diff --git a/module/component/component_manager_test.go b/module/component/component_manager_test.go index fc99ca92af3..5fe55ae5460 100644 --- a/module/component/component_manager_test.go +++ b/module/component/component_manager_test.go @@ -345,7 +345,7 @@ func StartStateTransition() (func(t func()), func(*rapid.T)) { executeTransitions := func(t *rapid.T) { for i := 0; i < len(transitions); i++ { - j := rapid.IntRange(0, len(transitions)-i-1).Draw(t, "").(int) + j := rapid.IntRange(0, len(transitions)-i-1).Draw(t, "") transitions[i], transitions[j+i] = transitions[j+i], transitions[i] transitions[i]() } @@ -390,35 +390,34 @@ type ComponentManagerMachine struct { assertErrorThrownMatches func(t *rapid.T, err error, msgAndArgs ...interface{}) assertErrorNotThrown func(t *rapid.T) - cancelGenerator *rapid.Generator + cancelGenerator *rapid.Generator[bool] drawStateTransition func(t *rapid.T) *StateTransition } -func (c *ComponentManagerMachine) Init(t *rapid.T) { - numWorkers := rapid.IntRange(0, 5).Draw(t, "num_workers").(int) - pCancel := rapid.Float64Range(0, 100).Draw(t, "p_cancel").(float64) +func (c *ComponentManagerMachine) init(t *rapid.T) { + numWorkers := rapid.IntRange(0, 5).Draw(t, "num_workers") + pCancel := rapid.Float64Range(0, 100).Draw(t, "p_cancel") - c.cancelGenerator = rapid.Float64Range(0, 100). - Map(func(n float64) bool { - return pCancel == 100 || n < pCancel - }) + c.cancelGenerator = rapid.Map(rapid.Float64Range(0, 100), func(n float64) bool { + return pCancel == 100 || n < pCancel + }) c.drawStateTransition = func(t *rapid.T) *StateTransition { st := &StateTransition{} if !c.canceled { - st.cancel = c.cancelGenerator.Draw(t, "cancel").(bool) + st.cancel = c.cancelGenerator.Draw(t, "cancel") } for workerId, state := range c.workerStates { if allowedTransitions, ok := WorkerStateTransitions[state]; ok { label := fmt.Sprintf("worker_transition_%v", workerId) st.workerIDs = append(st.workerIDs, workerId) - st.workerTransitions = append(st.workerTransitions, rapid.SampledFrom(allowedTransitions).Draw(t, label).(WorkerStateTransition)) + st.workerTransitions = append(st.workerTransitions, rapid.SampledFrom(allowedTransitions).Draw(t, label)) } } - return rapid.Just(st).Draw(t, "state_transition").(*StateTransition) + return rapid.Just(st).Draw(t, "state_transition") } ctx, cancel := context.WithCancel(context.Background()) @@ -625,7 +624,11 @@ func (c *ComponentManagerMachine) Check(t *rapid.T) { func TestComponentManager(t *testing.T) { unittest.SkipUnless(t, unittest.TEST_LONG_RUNNING, "skip because this test takes too long") - rapid.Check(t, rapid.Run(&ComponentManagerMachine{})) + rapid.Check(t, func(t *rapid.T) { + sm := new(ComponentManagerMachine) + sm.init(t) + t.Repeat(rapid.StateMachineActions(sm)) + }) } func TestComponentManagerShutdown(t *testing.T) { diff --git a/module/mempool/stdmap/incorporated_result_seals_test.go b/module/mempool/stdmap/incorporated_result_seals_test.go index fb1a4b450b9..2f83fb0c128 100644 --- a/module/mempool/stdmap/incorporated_result_seals_test.go +++ b/module/mempool/stdmap/incorporated_result_seals_test.go @@ -18,14 +18,14 @@ type icrSealsMachine struct { state []*flow.IncorporatedResultSeal // model of the icrSeals } -// Init is an action for initializing a icrSeals instance. -func (m *icrSealsMachine) Init(t *rapid.T) { +// init is an action for initializing a icrSeals instance. +func (m *icrSealsMachine) init(t *rapid.T) { m.icrs = NewIncorporatedResultSeals(1000) } // Add is a conditional action which adds an item to the icrSeals. func (m *icrSealsMachine) Add(t *rapid.T) { - i := rapid.Uint64().Draw(t, "i").(uint64) + i := rapid.Uint64().Draw(t, "i") seal := unittest.IncorporatedResultSeal.Fixture(func(s *flow.IncorporatedResultSeal) { s.Header.Height = i @@ -49,7 +49,7 @@ func (m *icrSealsMachine) Add(t *rapid.T) { // Prune is a Conditional action that removes elements of height strictly lower than its argument func (m *icrSealsMachine) PruneUpToHeight(t *rapid.T) { - h := rapid.Uint64().Draw(t, "h").(uint64) + h := rapid.Uint64().Draw(t, "h") err := m.icrs.PruneUpToHeight(h) if h >= m.icrs.lowestHeight { require.NoError(t, err) @@ -72,7 +72,7 @@ func (m *icrSealsMachine) Get(t *rapid.T) { if n == 0 { return } - i := rapid.IntRange(0, n-1).Draw(t, "i").(int) + i := rapid.IntRange(0, n-1).Draw(t, "i") s := m.state[i] actual, ok := m.icrs.ByID(s.ID()) @@ -89,7 +89,7 @@ func (m *icrSealsMachine) GetUnknown(t *rapid.T) { if n == 0 { return } - i := rapid.IntRange(0, n-1).Draw(t, "i").(int) + i := rapid.IntRange(0, n-1).Draw(t, "i") seal := unittest.IncorporatedResultSeal.Fixture(func(s *flow.IncorporatedResultSeal) { s.Header.Height = uint64(i) }) @@ -117,7 +117,7 @@ func (m *icrSealsMachine) Remove(t *rapid.T) { if n == 0 { return } - i := rapid.IntRange(0, n-1).Draw(t, "i").(int) + i := rapid.IntRange(0, n-1).Draw(t, "i") s := m.state[i] ok := m.icrs.Remove(s.ID()) @@ -137,7 +137,7 @@ func (m *icrSealsMachine) RemoveUnknown(t *rapid.T) { if n == 0 { return } - i := rapid.IntRange(0, n-1).Draw(t, "i").(int) + i := rapid.IntRange(0, n-1).Draw(t, "i") seal := unittest.IncorporatedResultSeal.Fixture(func(s *flow.IncorporatedResultSeal) { s.Header.Height = uint64(i) }) @@ -168,7 +168,11 @@ func (m *icrSealsMachine) Check(t *rapid.T) { // Run the icrSeals state machine and test it against its model func TestIcrs(t *testing.T) { - rapid.Check(t, rapid.Run(&icrSealsMachine{})) + rapid.Check(t, func(t *rapid.T) { + sm := new(icrSealsMachine) + sm.init(t) + t.Repeat(rapid.StateMachineActions(sm)) + }) } func TestIncorporatedResultSeals(t *testing.T) { diff --git a/module/metrics.go b/module/metrics.go index 834a7ec04ef..1423e53c7b7 100644 --- a/module/metrics.go +++ b/module/metrics.go @@ -729,6 +729,8 @@ type LedgerMetrics interface { } type WALMetrics interface { + // ExecutionCheckpointSize reports the size of a checkpoint in bytes + ExecutionCheckpointSize(bytes uint64) } type RateLimitedBlockstoreMetrics interface { diff --git a/module/metrics/execution.go b/module/metrics/execution.go index 90fc9ea27f4..6dbf6f5c3b5 100644 --- a/module/metrics/execution.go +++ b/module/metrics/execution.go @@ -22,6 +22,7 @@ type ExecutionCollector struct { lastFinalizedExecutedBlockHeightGauge prometheus.Gauge stateStorageDiskTotal prometheus.Gauge storageStateCommitment prometheus.Gauge + checkpointSize prometheus.Gauge forestApproxMemorySize prometheus.Gauge forestNumberOfTrees prometheus.Gauge latestTrieRegCount prometheus.Gauge @@ -650,6 +651,7 @@ func NewExecutionCollector(tracer module.Tracer) *ExecutionCollector { Help: "the execution state size on disk in bytes", }), + // TODO: remove storageStateCommitment: promauto.NewGauge(prometheus.GaugeOpts{ Namespace: namespaceExecution, Subsystem: subsystemStateStorage, @@ -657,6 +659,13 @@ func NewExecutionCollector(tracer module.Tracer) *ExecutionCollector { Help: "the storage size of a state commitment in bytes", }), + checkpointSize: promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: namespaceExecution, + Subsystem: subsystemStateStorage, + Name: "checkpoint_size_bytes", + Help: "the size of a checkpoint in bytes", + }), + stateSyncActive: promauto.NewGauge(prometheus.GaugeOpts{ Namespace: namespaceExecution, Subsystem: subsystemIngestion, @@ -799,6 +808,11 @@ func (ec *ExecutionCollector) ExecutionStorageStateCommitment(bytes int64) { ec.storageStateCommitment.Set(float64(bytes)) } +// ExecutionCheckpointSize reports the size of a checkpoint in bytes +func (ec *ExecutionCollector) ExecutionCheckpointSize(bytes uint64) { + ec.checkpointSize.Set(float64(bytes)) +} + // ExecutionLastExecutedBlockHeight reports last executed block height func (ec *ExecutionCollector) ExecutionLastExecutedBlockHeight(height uint64) { ec.lastExecutedBlockHeightGauge.Set(float64(height)) diff --git a/module/metrics/noop.go b/module/metrics/noop.go index 04a6d80b70e..def8a8f2d27 100644 --- a/module/metrics/noop.go +++ b/module/metrics/noop.go @@ -161,6 +161,7 @@ func (nc *NoopCollector) StartBlockReceivedToExecuted(blockID flow.Identifier) func (nc *NoopCollector) FinishBlockReceivedToExecuted(blockID flow.Identifier) {} func (nc *NoopCollector) ExecutionComputationUsedPerBlock(computation uint64) {} func (nc *NoopCollector) ExecutionStorageStateCommitment(bytes int64) {} +func (nc *NoopCollector) ExecutionCheckpointSize(bytes uint64) {} func (nc *NoopCollector) ExecutionLastExecutedBlockHeight(height uint64) {} func (nc *NoopCollector) ExecutionLastFinalizedExecutedBlockHeight(height uint64) {} func (nc *NoopCollector) ExecutionBlockExecuted(_ time.Duration, _ module.ExecutionResultStats) {} diff --git a/module/mock/execution_metrics.go b/module/mock/execution_metrics.go index bca785e7e75..cb9f6b632dc 100644 --- a/module/mock/execution_metrics.go +++ b/module/mock/execution_metrics.go @@ -46,6 +46,11 @@ func (_m *ExecutionMetrics) ExecutionBlockExecutionEffortVectorComponent(_a0 str _m.Called(_a0, _a1) } +// ExecutionCheckpointSize provides a mock function with given fields: bytes +func (_m *ExecutionMetrics) ExecutionCheckpointSize(bytes uint64) { + _m.Called(bytes) +} + // ExecutionChunkDataPackGenerated provides a mock function with given fields: proofSize, numberOfTransactions func (_m *ExecutionMetrics) ExecutionChunkDataPackGenerated(proofSize int, numberOfTransactions int) { _m.Called(proofSize, numberOfTransactions) diff --git a/module/mock/wal_metrics.go b/module/mock/wal_metrics.go index bf26cbb86ef..04806761950 100644 --- a/module/mock/wal_metrics.go +++ b/module/mock/wal_metrics.go @@ -9,6 +9,11 @@ type WALMetrics struct { mock.Mock } +// ExecutionCheckpointSize provides a mock function with given fields: bytes +func (_m *WALMetrics) ExecutionCheckpointSize(bytes uint64) { + _m.Called(bytes) +} + type mockConstructorTestingTNewWALMetrics interface { mock.TestingT Cleanup(func()) diff --git a/module/signature/checksum_test.go b/module/signature/checksum_test.go index 35a11408bca..9006565aca7 100644 --- a/module/signature/checksum_test.go +++ b/module/signature/checksum_test.go @@ -50,11 +50,11 @@ func TestCheckSum(t *testing.T) { // is able to extract the same data as the encoder. func TestPrefixCheckSum(t *testing.T) { rapid.Check(t, func(t *rapid.T) { - committeeSize := rapid.IntRange(0, 300).Draw(t, "committeeSize").(int) + committeeSize := rapid.IntRange(0, 300).Draw(t, "committeeSize") committee := unittest.IdentifierListFixture(committeeSize) - data := rapid.IntRange(0, 200).Map(func(count int) []byte { + data := rapid.Map(rapid.IntRange(0, 200), func(count int) []byte { return unittest.RandomBytes(count) - }).Draw(t, "data").([]byte) + }).Draw(t, "data") extracted, err := msig.CompareAndExtract(committee, msig.PrefixCheckSum(committee, data)) require.NoError(t, err) require.Equal(t, data, extracted) diff --git a/module/signature/signer_indices_test.go b/module/signature/signer_indices_test.go index 47be774088e..2a10311e2a9 100644 --- a/module/signature/signer_indices_test.go +++ b/module/signature/signer_indices_test.go @@ -104,9 +104,9 @@ func TestEncodeFail(t *testing.T) { func Test_EncodeSignerToIndicesAndSigType(t *testing.T) { rapid.Check(t, func(t *rapid.T) { // select total committee size, number of random beacon signers and number of staking signers - committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize").(int) - numStakingSigners := rapid.IntRange(0, committeeSize).Draw(t, "numStakingSigners").(int) - numRandomBeaconSigners := rapid.IntRange(0, committeeSize-numStakingSigners).Draw(t, "numRandomBeaconSigners").(int) + committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize") + numStakingSigners := rapid.IntRange(0, committeeSize).Draw(t, "numStakingSigners") + numRandomBeaconSigners := rapid.IntRange(0, committeeSize-numStakingSigners).Draw(t, "numRandomBeaconSigners") // create committee committeeIdentities := unittest.IdentityListFixture(committeeSize, unittest.WithRole(flow.RoleConsensus)).Sort(flow.Canonical[flow.Identity]) @@ -142,9 +142,9 @@ func Test_EncodeSignerToIndicesAndSigType(t *testing.T) { func Test_DecodeSigTypeToStakingAndBeaconSigners(t *testing.T) { rapid.Check(t, func(t *rapid.T) { // select total committee size, number of random beacon signers and number of staking signers - committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize").(int) - numStakingSigners := rapid.IntRange(0, committeeSize).Draw(t, "numStakingSigners").(int) - numRandomBeaconSigners := rapid.IntRange(0, committeeSize-numStakingSigners).Draw(t, "numRandomBeaconSigners").(int) + committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize") + numStakingSigners := rapid.IntRange(0, committeeSize).Draw(t, "numStakingSigners") + numRandomBeaconSigners := rapid.IntRange(0, committeeSize-numStakingSigners).Draw(t, "numRandomBeaconSigners") // create committee committeeIdentities := unittest.IdentityListFixture(committeeSize, unittest.WithRole(flow.RoleConsensus)). @@ -270,8 +270,8 @@ func TestValidPaddingErrIllegallyPaddedBitVector(t *testing.T) { func Test_EncodeSignersToIndices(t *testing.T) { rapid.Check(t, func(t *rapid.T) { // select total committee size, number of random beacon signers and number of staking signers - committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize").(int) - numSigners := rapid.IntRange(0, committeeSize).Draw(t, "numSigners").(int) + committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize") + numSigners := rapid.IntRange(0, committeeSize).Draw(t, "numSigners") // create committee identities := unittest.IdentityListFixture(committeeSize, unittest.WithRole(flow.RoleConsensus)).Sort(flow.Canonical[flow.Identity]) @@ -300,8 +300,8 @@ func Test_EncodeSignersToIndices(t *testing.T) { func Test_DecodeSignerIndicesToIdentifiers(t *testing.T) { rapid.Check(t, func(t *rapid.T) { // select total committee size, number of random beacon signers and number of staking signers - committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize").(int) - numSigners := rapid.IntRange(0, committeeSize).Draw(t, "numSigners").(int) + committeeSize := rapid.IntRange(1, 272).Draw(t, "committeeSize") + numSigners := rapid.IntRange(0, committeeSize).Draw(t, "numSigners") // create committee identities := unittest.IdentityListFixture(committeeSize, unittest.WithRole(flow.RoleConsensus)).Sort(flow.Canonical[flow.Identity]) @@ -336,8 +336,8 @@ const UpperBoundCommitteeSize = 272 func Test_DecodeSignerIndicesToIdentities(t *testing.T) { rapid.Check(t, func(t *rapid.T) { // select total committee size, number of random beacon signers and number of staking signers - committeeSize := rapid.IntRange(1, UpperBoundCommitteeSize).Draw(t, "committeeSize").(int) - numSigners := rapid.IntRange(0, committeeSize).Draw(t, "numSigners").(int) + committeeSize := rapid.IntRange(1, UpperBoundCommitteeSize).Draw(t, "committeeSize") + numSigners := rapid.IntRange(0, committeeSize).Draw(t, "numSigners") // create committee identities := unittest.IdentityListFixture(committeeSize, unittest.WithRole(flow.RoleConsensus)).Sort(flow.Canonical[flow.Identity]) diff --git a/state/protocol/events/gadgets/views_test.go b/state/protocol/events/gadgets/views_test.go index 484531c4b53..a0393398322 100644 --- a/state/protocol/events/gadgets/views_test.go +++ b/state/protocol/events/gadgets/views_test.go @@ -19,7 +19,7 @@ type viewsMachine struct { expectedCalls int // expected value of calls at any given time } -func (m *viewsMachine) Init(_ *rapid.T) { +func (m *viewsMachine) init(_ *rapid.T) { m.views = NewViews() m.callbacks = make(map[uint64]int) m.calls = 0 @@ -27,7 +27,7 @@ func (m *viewsMachine) Init(_ *rapid.T) { } func (m *viewsMachine) OnView(t *rapid.T) { - view := rapid.Uint64().Draw(t, "view").(uint64) + view := rapid.Uint64().Draw(t, "view") m.views.OnView(view, func(_ *flow.Header) { m.calls++ // count actual number of calls invoked by Views }) @@ -37,7 +37,7 @@ func (m *viewsMachine) OnView(t *rapid.T) { } func (m *viewsMachine) BlockFinalized(t *rapid.T) { - view := rapid.Uint64().Draw(t, "view").(uint64) + view := rapid.Uint64().Draw(t, "view") block := unittest.BlockHeaderFixture() block.View = view @@ -58,5 +58,9 @@ func (m *viewsMachine) Check(t *rapid.T) { } func TestViewsRapid(t *testing.T) { - rapid.Check(t, rapid.Run(new(viewsMachine))) + rapid.Check(t, func(t *rapid.T) { + sm := new(viewsMachine) + sm.init(t) + t.Repeat(rapid.StateMachineActions(sm)) + }) }