diff --git a/.github/stale.yml b/.github/workflows/stale.yml similarity index 100% rename from .github/stale.yml rename to .github/workflows/stale.yml diff --git a/SECURITY.md b/SECURITY.md index 6b370e9060b..2a38679616c 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -9,4 +9,4 @@ If you care about making a difference, please follow the guidelines below. # **Guidelines For Responsible Disclosure** -We ask that all researchers adhere to these guidelines [here](https://docs.onflow.org/bounties/responsible-disclosure/) +We ask that all researchers adhere to these guidelines [here](https://flow.com/flow-responsible-disclosure) diff --git a/access/handler.go b/access/handler.go index d9af3ea4720..3a49af84c68 100644 --- a/access/handler.go +++ b/access/handler.go @@ -24,6 +24,16 @@ type Handler struct { me module.Local } +// TODO: this is implemented in https://github.com/onflow/flow-go/pull/4957, remove when merged +func (h *Handler) GetProtocolStateSnapshotByBlockID(ctx context.Context, request *access.GetProtocolStateSnapshotByBlockIDRequest) (*access.ProtocolStateSnapshotResponse, error) { + panic("implement me") +} + +// TODO: this is implemented in https://github.com/onflow/flow-go/pull/4957, remove when merged +func (h *Handler) GetProtocolStateSnapshotByHeight(ctx context.Context, request *access.GetProtocolStateSnapshotByHeightRequest) (*access.ProtocolStateSnapshotResponse, error) { + panic("implement me") +} + // HandlerOption is used to hand over optional constructor parameters type HandlerOption func(*Handler) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index bfbcc877df9..ce717777b19 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -146,6 +146,7 @@ type AccessNodeConfig struct { executionDataIndexingEnabled bool registersDBPath string checkpointFile string + scriptExecutorConfig query.QueryConfig } type PublicNetworkConfig struct { @@ -232,6 +233,7 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { executionDataIndexingEnabled: false, registersDBPath: filepath.Join(homedir, ".flow", "execution_state"), checkpointFile: cmd.NotSet, + scriptExecutorConfig: query.NewDefaultConfig(), } } @@ -771,6 +773,7 @@ func (builder *FlowAccessNodeBuilder) BuildExecutionSyncComponents() *FlowAccess query.NewProtocolStateWrapper(builder.State), builder.Storage.Headers, builder.ExecutionIndexerCore.RegisterValue, + builder.scriptExecutorConfig, ) if err != nil { return nil, err @@ -929,6 +932,11 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { // Script Execution flags.StringVar(&builder.rpcConf.BackendConfig.ScriptExecutionMode, "script-execution-mode", defaultConfig.rpcConf.BackendConfig.ScriptExecutionMode, "mode to use when executing scripts. one of (local-only, execution-nodes-only, failover, compare)") + flags.Uint64Var(&builder.scriptExecutorConfig.ComputationLimit, "script-execution-computation-limit", defaultConfig.scriptExecutorConfig.ComputationLimit, "maximum number of computation units a locally executed script can use. default: 100000") + flags.IntVar(&builder.scriptExecutorConfig.MaxErrorMessageSize, "script-execution-max-error-length", defaultConfig.scriptExecutorConfig.MaxErrorMessageSize, "maximum number characters to include in error message strings. additional characters are truncated. default: 1000") + flags.DurationVar(&builder.scriptExecutorConfig.LogTimeThreshold, "script-execution-log-time-threshold", defaultConfig.scriptExecutorConfig.LogTimeThreshold, "emit a log for any scripts that take over this threshold. default: 1s") + flags.DurationVar(&builder.scriptExecutorConfig.ExecutionTimeLimit, "script-execution-timeout", defaultConfig.scriptExecutorConfig.ExecutionTimeLimit, "timeout value for locally executed scripts. default: 10s") + }).ValidateFlags(func() error { if builder.supportsObserver && (builder.PublicNetworkConfig.BindAddress == cmd.NotSet || builder.PublicNetworkConfig.BindAddress == "") { return errors.New("public-network-address must be set if supports-observer is true") @@ -1563,6 +1571,7 @@ func (builder *FlowAccessNodeBuilder) initPublicLibp2pNode(networkKey crypto.Pri UpdateInterval: builder.FlowConfig.NetworkConfig.PeerUpdateInterval, ConnectorFactory: connection.DefaultLibp2pBackoffConnectorFactory(), }, + &builder.FlowConfig.NetworkConfig.GossipSubConfig.SubscriptionProviderConfig, &p2p.DisallowListCacheConfig{ MaxSize: builder.FlowConfig.NetworkConfig.DisallowListNotificationCacheSize, Metrics: metrics.DisallowListCacheMetricsFactory(builder.HeroCacheMetricsFactory(), network.PublicNetwork), diff --git a/cmd/consensus/main.go b/cmd/consensus/main.go index 63ddfaa9cdd..835668ce747 100644 --- a/cmd/consensus/main.go +++ b/cmd/consensus/main.go @@ -594,6 +594,8 @@ func main() { ) notifier.AddParticipantConsumer(telemetryConsumer) + notifier.AddCommunicatorConsumer(telemetryConsumer) + notifier.AddFinalizationConsumer(telemetryConsumer) notifier.AddFollowerConsumer(followerDistributor) // initialize the persister diff --git a/cmd/execution_builder.go b/cmd/execution_builder.go index 9f0fe5114f3..06d948443a2 100644 --- a/cmd/execution_builder.go +++ b/cmd/execution_builder.go @@ -540,16 +540,10 @@ func (exeNode *ExecutionNode) LoadProviderEngine( "cannot get the latest executed block id: %w", err) } - stateCommit, err := exeNode.executionState.StateCommitmentByBlockID( - ctx, - blockID) + blockSnapshot, _, err := exeNode.executionState.CreateStorageSnapshot(blockID) if err != nil { - return nil, fmt.Errorf( - "cannot get the state commitment at latest executed block id %s: %w", - blockID.String(), - err) + return nil, fmt.Errorf("cannot create a storage snapshot at block %v: %w", blockID, err) } - blockSnapshot := exeNode.executionState.NewStorageSnapshot(stateCommit) // Get the epoch counter from the smart contract at the last executed block. contractEpochCounter, err := getContractEpochCounter( @@ -868,7 +862,7 @@ func (exeNode *ExecutionNode) LoadIngestionEngine( } fetcher := fetcher.NewCollectionFetcher(node.Logger, exeNode.collectionRequester, node.State, exeNode.exeConf.onflowOnlyLNs) - loader := loader.NewLoader(node.Logger, node.State, node.Storage.Headers, exeNode.executionState) + loader := loader.NewUnexecutedLoader(node.Logger, node.State, node.Storage.Headers, exeNode.executionState) exeNode.ingestionEng, err = ingestion.New( exeNode.ingestionUnit, @@ -905,7 +899,6 @@ func (exeNode *ExecutionNode) LoadScriptsEngine(node *NodeConfig) (module.ReadyD exeNode.scriptsEng = scripts.New( node.Logger, - node.State, exeNode.computationManager.QueryExecutor(), exeNode.executionState, ) diff --git a/cmd/execution_config.go b/cmd/execution_config.go index 6c6e0033ad0..4dff7829141 100644 --- a/cmd/execution_config.go +++ b/cmd/execution_config.go @@ -11,6 +11,7 @@ import ( "github.com/onflow/flow-go/engine/common/provider" "github.com/onflow/flow-go/engine/execution/computation/query" exeprovider "github.com/onflow/flow-go/engine/execution/provider" + "github.com/onflow/flow-go/fvm" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/mempool" "github.com/onflow/flow-go/utils/grpcutils" @@ -89,6 +90,8 @@ func (exeConf *ExecutionConfig) SetupFlags(flags *pflag.FlagSet) { "threshold for logging script execution") flags.DurationVar(&exeConf.computationConfig.QueryConfig.ExecutionTimeLimit, "script-execution-time-limit", query.DefaultExecutionTimeLimit, "script execution time limit") + flags.Uint64Var(&exeConf.computationConfig.QueryConfig.ComputationLimit, "script-execution-computation-limit", fvm.DefaultComputationLimit, + "script execution computation limit") flags.UintVar(&exeConf.transactionResultsCacheSize, "transaction-results-cache-size", 10000, "number of transaction results to be cached") flags.BoolVar(&exeConf.extensiveLog, "extensive-logging", false, "extensive logging logs tx contents and block headers") flags.DurationVar(&exeConf.chunkDataPackQueryTimeout, "chunk-data-pack-query-timeout", exeprovider.DefaultChunkDataPackQueryTimeout, "timeout duration to determine a chunk data pack query being slow") diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index f1c66d9a820..84e28f532f1 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -722,6 +722,7 @@ func (builder *ObserverServiceBuilder) initPublicLibp2pNode(networkKey crypto.Pr &builder.FlowConfig.NetworkConfig.ResourceManager, &builder.FlowConfig.NetworkConfig.GossipSubConfig.GossipSubRPCInspectorsConfig, p2pconfig.PeerManagerDisableConfig(), // disable peer manager for observer node. + &builder.FlowConfig.NetworkConfig.GossipSubConfig.SubscriptionProviderConfig, &p2p.DisallowListCacheConfig{ MaxSize: builder.FlowConfig.NetworkConfig.DisallowListNotificationCacheSize, Metrics: metrics.DisallowListCacheMetricsFactory(builder.HeroCacheMetricsFactory(), network.PublicNetwork), diff --git a/config/default-config.yml b/config/default-config.yml index 5c75ed4b338..d4f244286fa 100644 --- a/config/default-config.yml +++ b/config/default-config.yml @@ -33,18 +33,6 @@ network-config: # retry a unicast stream to a remote peer 3 times, the peer will give up and will not retry creating a unicast stream to that remote peer. # When it is set to zero it means that the peer will not retry creating a unicast stream to a remote peer if it fails. unicast-max-stream-creation-retry-attempt-times: 3 - # The number of seconds that the local peer waits since the last successful dial to a remote peer before resetting the unicast dial retry budget from zero to the maximum default. - # If it is set to 3600s (1h) for example, it means that if it has passed at least one hour since the last successful dial, and the remote peer has a zero dial retry budget, - # the unicast dial retry budget for that remote peer will be reset to the maximum default. - unicast-dial-zero-retry-reset-threshold: 3600s - # The maximum number of retry attempts for dialing a remote peer before giving up. If it is set to 3 for example, it means that if a peer fails to dial a remote peer 3 times, - # the peer will give up and will not retry dialing that remote peer. - unicast-max-dial-retry-attempt-times: 3 - # The backoff delay used in the exponential backoff for consecutive failed unicast dial attempts to a remote peer. - unicast-dial-backoff-delay: 1s - # The backoff delay used in the exponential backoff for backing off concurrent create stream attempts to the same remote peer - # when there is no available connections to that remote peer and a dial is in progress. - unicast-dial-in-progress-backoff-delay: 1s # The size of the dial config cache used to keep track of the dial config for each remote peer. The dial config is used to keep track of the dial retry budget for each remote peer. # Recommended to set it to the maximum number of remote peers in the network. unicast-dial-config-cache-size: 10_000 @@ -146,6 +134,15 @@ network-config: gossipsub-rpc-sent-tracker-workers: 5 # Peer scoring is the default value for enabling peer scoring gossipsub-peer-scoring-enabled: true + # The interval for updating the list of subscribed peers to all topics in gossipsub. This is used to keep track of subscriptions + # violations and penalize peers accordingly. Recommended value is in the order of a few minutes to avoid contentions; as the operation + # reads all topics and all peers subscribed to each topic. + gossipsub-subscription-provider-update-interval: 10m + # The size of cache for keeping the list of all peers subscribed to each topic (same as the local node). This cache is the local node's + # view of the network and is used to detect subscription violations and penalize peers accordingly. Recommended to be big enough to + # keep the entire network's size. Otherwise, the local node's view of the network will be incomplete due to cache eviction. + # Recommended size is 10x the number of peers in the network. + gossipsub-subscription-provider-cache-size: 10000 # Gossipsub rpc inspectors configs # The size of the queue for notifications about invalid RPC messages diff --git a/consensus/hotstuff/notifications/telemetry.go b/consensus/hotstuff/notifications/telemetry.go index 7bbf57f79de..d6cc3852179 100644 --- a/consensus/hotstuff/notifications/telemetry.go +++ b/consensus/hotstuff/notifications/telemetry.go @@ -38,7 +38,10 @@ type TelemetryConsumer struct { noPathLogger zerolog.Logger } +// Telemetry implements consumers for _all happy-path_ interfaces in consensus/hotstuff/notifications/telemetry.go: var _ hotstuff.ParticipantConsumer = (*TelemetryConsumer)(nil) +var _ hotstuff.CommunicatorConsumer = (*TelemetryConsumer)(nil) +var _ hotstuff.FinalizationConsumer = (*TelemetryConsumer)(nil) var _ hotstuff.VoteCollectorConsumer = (*TelemetryConsumer)(nil) var _ hotstuff.TimeoutCollectorConsumer = (*TelemetryConsumer)(nil) diff --git a/consensus/hotstuff/pacemaker/pacemaker.go b/consensus/hotstuff/pacemaker/pacemaker.go index fc3ba87dbe3..ae62aee0ea2 100644 --- a/consensus/hotstuff/pacemaker/pacemaker.go +++ b/consensus/hotstuff/pacemaker/pacemaker.go @@ -31,7 +31,7 @@ type ActivePaceMaker struct { ctx context.Context timeoutControl *timeout.Controller - notifier hotstuff.Consumer + notifier hotstuff.ParticipantConsumer viewTracker viewTracker started bool } diff --git a/engine/access/apiproxy/access_api_proxy.go b/engine/access/apiproxy/access_api_proxy.go index f5898686fc6..c8afbe2f7bd 100644 --- a/engine/access/apiproxy/access_api_proxy.go +++ b/engine/access/apiproxy/access_api_proxy.go @@ -42,6 +42,16 @@ func (h *FlowAccessAPIRouter) log(handler, rpc string, err error) { logger.Info().Msg("request succeeded") } +// TODO: this is implemented in https://github.com/onflow/flow-go/pull/4957, remove when merged +func (h *FlowAccessAPIRouter) GetProtocolStateSnapshotByBlockID(ctx context.Context, request *access.GetProtocolStateSnapshotByBlockIDRequest) (*access.ProtocolStateSnapshotResponse, error) { + panic("implement me") +} + +// TODO: this is implemented in https://github.com/onflow/flow-go/pull/4957, remove when merged +func (h *FlowAccessAPIRouter) GetProtocolStateSnapshotByHeight(ctx context.Context, request *access.GetProtocolStateSnapshotByHeightRequest) (*access.ProtocolStateSnapshotResponse, error) { + panic("implement me") +} + // Ping pings the service. It is special in the sense that it responds successful, // only if all underlying services are ready. func (h *FlowAccessAPIRouter) Ping(context context.Context, req *access.PingRequest) (*access.PingResponse, error) { diff --git a/engine/access/mock/access_api_client.go b/engine/access/mock/access_api_client.go index 4e2b1d065c7..f44ebcaca02 100644 --- a/engine/access/mock/access_api_client.go +++ b/engine/access/mock/access_api_client.go @@ -677,6 +677,72 @@ func (_m *AccessAPIClient) GetNodeVersionInfo(ctx context.Context, in *access.Ge return r0, r1 } +// GetProtocolStateSnapshotByBlockID provides a mock function with given fields: ctx, in, opts +func (_m *AccessAPIClient) GetProtocolStateSnapshotByBlockID(ctx context.Context, in *access.GetProtocolStateSnapshotByBlockIDRequest, opts ...grpc.CallOption) (*access.ProtocolStateSnapshotResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *access.ProtocolStateSnapshotResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByBlockIDRequest, ...grpc.CallOption) (*access.ProtocolStateSnapshotResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByBlockIDRequest, ...grpc.CallOption) *access.ProtocolStateSnapshotResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*access.ProtocolStateSnapshotResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *access.GetProtocolStateSnapshotByBlockIDRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetProtocolStateSnapshotByHeight provides a mock function with given fields: ctx, in, opts +func (_m *AccessAPIClient) GetProtocolStateSnapshotByHeight(ctx context.Context, in *access.GetProtocolStateSnapshotByHeightRequest, opts ...grpc.CallOption) (*access.ProtocolStateSnapshotResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *access.ProtocolStateSnapshotResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByHeightRequest, ...grpc.CallOption) (*access.ProtocolStateSnapshotResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByHeightRequest, ...grpc.CallOption) *access.ProtocolStateSnapshotResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*access.ProtocolStateSnapshotResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *access.GetProtocolStateSnapshotByHeightRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetTransaction provides a mock function with given fields: ctx, in, opts func (_m *AccessAPIClient) GetTransaction(ctx context.Context, in *access.GetTransactionRequest, opts ...grpc.CallOption) (*access.TransactionResponse, error) { _va := make([]interface{}, len(opts)) diff --git a/engine/access/mock/access_api_server.go b/engine/access/mock/access_api_server.go index 1a2c3772e44..e20d93ed4c1 100644 --- a/engine/access/mock/access_api_server.go +++ b/engine/access/mock/access_api_server.go @@ -535,6 +535,58 @@ func (_m *AccessAPIServer) GetNodeVersionInfo(_a0 context.Context, _a1 *access.G return r0, r1 } +// GetProtocolStateSnapshotByBlockID provides a mock function with given fields: _a0, _a1 +func (_m *AccessAPIServer) GetProtocolStateSnapshotByBlockID(_a0 context.Context, _a1 *access.GetProtocolStateSnapshotByBlockIDRequest) (*access.ProtocolStateSnapshotResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *access.ProtocolStateSnapshotResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByBlockIDRequest) (*access.ProtocolStateSnapshotResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByBlockIDRequest) *access.ProtocolStateSnapshotResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*access.ProtocolStateSnapshotResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *access.GetProtocolStateSnapshotByBlockIDRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetProtocolStateSnapshotByHeight provides a mock function with given fields: _a0, _a1 +func (_m *AccessAPIServer) GetProtocolStateSnapshotByHeight(_a0 context.Context, _a1 *access.GetProtocolStateSnapshotByHeightRequest) (*access.ProtocolStateSnapshotResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *access.ProtocolStateSnapshotResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByHeightRequest) (*access.ProtocolStateSnapshotResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *access.GetProtocolStateSnapshotByHeightRequest) *access.ProtocolStateSnapshotResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*access.ProtocolStateSnapshotResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *access.GetProtocolStateSnapshotByHeightRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetTransaction provides a mock function with given fields: _a0, _a1 func (_m *AccessAPIServer) GetTransaction(_a0 context.Context, _a1 *access.GetTransactionRequest) (*access.TransactionResponse, error) { ret := _m.Called(_a0, _a1) diff --git a/engine/access/mock/execution_api_client.go b/engine/access/mock/execution_api_client.go index 759ca90c81f..597eae4f253 100644 --- a/engine/access/mock/execution_api_client.go +++ b/engine/access/mock/execution_api_client.go @@ -214,6 +214,105 @@ func (_m *ExecutionAPIClient) GetRegisterAtBlockID(ctx context.Context, in *exec return r0, r1 } +// GetTransactionErrorMessage provides a mock function with given fields: ctx, in, opts +func (_m *ExecutionAPIClient) GetTransactionErrorMessage(ctx context.Context, in *execution.GetTransactionErrorMessageRequest, opts ...grpc.CallOption) (*execution.GetTransactionErrorMessageResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *execution.GetTransactionErrorMessageResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageRequest, ...grpc.CallOption) (*execution.GetTransactionErrorMessageResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageRequest, ...grpc.CallOption) *execution.GetTransactionErrorMessageResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*execution.GetTransactionErrorMessageResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *execution.GetTransactionErrorMessageRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTransactionErrorMessageByIndex provides a mock function with given fields: ctx, in, opts +func (_m *ExecutionAPIClient) GetTransactionErrorMessageByIndex(ctx context.Context, in *execution.GetTransactionErrorMessageByIndexRequest, opts ...grpc.CallOption) (*execution.GetTransactionErrorMessageResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *execution.GetTransactionErrorMessageResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageByIndexRequest, ...grpc.CallOption) (*execution.GetTransactionErrorMessageResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageByIndexRequest, ...grpc.CallOption) *execution.GetTransactionErrorMessageResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*execution.GetTransactionErrorMessageResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *execution.GetTransactionErrorMessageByIndexRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTransactionErrorMessagesByBlockID provides a mock function with given fields: ctx, in, opts +func (_m *ExecutionAPIClient) GetTransactionErrorMessagesByBlockID(ctx context.Context, in *execution.GetTransactionErrorMessagesByBlockIDRequest, opts ...grpc.CallOption) (*execution.GetTransactionErrorMessagesResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *execution.GetTransactionErrorMessagesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessagesByBlockIDRequest, ...grpc.CallOption) (*execution.GetTransactionErrorMessagesResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessagesByBlockIDRequest, ...grpc.CallOption) *execution.GetTransactionErrorMessagesResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*execution.GetTransactionErrorMessagesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *execution.GetTransactionErrorMessagesByBlockIDRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetTransactionResult provides a mock function with given fields: ctx, in, opts func (_m *ExecutionAPIClient) GetTransactionResult(ctx context.Context, in *execution.GetTransactionResultRequest, opts ...grpc.CallOption) (*execution.GetTransactionResultResponse, error) { _va := make([]interface{}, len(opts)) diff --git a/engine/access/mock/execution_api_server.go b/engine/access/mock/execution_api_server.go index 32ff605850a..11eff9dea23 100644 --- a/engine/access/mock/execution_api_server.go +++ b/engine/access/mock/execution_api_server.go @@ -170,6 +170,84 @@ func (_m *ExecutionAPIServer) GetRegisterAtBlockID(_a0 context.Context, _a1 *exe return r0, r1 } +// GetTransactionErrorMessage provides a mock function with given fields: _a0, _a1 +func (_m *ExecutionAPIServer) GetTransactionErrorMessage(_a0 context.Context, _a1 *execution.GetTransactionErrorMessageRequest) (*execution.GetTransactionErrorMessageResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *execution.GetTransactionErrorMessageResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageRequest) (*execution.GetTransactionErrorMessageResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageRequest) *execution.GetTransactionErrorMessageResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*execution.GetTransactionErrorMessageResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *execution.GetTransactionErrorMessageRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTransactionErrorMessageByIndex provides a mock function with given fields: _a0, _a1 +func (_m *ExecutionAPIServer) GetTransactionErrorMessageByIndex(_a0 context.Context, _a1 *execution.GetTransactionErrorMessageByIndexRequest) (*execution.GetTransactionErrorMessageResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *execution.GetTransactionErrorMessageResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageByIndexRequest) (*execution.GetTransactionErrorMessageResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessageByIndexRequest) *execution.GetTransactionErrorMessageResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*execution.GetTransactionErrorMessageResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *execution.GetTransactionErrorMessageByIndexRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTransactionErrorMessagesByBlockID provides a mock function with given fields: _a0, _a1 +func (_m *ExecutionAPIServer) GetTransactionErrorMessagesByBlockID(_a0 context.Context, _a1 *execution.GetTransactionErrorMessagesByBlockIDRequest) (*execution.GetTransactionErrorMessagesResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *execution.GetTransactionErrorMessagesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessagesByBlockIDRequest) (*execution.GetTransactionErrorMessagesResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *execution.GetTransactionErrorMessagesByBlockIDRequest) *execution.GetTransactionErrorMessagesResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*execution.GetTransactionErrorMessagesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *execution.GetTransactionErrorMessagesByBlockIDRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetTransactionResult provides a mock function with given fields: _a0, _a1 func (_m *ExecutionAPIServer) GetTransactionResult(_a0 context.Context, _a1 *execution.GetTransactionResultRequest) (*execution.GetTransactionResultResponse, error) { ret := _m.Called(_a0, _a1) diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index ec48ac0586b..0b5626c64b2 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -14,6 +14,8 @@ import ( "golang.org/x/exp/slices" + jsoncdc "github.com/onflow/cadence/encoding/json" + "github.com/onflow/flow/protobuf/go/flow/entities" mocks "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -24,6 +26,7 @@ import ( mockstatestream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" + "github.com/onflow/flow-go/utils/unittest/generator" ) type testType struct { @@ -66,6 +69,9 @@ func (s *SubscribeEventsSuite) SetupTest() { s.blocks = make([]*flow.Block, 0, blockCount) s.blockEvents = make(map[flow.Identifier]flow.EventsList, blockCount) + // by default, events are in CCF encoding + eventsGenerator := generator.EventGenerator(generator.WithEncoding(entities.EventEncodingVersion_CCF_V0)) + for i := 0; i < blockCount; i++ { block := unittest.BlockWithParentFixture(parent) // update for next iteration @@ -74,6 +80,11 @@ func (s *SubscribeEventsSuite) SetupTest() { result := unittest.ExecutionResultFixture() blockEvents := unittest.BlockEventsFixture(block.Header, (i%len(testEventTypes))*3+1, testEventTypes...) + // update payloads with valid CCF encoded data + for i := range blockEvents.Events { + blockEvents.Events[i].Payload = eventsGenerator.New().Payload + } + s.blocks = append(s.blocks, block) s.blockEvents[block.ID()] = blockEvents.Events @@ -171,26 +182,35 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { // construct expected event responses based on the provided test configuration for i, block := range s.blocks { - if startBlockFound || block.ID() == test.startBlockID { + blockID := block.ID() + if startBlockFound || blockID == test.startBlockID { startBlockFound = true if test.startHeight == request.EmptyHeight || block.Header.Height >= test.startHeight { - eventsForBlock := flow.EventsList{} - for _, event := range s.blockEvents[block.ID()] { + // track 2 lists, one for the expected results and one that is passed back + // from the subscription to the handler. These cannot be shared since the + // response struct is passed by reference from the mock to the handler, so + // a bug within the handler could go unnoticed + expectedEvents := flow.EventsList{} + subscriptionEvents := flow.EventsList{} + for _, event := range s.blockEvents[blockID] { if slices.Contains(test.eventTypes, string(event.Type)) || - len(test.eventTypes) == 0 { //Include all events - eventsForBlock = append(eventsForBlock, event) + len(test.eventTypes) == 0 { // Include all events + expectedEvents = append(expectedEvents, event) + subscriptionEvents = append(subscriptionEvents, event) } } - eventResponse := &backend.EventsResponse{ - Height: block.Header.Height, - BlockID: block.ID(), - Events: eventsForBlock, - } - - if len(eventsForBlock) > 0 || (i+1)%int(test.heartbeatInterval) == 0 { - expectedEventsResponses = append(expectedEventsResponses, eventResponse) + if len(expectedEvents) > 0 || (i+1)%int(test.heartbeatInterval) == 0 { + expectedEventsResponses = append(expectedEventsResponses, &backend.EventsResponse{ + Height: block.Header.Height, + BlockID: blockID, + Events: expectedEvents, + }) } - subscriptionEventsResponses = append(subscriptionEventsResponses, eventResponse) + subscriptionEventsResponses = append(subscriptionEventsResponses, &backend.EventsResponse{ + Height: block.Header.Height, + BlockID: blockID, + Events: subscriptionEvents, + }) } } } @@ -410,7 +430,11 @@ func requireResponse(t *testing.T, recorder *testHijackResponseRecorder, expecte require.Equal(t, expectedEvent.TransactionID, actualEvent.TransactionID) require.Equal(t, expectedEvent.TransactionIndex, actualEvent.TransactionIndex) require.Equal(t, expectedEvent.EventIndex, actualEvent.EventIndex) - require.Equal(t, expectedEvent.Payload, actualEvent.Payload) + // payload is not expected to match, but it should decode + + // payload must decode to valid json-cdc encoded data + _, err := jsoncdc.Decode(nil, actualEvent.Payload) + require.NoError(t, err) } } } diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 063cc4ed5c4..221a18ea7b0 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -15,6 +15,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/common/rpc/convert" "github.com/onflow/flow-go/model/flow" ) @@ -158,6 +159,18 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti blocksSinceLastMessage = 0 } + // EventsResponse contains CCF encoded events, and this API returns JSON-CDC events. + // convert event payload formats. + for i, e := range resp.Events { + payload, err := convert.CcfPayloadToJsonPayload(e.Payload) + if err != nil { + err = fmt.Errorf("could not convert event payload from CCF to Json: %w", err) + wsController.wsErrorHandler(err) + return + } + resp.Events[i].Payload = payload + } + // Write the response to the WebSocket connection err = wsController.conn.WriteJSON(event) if err != nil { diff --git a/engine/access/rpc/backend/backend_scripts.go b/engine/access/rpc/backend/backend_scripts.go index 3d8591817d9..05398ccf9c2 100644 --- a/engine/access/rpc/backend/backend_scripts.go +++ b/engine/access/rpc/backend/backend_scripts.go @@ -125,9 +125,11 @@ func (b *backendScripts) executeScript( case ScriptExecutionModeFailover: localResult, localDuration, localErr := b.executeScriptLocally(ctx, scriptRequest) - if localErr == nil || isInvalidArgumentError(localErr) { + if localErr == nil || isInvalidArgumentError(localErr) || status.Code(localErr) == codes.Canceled { return localResult, localErr } + // Note: scripts that timeout are retried on the execution nodes since ANs may have performance + // issues for some scripts. execResult, execDuration, execErr := b.executeScriptOnAvailableExecutionNodes(ctx, scriptRequest) resultComparer := newScriptResultComparison(b.log, b.metrics, scriptRequest) @@ -185,11 +187,13 @@ func (b *backendScripts) executeScriptLocally( if err != nil { convertedErr := convertScriptExecutionError(err, r.height) - if status.Code(convertedErr) == codes.InvalidArgument { + switch status.Code(convertedErr) { + case codes.InvalidArgument, codes.Canceled, codes.DeadlineExceeded: lg.Debug().Err(err). Str("script", string(r.script)). Msg("script failed to execute locally") - } else { + + default: lg.Error().Err(err).Msg("script execution failed") b.metrics.ScriptExecutionErrorLocal() } @@ -332,8 +336,17 @@ func convertScriptExecutionError(err error, height uint64) error { return rpc.ConvertError(err, "failed to execute script", codes.Internal) } - // runtime errors - return status.Errorf(codes.InvalidArgument, "failed to execute script: %v", err) + switch coded.Code() { + case fvmerrors.ErrCodeScriptExecutionCancelledError: + return status.Errorf(codes.Canceled, "script execution canceled: %v", err) + + case fvmerrors.ErrCodeScriptExecutionTimedOutError: + return status.Errorf(codes.DeadlineExceeded, "script execution timed out: %v", err) + + default: + // runtime errors + return status.Errorf(codes.InvalidArgument, "failed to execute script: %v", err) + } } return convertIndexError(err, height, "failed to execute script") diff --git a/engine/access/rpc/backend/backend_scripts_test.go b/engine/access/rpc/backend/backend_scripts_test.go index 951adc9b50c..bb734cab657 100644 --- a/engine/access/rpc/backend/backend_scripts_test.go +++ b/engine/access/rpc/backend/backend_scripts_test.go @@ -34,6 +34,8 @@ var ( cadenceErr = fvmerrors.NewCodedError(fvmerrors.ErrCodeCadenceRunTimeError, "cadence error") fvmFailureErr = fvmerrors.NewCodedError(fvmerrors.FailureCodeBlockFinderFailure, "fvm error") + ctxCancelErr = fvmerrors.NewCodedError(fvmerrors.ErrCodeScriptExecutionCancelledError, "context canceled error") + timeoutErr = fvmerrors.NewCodedError(fvmerrors.ErrCodeScriptExecutionTimedOutError, "timeout error") ) // Create a suite similar to GetAccount that covers each of the modes @@ -319,31 +321,49 @@ func (s *BackendScriptsSuite) TestExecuteScriptWithFailover_HappyPath() { } } -// TestExecuteScriptWithFailover_SkippedForInvalidArgument tests that failover is skipped for -// FVM errors that result in InvalidArgument errors -func (s *BackendScriptsSuite) TestExecuteScriptWithFailover_SkippedForInvalidArgument() { +// TestExecuteScriptWithFailover_SkippedForCorrectCodes tests that failover is skipped for +// FVM errors that result in InvalidArgument or Canceled errors +func (s *BackendScriptsSuite) TestExecuteScriptWithFailover_SkippedForCorrectCodes() { ctx := context.Background() // configure local script executor to fail scriptExecutor := execmock.NewScriptExecutor(s.T()) - scriptExecutor.On("ExecuteAtBlockHeight", mock.Anything, s.failingScript, s.arguments, s.block.Header.Height). - Return(nil, cadenceErr) backend := s.defaultBackend() backend.scriptExecMode = ScriptExecutionModeFailover backend.scriptExecutor = scriptExecutor - s.Run("ExecuteScriptAtLatestBlock", func() { - s.testExecuteScriptAtLatestBlock(ctx, backend, codes.InvalidArgument) - }) + testCases := []struct { + err error + statusCode codes.Code + }{ + { + err: cadenceErr, + statusCode: codes.InvalidArgument, + }, + { + err: ctxCancelErr, + statusCode: codes.Canceled, + }, + } - s.Run("ExecuteScriptAtBlockID", func() { - s.testExecuteScriptAtBlockID(ctx, backend, codes.InvalidArgument) - }) + for _, tt := range testCases { + scriptExecutor.On("ExecuteAtBlockHeight", mock.Anything, s.failingScript, s.arguments, s.block.Header.Height). + Return(nil, tt.err). + Times(3) - s.Run("ExecuteScriptAtBlockHeight", func() { - s.testExecuteScriptAtBlockHeight(ctx, backend, codes.InvalidArgument) - }) + s.Run(fmt.Sprintf("ExecuteScriptAtLatestBlock - %s", tt.statusCode), func() { + s.testExecuteScriptAtLatestBlock(ctx, backend, tt.statusCode) + }) + + s.Run(fmt.Sprintf("ExecuteScriptAtBlockID - %s", tt.statusCode), func() { + s.testExecuteScriptAtBlockID(ctx, backend, tt.statusCode) + }) + + s.Run(fmt.Sprintf("ExecuteScriptAtBlockHeight - %s", tt.statusCode), func() { + s.testExecuteScriptAtBlockHeight(ctx, backend, tt.statusCode) + }) + } } // TestExecuteScriptWithFailover_ReturnsENErrors tests that when an error is returned from the execution diff --git a/engine/execution/checker/engine.go b/engine/execution/checker/engine.go index dcf330bd2c7..a1a96184105 100644 --- a/engine/execution/checker/engine.go +++ b/engine/execution/checker/engine.go @@ -82,7 +82,7 @@ func (e *Engine) checkLastSealed(finalizedID flow.Identifier) error { blockID := seal.BlockID sealedCommit := seal.FinalState - mycommit, err := e.execState.StateCommitmentByBlockID(e.unit.Ctx(), blockID) + mycommit, err := e.execState.StateCommitmentByBlockID(blockID) if errors.Is(err, storage.ErrNotFound) { // have not executed the sealed block yet // in other words, this can't detect execution fork, if the execution is behind diff --git a/engine/execution/computation/committer/committer.go b/engine/execution/computation/committer/committer.go index df2ebb035c5..86d72db1ead 100644 --- a/engine/execution/computation/committer/committer.go +++ b/engine/execution/computation/committer/committer.go @@ -6,6 +6,7 @@ import ( "github.com/hashicorp/go-multierror" + "github.com/onflow/flow-go/engine/execution" execState "github.com/onflow/flow-go/engine/execution/state" "github.com/onflow/flow-go/fvm/storage/snapshot" "github.com/onflow/flow-go/ledger" @@ -31,25 +32,26 @@ func NewLedgerViewCommitter( func (committer *LedgerViewCommitter) CommitView( snapshot *snapshot.ExecutionSnapshot, - baseState flow.StateCommitment, + baseStorageSnapshot execution.ExtendableStorageSnapshot, ) ( newCommit flow.StateCommitment, proof []byte, trieUpdate *ledger.TrieUpdate, + newStorageSnapshot execution.ExtendableStorageSnapshot, err error, ) { var err1, err2 error var wg sync.WaitGroup wg.Add(1) go func() { - proof, err2 = committer.collectProofs(snapshot, baseState) + proof, err2 = committer.collectProofs(snapshot, baseStorageSnapshot) wg.Done() }() - newCommit, trieUpdate, err1 = execState.CommitDelta( + newCommit, trieUpdate, newStorageSnapshot, err1 = execState.CommitDelta( committer.ledger, snapshot, - baseState) + baseStorageSnapshot) wg.Wait() if err1 != nil { @@ -63,11 +65,12 @@ func (committer *LedgerViewCommitter) CommitView( func (committer *LedgerViewCommitter) collectProofs( snapshot *snapshot.ExecutionSnapshot, - baseState flow.StateCommitment, + baseStorageSnapshot execution.ExtendableStorageSnapshot, ) ( proof []byte, err error, ) { + baseState := baseStorageSnapshot.Commitment() // Reason for including AllRegisterIDs (read and written registers) instead of ReadRegisterIDs (only read registers): // AllRegisterIDs returns deduplicated register IDs that were touched by both // reads and writes during the block execution. diff --git a/engine/execution/computation/committer/committer_test.go b/engine/execution/computation/committer/committer_test.go index 18657a67f13..b0f927c2807 100644 --- a/engine/execution/computation/committer/committer_test.go +++ b/engine/execution/computation/committer/committer_test.go @@ -1,48 +1,105 @@ package committer_test import ( + "fmt" "testing" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/onflow/flow-go/engine/execution/computation/committer" + "github.com/onflow/flow-go/engine/execution/storehouse" "github.com/onflow/flow-go/fvm/storage/snapshot" - led "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/convert" + "github.com/onflow/flow-go/ledger/common/pathfinder" + "github.com/onflow/flow-go/ledger/complete" ledgermock "github.com/onflow/flow-go/ledger/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/trace" - utils "github.com/onflow/flow-go/utils/unittest" + "github.com/onflow/flow-go/utils/unittest" ) func TestLedgerViewCommitter(t *testing.T) { - t.Run("calls to set and prove", func(t *testing.T) { + // verify after committing a snapshot, proof will be generated, + // and changes are saved in storage snapshot + t.Run("CommitView should return proof and statecommitment", func(t *testing.T) { - ledger := new(ledgermock.Ledger) - com := committer.NewLedgerViewCommitter(ledger, trace.NewNoopTracer()) + l := ledgermock.NewLedger(t) + committer := committer.NewLedgerViewCommitter(l, trace.NewNoopTracer()) - var expectedStateCommitment led.State - copy(expectedStateCommitment[:], []byte{1, 2, 3}) - ledger.On("Set", mock.Anything). - Return(expectedStateCommitment, nil, nil). + // CommitDelta will call ledger.Set and ledger.Prove + + reg := unittest.MakeOwnerReg("key1", "val1") + startState := unittest.StateCommitmentFixture() + + update, err := ledger.NewUpdate(ledger.State(startState), []ledger.Key{convert.RegisterIDToLedgerKey(reg.Key)}, []ledger.Value{reg.Value}) + require.NoError(t, err) + + expectedTrieUpdate, err := pathfinder.UpdateToTrieUpdate(update, complete.DefaultPathFinderVersion) + require.NoError(t, err) + + endState := unittest.StateCommitmentFixture() + require.NotEqual(t, startState, endState) + + // mock ledger.Set + l.On("Set", mock.Anything). + Return(func(update *ledger.Update) (newState ledger.State, trieUpdate *ledger.TrieUpdate, err error) { + if update.State().Equals(ledger.State(startState)) { + return ledger.State(endState), expectedTrieUpdate, nil + } + return ledger.DummyState, nil, fmt.Errorf("wrong update") + }). Once() - expectedProof := led.Proof([]byte{2, 3, 4}) - ledger.On("Prove", mock.Anything). - Return(expectedProof, nil). + // mock ledger.Prove + expectedProof := ledger.Proof([]byte{2, 3, 4}) + l.On("Prove", mock.Anything). + Return(func(query *ledger.Query) (proof ledger.Proof, err error) { + if query.Size() != 1 { + return nil, fmt.Errorf("wrong query size: %v", query.Size()) + } + + k := convert.RegisterIDToLedgerKey(reg.Key) + if !query.Keys()[0].Equals(&k) { + return nil, fmt.Errorf("in correct query key for prove: %v", query.Keys()[0]) + } + + return expectedProof, nil + }). Once() - newState, proof, _, err := com.CommitView( - &snapshot.ExecutionSnapshot{ - WriteSet: map[flow.RegisterID]flow.RegisterValue{ - flow.NewRegisterID("owner", "key"): []byte{1}, - }, + // previous block's storage snapshot + oldReg := unittest.MakeOwnerReg("key1", "oldvalue") + previousBlockSnapshot := storehouse.NewExecutingBlockSnapshot( + snapshot.MapStorageSnapshot{ + oldReg.Key: oldReg.Value, + }, + flow.StateCommitment(update.State()), + ) + + // this block's register updates + blockUpdates := &snapshot.ExecutionSnapshot{ + WriteSet: map[flow.RegisterID]flow.RegisterValue{ + reg.Key: oldReg.Value, }, - utils.StateCommitmentFixture()) + } + + newCommit, proof, trieUpdate, newStorageSnapshot, err := committer.CommitView( + blockUpdates, + previousBlockSnapshot, + ) + require.NoError(t, err) - require.Equal(t, flow.StateCommitment(expectedStateCommitment), newState) + + // verify CommitView returns expected proof and statecommitment + require.Equal(t, previousBlockSnapshot.Commitment(), flow.StateCommitment(trieUpdate.RootHash)) + require.Equal(t, newCommit, newStorageSnapshot.Commitment()) + require.Equal(t, endState, newCommit) require.Equal(t, []uint8(expectedProof), proof) + require.True(t, expectedTrieUpdate.Equals(trieUpdate)) + }) } diff --git a/engine/execution/computation/committer/noop.go b/engine/execution/computation/committer/noop.go index dcdefbac634..b4549a78c15 100644 --- a/engine/execution/computation/committer/noop.go +++ b/engine/execution/computation/committer/noop.go @@ -1,6 +1,7 @@ package committer import ( + "github.com/onflow/flow-go/engine/execution" "github.com/onflow/flow-go/fvm/storage/snapshot" "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/model/flow" @@ -15,12 +16,17 @@ func NewNoopViewCommitter() *NoopViewCommitter { func (NoopViewCommitter) CommitView( _ *snapshot.ExecutionSnapshot, - s flow.StateCommitment, + baseStorageSnapshot execution.ExtendableStorageSnapshot, ) ( flow.StateCommitment, []byte, *ledger.TrieUpdate, + execution.ExtendableStorageSnapshot, error, ) { - return s, nil, nil, nil + + trieUpdate := &ledger.TrieUpdate{ + RootHash: ledger.RootHash(baseStorageSnapshot.Commitment()), + } + return baseStorageSnapshot.Commitment(), []byte{}, trieUpdate, baseStorageSnapshot, nil } diff --git a/engine/execution/computation/computer/computer.go b/engine/execution/computation/computer/computer.go index 6049ed3d3e8..25345aba997 100644 --- a/engine/execution/computation/computer/computer.go +++ b/engine/execution/computation/computer/computer.go @@ -346,7 +346,9 @@ func (e *blockComputer) executeBlock( parentBlockExecutionResultID, block, numTxns, - e.colResCons) + e.colResCons, + baseSnapshot, + ) defer collector.Stop() requestQueue := make(chan TransactionRequest, numTxns) diff --git a/engine/execution/computation/computer/computer_test.go b/engine/execution/computation/computer/computer_test.go index 4fc768809b6..286cfa588f6 100644 --- a/engine/execution/computation/computer/computer_test.go +++ b/engine/execution/computation/computer/computer_test.go @@ -27,6 +27,7 @@ import ( "github.com/onflow/flow-go/engine/execution/computation/committer" "github.com/onflow/flow-go/engine/execution/computation/computer" computermock "github.com/onflow/flow-go/engine/execution/computation/computer/mock" + "github.com/onflow/flow-go/engine/execution/storehouse" "github.com/onflow/flow-go/engine/execution/testutil" "github.com/onflow/flow-go/fvm" "github.com/onflow/flow-go/fvm/environment" @@ -40,6 +41,9 @@ import ( "github.com/onflow/flow-go/fvm/storage/state" "github.com/onflow/flow-go/fvm/systemcontracts" "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/convert" + "github.com/onflow/flow-go/ledger/common/pathfinder" + "github.com/onflow/flow-go/ledger/complete" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/epochs" "github.com/onflow/flow-go/module/executiondatasync/execution_data" @@ -69,22 +73,46 @@ type fakeCommitter struct { func (committer *fakeCommitter) CommitView( view *snapshot.ExecutionSnapshot, - startState flow.StateCommitment, + baseStorageSnapshot execution.ExtendableStorageSnapshot, ) ( flow.StateCommitment, []byte, *ledger.TrieUpdate, + execution.ExtendableStorageSnapshot, error, ) { committer.callCount++ + startState := baseStorageSnapshot.Commitment() endState := incStateCommitment(startState) - trieUpdate := &ledger.TrieUpdate{} - trieUpdate.RootHash[0] = byte(committer.callCount) - return endState, + reg := unittest.MakeOwnerReg("key", fmt.Sprintf("%v", committer.callCount)) + regKey := convert.RegisterIDToLedgerKey(reg.Key) + path, err := pathfinder.KeyToPath( + regKey, + complete.DefaultPathFinderVersion, + ) + if err != nil { + return flow.DummyStateCommitment, nil, nil, nil, err + } + trieUpdate := &ledger.TrieUpdate{ + RootHash: ledger.RootHash(startState), + Paths: []ledger.Path{ + path, + }, + Payloads: []*ledger.Payload{ + ledger.NewPayload(regKey, reg.Value), + }, + } + + newStorageSnapshot := baseStorageSnapshot.Extend(endState, map[flow.RegisterID]flow.RegisterValue{ + reg.Key: reg.Value, + }) + + return newStorageSnapshot.Commitment(), []byte{byte(committer.callCount)}, trieUpdate, + newStorageSnapshot, nil } @@ -269,12 +297,12 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { chunkDataPack1.Collection, chunkExecutionData1.Collection) assert.NotNil(t, chunkExecutionData1.TrieUpdate) - assert.Equal(t, byte(1), chunkExecutionData1.TrieUpdate.RootHash[0]) + assert.Equal(t, ledger.RootHash(chunk1.StartState), chunkExecutionData1.TrieUpdate.RootHash) chunkExecutionData2 := result.ChunkExecutionDatas[1] assert.NotNil(t, chunkExecutionData2.Collection) assert.NotNil(t, chunkExecutionData2.TrieUpdate) - assert.Equal(t, byte(2), chunkExecutionData2.TrieUpdate.RootHash[0]) + assert.Equal(t, ledger.RootHash(chunk2.StartState), chunkExecutionData2.TrieUpdate.RootHash) assert.GreaterOrEqual(t, vm.CallCount(), 3) // if every transaction is retried once, then the call count should be @@ -322,8 +350,13 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { Return(noOpExecutor{}). Once() // just system chunk + snapshot := storehouse.NewExecutingBlockSnapshot( + snapshot.MapStorageSnapshot{}, + unittest.StateCommitmentFixture(), + ) + committer.On("CommitView", mock.Anything, mock.Anything). - Return(nil, nil, nil, nil). + Return(nil, nil, nil, snapshot, nil). Once() // just system chunk result, err := exe.ExecuteBlock( @@ -415,8 +448,13 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { // create an empty block block := generateBlock(0, 0, rag) + snapshot := storehouse.NewExecutingBlockSnapshot( + snapshot.MapStorageSnapshot{}, + unittest.StateCommitmentFixture(), + ) + comm.On("CommitView", mock.Anything, mock.Anything). - Return(nil, nil, nil, nil). + Return(nil, nil, nil, snapshot, nil). Once() // just system chunk result, err := exe.ExecuteBlock( @@ -482,8 +520,13 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { block := generateBlock(collectionCount, transactionsPerCollection, rag) derivedBlockData := derived.NewEmptyDerivedBlockData(0) + snapshot := storehouse.NewExecutingBlockSnapshot( + snapshot.MapStorageSnapshot{}, + unittest.StateCommitmentFixture(), + ) + committer.On("CommitView", mock.Anything, mock.Anything). - Return(nil, nil, nil, nil). + Return(nil, nil, nil, snapshot, nil). Times(collectionCount + 1) result, err := exe.ExecuteBlock( @@ -600,17 +643,17 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { // events to emit for each iteration/transaction events := map[common.Location][]cadence.Event{ common.TransactionLocation(transactions[0].ID()): nil, - common.TransactionLocation(transactions[1].ID()): []cadence.Event{ + common.TransactionLocation(transactions[1].ID()): { serviceEventA, - cadence.Event{ + { EventType: &cadence.EventType{ Location: stdlib.FlowLocation{}, QualifiedIdentifier: "what.ever", }, }, }, - common.TransactionLocation(transactions[2].ID()): []cadence.Event{ - cadence.Event{ + common.TransactionLocation(transactions[2].ID()): { + { EventType: &cadence.EventType{ Location: stdlib.FlowLocation{}, QualifiedIdentifier: "what.ever", @@ -965,8 +1008,13 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { transactionsPerCollection := 3 block := generateBlock(collectionCount, transactionsPerCollection, rag) + snapshot := storehouse.NewExecutingBlockSnapshot( + snapshot.MapStorageSnapshot{}, + unittest.StateCommitmentFixture(), + ) + committer.On("CommitView", mock.Anything, mock.Anything). - Return(nil, nil, nil, nil). + Return(nil, nil, nil, snapshot, nil). Times(collectionCount + 1) _, err = exe.ExecuteBlock( @@ -1196,8 +1244,13 @@ func Test_ExecutingSystemCollection(t *testing.T) { ledger := testutil.RootBootstrappedLedger(vm, execCtx) committer := new(computermock.ViewCommitter) + snapshot := storehouse.NewExecutingBlockSnapshot( + snapshot.MapStorageSnapshot{}, + unittest.StateCommitmentFixture(), + ) + committer.On("CommitView", mock.Anything, mock.Anything). - Return(nil, nil, nil, nil). + Return(nil, nil, nil, snapshot, nil). Times(1) // only system chunk noopCollector := metrics.NewNoopCollector() diff --git a/engine/execution/computation/computer/mock/view_committer.go b/engine/execution/computation/computer/mock/view_committer.go index dfcacb97c83..5b635f9804a 100644 --- a/engine/execution/computation/computer/mock/view_committer.go +++ b/engine/execution/computation/computer/mock/view_committer.go @@ -3,9 +3,11 @@ package mock import ( - ledger "github.com/onflow/flow-go/ledger" + execution "github.com/onflow/flow-go/engine/execution" flow "github.com/onflow/flow-go/model/flow" + ledger "github.com/onflow/flow-go/ledger" + mock "github.com/stretchr/testify/mock" snapshot "github.com/onflow/flow-go/fvm/storage/snapshot" @@ -17,17 +19,18 @@ type ViewCommitter struct { } // CommitView provides a mock function with given fields: _a0, _a1 -func (_m *ViewCommitter) CommitView(_a0 *snapshot.ExecutionSnapshot, _a1 flow.StateCommitment) (flow.StateCommitment, []byte, *ledger.TrieUpdate, error) { +func (_m *ViewCommitter) CommitView(_a0 *snapshot.ExecutionSnapshot, _a1 execution.ExtendableStorageSnapshot) (flow.StateCommitment, []byte, *ledger.TrieUpdate, execution.ExtendableStorageSnapshot, error) { ret := _m.Called(_a0, _a1) var r0 flow.StateCommitment var r1 []byte var r2 *ledger.TrieUpdate - var r3 error - if rf, ok := ret.Get(0).(func(*snapshot.ExecutionSnapshot, flow.StateCommitment) (flow.StateCommitment, []byte, *ledger.TrieUpdate, error)); ok { + var r3 execution.ExtendableStorageSnapshot + var r4 error + if rf, ok := ret.Get(0).(func(*snapshot.ExecutionSnapshot, execution.ExtendableStorageSnapshot) (flow.StateCommitment, []byte, *ledger.TrieUpdate, execution.ExtendableStorageSnapshot, error)); ok { return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(*snapshot.ExecutionSnapshot, flow.StateCommitment) flow.StateCommitment); ok { + if rf, ok := ret.Get(0).(func(*snapshot.ExecutionSnapshot, execution.ExtendableStorageSnapshot) flow.StateCommitment); ok { r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { @@ -35,7 +38,7 @@ func (_m *ViewCommitter) CommitView(_a0 *snapshot.ExecutionSnapshot, _a1 flow.St } } - if rf, ok := ret.Get(1).(func(*snapshot.ExecutionSnapshot, flow.StateCommitment) []byte); ok { + if rf, ok := ret.Get(1).(func(*snapshot.ExecutionSnapshot, execution.ExtendableStorageSnapshot) []byte); ok { r1 = rf(_a0, _a1) } else { if ret.Get(1) != nil { @@ -43,7 +46,7 @@ func (_m *ViewCommitter) CommitView(_a0 *snapshot.ExecutionSnapshot, _a1 flow.St } } - if rf, ok := ret.Get(2).(func(*snapshot.ExecutionSnapshot, flow.StateCommitment) *ledger.TrieUpdate); ok { + if rf, ok := ret.Get(2).(func(*snapshot.ExecutionSnapshot, execution.ExtendableStorageSnapshot) *ledger.TrieUpdate); ok { r2 = rf(_a0, _a1) } else { if ret.Get(2) != nil { @@ -51,13 +54,21 @@ func (_m *ViewCommitter) CommitView(_a0 *snapshot.ExecutionSnapshot, _a1 flow.St } } - if rf, ok := ret.Get(3).(func(*snapshot.ExecutionSnapshot, flow.StateCommitment) error); ok { + if rf, ok := ret.Get(3).(func(*snapshot.ExecutionSnapshot, execution.ExtendableStorageSnapshot) execution.ExtendableStorageSnapshot); ok { r3 = rf(_a0, _a1) } else { - r3 = ret.Error(3) + if ret.Get(3) != nil { + r3 = ret.Get(3).(execution.ExtendableStorageSnapshot) + } + } + + if rf, ok := ret.Get(4).(func(*snapshot.ExecutionSnapshot, execution.ExtendableStorageSnapshot) error); ok { + r4 = rf(_a0, _a1) + } else { + r4 = ret.Error(4) } - return r0, r1, r2, r3 + return r0, r1, r2, r3, r4 } type mockConstructorTestingTNewViewCommitter interface { diff --git a/engine/execution/computation/computer/result_collector.go b/engine/execution/computation/computer/result_collector.go index 37ef4540748..4b367fda739 100644 --- a/engine/execution/computation/computer/result_collector.go +++ b/engine/execution/computation/computer/result_collector.go @@ -12,6 +12,7 @@ import ( "github.com/onflow/flow-go/crypto/hash" "github.com/onflow/flow-go/engine/execution" "github.com/onflow/flow-go/engine/execution/computation/result" + "github.com/onflow/flow-go/engine/execution/storehouse" "github.com/onflow/flow-go/fvm" "github.com/onflow/flow-go/fvm/meter" "github.com/onflow/flow-go/fvm/storage/snapshot" @@ -31,11 +32,12 @@ type ViewCommitter interface { // CommitView commits an execution snapshot and collects proofs CommitView( *snapshot.ExecutionSnapshot, - flow.StateCommitment, + execution.ExtendableStorageSnapshot, ) ( - flow.StateCommitment, + flow.StateCommitment, // TODO(leo): deprecate. see storehouse.ExtendableStorageSnapshot.Commitment() []byte, *ledger.TrieUpdate, + execution.ExtendableStorageSnapshot, error, ) } @@ -79,9 +81,10 @@ type resultCollector struct { blockStats module.ExecutionResultStats blockMeter *meter.Meter - currentCollectionStartTime time.Time - currentCollectionState *state.ExecutionState - currentCollectionStats module.ExecutionResultStats + currentCollectionStartTime time.Time + currentCollectionState *state.ExecutionState + currentCollectionStats module.ExecutionResultStats + currentCollectionStorageSnapshot execution.ExtendableStorageSnapshot } func newResultCollector( @@ -97,6 +100,7 @@ func newResultCollector( block *entity.ExecutableBlock, numTransactions int, consumers []result.ExecutedCollectionConsumer, + previousBlockSnapshot snapshot.StorageSnapshot, ) *resultCollector { numCollections := len(block.Collections()) + 1 now := time.Now() @@ -122,6 +126,10 @@ func newResultCollector( currentCollectionStats: module.ExecutionResultStats{ NumberOfCollections: 1, }, + currentCollectionStorageSnapshot: storehouse.NewExecutingBlockSnapshot( + previousBlockSnapshot, + *block.StartState, + ), } go collector.runResultProcessor() @@ -138,14 +146,19 @@ func (collector *resultCollector) commitCollection( collector.blockSpan, trace.EXECommitDelta).End() - startState := collector.result.CurrentEndState() - endState, proof, trieUpdate, err := collector.committer.CommitView( + startState := collector.currentCollectionStorageSnapshot.Commitment() + + _, proof, trieUpdate, newSnapshot, err := collector.committer.CommitView( collectionExecutionSnapshot, - startState) + collector.currentCollectionStorageSnapshot, + ) if err != nil { return fmt.Errorf("commit view failed: %w", err) } + endState := newSnapshot.Commitment() + collector.currentCollectionStorageSnapshot = newSnapshot + execColRes := collector.result.CollectionExecutionResultAt(collection.collectionIndex) execColRes.UpdateExecutionSnapshot(collectionExecutionSnapshot) diff --git a/engine/execution/computation/query/executor.go b/engine/execution/computation/query/executor.go index b129a4d3609..104fa2a9e77 100644 --- a/engine/execution/computation/query/executor.go +++ b/engine/execution/computation/query/executor.go @@ -54,6 +54,7 @@ type Executor interface { type QueryConfig struct { LogTimeThreshold time.Duration ExecutionTimeLimit time.Duration + ComputationLimit uint64 MaxErrorMessageSize int } @@ -61,6 +62,7 @@ func NewDefaultConfig() QueryConfig { return QueryConfig{ LogTimeThreshold: DefaultLogTimeThreshold, ExecutionTimeLimit: DefaultExecutionTimeLimit, + ComputationLimit: fvm.DefaultComputationLimit, MaxErrorMessageSize: DefaultMaxErrorMessageSize, } } @@ -87,6 +89,9 @@ func NewQueryExecutor( derivedChainData *derived.DerivedChainData, entropyPerBlock EntropyProviderPerBlock, ) *QueryExecutor { + if config.ComputationLimit > 0 { + vmCtx = fvm.NewContextFromParent(vmCtx, fvm.WithComputationLimit(config.ComputationLimit)) + } return &QueryExecutor{ config: config, logger: logger, diff --git a/engine/execution/execution_test.go b/engine/execution/execution_test.go index f52796f85e5..3d208a9a112 100644 --- a/engine/execution/execution_test.go +++ b/engine/execution/execution_test.go @@ -437,10 +437,10 @@ func TestFailedTxWillNotChangeStateCommitment(t *testing.T) { }) exe1Node.AssertHighestExecutedBlock(t, block1.Header) - scExe1Genesis, err := exe1Node.ExecutionState.StateCommitmentByBlockID(context.Background(), genesis.ID()) + scExe1Genesis, err := exe1Node.ExecutionState.StateCommitmentByBlockID(genesis.ID()) assert.NoError(t, err) - scExe1Block1, err := exe1Node.ExecutionState.StateCommitmentByBlockID(context.Background(), block1.ID()) + scExe1Block1, err := exe1Node.ExecutionState.StateCommitmentByBlockID(block1.ID()) assert.NoError(t, err) assert.NotEqual(t, scExe1Genesis, scExe1Block1) @@ -461,7 +461,7 @@ func TestFailedTxWillNotChangeStateCommitment(t *testing.T) { // exe2Node.AssertHighestExecutedBlock(t, block3.Header) // verify state commitment of block 2 is the same as block 1, since tx failed on seq number verification - scExe1Block2, err := exe1Node.ExecutionState.StateCommitmentByBlockID(context.Background(), block2.ID()) + scExe1Block2, err := exe1Node.ExecutionState.StateCommitmentByBlockID(block2.ID()) assert.NoError(t, err) // TODO this is no longer valid because the system chunk can change the state //assert.Equal(t, scExe1Block1, scExe1Block2) diff --git a/engine/execution/ingestion/engine.go b/engine/execution/ingestion/engine.go index bf8b67ad85c..729124181f4 100644 --- a/engine/execution/ingestion/engine.go +++ b/engine/execution/ingestion/engine.go @@ -284,7 +284,7 @@ func (e *Engine) handleBlock(ctx context.Context, block *flow.Block) error { span, _ := e.tracer.StartBlockSpan(ctx, blockID, trace.EXEHandleBlock) defer span.End() - executed, err := state.IsBlockExecuted(e.unit.Ctx(), e.execState, blockID) + executed, err := e.execState.IsBlockExecuted(block.Header.Height, blockID) if err != nil { return fmt.Errorf("could not check whether block is executed: %w", err) } @@ -357,7 +357,7 @@ func (e *Engine) enqueueBlockAndCheckExecutable( // check if the block's parent has been executed. (we can't execute the block if the parent has // not been executed yet) // check if there is a statecommitment for the parent block - parentCommitment, err := e.execState.StateCommitmentByBlockID(e.unit.Ctx(), block.Header.ParentID) + parentCommitment, err := e.execState.StateCommitmentByBlockID(block.Header.ParentID) // if we found the statecommitment for the parent block, then add it to the executable block. if err == nil { @@ -429,7 +429,10 @@ func (e *Engine) executeBlock( return } - snapshot := e.execState.NewStorageSnapshot(*executableBlock.StartState) + snapshot := e.execState.NewStorageSnapshot(*executableBlock.StartState, + executableBlock.ID(), + executableBlock.Block.Header.Height, + ) computationResult, err := e.computationManager.ComputeBlock( ctx, diff --git a/engine/execution/ingestion/engine_test.go b/engine/execution/ingestion/engine_test.go index 7edd78a8351..f66e7d00f50 100644 --- a/engine/execution/ingestion/engine_test.go +++ b/engine/execution/ingestion/engine_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/flow-go/crypto" + "github.com/onflow/flow-go/fvm/storage/snapshot" enginePkg "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/engine/execution" @@ -112,7 +113,7 @@ func runWithEngine(t *testing.T, f func(testingContext)) { uploadMgr := uploader.NewManager(trace.NewNoopTracer()) fetcher := mocks.NewMockFetcher() - loader := loader.NewLoader(log, protocolState, headers, executionState) + loader := loader.NewUnexecutedLoader(log, protocolState, headers, executionState) engine, err = New( unit, @@ -171,9 +172,10 @@ func TestExecuteOneBlock(t *testing.T) { blockA := makeBlockWithCollection(store.RootBlock, &col) result := store.CreateBlockAndMockResult(t, blockA) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) // receive block err := ctx.engine.handleBlock(context.Background(), blockA.Block) @@ -226,9 +228,10 @@ func TestExecuteBlocks(t *testing.T) { store.CreateBlockAndMockResult(t, blockA) store.CreateBlockAndMockResult(t, blockB) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) ctx.providerEngine.On("BroadcastExecutionReceipt", mock.Anything, mock.Anything, mock.Anything).Return(false, nil) // receive block @@ -275,9 +278,10 @@ func TestExecuteNextBlockIfCollectionIsReady(t *testing.T) { // C2 is available in storage require.NoError(t, ctx.collections.Store(&col2)) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) // receiving block A and B will not trigger any execution // because A is missing collection C1, B is waiting for A to be executed @@ -319,9 +323,10 @@ func TestExecuteBlockOnlyOnce(t *testing.T) { blockA := makeBlockWithCollection(store.RootBlock, &col) store.CreateBlockAndMockResult(t, blockA) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) // receive block err := ctx.engine.handleBlock(context.Background(), blockA.Block) @@ -375,9 +380,10 @@ func TestExecuteForkConcurrently(t *testing.T) { store.CreateBlockAndMockResult(t, blockA) store.CreateBlockAndMockResult(t, blockB) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) // receive blocks err := ctx.engine.handleBlock(context.Background(), blockA.Block) @@ -425,9 +431,10 @@ func TestExecuteBlockInOrder(t *testing.T) { store.CreateBlockAndMockResult(t, blockB) store.CreateBlockAndMockResult(t, blockC) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) // receive blocks err := ctx.engine.handleBlock(context.Background(), blockA.Block) @@ -495,9 +502,10 @@ func TestStopAtHeightWhenFinalizedBeforeExecuted(t *testing.T) { }) require.NoError(t, err) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) // receive blocks err = ctx.engine.handleBlock(context.Background(), blockA.Block) @@ -562,9 +570,10 @@ func TestStopAtHeightWhenExecutedBeforeFinalized(t *testing.T) { }) require.NoError(t, err) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) ctx.providerEngine.On("BroadcastExecutionReceipt", mock.Anything, mock.Anything, mock.Anything).Return(false, nil) ctx.mockComputeBlock(store) @@ -623,9 +632,10 @@ func TestStopAtHeightWhenExecutionFinalization(t *testing.T) { }) require.NoError(t, err) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) ctx.providerEngine.On("BroadcastExecutionReceipt", mock.Anything, mock.Anything, mock.Anything).Return(false, nil) ctx.mockComputeBlock(store) @@ -678,9 +688,10 @@ func TestExecutedBlockUploadedFailureDoesntBlock(t *testing.T) { blockA := makeBlockWithCollection(store.RootBlock, &col) result := store.CreateBlockAndMockResult(t, blockA) + ctx.mockIsBlockExecuted(store) ctx.mockStateCommitmentByBlockID(store) ctx.mockGetExecutionResultID(store) - ctx.executionState.On("NewStorageSnapshot", mock.Anything).Return(nil) + ctx.executionState.On("NewStorageSnapshot", mock.Anything, mock.Anything, mock.Anything).Return(nil) // receive block err := ctx.engine.handleBlock(context.Background(), blockA.Block) @@ -741,63 +752,63 @@ func makeBlockWithCollection(parent *flow.Header, cols ...*flow.Collection) *ent return executableBlock } +func (ctx *testingContext) mockIsBlockExecuted(store *mocks.MockBlockStore) { + ctx.executionState.On("IsBlockExecuted", mock.Anything, mock.Anything). + Return(func(height uint64, blockID flow.Identifier) (bool, error) { + _, err := store.GetExecuted(blockID) + if err != nil { + return false, nil + } + return true, nil + }) +} + func (ctx *testingContext) mockStateCommitmentByBlockID(store *mocks.MockBlockStore) { - mocked := ctx.executionState.On("StateCommitmentByBlockID", mock.Anything, mock.Anything) - // https://github.com/stretchr/testify/issues/350#issuecomment-570478958 - mocked.RunFn = func(args mock.Arguments) { - blockID := args[1].(flow.Identifier) - result, err := store.GetExecuted(blockID) - if err != nil { - mocked.ReturnArguments = mock.Arguments{flow.StateCommitment{}, storageerr.ErrNotFound} - return - } - mocked.ReturnArguments = mock.Arguments{result.Result.CurrentEndState(), nil} - } + ctx.executionState.On("StateCommitmentByBlockID", mock.Anything). + Return(func(blockID flow.Identifier) (flow.StateCommitment, error) { + result, err := store.GetExecuted(blockID) + if err != nil { + return flow.StateCommitment{}, storageerr.ErrNotFound + } + return result.Result.CurrentEndState(), nil + }) } func (ctx *testingContext) mockGetExecutionResultID(store *mocks.MockBlockStore) { - - mocked := ctx.executionState.On("GetExecutionResultID", mock.Anything, mock.Anything) - mocked.RunFn = func(args mock.Arguments) { - blockID := args[1].(flow.Identifier) - blockResult, err := store.GetExecuted(blockID) - if err != nil { - mocked.ReturnArguments = mock.Arguments{nil, storageerr.ErrNotFound} - return - } - - mocked.ReturnArguments = mock.Arguments{ - blockResult.Result.ExecutionReceipt.ExecutionResult.ID(), nil} - } + ctx.executionState.On("GetExecutionResultID", mock.Anything, mock.Anything). + Return(func(ctx context.Context, blockID flow.Identifier) (flow.Identifier, error) { + blockResult, err := store.GetExecuted(blockID) + if err != nil { + return flow.ZeroID, storageerr.ErrNotFound + } + + return blockResult.Result.ExecutionReceipt.ExecutionResult.ID(), nil + }) } func (ctx *testingContext) mockComputeBlock(store *mocks.MockBlockStore) { - mocked := ctx.computationManager.On("ComputeBlock", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - mocked.RunFn = func(args mock.Arguments) { - block := args[2].(*entity.ExecutableBlock) - blockResult, ok := store.ResultByBlock[block.ID()] - if !ok { - mocked.ReturnArguments = mock.Arguments{nil, fmt.Errorf("block %s not found", block.ID())} - return - } - mocked.ReturnArguments = mock.Arguments{blockResult.Result, nil} - } + ctx.computationManager.On("ComputeBlock", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(func(ctx context.Context, + parentBlockExecutionResultID flow.Identifier, + block *entity.ExecutableBlock, + snapshot snapshot.StorageSnapshot) ( + *execution.ComputationResult, error) { + blockResult, ok := store.ResultByBlock[block.ID()] + if !ok { + return nil, fmt.Errorf("block %s not found", block.ID()) + } + return blockResult.Result, nil + }) } func (ctx *testingContext) mockSaveExecutionResults(store *mocks.MockBlockStore, wg *sync.WaitGroup) { - mocked := ctx.executionState. - On("SaveExecutionResults", mock.Anything, mock.Anything) - - mocked.RunFn = func(args mock.Arguments) { - result := args[1].(*execution.ComputationResult) - - err := store.MarkExecuted(result) - if err != nil { - mocked.ReturnArguments = mock.Arguments{err} - wg.Done() - return - } - mocked.ReturnArguments = mock.Arguments{nil} - wg.Done() - } + ctx.executionState.On("SaveExecutionResults", mock.Anything, mock.Anything). + Return(func(ctx context.Context, result *execution.ComputationResult) error { + defer wg.Done() + err := store.MarkExecuted(result) + if err != nil { + return err + } + return nil + }) } diff --git a/engine/execution/ingestion/loader/loader.go b/engine/execution/ingestion/loader/unexecuted_loader.go similarity index 84% rename from engine/execution/ingestion/loader/loader.go rename to engine/execution/ingestion/loader/unexecuted_loader.go index 7d34d7ea666..a9eba76115f 100644 --- a/engine/execution/ingestion/loader/loader.go +++ b/engine/execution/ingestion/loader/unexecuted_loader.go @@ -13,28 +13,31 @@ import ( "github.com/onflow/flow-go/utils/logging" ) -type Loader struct { +// deprecated. Storehouse is going to use unfinalized loader instead +type UnexecutedLoader struct { log zerolog.Logger state protocol.State headers storage.Headers // see comments on getHeaderByHeight for why we need it execState state.ExecutionState } -func NewLoader( +func NewUnexecutedLoader( log zerolog.Logger, state protocol.State, headers storage.Headers, execState state.ExecutionState, -) *Loader { - return &Loader{ - log: log.With().Str("component", "ingestion_engine_block_loader").Logger(), +) *UnexecutedLoader { + return &UnexecutedLoader{ + log: log.With().Str("component", "ingestion_engine_unexecuted_loader").Logger(), state: state, headers: headers, execState: execState, } } -func (e *Loader) LoadUnexecuted(ctx context.Context) ([]flow.Identifier, error) { +// LoadUnexecuted loads all unexecuted and validated blocks +// any error returned are exceptions +func (e *UnexecutedLoader) LoadUnexecuted(ctx context.Context) ([]flow.Identifier, error) { // saving an executed block is currently not transactional, so it's possible // the block is marked as executed but the receipt might not be saved during a crash. // in order to mitigate this problem, we always re-execute the last executed and finalized @@ -63,7 +66,7 @@ func (e *Loader) LoadUnexecuted(ctx context.Context) ([]flow.Identifier, error) blockIDs := make([]flow.Identifier, 0) isRoot := rootBlock.ID() == last.ID() if !isRoot { - executed, err := state.IsBlockExecuted(ctx, e.execState, lastExecutedID) + executed, err := e.execState.IsBlockExecuted(lastExecutedHeight, lastExecutedID) if err != nil { return nil, fmt.Errorf("cannot check is last exeucted final block has been executed %v: %w", lastExecutedID, err) } @@ -104,7 +107,7 @@ func (e *Loader) LoadUnexecuted(ctx context.Context) ([]flow.Identifier, error) return blockIDs, nil } -func (e *Loader) unexecutedBlocks(ctx context.Context) ( +func (e *UnexecutedLoader) unexecutedBlocks(ctx context.Context) ( finalized []flow.Identifier, pending []flow.Identifier, err error, @@ -126,7 +129,7 @@ func (e *Loader) unexecutedBlocks(ctx context.Context) ( return finalized, pending, nil } -func (e *Loader) finalizedUnexecutedBlocks(ctx context.Context, finalized protocol.Snapshot) ( +func (e *UnexecutedLoader) finalizedUnexecutedBlocks(ctx context.Context, finalized protocol.Snapshot) ( []flow.Identifier, error, ) { @@ -159,7 +162,7 @@ func (e *Loader) finalizedUnexecutedBlocks(ctx context.Context, finalized protoc return nil, fmt.Errorf("could not get header at height: %v, %w", lastExecuted, err) } - executed, err := state.IsBlockExecuted(ctx, e.execState, header.ID()) + executed, err := e.execState.IsBlockExecuted(header.Height, header.ID()) if err != nil { return nil, fmt.Errorf("could not check whether block is executed: %w", err) } @@ -196,7 +199,7 @@ func (e *Loader) finalizedUnexecutedBlocks(ctx context.Context, finalized protoc return unexecuted, nil } -func (e *Loader) pendingUnexecutedBlocks(ctx context.Context, finalized protocol.Snapshot) ( +func (e *UnexecutedLoader) pendingUnexecutedBlocks(ctx context.Context, finalized protocol.Snapshot) ( []flow.Identifier, error, ) { @@ -208,7 +211,11 @@ func (e *Loader) pendingUnexecutedBlocks(ctx context.Context, finalized protocol unexecuted := make([]flow.Identifier, 0) for _, pending := range pendings { - executed, err := state.IsBlockExecuted(ctx, e.execState, pending) + p, err := e.headers.ByBlockID(pending) + if err != nil { + return nil, fmt.Errorf("could not get header by block id: %w", err) + } + executed, err := e.execState.IsBlockExecuted(p.Height, pending) if err != nil { return nil, fmt.Errorf("could not check block executed or not: %w", err) } @@ -224,7 +231,7 @@ func (e *Loader) pendingUnexecutedBlocks(ctx context.Context, finalized protocol // if the EN is dynamically bootstrapped, the finalized blocks at height range: // [ sealedRoot.Height, finalizedRoot.Height - 1] can not be retrieved from // protocol state, but only from headers -func (e *Loader) getHeaderByHeight(height uint64) (*flow.Header, error) { +func (e *UnexecutedLoader) getHeaderByHeight(height uint64) (*flow.Header, error) { // we don't use protocol state because for dynamic boostrapped execution node // the last executed and sealed block is below the finalized root block return e.headers.ByHeight(height) diff --git a/engine/execution/ingestion/loader/loader_test.go b/engine/execution/ingestion/loader/unexecuted_loader_test.go similarity index 83% rename from engine/execution/ingestion/loader/loader_test.go rename to engine/execution/ingestion/loader/unexecuted_loader_test.go index 5b61d155e8c..23779394c5b 100644 --- a/engine/execution/ingestion/loader/loader_test.go +++ b/engine/execution/ingestion/loader/unexecuted_loader_test.go @@ -11,7 +11,6 @@ import ( "github.com/onflow/flow-go/engine/execution/ingestion" "github.com/onflow/flow-go/engine/execution/ingestion/loader" - "github.com/onflow/flow-go/engine/execution/state" stateMock "github.com/onflow/flow-go/engine/execution/state/mock" "github.com/onflow/flow-go/model/flow" storageerr "github.com/onflow/flow-go/storage" @@ -20,7 +19,7 @@ import ( "github.com/onflow/flow-go/utils/unittest/mocks" ) -var _ ingestion.BlockLoader = (*loader.Loader)(nil) +var _ ingestion.BlockLoader = (*loader.UnexecutedLoader)(nil) // ExecutionState is a mocked version of execution state that // simulates some of its behavior for testing purpose @@ -41,7 +40,6 @@ func newMockExecutionState(seal *flow.Seal, genesis *flow.Header) *mockExecution } func (es *mockExecutionState) StateCommitmentByBlockID( - ctx context.Context, blockID flow.Identifier, ) ( flow.StateCommitment, @@ -57,10 +55,16 @@ func (es *mockExecutionState) StateCommitmentByBlockID( return commit, nil } +func (es *mockExecutionState) IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) { + es.Lock() + defer es.Unlock() + _, ok := es.commits[blockID] + return ok, nil +} + func (es *mockExecutionState) ExecuteBlock(t *testing.T, block *flow.Block) { - parentExecuted, err := state.IsBlockExecuted( - context.Background(), - es, + parentExecuted, err := es.IsBlockExecuted( + block.Header.Height, block.Header.ParentID) require.NoError(t, err) require.True(t, parentExecuted, "parent block not executed") @@ -93,7 +97,7 @@ func TestLoadingUnexecutedBlocks(t *testing.T) { headers := storage.NewMockHeaders(ctrl) headers.EXPECT().ByBlockID(genesis.ID()).Return(genesis.Header, nil) log := unittest.Logger() - loader := loader.NewLoader(log, ps, headers, es) + loader := loader.NewUnexecutedLoader(log, ps, headers, es) unexecuted, err := loader.LoadUnexecuted(context.Background()) require.NoError(t, err) @@ -120,8 +124,12 @@ func TestLoadingUnexecutedBlocks(t *testing.T) { ctrl := gomock.NewController(t) headers := storage.NewMockHeaders(ctrl) headers.EXPECT().ByBlockID(genesis.ID()).Return(genesis.Header, nil) + headers.EXPECT().ByBlockID(blockA.ID()).Return(blockA.Header, nil) + headers.EXPECT().ByBlockID(blockB.ID()).Return(blockB.Header, nil) + headers.EXPECT().ByBlockID(blockC.ID()).Return(blockC.Header, nil) + headers.EXPECT().ByBlockID(blockD.ID()).Return(blockD.Header, nil) log := unittest.Logger() - loader := loader.NewLoader(log, ps, headers, es) + loader := loader.NewUnexecutedLoader(log, ps, headers, es) unexecuted, err := loader.LoadUnexecuted(context.Background()) require.NoError(t, err) @@ -148,8 +156,13 @@ func TestLoadingUnexecutedBlocks(t *testing.T) { ctrl := gomock.NewController(t) headers := storage.NewMockHeaders(ctrl) headers.EXPECT().ByBlockID(genesis.ID()).Return(genesis.Header, nil) + headers.EXPECT().ByBlockID(blockA.ID()).Return(blockA.Header, nil) + headers.EXPECT().ByBlockID(blockB.ID()).Return(blockB.Header, nil) + headers.EXPECT().ByBlockID(blockC.ID()).Return(blockC.Header, nil) + headers.EXPECT().ByBlockID(blockD.ID()).Return(blockD.Header, nil) + log := unittest.Logger() - loader := loader.NewLoader(log, ps, headers, es) + loader := loader.NewUnexecutedLoader(log, ps, headers, es) es.ExecuteBlock(t, blockA) es.ExecuteBlock(t, blockB) @@ -181,8 +194,10 @@ func TestLoadingUnexecutedBlocks(t *testing.T) { ctrl := gomock.NewController(t) headers := storage.NewMockHeaders(ctrl) headers.EXPECT().ByBlockID(genesis.ID()).Return(genesis.Header, nil) + headers.EXPECT().ByBlockID(blockD.ID()).Return(blockD.Header, nil) + log := unittest.Logger() - loader := loader.NewLoader(log, ps, headers, es) + loader := loader.NewUnexecutedLoader(log, ps, headers, es) // block C is the only finalized block, index its header by its height headers.EXPECT().ByHeight(blockC.Header.Height).Return(blockC.Header, nil) @@ -218,8 +233,9 @@ func TestLoadingUnexecutedBlocks(t *testing.T) { ctrl := gomock.NewController(t) headers := storage.NewMockHeaders(ctrl) headers.EXPECT().ByBlockID(genesis.ID()).Return(genesis.Header, nil) + headers.EXPECT().ByBlockID(blockD.ID()).Return(blockD.Header, nil) log := unittest.Logger() - loader := loader.NewLoader(log, ps, headers, es) + loader := loader.NewUnexecutedLoader(log, ps, headers, es) // block C is finalized, index its header by its height headers.EXPECT().ByHeight(blockC.Header.Height).Return(blockC.Header, nil) @@ -254,8 +270,12 @@ func TestLoadingUnexecutedBlocks(t *testing.T) { ctrl := gomock.NewController(t) headers := storage.NewMockHeaders(ctrl) headers.EXPECT().ByBlockID(genesis.ID()).Return(genesis.Header, nil) + headers.EXPECT().ByBlockID(blockB.ID()).Return(blockB.Header, nil) + headers.EXPECT().ByBlockID(blockC.ID()).Return(blockC.Header, nil) + headers.EXPECT().ByBlockID(blockD.ID()).Return(blockD.Header, nil) + log := unittest.Logger() - loader := loader.NewLoader(log, ps, headers, es) + loader := loader.NewUnexecutedLoader(log, ps, headers, es) // block A is finalized, index its header by its height headers.EXPECT().ByHeight(blockA.Header.Height).Return(blockA.Header, nil) @@ -315,8 +335,15 @@ func TestLoadingUnexecutedBlocks(t *testing.T) { ctrl := gomock.NewController(t) headers := storage.NewMockHeaders(ctrl) headers.EXPECT().ByBlockID(genesis.ID()).Return(genesis.Header, nil) + headers.EXPECT().ByBlockID(blockD.ID()).Return(blockD.Header, nil) + headers.EXPECT().ByBlockID(blockE.ID()).Return(blockE.Header, nil) + headers.EXPECT().ByBlockID(blockF.ID()).Return(blockF.Header, nil) + headers.EXPECT().ByBlockID(blockG.ID()).Return(blockG.Header, nil) + headers.EXPECT().ByBlockID(blockH.ID()).Return(blockH.Header, nil) + headers.EXPECT().ByBlockID(blockI.ID()).Return(blockI.Header, nil) + log := unittest.Logger() - loader := loader.NewLoader(log, ps, headers, es) + loader := loader.NewUnexecutedLoader(log, ps, headers, es) // block C is finalized, index its header by its height headers.EXPECT().ByHeight(blockC.Header.Height).Return(blockC.Header, nil) diff --git a/engine/execution/ingestion/loader/unfinalized_loader.go b/engine/execution/ingestion/loader/unfinalized_loader.go new file mode 100644 index 00000000000..bcfc699074a --- /dev/null +++ b/engine/execution/ingestion/loader/unfinalized_loader.go @@ -0,0 +1,91 @@ +package loader + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/execution/state" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/state/protocol" + "github.com/onflow/flow-go/storage" +) + +type UnfinalizedLoader struct { + log zerolog.Logger + state protocol.State + headers storage.Headers // see comments on getHeaderByHeight for why we need it + execState state.FinalizedExecutionState +} + +// NewUnfinalizedLoader creates a new loader that loads all unfinalized and validated blocks +func NewUnfinalizedLoader( + log zerolog.Logger, + state protocol.State, + headers storage.Headers, + execState state.FinalizedExecutionState, +) *UnfinalizedLoader { + return &UnfinalizedLoader{ + log: log.With().Str("component", "ingestion_engine_unfinalized_loader").Logger(), + state: state, + headers: headers, + execState: execState, + } +} + +// LoadUnexecuted loads all unfinalized and validated blocks +// any error returned are exceptions +func (e *UnfinalizedLoader) LoadUnexecuted(ctx context.Context) ([]flow.Identifier, error) { + lastExecuted := e.execState.GetHighestFinalizedExecuted() + + // get finalized height + finalized := e.state.Final() + final, err := finalized.Head() + if err != nil { + return nil, fmt.Errorf("could not get finalized block: %w", err) + } + + // TODO: dynamically bootstrapped execution node will reload blocks from + unexecutedFinalized := make([]flow.Identifier, 0) + + // starting from the first unexecuted block, go through each unexecuted and finalized block + // reload its block to execution queues + // loading finalized blocks + for height := lastExecuted + 1; height <= final.Height; height++ { + header, err := e.getHeaderByHeight(height) + if err != nil { + return nil, fmt.Errorf("could not get header at height: %v, %w", height, err) + } + + unexecutedFinalized = append(unexecutedFinalized, header.ID()) + } + + // loaded all pending blocks + pendings, err := finalized.Descendants() + if err != nil { + return nil, fmt.Errorf("could not get descendants of finalized block: %w", err) + } + + unexecuted := append(unexecutedFinalized, pendings...) + + e.log.Info(). + Uint64("last_finalized", final.Height). + Uint64("last_finalized_executed", lastExecuted). + // Uint64("sealed_root_height", rootBlock.Height). + // Hex("sealed_root_id", logging.Entity(rootBlock)). + Int("total_finalized_unexecuted", len(unexecutedFinalized)). + Int("total_unexecuted", len(unexecuted)). + Msgf("finalized unexecuted blocks") + + return unexecuted, nil +} + +// if the EN is dynamically bootstrapped, the finalized blocks at height range: +// [ sealedRoot.Height, finalizedRoot.Height - 1] can not be retrieved from +// protocol state, but only from headers +func (e *UnfinalizedLoader) getHeaderByHeight(height uint64) (*flow.Header, error) { + // we don't use protocol state because for dynamic boostrapped execution node + // the last executed and sealed block is below the finalized root block + return e.headers.ByHeight(height) +} diff --git a/engine/execution/ingestion/loader/unfinalized_loader_test.go b/engine/execution/ingestion/loader/unfinalized_loader_test.go new file mode 100644 index 00000000000..3c8b84aed40 --- /dev/null +++ b/engine/execution/ingestion/loader/unfinalized_loader_test.go @@ -0,0 +1,55 @@ +package loader_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/engine/execution/ingestion" + "github.com/onflow/flow-go/engine/execution/ingestion/loader" + stateMock "github.com/onflow/flow-go/engine/execution/state/mock" + "github.com/onflow/flow-go/model/flow" + storage "github.com/onflow/flow-go/storage/mock" + "github.com/onflow/flow-go/utils/unittest" + "github.com/onflow/flow-go/utils/unittest/mocks" +) + +var _ ingestion.BlockLoader = (*loader.UnfinalizedLoader)(nil) + +func TestLoadingUnfinalizedBlocks(t *testing.T) { + ps := mocks.NewProtocolState() + + // Genesis <- A <- B <- C (finalized) <- D + chain, result, seal := unittest.ChainFixture(5) + genesis, blockA, blockB, blockC, blockD := + chain[0], chain[1], chain[2], chain[3], chain[4] + + logChain(chain) + + require.NoError(t, ps.Bootstrap(genesis, result, seal)) + require.NoError(t, ps.Extend(blockA)) + require.NoError(t, ps.Extend(blockB)) + require.NoError(t, ps.Extend(blockC)) + require.NoError(t, ps.Extend(blockD)) + require.NoError(t, ps.Finalize(blockC.ID())) + + es := new(stateMock.FinalizedExecutionState) + es.On("GetHighestFinalizedExecuted").Return(genesis.Header.Height) + headers := new(storage.Headers) + headers.On("ByHeight", blockA.Header.Height).Return(blockA.Header, nil) + headers.On("ByHeight", blockB.Header.Height).Return(blockB.Header, nil) + headers.On("ByHeight", blockC.Header.Height).Return(blockC.Header, nil) + + loader := loader.NewUnfinalizedLoader(unittest.Logger(), ps, headers, es) + + unexecuted, err := loader.LoadUnexecuted(context.Background()) + require.NoError(t, err) + + unittest.IDsEqual(t, []flow.Identifier{ + blockA.ID(), + blockB.ID(), + blockC.ID(), + blockD.ID(), + }, unexecuted) +} diff --git a/engine/execution/ingestion/stop/stop_control.go b/engine/execution/ingestion/stop/stop_control.go index 94ec0781191..bb14e8905d5 100644 --- a/engine/execution/ingestion/stop/stop_control.go +++ b/engine/execution/ingestion/stop/stop_control.go @@ -493,7 +493,7 @@ func (s *StopControl) blockFinalized( Msgf("Found ID of the block that should be executed last") // check if the parent block has been executed then stop right away - executed, err := state.IsBlockExecuted(ctx, s.exeState, h.ParentID) + executed, err := state.IsParentExecuted(s.exeState, h) if err != nil { handleErr(fmt.Errorf( "failed to check if the block has been executed: %w", diff --git a/engine/execution/ingestion/stop/stop_control_test.go b/engine/execution/ingestion/stop/stop_control_test.go index 829a1f65a0f..6698c3cc7b8 100644 --- a/engine/execution/ingestion/stop/stop_control_test.go +++ b/engine/execution/ingestion/stop/stop_control_test.go @@ -14,7 +14,6 @@ import ( "github.com/onflow/flow-go/engine/execution/state/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/irrecoverable" - "github.com/onflow/flow-go/storage" storageMock "github.com/onflow/flow-go/storage/mock" "github.com/onflow/flow-go/utils/unittest" ) @@ -92,7 +91,7 @@ func TestCannotSetNewValuesAfterStoppingCommenced(t *testing.T) { require.Equal(t, stop, sc.GetStopParameters()) // make execution check pretends block has been executed - execState.On("StateCommitmentByBlockID", testifyMock.Anything, testifyMock.Anything).Return(nil, nil) + execState.On("IsBlockExecuted", testifyMock.Anything, testifyMock.Anything).Return(true, nil) // no stopping has started yet, block below stop height header := unittest.BlockHeaderFixture(unittest.WithHeaderHeight(20)) @@ -145,9 +144,7 @@ func TestExecutionFallingBehind(t *testing.T) { require.NoError(t, err) require.Equal(t, stop, sc.GetStopParameters()) - execState. - On("StateCommitmentByBlockID", testifyMock.Anything, headerC.ParentID). - Return(nil, storage.ErrNotFound) + execState.On("IsBlockExecuted", headerC.Height-1, headerC.ParentID).Return(false, nil) // finalize blocks first sc.BlockFinalizedForTesting(headerA) @@ -214,9 +211,7 @@ func TestAddStopForPastBlocks(t *testing.T) { sc.OnBlockExecuted(headerC) // block is executed - execState. - On("StateCommitmentByBlockID", testifyMock.Anything, headerD.ParentID). - Return(nil, nil) + execState.On("IsBlockExecuted", headerD.Height-1, headerD.ParentID).Return(true, nil) // set stop at 22, but finalization and execution is at 23 // so stop right away @@ -261,9 +256,7 @@ func TestAddStopForPastBlocksExecutionFallingBehind(t *testing.T) { false, ) - execState. - On("StateCommitmentByBlockID", testifyMock.Anything, headerD.ParentID). - Return(nil, storage.ErrNotFound) + execState.On("IsBlockExecuted", headerD.Height-1, headerD.ParentID).Return(false, nil) // finalize blocks first sc.BlockFinalizedForTesting(headerA) @@ -317,9 +310,7 @@ func TestStopControlWithVersionControl(t *testing.T) { ) // setting this means all finalized blocks are considered already executed - execState. - On("StateCommitmentByBlockID", testifyMock.Anything, headerC.ParentID). - Return(nil, nil) + execState.On("IsBlockExecuted", headerC.Height-1, headerC.ParentID).Return(true, nil) versionBeacons. On("Highest", testifyMock.Anything). @@ -741,12 +732,8 @@ func Test_StopControlWorkers(t *testing.T) { Once() execState := mock.NewExecutionState(t) - execState.On( - "StateCommitmentByBlockID", - testifyMock.Anything, - headerA.ID(), - ).Return(flow.StateCommitment{}, nil). - Once() + + execState.On("IsBlockExecuted", headerA.Height, headerA.ID()).Return(true, nil).Once() headers := &stopControlMockHeaders{ headers: map[uint64]*flow.Header{ @@ -817,12 +804,7 @@ func Test_StopControlWorkers(t *testing.T) { Once() execState := mock.NewExecutionState(t) - execState.On( - "StateCommitmentByBlockID", - testifyMock.Anything, - headerB.ID(), - ).Return(flow.StateCommitment{}, nil). - Once() + execState.On("IsBlockExecuted", headerB.Height, headerB.ID()).Return(true, nil).Once() headers := &stopControlMockHeaders{ headers: map[uint64]*flow.Header{ diff --git a/engine/execution/mock/executed_finalized_wal.go b/engine/execution/mock/executed_finalized_wal.go index faccfaec0cb..321467c9b49 100644 --- a/engine/execution/mock/executed_finalized_wal.go +++ b/engine/execution/mock/executed_finalized_wal.go @@ -15,11 +15,11 @@ type ExecutedFinalizedWAL struct { } // Append provides a mock function with given fields: height, registers -func (_m *ExecutedFinalizedWAL) Append(height uint64, registers []flow.RegisterEntry) error { +func (_m *ExecutedFinalizedWAL) Append(height uint64, registers flow.RegisterEntries) error { ret := _m.Called(height, registers) var r0 error - if rf, ok := ret.Get(0).(func(uint64, []flow.RegisterEntry) error); ok { + if rf, ok := ret.Get(0).(func(uint64, flow.RegisterEntries) error); ok { r0 = rf(height, registers) } else { r0 = ret.Error(0) diff --git a/engine/execution/mock/extendable_storage_snapshot.go b/engine/execution/mock/extendable_storage_snapshot.go new file mode 100644 index 00000000000..6b65c7ca52f --- /dev/null +++ b/engine/execution/mock/extendable_storage_snapshot.go @@ -0,0 +1,88 @@ +// Code generated by mockery v2.21.4. DO NOT EDIT. + +package mock + +import ( + execution "github.com/onflow/flow-go/engine/execution" + flow "github.com/onflow/flow-go/model/flow" + + mock "github.com/stretchr/testify/mock" +) + +// ExtendableStorageSnapshot is an autogenerated mock type for the ExtendableStorageSnapshot type +type ExtendableStorageSnapshot struct { + mock.Mock +} + +// Commitment provides a mock function with given fields: +func (_m *ExtendableStorageSnapshot) Commitment() flow.StateCommitment { + ret := _m.Called() + + var r0 flow.StateCommitment + if rf, ok := ret.Get(0).(func() flow.StateCommitment); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(flow.StateCommitment) + } + } + + return r0 +} + +// Extend provides a mock function with given fields: newCommit, updatedRegisters +func (_m *ExtendableStorageSnapshot) Extend(newCommit flow.StateCommitment, updatedRegisters map[flow.RegisterID][]byte) execution.ExtendableStorageSnapshot { + ret := _m.Called(newCommit, updatedRegisters) + + var r0 execution.ExtendableStorageSnapshot + if rf, ok := ret.Get(0).(func(flow.StateCommitment, map[flow.RegisterID][]byte) execution.ExtendableStorageSnapshot); ok { + r0 = rf(newCommit, updatedRegisters) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(execution.ExtendableStorageSnapshot) + } + } + + return r0 +} + +// Get provides a mock function with given fields: id +func (_m *ExtendableStorageSnapshot) Get(id flow.RegisterID) ([]byte, error) { + ret := _m.Called(id) + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(flow.RegisterID) ([]byte, error)); ok { + return rf(id) + } + if rf, ok := ret.Get(0).(func(flow.RegisterID) []byte); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(flow.RegisterID) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewExtendableStorageSnapshot interface { + mock.TestingT + Cleanup(func()) +} + +// NewExtendableStorageSnapshot creates a new instance of ExtendableStorageSnapshot. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewExtendableStorageSnapshot(t mockConstructorTestingTNewExtendableStorageSnapshot) *ExtendableStorageSnapshot { + mock := &ExtendableStorageSnapshot{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/engine/execution/mock/in_memory_register_store.go b/engine/execution/mock/in_memory_register_store.go index 606bbe32b05..561f39c8a4d 100644 --- a/engine/execution/mock/in_memory_register_store.go +++ b/engine/execution/mock/in_memory_register_store.go @@ -117,11 +117,11 @@ func (_m *InMemoryRegisterStore) PrunedHeight() uint64 { } // SaveRegisters provides a mock function with given fields: height, blockID, parentID, registers -func (_m *InMemoryRegisterStore) SaveRegisters(height uint64, blockID flow.Identifier, parentID flow.Identifier, registers []flow.RegisterEntry) error { +func (_m *InMemoryRegisterStore) SaveRegisters(height uint64, blockID flow.Identifier, parentID flow.Identifier, registers flow.RegisterEntries) error { ret := _m.Called(height, blockID, parentID, registers) var r0 error - if rf, ok := ret.Get(0).(func(uint64, flow.Identifier, flow.Identifier, []flow.RegisterEntry) error); ok { + if rf, ok := ret.Get(0).(func(uint64, flow.Identifier, flow.Identifier, flow.RegisterEntries) error); ok { r0 = rf(height, blockID, parentID, registers) } else { r0 = ret.Error(0) diff --git a/engine/execution/mock/register_store.go b/engine/execution/mock/register_store.go index 1e73de34a02..e2bd3dba400 100644 --- a/engine/execution/mock/register_store.go +++ b/engine/execution/mock/register_store.go @@ -91,11 +91,11 @@ func (_m *RegisterStore) OnBlockFinalized() error { } // SaveRegisters provides a mock function with given fields: header, registers -func (_m *RegisterStore) SaveRegisters(header *flow.Header, registers []flow.RegisterEntry) error { +func (_m *RegisterStore) SaveRegisters(header *flow.Header, registers flow.RegisterEntries) error { ret := _m.Called(header, registers) var r0 error - if rf, ok := ret.Get(0).(func(*flow.Header, []flow.RegisterEntry) error); ok { + if rf, ok := ret.Get(0).(func(*flow.Header, flow.RegisterEntries) error); ok { r0 = rf(header, registers) } else { r0 = ret.Error(0) diff --git a/engine/execution/mock/wal_reader.go b/engine/execution/mock/wal_reader.go index eb00b4643ef..f9917c8b520 100644 --- a/engine/execution/mock/wal_reader.go +++ b/engine/execution/mock/wal_reader.go @@ -13,13 +13,13 @@ type WALReader struct { } // Next provides a mock function with given fields: -func (_m *WALReader) Next() (uint64, []flow.RegisterEntry, error) { +func (_m *WALReader) Next() (uint64, flow.RegisterEntries, error) { ret := _m.Called() var r0 uint64 - var r1 []flow.RegisterEntry + var r1 flow.RegisterEntries var r2 error - if rf, ok := ret.Get(0).(func() (uint64, []flow.RegisterEntry, error)); ok { + if rf, ok := ret.Get(0).(func() (uint64, flow.RegisterEntries, error)); ok { return rf() } if rf, ok := ret.Get(0).(func() uint64); ok { @@ -28,11 +28,11 @@ func (_m *WALReader) Next() (uint64, []flow.RegisterEntry, error) { r0 = ret.Get(0).(uint64) } - if rf, ok := ret.Get(1).(func() []flow.RegisterEntry); ok { + if rf, ok := ret.Get(1).(func() flow.RegisterEntries); ok { r1 = rf() } else { if ret.Get(1) != nil { - r1 = ret.Get(1).([]flow.RegisterEntry) + r1 = ret.Get(1).(flow.RegisterEntries) } } diff --git a/engine/execution/rpc/engine.go b/engine/execution/rpc/engine.go index 4a0745fc3e1..a1015cc18e6 100644 --- a/engine/execution/rpc/engine.go +++ b/engine/execution/rpc/engine.go @@ -26,7 +26,7 @@ import ( "github.com/onflow/flow-go/engine/common/rpc" "github.com/onflow/flow-go/engine/common/rpc/convert" exeEng "github.com/onflow/flow-go/engine/execution" - "github.com/onflow/flow-go/engine/execution/scripts" + "github.com/onflow/flow-go/engine/execution/state" fvmerrors "github.com/onflow/flow-go/fvm/errors" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/state/protocol" @@ -169,7 +169,7 @@ type handler struct { maxBlockRange int } -var _ execution.ExecutionAPIServer = &handler{} +var _ execution.ExecutionAPIServer = (*handler)(nil) // Ping responds to requests when the server is up. func (h *handler) Ping( @@ -497,6 +497,164 @@ func (h *handler) GetTransactionResultsByBlockID( }, nil } +// GetTransactionErrorMessage implements a grpc handler for getting a transaction error message by block ID and tx ID. +// Expected error codes during normal operations: +// - codes.InvalidArgument - invalid blockID, tx ID. +// - codes.NotFound - transaction result by tx ID not found. +func (h *handler) GetTransactionErrorMessage( + _ context.Context, + req *execution.GetTransactionErrorMessageRequest, +) (*execution.GetTransactionErrorMessageResponse, error) { + reqBlockID := req.GetBlockId() + blockID, err := convert.BlockID(reqBlockID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid blockID: %v", err) + } + + reqTxID := req.GetTransactionId() + txID, err := convert.TransactionID(reqTxID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid transactionID: %v", err) + } + + // lookup any transaction error that might have occurred + txResult, err := h.transactionResults.ByBlockIDTransactionID(blockID, txID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, status.Error(codes.NotFound, "transaction result not found") + } + + return nil, status.Errorf(codes.Internal, "failed to get transaction result: %v", err) + } + + result := &execution.GetTransactionErrorMessageResponse{ + TransactionId: convert.IdentifierToMessage(txResult.TransactionID), + } + + if len(txResult.ErrorMessage) > 0 { + cadenceErrMessage := txResult.ErrorMessage + if !utf8.ValidString(cadenceErrMessage) { + h.log.Warn(). + Str("block_id", blockID.String()). + Str("transaction_id", txID.String()). + Str("error_mgs", fmt.Sprintf("%q", cadenceErrMessage)). + Msg("invalid character in Cadence error message") + // convert non UTF-8 string to a UTF-8 string for safe GRPC marshaling + cadenceErrMessage = strings.ToValidUTF8(txResult.ErrorMessage, "?") + } + result.ErrorMessage = cadenceErrMessage + } + return result, nil +} + +// GetTransactionErrorMessageByIndex implements a grpc handler for getting a transaction error message by block ID and tx index. +// Expected error codes during normal operations: +// - codes.InvalidArgument - invalid blockID. +// - codes.NotFound - transaction result at index not found. +func (h *handler) GetTransactionErrorMessageByIndex( + _ context.Context, + req *execution.GetTransactionErrorMessageByIndexRequest, +) (*execution.GetTransactionErrorMessageResponse, error) { + reqBlockID := req.GetBlockId() + blockID, err := convert.BlockID(reqBlockID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid blockID: %v", err) + } + + index := req.GetIndex() + + // lookup any transaction error that might have occurred + txResult, err := h.transactionResults.ByBlockIDTransactionIndex(blockID, index) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, status.Error(codes.NotFound, "transaction result not found") + } + + return nil, status.Errorf(codes.Internal, "failed to get transaction result: %v", err) + } + + result := &execution.GetTransactionErrorMessageResponse{ + TransactionId: convert.IdentifierToMessage(txResult.TransactionID), + } + + if len(txResult.ErrorMessage) > 0 { + cadenceErrMessage := txResult.ErrorMessage + if !utf8.ValidString(cadenceErrMessage) { + h.log.Warn(). + Str("block_id", blockID.String()). + Str("transaction_id", txResult.TransactionID.String()). + Str("error_mgs", fmt.Sprintf("%q", cadenceErrMessage)). + Msg("invalid character in Cadence error message") + // convert non UTF-8 string to a UTF-8 string for safe GRPC marshaling + cadenceErrMessage = strings.ToValidUTF8(txResult.ErrorMessage, "?") + } + result.ErrorMessage = cadenceErrMessage + } + return result, nil +} + +// GetTransactionErrorMessagesByBlockID implements a grpc handler for getting transaction error messages by block ID. +// Only failed transactions will be returned. +// Expected error codes during normal operations: +// - codes.InvalidArgument - invalid blockID. +// - codes.NotFound - block was not executed or was pruned. +func (h *handler) GetTransactionErrorMessagesByBlockID( + _ context.Context, + req *execution.GetTransactionErrorMessagesByBlockIDRequest, +) (*execution.GetTransactionErrorMessagesResponse, error) { + reqBlockID := req.GetBlockId() + blockID, err := convert.BlockID(reqBlockID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid blockID: %v", err) + } + + // must verify block was locally executed first since transactionResults.ByBlockID will return + // an empty slice if block does not exist + if _, err = h.commits.ByBlockID(blockID); err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, status.Errorf(codes.NotFound, "block %s has not been executed by node or was pruned", blockID) + } + return nil, status.Errorf(codes.Internal, "state commitment for block ID %s could not be retrieved", blockID) + } + + // Get all tx results + txResults, err := h.transactionResults.ByBlockID(blockID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, status.Error(codes.NotFound, "transaction results not found") + } + + return nil, status.Errorf(codes.Internal, "failed to get transaction results: %v", err) + } + + var results []*execution.GetTransactionErrorMessagesResponse_Result + for index, txResult := range txResults { + if len(txResult.ErrorMessage) == 0 { + continue + } + txIndex := uint32(index) + cadenceErrMessage := txResult.ErrorMessage + if !utf8.ValidString(cadenceErrMessage) { + h.log.Warn(). + Str("block_id", blockID.String()). + Uint32("index", txIndex). + Str("error_mgs", fmt.Sprintf("%q", cadenceErrMessage)). + Msg("invalid character in Cadence error message") + // convert non UTF-8 string to a UTF-8 string for safe GRPC marshaling + cadenceErrMessage = strings.ToValidUTF8(txResult.ErrorMessage, "?") + } + results = append(results, &execution.GetTransactionErrorMessagesResponse_Result{ + TransactionId: convert.IdentifierToMessage(txResult.TransactionID), + Index: txIndex, + ErrorMessage: cadenceErrMessage, + }) + } + + return &execution.GetTransactionErrorMessagesResponse{ + Results: results, + }, nil +} + // eventResult creates EventsResponse_Result from flow.Event for the given blockID func (h *handler) eventResult( blockID flow.Identifier, @@ -545,11 +703,14 @@ func (h *handler) GetAccountAtBlockID( value, err := h.engine.GetAccount(ctx, flowAddress, blockFlowID) if err != nil { - if errors.Is(err, scripts.ErrStateCommitmentPruned) { + if errors.Is(err, state.ErrExecutionStatePruned) { return nil, status.Errorf(codes.OutOfRange, "state for block ID %s not available", blockFlowID) } + if errors.Is(err, state.ErrNotExecuted) { + return nil, status.Errorf(codes.NotFound, "block %s has not been executed by node or was pruned", blockFlowID) + } if errors.Is(err, storage.ErrNotFound) { - return nil, status.Errorf(codes.NotFound, "account with address %s not found", flowAddress) + return nil, status.Errorf(codes.NotFound, "block %s not found", blockFlowID) } if fvmerrors.IsAccountNotFoundError(err) { return nil, status.Errorf(codes.NotFound, "account not found") diff --git a/engine/execution/rpc/engine_test.go b/engine/execution/rpc/engine_test.go index 6aec7283fb4..d2f3913123a 100644 --- a/engine/execution/rpc/engine_test.go +++ b/engine/execution/rpc/engine_test.go @@ -20,7 +20,7 @@ import ( "github.com/onflow/flow-go/engine/common/rpc/convert" mockEng "github.com/onflow/flow-go/engine/execution/mock" - "github.com/onflow/flow-go/engine/execution/scripts" + "github.com/onflow/flow-go/engine/execution/state" "github.com/onflow/flow-go/model/flow" realstorage "github.com/onflow/flow-go/storage" storage "github.com/onflow/flow-go/storage/mock" @@ -335,7 +335,7 @@ func (suite *Suite) TestGetAccountAtBlockID() { "this error usually happens if the reference "+ "block for this script is not set to a recent block.", id, - scripts.ErrStateCommitmentPruned, + state.ErrExecutionStatePruned, unittest.IdentifierFixture(), ) @@ -651,7 +651,28 @@ func (suite *Suite) TestGetTransactionResult() { // expect a storage call for the invalid tx ID but return an error txResults := storage.NewTransactionResults(suite.T()) - txResults.On("ByBlockIDTransactionID", bID, wrongTxID).Return(nil, status.Error(codes.Internal, "")).Once() + txResults.On("ByBlockIDTransactionID", bID, wrongTxID).Return(nil, realstorage.ErrNotFound).Once() + + handler := createHandler(txResults) + + _, err := handler.GetTransactionResult(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.NotFound, "")) + }) + + // failure path - non-existent transaction ID in request results in an exception + suite.Run("request with non-existent transaction ID, exception", func() { + + wrongTxID := unittest.IdentifierFixture() + + // create an API request with the invalid transaction ID + req := concoctReq(bID[:], wrongTxID[:]) + + // expect a storage call for the invalid tx ID but return an exception + txResults := storage.NewTransactionResults(suite.T()) + txResults.On("ByBlockIDTransactionID", bID, wrongTxID).Return(nil, errors.New("internal-error")).Once() handler := createHandler(txResults) @@ -672,7 +693,28 @@ func (suite *Suite) TestGetTransactionResult() { // expect a storage call for the invalid tx ID but return an error txResults := storage.NewTransactionResults(suite.T()) - txResults.On("ByBlockIDTransactionIndex", bID, wrongTxIndex).Return(nil, status.Error(codes.Internal, "")).Once() + txResults.On("ByBlockIDTransactionIndex", bID, wrongTxIndex).Return(nil, realstorage.ErrNotFound).Once() + + handler := createHandler(txResults) + + _, err := handler.GetTransactionResultByIndex(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.NotFound, "")) + }) + + // failure path - non-existent transaction index in request results in an exception + suite.Run("request with non-existent transaction index, exception", func() { + + wrongTxIndex := txIndex + 1 + + // create an API request with the invalid transaction ID + req := concoctIndexReq(bID[:], wrongTxIndex) + + // expect a storage call for the invalid tx ID but return an exception + txResults := storage.NewTransactionResults(suite.T()) + txResults.On("ByBlockIDTransactionIndex", bID, wrongTxIndex).Return(nil, errors.New("internal-error")).Once() handler := createHandler(txResults) @@ -888,3 +930,453 @@ func (suite *Suite) TestGetTransactionResultsByBlockID() { errors.Is(err, status.Error(codes.NotFound, "")) }) } + +// TestGetTransactionErrorMessage tests the GetTransactionErrorMessage and GetTransactionErrorMessageByIndex API calls +func (suite *Suite) TestGetTransactionErrorMessage() { + block := unittest.BlockFixture() + tx := unittest.TransactionFixture() + bID := block.ID() + txID := tx.ID() + txIndex := rand.Uint32() + + // create the handler + createHandler := func(txResults *storage.TransactionResults) *handler { + handler := &handler{ + headers: suite.headers, + events: suite.events, + transactionResults: txResults, + commits: suite.commits, + chain: flow.Mainnet, + } + return handler + } + + // concoctReq creates a GetTransactionErrorMessageRequest + concoctReq := func(bID []byte, tID []byte) *execution.GetTransactionErrorMessageRequest { + return &execution.GetTransactionErrorMessageRequest{ + BlockId: bID, + TransactionId: tID, + } + } + + // concoctIndexReq creates a GetTransactionErrorMessageByIndexRequest + concoctIndexReq := func(bID []byte, tIndex uint32) *execution.GetTransactionErrorMessageByIndexRequest { + return &execution.GetTransactionErrorMessageByIndexRequest{ + BlockId: bID, + Index: tIndex, + } + } + + suite.Run("happy path - by tx id - no transaction error", func() { + + // create the expected result + expectedResult := &execution.GetTransactionErrorMessageResponse{ + TransactionId: convert.IdentifierToMessage(txID), + ErrorMessage: "", + } + + // expect a call to lookup transaction result by block ID and transaction ID, return a result with no error + txResults := storage.NewTransactionResults(suite.T()) + txResult := flow.TransactionResult{ + TransactionID: txID, + ErrorMessage: "", + } + txResults.On("ByBlockIDTransactionID", bID, txID).Return(&txResult, nil).Once() + + handler := createHandler(txResults) + + // create a valid API request + req := concoctReq(bID[:], txID[:]) + + // execute the GetTransactionErrorMessage call + actualResult, err := handler.GetTransactionErrorMessage(context.Background(), req) + + // check that a successful response is received + suite.Require().NoError(err) + + // check that all fields in response are as expected + suite.Equal(expectedResult, actualResult) + }) + + suite.Run("happy path - at index - no transaction error", func() { + + // create the expected result + expectedResult := &execution.GetTransactionErrorMessageResponse{ + TransactionId: convert.IdentifierToMessage(txID), + ErrorMessage: "", + } + + // expect a call to lookup transaction result by block ID and transaction ID, return a result with no error + txResults := storage.NewTransactionResults(suite.T()) + txResult := flow.TransactionResult{ + TransactionID: txID, + ErrorMessage: "", + } + txResults.On("ByBlockIDTransactionIndex", bID, txIndex).Return(&txResult, nil).Once() + + handler := createHandler(txResults) + + // create a valid API request + req := concoctIndexReq(bID[:], txIndex) + + // execute the GetTransactionResult call + actualResult, err := handler.GetTransactionErrorMessageByIndex(context.Background(), req) + + // check that a successful response is received + suite.Require().NoError(err) + + // check that all fields in response are as expected + suite.Equal(expectedResult, actualResult) + }) + + suite.Run("happy path - by tx id - transaction error", func() { + + // create the expected result + expectedResult := &execution.GetTransactionErrorMessageResponse{ + TransactionId: convert.IdentifierToMessage(txID), + ErrorMessage: "runtime error", + } + + // setup the storage to return a transaction error + txResults := storage.NewTransactionResults(suite.T()) + txResult := flow.TransactionResult{ + TransactionID: txID, + ErrorMessage: "runtime error", + } + txResults.On("ByBlockIDTransactionID", bID, txID).Return(&txResult, nil).Once() + + handler := createHandler(txResults) + + // create a valid API request + req := concoctReq(bID[:], txID[:]) + + // execute the GetTransactionErrorMessage call + actualResult, err := handler.GetTransactionErrorMessage(context.Background(), req) + + // check that a successful response is received + suite.Require().NoError(err) + + // check that all fields in response are as expected + suite.Equal(expectedResult, actualResult) + }) + + suite.Run("happy path - at index - transaction error", func() { + + // create the expected result + expectedResult := &execution.GetTransactionErrorMessageResponse{ + TransactionId: convert.IdentifierToMessage(txID), + ErrorMessage: "runtime error", + } + + // setup the storage to return a transaction error + txResults := storage.NewTransactionResults(suite.T()) + txResult := flow.TransactionResult{ + TransactionID: txID, + ErrorMessage: "runtime error", + } + txResults.On("ByBlockIDTransactionIndex", bID, txIndex).Return(&txResult, nil).Once() + + handler := createHandler(txResults) + + // create a valid API request + req := concoctIndexReq(bID[:], txIndex) + + // execute the GetTransactionErrorMessageByIndex call + actualResult, err := handler.GetTransactionErrorMessageByIndex(context.Background(), req) + + // check that a successful response is received + suite.Require().NoError(err) + + // check that all fields in response are as expected + suite.Equal(expectedResult, actualResult) + }) + + // failure path - nil transaction ID in the request results in an error + suite.Run("request with nil tx ID", func() { + + // create an API request with transaction ID as nil + req := concoctReq(bID[:], nil) + + txResults := storage.NewTransactionResults(suite.T()) + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessage(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.InvalidArgument, "")) + }) + + // failure path - nil block id in the request results in an error + suite.Run("request with nil block ID", func() { + + // create an API request with a nil block id + req := concoctReq(nil, txID[:]) + + txResults := storage.NewTransactionResults(suite.T()) + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessage(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.InvalidArgument, "")) + }) + + // failure path - nil block id in the index request results in an error + suite.Run("index request with nil block ID", func() { + + // create an API request with a nil block id + req := concoctIndexReq(nil, txIndex) + + txResults := storage.NewTransactionResults(suite.T()) + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessageByIndex(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.InvalidArgument, "")) + }) + + // failure path - non-existent transaction ID in request results in an error + suite.Run("request with non-existent transaction ID", func() { + + wrongTxID := unittest.IdentifierFixture() + + // create an API request with the invalid transaction ID + req := concoctReq(bID[:], wrongTxID[:]) + + // expect a storage call for the invalid tx ID but return an error + txResults := storage.NewTransactionResults(suite.T()) + txResults.On("ByBlockIDTransactionID", bID, wrongTxID).Return(nil, realstorage.ErrNotFound).Once() + + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessage(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.NotFound, "")) + }) + + // failure path - non-existent transaction ID in request results in an exception + suite.Run("request with non-existent transaction ID, exception", func() { + + wrongTxID := unittest.IdentifierFixture() + + // create an API request with the invalid transaction ID + req := concoctReq(bID[:], wrongTxID[:]) + + // expect a storage call for the invalid tx ID but return an exception + txResults := storage.NewTransactionResults(suite.T()) + txResults.On("ByBlockIDTransactionID", bID, wrongTxID).Return(nil, errors.New("internal-error")).Once() + + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessage(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.Internal, "")) + }) + + // failure path - non-existent transaction index in request results in an error + suite.Run("request with non-existent transaction index", func() { + + wrongTxIndex := txIndex + 1 + + // create an API request with the invalid transaction ID + req := concoctIndexReq(bID[:], wrongTxIndex) + + // expect a storage call for the invalid tx ID but return an error + txResults := storage.NewTransactionResults(suite.T()) + txResults.On("ByBlockIDTransactionIndex", bID, wrongTxIndex).Return(nil, realstorage.ErrNotFound).Once() + + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessageByIndex(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.NotFound, "")) + }) + + // failure path - non-existent transaction index in request results in an exception + suite.Run("request with non-existent transaction index, exception", func() { + + wrongTxIndex := txIndex + 1 + + // create an API request with the invalid transaction ID + req := concoctIndexReq(bID[:], wrongTxIndex) + + // expect a storage call for the invalid tx ID but return an exception + txResults := storage.NewTransactionResults(suite.T()) + txResults.On("ByBlockIDTransactionIndex", bID, wrongTxIndex).Return(nil, errors.New("internal-error")).Once() + + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessageByIndex(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.Internal, "")) + }) +} + +// TestGetTransactionErrorMessagesByBlockID tests GetTransactionErrorMessagesByBlockID API calls +func (suite *Suite) TestGetTransactionErrorMessagesByBlockID() { + block := unittest.BlockFixture() + tx := unittest.TransactionFixture() + bID := block.ID() + nonexistingBlockID := unittest.IdentifierFixture() + tx1ID := tx.ID() + tx2ID := tx.ID() + tx3ID := tx.ID() + + // create the handler + createHandler := func(txResults *storage.TransactionResults) *handler { + handler := &handler{ + headers: suite.headers, + events: suite.events, + transactionResults: txResults, + commits: suite.commits, + chain: flow.Mainnet, + } + return handler + } + + // concoctReq creates a GetTransactionErrorMessagesByBlockIDRequest + concoctReq := func(bID []byte) *execution.GetTransactionErrorMessagesByBlockIDRequest { + return &execution.GetTransactionErrorMessagesByBlockIDRequest{ + BlockId: bID, + } + } + + // happy path - if no transaction errors are found, an empty list is returned + suite.Run("happy path with no transaction error", func() { + suite.commits.On("ByBlockID", bID).Return(nil, nil).Once() + + // create the expected result + expectedResult := &execution.GetTransactionErrorMessagesResponse{ + Results: []*execution.GetTransactionErrorMessagesResponse_Result{}, + } + + // expect a call to lookup transaction result by block ID return a result with no error + txResultsMock := storage.NewTransactionResults(suite.T()) + txResults := []flow.TransactionResult{ + { + TransactionID: tx1ID, + ErrorMessage: "", + }, + { + TransactionID: tx2ID, + ErrorMessage: "", + }, + } + txResultsMock.On("ByBlockID", bID).Return(txResults, nil).Once() + + handler := createHandler(txResultsMock) + + // create a valid API request + req := concoctReq(bID[:]) + + // execute the GetTransactionErrorMessagesByBlockID call + actualResult, err := handler.GetTransactionErrorMessagesByBlockID(context.Background(), req) + + // check that a successful response is received + suite.Require().NoError(err) + + // check that all fields in response are as expected + suite.Assert().ElementsMatch(expectedResult.Results, actualResult.Results) + }) + + // happy path - valid requests receives error messages for all failed transactions. + suite.Run("happy path with transaction errors", func() { + + suite.commits.On("ByBlockID", bID).Return(nil, nil).Once() + + // create the expected result + expectedResult := &execution.GetTransactionErrorMessagesResponse{ + Results: []*execution.GetTransactionErrorMessagesResponse_Result{ + { + TransactionId: convert.IdentifierToMessage(tx2ID), + Index: 1, + ErrorMessage: "runtime error", + }, + { + TransactionId: convert.IdentifierToMessage(tx3ID), + Index: 2, + ErrorMessage: "runtime error", + }, + }, + } + + // expect a call to lookup transaction result by block ID return a result with no error + txResultsMock := storage.NewTransactionResults(suite.T()) + txResults := []flow.TransactionResult{ + { + TransactionID: tx1ID, + ErrorMessage: "", + }, + { + TransactionID: tx2ID, + ErrorMessage: "runtime error", + }, + { + TransactionID: tx3ID, + ErrorMessage: "runtime error", + }, + } + txResultsMock.On("ByBlockID", bID).Return(txResults, nil).Once() + + handler := createHandler(txResultsMock) + + // create a valid API request + req := concoctReq(bID[:]) + + // execute the GetTransactionErrorMessagesByBlockID call + actualResult, err := handler.GetTransactionErrorMessagesByBlockID(context.Background(), req) + + // check that a successful response is received + suite.Require().NoError(err) + + // check that all fields in response are as expected + suite.Assert().ElementsMatch(expectedResult.Results, actualResult.Results) + }) + + // failure path - nil block id in the request results in an error + suite.Run("request with nil block ID", func() { + + // create an API request with a nil block id + req := concoctReq(nil) + + txResults := storage.NewTransactionResults(suite.T()) + handler := createHandler(txResults) + + _, err := handler.GetTransactionErrorMessagesByBlockID(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.InvalidArgument, "")) + }) + + // failure path - nonexisting block id in the request results in not found error + suite.Run("request with nonexisting block ID", func() { + + suite.commits.On("ByBlockID", nonexistingBlockID).Return(nil, realstorage.ErrNotFound).Once() + + txResultsMock := storage.NewTransactionResults(suite.T()) + handler := createHandler(txResultsMock) + + // create a valid API request + req := concoctReq(nonexistingBlockID[:]) + + // execute the GetTransactionResult call + _, err := handler.GetTransactionErrorMessagesByBlockID(context.Background(), req) + + // check that an error was received + suite.Require().Error(err) + errors.Is(err, status.Error(codes.NotFound, "")) + }) +} diff --git a/engine/execution/scripts/engine.go b/engine/execution/scripts/engine.go index 0f25cf409ab..689d7858223 100644 --- a/engine/execution/scripts/engine.go +++ b/engine/execution/scripts/engine.go @@ -2,7 +2,6 @@ package scripts import ( "context" - "encoding/hex" "fmt" "github.com/rs/zerolog" @@ -12,15 +11,11 @@ import ( "github.com/onflow/flow-go/engine/execution/computation/query" "github.com/onflow/flow-go/engine/execution/state" "github.com/onflow/flow-go/model/flow" - "github.com/onflow/flow-go/state/protocol" ) -var ErrStateCommitmentPruned = fmt.Errorf("state commitment not found") - type Engine struct { unit *engine.Unit log zerolog.Logger - state protocol.State queryExecutor query.Executor execState state.ScriptExecutionState } @@ -29,14 +24,12 @@ var _ execution.ScriptExecutor = (*Engine)(nil) func New( logger zerolog.Logger, - state protocol.State, queryExecutor query.Executor, execState state.ScriptExecutionState, ) *Engine { return &Engine{ unit: engine.NewUnit(), log: logger.With().Str("engine", "scripts").Logger(), - state: state, execState: execState, queryExecutor: queryExecutor, } @@ -57,31 +50,11 @@ func (e *Engine) ExecuteScriptAtBlockID( blockID flow.Identifier, ) ([]byte, error) { - stateCommit, err := e.execState.StateCommitmentByBlockID(ctx, blockID) + blockSnapshot, header, err := e.execState.CreateStorageSnapshot(blockID) if err != nil { - return nil, fmt.Errorf("failed to get state commitment for block (%s): %w", blockID, err) - } - - // return early if state with the given state commitment is not in memory - // and already purged. This reduces allocations for scripts targeting old blocks. - if !e.execState.HasState(stateCommit) { - return nil, fmt.Errorf( - "failed to execute script at block (%s): %w (%s). "+ - "this error usually happens if the reference "+ - "block for this script is not set to a recent block.", - blockID.String(), - ErrStateCommitmentPruned, - hex.EncodeToString(stateCommit[:]), - ) + return nil, fmt.Errorf("failed to create storage snapshot: %w", err) } - header, err := e.state.AtBlockID(blockID).Head() - if err != nil { - return nil, fmt.Errorf("failed to get header (%s): %w", blockID, err) - } - - blockSnapshot := e.execState.NewStorageSnapshot(stateCommit) - return e.queryExecutor.ExecuteScript( ctx, script, @@ -96,13 +69,11 @@ func (e *Engine) GetRegisterAtBlockID( blockID flow.Identifier, ) ([]byte, error) { - stateCommit, err := e.execState.StateCommitmentByBlockID(ctx, blockID) + blockSnapshot, _, err := e.execState.CreateStorageSnapshot(blockID) if err != nil { - return nil, fmt.Errorf("failed to get state commitment for block (%s): %w", blockID, err) + return nil, fmt.Errorf("failed to create storage snapshot: %w", err) } - blockSnapshot := e.execState.NewStorageSnapshot(stateCommit) - id := flow.NewRegisterID(string(owner), string(key)) data, err := blockSnapshot.Get(id) if err != nil { @@ -117,29 +88,10 @@ func (e *Engine) GetAccount( addr flow.Address, blockID flow.Identifier, ) (*flow.Account, error) { - stateCommit, err := e.execState.StateCommitmentByBlockID(ctx, blockID) + blockSnapshot, header, err := e.execState.CreateStorageSnapshot(blockID) if err != nil { - return nil, fmt.Errorf("failed to get state commitment for block (%s): %w", blockID, err) - } - - // return early if state with the given state commitment is not in memory - // and already purged. This reduces allocations for get accounts targeting old blocks. - if !e.execState.HasState(stateCommit) { - return nil, fmt.Errorf( - "failed to get account at block (%s): %w (%s). "+ - "this error usually happens if the reference "+ - "block for this script is not set to a recent block.", - blockID.String(), - ErrStateCommitmentPruned, - hex.EncodeToString(stateCommit[:])) + return nil, fmt.Errorf("failed to create storage snapshot: %w", err) } - block, err := e.state.AtBlockID(blockID).Head() - if err != nil { - return nil, fmt.Errorf("failed to get block (%s): %w", blockID, err) - } - - blockSnapshot := e.execState.NewStorageSnapshot(stateCommit) - - return e.queryExecutor.GetAccount(ctx, addr, block, blockSnapshot) + return e.queryExecutor.GetAccount(ctx, addr, header, blockSnapshot) } diff --git a/engine/execution/scripts/engine_test.go b/engine/execution/scripts/engine_test.go deleted file mode 100644 index 5b5c116830f..00000000000 --- a/engine/execution/scripts/engine_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package scripts - -import ( - "context" - "strings" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - queryMock "github.com/onflow/flow-go/engine/execution/computation/query/mock" - stateMock "github.com/onflow/flow-go/engine/execution/state/mock" - "github.com/onflow/flow-go/model/flow" - protocol "github.com/onflow/flow-go/state/protocol/mock" - "github.com/onflow/flow-go/utils/unittest" -) - -type testingContext struct { - t *testing.T - engine *Engine - state *protocol.State - executionState *stateMock.ExecutionState - queryExecutor *queryMock.Executor - mu *sync.Mutex -} - -func (ctx *testingContext) stateCommitmentExist(blockID flow.Identifier, commit flow.StateCommitment) { - ctx.executionState.On("StateCommitmentByBlockID", mock.Anything, blockID).Return(commit, nil) -} - -func runWithEngine(t *testing.T, fn func(ctx testingContext)) { - log := unittest.Logger() - - queryExecutor := new(queryMock.Executor) - protocolState := new(protocol.State) - execState := new(stateMock.ExecutionState) - - engine := New(log, protocolState, queryExecutor, execState) - fn(testingContext{ - t: t, - engine: engine, - queryExecutor: queryExecutor, - executionState: execState, - state: protocolState, - }) -} - -func TestExecuteScriptAtBlockID(t *testing.T) { - t.Run("happy path", func(t *testing.T) { - runWithEngine(t, func(ctx testingContext) { - // Meaningless script - script := []byte{1, 1, 2, 3, 5, 8, 11} - scriptResult := []byte{1} - - // Ensure block we're about to query against is executable - blockA := unittest.ExecutableBlockFixture(nil, unittest.StateCommitmentPointerFixture()) - - snapshot := new(protocol.Snapshot) - snapshot.On("Head").Return(blockA.Block.Header, nil) - - commits := make(map[flow.Identifier]flow.StateCommitment) - commits[blockA.ID()] = *blockA.StartState - - ctx.stateCommitmentExist(blockA.ID(), *blockA.StartState) - - ctx.state.On("AtBlockID", blockA.Block.ID()).Return(snapshot) - ctx.executionState.On("NewStorageSnapshot", *blockA.StartState).Return(nil) - - ctx.executionState.On("HasState", *blockA.StartState).Return(true) - - // Successful call to computation manager - ctx.queryExecutor. - On("ExecuteScript", mock.Anything, script, [][]byte(nil), blockA.Block.Header, nil). - Return(scriptResult, nil) - - // Execute our script and expect no error - res, err := ctx.engine.ExecuteScriptAtBlockID(context.Background(), script, nil, blockA.Block.ID()) - assert.NoError(t, err) - assert.Equal(t, scriptResult, res) - - // Assert other components were called as expected - ctx.queryExecutor.AssertExpectations(t) - ctx.executionState.AssertExpectations(t) - ctx.state.AssertExpectations(t) - }) - }) - - t.Run("return early when state commitment not exist", func(t *testing.T) { - runWithEngine(t, func(ctx testingContext) { - // Meaningless script - script := []byte{1, 1, 2, 3, 5, 8, 11} - - // Ensure block we're about to query against is executable - blockA := unittest.ExecutableBlockFixture(nil, unittest.StateCommitmentPointerFixture()) - - // make sure blockID to state commitment mapping exist - ctx.executionState.On("StateCommitmentByBlockID", mock.Anything, blockA.ID()).Return(*blockA.StartState, nil) - - // but the state commitment does not exist (e.g. purged) - ctx.executionState.On("HasState", *blockA.StartState).Return(false) - - // Execute our script and expect no error - _, err := ctx.engine.ExecuteScriptAtBlockID(context.Background(), script, nil, blockA.Block.ID()) - assert.Error(t, err) - assert.True(t, strings.Contains(err.Error(), "state commitment not found")) - - // Assert other components were called as expected - ctx.executionState.AssertExpectations(t) - ctx.state.AssertExpectations(t) - }) - }) - -} diff --git a/engine/execution/state/bootstrap/bootstrap.go b/engine/execution/state/bootstrap/bootstrap.go index 0addc1665d0..b8cdc1192e2 100644 --- a/engine/execution/state/bootstrap/bootstrap.go +++ b/engine/execution/state/bootstrap/bootstrap.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/execution/state" + "github.com/onflow/flow-go/engine/execution/storehouse" "github.com/onflow/flow-go/fvm" "github.com/onflow/flow-go/fvm/storage/snapshot" "github.com/onflow/flow-go/ledger" @@ -36,9 +37,10 @@ func (b *Bootstrapper) BootstrapLedger( chain flow.Chain, opts ...fvm.BootstrapProcedureOption, ) (flow.StateCommitment, error) { + startCommit := flow.StateCommitment(ledger.InitialState()) storageSnapshot := state.NewLedgerStorageSnapshot( ledger, - flow.StateCommitment(ledger.InitialState())) + startCommit) vm := fvm.NewVirtualMachine() @@ -58,10 +60,11 @@ func (b *Bootstrapper) BootstrapLedger( return flow.DummyStateCommitment, err } - newStateCommitment, _, err := state.CommitDelta( + newStateCommitment, _, _, err := state.CommitDelta( ledger, executionSnapshot, - flow.StateCommitment(ledger.InitialState())) + storehouse.NewExecutingBlockSnapshot(storageSnapshot, startCommit), + ) if err != nil { return flow.DummyStateCommitment, err } diff --git a/engine/execution/state/mock/execution_state.go b/engine/execution/state/mock/execution_state.go index f847632cd94..0750c4a3853 100644 --- a/engine/execution/state/mock/execution_state.go +++ b/engine/execution/state/mock/execution_state.go @@ -44,6 +44,41 @@ func (_m *ExecutionState) ChunkDataPackByChunkID(_a0 flow.Identifier) (*flow.Chu return r0, r1 } +// CreateStorageSnapshot provides a mock function with given fields: blockID +func (_m *ExecutionState) CreateStorageSnapshot(blockID flow.Identifier) (snapshot.StorageSnapshot, *flow.Header, error) { + ret := _m.Called(blockID) + + var r0 snapshot.StorageSnapshot + var r1 *flow.Header + var r2 error + if rf, ok := ret.Get(0).(func(flow.Identifier) (snapshot.StorageSnapshot, *flow.Header, error)); ok { + return rf(blockID) + } + if rf, ok := ret.Get(0).(func(flow.Identifier) snapshot.StorageSnapshot); ok { + r0 = rf(blockID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snapshot.StorageSnapshot) + } + } + + if rf, ok := ret.Get(1).(func(flow.Identifier) *flow.Header); ok { + r1 = rf(blockID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*flow.Header) + } + } + + if rf, ok := ret.Get(2).(func(flow.Identifier) error); ok { + r2 = rf(blockID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // GetExecutionResultID provides a mock function with given fields: _a0, _a1 func (_m *ExecutionState) GetExecutionResultID(_a0 context.Context, _a1 flow.Identifier) (flow.Identifier, error) { ret := _m.Called(_a0, _a1) @@ -117,13 +152,37 @@ func (_m *ExecutionState) HasState(_a0 flow.StateCommitment) bool { return r0 } -// NewStorageSnapshot provides a mock function with given fields: _a0 -func (_m *ExecutionState) NewStorageSnapshot(_a0 flow.StateCommitment) snapshot.StorageSnapshot { - ret := _m.Called(_a0) +// IsBlockExecuted provides a mock function with given fields: height, blockID +func (_m *ExecutionState) IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) { + ret := _m.Called(height, blockID) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(uint64, flow.Identifier) (bool, error)); ok { + return rf(height, blockID) + } + if rf, ok := ret.Get(0).(func(uint64, flow.Identifier) bool); ok { + r0 = rf(height, blockID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(uint64, flow.Identifier) error); ok { + r1 = rf(height, blockID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewStorageSnapshot provides a mock function with given fields: commit, blockID, height +func (_m *ExecutionState) NewStorageSnapshot(commit flow.StateCommitment, blockID flow.Identifier, height uint64) snapshot.StorageSnapshot { + ret := _m.Called(commit, blockID, height) var r0 snapshot.StorageSnapshot - if rf, ok := ret.Get(0).(func(flow.StateCommitment) snapshot.StorageSnapshot); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(flow.StateCommitment, flow.Identifier, uint64) snapshot.StorageSnapshot); ok { + r0 = rf(commit, blockID, height) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(snapshot.StorageSnapshot) @@ -147,25 +206,25 @@ func (_m *ExecutionState) SaveExecutionResults(ctx context.Context, result *exec return r0 } -// StateCommitmentByBlockID provides a mock function with given fields: _a0, _a1 -func (_m *ExecutionState) StateCommitmentByBlockID(_a0 context.Context, _a1 flow.Identifier) (flow.StateCommitment, error) { - ret := _m.Called(_a0, _a1) +// StateCommitmentByBlockID provides a mock function with given fields: _a0 +func (_m *ExecutionState) StateCommitmentByBlockID(_a0 flow.Identifier) (flow.StateCommitment, error) { + ret := _m.Called(_a0) var r0 flow.StateCommitment var r1 error - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier) (flow.StateCommitment, error)); ok { - return rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(flow.Identifier) (flow.StateCommitment, error)); ok { + return rf(_a0) } - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier) flow.StateCommitment); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(flow.Identifier) flow.StateCommitment); ok { + r0 = rf(_a0) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(flow.StateCommitment) } } - if rf, ok := ret.Get(1).(func(context.Context, flow.Identifier) error); ok { - r1 = rf(_a0, _a1) + if rf, ok := ret.Get(1).(func(flow.Identifier) error); ok { + r1 = rf(_a0) } else { r1 = ret.Error(1) } diff --git a/engine/execution/state/mock/finalized_execution_state.go b/engine/execution/state/mock/finalized_execution_state.go new file mode 100644 index 00000000000..ae878be58e0 --- /dev/null +++ b/engine/execution/state/mock/finalized_execution_state.go @@ -0,0 +1,39 @@ +// Code generated by mockery v2.21.4. DO NOT EDIT. + +package mock + +import mock "github.com/stretchr/testify/mock" + +// FinalizedExecutionState is an autogenerated mock type for the FinalizedExecutionState type +type FinalizedExecutionState struct { + mock.Mock +} + +// GetHighestFinalizedExecuted provides a mock function with given fields: +func (_m *FinalizedExecutionState) GetHighestFinalizedExecuted() uint64 { + ret := _m.Called() + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +type mockConstructorTestingTNewFinalizedExecutionState interface { + mock.TestingT + Cleanup(func()) +} + +// NewFinalizedExecutionState creates a new instance of FinalizedExecutionState. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewFinalizedExecutionState(t mockConstructorTestingTNewFinalizedExecutionState) *FinalizedExecutionState { + mock := &FinalizedExecutionState{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/engine/execution/state/mock/read_only_execution_state.go b/engine/execution/state/mock/read_only_execution_state.go index 24f230ed316..108abf02bfe 100644 --- a/engine/execution/state/mock/read_only_execution_state.go +++ b/engine/execution/state/mock/read_only_execution_state.go @@ -42,6 +42,41 @@ func (_m *ReadOnlyExecutionState) ChunkDataPackByChunkID(_a0 flow.Identifier) (* return r0, r1 } +// CreateStorageSnapshot provides a mock function with given fields: blockID +func (_m *ReadOnlyExecutionState) CreateStorageSnapshot(blockID flow.Identifier) (snapshot.StorageSnapshot, *flow.Header, error) { + ret := _m.Called(blockID) + + var r0 snapshot.StorageSnapshot + var r1 *flow.Header + var r2 error + if rf, ok := ret.Get(0).(func(flow.Identifier) (snapshot.StorageSnapshot, *flow.Header, error)); ok { + return rf(blockID) + } + if rf, ok := ret.Get(0).(func(flow.Identifier) snapshot.StorageSnapshot); ok { + r0 = rf(blockID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snapshot.StorageSnapshot) + } + } + + if rf, ok := ret.Get(1).(func(flow.Identifier) *flow.Header); ok { + r1 = rf(blockID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*flow.Header) + } + } + + if rf, ok := ret.Get(2).(func(flow.Identifier) error); ok { + r2 = rf(blockID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // GetExecutionResultID provides a mock function with given fields: _a0, _a1 func (_m *ReadOnlyExecutionState) GetExecutionResultID(_a0 context.Context, _a1 flow.Identifier) (flow.Identifier, error) { ret := _m.Called(_a0, _a1) @@ -115,13 +150,37 @@ func (_m *ReadOnlyExecutionState) HasState(_a0 flow.StateCommitment) bool { return r0 } -// NewStorageSnapshot provides a mock function with given fields: _a0 -func (_m *ReadOnlyExecutionState) NewStorageSnapshot(_a0 flow.StateCommitment) snapshot.StorageSnapshot { - ret := _m.Called(_a0) +// IsBlockExecuted provides a mock function with given fields: height, blockID +func (_m *ReadOnlyExecutionState) IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) { + ret := _m.Called(height, blockID) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(uint64, flow.Identifier) (bool, error)); ok { + return rf(height, blockID) + } + if rf, ok := ret.Get(0).(func(uint64, flow.Identifier) bool); ok { + r0 = rf(height, blockID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(uint64, flow.Identifier) error); ok { + r1 = rf(height, blockID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewStorageSnapshot provides a mock function with given fields: commit, blockID, height +func (_m *ReadOnlyExecutionState) NewStorageSnapshot(commit flow.StateCommitment, blockID flow.Identifier, height uint64) snapshot.StorageSnapshot { + ret := _m.Called(commit, blockID, height) var r0 snapshot.StorageSnapshot - if rf, ok := ret.Get(0).(func(flow.StateCommitment) snapshot.StorageSnapshot); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(flow.StateCommitment, flow.Identifier, uint64) snapshot.StorageSnapshot); ok { + r0 = rf(commit, blockID, height) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(snapshot.StorageSnapshot) @@ -131,25 +190,25 @@ func (_m *ReadOnlyExecutionState) NewStorageSnapshot(_a0 flow.StateCommitment) s return r0 } -// StateCommitmentByBlockID provides a mock function with given fields: _a0, _a1 -func (_m *ReadOnlyExecutionState) StateCommitmentByBlockID(_a0 context.Context, _a1 flow.Identifier) (flow.StateCommitment, error) { - ret := _m.Called(_a0, _a1) +// StateCommitmentByBlockID provides a mock function with given fields: _a0 +func (_m *ReadOnlyExecutionState) StateCommitmentByBlockID(_a0 flow.Identifier) (flow.StateCommitment, error) { + ret := _m.Called(_a0) var r0 flow.StateCommitment var r1 error - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier) (flow.StateCommitment, error)); ok { - return rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(flow.Identifier) (flow.StateCommitment, error)); ok { + return rf(_a0) } - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier) flow.StateCommitment); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(flow.Identifier) flow.StateCommitment); ok { + r0 = rf(_a0) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(flow.StateCommitment) } } - if rf, ok := ret.Get(1).(func(context.Context, flow.Identifier) error); ok { - r1 = rf(_a0, _a1) + if rf, ok := ret.Get(1).(func(flow.Identifier) error); ok { + r1 = rf(_a0) } else { r1 = ret.Error(1) } diff --git a/engine/execution/state/mock/register_updates_holder.go b/engine/execution/state/mock/register_updates_holder.go index 69c58edf06f..dd4239d2f6d 100644 --- a/engine/execution/state/mock/register_updates_holder.go +++ b/engine/execution/state/mock/register_updates_holder.go @@ -12,6 +12,22 @@ type RegisterUpdatesHolder struct { mock.Mock } +// UpdatedRegisterSet provides a mock function with given fields: +func (_m *RegisterUpdatesHolder) UpdatedRegisterSet() map[flow.RegisterID][]byte { + ret := _m.Called() + + var r0 map[flow.RegisterID][]byte + if rf, ok := ret.Get(0).(func() map[flow.RegisterID][]byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[flow.RegisterID][]byte) + } + } + + return r0 +} + // UpdatedRegisters provides a mock function with given fields: func (_m *RegisterUpdatesHolder) UpdatedRegisters() flow.RegisterEntries { ret := _m.Called() diff --git a/engine/execution/state/mock/script_execution_state.go b/engine/execution/state/mock/script_execution_state.go index 904defab7fa..520d699b3bb 100644 --- a/engine/execution/state/mock/script_execution_state.go +++ b/engine/execution/state/mock/script_execution_state.go @@ -3,8 +3,6 @@ package mock import ( - context "context" - flow "github.com/onflow/flow-go/model/flow" mock "github.com/stretchr/testify/mock" @@ -16,6 +14,41 @@ type ScriptExecutionState struct { mock.Mock } +// CreateStorageSnapshot provides a mock function with given fields: blockID +func (_m *ScriptExecutionState) CreateStorageSnapshot(blockID flow.Identifier) (snapshot.StorageSnapshot, *flow.Header, error) { + ret := _m.Called(blockID) + + var r0 snapshot.StorageSnapshot + var r1 *flow.Header + var r2 error + if rf, ok := ret.Get(0).(func(flow.Identifier) (snapshot.StorageSnapshot, *flow.Header, error)); ok { + return rf(blockID) + } + if rf, ok := ret.Get(0).(func(flow.Identifier) snapshot.StorageSnapshot); ok { + r0 = rf(blockID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snapshot.StorageSnapshot) + } + } + + if rf, ok := ret.Get(1).(func(flow.Identifier) *flow.Header); ok { + r1 = rf(blockID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*flow.Header) + } + } + + if rf, ok := ret.Get(2).(func(flow.Identifier) error); ok { + r2 = rf(blockID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // HasState provides a mock function with given fields: _a0 func (_m *ScriptExecutionState) HasState(_a0 flow.StateCommitment) bool { ret := _m.Called(_a0) @@ -30,13 +63,37 @@ func (_m *ScriptExecutionState) HasState(_a0 flow.StateCommitment) bool { return r0 } -// NewStorageSnapshot provides a mock function with given fields: _a0 -func (_m *ScriptExecutionState) NewStorageSnapshot(_a0 flow.StateCommitment) snapshot.StorageSnapshot { - ret := _m.Called(_a0) +// IsBlockExecuted provides a mock function with given fields: height, blockID +func (_m *ScriptExecutionState) IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) { + ret := _m.Called(height, blockID) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(uint64, flow.Identifier) (bool, error)); ok { + return rf(height, blockID) + } + if rf, ok := ret.Get(0).(func(uint64, flow.Identifier) bool); ok { + r0 = rf(height, blockID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(uint64, flow.Identifier) error); ok { + r1 = rf(height, blockID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewStorageSnapshot provides a mock function with given fields: commit, blockID, height +func (_m *ScriptExecutionState) NewStorageSnapshot(commit flow.StateCommitment, blockID flow.Identifier, height uint64) snapshot.StorageSnapshot { + ret := _m.Called(commit, blockID, height) var r0 snapshot.StorageSnapshot - if rf, ok := ret.Get(0).(func(flow.StateCommitment) snapshot.StorageSnapshot); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(flow.StateCommitment, flow.Identifier, uint64) snapshot.StorageSnapshot); ok { + r0 = rf(commit, blockID, height) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(snapshot.StorageSnapshot) @@ -46,25 +103,25 @@ func (_m *ScriptExecutionState) NewStorageSnapshot(_a0 flow.StateCommitment) sna return r0 } -// StateCommitmentByBlockID provides a mock function with given fields: _a0, _a1 -func (_m *ScriptExecutionState) StateCommitmentByBlockID(_a0 context.Context, _a1 flow.Identifier) (flow.StateCommitment, error) { - ret := _m.Called(_a0, _a1) +// StateCommitmentByBlockID provides a mock function with given fields: _a0 +func (_m *ScriptExecutionState) StateCommitmentByBlockID(_a0 flow.Identifier) (flow.StateCommitment, error) { + ret := _m.Called(_a0) var r0 flow.StateCommitment var r1 error - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier) (flow.StateCommitment, error)); ok { - return rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(flow.Identifier) (flow.StateCommitment, error)); ok { + return rf(_a0) } - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier) flow.StateCommitment); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(flow.Identifier) flow.StateCommitment); ok { + r0 = rf(_a0) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(flow.StateCommitment) } } - if rf, ok := ret.Get(1).(func(context.Context, flow.Identifier) error); ok { - r1 = rf(_a0, _a1) + if rf, ok := ret.Get(1).(func(flow.Identifier) error); ok { + r1 = rf(_a0) } else { r1 = ret.Error(1) } diff --git a/engine/execution/state/state.go b/engine/execution/state/state.go index 3d996475ff9..5df8f473e54 100644 --- a/engine/execution/state/state.go +++ b/engine/execution/state/state.go @@ -21,6 +21,9 @@ import ( "github.com/onflow/flow-go/storage/badger/procedure" ) +var ErrExecutionStatePruned = fmt.Errorf("execution state is pruned") +var ErrNotExecuted = fmt.Errorf("block not executed") + // ReadOnlyExecutionState allows to read the execution state type ReadOnlyExecutionState interface { ScriptExecutionState @@ -36,14 +39,37 @@ type ReadOnlyExecutionState interface { // ScriptExecutionState is a subset of the `state.ExecutionState` interface purposed to only access the state // used for script execution and not mutate the execution state of the blockchain. type ScriptExecutionState interface { - // NewStorageSnapshot creates a new ready-only view at the given state commitment. - NewStorageSnapshot(flow.StateCommitment) snapshot.StorageSnapshot + // NewStorageSnapshot creates a new ready-only view at the given block. + NewStorageSnapshot(commit flow.StateCommitment, blockID flow.Identifier, height uint64) snapshot.StorageSnapshot + + // CreateStorageSnapshot creates a new ready-only view at the given block. + // It returns: + // - (nil, nil, storage.ErrNotFound) if block is unknown + // - (nil, nil, state.ErrNotExecuted) if block is not executed + // - (nil, nil, state.ErrExecutionStatePruned) if the execution state has been pruned + CreateStorageSnapshot(blockID flow.Identifier) (snapshot.StorageSnapshot, *flow.Header, error) // StateCommitmentByBlockID returns the final state commitment for the provided block ID. - StateCommitmentByBlockID(context.Context, flow.Identifier) (flow.StateCommitment, error) + StateCommitmentByBlockID(flow.Identifier) (flow.StateCommitment, error) // HasState returns true if the state with the given state commitment exists in memory HasState(flow.StateCommitment) bool + + // Any error returned is exception + IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) +} + +func IsParentExecuted(state ReadOnlyExecutionState, header *flow.Header) (bool, error) { + // sanity check, caller should not pass a root block + if header.Height == 0 { + return false, fmt.Errorf("root block does not have parent block") + } + return state.IsBlockExecuted(header.Height-1, header.ParentID) +} + +// FinalizedExecutionState is an interface used to access the finalized execution state +type FinalizedExecutionState interface { + GetHighestFinalizedExecuted() uint64 } // TODO Many operations here are should be transactional, so we need to refactor this @@ -213,36 +239,80 @@ func (storage *LedgerStorageSnapshot) Get( func (s *state) NewStorageSnapshot( commitment flow.StateCommitment, + blockID flow.Identifier, + height uint64, ) snapshot.StorageSnapshot { return NewLedgerStorageSnapshot(s.ls, commitment) } +func (s *state) CreateStorageSnapshot( + blockID flow.Identifier, +) (snapshot.StorageSnapshot, *flow.Header, error) { + header, err := s.headers.ByBlockID(blockID) + if err != nil { + return nil, nil, fmt.Errorf("cannot get header by block ID: %w", err) + } + + // make sure the block is executed + commit, err := s.commits.ByBlockID(blockID) + if err != nil { + // statecommitment not exists means the block hasn't been executed yet + if errors.Is(err, storage.ErrNotFound) { + return nil, nil, fmt.Errorf("block %v is not executed: %w", blockID, ErrNotExecuted) + } + + return nil, header, fmt.Errorf("cannot get commit by block ID: %w", err) + } + + // make sure we have trie state for this block + if !s.HasState(commit) { + return nil, header, fmt.Errorf("state not found for commit %x (block %v): %w", commit, blockID, ErrExecutionStatePruned) + } + + return s.NewStorageSnapshot(commit, blockID, header.Height), header, nil +} + type RegisterUpdatesHolder interface { UpdatedRegisters() flow.RegisterEntries + UpdatedRegisterSet() map[flow.RegisterID]flow.RegisterValue } -func CommitDelta(ldg ledger.Ledger, ruh RegisterUpdatesHolder, baseState flow.StateCommitment) (flow.StateCommitment, *ledger.TrieUpdate, error) { - keys, values := RegisterEntriesToKeysValues(ruh.UpdatedRegisters()) +// CommitDelta takes a base storage snapshot and creates a new storage snapshot +// with the register updates from the given RegisterUpdatesHolder +// a new statecommitment is returned from the ledger, along with the trie update +// any error returned are exceptions +func CommitDelta( + ldg ledger.Ledger, + ruh RegisterUpdatesHolder, + baseStorageSnapshot execution.ExtendableStorageSnapshot, +) (flow.StateCommitment, *ledger.TrieUpdate, execution.ExtendableStorageSnapshot, error) { + updatedRegisters := ruh.UpdatedRegisters() + keys, values := RegisterEntriesToKeysValues(updatedRegisters) + baseState := baseStorageSnapshot.Commitment() update, err := ledger.NewUpdate(ledger.State(baseState), keys, values) if err != nil { - return flow.DummyStateCommitment, nil, fmt.Errorf("cannot create ledger update: %w", err) + return flow.DummyStateCommitment, nil, nil, fmt.Errorf("cannot create ledger update: %w", err) } - commit, trieUpdate, err := ldg.Set(update) + newState, trieUpdate, err := ldg.Set(update) if err != nil { - return flow.DummyStateCommitment, nil, err + return flow.DummyStateCommitment, nil, nil, fmt.Errorf("could not update ledger: %w", err) } - return flow.StateCommitment(commit), trieUpdate, nil + newCommit := flow.StateCommitment(newState) + + newStorageSnapshot := baseStorageSnapshot.Extend(newCommit, ruh.UpdatedRegisterSet()) + + return newCommit, trieUpdate, newStorageSnapshot, nil } func (s *state) HasState(commitment flow.StateCommitment) bool { return s.ls.HasState(ledger.State(commitment)) } -func (s *state) StateCommitmentByBlockID(ctx context.Context, blockID flow.Identifier) (flow.StateCommitment, error) { +func (s *state) StateCommitmentByBlockID(blockID flow.Identifier) (flow.StateCommitment, error) { return s.commits.ByBlockID(blockID) } @@ -390,10 +460,12 @@ func (s *state) GetHighestExecutedBlockID(ctx context.Context) (uint64, flow.Ide } // IsBlockExecuted returns true if the block is executed, which means registers, events, -// results, statecommitment etc are all stored. +// results, etc are all stored. // otherwise returns false -func IsBlockExecuted(ctx context.Context, state ReadOnlyExecutionState, block flow.Identifier) (bool, error) { - _, err := state.StateCommitmentByBlockID(ctx, block) +func (s *state) IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) { + // ledger-based execution state uses commitment to determine if a block has been executed + // TODO: storehouse-based execution state will check its storage to determine if a block has been executed + _, err := s.StateCommitmentByBlockID(blockID) // statecommitment exists means the block has been executed if err == nil { diff --git a/engine/execution/state/state_test.go b/engine/execution/state/state_test.go index 6d6833837f0..31699566e48 100644 --- a/engine/execution/state/state_test.go +++ b/engine/execution/state/state_test.go @@ -1,31 +1,32 @@ package state_test import ( - "context" + "errors" + "fmt" "testing" "github.com/dgraph-io/badger/v2" - "github.com/golang/mock/gomock" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - ledger2 "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/convert" "github.com/onflow/flow-go/ledger/common/pathfinder" "github.com/onflow/flow-go/engine/execution/state" + "github.com/onflow/flow-go/engine/execution/storehouse" "github.com/onflow/flow-go/fvm/storage/snapshot" ledger "github.com/onflow/flow-go/ledger/complete" "github.com/onflow/flow-go/ledger/complete/wal/fixtures" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/metrics" "github.com/onflow/flow-go/module/trace" + storageerr "github.com/onflow/flow-go/storage" storage "github.com/onflow/flow-go/storage/mock" - "github.com/onflow/flow-go/storage/mocks" "github.com/onflow/flow-go/utils/unittest" ) -func prepareTest(f func(t *testing.T, es state.ExecutionState, l *ledger.Ledger)) func(*testing.T) { +func prepareTest(f func(t *testing.T, es state.ExecutionState, l *ledger.Ledger, headers *storage.Headers, commits *storage.Commits)) func(*testing.T) { return func(t *testing.T) { unittest.RunWithBadgerDB(t, func(badgerDB *badger.DB) { metricsCollector := &metrics.NoopCollector{} @@ -39,74 +40,70 @@ func prepareTest(f func(t *testing.T, es state.ExecutionState, l *ledger.Ledger) <-compactor.Done() }() - ctrl := gomock.NewController(t) - - stateCommitments := mocks.NewMockCommits(ctrl) - blocks := mocks.NewMockBlocks(ctrl) - headers := mocks.NewMockHeaders(ctrl) - collections := mocks.NewMockCollections(ctrl) - events := mocks.NewMockEvents(ctrl) - serviceEvents := mocks.NewMockServiceEvents(ctrl) - txResults := mocks.NewMockTransactionResults(ctrl) - - stateCommitment := ls.InitialState() - - stateCommitments.EXPECT().ByBlockID(gomock.Any()).Return(flow.StateCommitment(stateCommitment), nil) - - chunkDataPacks := new(storage.ChunkDataPacks) - - results := new(storage.ExecutionResults) - myReceipts := new(storage.MyExecutionReceipts) + stateCommitments := storage.NewCommits(t) + headers := storage.NewHeaders(t) + blocks := storage.NewBlocks(t) + collections := storage.NewCollections(t) + events := storage.NewEvents(t) + serviceEvents := storage.NewServiceEvents(t) + txResults := storage.NewTransactionResults(t) + chunkDataPacks := storage.NewChunkDataPacks(t) + results := storage.NewExecutionResults(t) + myReceipts := storage.NewMyExecutionReceipts(t) es := state.NewExecutionState( ls, stateCommitments, blocks, headers, collections, chunkDataPacks, results, myReceipts, events, serviceEvents, txResults, badgerDB, trace.NewNoopTracer(), ) - f(t, es, ls) + f(t, es, ls, headers, stateCommitments) }) } } func TestExecutionStateWithTrieStorage(t *testing.T) { - registerID1 := flow.NewRegisterID("fruit", "") - - registerID2 := flow.NewRegisterID("vegetable", "") - - t.Run("commit write and read new state", prepareTest(func(t *testing.T, es state.ExecutionState, l *ledger.Ledger) { - // TODO: use real block ID - sc1, err := es.StateCommitmentByBlockID(context.Background(), flow.Identifier{}) - assert.NoError(t, err) + t.Run("commit write and read new state", prepareTest(func( + t *testing.T, es state.ExecutionState, l *ledger.Ledger, headers *storage.Headers, stateCommitments *storage.Commits) { + header1 := unittest.BlockHeaderFixture() + sc1 := flow.StateCommitment(l.InitialState()) + reg1 := unittest.MakeOwnerReg("fruit", "apple") + reg2 := unittest.MakeOwnerReg("vegetable", "carrot") executionSnapshot := &snapshot.ExecutionSnapshot{ WriteSet: map[flow.RegisterID]flow.RegisterValue{ - registerID1: flow.RegisterValue("apple"), - registerID2: flow.RegisterValue("carrot"), + reg1.Key: reg1.Value, + reg2.Key: reg2.Value, }, } - sc2, update, err := state.CommitDelta(l, executionSnapshot, sc1) + sc2, update, sc2Snapshot, err := state.CommitDelta(l, executionSnapshot, + storehouse.NewExecutingBlockSnapshot(state.NewLedgerStorageSnapshot(l, sc1), sc1)) assert.NoError(t, err) + // validate new snapshot + val, err := sc2Snapshot.Get(reg1.Key) + require.NoError(t, err) + require.Equal(t, reg1.Value, val) + + val, err = sc2Snapshot.Get(reg2.Key) + require.NoError(t, err) + require.Equal(t, reg2.Value, val) + assert.Equal(t, sc1[:], update.RootHash[:]) assert.Len(t, update.Paths, 2) assert.Len(t, update.Payloads, 2) - key1 := ledger2.NewKey( - []ledger2.KeyPart{ - ledger2.NewKeyPart(0, []byte(registerID1.Owner)), - ledger2.NewKeyPart(2, []byte(registerID1.Key)), - }) + // validate sc2 + require.Equal(t, sc2, sc2Snapshot.Commitment()) + + key1 := convert.RegisterIDToLedgerKey(reg1.Key) path1, err := pathfinder.KeyToPath(key1, ledger.DefaultPathFinderVersion) assert.NoError(t, err) - key2 := ledger2.NewKey( - []ledger2.KeyPart{ - ledger2.NewKeyPart(0, []byte(registerID2.Owner)), - ledger2.NewKeyPart(2, []byte(registerID2.Key)), - }) + key2 := convert.RegisterIDToLedgerKey(reg2.Key) path2, err := pathfinder.KeyToPath(key2, ledger.DefaultPathFinderVersion) assert.NoError(t, err) + // validate update assert.Equal(t, path1, update.Paths[0]) assert.Equal(t, path2, update.Paths[1]) @@ -122,122 +119,205 @@ func TestExecutionStateWithTrieStorage(t *testing.T) { assert.Equal(t, []byte("apple"), []byte(update.Payloads[0].Value())) assert.Equal(t, []byte("carrot"), []byte(update.Payloads[1].Value())) - storageSnapshot := es.NewStorageSnapshot(sc2) + header2 := unittest.BlockHeaderWithParentFixture(header1) + storageSnapshot := es.NewStorageSnapshot(sc2, header2.ID(), header2.Height) - b1, err := storageSnapshot.Get(registerID1) + b1, err := storageSnapshot.Get(reg1.Key) assert.NoError(t, err) - b2, err := storageSnapshot.Get(registerID2) + b2, err := storageSnapshot.Get(reg2.Key) assert.NoError(t, err) assert.Equal(t, flow.RegisterValue("apple"), b1) assert.Equal(t, flow.RegisterValue("carrot"), b2) + + // verify has state + require.True(t, es.HasState(sc2)) + require.False(t, es.HasState(unittest.StateCommitmentFixture())) })) - t.Run("commit write and read previous state", prepareTest(func(t *testing.T, es state.ExecutionState, l *ledger.Ledger) { - // TODO: use real block ID - sc1, err := es.StateCommitmentByBlockID(context.Background(), flow.Identifier{}) - assert.NoError(t, err) + t.Run("commit write and read previous state", prepareTest(func( + t *testing.T, es state.ExecutionState, l *ledger.Ledger, headers *storage.Headers, stateCommitments *storage.Commits) { + header1 := unittest.BlockHeaderFixture() + sc1 := flow.StateCommitment(l.InitialState()) + reg1 := unittest.MakeOwnerReg("fruit", "apple") executionSnapshot1 := &snapshot.ExecutionSnapshot{ WriteSet: map[flow.RegisterID]flow.RegisterValue{ - registerID1: []byte("apple"), + reg1.Key: reg1.Value, }, } - sc2, _, err := state.CommitDelta(l, executionSnapshot1, sc1) + sc2, _, sc2Snapshot, err := state.CommitDelta(l, executionSnapshot1, + storehouse.NewExecutingBlockSnapshot(state.NewLedgerStorageSnapshot(l, sc1), sc1), + ) assert.NoError(t, err) // update value and get resulting state commitment executionSnapshot2 := &snapshot.ExecutionSnapshot{ WriteSet: map[flow.RegisterID]flow.RegisterValue{ - registerID1: []byte("orange"), + reg1.Key: flow.RegisterValue("orange"), }, } - sc3, _, err := state.CommitDelta(l, executionSnapshot2, sc2) + sc3, _, _, err := state.CommitDelta(l, executionSnapshot2, sc2Snapshot) assert.NoError(t, err) + header2 := unittest.BlockHeaderWithParentFixture(header1) // create a view for previous state version - storageSnapshot3 := es.NewStorageSnapshot(sc2) + storageSnapshot3 := es.NewStorageSnapshot(sc2, header2.ID(), header2.Height) + header3 := unittest.BlockHeaderWithParentFixture(header1) // create a view for new state version - storageSnapshot4 := es.NewStorageSnapshot(sc3) + storageSnapshot4 := es.NewStorageSnapshot(sc3, header3.ID(), header3.Height) + + // header2 and header3 are different blocks + assert.True(t, header2.ID() != (header3.ID())) // fetch the value at both versions - b1, err := storageSnapshot3.Get(registerID1) + b1, err := storageSnapshot3.Get(reg1.Key) assert.NoError(t, err) - b2, err := storageSnapshot4.Get(registerID1) + b2, err := storageSnapshot4.Get(reg1.Key) assert.NoError(t, err) assert.Equal(t, flow.RegisterValue("apple"), b1) assert.Equal(t, flow.RegisterValue("orange"), b2) })) - t.Run("commit delta and read new state", prepareTest(func(t *testing.T, es state.ExecutionState, l *ledger.Ledger) { - // TODO: use real block ID - sc1, err := es.StateCommitmentByBlockID(context.Background(), flow.Identifier{}) - assert.NoError(t, err) + t.Run("commit delta and read new state", prepareTest(func( + t *testing.T, es state.ExecutionState, l *ledger.Ledger, headers *storage.Headers, stateCommitments *storage.Commits) { + header1 := unittest.BlockHeaderFixture() + sc1 := flow.StateCommitment(l.InitialState()) + reg1 := unittest.MakeOwnerReg("fruit", "apple") + reg2 := unittest.MakeOwnerReg("vegetable", "carrot") // set initial value executionSnapshot1 := &snapshot.ExecutionSnapshot{ WriteSet: map[flow.RegisterID]flow.RegisterValue{ - registerID1: []byte("apple"), - registerID2: []byte("apple"), + reg1.Key: reg1.Value, + reg2.Key: reg2.Value, }, } - sc2, _, err := state.CommitDelta(l, executionSnapshot1, sc1) + sc2, _, sc2Snapshot, err := state.CommitDelta(l, executionSnapshot1, + storehouse.NewExecutingBlockSnapshot(state.NewLedgerStorageSnapshot(l, sc1), sc1), + ) assert.NoError(t, err) // update value and get resulting state commitment executionSnapshot2 := &snapshot.ExecutionSnapshot{ WriteSet: map[flow.RegisterID]flow.RegisterValue{ - registerID1: nil, + reg1.Key: nil, }, } - sc3, _, err := state.CommitDelta(l, executionSnapshot2, sc2) + sc3, _, _, err := state.CommitDelta(l, executionSnapshot2, sc2Snapshot) assert.NoError(t, err) + header2 := unittest.BlockHeaderWithParentFixture(header1) // create a view for previous state version - storageSnapshot3 := es.NewStorageSnapshot(sc2) + storageSnapshot3 := es.NewStorageSnapshot(sc2, header2.ID(), header2.Height) + header3 := unittest.BlockHeaderWithParentFixture(header2) // create a view for new state version - storageSnapshot4 := es.NewStorageSnapshot(sc3) + storageSnapshot4 := es.NewStorageSnapshot(sc3, header3.ID(), header3.Height) // fetch the value at both versions - b1, err := storageSnapshot3.Get(registerID1) + b1, err := storageSnapshot3.Get(reg1.Key) assert.NoError(t, err) - b2, err := storageSnapshot4.Get(registerID1) + b2, err := storageSnapshot4.Get(reg1.Key) assert.NoError(t, err) assert.Equal(t, flow.RegisterValue("apple"), b1) assert.Empty(t, b2) })) - t.Run("commit delta and persist state commit for the second time should be OK", prepareTest(func(t *testing.T, es state.ExecutionState, l *ledger.Ledger) { - // TODO: use real block ID - sc1, err := es.StateCommitmentByBlockID(context.Background(), flow.Identifier{}) - assert.NoError(t, err) + t.Run("commit delta and persist state commit for the second time should be OK", prepareTest(func( + t *testing.T, es state.ExecutionState, l *ledger.Ledger, headers *storage.Headers, stateCommitments *storage.Commits) { + sc1 := flow.StateCommitment(l.InitialState()) + reg1 := unittest.MakeOwnerReg("fruit", "apple") + reg2 := unittest.MakeOwnerReg("vegetable", "carrot") // set initial value executionSnapshot1 := &snapshot.ExecutionSnapshot{ WriteSet: map[flow.RegisterID]flow.RegisterValue{ - registerID1: flow.RegisterValue("apple"), - registerID2: flow.RegisterValue("apple"), + reg1.Key: reg1.Value, + reg2.Key: reg2.Value, }, } - sc2, _, err := state.CommitDelta(l, executionSnapshot1, sc1) + sc2, _, _, err := state.CommitDelta(l, executionSnapshot1, + storehouse.NewExecutingBlockSnapshot(state.NewLedgerStorageSnapshot(l, sc1), sc1), + ) assert.NoError(t, err) // committing for the second time should be OK - sc2Same, _, err := state.CommitDelta(l, executionSnapshot1, sc1) + sc2Same, _, _, err := state.CommitDelta(l, executionSnapshot1, + storehouse.NewExecutingBlockSnapshot(state.NewLedgerStorageSnapshot(l, sc1), sc1), + ) assert.NoError(t, err) require.Equal(t, sc2, sc2Same) })) + t.Run("commit write and create snapshot", prepareTest(func( + t *testing.T, es state.ExecutionState, l *ledger.Ledger, headers *storage.Headers, stateCommitments *storage.Commits) { + header1 := unittest.BlockHeaderFixture() + header2 := unittest.BlockHeaderWithParentFixture(header1) + sc1 := flow.StateCommitment(l.InitialState()) + + reg1 := unittest.MakeOwnerReg("fruit", "apple") + reg2 := unittest.MakeOwnerReg("vegetable", "carrot") + executionSnapshot := &snapshot.ExecutionSnapshot{ + WriteSet: map[flow.RegisterID]flow.RegisterValue{ + reg1.Key: reg1.Value, + reg2.Key: reg2.Value, + }, + } + + sc2, _, _, err := state.CommitDelta(l, executionSnapshot, + storehouse.NewExecutingBlockSnapshot(state.NewLedgerStorageSnapshot(l, sc1), sc1)) + assert.NoError(t, err) + + // test CreateStorageSnapshot for known and executed block + headers.On("ByBlockID", header2.ID()).Return(header2, nil) + stateCommitments.On("ByBlockID", header2.ID()).Return(sc2, nil) + snapshot2, h2, err := es.CreateStorageSnapshot(header2.ID()) + require.NoError(t, err) + require.Equal(t, header2.ID(), h2.ID()) + + val, err := snapshot2.Get(reg1.Key) + require.NoError(t, err) + require.Equal(t, val, reg1.Value) + + val, err = snapshot2.Get(reg2.Key) + require.NoError(t, err) + require.Equal(t, val, reg2.Value) + + // test CreateStorageSnapshot for unknown block + unknown := unittest.BlockHeaderFixture() + headers.On("ByBlockID", unknown.ID()).Return(nil, fmt.Errorf("unknown: %w", storageerr.ErrNotFound)) + _, _, err = es.CreateStorageSnapshot(unknown.ID()) + require.Error(t, err) + require.True(t, errors.Is(err, storageerr.ErrNotFound)) + + // test CreateStorageSnapshot for known and unexecuted block + unexecuted := unittest.BlockHeaderFixture() + headers.On("ByBlockID", unexecuted.ID()).Return(unexecuted, nil) + stateCommitments.On("ByBlockID", unexecuted.ID()).Return(nil, fmt.Errorf("not found: %w", storageerr.ErrNotFound)) + _, _, err = es.CreateStorageSnapshot(unexecuted.ID()) + require.Error(t, err) + require.True(t, errors.Is(err, state.ErrNotExecuted)) + + // test CreateStorageSnapshot for pruned block + pruned := unittest.BlockHeaderFixture() + prunedState := unittest.StateCommitmentFixture() + headers.On("ByBlockID", pruned.ID()).Return(pruned, nil) + stateCommitments.On("ByBlockID", pruned.ID()).Return(prunedState, nil) + _, _, err = es.CreateStorageSnapshot(pruned.ID()) + require.Error(t, err) + require.True(t, errors.Is(err, state.ErrExecutionStatePruned)) + })) + } diff --git a/engine/execution/storehouse.go b/engine/execution/storehouse.go index 21f2add53d9..ab3ebf66e90 100644 --- a/engine/execution/storehouse.go +++ b/engine/execution/storehouse.go @@ -1,8 +1,11 @@ package execution import ( + "github.com/onflow/flow-go/fvm/storage/snapshot" "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/finalizedreader" "github.com/onflow/flow-go/storage" + "github.com/onflow/flow-go/storage/pebble" ) // RegisterStore is the interface for register store @@ -11,7 +14,7 @@ type RegisterStore interface { // GetRegister first try to get the register from InMemoryRegisterStore, then OnDiskRegisterStore // It returns: // - (value, nil) if the register value is found at the given block - // - (nil, storage.ErrNotFound) if the register is not found + // - (nil, nil) if the register is not found // - (nil, storage.ErrHeightNotIndexed) if the height is below the first height that is indexed. // - (nil, storehouse.ErrNotExecuted) if the block is not executed yet // - (nil, storehouse.ErrNotExecuted) if the block is conflicting iwth finalized block @@ -26,9 +29,9 @@ type RegisterStore interface { // - exception if the block is below the pruned height // - exception if the save block is saved again // - exception for any other exception - SaveRegisters(header *flow.Header, registers []flow.RegisterEntry) error + SaveRegisters(header *flow.Header, registers flow.RegisterEntries) error - // Depend on FinalizedReader's GetFinalizedBlockIDAtHeight + // Depend on FinalizedReader's FinalizedBlockIDAtHeight // Depend on ExecutedFinalizedWAL.Append // Depend on OnDiskRegisterStore.SaveRegisters // OnBlockFinalized trigger the check of whether a block at the next height becomes finalized and executed. @@ -55,9 +58,15 @@ type RegisterStore interface { } type FinalizedReader interface { + // FinalizedBlockIDAtHeight returns the block ID of the finalized block at the given height. + // It return storage.NotFound if the given height has not been finalized yet + // any other error returned are exceptions FinalizedBlockIDAtHeight(height uint64) (flow.Identifier, error) } +// finalizedreader.FinalizedReader is an implementation of FinalizedReader interface +var _ FinalizedReader = (*finalizedreader.FinalizedReader)(nil) + // see implementation in engine/execution/storehouse/in_memory_register_store.go type InMemoryRegisterStore interface { Prune(finalizedHeight uint64, finalizedBlockID flow.Identifier) error @@ -72,7 +81,7 @@ type InMemoryRegisterStore interface { height uint64, blockID flow.Identifier, parentID flow.Identifier, - registers []flow.RegisterEntry, + registers flow.RegisterEntries, ) error IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) @@ -80,8 +89,11 @@ type InMemoryRegisterStore interface { type OnDiskRegisterStore = storage.RegisterIndex +// pebble.Registers is an implementation of OnDiskRegisterStore interface +var _ OnDiskRegisterStore = (*pebble.Registers)(nil) + type ExecutedFinalizedWAL interface { - Append(height uint64, registers []flow.RegisterEntry) error + Append(height uint64, registers flow.RegisterEntries) error // Latest returns the latest height in the WAL. Latest() (uint64, error) @@ -92,5 +104,11 @@ type ExecutedFinalizedWAL interface { type WALReader interface { // Next returns the next height and trie updates in the WAL. // It returns EOF when there are no more entries. - Next() (height uint64, registers []flow.RegisterEntry, err error) + Next() (height uint64, registers flow.RegisterEntries, err error) +} + +type ExtendableStorageSnapshot interface { + snapshot.StorageSnapshot + Extend(newCommit flow.StateCommitment, updatedRegisters map[flow.RegisterID]flow.RegisterValue) ExtendableStorageSnapshot + Commitment() flow.StateCommitment } diff --git a/engine/execution/storehouse/block_end_snapshot.go b/engine/execution/storehouse/block_end_snapshot.go new file mode 100644 index 00000000000..bf7718a9543 --- /dev/null +++ b/engine/execution/storehouse/block_end_snapshot.go @@ -0,0 +1,88 @@ +package storehouse + +import ( + "errors" + "sync" + + "github.com/onflow/flow-go/engine/execution" + "github.com/onflow/flow-go/fvm/storage/snapshot" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/storage" +) + +var _ snapshot.StorageSnapshot = (*BlockEndStateSnapshot)(nil) + +// BlockEndStateSnapshot represents the storage at the end of a block. +type BlockEndStateSnapshot struct { + storage execution.RegisterStore + + blockID flow.Identifier + height uint64 + + mutex sync.RWMutex + readCache map[flow.RegisterID]flow.RegisterValue // cache the reads from storage at baseBlock +} + +// the caller must ensure the block height is for the given block +func NewBlockEndStateSnapshot( + storage execution.RegisterStore, + blockID flow.Identifier, + height uint64, +) *BlockEndStateSnapshot { + return &BlockEndStateSnapshot{ + storage: storage, + blockID: blockID, + height: height, + readCache: make(map[flow.RegisterID]flow.RegisterValue), + } +} + +// Get returns the value of the register with the given register ID. +// It returns: +// - (value, nil) if the register exists +// - (nil, nil) if the register does not exist +// - (nil, storage.ErrHeightNotIndexed) if the height is below the first height that is indexed. +// - (nil, storehouse.ErrNotExecuted) if the block is not executed yet +// - (nil, storehouse.ErrNotExecuted) if the block is conflicting with finalized block +// - (nil, err) for any other exceptions +func (s *BlockEndStateSnapshot) Get(id flow.RegisterID) (flow.RegisterValue, error) { + value, ok := s.getFromCache(id) + if ok { + return value, nil + } + + value, err := s.getFromStorage(id) + if err != nil { + return nil, err + } + + s.mutex.Lock() + defer s.mutex.Unlock() + + // TODO: consider adding a limit/eviction policy for the cache + s.readCache[id] = value + return value, err +} + +func (s *BlockEndStateSnapshot) getFromCache(id flow.RegisterID) (flow.RegisterValue, bool) { + s.mutex.RLock() + defer s.mutex.RUnlock() + + value, ok := s.readCache[id] + return value, ok +} + +func (s *BlockEndStateSnapshot) getFromStorage(id flow.RegisterID) (flow.RegisterValue, error) { + value, err := s.storage.GetRegister(s.height, s.blockID, id) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + // if the error is not found, we return a nil RegisterValue, + // in this case, the nil value can be cached, because the storage will not change it + return nil, nil + } + // if the error is not ErrNotFound, such as storage.ErrHeightNotIndexed, storehouse.ErrNotExecuted + // we return the error without caching + return nil, err + } + return value, nil +} diff --git a/engine/execution/storehouse/block_end_snapshot_test.go b/engine/execution/storehouse/block_end_snapshot_test.go new file mode 100644 index 00000000000..3787ec2d552 --- /dev/null +++ b/engine/execution/storehouse/block_end_snapshot_test.go @@ -0,0 +1,102 @@ +package storehouse_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + executionMock "github.com/onflow/flow-go/engine/execution/mock" + "github.com/onflow/flow-go/engine/execution/storehouse" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/storage" + "github.com/onflow/flow-go/utils/unittest" +) + +func TestBlockEndSnapshot(t *testing.T) { + t.Run("Get register", func(t *testing.T) { + header := unittest.BlockHeaderFixture() + + // create mock for storage + store := executionMock.NewRegisterStore(t) + reg := unittest.MakeOwnerReg("key", "value") + store.On("GetRegister", header.Height, header.ID(), reg.Key).Return(reg.Value, nil).Once() + snapshot := storehouse.NewBlockEndStateSnapshot(store, header.ID(), header.Height) + + // test get from storage + value, err := snapshot.Get(reg.Key) + require.NoError(t, err) + require.Equal(t, reg.Value, value) + + // test get from cache + value, err = snapshot.Get(reg.Key) + require.NoError(t, err) + require.Equal(t, reg.Value, value) + + // test get non existing register + unknownReg := unittest.MakeOwnerReg("unknown", "unknown") + store.On("GetRegister", header.Height, header.ID(), unknownReg.Key). + Return(nil, fmt.Errorf("fail: %w", storage.ErrNotFound)).Once() + + value, err = snapshot.Get(unknownReg.Key) + require.NoError(t, err) + require.Nil(t, value) + + // test get non existing register from cache + _, err = snapshot.Get(unknownReg.Key) + require.NoError(t, err) + require.Nil(t, value) + + // test getting storage.ErrHeightNotIndexed error + heightNotIndexed := unittest.MakeOwnerReg("height not index", "height not index") + store.On("GetRegister", header.Height, header.ID(), heightNotIndexed.Key). + Return(nil, fmt.Errorf("fail: %w", storage.ErrHeightNotIndexed)). + Twice() // to verify the result is not cached + + // verify getting the correct error + _, err = snapshot.Get(heightNotIndexed.Key) + require.Error(t, err) + require.True(t, errors.Is(err, storage.ErrHeightNotIndexed)) + + // verify result is not cached + _, err = snapshot.Get(heightNotIndexed.Key) + require.Error(t, err) + require.True(t, errors.Is(err, storage.ErrHeightNotIndexed)) + + // test getting storage.ErrNotExecuted error + heightNotExecuted := unittest.MakeOwnerReg("height not executed", "height not executed") + counter := atomic.NewInt32(0) + store. + On("GetRegister", header.Height, header.ID(), heightNotExecuted.Key). + Return(func(uint64, flow.Identifier, flow.RegisterID) (flow.RegisterValue, error) { + counter.Inc() + // the first call should return error + if counter.Load() == 1 { + return nil, fmt.Errorf("fail: %w", storehouse.ErrNotExecuted) + } + // the second call, it returns value + return heightNotExecuted.Value, nil + }). + Times(2) + + // first time should return error + _, err = snapshot.Get(heightNotExecuted.Key) + require.Error(t, err) + require.True(t, errors.Is(err, storehouse.ErrNotExecuted)) + + // second time should return value + value, err = snapshot.Get(heightNotExecuted.Key) + require.NoError(t, err) + require.Equal(t, heightNotExecuted.Value, value) + + // third time should be cached + value, err = snapshot.Get(heightNotExecuted.Key) + require.NoError(t, err) + require.Equal(t, heightNotExecuted.Value, value) + + store.AssertExpectations(t) + }) + +} diff --git a/engine/execution/storehouse/executing_block_snapshot.go b/engine/execution/storehouse/executing_block_snapshot.go new file mode 100644 index 00000000000..e9e9b97c32b --- /dev/null +++ b/engine/execution/storehouse/executing_block_snapshot.go @@ -0,0 +1,76 @@ +package storehouse + +import ( + "github.com/onflow/flow-go/engine/execution" + "github.com/onflow/flow-go/fvm/storage/snapshot" + "github.com/onflow/flow-go/model/flow" +) + +var _ execution.ExtendableStorageSnapshot = (*ExecutingBlockSnapshot)(nil) + +// ExecutingBlockSnapshot is a snapshot of the storage at an executed collection. +// It starts with a storage snapshot at the end of previous block, +// The register updates at the executed collection at baseHeight + 1 are cached in +// a map, such that retrieving register values at the snapshot will first check +// the cache, and then the storage. +type ExecutingBlockSnapshot struct { + // the snapshot at the end of previous block + previous snapshot.StorageSnapshot + + commitment flow.StateCommitment + registerUpdates map[flow.RegisterID]flow.RegisterValue +} + +// create a new storage snapshot for an executed collection +// at the base block at height h - 1 +func NewExecutingBlockSnapshot( + previous snapshot.StorageSnapshot, + // the statecommitment of a block at height h + commitment flow.StateCommitment, +) *ExecutingBlockSnapshot { + return &ExecutingBlockSnapshot{ + previous: previous, + commitment: commitment, + registerUpdates: make(map[flow.RegisterID]flow.RegisterValue), + } +} + +// Get returns the register value at the snapshot. +func (s *ExecutingBlockSnapshot) Get(id flow.RegisterID) (flow.RegisterValue, error) { + // get from latest updates first + value, ok := s.getFromUpdates(id) + if ok { + return value, nil + } + + // get from BlockEndStateSnapshot at previous block + value, err := s.previous.Get(id) + return value, err +} + +func (s *ExecutingBlockSnapshot) getFromUpdates(id flow.RegisterID) (flow.RegisterValue, bool) { + value, ok := s.registerUpdates[id] + return value, ok +} + +// Extend returns a new storage snapshot at the same block but but for a different state commitment, +// which contains the given registerUpdates +// Usually it's used to create a new storage snapshot at the next executed collection. +// The registerUpdates contains the register updates at the executed collection. +func (s *ExecutingBlockSnapshot) Extend(newCommit flow.StateCommitment, updates map[flow.RegisterID]flow.RegisterValue) execution.ExtendableStorageSnapshot { + // if there is no update, we can return the original snapshot directly + // instead of wrapping it with a new ExecutingBlockSnapshot that has no update + if len(updates) == 0 { + return s + } + + return &ExecutingBlockSnapshot{ + previous: s, + commitment: newCommit, + registerUpdates: updates, + } +} + +func (s *ExecutingBlockSnapshot) Commitment() flow.StateCommitment { + return s.commitment +} diff --git a/engine/execution/storehouse/executing_block_snapshot_test.go b/engine/execution/storehouse/executing_block_snapshot_test.go new file mode 100644 index 00000000000..616430ec858 --- /dev/null +++ b/engine/execution/storehouse/executing_block_snapshot_test.go @@ -0,0 +1,92 @@ +package storehouse_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/engine/execution/storehouse" + "github.com/onflow/flow-go/fvm/storage/snapshot" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +func TestExtendingBlockSnapshot(t *testing.T) { + t.Run("Get register", func(t *testing.T) { + reg1 := makeReg("key1", "val1") + base := snapshot.MapStorageSnapshot{ + reg1.Key: reg1.Value, + } + baseCommit := unittest.StateCommitmentFixture() + snap := storehouse.NewExecutingBlockSnapshot(base, baseCommit) + + // should get value + value, err := snap.Get(reg1.Key) + require.NoError(t, err) + require.Equal(t, reg1.Value, value) + + // should get nil for unknown register + unknown := makeReg("unknown", "unknownV") + value, err = snap.Get(unknown.Key) + require.NoError(t, err) + require.Equal(t, []byte(nil), value) + }) + + t.Run("Extend snapshot", func(t *testing.T) { + reg1 := makeReg("key1", "val1") + reg2 := makeReg("key2", "val2") + base := snapshot.MapStorageSnapshot{ + reg1.Key: reg1.Value, + reg2.Key: reg2.Value, + } + // snap1: { key1: val1, key2: val2 } + snap1 := storehouse.NewExecutingBlockSnapshot(base, unittest.StateCommitmentFixture()) + + updatedReg2 := makeReg("key2", "val22") + reg3 := makeReg("key3", "val3") + // snap2: { key1: val1, key2: val22, key3: val3 } + snap2 := snap1.Extend(unittest.StateCommitmentFixture(), map[flow.RegisterID]flow.RegisterValue{ + updatedReg2.Key: updatedReg2.Value, + reg3.Key: reg3.Value, + }) + + // should get un-changed value + value, err := snap2.Get(reg1.Key) + require.NoError(t, err) + require.Equal(t, []byte("val1"), value) + + value, err = snap2.Get(reg2.Key) + require.NoError(t, err) + require.Equal(t, []byte("val22"), value) + + value, err = snap2.Get(reg3.Key) + require.NoError(t, err) + require.Equal(t, []byte("val3"), value) + + // should get nil for unknown register + unknown := makeReg("unknown", "unknownV") + value, err = snap2.Get(unknown.Key) + require.NoError(t, err) + require.Equal(t, []byte(nil), value) + + // create snap3 with reg3 updated + // snap3: { key1: val1, key2: val22, key3: val33 } + updatedReg3 := makeReg("key3", "val33") + snap3 := snap2.Extend(unittest.StateCommitmentFixture(), map[flow.RegisterID]flow.RegisterValue{ + updatedReg3.Key: updatedReg3.Value, + }) + + // verify all keys + value, err = snap3.Get(reg1.Key) + require.NoError(t, err) + require.Equal(t, []byte("val1"), value) + + value, err = snap3.Get(reg2.Key) + require.NoError(t, err) + require.Equal(t, []byte("val22"), value) + + value, err = snap3.Get(reg3.Key) + require.NoError(t, err) + require.Equal(t, []byte("val33"), value) + }) +} diff --git a/engine/execution/storehouse/in_memory_register_store.go b/engine/execution/storehouse/in_memory_register_store.go new file mode 100644 index 00000000000..085e49271c6 --- /dev/null +++ b/engine/execution/storehouse/in_memory_register_store.go @@ -0,0 +1,333 @@ +package storehouse + +import ( + "errors" + "fmt" + "sync" + + "github.com/onflow/flow-go/engine/execution" + "github.com/onflow/flow-go/model/flow" +) + +var _ execution.InMemoryRegisterStore = (*InMemoryRegisterStore)(nil) + +var ErrNotExecuted = fmt.Errorf("block is not executed") + +type PrunedError struct { + PrunedHeight uint64 + Height uint64 +} + +func NewPrunedError(height uint64, prunedHeight uint64) error { + return PrunedError{Height: height, PrunedHeight: prunedHeight} +} + +func (e PrunedError) Error() string { + return fmt.Sprintf("block is pruned at height %d", e.Height) +} + +func IsPrunedError(err error) (PrunedError, bool) { + var e PrunedError + ok := errors.As(err, &e) + if ok { + return e, true + } + return PrunedError{}, false +} + +type InMemoryRegisterStore struct { + sync.RWMutex + registersByBlockID map[flow.Identifier]map[flow.RegisterID]flow.RegisterValue // for storing the registers + parentByBlockID map[flow.Identifier]flow.Identifier // for register updates to be fork-aware + blockIDsByHeight map[uint64]map[flow.Identifier]struct{} // for pruning + prunedHeight uint64 // registers at pruned height are pruned (not saved in registersByBlockID) + prunedID flow.Identifier // to ensure all blocks are extending from pruned block (last finalized and executed block) +} + +func NewInMemoryRegisterStore(lastHeight uint64, lastID flow.Identifier) *InMemoryRegisterStore { + return &InMemoryRegisterStore{ + registersByBlockID: make(map[flow.Identifier]map[flow.RegisterID]flow.RegisterValue), + parentByBlockID: make(map[flow.Identifier]flow.Identifier), + blockIDsByHeight: make(map[uint64]map[flow.Identifier]struct{}), + prunedHeight: lastHeight, + prunedID: lastID, + } +} + +// SaveRegisters saves the registers of a block to InMemoryRegisterStore +// It needs to ensure the block is above the pruned height and is connected to the pruned block +func (s *InMemoryRegisterStore) SaveRegisters( + height uint64, + blockID flow.Identifier, + parentID flow.Identifier, + registers flow.RegisterEntries, +) error { + // preprocess data before acquiring the lock + regs := make(map[flow.RegisterID]flow.RegisterValue, len(registers)) + for _, reg := range registers { + regs[reg.Key] = reg.Value + } + + s.Lock() + defer s.Unlock() + + // ensure all saved registers are above the pruned height + if height <= s.prunedHeight { + return fmt.Errorf("saving pruned registers height %v <= pruned height %v", height, s.prunedHeight) + } + + // ensure the block is not already saved + _, ok := s.registersByBlockID[blockID] + if ok { + // already exist + return fmt.Errorf("saving registers for block %s, but it already exists", blockID) + } + + // make sure parent is a known block or the pruned block, which forms a fork + _, ok = s.registersByBlockID[parentID] + if !ok && parentID != s.prunedID { + return fmt.Errorf("saving registers for block %s, but its parent %s is not saved", blockID, parentID) + } + + // update registers for the block + s.registersByBlockID[blockID] = regs + + // update index on parent + s.parentByBlockID[blockID] = parentID + + // update index on height + sameHeight, ok := s.blockIDsByHeight[height] + if !ok { + sameHeight = make(map[flow.Identifier]struct{}) + s.blockIDsByHeight[height] = sameHeight + } + + sameHeight[blockID] = struct{}{} + return nil +} + +// GetRegister will return the latest updated value of the given register +// since the pruned height. +// It returns PrunedError if the register is unknown or not updated since the pruned height +// Can't return ErrNotFound, since we can't distinguish between not found or not updated since the pruned height +func (s *InMemoryRegisterStore) GetRegister(height uint64, blockID flow.Identifier, register flow.RegisterID) (flow.RegisterValue, error) { + s.RLock() + defer s.RUnlock() + + if height <= s.prunedHeight { + return flow.RegisterValue{}, NewPrunedError(height, s.prunedHeight) + } + + _, ok := s.registersByBlockID[blockID] + if !ok { + return flow.RegisterValue{}, fmt.Errorf("cannot get register at height %d, block %v is not saved: %w", height, blockID, ErrNotExecuted) + } + + // traverse the fork to find the latest updated value of the given register + // if not found, it means the register is not updated from the pruned block to the given block + block := blockID + for { + // TODO: do not hold the read lock when reading register from the updated register map + reg, ok := s.readRegisterAtBlockID(block, register) + if ok { + return reg, nil + } + + // the register didn't get updated at this block, so check its parent + + parent, ok := s.parentByBlockID[block] + if !ok { + // if the parent doesn't exist because the block itself is the pruned block, + // then it means the register is not updated since the pruned height. + // since we can't distinguish whether the register is not updated or not exist at all, + // we just return PrunedError error along with the prunedHeight, so the + // caller could check with OnDiskRegisterStore to find if this register has a updated value + // at earlier height. + if block == s.prunedID { + return flow.RegisterValue{}, NewPrunedError(height, s.prunedHeight) + } + + // in this case, it means the state of in-memory register store is inconsistent, + // because all saved block must have their parent saved in `parentByBlockID`, and traversing + // its parent should eventually reach the pruned block, otherwise it's a bug. + + return flow.RegisterValue{}, + fmt.Errorf("inconsistent parent block index in in-memory-register-store, ancient block %v is not found when getting register at block %v", + block, blockID) + } + + block = parent + } +} + +func (s *InMemoryRegisterStore) readRegisterAtBlockID(blockID flow.Identifier, register flow.RegisterID) (flow.RegisterValue, bool) { + registers, ok := s.registersByBlockID[blockID] + if !ok { + return flow.RegisterValue{}, false + } + + value, ok := registers[register] + return value, ok +} + +// GetUpdatedRegisters returns the updated registers of a block +func (s *InMemoryRegisterStore) GetUpdatedRegisters(height uint64, blockID flow.Identifier) (flow.RegisterEntries, error) { + registerUpdates, err := s.getUpdatedRegisters(height, blockID) + if err != nil { + return nil, err + } + + // since the registerUpdates won't be updated and registers for a block can only be set once, + // we don't need to hold the lock when converting it from map into slice. + registers := make(flow.RegisterEntries, 0, len(registerUpdates)) + for regID, reg := range registerUpdates { + registers = append(registers, flow.RegisterEntry{ + Key: regID, + Value: reg, + }) + } + + return registers, nil +} + +func (s *InMemoryRegisterStore) getUpdatedRegisters(height uint64, blockID flow.Identifier) (map[flow.RegisterID]flow.RegisterValue, error) { + s.RLock() + defer s.RUnlock() + if height <= s.prunedHeight { + return nil, fmt.Errorf("cannot get register at height %d, it is pruned %v", height, s.prunedHeight) + } + + registerUpdates, ok := s.registersByBlockID[blockID] + if !ok { + return nil, fmt.Errorf("cannot get register at height %d, block %s is not found: %w", height, blockID, ErrNotExecuted) + } + return registerUpdates, nil +} + +// Prune prunes the register store to the given height +// The pruned height must be an executed block, the caller should ensure that by calling SaveRegisters before. +// +// Pruning is done by walking up the finalized fork from `s.prunedHeight` to `height`. At each height, prune all +// other forks that begin at that height. This ensures that data for all conflicting forks are freed +// +// TODO: It does not block the caller, the pruning work is done async +func (s *InMemoryRegisterStore) Prune(height uint64, blockID flow.Identifier) error { + finalizedFork, err := s.findFinalizedFork(height, blockID) + if err != nil { + return fmt.Errorf("cannot find finalized fork: %w", err) + } + + s.Lock() + defer s.Unlock() + + // prune each height starting at the lowest height in the fork. this will remove all blocks + // below the new pruned height along with any conflicting forks. + for i := len(finalizedFork) - 1; i >= 0; i-- { + blockID := finalizedFork[i] + + err := s.pruneByHeight(s.prunedHeight+1, blockID) + if err != nil { + return fmt.Errorf("could not prune by height %v: %w", s.prunedHeight+1, err) + } + } + + return nil +} + +func (s *InMemoryRegisterStore) PrunedHeight() uint64 { + s.RLock() + defer s.RUnlock() + return s.prunedHeight +} + +func (s *InMemoryRegisterStore) IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) { + s.RLock() + defer s.RUnlock() + + // finalized and executed blocks are pruned + if height <= s.prunedHeight { + return false, fmt.Errorf("below pruned height") + } + + _, ok := s.registersByBlockID[blockID] + return ok, nil +} + +// findFinalizedFork returns the finalized fork from higher height to lower height +// the last block's height is s.prunedHeight + 1 +func (s *InMemoryRegisterStore) findFinalizedFork(height uint64, blockID flow.Identifier) ([]flow.Identifier, error) { + s.RLock() + defer s.RUnlock() + + if height <= s.prunedHeight { + return nil, fmt.Errorf("cannot find finalized fork at height %d, it is pruned (prunedHeight: %v)", height, s.prunedHeight) + } + prunedHeight := height + block := blockID + + // walk backwards from the provided finalized block to the last pruned block + // the result must be a chain from height/blockID to s.prunedHeight/s.prunedID + fork := make([]flow.Identifier, 0, height-s.prunedHeight) + for { + fork = append(fork, block) + prunedHeight-- + + parent, ok := s.parentByBlockID[block] + if !ok { + return nil, fmt.Errorf("inconsistent parent block index in in-memory-register-store, ancient block %s is not found when finding finalized fork at height %v", block, height) + } + if parent == s.prunedID { + break + } + block = parent + } + + if prunedHeight != s.prunedHeight { + return nil, fmt.Errorf("inconsistent parent block index in in-memory-register-store, pruned height %d is not equal to %d", prunedHeight, s.prunedHeight) + } + + return fork, nil +} + +func (s *InMemoryRegisterStore) pruneByHeight(height uint64, finalized flow.Identifier) error { + s.removeBlock(height, finalized) + + // remove conflicting forks + for blockID := range s.blockIDsByHeight[height] { + s.pruneFork(height, blockID) + } + + if len(s.blockIDsByHeight[height]) > 0 { + return fmt.Errorf("all forks on the same height should have been pruend, but actually not: %v", len(s.blockIDsByHeight[height])) + } + + delete(s.blockIDsByHeight, height) + s.prunedHeight = height + s.prunedID = finalized + return nil +} + +func (s *InMemoryRegisterStore) removeBlock(height uint64, blockID flow.Identifier) { + delete(s.registersByBlockID, blockID) + delete(s.parentByBlockID, blockID) + delete(s.blockIDsByHeight[height], blockID) +} + +// pruneFork prunes the provided block and all of its children +func (s *InMemoryRegisterStore) pruneFork(height uint64, blockID flow.Identifier) { + s.removeBlock(height, blockID) + // all its children must be at height + 1, whose parent is blockID + + nextHeight := height + 1 + blocksAtNextHeight, ok := s.blockIDsByHeight[nextHeight] + if !ok { + return + } + + for block := range blocksAtNextHeight { + isChild := s.parentByBlockID[block] == blockID + if isChild { + s.pruneFork(nextHeight, block) + } + } +} diff --git a/engine/execution/storehouse/in_memory_register_store_test.go b/engine/execution/storehouse/in_memory_register_store_test.go new file mode 100644 index 00000000000..98ee7da64d1 --- /dev/null +++ b/engine/execution/storehouse/in_memory_register_store_test.go @@ -0,0 +1,627 @@ +package storehouse + +import ( + "fmt" + "math/rand" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +// 1. SaveRegisters should fail if height is below or equal to pruned height +func TestInMemoryRegisterStore(t *testing.T) { + t.Run("FailBelowOrEqualPrunedHeight", func(t *testing.T) { + t.Parallel() + // 1. + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + err := store.SaveRegisters( + pruned-1, // below pruned pruned, will fail + unittest.IdentifierFixture(), + unittest.IdentifierFixture(), + flow.RegisterEntries{}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "<= pruned height") + + err = store.SaveRegisters( + pruned, // equal to pruned height, will fail + lastID, + unittest.IdentifierFixture(), + flow.RegisterEntries{}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "<= pruned height") + }) + + // 2. SaveRegisters should fail if its parent block doesn't exist and it is not the pruned block + // SaveRegisters should succeed if height is above pruned height and block is not saved, + // the updates can be retrieved by GetUpdatedRegisters + // GetRegister should return PrunedError if the queried key is not updated since pruned height + // GetRegister should return PrunedError if the queried height is below pruned height + // GetRegister should return ErrNotExecuted if the block is unknown + t.Run("FailParentNotExist", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + height := pruned + 1 // above the pruned pruned + blockID := unittest.IdentifierFixture() + notExistParent := unittest.IdentifierFixture() + reg := unittest.RegisterEntryFixture() + err := store.SaveRegisters( + height, + blockID, + notExistParent, // should fail because parent doesn't exist + flow.RegisterEntries{reg}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "but its parent") + }) + + t.Run("StoreOK", func(t *testing.T) { + t.Parallel() + // 3. + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + height := pruned + 1 // above the pruned pruned + blockID := unittest.IdentifierFixture() + reg := unittest.RegisterEntryFixture() + err := store.SaveRegisters( + height, + blockID, + lastID, + flow.RegisterEntries{reg}, + ) + require.NoError(t, err) + + val, err := store.GetRegister(height, blockID, reg.Key) + require.NoError(t, err) + require.Equal(t, reg.Value, val) + + // unknown key + _, err = store.GetRegister(height, blockID, unknownKey) + require.Error(t, err) + pe, ok := IsPrunedError(err) + require.True(t, ok) + require.Equal(t, pe.PrunedHeight, pruned) + require.Equal(t, pe.Height, height) + + // unknown block with unknown height + _, err = store.GetRegister(height+1, unknownBlock, reg.Key) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotExecuted) + + // unknown block with known height + _, err = store.GetRegister(height, unknownBlock, reg.Key) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotExecuted) + + // too low height + _, err = store.GetRegister(height-1, unknownBlock, reg.Key) + require.Error(t, err) + pe, ok = IsPrunedError(err) + require.True(t, ok) + require.Equal(t, pe.PrunedHeight, pruned) + require.Equal(t, pe.Height, height-1) + }) + + // 3. SaveRegisters should fail if the block is already saved + t.Run("StoreFailAlreadyExist", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + height := pruned + 1 // above the pruned pruned + blockID := unittest.IdentifierFixture() + reg := unittest.RegisterEntryFixture() + err := store.SaveRegisters( + height, + blockID, + lastID, + flow.RegisterEntries{reg}, + ) + require.NoError(t, err) + + // saving again should fail + err = store.SaveRegisters( + height, + blockID, + lastID, + flow.RegisterEntries{reg}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "already exists") + }) + + // 4. SaveRegisters should succeed if a different block at the same height was saved before, + // updates for different blocks can be retrieved by their blockID + t.Run("StoreOKDifferentBlockSameParent", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + // 10 <- A + // ^- B + height := pruned + 1 // above the pruned pruned + blockA := unittest.IdentifierFixture() + regA := unittest.RegisterEntryFixture() + err := store.SaveRegisters( + height, + blockA, + lastID, + flow.RegisterEntries{regA}, + ) + require.NoError(t, err) + + blockB := unittest.IdentifierFixture() + regB := unittest.RegisterEntryFixture() + err = store.SaveRegisters( + height, + blockB, // different block + lastID, // same parent + flow.RegisterEntries{regB}, + ) + require.NoError(t, err) + + valA, err := store.GetRegister(height, blockA, regA.Key) + require.NoError(t, err) + require.Equal(t, regA.Value, valA) + + valB, err := store.GetRegister(height, blockB, regB.Key) + require.NoError(t, err) + require.Equal(t, regB.Value, valB) + }) + + // 5. Given A(X: 1, Y: 2), GetRegister(A, X) should return 1, GetRegister(A, X) should return 2 + t.Run("GetRegistersOK", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + // 10 <- A (X: 1, Y: 2) + height := pruned + 1 // above the pruned pruned + blockA := unittest.IdentifierFixture() + regX := makeReg("X", "1") + regY := makeReg("Y", "2") + err := store.SaveRegisters( + height, + blockA, + lastID, + flow.RegisterEntries{regX, regY}, + ) + require.NoError(t, err) + + valX, err := store.GetRegister(height, blockA, regX.Key) + require.NoError(t, err) + require.Equal(t, regX.Value, valX) + + valY, err := store.GetRegister(height, blockA, regY.Key) + require.NoError(t, err) + require.Equal(t, regY.Value, valY) + }) + + // 6. Given A(X: 1, Y: 2) <- B(Y: 3), + // GetRegister(B, X) should return 1, because X is not updated in B + // GetRegister(B, Y) should return 3, because Y is updated in B + // GetRegister(A, Y) should return 2, because the query queries the value at A, not B + // GetRegister(B, Z) should return PrunedError, because register is unknown + // GetRegister(C, X) should return BlockNotExecuted, because block is not executed (unexecuted) + t.Run("GetLatestValueOK", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + // 10 <- A (X: 1, Y: 2) <- B (Y: 3) + blockA := unittest.IdentifierFixture() + regX := makeReg("X", "1") + regY := makeReg("Y", "2") + err := store.SaveRegisters( + pruned+1, + blockA, + lastID, + flow.RegisterEntries{regX, regY}, + ) + require.NoError(t, err) + + blockB := unittest.IdentifierFixture() + regY3 := makeReg("Y", "3") + err = store.SaveRegisters( + pruned+2, + blockB, + blockA, + flow.RegisterEntries{regY3}, + ) + require.NoError(t, err) + + val, err := store.GetRegister(pruned+2, blockB, regX.Key) + require.NoError(t, err) + require.Equal(t, regX.Value, val) // X is not updated in B + + val, err = store.GetRegister(pruned+2, blockB, regY.Key) + require.NoError(t, err) + require.Equal(t, regY3.Value, val) // Y is updated in B + + val, err = store.GetRegister(pruned+1, blockA, regY.Key) + require.NoError(t, err) + require.Equal(t, regY.Value, val) // Y's old value at A + + _, err = store.GetRegister(pruned+2, blockB, unknownKey) + require.Error(t, err) + pe, ok := IsPrunedError(err) + require.True(t, ok) + require.Equal(t, pe.PrunedHeight, pruned) + require.Equal(t, pe.Height, pruned+2) + + _, err = store.GetRegister(pruned+3, unittest.IdentifierFixture(), regX.Key) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotExecuted) // unknown block + }) + + // 7. Given the following tree: + // Pruned <- A(X:1) <- B(Y:2) + // .......^- C(X:3) <- D(Y:4) + // GetRegister(D, X) should return 3 + t.Run("StoreMultiForkOK", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + // 10 <- A (X: 1) <- B (Y: 2) + // ^- C (X: 3) <- D (Y: 4) + blockA := unittest.IdentifierFixture() + blockB := unittest.IdentifierFixture() + blockC := unittest.IdentifierFixture() + blockD := unittest.IdentifierFixture() + + require.NoError(t, store.SaveRegisters( + pruned+1, + blockA, + lastID, + flow.RegisterEntries{makeReg("X", "1")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+2, + blockB, + blockA, + flow.RegisterEntries{makeReg("Y", "2")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+1, + blockC, + lastID, + flow.RegisterEntries{makeReg("X", "3")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+2, + blockD, + blockC, + flow.RegisterEntries{makeReg("Y", "4")}, + )) + + reg := makeReg("X", "3") + val, err := store.GetRegister(pruned+2, blockD, reg.Key) + require.NoError(t, err) + require.Equal(t, reg.Value, val) + }) + + // 8. Given the following tree: + // Pruned <- A(X:1) <- B(Y:2), B is not executed + // GetUpdatedRegisters(B) should return ErrNotExecuted + t.Run("GetUpdatedRegisters", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + // 10 <- A (X: 1) <- B (Y: 2) + blockA := unittest.IdentifierFixture() + blockB := unittest.IdentifierFixture() + + require.NoError(t, store.SaveRegisters( + pruned+1, + blockA, + lastID, + flow.RegisterEntries{makeReg("X", "1")}, + )) + + reg, err := store.GetUpdatedRegisters(pruned+1, blockA) + require.NoError(t, err) + require.Equal(t, flow.RegisterEntries{makeReg("X", "1")}, reg) + + _, err = store.GetUpdatedRegisters(pruned+2, blockB) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotExecuted) + }) + + // 9. Prune should fail if the block is unknown + // Prune should succeed if the block is known, and GetUpdatedRegisters should return err + // Prune should prune up to the pruned height. + // Given Pruned <- A(X:1) <- B(X:2) <- C(X:3) <- D(X:4) + // after Prune(B), GetRegister(C, X) should return 3, GetRegister(B, X) should return err + t.Run("StorePrune", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + blockA := unittest.IdentifierFixture() + blockB := unittest.IdentifierFixture() + blockC := unittest.IdentifierFixture() + blockD := unittest.IdentifierFixture() + + require.NoError(t, store.SaveRegisters( + pruned+1, + blockA, + lastID, + flow.RegisterEntries{makeReg("X", "1")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+2, + blockB, + blockA, + flow.RegisterEntries{makeReg("X", "2")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+3, + blockC, + blockB, + flow.RegisterEntries{makeReg("X", "3")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+4, + blockD, + blockC, + flow.RegisterEntries{makeReg("X", "4")}, + )) + + err := store.Prune(pruned+1, unknownBlock) // block is unknown + require.Error(t, err) + + err = store.Prune(pruned+1, blockB) // block is known, but height is wrong + require.Error(t, err) + + err = store.Prune(pruned+4, unknownBlock) // height is unknown + require.Error(t, err) + + err = store.Prune(pruned+1, blockA) // prune next block + require.NoError(t, err) + + require.Equal(t, pruned+1, store.PrunedHeight()) + + reg := makeReg("X", "3") + val, err := store.GetRegister(pruned+3, blockC, reg.Key) + require.NoError(t, err) + require.Equal(t, reg.Value, val) + + _, err = store.GetRegister(pruned+1, blockA, reg.Key) // A is pruned + require.Error(t, err) + pe, ok := IsPrunedError(err) + require.True(t, ok) + require.Equal(t, pe.PrunedHeight, pruned+1) + require.Equal(t, pe.Height, pruned+1) + + err = store.Prune(pruned+3, blockC) // prune both B and C + require.NoError(t, err) + + require.Equal(t, pruned+3, store.PrunedHeight()) + + reg = makeReg("X", "4") + val, err = store.GetRegister(pruned+4, blockD, reg.Key) // can still get X at block D + require.NoError(t, err) + require.Equal(t, reg.Value, val) + }) + + // 10. Prune should prune conflicting forks + // Given Pruned <- A(X:1) <- B(X:2) + // .................. ^----- E(X:5) + // ............ ^- C(X:3) <- D(X:4) + // Prune(A) should prune C and D, and GetUpdatedRegisters(C) should return out of range error, + // GetUpdatedRegisters(D) should return NotFound + t.Run("PruneConflictingForks", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + blockA := unittest.IdentifierFixture() + blockB := unittest.IdentifierFixture() + blockC := unittest.IdentifierFixture() + blockD := unittest.IdentifierFixture() + blockE := unittest.IdentifierFixture() + + require.NoError(t, store.SaveRegisters( + pruned+1, + blockA, + lastID, + flow.RegisterEntries{makeReg("X", "1")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+2, + blockB, + blockA, + flow.RegisterEntries{makeReg("X", "2")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+1, + blockC, + lastID, + flow.RegisterEntries{makeReg("X", "3")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+2, + blockD, + blockC, + flow.RegisterEntries{makeReg("X", "4")}, + )) + + require.NoError(t, store.SaveRegisters( + pruned+2, + blockE, + blockA, + flow.RegisterEntries{makeReg("X", "5")}, + )) + + err := store.Prune(pruned+1, blockA) // prune A should prune C and D + require.NoError(t, err) + + _, err = store.GetUpdatedRegisters(pruned+2, blockD) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + + _, err = store.GetUpdatedRegisters(pruned+2, blockE) + require.NoError(t, err) + }) + + // 11. Concurrency: SaveRegisters can happen concurrently with GetUpdatedRegisters, and GetRegister + t.Run("ConcurrentSaveAndGet", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + // prepare a chain of 101 blocks with the first as lastID + count := 100 + blocks := make(map[uint64]flow.Identifier, count) + blocks[pruned] = lastID + for i := 1; i < count; i++ { + block := unittest.IdentifierFixture() + blocks[pruned+uint64(i)] = block + } + + reg := makeReg("X", "0") + + var wg sync.WaitGroup + for i := 1; i < count; i++ { + height := pruned + uint64(i) + require.NoError(t, store.SaveRegisters( + height, + blocks[height], + blocks[height-1], + flow.RegisterEntries{makeReg("X", fmt.Sprintf("%v", height))}, + )) + + // concurrently query get registers for past registers + wg.Add(1) + go func(i int) { + defer wg.Done() + + rdHeight := randBetween(pruned+1, pruned+uint64(i)+1) + val, err := store.GetRegister(rdHeight, blocks[rdHeight], reg.Key) + require.NoError(t, err) + r := makeReg("X", fmt.Sprintf("%v", rdHeight)) + require.Equal(t, r.Value, val) + }(i) + + // concurrently query updated registers + wg.Add(1) + go func(i int) { + defer wg.Done() + + rdHeight := randBetween(pruned+1, pruned+uint64(i)+1) + vals, err := store.GetUpdatedRegisters(rdHeight, blocks[rdHeight]) + require.NoError(t, err) + r := makeReg("X", fmt.Sprintf("%v", rdHeight)) + require.Equal(t, flow.RegisterEntries{r}, vals) + }(i) + } + + wg.Wait() + }) + + // 12. Concurrency: Prune can happen concurrently with GetUpdatedRegisters, and GetRegister + t.Run("ConcurrentSaveAndPrune", func(t *testing.T) { + t.Parallel() + pruned := uint64(10) + lastID := unittest.IdentifierFixture() + store := NewInMemoryRegisterStore(pruned, lastID) + + // prepare a chain of 101 blocks with the first as lastID + count := 100 + blocks := make(map[uint64]flow.Identifier, count) + blocks[pruned] = lastID + for i := 1; i < count; i++ { + block := unittest.IdentifierFixture() + blocks[pruned+uint64(i)] = block + } + + var wg sync.WaitGroup + savedHeights := make(chan uint64, 100) + + wg.Add(1) + go func() { + defer wg.Done() + + lastPrunedHeight := pruned + for savedHeight := range savedHeights { + if savedHeight%10 != 0 { + continue + } + rdHeight := randBetween(lastPrunedHeight+1, savedHeight+1) + err := store.Prune(rdHeight, blocks[rdHeight]) + require.NoError(t, err) + lastPrunedHeight = rdHeight + } + }() + + // save 100 blocks + for i := 1; i < count; i++ { + height := pruned + uint64(i) + require.NoError(t, store.SaveRegisters( + height, + blocks[height], + blocks[height-1], + flow.RegisterEntries{makeReg("X", fmt.Sprintf("%v", i))}, + )) + savedHeights <- height + } + + close(savedHeights) + + wg.Wait() + }) + + t.Run("PrunedError", func(t *testing.T) { + e := NewPrunedError(1, 2) + pe, ok := IsPrunedError(e) + require.True(t, ok) + require.Equal(t, uint64(1), pe.Height) + require.Equal(t, uint64(2), pe.PrunedHeight) + }) +} + +func randBetween(min, max uint64) uint64 { + return uint64(rand.Intn(int(max)-int(min))) + min +} + +func makeReg(key string, value string) flow.RegisterEntry { + return unittest.MakeOwnerReg(key, value) +} + +var unknownBlock = unittest.IdentifierFixture() +var unknownKey = flow.RegisterID{ + Owner: "unknown", + Key: "unknown", +} diff --git a/engine/execution/storehouse/register_engine.go b/engine/execution/storehouse/register_engine.go new file mode 100644 index 00000000000..d34e28637e5 --- /dev/null +++ b/engine/execution/storehouse/register_engine.go @@ -0,0 +1,57 @@ +package storehouse + +import ( + "fmt" + + "github.com/onflow/flow-go/consensus/hotstuff/model" + "github.com/onflow/flow-go/engine" + "github.com/onflow/flow-go/module/component" + "github.com/onflow/flow-go/module/irrecoverable" +) + +// RegisterEngine is a wrapper for RegisterStore in order to make Block Finalization process +// non-blocking. +type RegisterEngine struct { + *component.ComponentManager + store *RegisterStore + finalizationNotifier engine.Notifier +} + +func NewRegisterEngine(store *RegisterStore) *RegisterEngine { + e := &RegisterEngine{ + store: store, + finalizationNotifier: engine.NewNotifier(), + } + + // Add workers + e.ComponentManager = component.NewComponentManagerBuilder(). + AddWorker(e.finalizationProcessingLoop). + Build() + return e +} + +// OnBlockFinalized will create a single goroutine to notify register store +// when a block is finalized. +// This call is non-blocking in order to avoid blocking the consensus +func (e *RegisterEngine) OnBlockFinalized(*model.Block) { + e.finalizationNotifier.Notify() +} + +// finalizationProcessingLoop notify the register store when a block is finalized +// and handle the error if any +func (e *RegisterEngine) finalizationProcessingLoop(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + ready() + notifier := e.finalizationNotifier.Channel() + + for { + select { + case <-ctx.Done(): + return + case <-notifier: + err := e.store.OnBlockFinalized() + if err != nil { + ctx.Throw(fmt.Errorf("could not process finalized block: %w", err)) + } + } + } +} diff --git a/engine/execution/storehouse/register_store.go b/engine/execution/storehouse/register_store.go new file mode 100644 index 00000000000..80738d75676 --- /dev/null +++ b/engine/execution/storehouse/register_store.go @@ -0,0 +1,266 @@ +package storehouse + +import ( + "errors" + "fmt" + + "go.uber.org/atomic" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/execution" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/storage" +) + +type RegisterStore struct { + memStore *InMemoryRegisterStore + diskStore execution.OnDiskRegisterStore + wal execution.ExecutedFinalizedWAL + finalized execution.FinalizedReader + log zerolog.Logger + finalizing *atomic.Bool // making sure only one goroutine is finalizing at a time +} + +var _ execution.RegisterStore = (*RegisterStore)(nil) + +func NewRegisterStore( + diskStore execution.OnDiskRegisterStore, + wal execution.ExecutedFinalizedWAL, + finalized execution.FinalizedReader, + log zerolog.Logger, +) (*RegisterStore, error) { + // replay the executed and finalized blocks from the write ahead logs + // to the OnDiskRegisterStore + height, err := syncDiskStore(wal, diskStore, log) + if err != nil { + return nil, fmt.Errorf("cannot sync disk store: %w", err) + } + + // fetch the last executed and finalized block ID + finalizedID, err := finalized.FinalizedBlockIDAtHeight(height) + if err != nil { + return nil, fmt.Errorf("cannot get finalized block ID at height %d: %w", height, err) + } + + // init the memStore with the last executed and finalized block ID + memStore := NewInMemoryRegisterStore(height, finalizedID) + + return &RegisterStore{ + memStore: memStore, + diskStore: diskStore, + wal: wal, + finalized: finalized, + finalizing: atomic.NewBool(false), + log: log.With().Str("module", "register-store").Logger(), + }, nil +} + +// GetRegister first try to get the register from InMemoryRegisterStore, then OnDiskRegisterStore +// 1. below pruned height, and is conflicting +// 2. below pruned height, and is finalized +// 3. above pruned height, and is not executed +// 4. above pruned height, and is executed, and register is updated +// 5. above pruned height, and is executed, but register is not updated since pruned height +// It returns: +// - (value, nil) if the register value is found at the given block +// - (nil, nil) if the register is not found +// - (nil, storage.ErrHeightNotIndexed) if the height is below the first height that is indexed. +// - (nil, storehouse.ErrNotExecuted) if the block is not executed yet +// - (nil, storehouse.ErrNotExecuted) if the block is conflicting iwth finalized block +// - (nil, err) for any other exceptions +func (r *RegisterStore) GetRegister(height uint64, blockID flow.Identifier, register flow.RegisterID) (flow.RegisterValue, error) { + reg, err := r.memStore.GetRegister(height, blockID, register) + // the height might be lower than the lowest height in memStore, + // or the register might not be found in memStore. + if err == nil { + // this register was updated before its block is finalized + return reg, nil + } + + prunedError, ok := IsPrunedError(err) + if !ok { + // this means we ran into an exception. finding a register from in-memory store should either + // getting the register value or getting a ErrPruned error. + return flow.RegisterValue{}, fmt.Errorf("cannot get register from memStore: %w", err) + } + + // if in memory store returns PrunedError, and register height is above the pruned height, + // then it means the block is connected to the pruned block of in memory store, which is + // a finalized block and executed block, so we can get its value from on disk store. + if height > prunedError.PrunedHeight { + return r.getAndConvertNotFoundErr(register, prunedError.PrunedHeight) + } + + // if the block is below the pruned height, then there are two cases: + // the block is a finalized block, or a conflicting block. + // In order to distinguish, we need to query the finalized block ID at that height + finalizedID, err := r.finalized.FinalizedBlockIDAtHeight(height) + if err != nil { + return nil, fmt.Errorf("cannot get finalized block ID at height %d: %w", height, err) + } + + isConflictingBlock := blockID != finalizedID + if isConflictingBlock { + // conflicting blocks are considered as un-executed + return flow.RegisterValue{}, fmt.Errorf("getting registers from conflicting block %v at height %v: %w", blockID, height, ErrNotExecuted) + } + return r.getAndConvertNotFoundErr(register, height) +} + +// getAndConvertNotFoundErr returns nil if the register is not found from storage +func (r *RegisterStore) getAndConvertNotFoundErr(register flow.RegisterID, height uint64) (flow.RegisterValue, error) { + val, err := r.diskStore.Get(register, height) + if errors.Is(err, storage.ErrNotFound) { + // FVM expects the error to be nil when register is not found + return nil, nil + } + return val, err +} + +// SaveRegisters saves to InMemoryRegisterStore first, then trigger the same check as OnBlockFinalized +// Depend on InMemoryRegisterStore.SaveRegisters +// It returns: +// - nil if the registers are saved successfully +// - exception is the block is above the pruned height but does not connect to the pruned height (conflicting block). +// - exception if the block is below the pruned height +// - exception if the save block is saved again +// - exception for any other exception +func (r *RegisterStore) SaveRegisters(header *flow.Header, registers flow.RegisterEntries) error { + err := r.memStore.SaveRegisters(header.Height, header.ID(), header.ParentID, registers) + if err != nil { + return fmt.Errorf("cannot save register to memStore: %w", err) + } + + err = r.OnBlockFinalized() + if err != nil { + return fmt.Errorf("cannot trigger OnBlockFinalized: %w", err) + } + return nil +} + +// Depend on FinalizedReader's FinalizedBlockIDAtHeight +// Depend on ExecutedFinalizedWAL.Append +// Depend on OnDiskRegisterStore.SaveRegisters +// OnBlockFinalized trigger the check of whether a block at the next height becomes finalized and executed. +// the next height is the existing finalized and executed block's height + 1. +// If a block at next height becomes finalized and executed, then: +// 1. write the registers to write ahead logs +// 2. save the registers of the block to OnDiskRegisterStore +// 3. prune the height in InMemoryRegisterStore +func (r *RegisterStore) OnBlockFinalized() error { + // only one goroutine can execute OnBlockFinalized at a time + if !r.finalizing.CompareAndSwap(false, true) { + return nil + } + + defer r.finalizing.Store(false) + return r.onBlockFinalized() +} + +func (r *RegisterStore) onBlockFinalized() error { + latest := r.diskStore.LatestHeight() + next := latest + 1 + blockID, err := r.finalized.FinalizedBlockIDAtHeight(next) + if errors.Is(err, storage.ErrNotFound) { + // next block is not finalized yet + return nil + } + + regs, err := r.memStore.GetUpdatedRegisters(next, blockID) + if errors.Is(err, ErrNotExecuted) { + // next block is not executed yet + return nil + } + + // TODO: append WAL + // err = r.wal.Append(next, regs) + // if err != nil { + // return fmt.Errorf("cannot write %v registers to write ahead logs for height %v: %w", len(regs), next, err) + // } + + err = r.diskStore.Store(regs, next) + if err != nil { + return fmt.Errorf("cannot save %v registers to disk store for height %v: %w", len(regs), next, err) + } + + err = r.memStore.Prune(next, blockID) + if err != nil { + return fmt.Errorf("cannot prune memStore for height %v: %w", next, err) + } + + return r.onBlockFinalized() // check again until there is no more finalized block +} + +// LastFinalizedAndExecutedHeight returns the height of the last finalized and executed block, +// which has been saved in OnDiskRegisterStore +func (r *RegisterStore) LastFinalizedAndExecutedHeight() uint64 { + // diskStore caches the latest height in memory + return r.diskStore.LatestHeight() +} + +// IsBlockExecuted returns true if the block is executed, false if not executed +// Note: it returns (true, nil) even if the block has been pruned from on disk register store, +func (r *RegisterStore) IsBlockExecuted(height uint64, blockID flow.Identifier) (bool, error) { + executed, err := r.memStore.IsBlockExecuted(height, blockID) + if err != nil { + // the only error memStore would return is when the given height is lower than the pruned height in memStore. + // Since the pruned height in memStore is a finalized and executed height, in order to know if the block + // is executed, we just need to check if this block is the finalized blcok at the given height. + executed, err = r.isBlockFinalized(height, blockID) + return executed, err + } + + return executed, nil +} + +func (r *RegisterStore) isBlockFinalized(height uint64, blockID flow.Identifier) (bool, error) { + finalizedID, err := r.finalized.FinalizedBlockIDAtHeight(height) + if err != nil { + return false, fmt.Errorf("cannot get finalized block ID at height %d: %w", height, err) + } + return finalizedID == blockID, nil +} + +// syncDiskStore replay WAL to disk store +func syncDiskStore( + wal execution.ExecutedFinalizedWAL, + diskStore execution.OnDiskRegisterStore, + log zerolog.Logger, +) (uint64, error) { + // TODO: replace diskStore.Latest with wal.Latest + // latest, err := r.wal.Latest() + var err error + latest := diskStore.LatestHeight() // tmp + if err != nil { + return 0, fmt.Errorf("cannot get latest height from write ahead logs: %w", err) + } + + stored := diskStore.LatestHeight() + + if stored > latest { + return 0, fmt.Errorf("latest height in storehouse %v is larger than latest height %v in write ahead logs", stored, latest) + } + + if stored < latest { + // replay + reader := wal.GetReader(stored + 1) + for { + height, registers, err := reader.Next() + // TODO: to rename + if errors.Is(err, storage.ErrNotFound) { + break + } + if err != nil { + return 0, fmt.Errorf("cannot read registers from write ahead logs: %w", err) + } + + err = diskStore.Store(registers, height) + if err != nil { + return 0, fmt.Errorf("cannot save registers to disk store at height %v : %w", height, err) + } + } + } + + return latest, nil +} diff --git a/engine/execution/storehouse/register_store_test.go b/engine/execution/storehouse/register_store_test.go new file mode 100644 index 00000000000..43718d419e8 --- /dev/null +++ b/engine/execution/storehouse/register_store_test.go @@ -0,0 +1,449 @@ +package storehouse_test + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/engine/execution" + "github.com/onflow/flow-go/engine/execution/storehouse" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/storage/pebble" + "github.com/onflow/flow-go/utils/unittest" +) + +func withRegisterStore(t *testing.T, fn func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headers map[uint64]*flow.Header, +)) { + pebble.RunWithRegistersStorageAtInitialHeights(t, 10, 10, func(diskStore *pebble.Registers) { + log := unittest.Logger() + var wal execution.ExecutedFinalizedWAL + finalized, headerByHeight, highest := newMockFinalizedReader(10, 100) + rs, err := storehouse.NewRegisterStore(diskStore, wal, finalized, log) + require.NoError(t, err) + fn(t, rs, diskStore, finalized, 10, highest, headerByHeight) + }) +} + +// GetRegister should fail for +// 1. unknown blockID +// 2. height lower than OnDiskRegisterStore's root height +// 3. height too high +// 4. known block, but unknown register +func TestRegisterStoreGetRegisterFail(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + // unknown block + _, err := rs.GetRegister(rootHeight+1, unknownBlock, unknownReg.Key) + require.Error(t, err) + require.ErrorIs(t, err, storehouse.ErrNotExecuted) + + // too high + block11 := headerByHeight[rootHeight+1] + _, err = rs.GetRegister(rootHeight+1, block11.ID(), unknownReg.Key) + require.Error(t, err) + require.ErrorIs(t, err, storehouse.ErrNotExecuted) + + // lower than root height + _, err = rs.GetRegister(rootHeight-1, unknownBlock, unknownReg.Key) + require.Error(t, err) + // TODO: enable it once implemented + // require.ErrorIs(t, err, storehouse.ErrPruned) + + // known block, unknown register + rootBlock := headerByHeight[rootHeight] + val, err := rs.GetRegister(rootHeight, rootBlock.ID(), unknownReg.Key) + require.NoError(t, err) + require.Nil(t, val) + }) +} + +// SaveRegisters should fail for +// 1. mismatching parent +// 2. saved block +func TestRegisterStoreSaveRegistersShouldFail(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + wrongParent := unittest.BlockHeaderFixture(unittest.WithHeaderHeight(rootHeight + 1)) + err := rs.SaveRegisters(wrongParent, flow.RegisterEntries{}) + require.Error(t, err) + require.Contains(t, err.Error(), "parent") + + err = rs.SaveRegisters(headerByHeight[rootHeight], flow.RegisterEntries{}) + require.Error(t, err) + require.Contains(t, err.Error(), "pruned") + }) +} + +// SaveRegisters should ok, and +// 1. GetRegister can get saved registers, +// 2. IsBlockExecuted should return true +// +// if SaveRegisters with empty register, then +// 1. LastFinalizedAndExecutedHeight should be updated +// 2. IsBlockExecuted should return true +func TestRegisterStoreSaveRegistersShouldOK(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + // not executed + executed, err := rs.IsBlockExecuted(rootHeight+1, headerByHeight[rootHeight+1].ID()) + require.NoError(t, err) + require.False(t, executed) + + // save block 11 + reg := makeReg("X", "1") + err = rs.SaveRegisters(headerByHeight[rootHeight+1], flow.RegisterEntries{reg}) + require.NoError(t, err) + + // should get value + val, err := rs.GetRegister(rootHeight+1, headerByHeight[rootHeight+1].ID(), reg.Key) + require.NoError(t, err) + require.Equal(t, reg.Value, val) + + // should become executed + executed, err = rs.IsBlockExecuted(rootHeight+1, headerByHeight[rootHeight+1].ID()) + require.NoError(t, err) + require.True(t, executed) + + // block 12 is empty + err = rs.SaveRegisters(headerByHeight[rootHeight+2], flow.RegisterEntries{}) + require.NoError(t, err) + + // should get same value + val, err = rs.GetRegister(rootHeight+1, headerByHeight[rootHeight+2].ID(), reg.Key) + require.NoError(t, err) + require.Equal(t, reg.Value, val) + + // should become executed + executed, err = rs.IsBlockExecuted(rootHeight+1, headerByHeight[rootHeight+2].ID()) + require.NoError(t, err) + require.True(t, executed) + }) +} + +// if 11 is latest finalized, then +// 1. IsBlockExecuted should return true for finalized block 10 +// 2. IsBlockExecuted should return false for conflicting block 10 +// 4. IsBlockExecuted should return true for executed and unfinalized block 12 +// 3. IsBlockExecuted should return false for unexecuted block 13 +func TestRegisterStoreIsBlockExecuted(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + // save block 11 + reg := makeReg("X", "1") + err := rs.SaveRegisters(headerByHeight[rootHeight+1], flow.RegisterEntries{reg}) + require.NoError(t, err) + + // save block 12 + err = rs.SaveRegisters(headerByHeight[rootHeight+2], flow.RegisterEntries{makeReg("X", "2")}) + require.NoError(t, err) + + require.NoError(t, finalized.MockFinal(rootHeight+1)) + + require.NoError(t, rs.OnBlockFinalized()) // notify 11 is finalized + + require.Equal(t, rootHeight+1, rs.LastFinalizedAndExecutedHeight()) + + executed, err := rs.IsBlockExecuted(rootHeight, headerByHeight[rootHeight].ID()) + require.NoError(t, err) + require.True(t, executed) + + executed, err = rs.IsBlockExecuted(rootHeight+1, headerByHeight[rootHeight+1].ID()) + require.NoError(t, err) + require.True(t, executed) + + executed, err = rs.IsBlockExecuted(rootHeight+2, headerByHeight[rootHeight+2].ID()) + require.NoError(t, err) + require.True(t, executed) + + executed, err = rs.IsBlockExecuted(rootHeight+3, headerByHeight[rootHeight+3].ID()) + require.NoError(t, err) + require.False(t, executed) + }) +} + +// Test reading registers from finalized block +func TestRegisterStoreReadingFromDisk(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + + // R <- 11 (X: 1, Y: 2) <- 12 (Y: 3) <- 13 (X: 4) + // save block 11 + err := rs.SaveRegisters(headerByHeight[rootHeight+1], flow.RegisterEntries{makeReg("X", "1"), makeReg("Y", "2")}) + require.NoError(t, err) + + // save block 12 + err = rs.SaveRegisters(headerByHeight[rootHeight+2], flow.RegisterEntries{makeReg("Y", "3")}) + require.NoError(t, err) + + // save block 13 + err = rs.SaveRegisters(headerByHeight[rootHeight+3], flow.RegisterEntries{makeReg("X", "4")}) + require.NoError(t, err) + + require.NoError(t, finalized.MockFinal(rootHeight+2)) + require.NoError(t, rs.OnBlockFinalized()) // notify 12 is finalized + + val, err := rs.GetRegister(rootHeight+1, headerByHeight[rootHeight+1].ID(), makeReg("Y", "2").Key) + require.NoError(t, err) + // value at block 11 is now stored in OnDiskRegisterStore, which is 2 + require.Equal(t, makeReg("Y", "2").Value, val) + + val, err = rs.GetRegister(rootHeight+2, headerByHeight[rootHeight+2].ID(), makeReg("X", "1").Key) + require.NoError(t, err) + // value at block 12 is now stored in OnDiskRegisterStore, which is 1 + require.Equal(t, makeReg("X", "1").Value, val) + + val, err = rs.GetRegister(rootHeight+3, headerByHeight[rootHeight+3].ID(), makeReg("Y", "3").Key) + require.NoError(t, err) + // value at block 13 was stored in OnDiskRegisterStore at block 12, which is 3 + require.Equal(t, makeReg("Y", "3").Value, val) + + _, err = rs.GetRegister(rootHeight+4, headerByHeight[rootHeight+4].ID(), makeReg("Y", "3").Key) + require.Error(t, err) + }) +} + +func TestRegisterStoreReadingFromInMemStore(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + + // R <- 11 (X: 1, Y: 2) <- 12 (Y: 3) + // ^- 11 (X: 4) + + // save block 11 + err := rs.SaveRegisters(headerByHeight[rootHeight+1], flow.RegisterEntries{makeReg("X", "1"), makeReg("Y", "2")}) + require.NoError(t, err) + + // save block 12 + err = rs.SaveRegisters(headerByHeight[rootHeight+2], flow.RegisterEntries{makeReg("Y", "3")}) + require.NoError(t, err) + + // save block 11 fork + block11Fork := unittest.BlockWithParentFixture(headerByHeight[rootHeight]).Header + err = rs.SaveRegisters(block11Fork, flow.RegisterEntries{makeReg("X", "4")}) + require.NoError(t, err) + + val, err := rs.GetRegister(rootHeight+1, headerByHeight[rootHeight+1].ID(), makeReg("X", "1").Key) + require.NoError(t, err) + require.Equal(t, makeReg("X", "1").Value, val) + + val, err = rs.GetRegister(rootHeight+1, headerByHeight[rootHeight+1].ID(), makeReg("Y", "2").Key) + require.NoError(t, err) + require.Equal(t, makeReg("Y", "2").Value, val) + + val, err = rs.GetRegister(rootHeight+2, headerByHeight[rootHeight+2].ID(), makeReg("X", "1").Key) + require.NoError(t, err) + require.Equal(t, makeReg("X", "1").Value, val) + + val, err = rs.GetRegister(rootHeight+2, headerByHeight[rootHeight+2].ID(), makeReg("Y", "3").Key) + require.NoError(t, err) + require.Equal(t, makeReg("Y", "3").Value, val) + + val, err = rs.GetRegister(rootHeight+1, block11Fork.ID(), makeReg("X", "4").Key) + require.NoError(t, err) + require.Equal(t, makeReg("X", "4").Value, val) + + // finalizing 11 should prune block 11 fork, and won't be able to read register from block 11 fork + require.NoError(t, finalized.MockFinal(rootHeight+1)) + require.NoError(t, rs.OnBlockFinalized()) // notify 11 is finalized + + val, err = rs.GetRegister(rootHeight+1, block11Fork.ID(), makeReg("X", "4").Key) + require.Error(t, err, fmt.Sprintf("%v", val)) + // pruned conflicting forks are considered not executed + require.ErrorIs(t, err, storehouse.ErrNotExecuted) + }) +} + +// Execute first then finalize later +// SaveRegisters(1), SaveRegisters(2), SaveRegisters(3), then +// OnBlockFinalized(1), OnBlockFinalized(2), OnBlockFinalized(3) should +// 1. update LastFinalizedAndExecutedHeight +// 2. InMemoryRegisterStore should have correct pruned height +// 3. NewRegisterStore with the same OnDiskRegisterStore again should return correct LastFinalizedAndExecutedHeight +func TestRegisterStoreExecuteFirstFinalizeLater(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + // save block 11 + err := rs.SaveRegisters(headerByHeight[rootHeight+1], flow.RegisterEntries{makeReg("X", "1")}) + require.NoError(t, err) + require.Equal(t, rootHeight, rs.LastFinalizedAndExecutedHeight()) + + // save block 12 + err = rs.SaveRegisters(headerByHeight[rootHeight+2], flow.RegisterEntries{makeReg("X", "2")}) + require.NoError(t, err) + require.Equal(t, rootHeight, rs.LastFinalizedAndExecutedHeight()) + + // save block 13 + err = rs.SaveRegisters(headerByHeight[rootHeight+3], flow.RegisterEntries{makeReg("X", "3")}) + require.NoError(t, err) + require.Equal(t, rootHeight, rs.LastFinalizedAndExecutedHeight()) + + require.NoError(t, finalized.MockFinal(rootHeight+1)) + require.NoError(t, rs.OnBlockFinalized()) // notify 11 is finalized + require.Equal(t, rootHeight+1, rs.LastFinalizedAndExecutedHeight()) + + require.NoError(t, finalized.MockFinal(rootHeight+2)) + require.NoError(t, rs.OnBlockFinalized()) // notify 12 is finalized + require.Equal(t, rootHeight+2, rs.LastFinalizedAndExecutedHeight()) + + require.NoError(t, finalized.MockFinal(rootHeight+3)) + require.NoError(t, rs.OnBlockFinalized()) // notify 13 is finalized + require.Equal(t, rootHeight+3, rs.LastFinalizedAndExecutedHeight()) + }) +} + +// Finalize first then execute later +// OnBlockFinalized(1), OnBlockFinalized(2), OnBlockFinalized(3), then +// SaveRegisters(1), SaveRegisters(2), SaveRegisters(3) should +// 1. update LastFinalizedAndExecutedHeight +// 2. InMemoryRegisterStore should have correct pruned height +// 3. NewRegisterStore with the same OnDiskRegisterStore again should return correct LastFinalizedAndExecutedHeight +func TestRegisterStoreFinalizeFirstExecuteLater(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + require.NoError(t, finalized.MockFinal(rootHeight+1)) + require.NoError(t, rs.OnBlockFinalized()) // notify 11 is finalized + require.Equal(t, rootHeight, rs.LastFinalizedAndExecutedHeight(), fmt.Sprintf("LastFinalizedAndExecutedHeight: %d", rs.LastFinalizedAndExecutedHeight())) + + require.NoError(t, finalized.MockFinal(rootHeight+2)) + require.NoError(t, rs.OnBlockFinalized()) // notify 12 is finalized + require.Equal(t, rootHeight, rs.LastFinalizedAndExecutedHeight(), fmt.Sprintf("LastFinalizedAndExecutedHeight: %d", rs.LastFinalizedAndExecutedHeight())) + + require.NoError(t, finalized.MockFinal(rootHeight+3)) + require.NoError(t, rs.OnBlockFinalized()) // notify 13 is finalized + require.Equal(t, rootHeight, rs.LastFinalizedAndExecutedHeight()) + + // save block 11 + err := rs.SaveRegisters(headerByHeight[rootHeight+1], flow.RegisterEntries{makeReg("X", "1")}) + require.NoError(t, err) + require.Equal(t, rootHeight+1, rs.LastFinalizedAndExecutedHeight()) + + // save block 12 + err = rs.SaveRegisters(headerByHeight[rootHeight+2], flow.RegisterEntries{makeReg("X", "2")}) + require.NoError(t, err) + require.Equal(t, rootHeight+2, rs.LastFinalizedAndExecutedHeight()) + + // save block 13 + err = rs.SaveRegisters(headerByHeight[rootHeight+3], flow.RegisterEntries{makeReg("X", "3")}) + require.NoError(t, err) + require.Equal(t, rootHeight+3, rs.LastFinalizedAndExecutedHeight()) + }) +} + +// Finalize and Execute concurrently +// SaveRegisters(1), SaveRegisters(2), ... SaveRegisters(100), happen concurrently with +// OnBlockFinalized(1), OnBlockFinalized(2), ... OnBlockFinalized(100), should update LastFinalizedAndExecutedHeight +func TestRegisterStoreConcurrentFinalizeAndExecute(t *testing.T) { + t.Parallel() + withRegisterStore(t, func( + t *testing.T, + rs *storehouse.RegisterStore, + diskStore execution.OnDiskRegisterStore, + finalized *mockFinalizedReader, + rootHeight uint64, + endHeight uint64, + headerByHeight map[uint64]*flow.Header, + ) { + + var wg sync.WaitGroup + savedHeights := make(chan uint64, len(headerByHeight)) // enough buffer so that producer won't be blocked + + wg.Add(1) + go func() { + defer wg.Done() + + for savedHeight := range savedHeights { + err := finalized.MockFinal(savedHeight) + require.NoError(t, err) + require.NoError(t, rs.OnBlockFinalized(), fmt.Sprintf("saved height %v", savedHeight)) + } + }() + + for height := rootHeight + 1; height <= endHeight; height++ { + if height >= 50 { + savedHeights <- height + } + + err := rs.SaveRegisters(headerByHeight[height], flow.RegisterEntries{makeReg("X", fmt.Sprintf("%d", height))}) + require.NoError(t, err) + } + close(savedHeights) + + wg.Wait() // wait until all heights are finalized + + // after all heights are executed and finalized, the LastFinalizedAndExecutedHeight should be the last height + require.Equal(t, endHeight, rs.LastFinalizedAndExecutedHeight()) + }) +} diff --git a/engine/execution/storehouse/storehouse_test.go b/engine/execution/storehouse/storehouse_test.go new file mode 100644 index 00000000000..8589c88499d --- /dev/null +++ b/engine/execution/storehouse/storehouse_test.go @@ -0,0 +1,71 @@ +package storehouse_test + +import ( + "fmt" + + "go.uber.org/atomic" + + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/storage" + "github.com/onflow/flow-go/utils/unittest" +) + +var unknownBlock = unittest.IdentifierFixture() +var unknownReg = makeReg("unknown", "unknown") + +func makeReg(key string, value string) flow.RegisterEntry { + return flow.RegisterEntry{ + Key: flow.RegisterID{ + Owner: "owner", + Key: key, + }, + Value: []byte(value), + } +} + +type mockFinalizedReader struct { + headerByHeight map[uint64]*flow.Header + lowest uint64 + highest uint64 + finalizedHeight *atomic.Uint64 +} + +func newMockFinalizedReader(initHeight uint64, count int) (*mockFinalizedReader, map[uint64]*flow.Header, uint64) { + root := unittest.BlockHeaderFixture(unittest.WithHeaderHeight(initHeight)) + blocks := unittest.ChainFixtureFrom(count, root) + headerByHeight := make(map[uint64]*flow.Header, len(blocks)+1) + headerByHeight[root.Height] = root + + for _, b := range blocks { + headerByHeight[b.Header.Height] = b.Header + } + + highest := blocks[len(blocks)-1].Header.Height + return &mockFinalizedReader{ + headerByHeight: headerByHeight, + lowest: initHeight, + highest: highest, + finalizedHeight: atomic.NewUint64(initHeight), + }, headerByHeight, highest +} + +func (r *mockFinalizedReader) FinalizedBlockIDAtHeight(height uint64) (flow.Identifier, error) { + finalized := r.finalizedHeight.Load() + if height > finalized { + return flow.Identifier{}, storage.ErrNotFound + } + + if height < r.lowest { + return unknownBlock, nil + } + return r.headerByHeight[height].ID(), nil +} + +func (r *mockFinalizedReader) MockFinal(height uint64) error { + if height < r.lowest || height > r.highest { + return fmt.Errorf("height %d is out of range [%d, %d]", height, r.lowest, r.highest) + } + + r.finalizedHeight.Store(height) + return nil +} diff --git a/engine/testutil/nodes.go b/engine/testutil/nodes.go index c25f0e04a24..7ff2faf0a20 100644 --- a/engine/testutil/nodes.go +++ b/engine/testutil/nodes.go @@ -712,7 +712,7 @@ func ExecutionNode(t *testing.T, hub *stub.Hub, identity *flow.Identity, identit ) fetcher := exeFetcher.NewCollectionFetcher(node.Log, requestEngine, node.State, false) - loader := loader.NewLoader(node.Log, node.State, node.Headers, execState) + loader := loader.NewUnexecutedLoader(node.Log, node.State, node.Headers, execState) rootHead, rootQC := getRoot(t, &node) ingestionEngine, err := ingestion.New( unit, diff --git a/follower/follower_builder.go b/follower/follower_builder.go index 30d0d6abff1..cf7b19c7214 100644 --- a/follower/follower_builder.go +++ b/follower/follower_builder.go @@ -600,6 +600,7 @@ func (builder *FollowerServiceBuilder) initPublicLibp2pNode(networkKey crypto.Pr &builder.FlowConfig.NetworkConfig.ResourceManager, &builder.FlowConfig.NetworkConfig.GossipSubConfig.GossipSubRPCInspectorsConfig, p2pconfig.PeerManagerDisableConfig(), // disable peer manager for follower + &builder.FlowConfig.NetworkConfig.GossipSubConfig.SubscriptionProviderConfig, &p2p.DisallowListCacheConfig{ MaxSize: builder.FlowConfig.NetworkConfig.DisallowListNotificationCacheSize, Metrics: metrics.DisallowListCacheMetricsFactory(builder.HeroCacheMetricsFactory(), network.PublicNetwork), diff --git a/fvm/blueprints/contracts.go b/fvm/blueprints/contracts.go index bbe3ce422ab..34554be5a7a 100644 --- a/fvm/blueprints/contracts.go +++ b/fvm/blueprints/contracts.go @@ -3,8 +3,6 @@ package blueprints import ( _ "embed" - "encoding/hex" - "github.com/onflow/cadence" jsoncdc "github.com/onflow/cadence/encoding/json" "github.com/onflow/cadence/runtime/common" @@ -32,7 +30,7 @@ var setContractOperationAuthorizersTransactionTemplate string var setIsContractDeploymentRestrictedTransactionTemplate string //go:embed scripts/deployContractTransactionTemplate.cdc -var deployContractTransactionTemplate string +var DeployContractTransactionTemplate []byte // SetContractDeploymentAuthorizersTransaction returns a transaction for updating list of authorized accounts allowed to deploy/update contracts func SetContractDeploymentAuthorizersTransaction(serviceAccount flow.Address, authorized []flow.Address) (*flow.TransactionBody, error) { @@ -95,8 +93,8 @@ func SetIsContractDeploymentRestrictedTransaction(serviceAccount flow.Address, r // TODO (ramtin) get rid of authorizers func DeployContractTransaction(address flow.Address, contract []byte, contractName string) *flow.TransactionBody { return flow.NewTransactionBody(). - SetScript([]byte(deployContractTransactionTemplate)). + SetScript(DeployContractTransactionTemplate). AddArgument(jsoncdc.MustEncode(cadence.String(contractName))). - AddArgument(jsoncdc.MustEncode(cadence.String(hex.EncodeToString(contract)))). + AddArgument(jsoncdc.MustEncode(cadence.String(contract))). AddAuthorizer(address) } diff --git a/fvm/blueprints/scripts/deployContractTransactionTemplate.cdc b/fvm/blueprints/scripts/deployContractTransactionTemplate.cdc index 02573e4342b..4e24d39b7d4 100644 --- a/fvm/blueprints/scripts/deployContractTransactionTemplate.cdc +++ b/fvm/blueprints/scripts/deployContractTransactionTemplate.cdc @@ -1,5 +1,5 @@ transaction(name: String, code: String) { prepare(signer: AuthAccount) { - signer.contracts.add(name: name, code: code.decodeHex()) + signer.contracts.add(name: name, code: code.utf8) } } diff --git a/fvm/bootstrap.go b/fvm/bootstrap.go index a888952ccb4..52a326d0ba6 100644 --- a/fvm/bootstrap.go +++ b/fvm/bootstrap.go @@ -10,6 +10,7 @@ import ( "github.com/onflow/flow-go/fvm/blueprints" "github.com/onflow/flow-go/fvm/environment" "github.com/onflow/flow-go/fvm/errors" + "github.com/onflow/flow-go/fvm/evm/stdlib" "github.com/onflow/flow-go/fvm/meter" "github.com/onflow/flow-go/fvm/storage" "github.com/onflow/flow-go/fvm/storage/logical" @@ -75,6 +76,7 @@ type BootstrapParams struct { minimumStorageReservation cadence.UFix64 storagePerFlow cadence.UFix64 restrictedAccountCreationEnabled cadence.Bool + setupEVMEnabled cadence.Bool // versionFreezePeriod is the number of blocks in the future where the version // changes are frozen. The Node version beacon manages the freeze period, @@ -210,6 +212,13 @@ func WithRestrictedAccountCreationEnabled(enabled cadence.Bool) BootstrapProcedu } } +func WithSetupEVMEnabled(enabled cadence.Bool) BootstrapProcedureOption { + return func(bp *BootstrapProcedure) *BootstrapProcedure { + bp.setupEVMEnabled = enabled + return bp + } +} + func WithRestrictedContractDeployment(restricted *bool) BootstrapProcedureOption { return func(bp *BootstrapProcedure) *BootstrapProcedure { bp.restrictedContractDeployment = restricted @@ -380,6 +389,9 @@ func (b *bootstrapExecutor) Execute() error { // set the list of nodes which are allowed to stake in this network b.setStakingAllowlist(service, b.identities.NodeIDs()) + // sets up the EVM environment + b.setupEVM(service, flowToken) + return nil } @@ -776,6 +788,23 @@ func (b *bootstrapExecutor) setStakingAllowlist( panicOnMetaInvokeErrf("failed to set staking allow-list: %s", txError, err) } +func (b *bootstrapExecutor) setupEVM(serviceAddress, flowTokenAddress flow.Address) { + if b.setupEVMEnabled { + b.createAccount(nil) // account for storage + tx := blueprints.DeployContractTransaction( + serviceAddress, + stdlib.ContractCode(flowTokenAddress), + stdlib.ContractName, + ) + // WithEVMEnabled should only be used after we create an account for storage + txError, err := b.invokeMetaTransaction( + NewContextFromParent(b.ctx, WithEVMEnabled(true)), + Transaction(tx, 0), + ) + panicOnMetaInvokeErrf("failed to deploy EVM contract: %s", txError, err) + } +} + func (b *bootstrapExecutor) registerNodes(service, fungibleToken, flowToken flow.Address) { for _, id := range b.identities { diff --git a/fvm/context.go b/fvm/context.go index 44aecdd14ce..61a3f0c7268 100644 --- a/fvm/context.go +++ b/fvm/context.go @@ -28,6 +28,7 @@ type Context struct { // DisableMemoryAndInteractionLimits will override memory and interaction // limits and set them to MaxUint64, effectively disabling these limits. DisableMemoryAndInteractionLimits bool + EVMEnabled bool ComputationLimit uint64 MemoryLimit uint64 MaxStateKeySize uint64 @@ -366,3 +367,11 @@ func WithEventEncoder(encoder environment.EventEncoder) Option { return ctx } } + +// WithEVMEnabled enables access to the evm environment +func WithEVMEnabled(enabled bool) Option { + return func(ctx Context) Context { + ctx.EVMEnabled = enabled + return ctx + } +} diff --git a/fvm/environment/account_creator.go b/fvm/environment/account_creator.go index 07612384d2c..a9173a5c530 100644 --- a/fvm/environment/account_creator.go +++ b/fvm/environment/account_creator.go @@ -16,6 +16,7 @@ const ( FungibleTokenAccountIndex = 2 FlowTokenAccountIndex = 3 FlowFeesAccountIndex = 4 + EVMAccountIndex = 5 ) type AddressGenerator interface { diff --git a/fvm/environment/env.go b/fvm/environment/env.go index ac8ac32f3b7..200d3af7ea4 100644 --- a/fvm/environment/env.go +++ b/fvm/environment/env.go @@ -37,7 +37,7 @@ type Environment interface { // EventEmitter Events() flow.EventsList - EmitFlowEvent(etype flow.EventType, payload []byte) error + EmitRawEvent(etype flow.EventType, payload []byte) error ServiceEvents() flow.EventsList ConvertedServiceEvents() flow.ServiceEventList diff --git a/fvm/environment/event_emitter.go b/fvm/environment/event_emitter.go index 6a05fefe1f3..0579bb72833 100644 --- a/fvm/environment/event_emitter.go +++ b/fvm/environment/event_emitter.go @@ -38,16 +38,14 @@ func DefaultEventEmitterParams() EventEmitterParams { // Note that scripts do not emit events, but must expose the API in compliance // with the runtime environment interface. type EventEmitter interface { - // Cadence's runtime API. Note that the script variant will return - // OperationNotSupportedError. + // EmitEvent satisfies Cadence's runtime API. + // This will encode the cadence event and call EmitRawEvent. + // + // Note that the script variant will return OperationNotSupportedError. EmitEvent(event cadence.Event) error - // EmitFlowEvent is used to emit events that are not generated by - // Cadence runtime. - // Warning: current implementation of EmitFlowEvent does not support handling service events - // that functionality should be added if needed in the future - // TODO: we could merge this one with the EmitEvent endpoint - EmitFlowEvent(etype flow.EventType, payload []byte) error + // EmitRawEvent is used to emit events that are not Cadence events. + EmitRawEvent(eventType flow.EventType, payload []byte) error Events() flow.EventsList ServiceEvents() flow.EventsList @@ -79,12 +77,12 @@ func (emitter ParseRestrictedEventEmitter) EmitEvent(event cadence.Event) error event) } -func (emitter ParseRestrictedEventEmitter) EmitFlowEvent(etype flow.EventType, payload []byte) error { +func (emitter ParseRestrictedEventEmitter) EmitRawEvent(eventType flow.EventType, payload []byte) error { return parseRestrict2Arg( emitter.txnState, trace.FVMEnvEmitEvent, - emitter.impl.EmitFlowEvent, - etype, + emitter.impl.EmitRawEvent, + eventType, payload, ) } @@ -111,11 +109,11 @@ var _ EventEmitter = NoEventEmitter{} // where emitting an event does nothing. type NoEventEmitter struct{} -func (NoEventEmitter) EmitEvent(event cadence.Event) error { +func (NoEventEmitter) EmitEvent(cadence.Event) error { return nil } -func (NoEventEmitter) EmitFlowEvent(etype flow.EventType, payload []byte) error { +func (NoEventEmitter) EmitRawEvent(flow.EventType, []byte) error { return nil } @@ -180,12 +178,10 @@ func (emitter *eventEmitter) EventCollection() *EventCollection { } func (emitter *eventEmitter) EmitEvent(event cadence.Event) error { - defer emitter.tracer.StartExtensiveTracingChildSpan( - trace.FVMEnvEmitEvent).End() - - err := emitter.meter.MeterComputation(ComputationKindEmitEvent, 1) + defer emitter.tracer.StartExtensiveTracingChildSpan(trace.FVMEnvEncodeEvent).End() + err := emitter.meter.MeterComputation(ComputationKindEncodeEvent, 1) if err != nil { - return fmt.Errorf("emit event failed: %w", err) + return fmt.Errorf("emit event, event encoding failed: %w", err) } payload, err := emitter.EventEncoder.Encode(event) @@ -193,10 +189,18 @@ func (emitter *eventEmitter) EmitEvent(event cadence.Event) error { return errors.NewEventEncodingError(err) } - payloadSize := uint64(len(payload)) + return emitter.EmitRawEvent(flow.EventType(event.EventType.ID()), payload) +} +func (emitter *eventEmitter) EmitRawEvent(eventType flow.EventType, payload []byte) error { + defer emitter.tracer.StartExtensiveTracingChildSpan(trace.FVMEnvEmitEvent).End() + payloadSize := len(payload) + err := emitter.meter.MeterComputation(ComputationKindEmitEvent, uint(payloadSize)) + if err != nil { + return fmt.Errorf("emit event failed: %w", err) + } flowEvent := flow.Event{ - Type: flow.EventType(event.EventType.ID()), + Type: eventType, TransactionID: emitter.txID, TransactionIndex: emitter.txIndex, EventIndex: emitter.eventCollection.TotalEventCounter(), @@ -207,7 +211,7 @@ func (emitter *eventEmitter) EmitEvent(event cadence.Event) error { isServiceAccount := emitter.payer == emitter.chain.ServiceAddress() if emitter.ServiceEventCollectionEnabled { - ok, err := IsServiceEvent(event, emitter.chain.ChainID()) + ok, err := IsServiceEvent(eventType, emitter.chain.ChainID()) if err != nil { return fmt.Errorf("unable to check service event: %w", err) } @@ -215,7 +219,7 @@ func (emitter *eventEmitter) EmitEvent(event cadence.Event) error { eventEmitError := emitter.eventCollection.AppendServiceEvent( emitter.chain, flowEvent, - payloadSize) + uint64(payloadSize)) // skip limit if payer is service account // TODO skip only limit-related errors @@ -227,35 +231,14 @@ func (emitter *eventEmitter) EmitEvent(event cadence.Event) error { // as well. } - eventEmitError := emitter.eventCollection.AppendEvent(flowEvent, payloadSize) + eventEmitError := emitter.eventCollection.AppendEvent(flowEvent, uint64(payloadSize)) // skip limit if payer is service account if !isServiceAccount { return eventEmitError } return nil -} - -func (emitter *eventEmitter) EmitFlowEvent(etype flow.EventType, payload []byte) error { - defer emitter.tracer.StartExtensiveTracingChildSpan( - trace.FVMEnvEmitEvent).End() - - err := emitter.meter.MeterComputation(ComputationKindEmitEvent, 1) - if err != nil { - return fmt.Errorf("emit flow event failed: %w", err) - } - - eventSize := uint64(len(etype) + len(payload)) - - flowEvent := flow.Event{ - Type: etype, - TransactionID: emitter.txID, - TransactionIndex: emitter.txIndex, - EventIndex: emitter.eventCollection.TotalEventCounter(), - Payload: payload, - } - return emitter.eventCollection.AppendEvent(flowEvent, eventSize) } func (emitter *eventEmitter) Events() flow.EventsList { @@ -334,7 +317,7 @@ func (collection *EventCollection) TotalEventCounter() uint32 { // IsServiceEvent determines whether or not an emitted Cadence event is // considered a service event for the given chain. -func IsServiceEvent(event cadence.Event, chain flow.ChainID) (bool, error) { +func IsServiceEvent(eventType flow.EventType, chain flow.ChainID) (bool, error) { // retrieve the service event information for this chain events, err := systemcontracts.ServiceEventsForChain(chain) @@ -345,7 +328,6 @@ func IsServiceEvent(event cadence.Event, chain flow.ChainID) (bool, error) { err) } - eventType := flow.EventType(event.EventType.ID()) for _, serviceEvent := range events.All() { if serviceEvent.EventType() == eventType { return true, nil diff --git a/fvm/environment/event_emitter_test.go b/fvm/environment/event_emitter_test.go index 5057954680b..e681bfd6d98 100644 --- a/fvm/environment/event_emitter_test.go +++ b/fvm/environment/event_emitter_test.go @@ -27,7 +27,7 @@ func Test_IsServiceEvent(t *testing.T) { t.Run("correct", func(t *testing.T) { for _, event := range events.All() { - isServiceEvent, err := environment.IsServiceEvent(cadence.Event{ + event := cadence.Event{ EventType: &cadence.EventType{ Location: common.AddressLocation{ Address: common.MustBytesToAddress( @@ -35,14 +35,16 @@ func Test_IsServiceEvent(t *testing.T) { }, QualifiedIdentifier: event.QualifiedIdentifier(), }, - }, chain) + } + + isServiceEvent, err := environment.IsServiceEvent(flow.EventType(event.Type().ID()), chain) require.NoError(t, err) assert.True(t, isServiceEvent) } }) t.Run("wrong chain", func(t *testing.T) { - isServiceEvent, err := environment.IsServiceEvent(cadence.Event{ + event := cadence.Event{ EventType: &cadence.EventType{ Location: common.AddressLocation{ Address: common.MustBytesToAddress( @@ -50,13 +52,15 @@ func Test_IsServiceEvent(t *testing.T) { }, QualifiedIdentifier: events.EpochCommit.QualifiedIdentifier(), }, - }, chain) + } + + isServiceEvent, err := environment.IsServiceEvent(flow.EventType(event.Type().ID()), chain) require.NoError(t, err) assert.False(t, isServiceEvent) }) t.Run("wrong type", func(t *testing.T) { - isServiceEvent, err := environment.IsServiceEvent(cadence.Event{ + event := cadence.Event{ EventType: &cadence.EventType{ Location: common.AddressLocation{ Address: common.MustBytesToAddress( @@ -64,7 +68,9 @@ func Test_IsServiceEvent(t *testing.T) { }, QualifiedIdentifier: "SomeContract.SomeEvent", }, - }, chain) + } + + isServiceEvent, err := environment.IsServiceEvent(flow.EventType(event.Type().ID()), chain) require.NoError(t, err) assert.False(t, isServiceEvent) }) @@ -150,19 +156,19 @@ func Test_EmitEvent_Limit(t *testing.T) { require.Error(t, err) }) - t.Run("emit flow event - exceeding limit", func(t *testing.T) { + t.Run("emit raw event - exceeding limit", func(t *testing.T) { flowEvent := flow.Event{ Type: "sometype", Payload: []byte{1, 2, 3, 4, 5}, } - eventSize := uint64(len(flowEvent.Type) + len(flowEvent.Payload)) + eventSize := uint64(len(flowEvent.Payload)) eventEmitter := createTestEventEmitterWithLimit( flow.Emulator, flow.Emulator.Chain().NewAddressGenerator().CurrentAddress(), eventSize-1) - err := eventEmitter.EmitFlowEvent(flowEvent.Type, flowEvent.Payload) + err := eventEmitter.EmitRawEvent(flowEvent.Type, flowEvent.Payload) require.Error(t, err) }) } diff --git a/fvm/environment/meter.go b/fvm/environment/meter.go index 757ec0ea8be..75250d1c4c7 100644 --- a/fvm/environment/meter.go +++ b/fvm/environment/meter.go @@ -51,6 +51,7 @@ const ( ComputationKindEVMGasUsage = 2037 ComputationKindRLPEncoding = 2038 ComputationKindRLPDecoding = 2039 + ComputationKindEncodeEvent = 2040 ) type Meter interface { diff --git a/fvm/environment/mock/environment.go b/fvm/environment/mock/environment.go index 11b9cda285c..bd5b606f5b4 100644 --- a/fvm/environment/mock/environment.go +++ b/fvm/environment/mock/environment.go @@ -440,8 +440,8 @@ func (_m *Environment) EmitEvent(_a0 cadence.Event) error { return r0 } -// EmitFlowEvent provides a mock function with given fields: etype, payload -func (_m *Environment) EmitFlowEvent(etype flow.EventType, payload []byte) error { +// EmitRawEvent provides a mock function with given fields: etype, payload +func (_m *Environment) EmitRawEvent(etype flow.EventType, payload []byte) error { ret := _m.Called(etype, payload) var r0 error diff --git a/fvm/environment/mock/event_emitter.go b/fvm/environment/mock/event_emitter.go index 018efa1f19b..ac83fe259a3 100644 --- a/fvm/environment/mock/event_emitter.go +++ b/fvm/environment/mock/event_emitter.go @@ -45,13 +45,13 @@ func (_m *EventEmitter) EmitEvent(event cadence.Event) error { return r0 } -// EmitFlowEvent provides a mock function with given fields: etype, payload -func (_m *EventEmitter) EmitFlowEvent(etype flow.EventType, payload []byte) error { - ret := _m.Called(etype, payload) +// EmitRawEvent provides a mock function with given fields: eventType, payload +func (_m *EventEmitter) EmitRawEvent(eventType flow.EventType, payload []byte) error { + ret := _m.Called(eventType, payload) var r0 error if rf, ok := ret.Get(0).(func(flow.EventType, []byte) error); ok { - r0 = rf(etype, payload) + r0 = rf(eventType, payload) } else { r0 = ret.Error(0) } diff --git a/fvm/errors/codes.go b/fvm/errors/codes.go index 3308b47fdd9..cdbc734bd3d 100644 --- a/fvm/errors/codes.go +++ b/fvm/errors/codes.go @@ -78,8 +78,8 @@ const ( ErrCodeComputationLimitExceededError ErrorCode = 1110 ErrCodeMemoryLimitExceededError ErrorCode = 1111 ErrCodeCouldNotDecodeExecutionParameterFromState ErrorCode = 1112 - ErrCodeScriptExecutionCancelledError ErrorCode = 1114 ErrCodeScriptExecutionTimedOutError ErrorCode = 1113 + ErrCodeScriptExecutionCancelledError ErrorCode = 1114 ErrCodeEventEncodingError ErrorCode = 1115 ErrCodeInvalidInternalStateAccessError ErrorCode = 1116 // 1117 was never deployed and is free to use diff --git a/fvm/evm/emulator/database/database.go b/fvm/evm/emulator/database/database.go index f47e32b7174..8b18e56e7bc 100644 --- a/fvm/evm/emulator/database/database.go +++ b/fvm/evm/emulator/database/database.go @@ -34,6 +34,7 @@ type Database struct { flowEVMRootAddress flow.Address led atree.Ledger storage *atree.PersistentSlabStorage + baseStorage *atree.LedgerBaseStorage atreemap *atree.OrderedMap rootIDBytesToBeStored []byte // if is empty means we don't need to store anything // Ramtin: other database implementations for EVM uses a lock @@ -57,6 +58,7 @@ func NewDatabase(led atree.Ledger, flowEVMRootAddress flow.Address) (*Database, db := &Database{ led: led, + baseStorage: baseStorage, flowEVMRootAddress: flowEVMRootAddress, storage: storage, } @@ -237,7 +239,7 @@ func (db *Database) getRootHash() (gethCommon.Hash, error) { if len(data) == 0 { return gethTypes.EmptyRootHash, nil } - return gethCommon.Hash(data), nil + return gethCommon.BytesToHash(data), nil } // Commits the changes from atree into the underlying storage @@ -309,6 +311,19 @@ func (db *Database) Stat(property string) (string, error) { return "", types.ErrNotImplemented } +func (db *Database) BytesRetrieved() int { + return db.baseStorage.BytesRetrieved() +} + +func (db *Database) BytesStored() int { + return db.baseStorage.BytesStored() +} +func (db *Database) ResetReporter() { + db.baseStorage.ResetReporter() +} + +// Compact is not supported on a memory database, but there's no need either as +// a memory database doesn't waste space anyway. // Compact is a no op func (db *Database) Compact(start []byte, limit []byte) error { return nil @@ -357,6 +372,11 @@ func (b *batch) set(key []byte, value []byte, delete bool) error { return nil } +// DropCache drops the database read cache +func (db *Database) DropCache() { + db.storage.DropCache() +} + // ValueSize retrieves the amount of data queued up for writing. func (b *batch) ValueSize() int { return b.size diff --git a/fvm/evm/emulator/database/database_test.go b/fvm/evm/emulator/database/database_test.go index 62b7f9b63f3..a23e38d4295 100644 --- a/fvm/evm/emulator/database/database_test.go +++ b/fvm/evm/emulator/database/database_test.go @@ -24,7 +24,7 @@ func TestDatabase(t *testing.T) { value2 := []byte{9, 10, 11} t.Run("test basic database functionality", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(flowEVMRoot flow.Address) { db, err := database.NewDatabase(backend, flowEVMRoot) require.NoError(t, err) @@ -70,7 +70,7 @@ func TestDatabase(t *testing.T) { }) t.Run("test batch functionality", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(flowEVMRoot flow.Address) { db, err := database.NewDatabase(backend, flowEVMRoot) require.NoError(t, err) @@ -159,7 +159,7 @@ func TestDatabase(t *testing.T) { }) t.Run("test fatal error (not implemented methods)", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(flowEVMRoot flow.Address) { db, err := database.NewDatabase(backend, flowEVMRoot) require.NoError(t, err) diff --git a/fvm/evm/emulator/emulator.go b/fvm/evm/emulator/emulator.go index 9371562ef63..eef720912f7 100644 --- a/fvm/evm/emulator/emulator.go +++ b/fvm/evm/emulator/emulator.go @@ -148,7 +148,8 @@ func (bl *BlockView) newProcedure() (*procedure, error) { cfg.ChainConfig, cfg.EVMConfig, ), - state: execState, + state: execState, + database: bl.database, }, nil } @@ -159,9 +160,10 @@ func (bl *BlockView) commit(rootHash gethCommon.Hash) error { } type procedure struct { - config *Config - evm *gethVM.EVM - state *gethState.StateDB + config *Config + evm *gethVM.EVM + state *gethState.StateDB + database types.Database } // commit commits the changes to the state. @@ -184,6 +186,12 @@ func (proc *procedure) commit() (gethCommon.Hash, error) { if err != nil { return gethTypes.EmptyRootHash, handleCommitError(err) } + + // // remove the read registers (no history tracking) + // err = proc.database.DeleteAndCleanReadKey() + // if err != nil { + // return gethTypes.EmptyRootHash, types.NewFatalError(err) + // } return newRoot, nil } @@ -235,8 +243,11 @@ func (proc *procedure) withdrawFrom(address types.Address, amount *big.Int) (*ty // while this method is only called from bridged accounts // it might be the case that someone creates a bridged account // and never transfer tokens to and call for withdraw + // TODO: we might revisit this apporach and + // return res, types.ErrAccountDoesNotExist + // instead if !proc.state.Exist(addr) { - return res, types.ErrAccountDoesNotExist + proc.state.CreateAccount(addr) } // check the source account balance diff --git a/fvm/evm/emulator/emulator_test.go b/fvm/evm/emulator/emulator_test.go index 2169b96e630..caaf5853cac 100644 --- a/fvm/evm/emulator/emulator_test.go +++ b/fvm/evm/emulator/emulator_test.go @@ -22,7 +22,7 @@ var blockNumber = big.NewInt(10) var defaultCtx = types.NewDefaultBlockContext(blockNumber.Uint64()) func RunWithTestDB(t testing.TB, f func(types.Database)) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(flowEVMRoot flow.Address) { db, err := database.NewDatabase(backend, flowEVMRoot) require.NoError(t, err) @@ -62,7 +62,7 @@ func TestNativeTokenBridging(t *testing.T) { }) }) }) - t.Run("mint tokens withdraw", func(t *testing.T) { + t.Run("tokens withdraw", func(t *testing.T) { amount := big.NewInt(1000) RunWithNewEmulator(t, db, func(env *emulator.Emulator) { RunWithNewReadOnlyBlockView(t, env, func(blk types.ReadOnlyBlockView) { @@ -93,7 +93,7 @@ func TestNativeTokenBridging(t *testing.T) { func TestContractInteraction(t *testing.T) { RunWithTestDB(t, func(db types.Database) { - testContract := testutils.GetTestContract(t) + testContract := testutils.GetStorageTestContract(t) testAccount := types.NewAddressFromString("test") amount := big.NewInt(0).Mul(big.NewInt(1337), big.NewInt(gethParams.Ether)) @@ -148,7 +148,7 @@ func TestContractInteraction(t *testing.T) { types.NewContractCall( testAccount, contractAddr, - testContract.MakeStoreCallData(t, num), + testContract.MakeCallData(t, "store", num), 1_000_000, big.NewInt(0), // this should be zero because the contract doesn't have receiver ), @@ -164,7 +164,7 @@ func TestContractInteraction(t *testing.T) { types.NewContractCall( testAccount, contractAddr, - testContract.MakeRetrieveCallData(t), + testContract.MakeCallData(t, "retrieve"), 1_000_000, big.NewInt(0), // this should be zero because the contract doesn't have receiver ), @@ -183,7 +183,7 @@ func TestContractInteraction(t *testing.T) { types.NewContractCall( testAccount, contractAddr, - testContract.MakeBlockNumberCallData(t), + testContract.MakeCallData(t, "blockNumber"), 1_000_000, big.NewInt(0), // this should be zero because the contract doesn't have receiver ), @@ -369,3 +369,30 @@ func TestDatabaseErrorHandling(t *testing.T) { }) }) } + +func TestStorageNoSideEffect(t *testing.T) { + t.Skip("we need to fix this issue ") + + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { + testutils.RunWithTestFlowEVMRootAddress(t, backend, func(flowEVMRoot flow.Address) { + db, err := database.NewDatabase(backend, flowEVMRoot) + require.NoError(t, err) + + em := emulator.NewEmulator(db) + testAccount := types.NewAddressFromString("test") + + amount := big.NewInt(100) + RunWithNewBlockView(t, em, func(blk types.BlockView) { + _, err = blk.DirectCall(types.NewDepositCall(testAccount, amount)) + require.NoError(t, err) + }) + + orgSize := backend.TotalStorageSize() + RunWithNewBlockView(t, em, func(blk types.BlockView) { + _, err = blk.DirectCall(types.NewDepositCall(testAccount, amount)) + require.NoError(t, err) + }) + require.Equal(t, orgSize, backend.TotalStorageSize()) + }) + }) +} diff --git a/fvm/evm/evm.go b/fvm/evm/evm.go new file mode 100644 index 00000000000..2d3e1288ae2 --- /dev/null +++ b/fvm/evm/evm.go @@ -0,0 +1,55 @@ +package evm + +import ( + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/common" + + "github.com/onflow/flow-go/fvm/environment" + evm "github.com/onflow/flow-go/fvm/evm/emulator" + "github.com/onflow/flow-go/fvm/evm/emulator/database" + "github.com/onflow/flow-go/fvm/evm/handler" + "github.com/onflow/flow-go/fvm/evm/stdlib" + "github.com/onflow/flow-go/fvm/evm/types" + "github.com/onflow/flow-go/model/flow" +) + +func RootAccountAddress(chainID flow.ChainID) (flow.Address, error) { + return chainID.Chain().AddressAtIndex(environment.EVMAccountIndex) +} + +func SetupEnvironment( + chainID flow.ChainID, + backend types.Backend, + env runtime.Environment, + service flow.Address, + flowToken flow.Address, +) error { + // TODO: setup proper root address based on chainID + evmRootAddress, err := RootAccountAddress(chainID) + if err != nil { + return err + } + + db, err := database.NewDatabase(backend, evmRootAddress) + if err != nil { + return err + } + + em := evm.NewEmulator(db) + + bs, err := handler.NewBlockStore(backend, evmRootAddress) + if err != nil { + return err + } + + aa, err := handler.NewAddressAllocator(backend, evmRootAddress) + if err != nil { + return err + } + + contractHandler := handler.NewContractHandler(common.Address(flowToken), bs, aa, backend, em) + + stdlib.SetupEnvironment(env, contractHandler, service) + + return nil +} diff --git a/fvm/evm/evm_test.go b/fvm/evm/evm_test.go new file mode 100644 index 00000000000..9c46312ae58 --- /dev/null +++ b/fvm/evm/evm_test.go @@ -0,0 +1,281 @@ +package evm_test + +import ( + "fmt" + "math/big" + "testing" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/encoding/json" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/fvm" + "github.com/onflow/flow-go/fvm/evm/stdlib" + "github.com/onflow/flow-go/fvm/evm/testutils" + . "github.com/onflow/flow-go/fvm/evm/testutils" + "github.com/onflow/flow-go/fvm/storage/snapshot" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +func TestEVMRun(t *testing.T) { + + t.Parallel() + + t.Run("testing EVM.run (happy case)", func(t *testing.T) { + RunWithTestBackend(t, func(backend *testutils.TestBackend) { + RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { + tc := GetStorageTestContract(t) + RunWithDeployedContract(t, tc, backend, rootAddr, func(testContract *TestContract) { + RunWithEOATestAccount(t, backend, rootAddr, func(testAccount *EOATestAccount) { + num := int64(12) + chain := flow.Emulator.Chain() + + RunWithNewTestVM(t, chain, func(ctx fvm.Context, vm fvm.VM, snapshot snapshot.SnapshotTree) { + code := []byte(fmt.Sprintf( + ` + import EVM from %s + + access(all) + fun main(tx: [UInt8], coinbaseBytes: [UInt8; 20]) { + let coinbase = EVM.EVMAddress(bytes: coinbaseBytes) + EVM.run(tx: tx, coinbase: coinbase) + } + `, + chain.ServiceAddress().HexWithPrefix(), + )) + + gasLimit := uint64(100_000) + + txBytes := testAccount.PrepareSignAndEncodeTx(t, + testContract.DeployedAt.ToCommon(), + testContract.MakeCallData(t, "store", big.NewInt(num)), + big.NewInt(0), + gasLimit, + big.NewInt(0), + ) + + tx := cadence.NewArray( + ConvertToCadence(txBytes), + ).WithType(stdlib.EVMTransactionBytesCadenceType) + + coinbase := cadence.NewArray( + ConvertToCadence(testAccount.Address().Bytes()), + ).WithType(stdlib.EVMAddressBytesCadenceType) + + script := fvm.Script(code).WithArguments( + json.MustEncode(tx), + json.MustEncode(coinbase), + ) + + _, output, err := vm.Run( + ctx, + script, + snapshot) + require.NoError(t, err) + require.NoError(t, output.Err) + }) + }) + }) + }) + }) + }) +} + +func RunWithNewTestVM(t *testing.T, chain flow.Chain, f func(fvm.Context, fvm.VM, snapshot.SnapshotTree)) { + opts := []fvm.Option{ + fvm.WithChain(chain), + fvm.WithAuthorizationChecksEnabled(false), + fvm.WithSequenceNumberCheckAndIncrementEnabled(false), + } + ctx := fvm.NewContext(opts...) + + vm := fvm.NewVirtualMachine() + snapshotTree := snapshot.NewSnapshotTree(nil) + + baseBootstrapOpts := []fvm.BootstrapProcedureOption{ + fvm.WithInitialTokenSupply(unittest.GenesisTokenSupply), + fvm.WithSetupEVMEnabled(true), + } + + executionSnapshot, _, err := vm.Run( + ctx, + fvm.Bootstrap(unittest.ServiceAccountPublicKey, baseBootstrapOpts...), + snapshotTree) + require.NoError(t, err) + + snapshotTree = snapshotTree.Append(executionSnapshot) + + f(fvm.NewContextFromParent(ctx, fvm.WithEVMEnabled(true)), vm, snapshotTree) +} + +func TestEVMAddressDeposit(t *testing.T) { + + t.Parallel() + + RunWithTestBackend(t, func(backend *testutils.TestBackend) { + RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { + tc := GetStorageTestContract(t) + RunWithDeployedContract(t, tc, backend, rootAddr, func(testContract *TestContract) { + RunWithEOATestAccount(t, backend, rootAddr, func(testAccount *EOATestAccount) { + chain := flow.Emulator.Chain() + RunWithNewTestVM(t, chain, func(ctx fvm.Context, vm fvm.VM, snapshot snapshot.SnapshotTree) { + + code := []byte(fmt.Sprintf( + ` + import EVM from %[1]s + import FlowToken from %[2]s + + access(all) + fun main() { + let admin = getAuthAccount(%[1]s) + .borrow<&FlowToken.Administrator>(from: /storage/flowTokenAdmin)! + let minter <- admin.createNewMinter(allowedAmount: 1.23) + let vault <- minter.mintTokens(amount: 1.23) + destroy minter + + let address = EVM.EVMAddress( + bytes: [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + address.deposit(from: <-vault) + } + `, + chain.ServiceAddress().HexWithPrefix(), + fvm.FlowTokenAddress(chain).HexWithPrefix(), + )) + + script := fvm.Script(code) + + executionSnapshot, output, err := vm.Run( + ctx, + script, + snapshot) + require.NoError(t, err) + require.NoError(t, output.Err) + + // TODO: + _ = executionSnapshot + }) + }) + }) + }) + }) +} + +func TestBridgedAccountWithdraw(t *testing.T) { + + t.Parallel() + + RunWithTestBackend(t, func(backend *testutils.TestBackend) { + RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { + tc := GetStorageTestContract(t) + RunWithDeployedContract(t, tc, backend, rootAddr, func(testContract *TestContract) { + RunWithEOATestAccount(t, backend, rootAddr, func(testAccount *EOATestAccount) { + chain := flow.Emulator.Chain() + RunWithNewTestVM(t, chain, func(ctx fvm.Context, vm fvm.VM, snapshot snapshot.SnapshotTree) { + + code := []byte(fmt.Sprintf( + ` + import EVM from %[1]s + import FlowToken from %[2]s + + access(all) + fun main(): UFix64 { + let admin = getAuthAccount(%[1]s) + .borrow<&FlowToken.Administrator>(from: /storage/flowTokenAdmin)! + let minter <- admin.createNewMinter(allowedAmount: 2.34) + let vault <- minter.mintTokens(amount: 2.34) + destroy minter + + let bridgedAccount <- EVM.createBridgedAccount() + bridgedAccount.address().deposit(from: <-vault) + + let vault2 <- bridgedAccount.withdraw(balance: EVM.Balance(flow: 1.23)) + let balance = vault2.balance + destroy bridgedAccount + destroy vault2 + + return balance + } + `, + chain.ServiceAddress().HexWithPrefix(), + fvm.FlowTokenAddress(chain).HexWithPrefix(), + )) + + script := fvm.Script(code) + + executionSnapshot, output, err := vm.Run( + ctx, + script, + snapshot) + require.NoError(t, err) + require.NoError(t, output.Err) + + // TODO: + _ = executionSnapshot + }) + }) + }) + }) + }) +} + +// TODO: provide proper contract code +func TestBridgedAccountDeploy(t *testing.T) { + + t.Parallel() + + RunWithTestBackend(t, func(backend *testutils.TestBackend) { + RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { + tc := GetStorageTestContract(t) + RunWithDeployedContract(t, tc, backend, rootAddr, func(testContract *TestContract) { + RunWithEOATestAccount(t, backend, rootAddr, func(testAccount *EOATestAccount) { + chain := flow.Emulator.Chain() + RunWithNewTestVM(t, chain, func(ctx fvm.Context, vm fvm.VM, snapshot snapshot.SnapshotTree) { + + code := []byte(fmt.Sprintf( + ` + import EVM from %[1]s + import FlowToken from %[2]s + + access(all) + fun main(): [UInt8; 20] { + let admin = getAuthAccount(%[1]s) + .borrow<&FlowToken.Administrator>(from: /storage/flowTokenAdmin)! + let minter <- admin.createNewMinter(allowedAmount: 2.34) + let vault <- minter.mintTokens(amount: 2.34) + destroy minter + + let bridgedAccount <- EVM.createBridgedAccount() + bridgedAccount.address().deposit(from: <-vault) + + let address = bridgedAccount.deploy( + code: [], + gasLimit: 53000, + value: EVM.Balance(flow: 1.23) + ) + destroy bridgedAccount + return address.bytes + } + `, + chain.ServiceAddress().HexWithPrefix(), + fvm.FlowTokenAddress(chain).HexWithPrefix(), + )) + + script := fvm.Script(code) + + executionSnapshot, output, err := vm.Run( + ctx, + script, + snapshot) + require.NoError(t, err) + require.NoError(t, output.Err) + + // TODO: + _ = executionSnapshot + }) + }) + }) + }) + }) +} diff --git a/fvm/evm/handler/addressAllocator_test.go b/fvm/evm/handler/addressAllocator_test.go index 6ff534ff221..ab8eb0de2b4 100644 --- a/fvm/evm/handler/addressAllocator_test.go +++ b/fvm/evm/handler/addressAllocator_test.go @@ -14,7 +14,7 @@ import ( func TestAddressAllocator(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(root flow.Address) { aa, err := handler.NewAddressAllocator(backend, root) require.NoError(t, err) diff --git a/fvm/evm/handler/blockstore_test.go b/fvm/evm/handler/blockstore_test.go index 77f80d947ff..77720b143a2 100644 --- a/fvm/evm/handler/blockstore_test.go +++ b/fvm/evm/handler/blockstore_test.go @@ -13,7 +13,7 @@ import ( func TestBlockStore(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(root flow.Address) { bs, err := handler.NewBlockStore(backend, root) require.NoError(t, err) diff --git a/fvm/evm/handler/handler.go b/fvm/evm/handler/handler.go index b11ef08b7ee..83f2633985b 100644 --- a/fvm/evm/handler/handler.go +++ b/fvm/evm/handler/handler.go @@ -5,6 +5,7 @@ import ( gethTypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/flow-go/fvm/environment" "github.com/onflow/flow-go/fvm/errors" @@ -22,21 +23,28 @@ import ( // in the future we might benefit from a view style of access to db passed as // a param to the emulator. type ContractHandler struct { + flowTokenAddress common.Address blockstore types.BlockStore + addressAllocator types.AddressAllocator backend types.Backend emulator types.Emulator - addressAllocator types.AddressAllocator +} + +func (h *ContractHandler) FlowTokenAddress() common.Address { + return h.flowTokenAddress } var _ types.ContractHandler = &ContractHandler{} func NewContractHandler( + flowTokenAddress common.Address, blockstore types.BlockStore, addressAllocator types.AddressAllocator, backend types.Backend, emulator types.Emulator, ) *ContractHandler { return &ContractHandler{ + flowTokenAddress: flowTokenAddress, blockstore: blockstore, addressAllocator: addressAllocator, backend: backend, @@ -58,7 +66,7 @@ func (h *ContractHandler) AccountByAddress(addr types.Address, isAuthorized bool } // LastExecutedBlock returns the last executed block -func (h ContractHandler) LastExecutedBlock() *types.Block { +func (h *ContractHandler) LastExecutedBlock() *types.Block { block, err := h.blockstore.LatestBlock() handleError(err) return block @@ -66,7 +74,7 @@ func (h ContractHandler) LastExecutedBlock() *types.Block { // Run runs an rlpencoded evm transaction and // collects the gas fees and pay it to the coinbase address provided. -func (h ContractHandler) Run(rlpEncodedTx []byte, coinbase types.Address) { +func (h *ContractHandler) Run(rlpEncodedTx []byte, coinbase types.Address) { // step 1 - transaction decoding encodedLen := uint(len(rlpEncodedTx)) err := h.backend.MeterComputation(environment.ComputationKindRLPDecoding, encodedLen) @@ -113,14 +121,14 @@ func (h ContractHandler) Run(rlpEncodedTx []byte, coinbase types.Address) { handleError(err) } -func (h ContractHandler) checkGasLimit(limit types.GasLimit) { +func (h *ContractHandler) checkGasLimit(limit types.GasLimit) { // check gas limit against what has been left on the transaction side if !h.backend.ComputationAvailable(environment.ComputationKindEVMGasUsage, uint(limit)) { handleError(types.ErrInsufficientComputation) } } -func (h ContractHandler) meterGasUsage(res *types.Result) { +func (h *ContractHandler) meterGasUsage(res *types.Result) { if res != nil { err := h.backend.MeterComputation(environment.ComputationKindEVMGasUsage, uint(res.GasConsumed)) handleError(err) @@ -131,7 +139,7 @@ func (h *ContractHandler) emitEvent(event *types.Event) { // TODO add extra metering for rlp encoding encoded, err := event.Payload.Encode() handleError(err) - err = h.backend.EmitFlowEvent(event.Etype, encoded) + err = h.backend.EmitRawEvent(event.Etype, encoded) handleError(err) } diff --git a/fvm/evm/handler/handler_benchmark_test.go b/fvm/evm/handler/handler_benchmark_test.go new file mode 100644 index 00000000000..73f0f0ed59d --- /dev/null +++ b/fvm/evm/handler/handler_benchmark_test.go @@ -0,0 +1,82 @@ +package handler_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/fvm/evm/testutils" + "github.com/onflow/flow-go/fvm/evm/types" + "github.com/onflow/flow-go/model/flow" +) + +func BenchmarkStorage(b *testing.B) { benchmarkStorageGrowth(b, 100, 100) } + +// benchmark +func benchmarkStorageGrowth(b *testing.B, accountCount, setupKittyCount int) { + testutils.RunWithTestBackend(b, func(backend *testutils.TestBackend) { + testutils.RunWithTestFlowEVMRootAddress(b, backend, func(rootAddr flow.Address) { + testutils.RunWithDeployedContract(b, + testutils.GetDummyKittyTestContract(b), + backend, + rootAddr, + func(tc *testutils.TestContract) { + db, handler := SetupHandler(b, backend, rootAddr) + numOfAccounts := 100000 + accounts := make([]types.Account, numOfAccounts) + // setup several of accounts + // note that trie growth is the function of number of accounts + for i := 0; i < numOfAccounts; i++ { + account := handler.AccountByAddress(handler.AllocateAddress(), true) + account.Deposit(types.NewFlowTokenVault(types.Balance(100))) + accounts[i] = account + } + backend.DropEvents() + // mint kitties + for i := 0; i < setupKittyCount; i++ { + account := accounts[i%accountCount] + matronId := testutils.RandomBigInt(1000) + sireId := testutils.RandomBigInt(1000) + generation := testutils.RandomBigInt(1000) + genes := testutils.RandomBigInt(1000) + require.NotNil(b, account) + account.Call( + tc.DeployedAt, + tc.MakeCallData(b, + "CreateKitty", + matronId, + sireId, + generation, + genes, + ), + 300_000_000, + types.Balance(0), + ) + require.Equal(b, 2, len(backend.Events())) + backend.DropEvents() // this would make things lighter + } + + // measure the impact of mint after the setup phase + db.ResetReporter() + db.DropCache() + + accounts[0].Call( + tc.DeployedAt, + tc.MakeCallData(b, + "CreateKitty", + testutils.RandomBigInt(1000), + testutils.RandomBigInt(1000), + testutils.RandomBigInt(1000), + testutils.RandomBigInt(1000), + ), + 300_000_000, + types.Balance(0), + ) + + b.ReportMetric(float64(db.BytesRetrieved()), "bytes_read") + b.ReportMetric(float64(db.BytesStored()), "bytes_written") + b.ReportMetric(float64(backend.TotalStorageSize()), "total_storage_size") + }) + }) + }) +} diff --git a/fvm/evm/handler/handler_test.go b/fvm/evm/handler/handler_test.go index 29fae749a74..64b608bf425 100644 --- a/fvm/evm/handler/handler_test.go +++ b/fvm/evm/handler/handler_test.go @@ -11,9 +11,11 @@ import ( gethTypes "github.com/ethereum/go-ethereum/core/types" gethParams "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" + "github.com/onflow/cadence/runtime/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/flow-go/fvm" "github.com/onflow/flow-go/fvm/errors" "github.com/onflow/flow-go/fvm/evm/emulator" "github.com/onflow/flow-go/fvm/evm/emulator/database" @@ -25,11 +27,15 @@ import ( // TODO add test for fatal errors +var flowTokenAddress = common.Address(fvm.FlowTokenAddress(flow.Emulator.Chain())) + func TestHandler_TransactionRun(t *testing.T) { t.Parallel() t.Run("test - transaction run (happy case)", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + t.Parallel() + + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { testutils.RunWithEOATestAccount(t, backend, rootAddr, func(eoa *testutils.EOATestAccount) { @@ -55,8 +61,7 @@ func TestHandler_TransactionRun(t *testing.T) { return result, nil }, } - - handler := handler.NewContractHandler(bs, aa, backend, em) + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) coinbase := types.NewAddress(gethCommon.Address{}) @@ -104,7 +109,9 @@ func TestHandler_TransactionRun(t *testing.T) { }) t.Run("test - transaction run (unhappy cases)", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + t.Parallel() + + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { testutils.RunWithEOATestAccount(t, backend, rootAddr, func(eoa *testutils.EOATestAccount) { @@ -119,7 +126,7 @@ func TestHandler_TransactionRun(t *testing.T) { return &types.Result{}, types.NewEVMExecutionError(fmt.Errorf("some sort of error")) }, } - handler := handler.NewContractHandler(bs, aa, backend, em) + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) coinbase := types.NewAddress(gethCommon.Address{}) @@ -164,21 +171,11 @@ func TestHandler_TransactionRun(t *testing.T) { }) t.Run("test running transaction (with integrated emulator)", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { - testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { - - bs, err := handler.NewBlockStore(backend, rootAddr) - require.NoError(t, err) + t.Parallel() - aa, err := handler.NewAddressAllocator(backend, rootAddr) - require.NoError(t, err) - - db, err := database.NewDatabase(backend, rootAddr) - require.NoError(t, err) - - emulator := emulator.NewEmulator(db) - - handler := handler.NewContractHandler(bs, aa, backend, emulator) + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { + testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { + _, handler := SetupHandler(t, backend, rootAddr) eoa := testutils.GetTestEOAAccount(t, testutils.EOATestAccount1KeyHex) @@ -230,19 +227,12 @@ func TestHandler_OpsWithoutEmulator(t *testing.T) { t.Parallel() t.Run("test last executed block call", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { - testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { - bs, err := handler.NewBlockStore(backend, rootAddr) - require.NoError(t, err) - - aa, err := handler.NewAddressAllocator(backend, rootAddr) - require.NoError(t, err) + t.Parallel() - db, err := database.NewDatabase(backend, testutils.TestFlowEVMRootAddress) - require.NoError(t, err) - emulator := emulator.NewEmulator(db) + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { + testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { + _, handler := SetupHandler(t, backend, rootAddr) - handler := handler.NewContractHandler(bs, aa, backend, emulator) // test call last executed block without initialization b := handler.LastExecutedBlock() require.Equal(t, types.GenesisBlock, b) @@ -262,7 +252,9 @@ func TestHandler_OpsWithoutEmulator(t *testing.T) { }) t.Run("test address allocation", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + t.Parallel() + + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { blockchain, err := handler.NewBlockStore(backend, rootAddr) require.NoError(t, err) @@ -270,7 +262,8 @@ func TestHandler_OpsWithoutEmulator(t *testing.T) { aa, err := handler.NewAddressAllocator(backend, rootAddr) require.NoError(t, err) - handler := handler.NewContractHandler(blockchain, aa, backend, nil) + handler := handler.NewContractHandler(flowTokenAddress, blockchain, aa, backend, nil) + foa := handler.AllocateAddress() require.NotNil(t, foa) @@ -284,20 +277,12 @@ func TestHandler_OpsWithoutEmulator(t *testing.T) { func TestHandler_BridgedAccount(t *testing.T) { t.Run("test deposit/withdraw (with integrated emulator)", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { - testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { - bs, err := handler.NewBlockStore(backend, rootAddr) - require.NoError(t, err) - - aa, err := handler.NewAddressAllocator(backend, rootAddr) - require.NoError(t, err) - - db, err := database.NewDatabase(backend, rootAddr) - require.NoError(t, err) + t.Parallel() - emulator := emulator.NewEmulator(db) + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { + testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { + _, handler := SetupHandler(t, backend, rootAddr) - handler := handler.NewContractHandler(bs, aa, backend, emulator) foa := handler.AccountByAddress(handler.AllocateAddress(), true) require.NotNil(t, foa) @@ -360,7 +345,9 @@ func TestHandler_BridgedAccount(t *testing.T) { }) t.Run("test withdraw (unhappy case)", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + t.Parallel() + + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { testutils.RunWithEOATestAccount(t, backend, rootAddr, func(eoa *testutils.EOATestAccount) { bs, err := handler.NewBlockStore(backend, rootAddr) @@ -372,7 +359,8 @@ func TestHandler_BridgedAccount(t *testing.T) { // Withdraw calls are only possible within FOA accounts assertPanic(t, types.IsAUnAuthroizedMethodCallError, func() { em := &testutils.TestEmulator{} - handler := handler.NewContractHandler(bs, aa, backend, em) + + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) account := handler.AccountByAddress(testutils.RandomAddress(t), false) account.Withdraw(types.Balance(1)) @@ -385,8 +373,10 @@ func TestHandler_BridgedAccount(t *testing.T) { return &types.Result{}, types.NewEVMExecutionError(fmt.Errorf("some sort of error")) }, } - handler := handler.NewContractHandler(bs, aa, backend, em) + + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) account := handler.AccountByAddress(testutils.RandomAddress(t), true) + account.Withdraw(types.Balance(1)) }) @@ -397,8 +387,10 @@ func TestHandler_BridgedAccount(t *testing.T) { return &types.Result{}, types.NewEVMExecutionError(fmt.Errorf("some sort of error")) }, } - handler := handler.NewContractHandler(bs, aa, backend, em) + + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) account := handler.AccountByAddress(testutils.RandomAddress(t), true) + account.Withdraw(types.Balance(0)) }) @@ -409,8 +401,10 @@ func TestHandler_BridgedAccount(t *testing.T) { return &types.Result{}, types.NewFatalError(fmt.Errorf("some sort of fatal error")) }, } - handler := handler.NewContractHandler(bs, aa, backend, em) + + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) account := handler.AccountByAddress(testutils.RandomAddress(t), true) + account.Withdraw(types.Balance(0)) }) }) @@ -419,7 +413,9 @@ func TestHandler_BridgedAccount(t *testing.T) { }) t.Run("test deposit (unhappy case)", func(t *testing.T) { - testutils.RunWithTestBackend(t, func(backend types.Backend) { + t.Parallel() + + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { testutils.RunWithEOATestAccount(t, backend, rootAddr, func(eoa *testutils.EOATestAccount) { bs, err := handler.NewBlockStore(backend, rootAddr) @@ -435,8 +431,10 @@ func TestHandler_BridgedAccount(t *testing.T) { return &types.Result{}, types.NewEVMExecutionError(fmt.Errorf("some sort of error")) }, } - handler := handler.NewContractHandler(bs, aa, backend, em) + + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) account := handler.AccountByAddress(testutils.RandomAddress(t), true) + account.Deposit(types.NewFlowTokenVault(1)) }) @@ -447,8 +445,10 @@ func TestHandler_BridgedAccount(t *testing.T) { return &types.Result{}, types.NewFatalError(fmt.Errorf("some sort of fatal error")) }, } - handler := handler.NewContractHandler(bs, aa, backend, em) + + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, em) account := handler.AccountByAddress(testutils.RandomAddress(t), true) + account.Deposit(types.NewFlowTokenVault(1)) }) }) @@ -457,21 +457,13 @@ func TestHandler_BridgedAccount(t *testing.T) { }) t.Run("test deploy/call (with integrated emulator)", func(t *testing.T) { + t.Parallel() + // TODO update this test with events, gas metering, etc - testutils.RunWithTestBackend(t, func(backend types.Backend) { + testutils.RunWithTestBackend(t, func(backend *testutils.TestBackend) { testutils.RunWithTestFlowEVMRootAddress(t, backend, func(rootAddr flow.Address) { - bs, err := handler.NewBlockStore(backend, rootAddr) - require.NoError(t, err) - - aa, err := handler.NewAddressAllocator(backend, rootAddr) - require.NoError(t, err) - - db, err := database.NewDatabase(backend, rootAddr) - require.NoError(t, err) + _, handler := SetupHandler(t, backend, rootAddr) - emulator := emulator.NewEmulator(db) - - handler := handler.NewContractHandler(bs, aa, backend, emulator) foa := handler.AccountByAddress(handler.AllocateAddress(), true) require.NotNil(t, foa) @@ -481,7 +473,7 @@ func TestHandler_BridgedAccount(t *testing.T) { vault := types.NewFlowTokenVault(orgBalance) foa.Deposit(vault) - testContract := testutils.GetTestContract(t) + testContract := testutils.GetStorageTestContract(t) addr := foa.Deploy(testContract.ByteCode, math.MaxUint64, types.Balance(0)) require.NotNil(t, addr) @@ -489,13 +481,13 @@ func TestHandler_BridgedAccount(t *testing.T) { _ = foa.Call( addr, - testContract.MakeStoreCallData(t, num), + testContract.MakeCallData(t, "store", num), math.MaxUint64, types.Balance(0)) ret := foa.Call( addr, - testContract.MakeRetrieveCallData(t), + testContract.MakeCallData(t, "retrieve"), math.MaxUint64, types.Balance(0)) @@ -528,3 +520,19 @@ func assertPanic(t *testing.T, check checkError, f func()) { }() f() } + +func SetupHandler(t testing.TB, backend types.Backend, rootAddr flow.Address) (*database.Database, *handler.ContractHandler) { + bs, err := handler.NewBlockStore(backend, rootAddr) + require.NoError(t, err) + + aa, err := handler.NewAddressAllocator(backend, rootAddr) + require.NoError(t, err) + + db, err := database.NewDatabase(backend, rootAddr) + require.NoError(t, err) + + emulator := emulator.NewEmulator(db) + + handler := handler.NewContractHandler(flowTokenAddress, bs, aa, backend, emulator) + return db, handler +} diff --git a/fvm/evm/stdlib/contract.cdc b/fvm/evm/stdlib/contract.cdc new file mode 100644 index 00000000000..60f544a68b0 --- /dev/null +++ b/fvm/evm/stdlib/contract.cdc @@ -0,0 +1,134 @@ +import "FlowToken" + +access(all) +contract EVM { + + /// EVMAddress is an EVM-compatible address + access(all) + struct EVMAddress { + + /// Bytes of the address + access(all) + let bytes: [UInt8; 20] + + /// Constructs a new EVM address from the given byte representation + init(bytes: [UInt8; 20]) { + self.bytes = bytes + } + + /// Deposits the given vault into the EVM account with the given address + access(all) + fun deposit(from: @FlowToken.Vault) { + InternalEVM.deposit( + from: <-from, + to: self.bytes + ) + } + } + + access(all) + struct Balance { + + /// The balance in FLOW + access(all) + let flow: UFix64 + + /// Constructs a new balance, given the balance in FLOW + init(flow: UFix64) { + self.flow = flow + } + + // TODO: + // /// Returns the balance in terms of atto-FLOW. + // /// Atto-FLOW is the smallest denomination of FLOW inside EVM + // access(all) + // fun toAttoFlow(): UInt64 + } + + access(all) + resource BridgedAccount { + + access(self) + let addressBytes: [UInt8; 20] + + init(addressBytes: [UInt8; 20]) { + self.addressBytes = addressBytes + } + + /// The EVM address of the bridged account + access(all) + fun address(): EVMAddress { + // Always create a new EVMAddress instance + return EVMAddress(bytes: self.addressBytes) + } + + /// Deposits the given vault into the bridged account's balance + access(all) + fun deposit(from: @FlowToken.Vault) { + self.address().deposit(from: <-from) + } + + /// Withdraws the balance from the bridged account's balance + access(all) + fun withdraw(balance: Balance): @FlowToken.Vault { + let vault <- InternalEVM.withdraw( + from: self.addressBytes, + amount: balance.flow + ) as! @FlowToken.Vault + return <-vault + } + + /// Deploys a contract to the EVM environment. + /// Returns the address of the newly deployed contract + access(all) + fun deploy( + code: [UInt8], + gasLimit: UInt64, + value: Balance + ): EVMAddress { + let addressBytes = InternalEVM.deploy( + from: self.addressBytes, + code: code, + gasLimit: gasLimit, + value: value.flow + ) + return EVMAddress(bytes: addressBytes) + } + + /// Calls a function with the given data. + /// The execution is limited by the given amount of gas + access(all) + fun call( + to: EVMAddress, + data: [UInt8], + gasLimit: UInt64, + value: Balance + ): [UInt8] { + return InternalEVM.call( + from: self.addressBytes, + to: to.bytes, + data: data, + gasLimit: gasLimit, + value: value.flow + ) + } + } + + /// Creates a new bridged account + access(all) + fun createBridgedAccount(): @BridgedAccount { + return <-create BridgedAccount( + addressBytes: InternalEVM.createBridgedAccount() + ) + } + + /// Runs an a RLP-encoded EVM transaction, deducts the gas fees, + /// and deposits the gas fees into the provided coinbase address. + /// + /// Returns true if the transaction was successful, + /// and returns false otherwise + access(all) + fun run(tx: [UInt8], coinbase: EVMAddress) { + InternalEVM.run(tx: tx, coinbase: coinbase.bytes) + } +} diff --git a/fvm/evm/stdlib/contract.go b/fvm/evm/stdlib/contract.go new file mode 100644 index 00000000000..e66fa8a6787 --- /dev/null +++ b/fvm/evm/stdlib/contract.go @@ -0,0 +1,658 @@ +package stdlib + +import ( + _ "embed" + "fmt" + "regexp" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/errors" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" + "github.com/onflow/cadence/runtime/stdlib" + + "github.com/onflow/flow-go/fvm/evm/types" + "github.com/onflow/flow-go/model/flow" +) + +//go:embed contract.cdc +var contractCode string + +var flowTokenImportPattern = regexp.MustCompile(`^import "FlowToken"\n`) + +func ContractCode(flowTokenAddress flow.Address) []byte { + return []byte(flowTokenImportPattern.ReplaceAllString( + contractCode, + fmt.Sprintf("import FlowToken from %s", flowTokenAddress.HexWithPrefix()), + )) +} + +const ContractName = "EVM" + +var EVMTransactionBytesCadenceType = cadence.NewVariableSizedArrayType(cadence.TheUInt8Type) +var evmTransactionBytesType = sema.NewVariableSizedType(nil, sema.UInt8Type) + +var evmAddressBytesType = sema.NewConstantSizedType(nil, sema.UInt8Type, types.AddressLength) +var evmAddressBytesStaticType = interpreter.ConvertSemaArrayTypeToStaticArrayType(nil, evmAddressBytesType) +var EVMAddressBytesCadenceType = cadence.NewConstantSizedArrayType(types.AddressLength, cadence.TheUInt8Type) + +const internalEVMTypeRunFunctionName = "run" + +var internalEVMTypeRunFunctionType = &sema.FunctionType{ + Parameters: []sema.Parameter{ + { + Label: "tx", + TypeAnnotation: sema.NewTypeAnnotation(evmTransactionBytesType), + }, + { + Label: "coinbase", + TypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation(sema.BoolType), +} + +func newInternalEVMTypeRunFunction( + gauge common.MemoryGauge, + handler types.ContractHandler, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + internalEVMTypeRunFunctionType, + func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + locationRange := invocation.LocationRange + + // Get transaction argument + + transactionValue, ok := invocation.Arguments[0].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + transaction, err := interpreter.ByteArrayValueToByteSlice(inter, transactionValue, locationRange) + if err != nil { + panic(err) + } + + // Get coinbase argument + + coinbaseValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + coinbase, err := interpreter.ByteArrayValueToByteSlice(inter, coinbaseValue, locationRange) + if err != nil { + panic(err) + } + + // Run + + cb := types.NewAddressFromBytes(coinbase) + handler.Run(transaction, cb) + + return interpreter.Void + }, + ) +} + +func EVMAddressToAddressBytesArrayValue( + inter *interpreter.Interpreter, + address types.Address, +) *interpreter.ArrayValue { + var index int + return interpreter.NewArrayValueWithIterator( + inter, + evmAddressBytesStaticType, + common.ZeroAddress, + types.AddressLength, + func() interpreter.Value { + if index >= types.AddressLength { + return nil + } + result := interpreter.NewUInt8Value(inter, func() uint8 { + return address[index] + }) + index++ + return result + }, + ) +} + +const internalEVMTypeCallFunctionName = "call" + +var internalEVMTypeCallFunctionType = &sema.FunctionType{ + Parameters: []sema.Parameter{ + { + Label: "from", + TypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), + }, + { + Label: "to", + TypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), + }, + { + Label: "data", + TypeAnnotation: sema.NewTypeAnnotation(sema.ByteArrayType), + }, + { + Label: "gasLimit", + TypeAnnotation: sema.NewTypeAnnotation(sema.UInt64Type), + }, + { + Label: "value", + TypeAnnotation: sema.NewTypeAnnotation(sema.UFix64Type), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation(sema.ByteArrayType), +} + +func AddressBytesArrayValueToEVMAddress( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + addressBytesValue *interpreter.ArrayValue, +) ( + result types.Address, + err error, +) { + // Convert + + var bytes []byte + bytes, err = interpreter.ByteArrayValueToByteSlice( + inter, + addressBytesValue, + locationRange, + ) + if err != nil { + return result, err + } + + // Check length + + length := len(bytes) + const expectedLength = types.AddressLength + if length != expectedLength { + return result, errors.NewDefaultUserError( + "invalid address length: got %d, expected %d", + length, + expectedLength, + ) + } + + copy(result[:], bytes) + + return result, nil +} + +func newInternalEVMTypeCallFunction( + gauge common.MemoryGauge, + handler types.ContractHandler, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + internalEVMTypeCallFunctionType, + func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + locationRange := invocation.LocationRange + + // Get from address + + fromAddressValue, ok := invocation.Arguments[0].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + fromAddress, err := AddressBytesArrayValueToEVMAddress(inter, locationRange, fromAddressValue) + if err != nil { + panic(err) + } + + // Get to address + + toAddressValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + toAddress, err := AddressBytesArrayValueToEVMAddress(inter, locationRange, toAddressValue) + if err != nil { + panic(err) + } + + // Get data + + dataValue, ok := invocation.Arguments[2].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + data, err := interpreter.ByteArrayValueToByteSlice(inter, dataValue, locationRange) + if err != nil { + panic(err) + } + + // Get gas limit + + gasLimitValue, ok := invocation.Arguments[3].(interpreter.UInt64Value) + if !ok { + panic(errors.NewUnreachableError()) + } + + gasLimit := types.GasLimit(gasLimitValue) + + // Get balance + + balanceValue, ok := invocation.Arguments[4].(interpreter.UFix64Value) + if !ok { + panic(errors.NewUnreachableError()) + } + + balance := types.Balance(balanceValue) + + // Call + + const isAuthorized = true + account := handler.AccountByAddress(fromAddress, isAuthorized) + result := account.Call(toAddress, data, gasLimit, balance) + + return interpreter.ByteSliceToByteArrayValue(inter, result) + }, + ) +} + +const internalEVMTypeCreateBridgedAccountFunctionName = "createBridgedAccount" + +var internalEVMTypeCreateBridgedAccountFunctionType = &sema.FunctionType{ + ReturnTypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), +} + +func newInternalEVMTypeCreateBridgedAccountFunction( + gauge common.MemoryGauge, + handler types.ContractHandler, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + internalEVMTypeCreateBridgedAccountFunctionType, + func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + address := handler.AllocateAddress() + return EVMAddressToAddressBytesArrayValue(inter, address) + }, + ) +} + +const internalEVMTypeDepositFunctionName = "deposit" + +var internalEVMTypeDepositFunctionType = &sema.FunctionType{ + Parameters: []sema.Parameter{ + { + Label: "from", + TypeAnnotation: sema.NewTypeAnnotation(sema.AnyResourceType), + }, + { + Label: "to", + TypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation(sema.VoidType), +} + +const fungibleTokenVaultTypeBalanceFieldName = "balance" + +func newInternalEVMTypeDepositFunction( + gauge common.MemoryGauge, + handler types.ContractHandler, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + internalEVMTypeCallFunctionType, + func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + locationRange := invocation.LocationRange + + // Get from vault + + fromValue, ok := invocation.Arguments[0].(*interpreter.CompositeValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + amountValue, ok := fromValue.GetField( + inter, + locationRange, + fungibleTokenVaultTypeBalanceFieldName, + ).(interpreter.UFix64Value) + if !ok { + panic(errors.NewUnreachableError()) + } + + amount := types.Balance(amountValue) + + // Get to address + + toAddressValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + toAddress, err := AddressBytesArrayValueToEVMAddress(inter, locationRange, toAddressValue) + if err != nil { + panic(err) + } + + // NOTE: We're intentionally not destroying the vault here, + // because the value of it is supposed to be "kept alive". + // Destroying would incorrectly be equivalent to a burn and decrease the total supply, + // and a withdrawal would then have to perform an actual mint of new tokens. + + // Deposit + + const isAuthorized = false + account := handler.AccountByAddress(toAddress, isAuthorized) + account.Deposit(types.NewFlowTokenVault(amount)) + + return interpreter.Void + }, + ) +} + +const internalEVMTypeWithdrawFunctionName = "withdraw" + +var internalEVMTypeWithdrawFunctionType = &sema.FunctionType{ + Parameters: []sema.Parameter{ + { + Label: "from", + TypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), + }, + { + Label: "amount", + TypeAnnotation: sema.NewTypeAnnotation(sema.UFix64Type), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation(sema.AnyResourceType), +} + +func newInternalEVMTypeWithdrawFunction( + gauge common.MemoryGauge, + handler types.ContractHandler, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + internalEVMTypeCallFunctionType, + func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + locationRange := invocation.LocationRange + + // Get from address + + fromAddressValue, ok := invocation.Arguments[0].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + fromAddress, err := AddressBytesArrayValueToEVMAddress(inter, locationRange, fromAddressValue) + if err != nil { + panic(err) + } + + // Get amount + + amountValue, ok := invocation.Arguments[1].(interpreter.UFix64Value) + if !ok { + panic(errors.NewUnreachableError()) + } + + amount := types.Balance(amountValue) + + // Withdraw + + const isAuthorized = true + account := handler.AccountByAddress(fromAddress, isAuthorized) + vault := account.Withdraw(amount) + + // TODO: improve: maybe call actual constructor + return interpreter.NewCompositeValue( + inter, + locationRange, + common.NewAddressLocation(gauge, handler.FlowTokenAddress(), "FlowToken"), + "FlowToken.Vault", + common.CompositeKindResource, + []interpreter.CompositeField{ + { + Name: "balance", + Value: interpreter.NewUFix64Value(gauge, func() uint64 { + return uint64(vault.Balance()) + }), + }, + }, + common.ZeroAddress, + ) + }, + ) +} + +const internalEVMTypeDeployFunctionName = "deploy" + +var internalEVMTypeDeployFunctionType = &sema.FunctionType{ + Parameters: []sema.Parameter{ + { + Label: "from", + TypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), + }, + { + Label: "code", + TypeAnnotation: sema.NewTypeAnnotation(sema.ByteArrayType), + }, + { + Label: "gasLimit", + TypeAnnotation: sema.NewTypeAnnotation(sema.UInt64Type), + }, + { + Label: "value", + TypeAnnotation: sema.NewTypeAnnotation(sema.UFix64Type), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation(evmAddressBytesType), +} + +func newInternalEVMTypeDeployFunction( + gauge common.MemoryGauge, + handler types.ContractHandler, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + internalEVMTypeCallFunctionType, + func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + locationRange := invocation.LocationRange + + // Get from address + + fromAddressValue, ok := invocation.Arguments[0].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + fromAddress, err := AddressBytesArrayValueToEVMAddress(inter, locationRange, fromAddressValue) + if err != nil { + panic(err) + } + + // Get code + + codeValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + code, err := interpreter.ByteArrayValueToByteSlice(inter, codeValue, locationRange) + if err != nil { + panic(err) + } + + // Get gas limit + + gasLimitValue, ok := invocation.Arguments[2].(interpreter.UInt64Value) + if !ok { + panic(errors.NewUnreachableError()) + } + + gasLimit := types.GasLimit(gasLimitValue) + + // Get value + + amountValue, ok := invocation.Arguments[3].(interpreter.UFix64Value) + if !ok { + panic(errors.NewUnreachableError()) + } + + amount := types.Balance(amountValue) + + // Deploy + + const isAuthorized = true + account := handler.AccountByAddress(fromAddress, isAuthorized) + address := account.Deploy(code, gasLimit, amount) + + return EVMAddressToAddressBytesArrayValue(inter, address) + }, + ) +} + +func NewInternalEVMContractValue( + gauge common.MemoryGauge, + handler types.ContractHandler, +) *interpreter.SimpleCompositeValue { + return interpreter.NewSimpleCompositeValue( + gauge, + InternalEVMContractType.ID(), + internalEVMContractStaticType, + InternalEVMContractType.Fields, + map[string]interpreter.Value{ + internalEVMTypeRunFunctionName: newInternalEVMTypeRunFunction(gauge, handler), + internalEVMTypeCreateBridgedAccountFunctionName: newInternalEVMTypeCreateBridgedAccountFunction(gauge, handler), + internalEVMTypeCallFunctionName: newInternalEVMTypeCallFunction(gauge, handler), + internalEVMTypeDepositFunctionName: newInternalEVMTypeDepositFunction(gauge, handler), + internalEVMTypeWithdrawFunctionName: newInternalEVMTypeWithdrawFunction(gauge, handler), + internalEVMTypeDeployFunctionName: newInternalEVMTypeDeployFunction(gauge, handler), + }, + nil, + nil, + nil, + ) +} + +const InternalEVMContractName = "InternalEVM" + +var InternalEVMContractType = func() *sema.CompositeType { + ty := &sema.CompositeType{ + Identifier: InternalEVMContractName, + Kind: common.CompositeKindContract, + } + + ty.Members = sema.MembersAsMap([]*sema.Member{ + sema.NewUnmeteredPublicFunctionMember( + ty, + internalEVMTypeRunFunctionName, + internalEVMTypeRunFunctionType, + "", + ), + sema.NewUnmeteredPublicFunctionMember( + ty, + internalEVMTypeCreateBridgedAccountFunctionName, + internalEVMTypeCreateBridgedAccountFunctionType, + "", + ), + sema.NewUnmeteredPublicFunctionMember( + ty, + internalEVMTypeCallFunctionName, + internalEVMTypeCallFunctionType, + "", + ), + sema.NewUnmeteredPublicFunctionMember( + ty, + internalEVMTypeDepositFunctionName, + internalEVMTypeDepositFunctionType, + "", + ), + sema.NewUnmeteredPublicFunctionMember( + ty, + internalEVMTypeWithdrawFunctionName, + internalEVMTypeWithdrawFunctionType, + "", + ), + sema.NewUnmeteredPublicFunctionMember( + ty, + internalEVMTypeDeployFunctionName, + internalEVMTypeDeployFunctionType, + "", + ), + }) + return ty +}() + +var internalEVMContractStaticType = interpreter.ConvertSemaCompositeTypeToStaticCompositeType( + nil, + InternalEVMContractType, +) + +func newInternalEVMStandardLibraryValue( + gauge common.MemoryGauge, + handler types.ContractHandler, +) stdlib.StandardLibraryValue { + return stdlib.StandardLibraryValue{ + Name: InternalEVMContractName, + Type: InternalEVMContractType, + Value: NewInternalEVMContractValue(gauge, handler), + Kind: common.DeclarationKindContract, + } +} + +var internalEVMStandardLibraryType = stdlib.StandardLibraryType{ + Name: InternalEVMContractName, + Type: InternalEVMContractType, + Kind: common.DeclarationKindContract, +} + +func SetupEnvironment(env runtime.Environment, handler types.ContractHandler, service flow.Address) { + location := common.NewAddressLocation(nil, common.Address(service), ContractName) + env.DeclareType( + internalEVMStandardLibraryType, + location, + ) + env.DeclareValue( + newInternalEVMStandardLibraryValue(nil, handler), + location, + ) +} + +func NewEVMAddressCadenceType(address common.Address) *cadence.StructType { + return cadence.NewStructType( + common.NewAddressLocation(nil, address, ContractName), + "EVM.EVMAddress", + []cadence.Field{ + { + Identifier: "bytes", + Type: EVMAddressBytesCadenceType, + }, + }, + nil, + ) +} + +func NewBalanceCadenceType(address common.Address) *cadence.StructType { + return cadence.NewStructType( + common.NewAddressLocation(nil, address, ContractName), + "EVM.Balance", + []cadence.Field{ + { + Identifier: "flow", + Type: cadence.UFix64Type{}, + }, + }, + nil, + ) +} diff --git a/fvm/evm/stdlib/contract_test.go b/fvm/evm/stdlib/contract_test.go new file mode 100644 index 00000000000..2ba15c3eb16 --- /dev/null +++ b/fvm/evm/stdlib/contract_test.go @@ -0,0 +1,1145 @@ +package stdlib_test + +import ( + "encoding/binary" + "testing" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/encoding/json" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/common" + contracts2 "github.com/onflow/flow-core-contracts/lib/go/contracts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/fvm/blueprints" + "github.com/onflow/flow-go/fvm/evm/stdlib" + . "github.com/onflow/flow-go/fvm/evm/testutils" + "github.com/onflow/flow-go/fvm/evm/types" + "github.com/onflow/flow-go/model/flow" +) + +type testContractHandler struct { + flowTokenAddress common.Address + allocateAddress func() types.Address + addressIndex uint64 + accountByAddress func(types.Address, bool) types.Account + lastExecutedBlock func() *types.Block + run func(tx []byte, coinbase types.Address) +} + +func (t *testContractHandler) FlowTokenAddress() common.Address { + return t.flowTokenAddress +} + +var _ types.ContractHandler = &testContractHandler{} + +func (t *testContractHandler) AllocateAddress() types.Address { + if t.allocateAddress == nil { + t.addressIndex++ + var address types.Address + binary.LittleEndian.PutUint64(address[:], t.addressIndex) + return address + } + return t.allocateAddress() +} + +func (t *testContractHandler) AccountByAddress(addr types.Address, isAuthorized bool) types.Account { + if t.accountByAddress == nil { + panic("unexpected AccountByAddress") + } + return t.accountByAddress(addr, isAuthorized) +} + +func (t *testContractHandler) LastExecutedBlock() *types.Block { + if t.lastExecutedBlock == nil { + panic("unexpected LastExecutedBlock") + } + return t.lastExecutedBlock() +} + +func (t *testContractHandler) Run(tx []byte, coinbase types.Address) { + if t.run == nil { + panic("unexpected Run") + } + t.run(tx, coinbase) +} + +type testFlowAccount struct { + address types.Address + balance func() types.Balance + transfer func(address types.Address, balance types.Balance) + deposit func(vault *types.FLOWTokenVault) + withdraw func(balance types.Balance) *types.FLOWTokenVault + deploy func(code types.Code, limit types.GasLimit, balance types.Balance) types.Address + call func(address types.Address, data types.Data, limit types.GasLimit, balance types.Balance) types.Data +} + +var _ types.Account = &testFlowAccount{} + +func (t *testFlowAccount) Address() types.Address { + return t.address +} + +func (t *testFlowAccount) Balance() types.Balance { + if t.balance == nil { + return types.Balance(0) + } + return t.balance() +} + +func (t *testFlowAccount) Transfer(address types.Address, balance types.Balance) { + if t.transfer == nil { + panic("unexpected Transfer") + } + t.transfer(address, balance) +} + +func (t *testFlowAccount) Deposit(vault *types.FLOWTokenVault) { + if t.deposit == nil { + panic("unexpected Deposit") + } + t.deposit(vault) +} + +func (t *testFlowAccount) Withdraw(balance types.Balance) *types.FLOWTokenVault { + if t.withdraw == nil { + panic("unexpected Withdraw") + } + return t.withdraw(balance) +} + +func (t *testFlowAccount) Deploy(code types.Code, limit types.GasLimit, balance types.Balance) types.Address { + if t.deploy == nil { + panic("unexpected Deploy") + } + return t.deploy(code, limit, balance) +} + +func (t *testFlowAccount) Call(address types.Address, data types.Data, limit types.GasLimit, balance types.Balance) types.Data { + if t.call == nil { + panic("unexpected Call") + } + return t.call(address, data, limit, balance) +} + +func deployContracts( + t *testing.T, + rt runtime.Runtime, + contractsAddress flow.Address, + runtimeInterface *TestRuntimeInterface, + transactionEnvironment runtime.Environment, + nextTransactionLocation func() common.TransactionLocation, +) { + + contractsAddressHex := contractsAddress.Hex() + + contracts := []struct { + name string + code []byte + deployTx []byte + }{ + { + name: "FungibleToken", + code: contracts2.FungibleToken(), + }, + { + name: "NonFungibleToken", + code: contracts2.NonFungibleToken(), + }, + { + name: "MetadataViews", + code: contracts2.MetadataViews( + contractsAddressHex, + contractsAddressHex, + ), + }, + { + name: "FungibleTokenMetadataViews", + code: contracts2.FungibleTokenMetadataViews( + contractsAddressHex, + contractsAddressHex, + ), + }, + { + name: "ViewResolver", + code: contracts2.ViewResolver(), + }, + { + name: "FlowToken", + code: contracts2.FlowToken( + contractsAddressHex, + contractsAddressHex, + contractsAddressHex, + ), + deployTx: []byte(` + transaction(name: String, code: String) { + prepare(signer: AuthAccount) { + signer.contracts.add(name: name, code: code.utf8, signer) + } + } + `), + }, + { + name: stdlib.ContractName, + code: stdlib.ContractCode(contractsAddress), + }, + } + + for _, contract := range contracts { + + deployTx := contract.deployTx + if len(deployTx) == 0 { + deployTx = blueprints.DeployContractTransactionTemplate + } + + err := rt.ExecuteTransaction( + runtime.Script{ + Source: deployTx, + Arguments: EncodeArgs([]cadence.Value{ + cadence.String(contract.name), + cadence.String(contract.code), + }), + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: transactionEnvironment, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + } + +} + +func newEVMTransactionEnvironment(handler types.ContractHandler, service flow.Address) runtime.Environment { + transactionEnvironment := runtime.NewBaseInterpreterEnvironment(runtime.Config{}) + + stdlib.SetupEnvironment( + transactionEnvironment, + handler, + service, + ) + + return transactionEnvironment +} + +func newEVMScriptEnvironment(handler types.ContractHandler, service flow.Address) runtime.Environment { + scriptEnvironment := runtime.NewScriptInterpreterEnvironment(runtime.Config{}) + + stdlib.SetupEnvironment( + scriptEnvironment, + handler, + service, + ) + + return scriptEnvironment +} + +func TestEVMAddressConstructionAndReturn(t *testing.T) { + + t.Parallel() + + handler := &testContractHandler{} + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + + access(all) + fun main(_ bytes: [UInt8; 20]): EVM.EVMAddress { + return EVM.EVMAddress(bytes: bytes) + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + addressBytesArray := cadence.NewArray([]cadence.Value{ + cadence.UInt8(1), cadence.UInt8(1), + cadence.UInt8(2), cadence.UInt8(2), + cadence.UInt8(3), cadence.UInt8(3), + cadence.UInt8(4), cadence.UInt8(4), + cadence.UInt8(5), cadence.UInt8(5), + cadence.UInt8(6), cadence.UInt8(6), + cadence.UInt8(7), cadence.UInt8(7), + cadence.UInt8(8), cadence.UInt8(8), + cadence.UInt8(9), cadence.UInt8(9), + cadence.UInt8(10), cadence.UInt8(10), + }).WithType(stdlib.EVMAddressBytesCadenceType) + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + result, err := rt.ExecuteScript( + runtime.Script{ + Source: script, + Arguments: EncodeArgs([]cadence.Value{ + addressBytesArray, + }), + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + evmAddressCadenceType := stdlib.NewEVMAddressCadenceType(common.Address(contractsAddress)) + + assert.Equal(t, + cadence.Struct{ + StructType: evmAddressCadenceType, + Fields: []cadence.Value{ + addressBytesArray, + }, + }, + result, + ) +} + +func TestBalanceConstructionAndReturn(t *testing.T) { + + t.Parallel() + + handler := &testContractHandler{} + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + + access(all) + fun main(_ flow: UFix64): EVM.Balance { + return EVM.Balance(flow: flow) + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + flowValue, err := cadence.NewUFix64FromParts(1, 23000000) + require.NoError(t, err) + + result, err := rt.ExecuteScript( + runtime.Script{ + Source: script, + Arguments: EncodeArgs([]cadence.Value{ + flowValue, + }), + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + evmBalanceCadenceType := stdlib.NewBalanceCadenceType(common.Address(contractsAddress)) + + assert.Equal(t, + cadence.Struct{ + StructType: evmBalanceCadenceType, + Fields: []cadence.Value{ + flowValue, + }, + }, + result, + ) +} + +func TestEVMRun(t *testing.T) { + + t.Parallel() + + evmTx := cadence.NewArray([]cadence.Value{ + cadence.UInt8(1), + cadence.UInt8(2), + cadence.UInt8(3), + }).WithType(stdlib.EVMTransactionBytesCadenceType) + + coinbase := cadence.NewArray([]cadence.Value{ + cadence.UInt8(1), cadence.UInt8(1), + cadence.UInt8(2), cadence.UInt8(2), + cadence.UInt8(3), cadence.UInt8(3), + cadence.UInt8(4), cadence.UInt8(4), + cadence.UInt8(5), cadence.UInt8(5), + cadence.UInt8(6), cadence.UInt8(6), + cadence.UInt8(7), cadence.UInt8(7), + cadence.UInt8(8), cadence.UInt8(8), + cadence.UInt8(9), cadence.UInt8(9), + cadence.UInt8(10), cadence.UInt8(10), + }).WithType(stdlib.EVMAddressBytesCadenceType) + + runCalled := false + + handler := &testContractHandler{ + run: func(tx []byte, coinbase types.Address) { + runCalled = true + + assert.Equal(t, []byte{1, 2, 3}, tx) + assert.Equal(t, + types.Address{ + 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, + }, + coinbase, + ) + + }, + } + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + + access(all) + fun main(tx: [UInt8], coinbaseBytes: [UInt8; 20]) { + let coinbase = EVM.EVMAddress(bytes: coinbaseBytes) + EVM.run(tx: tx, coinbase: coinbase) + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + _, err := rt.ExecuteScript( + runtime.Script{ + Source: script, + Arguments: EncodeArgs([]cadence.Value{evmTx, coinbase}), + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + assert.True(t, runCalled) +} + +func TestEVMCreateBridgedAccount(t *testing.T) { + + t.Parallel() + + handler := &testContractHandler{} + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + + access(all) + fun main(): [UInt8; 20] { + let bridgedAccount1 <- EVM.createBridgedAccount() + destroy bridgedAccount1 + + let bridgedAccount2 <- EVM.createBridgedAccount() + let bytes = bridgedAccount2.address().bytes + destroy bridgedAccount2 + + return bytes + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + actual, err := rt.ExecuteScript( + runtime.Script{ + Source: script, + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + expected := cadence.NewArray([]cadence.Value{ + cadence.UInt8(2), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + }).WithType(cadence.NewConstantSizedArrayType( + types.AddressLength, + cadence.UInt8Type{}, + )) + + require.Equal(t, expected, actual) +} + +func TestBridgedAccountCall(t *testing.T) { + + t.Parallel() + + expectedBalance, err := cadence.NewUFix64FromParts(1, 23000000) + require.NoError(t, err) + + handler := &testContractHandler{ + accountByAddress: func(fromAddress types.Address, isAuthorized bool) types.Account { + assert.Equal(t, types.Address{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, fromAddress) + assert.True(t, isAuthorized) + + return &testFlowAccount{ + address: fromAddress, + call: func( + toAddress types.Address, + data types.Data, + limit types.GasLimit, + balance types.Balance, + ) types.Data { + assert.Equal(t, types.Address{2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, toAddress) + assert.Equal(t, types.Data{4, 5, 6}, data) + assert.Equal(t, types.GasLimit(9999), limit) + assert.Equal(t, types.Balance(expectedBalance), balance) + + return types.Data{3, 1, 4} + }, + } + }, + } + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + + access(all) + fun main(): [UInt8] { + let bridgedAccount <- EVM.createBridgedAccount() + let response = bridgedAccount.call( + to: EVM.EVMAddress( + bytes: [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ), + data: [4, 5, 6], + gasLimit: 9999, + value: EVM.Balance(flow: 1.23) + ) + destroy bridgedAccount + return response + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + actual, err := rt.ExecuteScript( + runtime.Script{ + Source: script, + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + expected := cadence.NewArray([]cadence.Value{ + cadence.UInt8(3), + cadence.UInt8(1), + cadence.UInt8(4), + }).WithType(cadence.NewVariableSizedArrayType(cadence.UInt8Type{})) + + require.Equal(t, expected, actual) +} + +func TestEVMAddressDeposit(t *testing.T) { + + t.Parallel() + + expectedBalance, err := cadence.NewUFix64FromParts(1, 23000000) + require.NoError(t, err) + + var deposited bool + + handler := &testContractHandler{ + + accountByAddress: func(fromAddress types.Address, isAuthorized bool) types.Account { + assert.Equal(t, types.Address{2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, fromAddress) + assert.False(t, isAuthorized) + + return &testFlowAccount{ + address: fromAddress, + deposit: func(vault *types.FLOWTokenVault) { + deposited = true + assert.Equal( + t, + types.Balance(expectedBalance), + vault.Balance(), + ) + }, + } + }, + } + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + import FlowToken from 0x1 + + access(all) + fun main() { + let admin = getAuthAccount(0x1) + .borrow<&FlowToken.Administrator>(from: /storage/flowTokenAdmin)! + let minter <- admin.createNewMinter(allowedAmount: 1.23) + let vault <- minter.mintTokens(amount: 1.23) + destroy minter + + let address = EVM.EVMAddress( + bytes: [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + address.deposit(from: <-vault) + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + _, err = rt.ExecuteScript( + runtime.Script{ + Source: script, + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + require.True(t, deposited) +} + +func TestBridgedAccountWithdraw(t *testing.T) { + + t.Parallel() + + expectedDepositBalance, err := cadence.NewUFix64FromParts(2, 34000000) + require.NoError(t, err) + + expectedWithdrawBalance, err := cadence.NewUFix64FromParts(1, 23000000) + require.NoError(t, err) + + var deposited bool + var withdrew bool + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + handler := &testContractHandler{ + flowTokenAddress: common.Address(contractsAddress), + accountByAddress: func(fromAddress types.Address, isAuthorized bool) types.Account { + assert.Equal(t, types.Address{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, fromAddress) + assert.Equal(t, deposited, isAuthorized) + + return &testFlowAccount{ + address: fromAddress, + deposit: func(vault *types.FLOWTokenVault) { + deposited = true + assert.Equal(t, + types.Balance(expectedDepositBalance), + vault.Balance(), + ) + }, + withdraw: func(balance types.Balance) *types.FLOWTokenVault { + assert.Equal(t, + types.Balance(expectedWithdrawBalance), + balance, + ) + withdrew = true + return types.NewFlowTokenVault(balance) + }, + } + }, + } + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + import FlowToken from 0x1 + + access(all) + fun main(): UFix64 { + let admin = getAuthAccount(0x1) + .borrow<&FlowToken.Administrator>(from: /storage/flowTokenAdmin)! + let minter <- admin.createNewMinter(allowedAmount: 2.34) + let vault <- minter.mintTokens(amount: 2.34) + destroy minter + + let bridgedAccount <- EVM.createBridgedAccount() + bridgedAccount.address().deposit(from: <-vault) + + let vault2 <- bridgedAccount.withdraw(balance: EVM.Balance(flow: 1.23)) + let balance = vault2.balance + destroy bridgedAccount + destroy vault2 + + return balance + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + result, err := rt.ExecuteScript( + runtime.Script{ + Source: script, + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + assert.True(t, deposited) + assert.True(t, withdrew) + assert.Equal(t, expectedWithdrawBalance, result) +} + +func TestBridgedAccountDeploy(t *testing.T) { + + t.Parallel() + + var deployed bool + + contractsAddress := flow.BytesToAddress([]byte{0x1}) + + expectedBalance, err := cadence.NewUFix64FromParts(1, 23000000) + require.NoError(t, err) + + var handler *testContractHandler + handler = &testContractHandler{ + flowTokenAddress: common.Address(contractsAddress), + accountByAddress: func(fromAddress types.Address, isAuthorized bool) types.Account { + assert.Equal(t, types.Address{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, fromAddress) + assert.True(t, isAuthorized) + + return &testFlowAccount{ + address: fromAddress, + deploy: func(code types.Code, limit types.GasLimit, balance types.Balance) types.Address { + deployed = true + assert.Equal(t, types.Code{4, 5, 6}, code) + assert.Equal(t, types.GasLimit(9999), limit) + assert.Equal(t, types.Balance(expectedBalance), balance) + + return handler.AllocateAddress() + }, + } + }, + } + + transactionEnvironment := newEVMTransactionEnvironment(handler, contractsAddress) + scriptEnvironment := newEVMScriptEnvironment(handler, contractsAddress) + + rt := runtime.NewInterpreterRuntime(runtime.Config{}) + + script := []byte(` + import EVM from 0x1 + import FlowToken from 0x1 + + access(all) + fun main(): [UInt8; 20] { + let bridgedAccount <- EVM.createBridgedAccount() + let address = bridgedAccount.deploy( + code: [4, 5, 6], + gasLimit: 9999, + value: EVM.Balance(flow: 1.23) + ) + destroy bridgedAccount + return address.bytes + } + `) + + accountCodes := map[common.Location][]byte{} + var events []cadence.Event + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]runtime.Address, error) { + return []runtime.Address{runtime.Address(contractsAddress)}, nil + }, + OnResolveLocation: SingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nextScriptLocation := NewScriptLocationGenerator() + + // Deploy contracts + + deployContracts( + t, + rt, + contractsAddress, + runtimeInterface, + transactionEnvironment, + nextTransactionLocation, + ) + + // Run script + + actual, err := rt.ExecuteScript( + runtime.Script{ + Source: script, + }, + runtime.Context{ + Interface: runtimeInterface, + Environment: scriptEnvironment, + Location: nextScriptLocation(), + }, + ) + require.NoError(t, err) + + expected := cadence.NewArray([]cadence.Value{ + cadence.UInt8(2), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + cadence.UInt8(0), cadence.UInt8(0), + }).WithType(cadence.NewConstantSizedArrayType( + types.AddressLength, + cadence.UInt8Type{}, + )) + + require.Equal(t, expected, actual) + + require.True(t, deployed) +} diff --git a/fvm/evm/testutils/accounts.go b/fvm/evm/testutils/accounts.go index 237474da400..77cee885204 100644 --- a/fvm/evm/testutils/accounts.go +++ b/fvm/evm/testutils/accounts.go @@ -110,7 +110,7 @@ func GetTestEOAAccount(t testing.TB, keyHex string) *EOATestAccount { } } -func RunWithEOATestAccount(t *testing.T, led atree.Ledger, flowEVMRootAddress flow.Address, f func(*EOATestAccount)) { +func RunWithEOATestAccount(t testing.TB, led atree.Ledger, flowEVMRootAddress flow.Address, f func(*EOATestAccount)) { account := GetTestEOAAccount(t, EOATestAccount1KeyHex) // fund account @@ -131,5 +131,12 @@ func RunWithEOATestAccount(t *testing.T, led atree.Ledger, flowEVMRootAddress fl ) require.NoError(t, err) + blk2, err := e.NewReadOnlyBlockView(types.NewDefaultBlockContext(2)) + require.NoError(t, err) + + bal, err := blk2.BalanceOf(account.Address()) + require.NoError(t, err) + require.Greater(t, bal.Uint64(), uint64(0)) + f(account) } diff --git a/fvm/evm/testutils/backend.go b/fvm/evm/testutils/backend.go index 477f1dc89fb..c73eb04b00b 100644 --- a/fvm/evm/testutils/backend.go +++ b/fvm/evm/testutils/backend.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/flow-go/fvm/environment" - "github.com/onflow/flow-go/fvm/evm/types" "github.com/onflow/flow-go/fvm/meter" "github.com/onflow/flow-go/model/flow" ) @@ -27,8 +26,8 @@ func RunWithTestFlowEVMRootAddress(t testing.TB, backend atree.Ledger, f func(fl f(TestFlowEVMRootAddress) } -func RunWithTestBackend(t testing.TB, f func(types.Backend)) { - tb := &testBackend{ +func RunWithTestBackend(t testing.TB, f func(*TestBackend)) { + tb := &TestBackend{ TestValueStore: GetSimpleValueStore(), testEventEmitter: getSimpleEventEmitter(), testMeter: getSimpleMeter(), @@ -71,19 +70,32 @@ func GetSimpleValueStore() *TestValueStore { binary.BigEndian.PutUint64(data[:], index) return atree.StorageIndex(data), nil }, + TotalStorageSizeFunc: func() int { + sum := 0 + for key, value := range data { + sum += len(key) + len(value) + } + for key := range allocator { + sum += len(key) + 8 + } + return sum + }, } } func getSimpleEventEmitter() *testEventEmitter { events := make(flow.EventsList, 0) return &testEventEmitter{ - emitFlowEvent: func(etype flow.EventType, payload []byte) error { + emitRawEvent: func(etype flow.EventType, payload []byte) error { events = append(events, flow.Event{Type: etype, Payload: payload}) return nil }, events: func() flow.EventsList { return events }, + reset: func() { + events = make(flow.EventsList, 0) + }, } } @@ -107,17 +119,32 @@ func getSimpleMeter() *testMeter { } } -type testBackend struct { +type TestBackend struct { *TestValueStore *testMeter *testEventEmitter } +func (tb *TestBackend) TotalStorageSize() int { + if tb.TotalStorageSizeFunc == nil { + panic("method not set") + } + return tb.TotalStorageSizeFunc() +} + +func (tb *TestBackend) DropEvents() { + if tb.reset == nil { + panic("method not set") + } + tb.reset() +} + type TestValueStore struct { GetValueFunc func(owner, key []byte) ([]byte, error) SetValueFunc func(owner, key, value []byte) error ValueExistsFunc func(owner, key []byte) (bool, error) AllocateStorageIndexFunc func(owner []byte) (atree.StorageIndex, error) + TotalStorageSizeFunc func() int } var _ environment.ValueStore = &TestValueStore{} @@ -150,6 +177,13 @@ func (vs *TestValueStore) AllocateStorageIndex(owner []byte) (atree.StorageIndex return vs.AllocateStorageIndexFunc(owner) } +func (vs *TestValueStore) TotalStorageSize() int { + if vs.TotalStorageSizeFunc == nil { + panic("method not set") + } + return vs.TotalStorageSizeFunc() +} + type testMeter struct { meterComputation func(common.ComputationKind, uint) error hasComputationCapacity func(common.ComputationKind, uint) bool @@ -238,7 +272,7 @@ func (m *testMeter) TotalEmittedEventBytes() uint64 { type testEventEmitter struct { emitEvent func(event cadence.Event) error - emitFlowEvent func(etype flow.EventType, payload []byte) error + emitRawEvent func(etype flow.EventType, payload []byte) error events func() flow.EventsList serviceEvents func() flow.EventsList convertedServiceEvents func() flow.ServiceEventList @@ -254,11 +288,11 @@ func (vs *testEventEmitter) EmitEvent(event cadence.Event) error { return vs.emitEvent(event) } -func (vs *testEventEmitter) EmitFlowEvent(etype flow.EventType, payload []byte) error { - if vs.emitFlowEvent == nil { +func (vs *testEventEmitter) EmitRawEvent(etype flow.EventType, payload []byte) error { + if vs.emitRawEvent == nil { panic("method not set") } - return vs.emitFlowEvent(etype, payload) + return vs.emitRawEvent(etype, payload) } func (vs *testEventEmitter) Events() flow.EventsList { diff --git a/fvm/evm/testutils/cadence.go b/fvm/evm/testutils/cadence.go new file mode 100644 index 00000000000..a35070c3f69 --- /dev/null +++ b/fvm/evm/testutils/cadence.go @@ -0,0 +1,691 @@ +package testutils + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/onflow/atree" + "github.com/onflow/cadence" + "github.com/onflow/cadence/encoding/json" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" + cadenceStdlib "github.com/onflow/cadence/runtime/stdlib" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" +) + +// TODO: replace with Cadence runtime testing utils once available https://github.com/onflow/cadence/pull/2800 + +func SingleIdentifierLocationResolver(t testing.TB) func( + identifiers []runtime.Identifier, + location runtime.Location, +) ( + []runtime.ResolvedLocation, + error, +) { + return func(identifiers []runtime.Identifier, location runtime.Location) ([]runtime.ResolvedLocation, error) { + require.Len(t, identifiers, 1) + require.IsType(t, common.AddressLocation{}, location) + + return []runtime.ResolvedLocation{ + { + Location: common.AddressLocation{ + Address: location.(common.AddressLocation).Address, + Name: identifiers[0].Identifier, + }, + Identifiers: identifiers, + }, + }, nil + } +} + +func newLocationGenerator[T ~[32]byte]() func() T { + var count uint64 + return func() T { + t := T{} + newCount := atomic.AddUint64(&count, 1) + binary.LittleEndian.PutUint64(t[:], newCount) + return t + } +} + +func NewTransactionLocationGenerator() func() common.TransactionLocation { + return newLocationGenerator[common.TransactionLocation]() +} + +func NewScriptLocationGenerator() func() common.ScriptLocation { + return newLocationGenerator[common.ScriptLocation]() +} + +func EncodeArgs(argValues []cadence.Value) [][]byte { + args := make([][]byte, len(argValues)) + for i, arg := range argValues { + var err error + args[i], err = json.Encode(arg) + if err != nil { + panic(fmt.Errorf("broken test: invalid argument: %w", err)) + } + } + return args +} + +type TestLedger struct { + StoredValues map[string][]byte + OnValueExists func(owner, key []byte) (exists bool, err error) + OnGetValue func(owner, key []byte) (value []byte, err error) + OnSetValue func(owner, key, value []byte) (err error) + OnAllocateStorageIndex func(owner []byte) (atree.StorageIndex, error) +} + +var _ atree.Ledger = TestLedger{} + +func (s TestLedger) GetValue(owner, key []byte) (value []byte, err error) { + return s.OnGetValue(owner, key) +} + +func (s TestLedger) SetValue(owner, key, value []byte) (err error) { + return s.OnSetValue(owner, key, value) +} + +func (s TestLedger) ValueExists(owner, key []byte) (exists bool, err error) { + return s.OnValueExists(owner, key) +} + +func (s TestLedger) AllocateStorageIndex(owner []byte) (atree.StorageIndex, error) { + return s.OnAllocateStorageIndex(owner) +} + +func (s TestLedger) Dump() { + // Only used for testing/debugging purposes + for key, data := range s.StoredValues { //nolint:maprange + fmt.Printf("%s:\n", strconv.Quote(key)) + fmt.Printf("%s\n", hex.Dump(data)) + println() + } +} + +func NewTestLedger( + onRead func(owner, key, value []byte), + onWrite func(owner, key, value []byte), +) TestLedger { + + storageKey := func(owner, key string) string { + return strings.Join([]string{owner, key}, "|") + } + + storedValues := map[string][]byte{} + + storageIndices := map[string]uint64{} + + return TestLedger{ + StoredValues: storedValues, + OnValueExists: func(owner, key []byte) (bool, error) { + value := storedValues[storageKey(string(owner), string(key))] + return len(value) > 0, nil + }, + OnGetValue: func(owner, key []byte) (value []byte, err error) { + value = storedValues[storageKey(string(owner), string(key))] + if onRead != nil { + onRead(owner, key, value) + } + return value, nil + }, + OnSetValue: func(owner, key, value []byte) (err error) { + storedValues[storageKey(string(owner), string(key))] = value + if onWrite != nil { + onWrite(owner, key, value) + } + return nil + }, + OnAllocateStorageIndex: func(owner []byte) (result atree.StorageIndex, err error) { + index := storageIndices[string(owner)] + 1 + storageIndices[string(owner)] = index + binary.BigEndian.PutUint64(result[:], index) + return + }, + } +} + +type TestRuntimeInterface struct { + Storage atree.Ledger + + OnResolveLocation func( + identifiers []runtime.Identifier, + location runtime.Location, + ) ( + []runtime.ResolvedLocation, + error, + ) + OnGetCode func(_ runtime.Location) ([]byte, error) + OnGetAndSetProgram func( + location runtime.Location, + load func() (*interpreter.Program, error), + ) (*interpreter.Program, error) + OnSetInterpreterSharedState func(state *interpreter.SharedState) + OnGetInterpreterSharedState func() *interpreter.SharedState + OnCreateAccount func(payer runtime.Address) (address runtime.Address, err error) + OnAddEncodedAccountKey func(address runtime.Address, publicKey []byte) error + OnRemoveEncodedAccountKey func(address runtime.Address, index int) (publicKey []byte, err error) + OnAddAccountKey func( + address runtime.Address, + publicKey *cadenceStdlib.PublicKey, + hashAlgo runtime.HashAlgorithm, + weight int, + ) (*cadenceStdlib.AccountKey, error) + OnGetAccountKey func(address runtime.Address, index int) (*cadenceStdlib.AccountKey, error) + OnRemoveAccountKey func(address runtime.Address, index int) (*cadenceStdlib.AccountKey, error) + OnAccountKeysCount func(address runtime.Address) (uint64, error) + OnUpdateAccountContractCode func(location common.AddressLocation, code []byte) error + OnGetAccountContractCode func(location common.AddressLocation) (code []byte, err error) + OnRemoveAccountContractCode func(location common.AddressLocation) (err error) + OnGetSigningAccounts func() ([]runtime.Address, error) + OnProgramLog func(string) + OnEmitEvent func(cadence.Event) error + OnResourceOwnerChanged func( + interpreter *interpreter.Interpreter, + resource *interpreter.CompositeValue, + oldAddress common.Address, + newAddress common.Address, + ) + OnGenerateUUID func() (uint64, error) + OnMeterComputation func(compKind common.ComputationKind, intensity uint) error + OnDecodeArgument func(b []byte, t cadence.Type) (cadence.Value, error) + OnProgramParsed func(location runtime.Location, duration time.Duration) + OnProgramChecked func(location runtime.Location, duration time.Duration) + OnProgramInterpreted func(location runtime.Location, duration time.Duration) + OnReadRandom func([]byte) error + OnVerifySignature func( + signature []byte, + tag string, + signedData []byte, + publicKey []byte, + signatureAlgorithm runtime.SignatureAlgorithm, + hashAlgorithm runtime.HashAlgorithm, + ) (bool, error) + OnHash func( + data []byte, + tag string, + hashAlgorithm runtime.HashAlgorithm, + ) ([]byte, error) + OnSetCadenceValue func(owner runtime.Address, key string, value cadence.Value) (err error) + OnGetAccountBalance func(_ runtime.Address) (uint64, error) + OnGetAccountAvailableBalance func(_ runtime.Address) (uint64, error) + OnGetStorageUsed func(_ runtime.Address) (uint64, error) + OnGetStorageCapacity func(_ runtime.Address) (uint64, error) + Programs map[runtime.Location]*interpreter.Program + OnImplementationDebugLog func(message string) error + OnValidatePublicKey func(publicKey *cadenceStdlib.PublicKey) error + OnBLSVerifyPOP func(pk *cadenceStdlib.PublicKey, s []byte) (bool, error) + OnBLSAggregateSignatures func(sigs [][]byte) ([]byte, error) + OnBLSAggregatePublicKeys func(keys []*cadenceStdlib.PublicKey) (*cadenceStdlib.PublicKey, error) + OnGetAccountContractNames func(address runtime.Address) ([]string, error) + OnRecordTrace func( + operation string, + location runtime.Location, + duration time.Duration, + attrs []attribute.KeyValue, + ) + OnMeterMemory func(usage common.MemoryUsage) error + OnComputationUsed func() (uint64, error) + OnMemoryUsed func() (uint64, error) + OnInteractionUsed func() (uint64, error) + OnGenerateAccountID func(address common.Address) (uint64, error) + + lastUUID uint64 + accountIDs map[common.Address]uint64 + updatedContractCode bool +} + +// TestRuntimeInterface should implement Interface +var _ runtime.Interface = &TestRuntimeInterface{} + +func (i *TestRuntimeInterface) ResolveLocation( + identifiers []runtime.Identifier, + location runtime.Location, +) ([]runtime.ResolvedLocation, error) { + if i.OnResolveLocation == nil { + return []runtime.ResolvedLocation{ + { + Location: location, + Identifiers: identifiers, + }, + }, nil + } + return i.OnResolveLocation(identifiers, location) +} + +func (i *TestRuntimeInterface) GetCode(location runtime.Location) ([]byte, error) { + if i.OnGetCode == nil { + return nil, nil + } + return i.OnGetCode(location) +} + +func (i *TestRuntimeInterface) GetOrLoadProgram( + location runtime.Location, + load func() (*interpreter.Program, error), +) ( + program *interpreter.Program, + err error, +) { + if i.OnGetAndSetProgram == nil { + if i.Programs == nil { + i.Programs = map[runtime.Location]*interpreter.Program{} + } + + var ok bool + program, ok = i.Programs[location] + if ok { + return + } + + program, err = load() + + // NOTE: important: still set empty program, + // even if error occurred + + i.Programs[location] = program + + return + } + + return i.OnGetAndSetProgram(location, load) +} + +func (i *TestRuntimeInterface) SetInterpreterSharedState(state *interpreter.SharedState) { + if i.OnSetInterpreterSharedState == nil { + return + } + + i.OnSetInterpreterSharedState(state) +} + +func (i *TestRuntimeInterface) GetInterpreterSharedState() *interpreter.SharedState { + if i.OnGetInterpreterSharedState == nil { + return nil + } + + return i.OnGetInterpreterSharedState() +} + +func (i *TestRuntimeInterface) ValueExists(owner, key []byte) (exists bool, err error) { + return i.Storage.ValueExists(owner, key) +} + +func (i *TestRuntimeInterface) GetValue(owner, key []byte) (value []byte, err error) { + return i.Storage.GetValue(owner, key) +} + +func (i *TestRuntimeInterface) SetValue(owner, key, value []byte) (err error) { + return i.Storage.SetValue(owner, key, value) +} + +func (i *TestRuntimeInterface) AllocateStorageIndex(owner []byte) (atree.StorageIndex, error) { + return i.Storage.AllocateStorageIndex(owner) +} + +func (i *TestRuntimeInterface) CreateAccount(payer runtime.Address) (address runtime.Address, err error) { + if i.OnCreateAccount == nil { + panic("must specify TestRuntimeInterface.OnCreateAccount") + } + return i.OnCreateAccount(payer) +} + +func (i *TestRuntimeInterface) AddEncodedAccountKey(address runtime.Address, publicKey []byte) error { + if i.OnAddEncodedAccountKey == nil { + panic("must specify TestRuntimeInterface.OnAddEncodedAccountKey") + } + return i.OnAddEncodedAccountKey(address, publicKey) +} + +func (i *TestRuntimeInterface) RevokeEncodedAccountKey(address runtime.Address, index int) ([]byte, error) { + if i.OnRemoveEncodedAccountKey == nil { + panic("must specify TestRuntimeInterface.OnRemoveEncodedAccountKey") + } + return i.OnRemoveEncodedAccountKey(address, index) +} + +func (i *TestRuntimeInterface) AddAccountKey( + address runtime.Address, + publicKey *cadenceStdlib.PublicKey, + hashAlgo runtime.HashAlgorithm, + weight int, +) (*cadenceStdlib.AccountKey, error) { + if i.OnAddAccountKey == nil { + panic("must specify TestRuntimeInterface.OnAddAccountKey") + } + return i.OnAddAccountKey(address, publicKey, hashAlgo, weight) +} + +func (i *TestRuntimeInterface) GetAccountKey(address runtime.Address, index int) (*cadenceStdlib.AccountKey, error) { + if i.OnGetAccountKey == nil { + panic("must specify TestRuntimeInterface.OnGetAccountKey") + } + return i.OnGetAccountKey(address, index) +} + +func (i *TestRuntimeInterface) AccountKeysCount(address runtime.Address) (uint64, error) { + if i.OnAccountKeysCount == nil { + panic("must specify TestRuntimeInterface.OnAccountKeysCount") + } + return i.OnAccountKeysCount(address) +} + +func (i *TestRuntimeInterface) RevokeAccountKey(address runtime.Address, index int) (*cadenceStdlib.AccountKey, error) { + if i.OnRemoveAccountKey == nil { + panic("must specify TestRuntimeInterface.OnRemoveAccountKey") + } + return i.OnRemoveAccountKey(address, index) +} + +func (i *TestRuntimeInterface) UpdateAccountContractCode(location common.AddressLocation, code []byte) (err error) { + if i.OnUpdateAccountContractCode == nil { + panic("must specify TestRuntimeInterface.OnUpdateAccountContractCode") + } + + err = i.OnUpdateAccountContractCode(location, code) + if err != nil { + return err + } + + i.updatedContractCode = true + + return nil +} + +func (i *TestRuntimeInterface) GetAccountContractCode(location common.AddressLocation) (code []byte, err error) { + if i.OnGetAccountContractCode == nil { + panic("must specify TestRuntimeInterface.OnGetAccountContractCode") + } + return i.OnGetAccountContractCode(location) +} + +func (i *TestRuntimeInterface) RemoveAccountContractCode(location common.AddressLocation) (err error) { + if i.OnRemoveAccountContractCode == nil { + panic("must specify TestRuntimeInterface.OnRemoveAccountContractCode") + } + return i.OnRemoveAccountContractCode(location) +} + +func (i *TestRuntimeInterface) GetSigningAccounts() ([]runtime.Address, error) { + if i.OnGetSigningAccounts == nil { + return nil, nil + } + return i.OnGetSigningAccounts() +} + +func (i *TestRuntimeInterface) ProgramLog(message string) error { + i.OnProgramLog(message) + return nil +} + +func (i *TestRuntimeInterface) EmitEvent(event cadence.Event) error { + return i.OnEmitEvent(event) +} + +func (i *TestRuntimeInterface) ResourceOwnerChanged( + interpreter *interpreter.Interpreter, + resource *interpreter.CompositeValue, + oldOwner common.Address, + newOwner common.Address, +) { + if i.OnResourceOwnerChanged != nil { + i.OnResourceOwnerChanged( + interpreter, + resource, + oldOwner, + newOwner, + ) + } +} + +func (i *TestRuntimeInterface) GenerateUUID() (uint64, error) { + if i.OnGenerateUUID == nil { + i.lastUUID++ + return i.lastUUID, nil + } + return i.OnGenerateUUID() +} + +func (i *TestRuntimeInterface) MeterComputation(compKind common.ComputationKind, intensity uint) error { + if i.OnMeterComputation == nil { + return nil + } + return i.OnMeterComputation(compKind, intensity) +} + +func (i *TestRuntimeInterface) DecodeArgument(b []byte, t cadence.Type) (cadence.Value, error) { + if i.OnDecodeArgument == nil { + panic("must specify TestRuntimeInterface.OnDecodeArgument") + } + return i.OnDecodeArgument(b, t) +} + +func (i *TestRuntimeInterface) ProgramParsed(location runtime.Location, duration time.Duration) { + if i.OnProgramParsed == nil { + return + } + i.OnProgramParsed(location, duration) +} + +func (i *TestRuntimeInterface) ProgramChecked(location runtime.Location, duration time.Duration) { + if i.OnProgramChecked == nil { + return + } + i.OnProgramChecked(location, duration) +} + +func (i *TestRuntimeInterface) ProgramInterpreted(location runtime.Location, duration time.Duration) { + if i.OnProgramInterpreted == nil { + return + } + i.OnProgramInterpreted(location, duration) +} + +func (i *TestRuntimeInterface) GetCurrentBlockHeight() (uint64, error) { + return 1, nil +} + +func (i *TestRuntimeInterface) GetBlockAtHeight(height uint64) (block cadenceStdlib.Block, exists bool, err error) { + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.BigEndian, height) + if err != nil { + panic(err) + } + + encoded := buf.Bytes() + var hash cadenceStdlib.BlockHash + copy(hash[sema.BlockTypeIdFieldType.Size-int64(len(encoded)):], encoded) + + block = cadenceStdlib.Block{ + Height: height, + View: height, + Hash: hash, + Timestamp: time.Unix(int64(height), 0).UnixNano(), + } + return block, true, nil +} + +func (i *TestRuntimeInterface) ReadRandom(buffer []byte) error { + if i.OnReadRandom == nil { + return nil + } + return i.OnReadRandom(buffer) +} + +func (i *TestRuntimeInterface) VerifySignature( + signature []byte, + tag string, + signedData []byte, + publicKey []byte, + signatureAlgorithm runtime.SignatureAlgorithm, + hashAlgorithm runtime.HashAlgorithm, +) (bool, error) { + if i.OnVerifySignature == nil { + return false, nil + } + return i.OnVerifySignature( + signature, + tag, + signedData, + publicKey, + signatureAlgorithm, + hashAlgorithm, + ) +} + +func (i *TestRuntimeInterface) Hash(data []byte, tag string, hashAlgorithm runtime.HashAlgorithm) ([]byte, error) { + if i.OnHash == nil { + return nil, nil + } + return i.OnHash(data, tag, hashAlgorithm) +} + +func (i *TestRuntimeInterface) SetCadenceValue(owner common.Address, key string, value cadence.Value) (err error) { + if i.OnSetCadenceValue == nil { + panic("must specify TestRuntimeInterface.OnSetCadenceValue") + } + return i.OnSetCadenceValue(owner, key, value) +} + +func (i *TestRuntimeInterface) GetAccountBalance(address runtime.Address) (uint64, error) { + if i.OnGetAccountBalance == nil { + panic("must specify TestRuntimeInterface.OnGetAccountBalance") + } + return i.OnGetAccountBalance(address) +} + +func (i *TestRuntimeInterface) GetAccountAvailableBalance(address runtime.Address) (uint64, error) { + if i.OnGetAccountAvailableBalance == nil { + panic("must specify TestRuntimeInterface.OnGetAccountAvailableBalance") + } + return i.OnGetAccountAvailableBalance(address) +} + +func (i *TestRuntimeInterface) GetStorageUsed(address runtime.Address) (uint64, error) { + if i.OnGetStorageUsed == nil { + panic("must specify TestRuntimeInterface.OnGetStorageUsed") + } + return i.OnGetStorageUsed(address) +} + +func (i *TestRuntimeInterface) GetStorageCapacity(address runtime.Address) (uint64, error) { + if i.OnGetStorageCapacity == nil { + panic("must specify TestRuntimeInterface.OnGetStorageCapacity") + } + return i.OnGetStorageCapacity(address) +} + +func (i *TestRuntimeInterface) ImplementationDebugLog(message string) error { + if i.OnImplementationDebugLog == nil { + return nil + } + return i.OnImplementationDebugLog(message) +} + +func (i *TestRuntimeInterface) ValidatePublicKey(key *cadenceStdlib.PublicKey) error { + if i.OnValidatePublicKey == nil { + return errors.New("mock defaults to public key validation failure") + } + + return i.OnValidatePublicKey(key) +} + +func (i *TestRuntimeInterface) BLSVerifyPOP(key *cadenceStdlib.PublicKey, s []byte) (bool, error) { + if i.OnBLSVerifyPOP == nil { + return false, nil + } + + return i.OnBLSVerifyPOP(key, s) +} + +func (i *TestRuntimeInterface) BLSAggregateSignatures(sigs [][]byte) ([]byte, error) { + if i.OnBLSAggregateSignatures == nil { + return []byte{}, nil + } + + return i.OnBLSAggregateSignatures(sigs) +} + +func (i *TestRuntimeInterface) BLSAggregatePublicKeys(keys []*cadenceStdlib.PublicKey) (*cadenceStdlib.PublicKey, error) { + if i.OnBLSAggregatePublicKeys == nil { + return nil, nil + } + + return i.OnBLSAggregatePublicKeys(keys) +} + +func (i *TestRuntimeInterface) GetAccountContractNames(address runtime.Address) ([]string, error) { + if i.OnGetAccountContractNames == nil { + return []string{}, nil + } + + return i.OnGetAccountContractNames(address) +} + +func (i *TestRuntimeInterface) GenerateAccountID(address common.Address) (uint64, error) { + if i.OnGenerateAccountID == nil { + if i.accountIDs == nil { + i.accountIDs = map[common.Address]uint64{} + } + i.accountIDs[address]++ + return i.accountIDs[address], nil + } + + return i.OnGenerateAccountID(address) +} + +func (i *TestRuntimeInterface) RecordTrace( + operation string, + location runtime.Location, + duration time.Duration, + attrs []attribute.KeyValue, +) { + if i.OnRecordTrace == nil { + return + } + i.OnRecordTrace(operation, location, duration, attrs) +} + +func (i *TestRuntimeInterface) MeterMemory(usage common.MemoryUsage) error { + if i.OnMeterMemory == nil { + return nil + } + + return i.OnMeterMemory(usage) +} + +func (i *TestRuntimeInterface) ComputationUsed() (uint64, error) { + if i.OnComputationUsed == nil { + return 0, nil + } + + return i.OnComputationUsed() +} + +func (i *TestRuntimeInterface) MemoryUsed() (uint64, error) { + if i.OnMemoryUsed == nil { + return 0, nil + } + + return i.OnMemoryUsed() +} + +func (i *TestRuntimeInterface) InteractionUsed() (uint64, error) { + if i.OnInteractionUsed == nil { + return 0, nil + } + + return i.OnInteractionUsed() +} diff --git a/fvm/evm/testutils/contract.go b/fvm/evm/testutils/contract.go index a4984974455..688afe1f941 100644 --- a/fvm/evm/testutils/contract.go +++ b/fvm/evm/testutils/contract.go @@ -26,35 +26,19 @@ type TestContract struct { DeployedAt types.Address } -func (tc *TestContract) MakeStoreCallData(t *testing.T, num *big.Int) []byte { +func (tc *TestContract) MakeCallData(t testing.TB, name string, args ...interface{}) []byte { abi, err := gethABI.JSON(strings.NewReader(tc.ABI)) require.NoError(t, err) - store, err := abi.Pack("store", num) + call, err := abi.Pack(name, args...) require.NoError(t, err) - return store -} - -func (tc *TestContract) MakeRetrieveCallData(t *testing.T) []byte { - abi, err := gethABI.JSON(strings.NewReader(tc.ABI)) - require.NoError(t, err) - retrieve, err := abi.Pack("retrieve") - require.NoError(t, err) - return retrieve -} - -func (tc *TestContract) MakeBlockNumberCallData(t *testing.T) []byte { - abi, err := gethABI.JSON(strings.NewReader(tc.ABI)) - require.NoError(t, err) - blockNum, err := abi.Pack("blockNumber") - require.NoError(t, err) - return blockNum + return call } func (tc *TestContract) SetDeployedAt(deployedAt types.Address) { tc.DeployedAt = deployedAt } -func GetTestContract(t *testing.T) *TestContract { +func GetStorageTestContract(t *testing.T) *TestContract { byteCodes, err := hex.DecodeString("608060405261022c806100136000396000f3fe608060405234801561001057600080fd5b50600436106100575760003560e01c80632e64cec11461005c57806348b151661461007a57806357e871e7146100985780636057361d146100b657806385df51fd146100d2575b600080fd5b610064610102565b6040516100719190610149565b60405180910390f35b61008261010b565b60405161008f9190610149565b60405180910390f35b6100a0610113565b6040516100ad9190610149565b60405180910390f35b6100d060048036038101906100cb9190610195565b61011b565b005b6100ec60048036038101906100e79190610195565b610125565b6040516100f991906101db565b60405180910390f35b60008054905090565b600042905090565b600043905090565b8060008190555050565b600081409050919050565b6000819050919050565b61014381610130565b82525050565b600060208201905061015e600083018461013a565b92915050565b600080fd5b61017281610130565b811461017d57600080fd5b50565b60008135905061018f81610169565b92915050565b6000602082840312156101ab576101aa610164565b5b60006101b984828501610180565b91505092915050565b6000819050919050565b6101d5816101c2565b82525050565b60006020820190506101f060008301846101cc565b9291505056fea26469706673582212203ee61567a25f0b1848386ae6b8fdbd7733c8a502c83b5ed305b921b7933f4e8164736f6c63430008120033") require.NoError(t, err) return &TestContract{ @@ -165,8 +149,238 @@ func GetTestContract(t *testing.T) *TestContract { } } -func RunWithDeployedContract(t *testing.T, led atree.Ledger, flowEVMRootAddress flow.Address, f func(*TestContract)) { - tc := GetTestContract(t) +func GetDummyKittyTestContract(t testing.TB) *TestContract { + byteCodes, err := hex.DecodeString("608060405234801561001057600080fd5b506107dd806100206000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063a45f4bfc14610046578063d0b169d114610076578063ddf252ad146100a6575b600080fd5b610060600480360381019061005b91906104e4565b6100c2565b60405161006d9190610552565b60405180910390f35b610090600480360381019061008b919061056d565b6100f5565b60405161009d91906105e3565b60405180910390f35b6100c060048036038101906100bb919061062a565b610338565b005b60026020528060005260406000206000915054906101000a900473ffffffffffffffffffffffffffffffffffffffff1681565b60008463ffffffff16851461010957600080fd5b8363ffffffff16841461011b57600080fd5b8261ffff16831461012b57600080fd5b60006040518060a001604052808481526020014267ffffffffffffffff1681526020018763ffffffff1681526020018663ffffffff1681526020018561ffff16815250905060018190806001815401808255809150506001900390600052602060002090600202016000909190919091506000820151816000015560208201518160010160006101000a81548167ffffffffffffffff021916908367ffffffffffffffff16021790555060408201518160010160086101000a81548163ffffffff021916908363ffffffff160217905550606082015181600101600c6101000a81548163ffffffff021916908363ffffffff16021790555060808201518160010160106101000a81548161ffff021916908361ffff16021790555050507fc1e409485f45287e73ab1623a8f2ef17af5eac1b4c792ee9ec466e8795e7c09133600054836040015163ffffffff16846060015163ffffffff16856000015160405161029995949392919061067d565b60405180910390a13073ffffffffffffffffffffffffffffffffffffffff1663ddf252ad6000336000546040518463ffffffff1660e01b81526004016102e1939291906106d0565b600060405180830381600087803b1580156102fb57600080fd5b505af115801561030f573d6000803e3d6000fd5b5050505060008081548092919061032590610736565b9190505550600054915050949350505050565b600360008373ffffffffffffffffffffffffffffffffffffffff1673ffffffffffffffffffffffffffffffffffffffff168152602001908152602001600020600081548092919061038890610736565b9190505550816002600083815260200190815260200160002060006101000a81548173ffffffffffffffffffffffffffffffffffffffff021916908373ffffffffffffffffffffffffffffffffffffffff160217905550600073ffffffffffffffffffffffffffffffffffffffff168373ffffffffffffffffffffffffffffffffffffffff161461046957600360008473ffffffffffffffffffffffffffffffffffffffff1673ffffffffffffffffffffffffffffffffffffffff16815260200190815260200160002060008154809291906104639061077e565b91905055505b7feaf1c4b3ce0f4f62a2bae7eb3e68225c75f7e6ff4422073b7437b9a78d25f17083838360405161049c939291906106d0565b60405180910390a1505050565b600080fd5b6000819050919050565b6104c1816104ae565b81146104cc57600080fd5b50565b6000813590506104de816104b8565b92915050565b6000602082840312156104fa576104f96104a9565b5b6000610508848285016104cf565b91505092915050565b600073ffffffffffffffffffffffffffffffffffffffff82169050919050565b600061053c82610511565b9050919050565b61054c81610531565b82525050565b60006020820190506105676000830184610543565b92915050565b60008060008060808587031215610587576105866104a9565b5b6000610595878288016104cf565b94505060206105a6878288016104cf565b93505060406105b7878288016104cf565b92505060606105c8878288016104cf565b91505092959194509250565b6105dd816104ae565b82525050565b60006020820190506105f860008301846105d4565b92915050565b61060781610531565b811461061257600080fd5b50565b600081359050610624816105fe565b92915050565b600080600060608486031215610643576106426104a9565b5b600061065186828701610615565b935050602061066286828701610615565b9250506040610673868287016104cf565b9150509250925092565b600060a0820190506106926000830188610543565b61069f60208301876105d4565b6106ac60408301866105d4565b6106b960608301856105d4565b6106c660808301846105d4565b9695505050505050565b60006060820190506106e56000830186610543565b6106f26020830185610543565b6106ff60408301846105d4565b949350505050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052601160045260246000fd5b6000610741826104ae565b91507fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff820361077357610772610707565b5b600182019050919050565b6000610789826104ae565b91506000820361079c5761079b610707565b5b60018203905091905056fea2646970667358221220ab35c07ec72cc064a663de06ec7f5f919b1a499a25cf6ef0c63a45fdd4a1e91e64736f6c63430008120033") + require.NoError(t, err) + return &TestContract{ + Code: ` + contract DummyKitty { + + event BirthEvent(address owner, uint256 kittyId, uint256 matronId, uint256 sireId, uint256 genes); + event TransferEvent(address from, address to, uint256 tokenId); + + struct Kitty { + uint256 genes; + uint64 birthTime; + uint32 matronId; + uint32 sireId; + uint16 generation; + } + + uint256 idCounter; + + // @dev all kitties + Kitty[] kitties; + + /// @dev a mapping from cat IDs to the address that owns them. + mapping (uint256 => address) public kittyIndexToOwner; + + // @dev a mapping from owner address to count of tokens that address owns. + mapping (address => uint256) ownershipTokenCount; + + /// @dev a method to transfer kitty + function Transfer(address _from, address _to, uint256 _tokenId) external { + // Since the number of kittens is capped to 2^32 we can't overflow this + ownershipTokenCount[_to]++; + // transfer ownership + kittyIndexToOwner[_tokenId] = _to; + // When creating new kittens _from is 0x0, but we can't account that address. + if (_from != address(0)) { + ownershipTokenCount[_from]--; + } + // Emit the transfer event. + emit TransferEvent(_from, _to, _tokenId); + } + + /// @dev a method callable by anyone to create a kitty + function CreateKitty( + uint256 _matronId, + uint256 _sireId, + uint256 _generation, + uint256 _genes + ) + external + returns (uint) + { + + require(_matronId == uint256(uint32(_matronId))); + require(_sireId == uint256(uint32(_sireId))); + require(_generation == uint256(uint16(_generation))); + + Kitty memory _kitty = Kitty({ + genes: _genes, + birthTime: uint64(block.timestamp), + matronId: uint32(_matronId), + sireId: uint32(_sireId), + generation: uint16(_generation) + }); + + kitties.push(_kitty); + + emit BirthEvent( + msg.sender, + idCounter, + uint256(_kitty.matronId), + uint256(_kitty.sireId), + _kitty.genes + ); + + this.Transfer(address(0), msg.sender, idCounter); + + idCounter++; + + return idCounter; + } + } + `, + + ABI: ` + [ + { + "anonymous": false, + "inputs": [ + { + "indexed": false, + "internalType": "address", + "name": "owner", + "type": "address" + }, + { + "indexed": false, + "internalType": "uint256", + "name": "kittyId", + "type": "uint256" + }, + { + "indexed": false, + "internalType": "uint256", + "name": "matronId", + "type": "uint256" + }, + { + "indexed": false, + "internalType": "uint256", + "name": "sireId", + "type": "uint256" + }, + { + "indexed": false, + "internalType": "uint256", + "name": "genes", + "type": "uint256" + } + ], + "name": "BirthEvent", + "type": "event" + }, + { + "anonymous": false, + "inputs": [ + { + "indexed": false, + "internalType": "address", + "name": "from", + "type": "address" + }, + { + "indexed": false, + "internalType": "address", + "name": "to", + "type": "address" + }, + { + "indexed": false, + "internalType": "uint256", + "name": "tokenId", + "type": "uint256" + } + ], + "name": "TransferEvent", + "type": "event" + }, + { + "inputs": [ + { + "internalType": "uint256", + "name": "_matronId", + "type": "uint256" + }, + { + "internalType": "uint256", + "name": "_sireId", + "type": "uint256" + }, + { + "internalType": "uint256", + "name": "_generation", + "type": "uint256" + }, + { + "internalType": "uint256", + "name": "_genes", + "type": "uint256" + } + ], + "name": "CreateKitty", + "outputs": [ + { + "internalType": "uint256", + "name": "", + "type": "uint256" + } + ], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "address", + "name": "_from", + "type": "address" + }, + { + "internalType": "address", + "name": "_to", + "type": "address" + }, + { + "internalType": "uint256", + "name": "_tokenId", + "type": "uint256" + } + ], + "name": "Transfer", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "uint256", + "name": "", + "type": "uint256" + } + ], + "name": "kittyIndexToOwner", + "outputs": [ + { + "internalType": "address", + "name": "", + "type": "address" + } + ], + "stateMutability": "view", + "type": "function" + } + ] + `, + ByteCode: byteCodes, + } +} + +func RunWithDeployedContract(t testing.TB, tc *TestContract, led atree.Ledger, flowEVMRootAddress flow.Address, f func(*TestContract)) { // deploy contract db, err := database.NewDatabase(led, flowEVMRootAddress) require.NoError(t, err) @@ -185,7 +399,10 @@ func RunWithDeployedContract(t *testing.T, led atree.Ledger, flowEVMRootAddress ) require.NoError(t, err) - res, err := blk.DirectCall( + blk2, err := e.NewBlockView(types.NewDefaultBlockContext(3)) + require.NoError(t, err) + + res, err := blk2.DirectCall( types.NewDeployCall( caller, tc.ByteCode, diff --git a/fvm/evm/testutils/emulator.go b/fvm/evm/testutils/emulator.go index 48b3e2218d7..5f7f2ce3068 100644 --- a/fvm/evm/testutils/emulator.go +++ b/fvm/evm/testutils/emulator.go @@ -80,6 +80,10 @@ func RandomCommonHash(t testing.TB) gethCommon.Hash { return ret } +func RandomBigInt(limit int64) *big.Int { + return big.NewInt(rand.Int63n(limit) + 1) +} + func RandomAddress(t testing.TB) types.Address { return types.NewAddress(RandomCommonAddress(t)) } diff --git a/fvm/evm/types/address.go b/fvm/evm/types/address.go index afcaa72e246..134ae6c6cf8 100644 --- a/fvm/evm/types/address.go +++ b/fvm/evm/types/address.go @@ -17,6 +17,9 @@ func NewAddress(addr gethCommon.Address) Address { return Address(addr) } +// EmptyAddress is an empty evm address +var EmptyAddress = Address(gethCommon.Address{}) + // Bytes returns a byte slice for the address func (fa Address) Bytes() []byte { return fa[:] @@ -27,14 +30,16 @@ func (fa Address) ToCommon() gethCommon.Address { return gethCommon.Address(fa) } +// NewAddressFromBytes constructs a new address from bytes +func NewAddressFromBytes(inp []byte) Address { + return Address(gethCommon.BytesToAddress(inp)) +} + // NewAddressFromString constructs a new address from an string func NewAddressFromString(str string) Address { - return Address(gethCommon.BytesToAddress([]byte(str))) + return NewAddressFromBytes([]byte(str)) } -// EmptyAddress is an empty evm address -var EmptyAddress = Address(gethCommon.Address{}) - type GasLimit uint64 type Code []byte diff --git a/fvm/evm/types/events.go b/fvm/evm/types/events.go index fb3e802bb83..10490797775 100644 --- a/fvm/evm/types/events.go +++ b/fvm/evm/types/events.go @@ -21,6 +21,7 @@ type Event struct { Payload EventPayload } +// we might break this event into two (tx included /tx executed) if size becomes an issue type TransactionExecutedPayload struct { BlockHeight uint64 TxEncoded []byte diff --git a/fvm/evm/types/handler.go b/fvm/evm/types/handler.go index d3775f7987a..3badb5c6175 100644 --- a/fvm/evm/types/handler.go +++ b/fvm/evm/types/handler.go @@ -2,6 +2,7 @@ package types import ( gethCommon "github.com/ethereum/go-ethereum/common" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/flow-go/fvm/environment" ) @@ -37,6 +38,8 @@ type ContractHandler interface { // Run runs a transaction in the evm environment, // collects the gas fees, and transfers the gas fees to the given coinbase account. Run(tx []byte, coinbase Address) + + FlowTokenAddress() common.Address } // Backend passes the FVM functionality needed inside the handler @@ -57,7 +60,7 @@ type BlockStore interface { // LatestBlock returns the latest appended block LatestBlock() (*Block, error) - // returns the hash of the block at the given height + // BlockHash returns the hash of the block at the given height BlockHash(height int) (gethCommon.Hash, error) // BlockProposal returns the block proposal @@ -66,6 +69,6 @@ type BlockStore interface { // CommitBlockProposal commits the block proposal and update the chain of blocks CommitBlockProposal() error - // Resets the block proposal + // ResetBlockProposal resets the block proposal ResetBlockProposal() error } diff --git a/fvm/evm/types/result.go b/fvm/evm/types/result.go index fb6f4087210..6e4248b2d58 100644 --- a/fvm/evm/types/result.go +++ b/fvm/evm/types/result.go @@ -5,7 +5,9 @@ import ( gethTypes "github.com/ethereum/go-ethereum/core/types" ) -// Result captures the result of an interaction with the emulator (direct call or evm tx) +// Result captures the result of an interaction to the emulator +// it could be the out put of a direct call or output of running an +// evm transaction. // Its more comprehensive than typical evm receipt, usually // the receipt generation requires some extra calculation (e.g. Deployed contract address) // but we take a different apporach here and include more data so that diff --git a/fvm/fvm_bench_test.go b/fvm/fvm_bench_test.go index 276c8cb69b8..2bbd9e9f311 100644 --- a/fvm/fvm_bench_test.go +++ b/fvm/fvm_bench_test.go @@ -11,18 +11,16 @@ import ( "github.com/ipfs/go-datastore" dssync "github.com/ipfs/go-datastore/sync" blockstore "github.com/ipfs/go-ipfs-blockstore" - "github.com/rs/zerolog" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/onflow/cadence" "github.com/onflow/cadence/encoding/ccf" jsoncdc "github.com/onflow/cadence/encoding/json" "github.com/onflow/cadence/runtime" + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" flow2 "github.com/onflow/flow-go-sdk" "github.com/onflow/flow-go-sdk/templates" - "github.com/onflow/flow-go/engine/execution" "github.com/onflow/flow-go/engine/execution/computation" "github.com/onflow/flow-go/engine/execution/computation/committer" @@ -225,7 +223,7 @@ func NewBasicBlockExecutor(tb testing.TB, chain flow.Chain, logger zerolog.Logge me, prov, nil, - nil, + testutil.ProtocolStateWithSourceFixture(nil), 1) // We're interested in fvm's serial execution time require.NoError(tb, err) diff --git a/fvm/runtime/reusable_cadence_runtime.go b/fvm/runtime/reusable_cadence_runtime.go index 307d6959bb2..f2e5c941d77 100644 --- a/fvm/runtime/reusable_cadence_runtime.go +++ b/fvm/runtime/reusable_cadence_runtime.go @@ -28,15 +28,17 @@ var randomSourceFunctionType = &sema.FunctionType{ type ReusableCadenceRuntime struct { runtime.Runtime - runtime.Environment + TxRuntimeEnv runtime.Environment + ScriptRuntimeEnv runtime.Environment fvmEnv Environment } func NewReusableCadenceRuntime(rt runtime.Runtime, config runtime.Config) *ReusableCadenceRuntime { reusable := &ReusableCadenceRuntime{ - Runtime: rt, - Environment: runtime.NewBaseInterpreterEnvironment(config), + Runtime: rt, + TxRuntimeEnv: runtime.NewBaseInterpreterEnvironment(config), + ScriptRuntimeEnv: runtime.NewScriptInterpreterEnvironment(config), } // Declare the `randomSourceHistory` function. This function is **only** used by the @@ -78,7 +80,7 @@ func NewReusableCadenceRuntime(rt runtime.Runtime, config runtime.Config) *Reusa ), } - reusable.DeclareValue(blockRandomSource, nil) + reusable.TxRuntimeEnv.DeclareValue(blockRandomSource, nil) return reusable } @@ -99,7 +101,7 @@ func (reusable *ReusableCadenceRuntime) ReadStored( path, runtime.Context{ Interface: reusable.fvmEnv, - Environment: reusable.Environment, + Environment: reusable.TxRuntimeEnv, }, ) } @@ -120,7 +122,7 @@ func (reusable *ReusableCadenceRuntime) InvokeContractFunction( argumentTypes, runtime.Context{ Interface: reusable.fvmEnv, - Environment: reusable.Environment, + Environment: reusable.TxRuntimeEnv, }, ) } @@ -134,7 +136,7 @@ func (reusable *ReusableCadenceRuntime) NewTransactionExecutor( runtime.Context{ Interface: reusable.fvmEnv, Location: location, - Environment: reusable.Environment, + Environment: reusable.TxRuntimeEnv, }, ) } @@ -149,8 +151,9 @@ func (reusable *ReusableCadenceRuntime) ExecuteScript( return reusable.Runtime.ExecuteScript( script, runtime.Context{ - Interface: reusable.fvmEnv, - Location: location, + Interface: reusable.fvmEnv, + Location: location, + Environment: reusable.ScriptRuntimeEnv, }, ) } diff --git a/fvm/script.go b/fvm/script.go index 23d89027835..b8ae36c755e 100644 --- a/fvm/script.go +++ b/fvm/script.go @@ -10,6 +10,7 @@ import ( "github.com/onflow/flow-go/fvm/environment" "github.com/onflow/flow-go/fvm/errors" + "github.com/onflow/flow-go/fvm/evm" "github.com/onflow/flow-go/fvm/storage" "github.com/onflow/flow-go/fvm/storage/logical" "github.com/onflow/flow-go/model/flow" @@ -198,6 +199,20 @@ func (executor *scriptExecutor) executeScript() error { rt := executor.env.BorrowCadenceRuntime() defer executor.env.ReturnCadenceRuntime(rt) + if executor.ctx.EVMEnabled { + chain := executor.ctx.Chain + err := evm.SetupEnvironment( + chain.ChainID(), + executor.env, + rt.ScriptRuntimeEnv, + chain.ServiceAddress(), + FlowTokenAddress(chain), + ) + if err != nil { + return err + } + } + value, err := rt.ExecuteScript( runtime.Script{ Source: executor.proc.Script, diff --git a/fvm/storage/snapshot/execution_snapshot.go b/fvm/storage/snapshot/execution_snapshot.go index 89cabec443a..420c4ffccb4 100644 --- a/fvm/storage/snapshot/execution_snapshot.go +++ b/fvm/storage/snapshot/execution_snapshot.go @@ -37,6 +37,11 @@ func (snapshot *ExecutionSnapshot) UpdatedRegisters() flow.RegisterEntries { return entries } +// UpdatedRegisterSet returns all registers that were updated by this view. +func (snapshot *ExecutionSnapshot) UpdatedRegisterSet() map[flow.RegisterID]flow.RegisterValue { + return snapshot.WriteSet +} + // UpdatedRegisterIDs returns all register ids that were updated by this // view. The returned ids are unsorted. func (snapshot *ExecutionSnapshot) UpdatedRegisterIDs() []flow.RegisterID { diff --git a/fvm/transactionInvoker.go b/fvm/transactionInvoker.go index 68f3861c849..03ba76878e5 100644 --- a/fvm/transactionInvoker.go +++ b/fvm/transactionInvoker.go @@ -12,6 +12,7 @@ import ( "github.com/onflow/flow-go/fvm/environment" "github.com/onflow/flow-go/fvm/errors" + "github.com/onflow/flow-go/fvm/evm" reusableRuntime "github.com/onflow/flow-go/fvm/runtime" "github.com/onflow/flow-go/fvm/storage" "github.com/onflow/flow-go/fvm/storage/derived" @@ -224,6 +225,21 @@ func (executor *transactionExecutor) execute() error { } func (executor *transactionExecutor) ExecuteTransactionBody() error { + // setup evm + if executor.ctx.EVMEnabled { + chain := executor.ctx.Chain + err := evm.SetupEnvironment( + chain.ChainID(), + executor.env, + executor.cadenceRuntime.TxRuntimeEnv, + chain.ServiceAddress(), + FlowTokenAddress(chain), + ) + if err != nil { + return err + } + } + var invalidator derived.TransactionInvalidator if !executor.errs.CollectedError() { diff --git a/go.mod b/go.mod index bfc7cbcef83..7af3c3f6586 100644 --- a/go.mod +++ b/go.mod @@ -51,13 +51,13 @@ require ( github.com/multiformats/go-multiaddr-dns v0.3.1 github.com/multiformats/go-multihash v0.2.3 github.com/onflow/atree v0.6.0 - github.com/onflow/cadence v0.42.4 + github.com/onflow/cadence v0.42.5 github.com/onflow/flow v0.3.4 github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122 github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-240579784e9b - github.com/onflow/flow-go-sdk v0.41.14 + github.com/onflow/flow-go-sdk v0.41.16 github.com/onflow/flow-go/crypto v0.24.9 - github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63 + github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2 github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pierrec/lz4 v2.6.1+incompatible diff --git a/go.sum b/go.sum index 94c29947bea..b07ae6a4f64 100644 --- a/go.sum +++ b/go.sum @@ -1325,8 +1325,8 @@ github.com/onflow/atree v0.1.0-beta1.0.20211027184039-559ee654ece9/go.mod h1:+6x github.com/onflow/atree v0.6.0 h1:j7nQ2r8npznx4NX39zPpBYHmdy45f4xwoi+dm37Jk7c= github.com/onflow/atree v0.6.0/go.mod h1:gBHU0M05qCbv9NN0kijLWMgC47gHVNBIp4KmsVFi0tc= github.com/onflow/cadence v0.20.1/go.mod h1:7mzUvPZUIJztIbr9eTvs+fQjWWHTF8veC+yk4ihcNIA= -github.com/onflow/cadence v0.42.4 h1:KoXnwPCMcjixZv+gHZwWkDiAyVExJhJJe6SebdnHNv8= -github.com/onflow/cadence v0.42.4/go.mod h1:raU8va8QRyTa/eUbhej4mbyW2ETePfSaywoo36MddgE= +github.com/onflow/cadence v0.42.5 h1:QCilotmJzfRToLd+02o3N62JIioSr8FfN7cujmR/IXQ= +github.com/onflow/cadence v0.42.5/go.mod h1:raU8va8QRyTa/eUbhej4mbyW2ETePfSaywoo36MddgE= github.com/onflow/flow v0.3.4 h1:FXUWVdYB90f/rjNcY0Owo30gL790tiYff9Pb/sycXYE= github.com/onflow/flow v0.3.4/go.mod h1:lzyAYmbu1HfkZ9cfnL5/sjrrsnJiUU8fRL26CqLP7+c= github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122 h1:yUzR59WUue8BN/bfwy0eN4YOLYXeQ3G9I53H0amxbDU= @@ -1336,16 +1336,16 @@ github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-2 github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13 h1:B4ll7e3j+MqTJv2122Enq3RtDNzmIGRu9xjV7fo7un0= github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13/go.mod h1:kTMFIySzEJJeupk+7EmXs0EJ6CBWY/MV9fv9iYQk+RU= github.com/onflow/flow-go-sdk v0.24.0/go.mod h1:IoptMLPyFXWvyd9yYA6/4EmSeeozl6nJoIv4FaEMg74= -github.com/onflow/flow-go-sdk v0.41.14 h1:Pe9hogrRWkICa34Gu8yLLCMlTnu5f3L7DNXeXREkkVw= -github.com/onflow/flow-go-sdk v0.41.14/go.mod h1:yOyiwL7bt8N6FpLAyNO2D46vnFlLnIu9Cyt5PKIgqAk= +github.com/onflow/flow-go-sdk v0.41.16 h1:HsmHwEVmj+iK+GszHbFseHh7Ii5W3PWOIRNAH/En08Q= +github.com/onflow/flow-go-sdk v0.41.16/go.mod h1:bVrVNoJKiwB6vW5Qbm5tFAfJBQ5we4uSQWnn9gNAFhQ= github.com/onflow/flow-go/crypto v0.21.3/go.mod h1:vI6V4CY3R6c4JKBxdcRiR/AnjBfL8OSD97bJc60cLuQ= github.com/onflow/flow-go/crypto v0.24.9 h1:0EQp+kSZYJepMIiSypfJVe7tzsPcb6UXOdOtsTCDhBs= github.com/onflow/flow-go/crypto v0.24.9/go.mod h1:fqCzkIBBMRRkciVrvW21rECKq1oD7Q6u+bCI78lfNX0= github.com/onflow/flow-nft/lib/go/contracts v1.1.0 h1:rhUDeD27jhLwOqQKI/23008CYfnqXErrJvc4EFRP2a0= github.com/onflow/flow-nft/lib/go/contracts v1.1.0/go.mod h1:YsvzYng4htDgRB9sa9jxdwoTuuhjK8WYWXTyLkIigZY= github.com/onflow/flow/protobuf/go/flow v0.2.2/go.mod h1:gQxYqCfkI8lpnKsmIjwtN2mV/N2PIwc1I+RUK4HPIc8= -github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63 h1:SX8OhYbyKBExhy4qEDR/Hw6MVTBTzlDb8LfCHfFyte4= -github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63/go.mod h1:NA2pX2nw8zuaxfKphhKsk00kWLwfd+tv8mS23YXO4Sk= +github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2 h1:qZjl4wSTG/E9znEjkHF0nNaEdlBLJoOEAtr7xUsTNqc= +github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2/go.mod h1:NA2pX2nw8zuaxfKphhKsk00kWLwfd+tv8mS23YXO4Sk= github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d h1:QcOAeEyF3iAUHv21LQ12sdcsr0yFrJGoGLyCAzYYtvI= github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d/go.mod h1:GCPpiyRoHncdqPj++zPr9ZOYBX4hpJ0pYZRYqSE8VKk= github.com/onflow/sdks v0.5.0 h1:2HCRibwqDaQ1c9oUApnkZtEAhWiNY2GTpRD5+ftdkN8= diff --git a/insecure/Makefile b/insecure/Makefile index 72a38cf4b4d..70eeff5a46f 100644 --- a/insecure/Makefile +++ b/insecure/Makefile @@ -12,3 +12,18 @@ endif .PHONY: test test: go test $(if $(VERBOSE),-v,) -coverprofile=$(COVER_PROFILE) $(RACE_FLAG) $(if $(JSON_OUTPUT),-json,) $(if $(NUM_RUNS),-count $(NUM_RUNS),) --tags relic ./... + +.PHONY: lint +lint: tidy + # revive -config revive.toml -exclude storage/ledger/trie ./... + golangci-lint run -v --build-tags relic ./... + +# this ensures there is no unused dependency being added by accident +.PHONY: tidy +tidy: + go mod tidy -v + cd integration; go mod tidy -v + cd crypto; go mod tidy -v + cd cmd/testclient; go mod tidy -v + cd insecure; go mod tidy -v + git diff --exit-code \ No newline at end of file diff --git a/insecure/dependency_test.go b/insecure/dependency_test.go new file mode 100644 index 00000000000..a2375847be9 --- /dev/null +++ b/insecure/dependency_test.go @@ -0,0 +1,8 @@ +package insecure + +import "github.com/btcsuite/btcd/chaincfg/chainhash" + +// this is added to resolve the issue with chainhash ambiguous import, +// the code is not used, but it's needed to force go.mod specify and retain chainhash version +// workaround for issue: https://github.com/golang/go/issues/27899 +var _ = chainhash.Hash{} diff --git a/insecure/go.mod b/insecure/go.mod index 3f3a246ec33..228bad8a836 100644 --- a/insecure/go.mod +++ b/insecure/go.mod @@ -3,6 +3,7 @@ module github.com/onflow/flow-go/insecure go 1.20 require ( + github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 github.com/golang/protobuf v1.5.3 github.com/hashicorp/go-multierror v1.1.1 github.com/ipfs/go-datastore v0.6.0 @@ -27,6 +28,8 @@ require ( cloud.google.com/go/iam v1.1.1 // indirect cloud.google.com/go/storage v1.30.1 // indirect github.com/DataDog/zstd v1.5.2 // indirect + github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 // indirect + github.com/VictoriaMetrics/fastcache v1.6.0 // indirect github.com/aws/aws-sdk-go-v2 v1.17.7 // indirect github.com/aws/aws-sdk-go-v2/config v1.18.19 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.13.18 // indirect @@ -61,6 +64,7 @@ require ( github.com/cskr/pubsub v1.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect + github.com/deckarep/golang-set/v2 v2.1.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect github.com/dgraph-io/badger/v2 v2.2007.4 // indirect @@ -89,8 +93,10 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.1 // indirect + github.com/go-stack/stack v1.8.1 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect + github.com/gofrs/flock v0.8.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v1.1.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect @@ -113,6 +119,7 @@ require ( github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/hashicorp/golang-lru/v2 v2.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/holiman/bloomfilter/v2 v2.0.3 // indirect github.com/holiman/uint256 v1.2.2-0.20230321075855-87b91420868c // indirect github.com/huin/goupnp v1.2.0 // indirect github.com/improbable-eng/grpc-web v0.15.0 // indirect @@ -189,14 +196,15 @@ require ( github.com/multiformats/go-multihash v0.2.3 // indirect github.com/multiformats/go-multistream v0.4.1 // indirect github.com/multiformats/go-varint v0.0.7 // indirect + github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/onflow/atree v0.6.0 // indirect - github.com/onflow/cadence v0.42.4 // indirect + github.com/onflow/cadence v0.42.5 // indirect github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122 // indirect github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-240579784e9b // indirect github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13 // indirect - github.com/onflow/flow-go-sdk v0.41.14 // indirect + github.com/onflow/flow-go-sdk v0.41.16 // indirect github.com/onflow/flow-nft/lib/go/contracts v1.1.0 // indirect - github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63 // indirect + github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2 // indirect github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d // indirect github.com/onflow/sdks v0.5.0 // indirect github.com/onflow/wal v0.0.0-20230529184820-bc9f8244608d // indirect @@ -226,6 +234,7 @@ require ( github.com/rs/cors v1.8.0 // indirect github.com/schollz/progressbar/v3 v3.13.1 // indirect github.com/sethvargo/go-retry v0.2.3 // indirect + github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect github.com/shirou/gopsutil/v3 v3.22.2 // indirect github.com/slok/go-http-metrics v0.10.0 // indirect github.com/sony/gobreaker v0.5.0 // indirect @@ -237,6 +246,7 @@ require ( github.com/spf13/viper v1.15.0 // indirect github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect + github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect github.com/texttheater/golang-levenshtein/levenshtein v0.0.0-20200805054039-cae8b0eaed6c // indirect github.com/tklauser/go-sysconf v0.3.9 // indirect github.com/tklauser/numcpus v0.3.0 // indirect @@ -281,6 +291,7 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.2.1 // indirect nhooyr.io/websocket v1.8.7 // indirect diff --git a/insecure/go.sum b/insecure/go.sum index 613e5fffa3f..faa95c626a2 100644 --- a/insecure/go.sum +++ b/insecure/go.sum @@ -96,8 +96,11 @@ github.com/OneOfOne/xxhash v1.2.5/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdII github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 h1:fLjPD/aNc3UIOA6tDi6QXUemppXK3P9BI7mr2hd6gx8= github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/VictoriaMetrics/fastcache v1.5.3/go.mod h1:+jv9Ckb+za/P1ZRg/sulP5Ni1v49daAVERr0H3CuscE= +github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= +github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= @@ -108,6 +111,7 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= @@ -185,6 +189,8 @@ github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13P github.com/btcsuite/btcd v0.21.0-beta/go.mod h1:ZSWyehm27aAuS9bvkATT+Xte3hjHZ+MRgMY/8NJ7K94= github.com/btcsuite/btcd/btcec/v2 v2.2.1 h1:xP60mv8fvp+0khmrN0zTdPC3cNm24rfeE6lh2R/Yv3E= github.com/btcsuite/btcd/btcec/v2 v2.2.1/go.mod h1:9/CSmJxmuvqzX9Wh2fXMWToLOHhPd11lSPuIupwTkI8= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 h1:KdUfX2zKommPRa+PD0sWZUyXe9w277ABlgELO7H04IM= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= github.com/btcsuite/btcutil v0.0.0-20190207003914-4c204d697803/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= @@ -277,6 +283,8 @@ github.com/davidlazar/go-crypto v0.0.0-20170701192655-dcfb0a7ac018/go.mod h1:rQY github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c h1:pFUpOrbxDR6AkioZ1ySsx5yxlDQZ8stG2b88gTPxgJU= github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c/go.mod h1:6UhI8N9EjYm1c2odKpFpAYeR8dsBeM7PtzQhRgxRr9U= github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ= +github.com/deckarep/golang-set/v2 v2.1.0 h1:g47V4Or+DUdzbs8FxCCmgb6VYd+ptPAngjM6dtGktsI= +github.com/deckarep/golang-set/v2 v2.1.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/decred/dcrd/crypto/blake256 v1.0.1 h1:7PltbUIQB7u/FfZ39+DGa/ShuMyJ5ilcvdfma9wOH6Y= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= @@ -421,6 +429,8 @@ github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+j github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/go-test/deep v1.0.5/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= @@ -436,6 +446,8 @@ github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= +github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/googleapis v0.0.0-20180223154316-0cd9801be74a/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/googleapis v1.4.1/go.mod h1:2lpHqI5OcWCtVElxXnPt+s8oJvMpySlOyM6xDCrzib4= @@ -496,6 +508,7 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= @@ -631,6 +644,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= +github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA= github.com/holiman/uint256 v1.2.2-0.20230321075855-87b91420868c h1:DZfsyhDK1hnSS5lH8l+JggqzEleHteTYfutAiVlSUM8= github.com/holiman/uint256 v1.2.2-0.20230321075855-87b91420868c/go.mod h1:SC8Ryt4n+UBbPbIBKaG9zbbDlp4jOru9xFZmPzLUTxw= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -864,6 +879,7 @@ github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= @@ -1271,6 +1287,7 @@ github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxzi github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= @@ -1278,12 +1295,14 @@ github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onflow/atree v0.1.0-beta1.0.20211027184039-559ee654ece9/go.mod h1:+6x071HgCF/0v5hQcaE5qqjc2UqN5gCU8h5Mk6uqpOg= github.com/onflow/atree v0.6.0 h1:j7nQ2r8npznx4NX39zPpBYHmdy45f4xwoi+dm37Jk7c= github.com/onflow/atree v0.6.0/go.mod h1:gBHU0M05qCbv9NN0kijLWMgC47gHVNBIp4KmsVFi0tc= github.com/onflow/cadence v0.20.1/go.mod h1:7mzUvPZUIJztIbr9eTvs+fQjWWHTF8veC+yk4ihcNIA= -github.com/onflow/cadence v0.42.4 h1:KoXnwPCMcjixZv+gHZwWkDiAyVExJhJJe6SebdnHNv8= -github.com/onflow/cadence v0.42.4/go.mod h1:raU8va8QRyTa/eUbhej4mbyW2ETePfSaywoo36MddgE= +github.com/onflow/cadence v0.42.5 h1:QCilotmJzfRToLd+02o3N62JIioSr8FfN7cujmR/IXQ= +github.com/onflow/cadence v0.42.5/go.mod h1:raU8va8QRyTa/eUbhej4mbyW2ETePfSaywoo36MddgE= github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122 h1:yUzR59WUue8BN/bfwy0eN4YOLYXeQ3G9I53H0amxbDU= github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122/go.mod h1:jM6GMAL+m0hjusUgiYDNrixPQ6b9s8xjoJQoEu5bHQI= github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-240579784e9b h1:Q9iCekuCTeZU3CkVRTj5BhMBY/vR/uA2K63JTl5vCD8= @@ -1291,16 +1310,16 @@ github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-2 github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13 h1:B4ll7e3j+MqTJv2122Enq3RtDNzmIGRu9xjV7fo7un0= github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13/go.mod h1:kTMFIySzEJJeupk+7EmXs0EJ6CBWY/MV9fv9iYQk+RU= github.com/onflow/flow-go-sdk v0.24.0/go.mod h1:IoptMLPyFXWvyd9yYA6/4EmSeeozl6nJoIv4FaEMg74= -github.com/onflow/flow-go-sdk v0.41.14 h1:Pe9hogrRWkICa34Gu8yLLCMlTnu5f3L7DNXeXREkkVw= -github.com/onflow/flow-go-sdk v0.41.14/go.mod h1:yOyiwL7bt8N6FpLAyNO2D46vnFlLnIu9Cyt5PKIgqAk= +github.com/onflow/flow-go-sdk v0.41.16 h1:HsmHwEVmj+iK+GszHbFseHh7Ii5W3PWOIRNAH/En08Q= +github.com/onflow/flow-go-sdk v0.41.16/go.mod h1:bVrVNoJKiwB6vW5Qbm5tFAfJBQ5we4uSQWnn9gNAFhQ= github.com/onflow/flow-go/crypto v0.21.3/go.mod h1:vI6V4CY3R6c4JKBxdcRiR/AnjBfL8OSD97bJc60cLuQ= github.com/onflow/flow-go/crypto v0.24.9 h1:0EQp+kSZYJepMIiSypfJVe7tzsPcb6UXOdOtsTCDhBs= github.com/onflow/flow-go/crypto v0.24.9/go.mod h1:fqCzkIBBMRRkciVrvW21rECKq1oD7Q6u+bCI78lfNX0= github.com/onflow/flow-nft/lib/go/contracts v1.1.0 h1:rhUDeD27jhLwOqQKI/23008CYfnqXErrJvc4EFRP2a0= github.com/onflow/flow-nft/lib/go/contracts v1.1.0/go.mod h1:YsvzYng4htDgRB9sa9jxdwoTuuhjK8WYWXTyLkIigZY= github.com/onflow/flow/protobuf/go/flow v0.2.2/go.mod h1:gQxYqCfkI8lpnKsmIjwtN2mV/N2PIwc1I+RUK4HPIc8= -github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63 h1:SX8OhYbyKBExhy4qEDR/Hw6MVTBTzlDb8LfCHfFyte4= -github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63/go.mod h1:NA2pX2nw8zuaxfKphhKsk00kWLwfd+tv8mS23YXO4Sk= +github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2 h1:qZjl4wSTG/E9znEjkHF0nNaEdlBLJoOEAtr7xUsTNqc= +github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2/go.mod h1:NA2pX2nw8zuaxfKphhKsk00kWLwfd+tv8mS23YXO4Sk= github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d h1:QcOAeEyF3iAUHv21LQ12sdcsr0yFrJGoGLyCAzYYtvI= github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d/go.mod h1:GCPpiyRoHncdqPj++zPr9ZOYBX4hpJ0pYZRYqSE8VKk= github.com/onflow/sdks v0.5.0 h1:2HCRibwqDaQ1c9oUApnkZtEAhWiNY2GTpRD5+ftdkN8= @@ -1314,6 +1333,7 @@ github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+ github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= @@ -1463,6 +1483,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sethvargo/go-retry v0.2.3 h1:oYlgvIvsju3jNbottWABtbnoLC+GDtLdBHxKWxQm/iU= github.com/sethvargo/go-retry v0.2.3/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw= +github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible h1:Bn1aCHHRnjv4Bl16T8rcaFjYSrGrIZvpiGO6P3Q4GpU= +github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil/v3 v3.22.2 h1:wCrArWFkHYIdDxx/FSfF5RB4dpJYW6t7rcp3+zL8uks= github.com/shirou/gopsutil/v3 v3.22.2/go.mod h1:WapW1AOOPlHyXr+yOyw3uYx36enocrtSoSBy0L5vUHY= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -1565,6 +1587,8 @@ github.com/supranational/blst v0.3.4/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/ github.com/supranational/blst v0.3.11-0.20230406105308-e9dfc5ee724b h1:u49mjRnygnB34h8OKbnNJFVUtWSKIKb1KukdV8bILUM= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d/go.mod h1:9OrXJhf154huy1nPWmuSrkgjPUtUNhA+Zmy+6AESzuA= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/texttheater/golang-levenshtein/levenshtein v0.0.0-20200805054039-cae8b0eaed6c h1:HelZ2kAFadG0La9d+4htN4HzQ68Bm2iM9qKMSMES6xg= github.com/texttheater/golang-levenshtein/levenshtein v0.0.0-20200805054039-cae8b0eaed6c/go.mod h1:JlzghshsemAMDGZLytTFY8C1JQxQPhnatWqNwUXjggo= @@ -1841,6 +1865,7 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= @@ -1958,6 +1983,7 @@ golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200918174421-af09f7315aff/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1977,6 +2003,7 @@ golang.org/x/sys v0.0.0-20210309074719-68d13333faf2/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210317225723-c4fcb01b228e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -2319,12 +2346,14 @@ gopkg.in/ini.v1 v1.51.1/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= gopkg.in/olebedev/go-duktape.v3 v3.0.0-20190213234257-ec84240a7772/go.mod h1:uAJfkITjFhyEEuUfm7bsmCZRbW5WRq8s9EY8HZ6hCns= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/sourcemap.v1 v1.0.5/go.mod h1:2RlvNNSMglmRrcvhfuzp4hQHwOtjxlbjX7UPY/GXb78= gopkg.in/src-d/go-cli.v0 v0.0.0-20181105080154-d492247bbc0d/go.mod h1:z+K8VcOYVYcSwSjGebuDL6176A1XskgbtNl64NSg+n8= gopkg.in/src-d/go-log.v1 v1.0.1/go.mod h1:GN34hKP0g305ysm2/hctJ0Y8nWP3zxXXJ8GFabTyABE= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= diff --git a/insecure/integration/functional/test/gossipsub/rpc_inspector/utils.go b/insecure/integration/functional/test/gossipsub/rpc_inspector/utils.go index 863fb36a898..555a06a6bba 100644 --- a/insecure/integration/functional/test/gossipsub/rpc_inspector/utils.go +++ b/insecure/integration/functional/test/gossipsub/rpc_inspector/utils.go @@ -58,22 +58,6 @@ func withExpectedNotificationDissemination(expectedNumOfTotalNotif int, f onNoti } } -// mockDistributorReadyDoneAware mocks the Ready and Done methods of the distributor to return a channel that is already closed, -// so that the distributor is considered ready and done when the test needs. -func mockDistributorReadyDoneAware(d *mockp2p.GossipSubInspectorNotificationDistributor) { - d.On("Start", mockery.Anything).Return().Maybe() - d.On("Ready").Return(func() <-chan struct{} { - ch := make(chan struct{}) - close(ch) - return ch - }()).Maybe() - d.On("Done").Return(func() <-chan struct{} { - ch := make(chan struct{}) - close(ch) - return ch - }()).Maybe() -} - func meshTracerFixture(flowConfig *config.FlowConfig, idProvider module.IdentityProvider) *tracer.GossipSubMeshTracer { meshTracerCfg := &tracer.GossipSubMeshTracerConfig{ Logger: unittest.Logger(), diff --git a/insecure/integration/functional/test/gossipsub/rpc_inspector/validation_inspector_test.go b/insecure/integration/functional/test/gossipsub/rpc_inspector/validation_inspector_test.go index 50434ad6de8..d7a53f6fce4 100644 --- a/insecure/integration/functional/test/gossipsub/rpc_inspector/validation_inspector_test.go +++ b/insecure/integration/functional/test/gossipsub/rpc_inspector/validation_inspector_test.go @@ -22,6 +22,7 @@ import ( "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" "github.com/onflow/flow-go/module/mock" + "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/p2p" "github.com/onflow/flow-go/network/p2p/inspector/validation" @@ -86,23 +87,33 @@ func TestValidationInspector_InvalidTopicId_Detection(t *testing.T) { signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) withExpectedNotificationDissemination(expectedNumOfTotalNotif, inspectDisseminatedNotifyFunc)(distributor, spammer) meshTracer := meshTracerFixture(flowConfig, idProvider) - - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() @@ -114,9 +125,7 @@ func TestValidationInspector_InvalidTopicId_Detection(t *testing.T) { invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())) // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{unknownTopic.String(), malformedTopic.String(), invalidSporkIDTopic.String()} - })) + topicProvider.UpdateTopics([]string{unknownTopic.String(), malformedTopic.String(), invalidSporkIDTopic.String()}) validationInspector.Start(signalerCtx) nodes := []p2p.LibP2PNode{victimNode, spammer.SpammerNode} @@ -213,32 +222,41 @@ func TestValidationInspector_DuplicateTopicId_Detection(t *testing.T) { signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) withExpectedNotificationDissemination(expectedNumOfTotalNotif, inspectDisseminatedNotifyFunc)(distributor, spammer) meshTracer := meshTracerFixture(flowConfig, idProvider) - - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) + corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() // a topics spork ID is considered invalid if it does not match the current spork ID duplicateTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, sporkID)) // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{duplicateTopic.String()} - })) + topicProvider.UpdateTopics([]string{duplicateTopic.String()}) validationInspector.Start(signalerCtx) nodes := []p2p.LibP2PNode{victimNode, spammer.SpammerNode} @@ -286,7 +304,9 @@ func TestValidationInspector_IHaveDuplicateMessageId_Detection(t *testing.T) { require.True(t, ok) require.True(t, validation.IsDuplicateTopicErr(notification.Error)) require.Equal(t, spammer.SpammerNode.ID(), notification.PeerID) - require.True(t, notification.MsgType == p2pmsg.CtrlMsgIHave, fmt.Sprintf("unexpected control message type %s error: %s", notification.MsgType, notification.Error)) + require.True(t, + notification.MsgType == p2pmsg.CtrlMsgIHave, + fmt.Sprintf("unexpected control message type %s error: %s", notification.MsgType, notification.Error)) invIHaveNotifCount.Inc() if count.Load() == int64(expectedNumOfTotalNotif) { @@ -302,23 +322,35 @@ func TestValidationInspector_IHaveDuplicateMessageId_Detection(t *testing.T) { signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) withExpectedNotificationDissemination(expectedNumOfTotalNotif, inspectDisseminatedNotifyFunc)(distributor, spammer) meshTracer := meshTracerFixture(flowConfig, idProvider) - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) + corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() @@ -326,9 +358,7 @@ func TestValidationInspector_IHaveDuplicateMessageId_Detection(t *testing.T) { pushBlocks := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, sporkID)) reqChunks := channels.Topic(fmt.Sprintf("%s/%s", channels.RequestChunks, sporkID)) // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{pushBlocks.String(), reqChunks.String()} - })) + topicProvider.UpdateTopics([]string{pushBlocks.String(), reqChunks.String()}) validationInspector.Start(signalerCtx) nodes := []p2p.LibP2PNode{victimNode, spammer.SpammerNode} @@ -406,32 +436,42 @@ func TestValidationInspector_UnknownClusterId_Detection(t *testing.T) { signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) withExpectedNotificationDissemination(expectedNumOfTotalNotif, inspectDisseminatedNotifyFunc)(distributor, spammer) meshTracer := meshTracerFixture(flowConfig, idProvider) - - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) + corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() - idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Times(3) + idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Times(4) // setup cluster prefixed topic with an invalid cluster ID unknownClusterID := channels.Topic(channels.SyncCluster("unknown-cluster-ID")) // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{unknownClusterID.String()} - })) + topicProvider.UpdateTopics([]string{unknownClusterID.String()}) + // consume cluster ID update so that active cluster IDs set validationInspector.ActiveClustersChanged(flow.ChainIDList{"known-cluster-id"}) @@ -485,36 +525,49 @@ func TestValidationInspector_ActiveClusterIdsNotSet_Graft_Detection(t *testing.T }) logger := zerolog.New(os.Stdout).Level(zerolog.WarnLevel).Hook(hook) + inspectorIdProvider := mock.NewIdentityProvider(t) idProvider := mock.NewIdentityProvider(t) spammer := corruptlibp2p.NewGossipSubRouterSpammer(t, sporkID, role, idProvider) ctx, cancel := context.WithCancel(context.Background()) signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) meshTracer := meshTracerFixture(flowConfig, idProvider) - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, logger, sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: logger, + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: inspectorIdProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) + corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() - idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Times(int(controlMessageCount + 1)) - + idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() + // we expect controlMessageCount plus 1 extra call, this is due to messages that are exchanged when the nodes startup + inspectorIdProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Times(int(controlMessageCount + 1)) clusterPrefixedTopic := randomClusterPrefixedTopic() + // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{clusterPrefixedTopic.String()} - })) + topicProvider.UpdateTopics([]string{clusterPrefixedTopic.String()}) // we deliberately avoid setting the cluster IDs so that we eventually receive errors after we have exceeded the allowed cluster // prefixed hard threshold @@ -525,9 +578,7 @@ func TestValidationInspector_ActiveClusterIdsNotSet_Graft_Detection(t *testing.T defer stopComponents(t, cancel, nodes, validationInspector) // generate multiple control messages with GRAFT's for randomly generated // cluster prefixed channels, this ensures we do not encounter duplicate topic ID errors - ctlMsgs := spammer.GenerateCtlMessages(int(controlMessageCount), - corruptlibp2p.WithGraft(1, clusterPrefixedTopic.String()), - ) + ctlMsgs := spammer.GenerateCtlMessages(int(controlMessageCount), corruptlibp2p.WithGraft(1, clusterPrefixedTopic.String())) // start spamming the victim peer spammer.SpamControlMessage(t, victimNode, ctlMsgs) @@ -569,30 +620,41 @@ func TestValidationInspector_ActiveClusterIdsNotSet_Prune_Detection(t *testing.T signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) meshTracer := meshTracerFixture(flowConfig, idProvider) - - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, logger, sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + inspectorIdProvider := mock.NewIdentityProvider(t) + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: logger, + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: inspectorIdProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() - idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Times(int(controlMessageCount + 1)) + idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() + // we expect controlMessageCount plus 1 extra call, this is due to messages that are exchanged when the nodes startup + inspectorIdProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Times(int(controlMessageCount + 1)) clusterPrefixedTopic := randomClusterPrefixedTopic() // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{clusterPrefixedTopic.String()} - })) + topicProvider.UpdateTopics([]string{clusterPrefixedTopic.String()}) // we deliberately avoid setting the cluster IDs so that we eventually receive errors after we have exceeded the allowed cluster // prefixed hard threshold @@ -603,9 +665,7 @@ func TestValidationInspector_ActiveClusterIdsNotSet_Prune_Detection(t *testing.T defer stopComponents(t, cancel, nodes, validationInspector) // generate multiple control messages with GRAFT's for randomly generated // cluster prefixed channels, this ensures we do not encounter duplicate topic ID errors - ctlMsgs := spammer.GenerateCtlMessages(int(controlMessageCount), - corruptlibp2p.WithPrune(1, clusterPrefixedTopic.String()), - ) + ctlMsgs := spammer.GenerateCtlMessages(int(controlMessageCount), corruptlibp2p.WithPrune(1, clusterPrefixedTopic.String())) // start spamming the victim peer spammer.SpamControlMessage(t, victimNode, ctlMsgs) @@ -652,24 +712,38 @@ func TestValidationInspector_UnstakedNode_Detection(t *testing.T) { signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) meshTracer := meshTracerFixture(flowConfig, idProvider) - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, logger, sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + inspectorIdProvider := mock.NewIdentityProvider(t) + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: logger, + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: inspectorIdProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() - idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(nil, false).Times(3) + idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() + // we expect 2 calls from notification inspection plus 1 extra call, this is due to messages that are exchanged when the nodes startup + inspectorIdProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(nil, false).Times(3) // setup cluster prefixed topic with an invalid cluster ID clusterID := flow.ChainID("known-cluster-id") @@ -678,9 +752,7 @@ func TestValidationInspector_UnstakedNode_Detection(t *testing.T) { validationInspector.ActiveClustersChanged(flow.ChainIDList{clusterID}) // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{clusterIDTopic.String()} - })) + topicProvider.UpdateTopics([]string{clusterIDTopic.String()}) validationInspector.Start(signalerCtx) nodes := []p2p.LibP2PNode{victimNode, spammer.SpammerNode} @@ -710,7 +782,7 @@ func TestValidationInspector_InspectIWants_CacheMissThreshold(t *testing.T) { require.NoError(t, err) inspectorConfig := flowConfig.NetworkConfig.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs // force all cache miss checks - inspectorConfig.IWantRPCInspectionConfig.CacheMissCheckSize = 0 + inspectorConfig.IWantRPCInspectionConfig.CacheMissCheckSize = 1 inspectorConfig.NumberOfWorkers = 1 inspectorConfig.IWantRPCInspectionConfig.CacheMissThreshold = .5 // set cache miss threshold to 50% messageCount := 1 @@ -723,7 +795,9 @@ func TestValidationInspector_InspectIWants_CacheMissThreshold(t *testing.T) { notification, ok := args[0].(*p2p.InvCtrlMsgNotif) require.True(t, ok) require.Equal(t, spammer.SpammerNode.ID(), notification.PeerID) - require.True(t, notification.MsgType == p2pmsg.CtrlMsgIWant, fmt.Sprintf("unexpected control message type %s error: %s", notification.MsgType, notification.Error)) + require.True(t, + notification.MsgType == p2pmsg.CtrlMsgIWant, + fmt.Sprintf("unexpected control message type %s error: %s", notification.MsgType, notification.Error)) require.True(t, validation.IsIWantCacheMissThresholdErr(notification.Error)) cacheMissThresholdNotifCount.Inc() @@ -740,23 +814,34 @@ func TestValidationInspector_InspectIWants_CacheMissThreshold(t *testing.T) { signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) withExpectedNotificationDissemination(1, inspectDisseminatedNotifyFunc)(distributor, spammer) meshTracer := meshTracerFixture(flowConfig, idProvider) - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) + topicProvider := newMockUpdatableTopicProvider() + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) require.NoError(t, err) corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() @@ -765,15 +850,12 @@ func TestValidationInspector_InspectIWants_CacheMissThreshold(t *testing.T) { // create control message with iWant that contains 5 message IDs that were not tracked ctlWithIWants := spammer.GenerateCtlMessages(int(controlMessageCount), corruptlibp2p.WithIWant(messageCount, messageCount)) ctlWithIWants[0].Iwant[0].MessageIDs = messageIDs // the first 5 message ids will not have a corresponding iHave - topic := channels.PushBlocks // create control message with iHave that contains only the last 4 message IDs, this will force cache misses for the other 6 message IDs ctlWithIhaves := spammer.GenerateCtlMessages(int(controlMessageCount), corruptlibp2p.WithIHave(messageCount, messageCount, topic.String())) ctlWithIhaves[0].Ihave[0].MessageIDs = messageIDs[6:] // set topic oracle - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{topic.String()} - })) + topicProvider.UpdateTopics([]string{topic.String()}) validationInspector.Start(signalerCtx) nodes := []p2p.LibP2PNode{victimNode, spammer.SpammerNode} startNodesAndEnsureConnected(t, signalerCtx, nodes, sporkID) @@ -805,20 +887,59 @@ func TestValidationInspector_InspectRpcPublishMessages(t *testing.T) { require.NoError(t, err) inspectorConfig := flowConfig.NetworkConfig.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs inspectorConfig.NumberOfWorkers = 1 - // after 5 errors encountered disseminate a notification - inspectorConfig.RpcMessageErrorThreshold = 4 + + idProvider := mock.NewIdentityProvider(t) + spammer := corruptlibp2p.NewGossipSubRouterSpammer(t, sporkID, role, idProvider) controlMessageCount := int64(1) notificationCount := atomic.NewUint64(0) done := make(chan struct{}) + validTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.TestNetworkChannel.String(), sporkID)).String() + // create unknown topic + unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", corruptlibp2p.GossipSubTopicIdFixture(), sporkID)).String() + // create malformed topic + malformedTopic := channels.Topic(unittest.RandomStringFixture(t, 100)).String() + // a topics spork ID is considered invalid if it does not match the current spork ID + invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())).String() + + // unknown peer ID + unknownPeerID := unittest.PeerIdFixture(t) + + // ejected identity + ejectedIdentityPeerID := unittest.PeerIdFixture(t) + ejectedIdentity := unittest.IdentityFixture() + ejectedIdentity.Ejected = true + + // invalid messages this should force a notification to disseminate + invalidPublishMsgs := []*pb.Message{ + {Topic: &unknownTopic, From: []byte(spammer.SpammerNode.ID())}, + {Topic: &malformedTopic, From: []byte(spammer.SpammerNode.ID())}, + {Topic: &malformedTopic, From: []byte(spammer.SpammerNode.ID())}, + {Topic: &malformedTopic, From: []byte(spammer.SpammerNode.ID())}, + {Topic: &invalidSporkIDTopic, From: []byte(spammer.SpammerNode.ID())}, + {Topic: &validTopic, From: []byte(unknownPeerID)}, + {Topic: &validTopic, From: []byte(ejectedIdentityPeerID)}, + } + topic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, sporkID)) + // first create 4 valid messages + publishMsgs := unittest.GossipSubMessageFixtures(4, topic.String(), unittest.WithFrom(spammer.SpammerNode.ID())) + publishMsgs = append(publishMsgs, invalidPublishMsgs...) // ensure expected notifications are disseminated with expected error inspectDisseminatedNotifyFunc := func(spammer *corruptlibp2p.GossipSubRouterSpammer) func(args mockery.Arguments) { return func(args mockery.Arguments) { notification, ok := args[0].(*p2p.InvCtrlMsgNotif) require.True(t, ok) require.Equal(t, spammer.SpammerNode.ID(), notification.PeerID) - require.True(t, notification.MsgType == p2pmsg.RpcPublishMessage, fmt.Sprintf("unexpected control message type %s error: %s", notification.MsgType, notification.Error)) + require.True(t, + notification.MsgType == p2pmsg.RpcPublishMessage, + fmt.Sprintf("unexpected control message type %s error: %s", notification.MsgType, notification.Error)) require.True(t, validation.IsInvalidRpcPublishMessagesErr(notification.Error)) + require.Contains(t, + notification.Error.Error(), + fmt.Sprintf("%d error(s) encountered", len(invalidPublishMsgs)), + fmt.Sprintf("expected %d errors, an error for each invalid pubsub message", len(invalidPublishMsgs))) + require.Contains(t, notification.Error.Error(), fmt.Sprintf("received rpc publish message from unstaked peer: %s", unknownPeerID)) + require.Contains(t, notification.Error.Error(), fmt.Sprintf("received rpc publish message from ejected peer: %s", ejectedIdentityPeerID)) notificationCount.Inc() if notificationCount.Load() == 1 { close(done) @@ -826,55 +947,58 @@ func TestValidationInspector_InspectRpcPublishMessages(t *testing.T) { } } - idProvider := mock.NewIdentityProvider(t) - spammer := corruptlibp2p.NewGossipSubRouterSpammer(t, sporkID, role, idProvider) - ctx, cancel := context.WithCancel(context.Background()) signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(t) - mockDistributorReadyDoneAware(distributor) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) withExpectedNotificationDissemination(1, inspectDisseminatedNotifyFunc)(distributor, spammer) meshTracer := meshTracerFixture(flowConfig, idProvider) + topicProvider := newMockUpdatableTopicProvider() + validationInspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: sporkID, + Config: &inspectorConfig, + Distributor: distributor, + IdProvider: idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: meshTracer, + NetworkingType: network.PrivateNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) + require.NoError(t, err) + // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation + topics := make([]string, len(publishMsgs)) + for i := 0; i < len(publishMsgs); i++ { + topics[i] = publishMsgs[i].GetTopic() + } + topicProvider.UpdateTopics(topics) + + // after 7 errors encountered disseminate a notification + inspectorConfig.RpcMessageErrorThreshold = 6 - validationInspector, err := validation.NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &inspectorConfig, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), meshTracer) require.NoError(t, err) corruptInspectorFunc := corruptlibp2p.CorruptInspectorFunc(validationInspector) - victimNode, victimIdentity := p2ptest.NodeFixture( - t, + victimNode, victimIdentity := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(role), p2ptest.WithGossipSubTracer(meshTracer), - internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), - corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc)), - ) + internal.WithCorruptGossipSub(corruptlibp2p.CorruptGossipSubFactory(), corruptlibp2p.CorruptGossipSubConfigFactoryWithInspector(corruptInspectorFunc))) idProvider.On("ByPeerID", victimNode.ID()).Return(&victimIdentity, true).Maybe() idProvider.On("ByPeerID", spammer.SpammerNode.ID()).Return(&spammer.SpammerId, true).Maybe() - topic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, sporkID)) - // first create 4 valid messages - publishMsgs := unittest.GossipSubMessageFixtures(t, 4, topic.String()) - // create unknown topic - unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", corruptlibp2p.GossipSubTopicIdFixture(), sporkID)).String() - // create malformed topic - malformedTopic := channels.Topic(unittest.RandomStringFixture(t, 100)).String() - // a topics spork ID is considered invalid if it does not match the current spork ID - invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())).String() - // append 5 messages with invalid topics, this should force a notification to disseminate - publishMsgs = append(publishMsgs, []*pb.Message{ - {Topic: &unknownTopic}, - {Topic: &malformedTopic}, - {Topic: &malformedTopic}, - {Topic: &malformedTopic}, - {Topic: &invalidSporkIDTopic}, - }...) + // return nil for unknown peer ID indicating unstaked peer + idProvider.On("ByPeerID", unknownPeerID).Return(nil, false).Once() + // return ejected identity for peer ID will force message validation failure + idProvider.On("ByPeerID", ejectedIdentityPeerID).Return(ejectedIdentity, true).Once() // set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation - require.NoError(t, validationInspector.SetTopicOracle(func() []string { - return []string{topic.String(), unknownTopic, malformedTopic, invalidSporkIDTopic} - })) + topicProvider.UpdateTopics([]string{topic.String(), unknownTopic, malformedTopic, invalidSporkIDTopic}) validationInspector.Start(signalerCtx) nodes := []p2p.LibP2PNode{victimNode, spammer.SpammerNode} @@ -885,7 +1009,6 @@ func TestValidationInspector_InspectRpcPublishMessages(t *testing.T) { // prepare to spam - generate control messages ctlMsg := spammer.GenerateCtlMessages(int(controlMessageCount)) - // start spamming the victim peer spammer.SpamControlMessage(t, victimNode, ctlMsg, publishMsgs...) @@ -908,37 +1031,34 @@ func TestGossipSubSpamMitigationIntegration(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) - victimNode, victimId := p2ptest.NodeFixture( - t, + victimNode, victimId := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithRole(flow.RoleConsensus), - p2ptest.EnablePeerScoringWithOverride(p2p.PeerScoringConfigNoOverride), - ) + p2ptest.EnablePeerScoringWithOverride(p2p.PeerScoringConfigNoOverride)) ids := flow.IdentityList{&victimId, &spammer.SpammerId} - idProvider.On("ByPeerID", mockery.Anything).Return( - func(peerId peer.ID) *flow.Identity { - switch peerId { - case victimNode.ID(): - return &victimId - case spammer.SpammerNode.ID(): - return &spammer.SpammerId - default: - return nil - } + idProvider.On("ByPeerID", mockery.Anything).Return(func(peerId peer.ID) *flow.Identity { + switch peerId { + case victimNode.ID(): + return &victimId + case spammer.SpammerNode.ID(): + return &spammer.SpammerId + default: + return nil + } - }, func(peerId peer.ID) bool { - switch peerId { - case victimNode.ID(): - fallthrough - case spammer.SpammerNode.ID(): - return true - default: - return false - } - }) + }, func(peerId peer.ID) bool { + switch peerId { + case victimNode.ID(): + fallthrough + case spammer.SpammerNode.ID(): + return true + default: + return false + } + }) spamRpcCount := 10 // total number of individual rpc messages to send spamCtrlMsgCount := int64(10) // total number of control messages to send on each RPC @@ -947,7 +1067,7 @@ func TestGossipSubSpamMitigationIntegration(t *testing.T) { unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", corruptlibp2p.GossipSubTopicIdFixture(), sporkID)) // malformedTopic is a topic that is not shaped like a valid topic (i.e., it does not have the correct prefix and spork ID). - malformedTopic := channels.Topic(unittest.RandomStringFixture(t, 100)) + malformedTopic := channels.Topic("!@#$%^&**((") // invalidSporkIDTopic is a topic that has a valid prefix but an invalid spork ID (i.e., not the current spork ID). invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())) @@ -997,8 +1117,7 @@ func TestGossipSubSpamMitigationIntegration(t *testing.T) { // now we expect the detection and mitigation to kick in and the victim node to disconnect from the spammer node. // so the spammer and victim nodes should not be able to exchange messages on the topic. - p2ptest.EnsureNoPubsubExchangeBetweenGroups( - t, + p2ptest.EnsureNoPubsubExchangeBetweenGroups(t, ctx, []p2p.LibP2PNode{victimNode}, flow.IdentifierList{victimId.NodeID}, @@ -1010,3 +1129,36 @@ func TestGossipSubSpamMitigationIntegration(t *testing.T) { return unittest.ProposalFixture() }) } + +// mockUpdatableTopicProvider is a mock implementation of the TopicProvider interface. +// TODO: there is a duplicate implementation of this in the test package, we should consolidate them. +// The duplicate exists in network/p2p/inspector/internal/mockTopicProvider.go. The reason for duplication is that +// the inspector/validation package does not have a separate test package. Hence, sharing the mock implementation +// will cause a cyclic dependency. +type mockUpdatableTopicProvider struct { + topics []string + subscriptions map[string][]peer.ID +} + +func newMockUpdatableTopicProvider() *mockUpdatableTopicProvider { + return &mockUpdatableTopicProvider{ + topics: []string{}, + subscriptions: map[string][]peer.ID{}, + } +} + +func (m *mockUpdatableTopicProvider) GetTopics() []string { + return m.topics +} + +func (m *mockUpdatableTopicProvider) ListPeers(topic string) []peer.ID { + return m.subscriptions[topic] +} + +func (m *mockUpdatableTopicProvider) UpdateTopics(topics []string) { + m.topics = topics +} + +func (m *mockUpdatableTopicProvider) UpdateSubscriptions(topic string, peers []peer.ID) { + m.subscriptions[topic] = peers +} diff --git a/integration/go.mod b/integration/go.mod index 882f34bcde6..fe711b2257f 100644 --- a/integration/go.mod +++ b/integration/go.mod @@ -3,8 +3,9 @@ module github.com/onflow/flow-go/integration go 1.20 require ( - cloud.google.com/go/bigquery v1.52.0 + cloud.google.com/go/bigquery v1.53.0 github.com/VividCortex/ewma v1.2.0 + github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 github.com/coreos/go-semver v0.3.0 github.com/dapperlabs/testingdock v0.4.5-0.20231020233342-a2853fe18724 github.com/dgraph-io/badger/v2 v2.2007.4 @@ -18,15 +19,15 @@ require ( github.com/ipfs/go-datastore v0.6.0 github.com/ipfs/go-ds-badger2 v0.1.3 github.com/ipfs/go-ipfs-blockstore v1.3.0 - github.com/onflow/cadence v0.42.4 + github.com/onflow/cadence v0.42.5 github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122 github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-240579784e9b - github.com/onflow/flow-emulator v0.54.1-0.20231024204057-0273f8fe3807 - github.com/onflow/flow-go v0.32.3 - github.com/onflow/flow-go-sdk v0.41.14 + github.com/onflow/flow-emulator v0.54.1-0.20231110220143-28061d9b37e7 + github.com/onflow/flow-go v0.32.7 + github.com/onflow/flow-go-sdk v0.41.16 github.com/onflow/flow-go/crypto v0.24.9 github.com/onflow/flow-go/insecure v0.0.0-00010101000000-000000000000 - github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63 + github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2 github.com/plus3it/gorecurcopy v0.0.1 github.com/prometheus/client_golang v1.16.0 github.com/prometheus/client_model v0.5.0 @@ -37,13 +38,13 @@ require ( go.uber.org/atomic v1.11.0 golang.org/x/exp v0.0.0-20230321023759-10a507213a29 golang.org/x/sync v0.3.0 - google.golang.org/grpc v1.58.3 + google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 ) require ( - cloud.google.com/go v0.110.4 // indirect - cloud.google.com/go/compute v1.21.0 // indirect + cloud.google.com/go v0.110.7 // indirect + cloud.google.com/go/compute v1.23.0 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/iam v1.1.1 // indirect cloud.google.com/go/storage v1.30.1 // indirect @@ -51,6 +52,8 @@ require ( github.com/DataDog/zstd v1.5.2 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/ProtonMail/go-crypto v0.0.0-20221026131551-cf6655e29de4 // indirect + github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 // indirect + github.com/VictoriaMetrics/fastcache v1.6.0 // indirect github.com/acomagu/bufpipe v1.0.3 // indirect github.com/andybalholm/brotli v1.0.4 // indirect github.com/apache/arrow/go/v12 v12.0.0 // indirect @@ -90,6 +93,7 @@ require ( github.com/cskr/pubsub v1.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect + github.com/deckarep/golang-set/v2 v2.1.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/dgraph-io/ristretto v0.1.0 // indirect github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect @@ -127,21 +131,22 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.1 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect + github.com/go-stack/stack v1.8.1 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/goccy/go-json v0.9.11 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect + github.com/gofrs/flock v0.8.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/glog v1.1.0 // indirect + github.com/golang/glog v1.1.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect github.com/google/flatbuffers v2.0.8+incompatible // indirect - github.com/google/go-cmp v0.5.9 // indirect github.com/google/gopacket v1.1.19 // indirect github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect github.com/google/s2a-go v0.1.4 // indirect - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.3.1 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.11.0 // indirect github.com/gorilla/mux v1.8.0 // indirect @@ -154,6 +159,7 @@ require ( github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d // indirect github.com/hashicorp/golang-lru/v2 v2.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/holiman/bloomfilter/v2 v2.0.3 // indirect github.com/holiman/uint256 v1.2.2-0.20230321075855-87b91420868c // indirect github.com/huin/goupnp v1.2.0 // indirect github.com/imdario/mergo v0.3.13 // indirect @@ -237,6 +243,7 @@ require ( github.com/multiformats/go-multihash v0.2.3 // indirect github.com/multiformats/go-multistream v0.4.1 // indirect github.com/multiformats/go-varint v0.0.7 // indirect + github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/onflow/atree v0.6.0 // indirect github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13 // indirect github.com/onflow/flow-nft/lib/go/contracts v1.1.0 // indirect @@ -274,6 +281,7 @@ require ( github.com/schollz/progressbar/v3 v3.13.1 // indirect github.com/sergi/go-diff v1.1.0 // indirect github.com/sethvargo/go-retry v0.2.3 // indirect + github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect github.com/shirou/gopsutil/v3 v3.22.2 // indirect github.com/sirupsen/logrus v1.9.2 // indirect github.com/skeema/knownhosts v1.1.0 // indirect @@ -288,6 +296,7 @@ require ( github.com/spf13/viper v1.15.0 // indirect github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect + github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect github.com/texttheater/golang-levenshtein/levenshtein v0.0.0-20200805054039-cae8b0eaed6c // indirect github.com/tklauser/go-sysconf v0.3.9 // indirect github.com/tklauser/numcpus v0.3.0 // indirect @@ -317,8 +326,8 @@ require ( go.uber.org/zap v1.24.0 // indirect golang.org/x/crypto v0.12.0 // indirect golang.org/x/mod v0.10.0 // indirect - golang.org/x/net v0.12.0 // indirect - golang.org/x/oauth2 v0.10.0 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/oauth2 v0.11.0 // indirect golang.org/x/sys v0.12.0 // indirect golang.org/x/term v0.11.0 // indirect golang.org/x/text v0.12.0 // indirect @@ -328,11 +337,12 @@ require ( gonum.org/v1/gonum v0.13.0 // indirect google.golang.org/api v0.126.0 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect + google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.2.1 // indirect diff --git a/integration/go.sum b/integration/go.sum index fed603cf410..7f8cf884913 100644 --- a/integration/go.sum +++ b/integration/go.sum @@ -29,21 +29,21 @@ cloud.google.com/go v0.90.0/go.mod h1:kRX0mNRHe0e2rC6oNakvwQqzyDmg57xJ+SZU1eT2aD cloud.google.com/go v0.93.3/go.mod h1:8utlLll2EF5XMAV15woO4lSbWQlk8rer9aLOfLh7+YI= cloud.google.com/go v0.94.1/go.mod h1:qAlAugsXlC+JWO+Bke5vCtc9ONxjQT3drlTTnAplMW4= cloud.google.com/go v0.97.0/go.mod h1:GF7l59pYBVlXQIBLx3a761cZ41F9bBH3JUlihCt2Udc= -cloud.google.com/go v0.110.4 h1:1JYyxKMN9hd5dR2MYTPWkGUgcoxVVhg0LKNKEo0qvmk= -cloud.google.com/go v0.110.4/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI= +cloud.google.com/go v0.110.7 h1:rJyC7nWRg2jWGZ4wSJ5nY65GTdYJkg0cd/uXb+ACI6o= +cloud.google.com/go v0.110.7/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/bigquery v1.52.0 h1:JKLNdxI0N+TIUWD6t9KN646X27N5dQWq9dZbbTWZ8hc= -cloud.google.com/go/bigquery v1.52.0/go.mod h1:3b/iXjRQGU4nKa87cXeg6/gogLjO8C6PmuM8i5Bi/u4= -cloud.google.com/go/compute v1.21.0 h1:JNBsyXVoOoNJtTQcnEY5uYpZIbeCTYIeDe0Xh1bySMk= -cloud.google.com/go/compute v1.21.0/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM= +cloud.google.com/go/bigquery v1.53.0 h1:K3wLbjbnSlxhuG5q4pntHv5AEbQM1QqHKGYgwFIqOTg= +cloud.google.com/go/bigquery v1.53.0/go.mod h1:3b/iXjRQGU4nKa87cXeg6/gogLjO8C6PmuM8i5Bi/u4= +cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY= +cloud.google.com/go/compute v1.23.0/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= -cloud.google.com/go/datacatalog v1.14.1 h1:cFPBt8V5V2T3mu/96tc4nhcMB+5cYcpwjBfn79bZDI8= +cloud.google.com/go/datacatalog v1.16.0 h1:qVeQcw1Cz93/cGu2E7TYUPh8Lz5dn5Ws2siIuQ17Vng= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/iam v1.1.1 h1:lW7fzj15aVIXYHREOqjRBV9PsH0Z6u8Y46a1YGvQP4Y= @@ -109,8 +109,11 @@ github.com/ProtonMail/go-crypto v0.0.0-20221026131551-cf6655e29de4/go.mod h1:UBY github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 h1:fLjPD/aNc3UIOA6tDi6QXUemppXK3P9BI7mr2hd6gx8= github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/VictoriaMetrics/fastcache v1.5.3/go.mod h1:+jv9Ckb+za/P1ZRg/sulP5Ni1v49daAVERr0H3CuscE= +github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= +github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= @@ -125,6 +128,7 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= @@ -212,6 +216,8 @@ github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13P github.com/btcsuite/btcd v0.21.0-beta/go.mod h1:ZSWyehm27aAuS9bvkATT+Xte3hjHZ+MRgMY/8NJ7K94= github.com/btcsuite/btcd/btcec/v2 v2.2.1 h1:xP60mv8fvp+0khmrN0zTdPC3cNm24rfeE6lh2R/Yv3E= github.com/btcsuite/btcd/btcec/v2 v2.2.1/go.mod h1:9/CSmJxmuvqzX9Wh2fXMWToLOHhPd11lSPuIupwTkI8= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 h1:KdUfX2zKommPRa+PD0sWZUyXe9w277ABlgELO7H04IM= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= github.com/btcsuite/btcutil v0.0.0-20190207003914-4c204d697803/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= @@ -313,6 +319,8 @@ github.com/davidlazar/go-crypto v0.0.0-20170701192655-dcfb0a7ac018/go.mod h1:rQY github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c h1:pFUpOrbxDR6AkioZ1ySsx5yxlDQZ8stG2b88gTPxgJU= github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c/go.mod h1:6UhI8N9EjYm1c2odKpFpAYeR8dsBeM7PtzQhRgxRr9U= github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ= +github.com/deckarep/golang-set/v2 v2.1.0 h1:g47V4Or+DUdzbs8FxCCmgb6VYd+ptPAngjM6dtGktsI= +github.com/deckarep/golang-set/v2 v2.1.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/decred/dcrd/crypto/blake256 v1.0.1 h1:7PltbUIQB7u/FfZ39+DGa/ShuMyJ5ilcvdfma9wOH6Y= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= @@ -479,6 +487,8 @@ github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/go-test/deep v1.0.5/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= @@ -493,6 +503,7 @@ github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/googleapis v0.0.0-20180223154316-0cd9801be74a/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= @@ -509,8 +520,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= -github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE= -github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP3NQ= +github.com/golang/glog v1.1.2 h1:DVjP2PbBOzHyzA+dn3WhHIq4NdVu3Q+pvivFICf/7fo= +github.com/golang/glog v1.1.2/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -554,6 +565,7 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= @@ -613,8 +625,9 @@ github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k= github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= @@ -692,6 +705,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= +github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA= github.com/holiman/uint256 v1.2.2-0.20230321075855-87b91420868c h1:DZfsyhDK1hnSS5lH8l+JggqzEleHteTYfutAiVlSUM8= github.com/holiman/uint256 v1.2.2-0.20230321075855-87b91420868c/go.mod h1:SC8Ryt4n+UBbPbIBKaG9zbbDlp4jOru9xFZmPzLUTxw= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -938,6 +953,7 @@ github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= @@ -1370,31 +1386,33 @@ github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onflow/atree v0.1.0-beta1.0.20211027184039-559ee654ece9/go.mod h1:+6x071HgCF/0v5hQcaE5qqjc2UqN5gCU8h5Mk6uqpOg= github.com/onflow/atree v0.6.0 h1:j7nQ2r8npznx4NX39zPpBYHmdy45f4xwoi+dm37Jk7c= github.com/onflow/atree v0.6.0/go.mod h1:gBHU0M05qCbv9NN0kijLWMgC47gHVNBIp4KmsVFi0tc= github.com/onflow/cadence v0.20.1/go.mod h1:7mzUvPZUIJztIbr9eTvs+fQjWWHTF8veC+yk4ihcNIA= -github.com/onflow/cadence v0.42.4 h1:KoXnwPCMcjixZv+gHZwWkDiAyVExJhJJe6SebdnHNv8= -github.com/onflow/cadence v0.42.4/go.mod h1:raU8va8QRyTa/eUbhej4mbyW2ETePfSaywoo36MddgE= +github.com/onflow/cadence v0.42.5 h1:QCilotmJzfRToLd+02o3N62JIioSr8FfN7cujmR/IXQ= +github.com/onflow/cadence v0.42.5/go.mod h1:raU8va8QRyTa/eUbhej4mbyW2ETePfSaywoo36MddgE= github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122 h1:yUzR59WUue8BN/bfwy0eN4YOLYXeQ3G9I53H0amxbDU= github.com/onflow/flow-core-contracts/lib/go/contracts v1.2.4-0.20231120143830-9e8417b56122/go.mod h1:jM6GMAL+m0hjusUgiYDNrixPQ6b9s8xjoJQoEu5bHQI= github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-240579784e9b h1:Q9iCekuCTeZU3CkVRTj5BhMBY/vR/uA2K63JTl5vCD8= github.com/onflow/flow-core-contracts/lib/go/templates v1.2.4-0.20231111185227-240579784e9b/go.mod h1:ZeLxwaBkzuSInESGjL8/IPZWezF+YOYsYbMrZlhN+q4= -github.com/onflow/flow-emulator v0.54.1-0.20231024204057-0273f8fe3807 h1:/4jZ2oELdhKubgL97NGqhiuO80oMH/M+fIQoNPfGg+g= -github.com/onflow/flow-emulator v0.54.1-0.20231024204057-0273f8fe3807/go.mod h1:Qq1YmTDYlfpzfuzrFH8gwMgzzv80LCKFiS1Kqm8vFcY= +github.com/onflow/flow-emulator v0.54.1-0.20231110220143-28061d9b37e7 h1:rMqbd3ZYtemFCdP/QvLbHloQn4+ZOakQrPMZdojvKpE= +github.com/onflow/flow-emulator v0.54.1-0.20231110220143-28061d9b37e7/go.mod h1:AaHzJfv3jIanaty6RnSf0oEmvokSErtUsW8xotR2z2I= github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13 h1:B4ll7e3j+MqTJv2122Enq3RtDNzmIGRu9xjV7fo7un0= github.com/onflow/flow-ft/lib/go/contracts v0.7.1-0.20230711213910-baad011d2b13/go.mod h1:kTMFIySzEJJeupk+7EmXs0EJ6CBWY/MV9fv9iYQk+RU= github.com/onflow/flow-go-sdk v0.24.0/go.mod h1:IoptMLPyFXWvyd9yYA6/4EmSeeozl6nJoIv4FaEMg74= -github.com/onflow/flow-go-sdk v0.41.14 h1:Pe9hogrRWkICa34Gu8yLLCMlTnu5f3L7DNXeXREkkVw= -github.com/onflow/flow-go-sdk v0.41.14/go.mod h1:yOyiwL7bt8N6FpLAyNO2D46vnFlLnIu9Cyt5PKIgqAk= +github.com/onflow/flow-go-sdk v0.41.16 h1:HsmHwEVmj+iK+GszHbFseHh7Ii5W3PWOIRNAH/En08Q= +github.com/onflow/flow-go-sdk v0.41.16/go.mod h1:bVrVNoJKiwB6vW5Qbm5tFAfJBQ5we4uSQWnn9gNAFhQ= github.com/onflow/flow-go/crypto v0.21.3/go.mod h1:vI6V4CY3R6c4JKBxdcRiR/AnjBfL8OSD97bJc60cLuQ= github.com/onflow/flow-go/crypto v0.24.9 h1:0EQp+kSZYJepMIiSypfJVe7tzsPcb6UXOdOtsTCDhBs= github.com/onflow/flow-go/crypto v0.24.9/go.mod h1:fqCzkIBBMRRkciVrvW21rECKq1oD7Q6u+bCI78lfNX0= github.com/onflow/flow-nft/lib/go/contracts v1.1.0 h1:rhUDeD27jhLwOqQKI/23008CYfnqXErrJvc4EFRP2a0= github.com/onflow/flow-nft/lib/go/contracts v1.1.0/go.mod h1:YsvzYng4htDgRB9sa9jxdwoTuuhjK8WYWXTyLkIigZY= github.com/onflow/flow/protobuf/go/flow v0.2.2/go.mod h1:gQxYqCfkI8lpnKsmIjwtN2mV/N2PIwc1I+RUK4HPIc8= -github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63 h1:SX8OhYbyKBExhy4qEDR/Hw6MVTBTzlDb8LfCHfFyte4= -github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231018182244-e72527c55c63/go.mod h1:NA2pX2nw8zuaxfKphhKsk00kWLwfd+tv8mS23YXO4Sk= +github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2 h1:qZjl4wSTG/E9znEjkHF0nNaEdlBLJoOEAtr7xUsTNqc= +github.com/onflow/flow/protobuf/go/flow v0.3.2-0.20231121210617-52ee94b830c2/go.mod h1:NA2pX2nw8zuaxfKphhKsk00kWLwfd+tv8mS23YXO4Sk= github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d h1:QcOAeEyF3iAUHv21LQ12sdcsr0yFrJGoGLyCAzYYtvI= github.com/onflow/go-bitswap v0.0.0-20230703214630-6d3db958c73d/go.mod h1:GCPpiyRoHncdqPj++zPr9ZOYBX4hpJ0pYZRYqSE8VKk= github.com/onflow/nft-storefront/lib/go/contracts v0.0.0-20221222181731-14b90207cead h1:2j1Unqs76Z1b95Gu4C3Y28hzNUHBix7wL490e61SMSw= @@ -1578,6 +1596,8 @@ github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sethvargo/go-retry v0.2.3 h1:oYlgvIvsju3jNbottWABtbnoLC+GDtLdBHxKWxQm/iU= github.com/sethvargo/go-retry v0.2.3/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw= +github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible h1:Bn1aCHHRnjv4Bl16T8rcaFjYSrGrIZvpiGO6P3Q4GpU= +github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil/v3 v3.22.2 h1:wCrArWFkHYIdDxx/FSfF5RB4dpJYW6t7rcp3+zL8uks= github.com/shirou/gopsutil/v3 v3.22.2/go.mod h1:WapW1AOOPlHyXr+yOyw3uYx36enocrtSoSBy0L5vUHY= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -1685,6 +1705,8 @@ github.com/supranational/blst v0.3.4/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/ github.com/supranational/blst v0.3.11-0.20230406105308-e9dfc5ee724b h1:u49mjRnygnB34h8OKbnNJFVUtWSKIKb1KukdV8bILUM= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d/go.mod h1:9OrXJhf154huy1nPWmuSrkgjPUtUNhA+Zmy+6AESzuA= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/texttheater/golang-levenshtein/levenshtein v0.0.0-20200805054039-cae8b0eaed6c h1:HelZ2kAFadG0La9d+4htN4HzQ68Bm2iM9qKMSMES6xg= github.com/texttheater/golang-levenshtein/levenshtein v0.0.0-20200805054039-cae8b0eaed6c/go.mod h1:JlzghshsemAMDGZLytTFY8C1JQxQPhnatWqNwUXjggo= @@ -1976,6 +1998,7 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= @@ -1998,8 +2021,8 @@ golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= -golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -2018,8 +2041,8 @@ golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210805134026-6f1e6394065a/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.10.0 h1:zHCpF2Khkwy4mMB4bv0U37YtJdTGW8jI0glAApi0Kh8= -golang.org/x/oauth2 v0.10.0/go.mod h1:kTpgurOux7LqtuxjuyZa4Gj2gdezIt/jQtGnNFfypQI= +golang.org/x/oauth2 v0.11.0 h1:vPL4xzxBM4niKCW6g9whtaWVXTJf1U5e4aZxxFx/gbU= +golang.org/x/oauth2 v0.11.0/go.mod h1:LdF7O/8bLR/qWK9DrpXmbHLTouvRHK0SgJl0GmDBchk= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -2101,6 +2124,7 @@ golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200918174421-af09f7315aff/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -2122,6 +2146,7 @@ golang.org/x/sys v0.0.0-20210309074719-68d13333faf2/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210317225723-c4fcb01b228e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -2398,13 +2423,13 @@ google.golang.org/genproto v0.0.0-20210921142501-181ce0d877f6/go.mod h1:5CzLGKJ6 google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211007155348-82e027067bd4/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= -google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 h1:Z0hjGZePRE0ZBWotvtrwxFNrNE9CUAGtplaDK5NNI/g= -google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98/go.mod h1:S7mY02OqCJTD0E1OiQy1F72PWFB4bZJ87cAtLPYgDR0= -google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 h1:FmF5cCW94Ij59cfpoLiwTgodWmm60eEV0CjlsVg2fuw= -google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98/go.mod h1:rsr7RhLuwsDKL7RmgDDCUc6yaGr1iqceVb5Wv6f6YvQ= +google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d h1:VBu5YqKPv6XiJ199exd8Br+Aetz+o08F+PLMnwJQHAY= +google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d/go.mod h1:yZTlhN0tQnXo3h00fuXNCxJdLdIdnVFVBaRJ5LWBbw4= +google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d h1:DoPTO70H+bcDXcd39vOqb2viZxgqeBeSGtZ55yZU4/Q= +google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d/go.mod h1:KjSP20unUpOx5kyQUFa7k4OJg0qeJ7DEZflGDu2p6Bk= google.golang.org/genproto/googleapis/bytestream v0.0.0-20230530153820-e85fd2cbaebc h1:g3hIDl0jRNd9PPTs2uBzYuaD5mQuwOkZY0vSc0LR32o= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= @@ -2441,8 +2466,8 @@ google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnD google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= -google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= -google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= +google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= +google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0 h1:TLkBREm4nIsEcexnCjgQd5GQWaHcqMzwQV0TX9pq8S0= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0/go.mod h1:DNq5QpG7LJqD2AamLZ7zvKE0DEpVl2BSEVjFycAAjRY= @@ -2479,6 +2504,7 @@ gopkg.in/ini.v1 v1.51.1/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= gopkg.in/olebedev/go-duktape.v3 v3.0.0-20190213234257-ec84240a7772/go.mod h1:uAJfkITjFhyEEuUfm7bsmCZRbW5WRq8s9EY8HZ6hCns= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= diff --git a/integration/utils/dependency_test.go b/integration/utils/dependency_test.go new file mode 100644 index 00000000000..8ac0fac8cc2 --- /dev/null +++ b/integration/utils/dependency_test.go @@ -0,0 +1,8 @@ +package utils + +import "github.com/btcsuite/btcd/chaincfg/chainhash" + +// this is added to resolve the issue with chainhash ambiguous import, +// the code is not used, but it's needed to force go.mod specify and retain chainhash version +// workaround for issue: https://github.com/golang/go/issues/27899 +var _ = chainhash.Hash{} diff --git a/ledger/common/convert/convert.go b/ledger/common/convert/convert.go index 6c028b1b5b2..d1c9d732570 100644 --- a/ledger/common/convert/convert.go +++ b/ledger/common/convert/convert.go @@ -61,5 +61,4 @@ func PayloadToRegister(payload *ledger.Payload) (flow.RegisterID, flow.RegisterV } return regID, payload.Value(), nil - } diff --git a/model/flow/constants.go b/model/flow/constants.go index 4f172c36528..6b03c36a6db 100644 --- a/model/flow/constants.go +++ b/model/flow/constants.go @@ -28,6 +28,10 @@ const DefaultTransactionExpiryBuffer = 30 // DefaultMaxTransactionGasLimit is the default maximum value for the transaction gas limit. const DefaultMaxTransactionGasLimit = 9999 +// EstimatedComputationPerMillisecond is the approximate number of computation units that can be performed in a millisecond. +// this was calibrated during the Variable Transaction Fees: Execution Effort FLIP https://github.com/onflow/flow/pull/753 +const EstimatedComputationPerMillisecond = 9999.0 / 200.0 + // DefaultMaxTransactionByteSize is the default maximum transaction byte size. (~1.5MB) const DefaultMaxTransactionByteSize = 1_500_000 diff --git a/module/execution/scripts.go b/module/execution/scripts.go index 35680b1ca84..471fee0c8a4 100644 --- a/module/execution/scripts.go +++ b/module/execution/scripts.go @@ -66,6 +66,7 @@ func NewScripts( entropy query.EntropyProviderPerBlock, header storage.Headers, registerAtHeight RegisterAtHeight, + queryConf query.QueryConfig, ) (*Scripts, error) { vm := fvm.NewVirtualMachine() @@ -80,7 +81,7 @@ func NewScripts( } queryExecutor := query.NewQueryExecutor( - query.NewDefaultConfig(), + queryConf, log, metrics, vm, diff --git a/module/execution/scripts_test.go b/module/execution/scripts_test.go index 97a63a20d1b..b73b905eb8f 100644 --- a/module/execution/scripts_test.go +++ b/module/execution/scripts_test.go @@ -15,6 +15,7 @@ import ( mocks "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/onflow/flow-go/engine/execution/computation/query" "github.com/onflow/flow-go/engine/execution/computation/query/mock" "github.com/onflow/flow-go/engine/execution/testutil" "github.com/onflow/flow-go/fvm" @@ -165,6 +166,7 @@ func (s *scriptTestSuite) SetupTest() { entropyBlock, headers, index.RegisterValue, + query.NewDefaultConfig(), ) s.Require().NoError(err) s.scripts = scripts diff --git a/module/finalizedreader/finalizedreader.go b/module/finalizedreader/finalizedreader.go new file mode 100644 index 00000000000..18c44ac3705 --- /dev/null +++ b/module/finalizedreader/finalizedreader.go @@ -0,0 +1,25 @@ +package finalizedreader + +import ( + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/storage" +) + +type FinalizedReader struct { + headers storage.Headers +} + +func NewFinalizedReader(headers storage.Headers) *FinalizedReader { + return &FinalizedReader{ + headers: headers, + } +} + +func (r *FinalizedReader) FinalizedBlockIDAtHeight(height uint64) (flow.Identifier, error) { + header, err := r.headers.ByHeight(height) + if err != nil { + return flow.ZeroID, err + } + + return header.ID(), nil +} diff --git a/module/finalizedreader/finalizedreader_test.go b/module/finalizedreader/finalizedreader_test.go new file mode 100644 index 00000000000..9df17f026b5 --- /dev/null +++ b/module/finalizedreader/finalizedreader_test.go @@ -0,0 +1,44 @@ +package finalizedreader + +import ( + "errors" + "testing" + + "github.com/dgraph-io/badger/v2" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/storage" + "github.com/onflow/flow-go/storage/badger/operation" + "github.com/onflow/flow-go/utils/unittest" + + badgerstorage "github.com/onflow/flow-go/storage/badger" +) + +func TestFinalizedReader(t *testing.T) { + unittest.RunWithBadgerDB(t, func(db *badger.DB) { + // prepare the storage.Headers instance + metrics := metrics.NewNoopCollector() + headers := badgerstorage.NewHeaders(metrics, db) + block := unittest.BlockFixture() + + // store header + err := headers.Store(block.Header) + require.NoError(t, err) + + // index the header + err = operation.RetryOnConflict(db.Update, operation.IndexBlockHeight(block.Header.Height, block.ID())) + require.NoError(t, err) + + // verify is able to reader the finalized block ID + reader := NewFinalizedReader(headers) + finalized, err := reader.FinalizedBlockIDAtHeight(block.Header.Height) + require.NoError(t, err) + require.Equal(t, block.ID(), finalized) + + // verify is able to return storage.NotFound when the height is not finalized + _, err = reader.FinalizedBlockIDAtHeight(block.Header.Height + 1) + require.Error(t, err) + require.True(t, errors.Is(err, storage.ErrNotFound), err) + }) +} diff --git a/module/metrics/execution.go b/module/metrics/execution.go index f17692e3859..94c3e70e107 100644 --- a/module/metrics/execution.go +++ b/module/metrics/execution.go @@ -11,77 +11,78 @@ import ( ) type ExecutionCollector struct { - tracer module.Tracer - totalExecutedBlocksCounter prometheus.Counter - totalExecutedCollectionsCounter prometheus.Counter - totalExecutedTransactionsCounter prometheus.Counter - totalExecutedScriptsCounter prometheus.Counter - totalFailedTransactionsCounter prometheus.Counter - lastExecutedBlockHeightGauge prometheus.Gauge - stateStorageDiskTotal prometheus.Gauge - storageStateCommitment prometheus.Gauge - forestApproxMemorySize prometheus.Gauge - forestNumberOfTrees prometheus.Gauge - latestTrieRegCount prometheus.Gauge - latestTrieRegCountDiff prometheus.Gauge - latestTrieRegSize prometheus.Gauge - latestTrieRegSizeDiff prometheus.Gauge - latestTrieMaxDepthTouched prometheus.Gauge - updated prometheus.Counter - proofSize prometheus.Gauge - updatedValuesNumber prometheus.Counter - updatedValuesSize prometheus.Gauge - updatedDuration prometheus.Histogram - updatedDurationPerValue prometheus.Histogram - readValuesNumber prometheus.Counter - readValuesSize prometheus.Gauge - readDuration prometheus.Histogram - readDurationPerValue prometheus.Histogram - blockComputationUsed prometheus.Histogram - blockComputationVector *prometheus.GaugeVec - blockCachedPrograms prometheus.Gauge - blockMemoryUsed prometheus.Histogram - blockEventCounts prometheus.Histogram - blockEventSize prometheus.Histogram - blockExecutionTime prometheus.Histogram - blockTransactionCounts prometheus.Histogram - blockCollectionCounts prometheus.Histogram - collectionComputationUsed prometheus.Histogram - collectionMemoryUsed prometheus.Histogram - collectionEventSize prometheus.Histogram - collectionEventCounts prometheus.Histogram - collectionNumberOfRegistersTouched prometheus.Histogram - collectionTotalBytesWrittenToRegisters prometheus.Histogram - collectionExecutionTime prometheus.Histogram - collectionTransactionCounts prometheus.Histogram - collectionRequestSent prometheus.Counter - collectionRequestRetried prometheus.Counter - transactionParseTime prometheus.Histogram - transactionCheckTime prometheus.Histogram - transactionInterpretTime prometheus.Histogram - transactionExecutionTime prometheus.Histogram - transactionConflictRetries prometheus.Histogram - transactionMemoryEstimate prometheus.Histogram - transactionComputationUsed prometheus.Histogram - transactionEmittedEvents prometheus.Histogram - transactionEventSize prometheus.Histogram - scriptExecutionTime prometheus.Histogram - scriptComputationUsed prometheus.Histogram - scriptMemoryUsage prometheus.Histogram - scriptMemoryEstimate prometheus.Histogram - scriptMemoryDifference prometheus.Histogram - numberOfAccounts prometheus.Gauge - programsCacheMiss prometheus.Counter - programsCacheHit prometheus.Counter - chunkDataPackRequestProcessedTotal prometheus.Counter - chunkDataPackProofSize prometheus.Histogram - chunkDataPackCollectionSize prometheus.Histogram - stateSyncActive prometheus.Gauge - blockDataUploadsInProgress prometheus.Gauge - blockDataUploadsDuration prometheus.Histogram - maxCollectionHeight prometheus.Gauge - computationResultUploadedCount prometheus.Counter - computationResultUploadRetriedCount prometheus.Counter + tracer module.Tracer + totalExecutedBlocksCounter prometheus.Counter + totalExecutedCollectionsCounter prometheus.Counter + totalExecutedTransactionsCounter prometheus.Counter + totalExecutedScriptsCounter prometheus.Counter + totalFailedTransactionsCounter prometheus.Counter + lastExecutedBlockHeightGauge prometheus.Gauge + stateStorageDiskTotal prometheus.Gauge + storageStateCommitment prometheus.Gauge + forestApproxMemorySize prometheus.Gauge + forestNumberOfTrees prometheus.Gauge + latestTrieRegCount prometheus.Gauge + latestTrieRegCountDiff prometheus.Gauge + latestTrieRegSize prometheus.Gauge + latestTrieRegSizeDiff prometheus.Gauge + latestTrieMaxDepthTouched prometheus.Gauge + updated prometheus.Counter + proofSize prometheus.Gauge + updatedValuesNumber prometheus.Counter + updatedValuesSize prometheus.Gauge + updatedDuration prometheus.Histogram + updatedDurationPerValue prometheus.Histogram + readValuesNumber prometheus.Counter + readValuesSize prometheus.Gauge + readDuration prometheus.Histogram + readDurationPerValue prometheus.Histogram + blockComputationUsed prometheus.Histogram + blockComputationVector *prometheus.GaugeVec + blockCachedPrograms prometheus.Gauge + blockMemoryUsed prometheus.Histogram + blockEventCounts prometheus.Histogram + blockEventSize prometheus.Histogram + blockExecutionTime prometheus.Histogram + blockTransactionCounts prometheus.Histogram + blockCollectionCounts prometheus.Histogram + collectionComputationUsed prometheus.Histogram + collectionMemoryUsed prometheus.Histogram + collectionEventSize prometheus.Histogram + collectionEventCounts prometheus.Histogram + collectionNumberOfRegistersTouched prometheus.Histogram + collectionTotalBytesWrittenToRegisters prometheus.Histogram + collectionExecutionTime prometheus.Histogram + collectionTransactionCounts prometheus.Histogram + collectionRequestSent prometheus.Counter + collectionRequestRetried prometheus.Counter + transactionParseTime prometheus.Histogram + transactionCheckTime prometheus.Histogram + transactionInterpretTime prometheus.Histogram + transactionExecutionTime prometheus.Histogram + transactionConflictRetries prometheus.Histogram + transactionMemoryEstimate prometheus.Histogram + transactionComputationUsed prometheus.Histogram + transactionNormalizedTimePerComputation prometheus.Histogram + transactionEmittedEvents prometheus.Histogram + transactionEventSize prometheus.Histogram + scriptExecutionTime prometheus.Histogram + scriptComputationUsed prometheus.Histogram + scriptMemoryUsage prometheus.Histogram + scriptMemoryEstimate prometheus.Histogram + scriptMemoryDifference prometheus.Histogram + numberOfAccounts prometheus.Gauge + programsCacheMiss prometheus.Counter + programsCacheHit prometheus.Counter + chunkDataPackRequestProcessedTotal prometheus.Counter + chunkDataPackProofSize prometheus.Histogram + chunkDataPackCollectionSize prometheus.Histogram + stateSyncActive prometheus.Gauge + blockDataUploadsInProgress prometheus.Gauge + blockDataUploadsDuration prometheus.Histogram + maxCollectionHeight prometheus.Gauge + computationResultUploadedCount prometheus.Counter + computationResultUploadRetriedCount prometheus.Counter } func NewExecutionCollector(tracer module.Tracer) *ExecutionCollector { @@ -405,6 +406,14 @@ func NewExecutionCollector(tracer module.Tracer) *ExecutionCollector { Buckets: []float64{50, 100, 500, 1000, 5000, 10000}, }) + transactionNormalizedTimePerComputation := promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: namespaceExecution, + Subsystem: subsystemRuntime, + Name: "transaction_ms_per_computation", + Help: "The normalized ratio of millisecond of execution time per computation used. Value below 1 means the transaction was executed faster than estimated (is using less resources then estimated)", + Buckets: []float64{0.015625, 0.03125, 0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8, 16, 32, 64}, + }) + transactionMemoryEstimate := promauto.NewHistogram(prometheus.HistogramOpts{ Namespace: namespaceExecution, Subsystem: subsystemRuntime, @@ -524,63 +533,64 @@ func NewExecutionCollector(tracer module.Tracer) *ExecutionCollector { ec := &ExecutionCollector{ tracer: tracer, - forestApproxMemorySize: forestApproxMemorySize, - forestNumberOfTrees: forestNumberOfTrees, - latestTrieRegCount: latestTrieRegCount, - latestTrieRegCountDiff: latestTrieRegCountDiff, - latestTrieRegSize: latestTrieRegSize, - latestTrieRegSizeDiff: latestTrieRegSizeDiff, - latestTrieMaxDepthTouched: latestTrieMaxDepthTouched, - updated: updatedCount, - proofSize: proofSize, - updatedValuesNumber: updatedValuesNumber, - updatedValuesSize: updatedValuesSize, - updatedDuration: updatedDuration, - updatedDurationPerValue: updatedDurationPerValue, - readValuesNumber: readValuesNumber, - readValuesSize: readValuesSize, - readDuration: readDuration, - readDurationPerValue: readDurationPerValue, - blockExecutionTime: blockExecutionTime, - blockComputationUsed: blockComputationUsed, - blockComputationVector: blockComputationVector, - blockCachedPrograms: blockCachedPrograms, - blockMemoryUsed: blockMemoryUsed, - blockEventCounts: blockEventCounts, - blockEventSize: blockEventSize, - blockTransactionCounts: blockTransactionCounts, - blockCollectionCounts: blockCollectionCounts, - collectionExecutionTime: collectionExecutionTime, - collectionComputationUsed: collectionComputationUsed, - collectionMemoryUsed: collectionMemoryUsed, - collectionEventSize: collectionEventSize, - collectionEventCounts: collectionEventCounts, - collectionNumberOfRegistersTouched: collectionNumberOfRegistersTouched, - collectionTotalBytesWrittenToRegisters: collectionTotalBytesWrittenToRegisters, - collectionTransactionCounts: collectionTransactionCounts, - collectionRequestSent: collectionRequestsSent, - collectionRequestRetried: collectionRequestsRetries, - transactionParseTime: transactionParseTime, - transactionCheckTime: transactionCheckTime, - transactionInterpretTime: transactionInterpretTime, - transactionExecutionTime: transactionExecutionTime, - transactionConflictRetries: transactionConflictRetries, - transactionComputationUsed: transactionComputationUsed, - transactionMemoryEstimate: transactionMemoryEstimate, - transactionEmittedEvents: transactionEmittedEvents, - transactionEventSize: transactionEventSize, - scriptExecutionTime: scriptExecutionTime, - scriptComputationUsed: scriptComputationUsed, - scriptMemoryUsage: scriptMemoryUsage, - scriptMemoryEstimate: scriptMemoryEstimate, - scriptMemoryDifference: scriptMemoryDifference, - chunkDataPackRequestProcessedTotal: chunkDataPackRequestProcessedTotal, - chunkDataPackProofSize: chunkDataPackProofSize, - chunkDataPackCollectionSize: chunkDataPackCollectionSize, - blockDataUploadsInProgress: blockDataUploadsInProgress, - blockDataUploadsDuration: blockDataUploadsDuration, - computationResultUploadedCount: computationResultUploadedCount, - computationResultUploadRetriedCount: computationResultUploadRetriedCount, + forestApproxMemorySize: forestApproxMemorySize, + forestNumberOfTrees: forestNumberOfTrees, + latestTrieRegCount: latestTrieRegCount, + latestTrieRegCountDiff: latestTrieRegCountDiff, + latestTrieRegSize: latestTrieRegSize, + latestTrieRegSizeDiff: latestTrieRegSizeDiff, + latestTrieMaxDepthTouched: latestTrieMaxDepthTouched, + updated: updatedCount, + proofSize: proofSize, + updatedValuesNumber: updatedValuesNumber, + updatedValuesSize: updatedValuesSize, + updatedDuration: updatedDuration, + updatedDurationPerValue: updatedDurationPerValue, + readValuesNumber: readValuesNumber, + readValuesSize: readValuesSize, + readDuration: readDuration, + readDurationPerValue: readDurationPerValue, + blockExecutionTime: blockExecutionTime, + blockComputationUsed: blockComputationUsed, + blockComputationVector: blockComputationVector, + blockCachedPrograms: blockCachedPrograms, + blockMemoryUsed: blockMemoryUsed, + blockEventCounts: blockEventCounts, + blockEventSize: blockEventSize, + blockTransactionCounts: blockTransactionCounts, + blockCollectionCounts: blockCollectionCounts, + collectionExecutionTime: collectionExecutionTime, + collectionComputationUsed: collectionComputationUsed, + collectionMemoryUsed: collectionMemoryUsed, + collectionEventSize: collectionEventSize, + collectionEventCounts: collectionEventCounts, + collectionNumberOfRegistersTouched: collectionNumberOfRegistersTouched, + collectionTotalBytesWrittenToRegisters: collectionTotalBytesWrittenToRegisters, + collectionTransactionCounts: collectionTransactionCounts, + collectionRequestSent: collectionRequestsSent, + collectionRequestRetried: collectionRequestsRetries, + transactionParseTime: transactionParseTime, + transactionCheckTime: transactionCheckTime, + transactionInterpretTime: transactionInterpretTime, + transactionExecutionTime: transactionExecutionTime, + transactionConflictRetries: transactionConflictRetries, + transactionComputationUsed: transactionComputationUsed, + transactionNormalizedTimePerComputation: transactionNormalizedTimePerComputation, + transactionMemoryEstimate: transactionMemoryEstimate, + transactionEmittedEvents: transactionEmittedEvents, + transactionEventSize: transactionEventSize, + scriptExecutionTime: scriptExecutionTime, + scriptComputationUsed: scriptComputationUsed, + scriptMemoryUsage: scriptMemoryUsage, + scriptMemoryEstimate: scriptMemoryEstimate, + scriptMemoryDifference: scriptMemoryDifference, + chunkDataPackRequestProcessedTotal: chunkDataPackRequestProcessedTotal, + chunkDataPackProofSize: chunkDataPackProofSize, + chunkDataPackCollectionSize: chunkDataPackCollectionSize, + blockDataUploadsInProgress: blockDataUploadsInProgress, + blockDataUploadsDuration: blockDataUploadsDuration, + computationResultUploadedCount: computationResultUploadedCount, + computationResultUploadRetriedCount: computationResultUploadRetriedCount, totalExecutedBlocksCounter: promauto.NewCounter(prometheus.CounterOpts{ Namespace: namespaceExecution, Subsystem: subsystemRuntime, @@ -739,6 +749,11 @@ func (ec *ExecutionCollector) ExecutionTransactionExecuted( ec.transactionExecutionTime.Observe(float64(dur.Milliseconds())) ec.transactionConflictRetries.Observe(float64(numConflictRetries)) ec.transactionComputationUsed.Observe(float64(compUsed)) + if compUsed > 0 { + // normalize so the value should be around 1 + ec.transactionNormalizedTimePerComputation.Observe( + (float64(dur.Milliseconds()) / float64(compUsed)) * flow.EstimatedComputationPerMillisecond) + } ec.transactionMemoryEstimate.Observe(float64(memoryUsed)) ec.transactionEmittedEvents.Observe(float64(eventCounts)) ec.transactionEventSize.Observe(float64(eventSize)) diff --git a/module/metrics/herocache.go b/module/metrics/herocache.go index 586f6bbda75..59ddb0f2f36 100644 --- a/module/metrics/herocache.go +++ b/module/metrics/herocache.go @@ -72,6 +72,10 @@ func NetworkReceiveCacheMetricsFactory(f HeroCacheMetricsFactory, networkType ne return f(namespaceNetwork, r) } +func NewSubscriptionRecordCacheMetricsFactory(f HeroCacheMetricsFactory) module.HeroCacheMetrics { + return f(namespaceNetwork, ResourceNetworkingSubscriptionRecordsCache) +} + // DisallowListCacheMetricsFactory is the factory method for creating a new HeroCacheCollector for the disallow list cache. // The disallow-list cache is used to keep track of peers that are disallow-listed and the reasons for it. // Args: diff --git a/module/metrics/labels.go b/module/metrics/labels.go index 197fffb1a21..e58610bec35 100644 --- a/module/metrics/labels.go +++ b/module/metrics/labels.go @@ -85,6 +85,7 @@ const ( ResourceEpochCommit = "epoch_commit" ResourceEpochStatus = "epoch_status" ResourceNetworkingReceiveCache = "networking_received_message" // networking layer + ResourceNetworkingSubscriptionRecordsCache = "subscription_records_cache" // networking layer ResourceNetworkingDnsIpCache = "networking_dns_ip_cache" // networking layer ResourceNetworkingDnsTxtCache = "networking_dns_txt_cache" // networking layer ResourceNetworkingDisallowListNotificationQueue = "networking_disallow_list_notification_queue" diff --git a/module/metrics/unicast_manager.go b/module/metrics/unicast_manager.go index e621c44f460..4f1ef04ec52 100644 --- a/module/metrics/unicast_manager.go +++ b/module/metrics/unicast_manager.go @@ -120,6 +120,22 @@ func NewUnicastManagerMetrics(prefix string) *UnicastManagerMetrics { }, ) + uc.streamCreationRetryBudgetResetToDefault = promauto.NewCounter( + prometheus.CounterOpts{ + Namespace: namespaceNetwork, + Subsystem: subsystemGossip, + Name: uc.prefix + "stream_creation_retry_budget_reset_to_default_total", + Help: "the number of times the stream creation retry budget is reset to default by the unicast manager", + }) + + uc.dialRetryBudgetResetToDefault = promauto.NewCounter( + prometheus.CounterOpts{ + Namespace: namespaceNetwork, + Subsystem: subsystemGossip, + Name: uc.prefix + "dial_retry_budget_reset_to_default_total", + Help: "the number of times the dial retry budget is reset to default by the unicast manager", + }) + return uc } diff --git a/module/trace/constants.go b/module/trace/constants.go index 5cda4f10d33..2d333bdb5fc 100644 --- a/module/trace/constants.go +++ b/module/trace/constants.go @@ -167,6 +167,7 @@ const ( FVMEnvGetOrLoadProgram SpanName = "fvm.env.getOrLoadCachedProgram" FVMEnvProgramLog SpanName = "fvm.env.programLog" FVMEnvEmitEvent SpanName = "fvm.env.emitEvent" + FVMEnvEncodeEvent SpanName = "fvm.env.encodeEvent" FVMEnvGenerateUUID SpanName = "fvm.env.generateUUID" FVMEnvGenerateAccountLocalID SpanName = "fvm.env.generateAccountLocalID" FVMEnvDecodeArgument SpanName = "fvm.env.decodeArgument" diff --git a/network/alsp/internal/cache.go b/network/alsp/internal/cache.go index c29ae4bd988..6bc6f361593 100644 --- a/network/alsp/internal/cache.go +++ b/network/alsp/internal/cache.go @@ -81,7 +81,6 @@ func (s *SpamRecordCache) Adjust(originId flow.Identifier, adjustFunc model.Reco penalty, err := s.adjust(originId, adjustFunc) switch { - case err == ErrSpamRecordNotFound: // if the record does not exist, we initialize the record and try to adjust it again. // Note: there is an edge case where the record is initialized by another goroutine between the two calls. diff --git a/network/internal/p2pfixtures/fixtures.go b/network/internal/p2pfixtures/fixtures.go index 3873d2ec668..551870f2e4b 100644 --- a/network/internal/p2pfixtures/fixtures.go +++ b/network/internal/p2pfixtures/fixtures.go @@ -13,7 +13,6 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/routing" "github.com/multiformats/go-multiaddr" @@ -32,7 +31,6 @@ import ( "github.com/onflow/flow-go/network/message" "github.com/onflow/flow-go/network/p2p" p2pdht "github.com/onflow/flow-go/network/p2p/dht" - "github.com/onflow/flow-go/network/p2p/keyutils" "github.com/onflow/flow-go/network/p2p/p2pbuilder" p2pconfig "github.com/onflow/flow-go/network/p2p/p2pbuilder/config" "github.com/onflow/flow-go/network/p2p/tracer" @@ -128,6 +126,7 @@ func CreateNode(t *testing.T, networkKey crypto.PrivateKey, sporkID flow.Identif &defaultFlowConfig.NetworkConfig.ResourceManager, &defaultFlowConfig.NetworkConfig.GossipSubRPCInspectorsConfig, p2pconfig.PeerManagerDisableConfig(), + &defaultFlowConfig.NetworkConfig.GossipSubConfig.SubscriptionProviderConfig, &p2p.DisallowListCacheConfig{ MaxSize: uint32(1000), Metrics: metrics.NewNoopCollector(), @@ -153,36 +152,6 @@ func CreateNode(t *testing.T, networkKey crypto.PrivateKey, sporkID flow.Identif return libp2pNode } -// PeerIdFixture creates a random and unique peer ID (libp2p node ID). -func PeerIdFixture(t *testing.T) peer.ID { - key, err := generateNetworkingKey(unittest.IdentifierFixture()) - require.NoError(t, err) - - pubKey, err := keyutils.LibP2PPublicKeyFromFlow(key.PublicKey()) - require.NoError(t, err) - - peerID, err := peer.IDFromPublicKey(pubKey) - require.NoError(t, err) - - return peerID -} - -// generateNetworkingKey generates a Flow ECDSA key using the given seed -func generateNetworkingKey(s flow.Identifier) (crypto.PrivateKey, error) { - seed := make([]byte, crypto.KeyGenSeedMinLen) - copy(seed, s[:]) - return crypto.GeneratePrivateKey(crypto.ECDSASecp256k1, seed) -} - -// PeerIdsFixture creates random and unique peer IDs (libp2p node IDs). -func PeerIdsFixture(t *testing.T, n int) []peer.ID { - peerIDs := make([]peer.ID, n) - for i := 0; i < n; i++ { - peerIDs[i] = PeerIdFixture(t) - } - return peerIDs -} - // SubMustNeverReceiveAnyMessage checks that the subscription never receives any message within the given timeout by the context. func SubMustNeverReceiveAnyMessage(t *testing.T, ctx context.Context, sub p2p.Subscription) { timeouted := make(chan struct{}) diff --git a/network/netconf/config.go b/network/netconf/config.go index 2eeccdce256..b9df868d281 100644 --- a/network/netconf/config.go +++ b/network/netconf/config.go @@ -33,19 +33,10 @@ type Config struct { type UnicastConfig struct { // UnicastRateLimitersConfig configuration for all unicast rate limiters. UnicastRateLimitersConfig `mapstructure:",squash"` + // CreateStreamBackoffDelay initial delay used in the exponential backoff for create stream retries. CreateStreamBackoffDelay time.Duration `validate:"gt=0s" mapstructure:"unicast-create-stream-retry-delay"` - // DialInProgressBackoffDelay is the backoff delay for parallel attempts on dialing to the same peer. - // When the unicast manager is invoked to create stream to the same peer concurrently while there is - // already an ongoing dialing attempt to the same peer, the unicast manager will wait for this backoff delay - // and retry creating the stream after the backoff delay has elapsed. This is to prevent the unicast manager - // from creating too many parallel dialing attempts to the same peer. - DialInProgressBackoffDelay time.Duration `validate:"gt=0s" mapstructure:"unicast-dial-in-progress-backoff-delay"` - - // DialBackoffDelay is the backoff delay between retrying connection to the same peer. - DialBackoffDelay time.Duration `validate:"gt=0s" mapstructure:"unicast-dial-backoff-delay"` - // StreamZeroRetryResetThreshold is the threshold that determines when to reset the stream creation retry budget to the default value. // // For example the default value of 100 means that if the stream creation retry budget is decreased to 0, then it will be reset to default value @@ -58,25 +49,11 @@ type UnicastConfig struct { // 100 stream creations are all successful. StreamZeroRetryResetThreshold uint64 `validate:"gt=0" mapstructure:"unicast-stream-zero-retry-reset-threshold"` - // DialZeroRetryResetThreshold is the threshold that determines when to reset the dial retry budget to the default value. - // For example the threshold of 1 hour means that if the dial retry budget is decreased to 0, then it will be reset to default value - // when it has been 1 hour since the last successful dial. - // - // This is to prevent the retry budget from being reset too frequently, as the retry budget is used to gauge the reliability of the dialing a remote peer. - // When the dial retry budget is reset to the default value, it means that the dialing is reliable enough to be trusted again. - // This parameter mandates when the dialing is reliable enough to be trusted again; i.e., when it has been 1 hour since the last successful dial. - // Note that the last dial attempt timestamp is reset to zero when the dial fails, so the value of for example 1 hour means that the dialing to the remote peer is reliable enough that the last - // successful dial attempt was 1 hour ago. - DialZeroRetryResetThreshold time.Duration `validate:"gt=0s" mapstructure:"unicast-dial-zero-retry-reset-threshold"` - - // MaxDialRetryAttemptTimes is the maximum number of attempts to be made to connect to a remote node to establish a unicast (1:1) connection before we give up. - MaxDialRetryAttemptTimes uint64 `validate:"gt=0" mapstructure:"unicast-max-dial-retry-attempt-times"` - // MaxStreamCreationRetryAttemptTimes is the maximum number of attempts to be made to create a stream to a remote node over a direct unicast (1:1) connection before we give up. MaxStreamCreationRetryAttemptTimes uint64 `validate:"gt=1" mapstructure:"unicast-max-stream-creation-retry-attempt-times"` - // DialConfigCacheSize is the cache size of the dial config cache that keeps the individual dial config for each peer. - DialConfigCacheSize uint32 `validate:"gt=0" mapstructure:"unicast-dial-config-cache-size"` + // ConfigCacheSize is the cache size of the dial config cache that keeps the individual dial config for each peer. + ConfigCacheSize uint32 `validate:"gt=0" mapstructure:"unicast-dial-config-cache-size"` } // UnicastRateLimitersConfig unicast rate limiter configuration for the message and bandwidth rate limiters. diff --git a/network/netconf/flags.go b/network/netconf/flags.go index 3a90c0c1973..f40d755fbfc 100644 --- a/network/netconf/flags.go +++ b/network/netconf/flags.go @@ -20,11 +20,7 @@ const ( unicastMessageTimeout = "unicast-message-timeout" unicastCreateStreamRetryDelay = "unicast-create-stream-retry-delay" unicastStreamZeroRetryResetThreshold = "unicast-stream-zero-retry-reset-threshold" - unicastDialZeroRetryResetThreshold = "unicast-dial-zero-retry-reset-threshold" - unicastMaxDialRetryAttemptTimes = "unicast-max-dial-retry-attempt-times" unicastMaxStreamCreationRetryAttemptTimes = "unicast-max-stream-creation-retry-attempt-times" - unicastDialInProgressBackoffDelay = "unicast-dial-in-progress-backoff-delay" - unicastDialBackoffDelay = "unicast-dial-backoff-delay" unicastDialConfigCacheSize = "unicast-dial-config-cache-size" dnsCacheTTL = "dns-cache-ttl" disallowListNotificationCacheSize = "disallow-list-notification-cache-size" @@ -64,6 +60,9 @@ const ( rpcSentTrackerNumOfWorkers = "gossipsub-rpc-sent-tracker-workers" scoreTracerInterval = "gossipsub-score-tracer-interval" + gossipSubSubscriptionProviderUpdateInterval = "gossipsub-subscription-provider-update-interval" + gossipSubSubscriptionProviderCacheSize = "gossipsub-subscription-provider-cache-size" + // gossipsub validation inspector gossipSubRPCInspectorNotificationCacheSize = "gossipsub-rpc-inspector-notification-cache-size" validationInspectorNumberOfWorkers = "gossipsub-rpc-validation-inspector-workers" @@ -104,11 +103,7 @@ func AllFlagNames() []string { peerUpdateInterval, unicastMessageTimeout, unicastCreateStreamRetryDelay, - unicastDialInProgressBackoffDelay, - unicastDialBackoffDelay, unicastStreamZeroRetryResetThreshold, - unicastDialZeroRetryResetThreshold, - unicastMaxDialRetryAttemptTimes, unicastMaxStreamCreationRetryAttemptTimes, unicastDialConfigCacheSize, dnsCacheTTL, @@ -179,9 +174,13 @@ func AllFlagNames() []string { func InitializeNetworkFlags(flags *pflag.FlagSet, config *Config) { flags.Bool(networkingConnectionPruning, config.NetworkConnectionPruning, "enabling connection trimming") flags.Duration(dnsCacheTTL, config.DNSCacheTTL, "time-to-live for dns cache") - flags.StringSlice(preferredUnicastsProtocols, config.PreferredUnicastProtocols, "preferred unicast protocols in ascending order of preference") + flags.StringSlice( + preferredUnicastsProtocols, config.PreferredUnicastProtocols, "preferred unicast protocols in ascending order of preference") flags.Uint32(receivedMessageCacheSize, config.NetworkReceivedMessageCacheSize, "incoming message cache size at networking layer") - flags.Uint32(disallowListNotificationCacheSize, config.DisallowListNotificationCacheSize, "cache size for notification events from disallow list") + flags.Uint32( + disallowListNotificationCacheSize, + config.DisallowListNotificationCacheSize, + "cache size for notification events from disallow list") flags.Duration(peerUpdateInterval, config.PeerUpdateInterval, "how often to refresh the peer connections for the node") flags.Duration(unicastMessageTimeout, config.UnicastMessageTimeout, "how long a unicast transmission can take to complete") // unicast manager options @@ -189,24 +188,12 @@ func InitializeNetworkFlags(flags *pflag.FlagSet, config *Config) { config.UnicastConfig.CreateStreamBackoffDelay, "initial backoff delay between failing to establish a connection with another node and retrying, "+ "this delay increases exponentially with the number of subsequent failures to establish a connection.") - flags.Duration(unicastDialBackoffDelay, - config.UnicastConfig.DialInProgressBackoffDelay, - "initial backoff delay between failing to establish a connection with another node and retrying, "+ - "this delay increases exponentially with the number of subsequent failures to establish a connection.") - flags.Duration(unicastDialInProgressBackoffDelay, - config.UnicastConfig.DialInProgressBackoffDelay, - "initial backoff delay for concurrent stream creations to a remote peer when there is no exising connection and a dial is in progress. "+ - "this delay increases exponentially with the number of subsequent failure attempts") flags.Uint64(unicastStreamZeroRetryResetThreshold, config.UnicastConfig.StreamZeroRetryResetThreshold, "reset stream creation retry budget from zero to the maximum after consecutive successful streams reach this threshold.") - flags.Duration(unicastDialZeroRetryResetThreshold, - config.UnicastConfig.DialZeroRetryResetThreshold, - "reset dial retry budget if the last successful dial is longer than this threshold.") - flags.Uint64(unicastMaxDialRetryAttemptTimes, config.UnicastConfig.MaxDialRetryAttemptTimes, "maximum attempts to establish a unicast connection.") flags.Uint64(unicastMaxStreamCreationRetryAttemptTimes, config.UnicastConfig.MaxStreamCreationRetryAttemptTimes, "max attempts to create a unicast stream.") flags.Uint32(unicastDialConfigCacheSize, - config.UnicastConfig.DialConfigCacheSize, + config.UnicastConfig.ConfigCacheSize, "cache size of the dial config cache, recommended to be big enough to accommodate the entire nodes in the network.") // unicast stream handler rate limits @@ -229,10 +216,22 @@ func InitializeNetworkFlags(flags *pflag.FlagSet, config *Config) { flags.Duration(silencePeriod, config.ConnectionManagerConfig.SilencePeriod, "silence period for libp2p connection manager") flags.Bool(peerScoring, config.GossipSubConfig.PeerScoring, "enabling peer scoring on pubsub network") flags.Duration(localMeshLogInterval, config.GossipSubConfig.LocalMeshLogInterval, "logging interval for local mesh in gossipsub") - flags.Duration(scoreTracerInterval, config.GossipSubConfig.ScoreTracerInterval, "logging interval for peer score tracer in gossipsub, set to 0 to disable") - flags.Uint32(rpcSentTrackerCacheSize, config.GossipSubConfig.RPCSentTrackerCacheSize, "cache size of the rpc sent tracker used by the gossipsub mesh tracer.") - flags.Uint32(rpcSentTrackerQueueCacheSize, config.GossipSubConfig.RPCSentTrackerQueueCacheSize, "cache size of the rpc sent tracker worker queue.") - flags.Int(rpcSentTrackerNumOfWorkers, config.GossipSubConfig.RpcSentTrackerNumOfWorkers, "number of workers for the rpc sent tracker worker pool.") + flags.Duration( + scoreTracerInterval, + config.GossipSubConfig.ScoreTracerInterval, + "logging interval for peer score tracer in gossipsub, set to 0 to disable") + flags.Uint32( + rpcSentTrackerCacheSize, + config.GossipSubConfig.RPCSentTrackerCacheSize, + "cache size of the rpc sent tracker used by the gossipsub mesh tracer.") + flags.Uint32( + rpcSentTrackerQueueCacheSize, + config.GossipSubConfig.RPCSentTrackerQueueCacheSize, + "cache size of the rpc sent tracker worker queue.") + flags.Int( + rpcSentTrackerNumOfWorkers, + config.GossipSubConfig.RpcSentTrackerNumOfWorkers, + "number of workers for the rpc sent tracker worker pool.") // gossipsub RPC control message validation limits used for validation configuration and rate limiting flags.Int(validationInspectorNumberOfWorkers, config.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs.NumberOfWorkers, @@ -300,12 +299,15 @@ func InitializeNetworkFlags(flags *pflag.FlagSet, config *Config) { config.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs.IWantRPCInspectionConfig.DuplicateMsgIDThreshold, "max allowed duplicate message IDs in a single iWant control message") - flags.Int(rpcMessageMaxSampleSize, - config.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs.RpcMessageMaxSampleSize, - "the max sample size used for RPC message validation. If the total number of RPC messages exceeds this value a sample will be taken but messages will not be truncated") - flags.Int(rpcMessageErrorThreshold, - config.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs.RpcMessageErrorThreshold, - "the threshold at which an error will be returned if the number of invalid RPC messages exceeds this value") + flags.Int(rpcMessageMaxSampleSize, config.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs.RpcMessageMaxSampleSize, "the max sample size used for RPC message validation. If the total number of RPC messages exceeds this value a sample will be taken but messages will not be truncated") + flags.Int(rpcMessageErrorThreshold, config.GossipSubConfig.GossipSubRPCInspectorsConfig.GossipSubRPCValidationInspectorConfigs.RpcMessageErrorThreshold, "the threshold at which an error will be returned if the number of invalid RPC messages exceeds this value") + flags.Duration( + gossipSubSubscriptionProviderUpdateInterval, config.GossipSubConfig.SubscriptionProviderConfig.SubscriptionUpdateInterval, + "interval for updating the list of subscribed topics for all peers in the gossipsub, recommended value is a few minutes") + flags.Uint32( + gossipSubSubscriptionProviderCacheSize, + config.GossipSubConfig.SubscriptionProviderConfig.CacheSize, + "size of the cache that keeps the list of topics each peer has subscribed to, recommended size is 10x the number of authorized nodes") } // LoadLibP2PResourceManagerFlags loads all CLI flags for the libp2p resource manager configuration on the provided pflag set. @@ -379,7 +381,8 @@ func SetAliases(conf *viper.Viper) error { for _, flagName := range AllFlagNames() { fullKey, ok := m[flagName] if !ok { - return fmt.Errorf("invalid network configuration missing configuration key flag name %s check config file and cli flags", flagName) + return fmt.Errorf( + "invalid network configuration missing configuration key flag name %s check config file and cli flags", flagName) } conf.RegisterAlias(fullKey, flagName) } diff --git a/network/p2p/builder.go b/network/p2p/builder.go index d38457674e9..31a7da024f5 100644 --- a/network/p2p/builder.go +++ b/network/p2p/builder.go @@ -110,6 +110,7 @@ type GossipSubRpcInspectorSuiteFactoryFunc func( metrics.HeroCacheMetricsFactory, flownet.NetworkingType, module.IdentityProvider, + func() TopicProvider, ) (GossipSubInspectorSuite, error) // NodeBuilder is a builder pattern for creating a libp2p Node instance. diff --git a/network/p2p/connection/connection_gater_test.go b/network/p2p/connection/connection_gater_test.go index 5a2c678b15c..59ef138758d 100644 --- a/network/p2p/connection/connection_gater_test.go +++ b/network/p2p/connection/connection_gater_test.go @@ -74,7 +74,7 @@ func TestConnectionGating(t *testing.T) { // although nodes have each other addresses, they are not in the allow-lists of each other. // so they should not be able to connect to each other. p2pfixtures.EnsureNoStreamCreationBetweenGroups(t, ctx, []p2p.LibP2PNode{node1}, []p2p.LibP2PNode{node2}, func(t *testing.T, err error) { - require.True(t, stream.IsErrGaterDisallowedConnection(err)) + require.Truef(t, stream.IsErrGaterDisallowedConnection(err), "expected ErrGaterDisallowedConnection, got: %v", err) }) }) @@ -89,7 +89,7 @@ func TestConnectionGating(t *testing.T) { // from node2 -> node1 should also NOT work, since node 1 is not in node2's allow list for dialing! p2pfixtures.EnsureNoStreamCreation(t, ctx, []p2p.LibP2PNode{node2}, []p2p.LibP2PNode{node1}, func(t *testing.T, err error) { // dialing node-1 by node-2 should fail locally at the connection gater of node-2. - require.True(t, stream.IsErrGaterDisallowedConnection(err)) + require.Truef(t, stream.IsErrGaterDisallowedConnection(err), "expected ErrGaterDisallowedConnection, got: %v", err) }) // now node2 should be able to connect to node1. diff --git a/network/p2p/connection/connector.go b/network/p2p/connection/connector.go index 69fbb5d4359..2e59c595bf7 100644 --- a/network/p2p/connection/connector.go +++ b/network/p2p/connection/connector.go @@ -59,7 +59,7 @@ var _ p2p.PeerUpdater = (*PeerUpdater)(nil) // - error: an error if there is any error while creating the connector. The errors are irrecoverable and unexpected. func NewPeerUpdater(cfg *PeerUpdaterConfig) (*PeerUpdater, error) { libP2PConnector := &PeerUpdater{ - log: cfg.Logger, + log: cfg.Logger.With().Str("component", "peer-updater").Logger(), connector: cfg.Connector, host: cfg.Host, pruneConnections: cfg.PruneConnections, diff --git a/network/p2p/consumers.go b/network/p2p/consumers.go index 85206b7f1df..c3f23205210 100644 --- a/network/p2p/consumers.go +++ b/network/p2p/consumers.go @@ -33,14 +33,25 @@ type InvCtrlMsgNotif struct { Error error // MsgType the control message type. MsgType p2pmsg.ControlMessageType + // Count the number of errors. + Count uint64 } // NewInvalidControlMessageNotification returns a new *InvCtrlMsgNotif -func NewInvalidControlMessageNotification(peerID peer.ID, ctlMsgType p2pmsg.ControlMessageType, err error) *InvCtrlMsgNotif { +// Args: +// - peerID: peer id of the offender. +// - ctlMsgType: the control message type of the rpc message that caused the error. +// - err: the error that occurred. +// - count: the number of occurrences of the error. +// +// Returns: +// - *InvCtlMsgNotif: invalid control message notification. +func NewInvalidControlMessageNotification(peerID peer.ID, ctlMsgType p2pmsg.ControlMessageType, err error, count uint64) *InvCtrlMsgNotif { return &InvCtrlMsgNotif{ PeerID: peerID, Error: err, MsgType: ctlMsgType, + Count: count, } } @@ -70,11 +81,4 @@ type GossipSubInspectorSuite interface { // pattern where the consumer is notified when a new notification is published. // A consumer is only notified once for each notification, and only receives notifications that were published after it was added. AddInvalidControlMessageConsumer(GossipSubInvCtrlMsgNotifConsumer) - - // SetTopicOracle sets the topic oracle of the gossipsub inspector suite. - // The topic oracle is used to determine the list of topics that the node is subscribed to. - // If an oracle is not set, the node will not be able to determine the list of topics that the node is subscribed to. - // This func is expected to be called once and will return an error on all subsequent calls. - // All errors returned from this func are considered irrecoverable. - SetTopicOracle(topicOracle func() []string) error } diff --git a/network/p2p/inspector/internal/mockTopicProvider.go b/network/p2p/inspector/internal/mockTopicProvider.go new file mode 100644 index 00000000000..33599a2fb97 --- /dev/null +++ b/network/p2p/inspector/internal/mockTopicProvider.go @@ -0,0 +1,35 @@ +package internal + +import ( + "github.com/libp2p/go-libp2p/core/peer" +) + +// MockUpdatableTopicProvider is a mock implementation of the TopicProvider interface. +// TODO: this should be moved to a common package (e.g. network/p2p/test). Currently, it is not possible to do so because of a circular dependency. +type MockUpdatableTopicProvider struct { + topics []string + subscriptions map[string][]peer.ID +} + +func NewMockUpdatableTopicProvider() *MockUpdatableTopicProvider { + return &MockUpdatableTopicProvider{ + topics: []string{}, + subscriptions: map[string][]peer.ID{}, + } +} + +func (m *MockUpdatableTopicProvider) GetTopics() []string { + return m.topics +} + +func (m *MockUpdatableTopicProvider) ListPeers(topic string) []peer.ID { + return m.subscriptions[topic] +} + +func (m *MockUpdatableTopicProvider) UpdateTopics(topics []string) { + m.topics = topics +} + +func (m *MockUpdatableTopicProvider) UpdateSubscriptions(topic string, peers []peer.ID) { + m.subscriptions[topic] = peers +} diff --git a/network/p2p/inspector/validation/control_message_validation_inspector.go b/network/p2p/inspector/validation/control_message_validation_inspector.go index dcb07027976..66cb4339f9d 100644 --- a/network/p2p/inspector/validation/control_message_validation_inspector.go +++ b/network/p2p/inspector/validation/control_message_validation_inspector.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/go-playground/validator/v10" "github.com/hashicorp/go-multierror" pubsub "github.com/libp2p/go-libp2p-pubsub" pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb" @@ -16,6 +17,8 @@ import ( "github.com/onflow/flow-go/module/component" "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/mempool/queue" + "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/p2p" "github.com/onflow/flow-go/network/p2p/inspector/internal/cache" @@ -51,12 +54,37 @@ type ControlMsgValidationInspector struct { // 1. The cluster prefix topic is received while the inspector waits for the cluster IDs provider to be set (this can happen during the startup or epoch transitions). // 2. The node sends a cluster prefix topic where the cluster prefix does not match any of the active cluster IDs. // In such cases, the inspector will allow a configured number of these messages from the corresponding peer. - tracker *cache.ClusterPrefixedMessagesReceivedTracker - idProvider module.IdentityProvider - rateLimiters map[p2pmsg.ControlMessageType]p2p.BasicRateLimiter - rpcTracker p2p.RpcControlTracking + tracker *cache.ClusterPrefixedMessagesReceivedTracker + idProvider module.IdentityProvider + rpcTracker p2p.RpcControlTracking + // networkingType indicates public or private network, rpc publish messages are inspected for unstaked senders when running the private network. + networkingType network.NetworkingType // topicOracle callback used to retrieve the current subscribed topics of the libp2p node. - topicOracle func() []string + topicOracle func() p2p.TopicProvider +} + +type InspectorParams struct { + // Logger the logger used by the inspector. + Logger zerolog.Logger `validate:"required"` + // SporkID the current spork ID. + SporkID flow.Identifier `validate:"required"` + // Config inspector configuration. + Config *p2pconf.GossipSubRPCValidationInspectorConfigs `validate:"required"` + // Distributor gossipsub inspector notification distributor. + Distributor p2p.GossipSubInspectorNotifDistributor `validate:"required"` + // HeroCacheMetricsFactory the metrics factory. + HeroCacheMetricsFactory metrics.HeroCacheMetricsFactory `validate:"required"` + // IdProvider identity provider is used to get the flow identifier for a peer. + IdProvider module.IdentityProvider `validate:"required"` + // InspectorMetrics metrics for the validation inspector. + InspectorMetrics module.GossipSubRpcValidationInspectorMetrics `validate:"required"` + // RpcTracker tracker used to track iHave RPC's sent and last size. + RpcTracker p2p.RpcControlTracking `validate:"required"` + // NetworkingType the networking type of the node. + NetworkingType network.NetworkingType `validate:"required"` + // TopicOracle callback used to retrieve the current subscribed topics of the libp2p node. + // It is set as a callback to avoid circular dependencies between the topic oracle and the inspector. + TopicOracle func() p2p.TopicProvider `validate:"required"` } var _ component.Component = (*ControlMsgValidationInspector)(nil) @@ -65,55 +93,64 @@ var _ protocol.Consumer = (*ControlMsgValidationInspector)(nil) // NewControlMsgValidationInspector returns new ControlMsgValidationInspector // Args: -// - logger: the logger used by the inspector. -// - sporkID: the current spork ID. -// - config: inspector configuration. -// - distributor: gossipsub inspector notification distributor. -// - clusterPrefixedCacheCollector: metrics collector for the underlying cluster prefix received tracker cache. -// - idProvider: identity provider is used to get the flow identifier for a peer. +// - *InspectorParams: params used to create the inspector. // // Returns: // - *ControlMsgValidationInspector: a new control message validation inspector. // - error: an error if there is any error while creating the inspector. All errors are irrecoverable and unexpected. -func NewControlMsgValidationInspector(ctx irrecoverable.SignalerContext, logger zerolog.Logger, sporkID flow.Identifier, config *p2pconf.GossipSubRPCValidationInspectorConfigs, distributor p2p.GossipSubInspectorNotifDistributor, inspectMsgQueueCacheCollector module.HeroCacheMetrics, clusterPrefixedCacheCollector module.HeroCacheMetrics, idProvider module.IdentityProvider, inspectorMetrics module.GossipSubRpcValidationInspectorMetrics, rpcTracker p2p.RpcControlTracking) (*ControlMsgValidationInspector, error) { - lg := logger.With().Str("component", "gossip_sub_rpc_validation_inspector").Logger() +func NewControlMsgValidationInspector(params *InspectorParams) (*ControlMsgValidationInspector, error) { + err := validator.New().Struct(params) + if err != nil { + return nil, fmt.Errorf("inspector params validation failed: %w", err) + } + lg := params.Logger.With().Str("component", "gossip_sub_rpc_validation_inspector").Logger() + + inspectMsgQueueCacheCollector := metrics.GossipSubRPCInspectorQueueMetricFactory(params.HeroCacheMetricsFactory, params.NetworkingType) + clusterPrefixedCacheCollector := metrics.GossipSubRPCInspectorClusterPrefixedCacheMetricFactory(params.HeroCacheMetricsFactory, params.NetworkingType) - clusterPrefixedTracker, err := cache.NewClusterPrefixedMessagesReceivedTracker(logger, config.ClusterPrefixedControlMsgsReceivedCacheSize, clusterPrefixedCacheCollector, config.ClusterPrefixedControlMsgsReceivedCacheDecay) + clusterPrefixedTracker, err := cache.NewClusterPrefixedMessagesReceivedTracker(params.Logger, + params.Config.ClusterPrefixedControlMsgsReceivedCacheSize, + clusterPrefixedCacheCollector, + params.Config.ClusterPrefixedControlMsgsReceivedCacheDecay) if err != nil { return nil, fmt.Errorf("failed to create cluster prefix topics received tracker") } - if config.RpcMessageMaxSampleSize < config.RpcMessageErrorThreshold { - return nil, fmt.Errorf("rpc message max sample size must be greater than or equal to rpc message error threshold, got %d and %d respectively", config.RpcMessageMaxSampleSize, config.RpcMessageErrorThreshold) + if params.Config.RpcMessageMaxSampleSize < params.Config.RpcMessageErrorThreshold { + return nil, fmt.Errorf("rpc message max sample size must be greater than or equal to rpc message error threshold, got %d and %d respectively", + params.Config.RpcMessageMaxSampleSize, + params.Config.RpcMessageErrorThreshold) } c := &ControlMsgValidationInspector{ - ctx: ctx, - logger: lg, - sporkID: sporkID, - config: config, - distributor: distributor, - tracker: clusterPrefixedTracker, - rpcTracker: rpcTracker, - idProvider: idProvider, - metrics: inspectorMetrics, - rateLimiters: make(map[p2pmsg.ControlMessageType]p2p.BasicRateLimiter), - } - - store := queue.NewHeroStore(config.CacheSize, logger, inspectMsgQueueCacheCollector) + logger: lg, + sporkID: params.SporkID, + config: params.Config, + distributor: params.Distributor, + tracker: clusterPrefixedTracker, + rpcTracker: params.RpcTracker, + idProvider: params.IdProvider, + metrics: params.InspectorMetrics, + networkingType: params.NetworkingType, + topicOracle: params.TopicOracle, + } + + store := queue.NewHeroStore(params.Config.CacheSize, params.Logger, inspectMsgQueueCacheCollector) + pool := worker.NewWorkerPoolBuilder[*InspectRPCRequest](lg, store, c.processInspectRPCReq).Build() c.workerPool = pool builder := component.NewComponentManagerBuilder() builder.AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { - distributor.Start(ctx) + c.ctx = ctx + c.distributor.Start(ctx) select { case <-ctx.Done(): - case <-distributor.Ready(): + case <-c.distributor.Ready(): ready() } - <-distributor.Done() + <-c.distributor.Done() }) for i := 0; i < c.config.NumberOfWorkers; i++ { builder.AddWorker(pool.WorkerLogic()) @@ -124,20 +161,31 @@ func NewControlMsgValidationInspector(ctx irrecoverable.SignalerContext, logger func (c *ControlMsgValidationInspector) Start(parent irrecoverable.SignalerContext) { if c.topicOracle == nil { - parent.Throw(fmt.Errorf("topic oracle not set")) + parent.Throw(fmt.Errorf("control message validation inspector topic oracle not set")) } c.Component.Start(parent) } +// Name returns the name of the rpc inspector. +func (c *ControlMsgValidationInspector) Name() string { + return rpcInspectorComponentName +} + +// ActiveClustersChanged consumes cluster ID update protocol events. +func (c *ControlMsgValidationInspector) ActiveClustersChanged(clusterIDList flow.ChainIDList) { + c.tracker.StoreActiveClusterIds(clusterIDList) +} + // Inspect is called by gossipsub upon reception of a rpc from a remote node. // It creates a new InspectRPCRequest for the RPC to be inspected async by the worker pool. +// Args: +// - from: the sender. +// - rpc: the control message RPC. +// +// Returns: +// - error: if a new inspect rpc request cannot be created, all errors returned are considered irrecoverable. func (c *ControlMsgValidationInspector) Inspect(from peer.ID, rpc *pubsub.RPC) error { - // first truncate rpc - err := c.truncateRPC(from, rpc) - if err != nil { - // irrecoverable error encountered - c.logAndThrowError(fmt.Errorf("failed to get inspect RPC request could not perform truncation: %w", err)) - } + c.truncateRPC(from, rpc) // queue further async inspection req, err := NewInspectRPCRequest(from, rpc) if err != nil { @@ -154,6 +202,11 @@ func (c *ControlMsgValidationInspector) Inspect(from peer.ID, rpc *pubsub.RPC) e // processInspectRPCReq func used by component workers to perform further inspection of RPC control messages that will validate ensure all control message // types are valid in the RPC. +// Args: +// - req: the inspect rpc request. +// +// Returns: +// - error: no error is expected to be returned from this func as they are logged and distributed in invalid control message notifications. func (c *ControlMsgValidationInspector) processInspectRPCReq(req *InspectRPCRequest) error { c.metrics.AsyncProcessingStarted() start := time.Now() @@ -167,40 +220,63 @@ func (c *ControlMsgValidationInspector) processInspectRPCReq(req *InspectRPCRequ case p2pmsg.CtrlMsgGraft: err := c.inspectGraftMessages(req.Peer, req.rpc.GetControl().GetGraft(), activeClusterIDS) if err != nil { - c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgGraft, err) + c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgGraft, err, 1) return nil } case p2pmsg.CtrlMsgPrune: err := c.inspectPruneMessages(req.Peer, req.rpc.GetControl().GetPrune(), activeClusterIDS) if err != nil { - c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgPrune, err) + c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgPrune, err, 1) return nil } case p2pmsg.CtrlMsgIWant: err := c.inspectIWantMessages(req.Peer, req.rpc.GetControl().GetIwant()) if err != nil { - c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgIWant, err) + c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgIWant, err, 1) return nil } case p2pmsg.CtrlMsgIHave: err := c.inspectIHaveMessages(req.Peer, req.rpc.GetControl().GetIhave(), activeClusterIDS) if err != nil { - c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgIHave, err) + c.logAndDistributeAsyncInspectErrs(req, p2pmsg.CtrlMsgIHave, err, 1) return nil } } } // inspect rpc publish messages after all control message validation has passed - err := c.inspectRpcPublishMessages(req.Peer, req.rpc.GetPublish(), activeClusterIDS) + err, errCount := c.inspectRpcPublishMessages(req.Peer, req.rpc.GetPublish(), activeClusterIDS) if err != nil { - c.logAndDistributeAsyncInspectErrs(req, p2pmsg.RpcPublishMessage, err) + c.logAndDistributeAsyncInspectErrs(req, p2pmsg.RpcPublishMessage, err, errCount) return nil } return nil } +// checkPubsubMessageSender checks the sender of the sender of pubsub message to ensure they are not unstaked, or ejected. +// This check is only required on private networks. +// Args: +// - message: the pubsub message. +// +// Returns: +// - error: if the peer ID cannot be created from bytes, sender is unknown or the identity is ejected. +// +// All errors returned from this function can be considered benign. +func (c *ControlMsgValidationInspector) checkPubsubMessageSender(message *pubsub_pb.Message) error { + pid, err := peer.IDFromBytes(message.GetFrom()) + if err != nil { + return fmt.Errorf("failed to get peer ID from bytes: %w", err) + } + if id, ok := c.idProvider.ByPeerID(pid); !ok { + return fmt.Errorf("received rpc publish message from unstaked peer: %s", pid) + } else if id.Ejected { + return fmt.Errorf("received rpc publish message from ejected peer: %s", pid) + } + + return nil +} + // inspectGraftMessages performs topic validation on all grafts in the control message using the provided validateTopic func while tracking duplicates. // Args: // - from: peer ID of the sender. @@ -322,7 +398,7 @@ func (c *ControlMsgValidationInspector) inspectIWantMessages(from peer.ID, iWant allowedCacheMissesThreshold := float64(sampleSize) * c.config.IWantRPCInspectionConfig.CacheMissThreshold duplicates := 0 allowedDuplicatesThreshold := float64(sampleSize) * c.config.IWantRPCInspectionConfig.DuplicateMsgIDThreshold - checkCacheMisses := len(iWants) > c.config.IWantRPCInspectionConfig.CacheMissCheckSize + checkCacheMisses := len(iWants) >= c.config.IWantRPCInspectionConfig.CacheMissCheckSize lg = lg.With(). Uint("iwant_sample_size", sampleSize). Float64("allowed_cache_misses_threshold", allowedCacheMissesThreshold). @@ -375,11 +451,12 @@ func (c *ControlMsgValidationInspector) inspectIWantMessages(from peer.ID, iWant // - messages: rpc publish messages. // - activeClusterIDS: the list of active cluster ids. // Returns: -// - InvalidRpcPublishMessagesErr: if the amount of invalid messages exceeds the configured RpcMessageErrorThreshold. -func (c *ControlMsgValidationInspector) inspectRpcPublishMessages(from peer.ID, messages []*pubsub_pb.Message, activeClusterIDS flow.ChainIDList) error { +// - InvalidRpcPublishMessagesErr: if the amount of invalid messages exceeds the configured RPCMessageErrorThreshold. +// - int: the number of invalid pubsub messages +func (c *ControlMsgValidationInspector) inspectRpcPublishMessages(from peer.ID, messages []*pubsub_pb.Message, activeClusterIDS flow.ChainIDList) (error, uint64) { totalMessages := len(messages) if totalMessages == 0 { - return nil + return nil, 0 } sampleSize := c.config.RpcMessageMaxSampleSize if sampleSize > totalMessages { @@ -389,7 +466,7 @@ func (c *ControlMsgValidationInspector) inspectRpcPublishMessages(from peer.ID, messages[i], messages[j] = messages[j], messages[i] }) - subscribedTopics := c.topicOracle() + subscribedTopics := c.topicOracle().GetTopics() hasSubscription := func(topic string) bool { for _, subscribedTopic := range subscribedTopics { if topic == subscribedTopic { @@ -398,30 +475,41 @@ func (c *ControlMsgValidationInspector) inspectRpcPublishMessages(from peer.ID, } return false } - var errs *multierror.Error for _, message := range messages[:sampleSize] { + if c.networkingType == network.PrivateNetwork { + err := c.checkPubsubMessageSender(message) + if err != nil { + errs = multierror.Append(errs, err) + continue + } + } topic := channels.Topic(message.GetTopic()) err := c.validateTopic(from, topic, activeClusterIDS) if err != nil { + // we can skip checking for subscription of topic that failed validation and continue errs = multierror.Append(errs, err) + continue } if !hasSubscription(topic.String()) { errs = multierror.Append(errs, fmt.Errorf("subscription for topic %s not found", topic)) } + } - // return an error when we exceed the error threshold - if errs != nil && errs.Len() > c.config.RpcMessageErrorThreshold { - return NewInvalidRpcPublishMessagesErr(errs.ErrorOrNil(), errs.Len()) - } + // return an error when we exceed the error threshold + if errs != nil && errs.Len() > c.config.RpcMessageErrorThreshold { + return NewInvalidRpcPublishMessagesErr(errs.ErrorOrNil(), errs.Len()), uint64(errs.Len()) } - return nil + return nil, 0 } // truncateRPC truncates the RPC by truncating each control message type using the configured max sample size values. -func (c *ControlMsgValidationInspector) truncateRPC(from peer.ID, rpc *pubsub.RPC) error { +// Args: +// - from: peer ID of the sender. +// - rpc: the pubsub RPC. +func (c *ControlMsgValidationInspector) truncateRPC(from peer.ID, rpc *pubsub.RPC) { for _, ctlMsgType := range p2pmsg.ControlMessageTypes() { switch ctlMsgType { case p2pmsg.CtrlMsgGraft: @@ -437,16 +525,12 @@ func (c *ControlMsgValidationInspector) truncateRPC(from peer.ID, rpc *pubsub.RP c.logAndThrowError(fmt.Errorf("unknown control message type encountered during RPC truncation")) } } - return nil } // truncateGraftMessages truncates the Graft control messages in the RPC. If the total number of Grafts in the RPC exceeds the configured // GraftPruneMessageMaxSampleSize the list of Grafts will be truncated. // Args: // - rpc: the rpc message to truncate. -// -// Returns: -// - error: if any error encountered while sampling the messages, all errors are considered irrecoverable. func (c *ControlMsgValidationInspector) truncateGraftMessages(rpc *pubsub.RPC) { grafts := rpc.GetControl().GetGraft() totalGrafts := len(grafts) @@ -467,9 +551,6 @@ func (c *ControlMsgValidationInspector) truncateGraftMessages(rpc *pubsub.RPC) { // GraftPruneMessageMaxSampleSize the list of Prunes will be truncated. // Args: // - rpc: the rpc message to truncate. -// -// Returns: -// - error: if any error encountered while sampling the messages, all errors are considered irrecoverable. func (c *ControlMsgValidationInspector) truncatePruneMessages(rpc *pubsub.RPC) { prunes := rpc.GetControl().GetPrune() totalPrunes := len(prunes) @@ -490,9 +571,6 @@ func (c *ControlMsgValidationInspector) truncatePruneMessages(rpc *pubsub.RPC) { // MaxSampleSize the list of iHaves will be truncated. // Args: // - rpc: the rpc message to truncate. -// -// Returns: -// - error: if any error encountered while sampling the messages, all errors are considered irrecoverable. func (c *ControlMsgValidationInspector) truncateIHaveMessages(rpc *pubsub.RPC) { ihaves := rpc.GetControl().GetIhave() totalIHaves := len(ihaves) @@ -537,9 +615,6 @@ func (c *ControlMsgValidationInspector) truncateIHaveMessageIds(rpc *pubsub.RPC) // MaxSampleSize the list of iWants will be truncated. // Args: // - rpc: the rpc message to truncate. -// -// Returns: -// - error: if any error encountered while sampling the messages, all errors are considered irrecoverable. func (c *ControlMsgValidationInspector) truncateIWantMessages(from peer.ID, rpc *pubsub.RPC) { iWants := rpc.GetControl().GetIwant() totalIWants := uint(len(iWants)) @@ -561,9 +636,6 @@ func (c *ControlMsgValidationInspector) truncateIWantMessages(from peer.ID, rpc // MaxMessageIDSampleSize the list of message ids will be truncated. Before message ids are truncated the iWant control messages should have been truncated themselves. // Args: // - rpc: the rpc message to truncate. -// -// Returns: -// - error: if any error encountered while sampling the messages, all errors are considered irrecoverable. func (c *ControlMsgValidationInspector) truncateIWantMessageIds(from peer.ID, rpc *pubsub.RPC) { lastHighest := c.rpcTracker.LastHighestIHaveRPCSize() lg := c.logger.With(). @@ -594,28 +666,6 @@ func (c *ControlMsgValidationInspector) truncateIWantMessageIds(from peer.ID, rp } } -// Name returns the name of the rpc inspector. -func (c *ControlMsgValidationInspector) Name() string { - return rpcInspectorComponentName -} - -// ActiveClustersChanged consumes cluster ID update protocol events. -func (c *ControlMsgValidationInspector) ActiveClustersChanged(clusterIDList flow.ChainIDList) { - c.tracker.StoreActiveClusterIds(clusterIDList) -} - -// SetTopicOracle Sets the topic oracle. The topic oracle is used to determine the list of topics that the node is subscribed to. -// If an oracle is not set, the node will not be able to determine the list of topics that the node is subscribed to. -// This func is expected to be called once and will return an error on all subsequent calls. -// All errors returned from this func are considered irrecoverable. -func (c *ControlMsgValidationInspector) SetTopicOracle(topicOracle func() []string) error { - if c.topicOracle != nil { - return fmt.Errorf("topic oracle already set") - } - c.topicOracle = topicOracle - return nil -} - // performSample performs sampling on the specified control message that will randomize // the items in the control message slice up to index sampleSize-1. Any error encountered during sampling is considered // irrecoverable and will cause the node to crash. @@ -745,29 +795,41 @@ func (c *ControlMsgValidationInspector) checkClusterPrefixHardThreshold(nodeID f } // logAndDistributeErr logs the provided error and attempts to disseminate an invalid control message validation notification for the error. -func (c *ControlMsgValidationInspector) logAndDistributeAsyncInspectErrs(req *InspectRPCRequest, ctlMsgType p2pmsg.ControlMessageType, err error) { +// Args: +// - req: inspect rpc request that failed validation. +// - ctlMsgType: the control message type of the rpc message that caused the error. +// - err: the error that occurred. +// - count: the number of occurrences of the error. +func (c *ControlMsgValidationInspector) logAndDistributeAsyncInspectErrs(req *InspectRPCRequest, ctlMsgType p2pmsg.ControlMessageType, err error, count uint64) { lg := c.logger.With(). + Err(err). + Str("control_message_type", ctlMsgType.String()). Bool(logging.KeySuspicious, true). Bool(logging.KeyNetworkingSecurity, true). + Uint64("error_count", count). Str("peer_id", p2plogging.PeerId(req.Peer)). Logger() switch { case IsErrActiveClusterIDsNotSet(err): - lg.Warn().Err(err).Msg("active cluster ids not set") + lg.Warn().Msg("active cluster ids not set") case IsErrUnstakedPeer(err): - lg.Warn().Err(err).Msg("control message received from unstaked peer") + lg.Warn().Msg("control message received from unstaked peer") default: - err = c.distributor.Distribute(p2p.NewInvalidControlMessageNotification(req.Peer, ctlMsgType, err)) - if err != nil { + distErr := c.distributor.Distribute(p2p.NewInvalidControlMessageNotification(req.Peer, ctlMsgType, err, count)) + if distErr != nil { lg.Error(). - Err(err). + Err(distErr). Msg("failed to distribute invalid control message notification") } - lg.Error().Err(err).Msg("rpc control message async inspection failed") + lg.Error().Msg("rpc control message async inspection failed") } } +// logAndThrowError logs and throws irrecoverable errors on the context. +// Args: +// +// err: the error encountered. func (c *ControlMsgValidationInspector) logAndThrowError(err error) { c.logger.Error(). Err(err). diff --git a/network/p2p/inspector/validation/control_message_validation_inspector_test.go b/network/p2p/inspector/validation/control_message_validation_inspector_test.go index 1a68e7d7a10..f076a2868f7 100644 --- a/network/p2p/inspector/validation/control_message_validation_inspector_test.go +++ b/network/p2p/inspector/validation/control_message_validation_inspector_test.go @@ -1,529 +1,837 @@ -package validation +package validation_test import ( "context" "fmt" "testing" + "time" pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/onflow/flow-go/config" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" mockmodule "github.com/onflow/flow-go/module/mock" + "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/p2p" + "github.com/onflow/flow-go/network/p2p/inspector/internal" + "github.com/onflow/flow-go/network/p2p/inspector/validation" p2pmsg "github.com/onflow/flow-go/network/p2p/message" mockp2p "github.com/onflow/flow-go/network/p2p/mock" + "github.com/onflow/flow-go/network/p2p/p2pconf" + p2ptest "github.com/onflow/flow-go/network/p2p/test" "github.com/onflow/flow-go/utils/unittest" ) +type ControlMsgValidationInspectorSuite struct { + suite.Suite + sporkID flow.Identifier + config *p2pconf.GossipSubRPCValidationInspectorConfigs + distributor *mockp2p.GossipSubInspectorNotificationDistributor + params *validation.InspectorParams + rpcTracker *mockp2p.RpcControlTracking + idProvider *mockmodule.IdentityProvider + inspector *validation.ControlMsgValidationInspector + signalerCtx *irrecoverable.MockSignalerContext + topicProviderOracle *internal.MockUpdatableTopicProvider + cancel context.CancelFunc +} + +func TestControlMsgValidationInspector(t *testing.T) { + suite.Run(t, new(ControlMsgValidationInspectorSuite)) +} + +func (suite *ControlMsgValidationInspectorSuite) SetupTest() { + suite.sporkID = unittest.IdentifierFixture() + flowConfig, err := config.DefaultConfig() + require.NoError(suite.T(), err, "failed to get default flow config") + suite.config = &flowConfig.NetworkConfig.GossipSubRPCValidationInspectorConfigs + distributor := mockp2p.NewGossipSubInspectorNotificationDistributor(suite.T()) + p2ptest.MockInspectorNotificationDistributorReadyDoneAware(distributor) + suite.distributor = distributor + suite.idProvider = mockmodule.NewIdentityProvider(suite.T()) + rpcTracker := mockp2p.NewRpcControlTracking(suite.T()) + suite.rpcTracker = rpcTracker + suite.topicProviderOracle = internal.NewMockUpdatableTopicProvider() + params := &validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: suite.sporkID, + Config: &flowConfig.NetworkConfig.GossipSubRPCValidationInspectorConfigs, + Distributor: distributor, + IdProvider: suite.idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: rpcTracker, + NetworkingType: network.PublicNetwork, + TopicOracle: func() p2p.TopicProvider { + return suite.topicProviderOracle + }, + } + suite.params = params + inspector, err := validation.NewControlMsgValidationInspector(params) + require.NoError(suite.T(), err, "failed to create control message validation inspector fixture") + suite.inspector = inspector + ctx, cancel := context.WithCancel(context.Background()) + suite.cancel = cancel + suite.signalerCtx = irrecoverable.NewMockSignalerContext(suite.T(), ctx) +} + +func (suite *ControlMsgValidationInspectorSuite) StopInspector() { + suite.cancel() + unittest.RequireCloseBefore(suite.T(), suite.inspector.Done(), 500*time.Millisecond, "inspector did not stop") +} + +func TestNewControlMsgValidationInspector(t *testing.T) { + t.Run("should create validation inspector without error", func(t *testing.T) { + sporkID := unittest.IdentifierFixture() + flowConfig, err := config.DefaultConfig() + require.NoError(t, err, "failed to get default flow config") + distributor := mockp2p.NewGossipSubInspectorNotifDistributor(t) + idProvider := mockmodule.NewIdentityProvider(t) + topicProvider := internal.NewMockUpdatableTopicProvider() + inspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: sporkID, + Config: &flowConfig.NetworkConfig.GossipSubRPCValidationInspectorConfigs, + Distributor: distributor, + IdProvider: idProvider, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + InspectorMetrics: metrics.NewNoopCollector(), + RpcTracker: mockp2p.NewRpcControlTracking(t), + NetworkingType: network.PublicNetwork, + TopicOracle: func() p2p.TopicProvider { + return topicProvider + }, + }) + require.NoError(t, err) + require.NotNil(t, inspector) + }) + t.Run("should return error if any of the params are nil", func(t *testing.T) { + inspector, err := validation.NewControlMsgValidationInspector(&validation.InspectorParams{ + Logger: unittest.Logger(), + SporkID: unittest.IdentifierFixture(), + Config: nil, + Distributor: nil, + IdProvider: nil, + HeroCacheMetricsFactory: nil, + InspectorMetrics: nil, + RpcTracker: nil, + TopicOracle: nil, + }) + require.Nil(t, inspector) + require.Error(t, err) + s := err.Error() + require.Contains(t, s, "validation for 'Config' failed on the 'required'") + require.Contains(t, s, "validation for 'Distributor' failed on the 'required'") + require.Contains(t, s, "validation for 'IdProvider' failed on the 'required'") + require.Contains(t, s, "validation for 'HeroCacheMetricsFactory' failed on the 'required'") + require.Contains(t, s, "validation for 'InspectorMetrics' failed on the 'required'") + require.Contains(t, s, "validation for 'RpcTracker' failed on the 'required'") + require.Contains(t, s, "validation for 'NetworkingType' failed on the 'required'") + require.Contains(t, s, "validation for 'TopicOracle' failed on the 'required'") + }) +} + // TestControlMessageValidationInspector_TruncateRPC verifies the expected truncation behavior of RPC control messages. // Message truncation for each control message type occurs when the count of control // messages exceeds the configured maximum sample size for that control message type. -func TestControlMessageValidationInspector_truncateRPC(t *testing.T) { - t.Run("truncateGraftMessages should truncate graft messages as expected", func(t *testing.T) { - inspector, _, _, _, _ := inspectorFixture(t) - inspector.config.GraftPruneMessageMaxSampleSize = 100 +func (suite *ControlMsgValidationInspectorSuite) TestControlMessageValidationInspector_truncateRPC() { + suite.T().Run("truncateGraftMessages should truncate graft messages as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + // topic validation is ignored set any topic oracle + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Maybe() + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Maybe() + suite.config.GraftPruneMessageMaxSampleSize = 100 + suite.inspector.Start(suite.signalerCtx) + // topic validation not performed so we can use random strings graftsGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixtures(unittest.IdentifierListFixture(200).Strings()...)...)) - require.Greater(t, len(graftsGreaterThanMaxSampleSize.GetControl().GetGraft()), inspector.config.GraftPruneMessageMaxSampleSize) + require.Greater(t, len(graftsGreaterThanMaxSampleSize.GetControl().GetGraft()), suite.config.GraftPruneMessageMaxSampleSize) graftsLessThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixtures(unittest.IdentifierListFixture(50).Strings()...)...)) - require.Less(t, len(graftsLessThanMaxSampleSize.GetControl().GetGraft()), inspector.config.GraftPruneMessageMaxSampleSize) - inspector.truncateGraftMessages(graftsGreaterThanMaxSampleSize) - inspector.truncateGraftMessages(graftsLessThanMaxSampleSize) - // rpc with grafts greater than configured max sample size should be truncated to GraftPruneMessageMaxSampleSize - require.Len(t, graftsGreaterThanMaxSampleSize.GetControl().GetGraft(), inspector.config.GraftPruneMessageMaxSampleSize) - // rpc with grafts less than GraftPruneMessageMaxSampleSize should not be truncated - require.Len(t, graftsLessThanMaxSampleSize.GetControl().GetGraft(), 50) + require.Less(t, len(graftsLessThanMaxSampleSize.GetControl().GetGraft()), suite.config.GraftPruneMessageMaxSampleSize) + + from := unittest.PeerIdFixture(t) + require.NoError(t, suite.inspector.Inspect(from, graftsGreaterThanMaxSampleSize)) + require.NoError(t, suite.inspector.Inspect(from, graftsLessThanMaxSampleSize)) + require.Eventually(t, func() bool { + // rpc with grafts greater than configured max sample size should be truncated to GraftPruneMessageMaxSampleSize + shouldBeTruncated := len(graftsGreaterThanMaxSampleSize.GetControl().GetGraft()) == suite.config.GraftPruneMessageMaxSampleSize + // rpc with grafts less than GraftPruneMessageMaxSampleSize should not be truncated + shouldNotBeTruncated := len(graftsLessThanMaxSampleSize.GetControl().GetGraft()) == 50 + return shouldBeTruncated && shouldNotBeTruncated + }, time.Second, 500*time.Millisecond) }) - t.Run("truncatePruneMessages should truncate prune messages as expected", func(t *testing.T) { - inspector, _, _, _, _ := inspectorFixture(t) - inspector.config.GraftPruneMessageMaxSampleSize = 100 - // topic validation not performed so we can use random strings + suite.T().Run("truncatePruneMessages should truncate prune messages as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Maybe() + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Twice() + suite.config.GraftPruneMessageMaxSampleSize = 100 + + suite.inspector.Start(suite.signalerCtx) + // unittest.RequireCloseBefore(t, inspector.Ready(), 100*time.Millisecond, "inspector did not start") + // topic validation not performed, so we can use random strings prunesGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithPrunes(unittest.P2PRPCPruneFixtures(unittest.IdentifierListFixture(200).Strings()...)...)) - require.Greater(t, len(prunesGreaterThanMaxSampleSize.GetControl().GetPrune()), inspector.config.GraftPruneMessageMaxSampleSize) + require.Greater(t, len(prunesGreaterThanMaxSampleSize.GetControl().GetPrune()), suite.config.GraftPruneMessageMaxSampleSize) prunesLessThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithPrunes(unittest.P2PRPCPruneFixtures(unittest.IdentifierListFixture(50).Strings()...)...)) - require.Less(t, len(prunesLessThanMaxSampleSize.GetControl().GetPrune()), inspector.config.GraftPruneMessageMaxSampleSize) - inspector.truncatePruneMessages(prunesGreaterThanMaxSampleSize) - inspector.truncatePruneMessages(prunesLessThanMaxSampleSize) - // rpc with prunes greater than configured max sample size should be truncated to GraftPruneMessageMaxSampleSize - require.Len(t, prunesGreaterThanMaxSampleSize.GetControl().GetPrune(), inspector.config.GraftPruneMessageMaxSampleSize) - // rpc with prunes less than GraftPruneMessageMaxSampleSize should not be truncated - require.Len(t, prunesLessThanMaxSampleSize.GetControl().GetPrune(), 50) + require.Less(t, len(prunesLessThanMaxSampleSize.GetControl().GetPrune()), suite.config.GraftPruneMessageMaxSampleSize) + from := unittest.PeerIdFixture(t) + require.NoError(t, suite.inspector.Inspect(from, prunesGreaterThanMaxSampleSize)) + require.NoError(t, suite.inspector.Inspect(from, prunesLessThanMaxSampleSize)) + require.Eventually(t, func() bool { + // rpc with prunes greater than configured max sample size should be truncated to GraftPruneMessageMaxSampleSize + shouldBeTruncated := len(prunesGreaterThanMaxSampleSize.GetControl().GetPrune()) == suite.config.GraftPruneMessageMaxSampleSize + // rpc with prunes less than GraftPruneMessageMaxSampleSize should not be truncated + shouldNotBeTruncated := len(prunesLessThanMaxSampleSize.GetControl().GetPrune()) == 50 + return shouldBeTruncated && shouldNotBeTruncated + }, time.Second, 500*time.Millisecond) }) - t.Run("truncateIHaveMessages should truncate iHave messages as expected", func(t *testing.T) { - inspector, _, _, _, _ := inspectorFixture(t) - inspector.config.IHaveRPCInspectionConfig.MaxSampleSize = 100 + suite.T().Run("truncateIHaveMessages should truncate iHave messages as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Maybe() + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Twice() + suite.config.IHaveRPCInspectionConfig.MaxSampleSize = 100 + suite.inspector.Start(suite.signalerCtx) + // topic validation not performed so we can use random strings - iHavesGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIHaves(unittest.P2PRPCIHaveFixtures(200, unittest.IdentifierListFixture(200).Strings()...)...)) - require.Greater(t, len(iHavesGreaterThanMaxSampleSize.GetControl().GetIhave()), inspector.config.IHaveRPCInspectionConfig.MaxSampleSize) + iHavesGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIHaves(unittest.P2PRPCIHaveFixtures(200, + unittest.IdentifierListFixture(200).Strings()...)...)) + require.Greater(t, len(iHavesGreaterThanMaxSampleSize.GetControl().GetIhave()), suite.config.IHaveRPCInspectionConfig.MaxSampleSize) iHavesLessThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIHaves(unittest.P2PRPCIHaveFixtures(200, unittest.IdentifierListFixture(50).Strings()...)...)) - require.Less(t, len(iHavesLessThanMaxSampleSize.GetControl().GetIhave()), inspector.config.IHaveRPCInspectionConfig.MaxSampleSize) - inspector.truncateIHaveMessages(iHavesGreaterThanMaxSampleSize) - inspector.truncateIHaveMessages(iHavesLessThanMaxSampleSize) - // rpc with iHaves greater than configured max sample size should be truncated to MaxSampleSize - require.Len(t, iHavesGreaterThanMaxSampleSize.GetControl().GetIhave(), inspector.config.IHaveRPCInspectionConfig.MaxSampleSize) - // rpc with iHaves less than MaxSampleSize should not be truncated - require.Len(t, iHavesLessThanMaxSampleSize.GetControl().GetIhave(), 50) + require.Less(t, len(iHavesLessThanMaxSampleSize.GetControl().GetIhave()), suite.config.IHaveRPCInspectionConfig.MaxSampleSize) + + from := unittest.PeerIdFixture(t) + require.NoError(t, suite.inspector.Inspect(from, iHavesGreaterThanMaxSampleSize)) + require.NoError(t, suite.inspector.Inspect(from, iHavesLessThanMaxSampleSize)) + require.Eventually(t, func() bool { + // rpc with iHaves greater than configured max sample size should be truncated to MaxSampleSize + shouldBeTruncated := len(iHavesGreaterThanMaxSampleSize.GetControl().GetIhave()) == suite.config.IHaveRPCInspectionConfig.MaxSampleSize + // rpc with iHaves less than MaxSampleSize should not be truncated + shouldNotBeTruncated := len(iHavesLessThanMaxSampleSize.GetControl().GetIhave()) == 50 + return shouldBeTruncated && shouldNotBeTruncated + }, time.Second, 500*time.Millisecond) }) - t.Run("truncateIHaveMessageIds should truncate iHave message ids as expected", func(t *testing.T) { - inspector, _, _, _, _ := inspectorFixture(t) - inspector.config.IHaveRPCInspectionConfig.MaxMessageIDSampleSize = 100 + suite.T().Run("truncateIHaveMessageIds should truncate iHave message ids as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Maybe() + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Twice() + suite.config.IHaveRPCInspectionConfig.MaxMessageIDSampleSize = 100 + suite.inspector.Start(suite.signalerCtx) + // topic validation not performed so we can use random strings - iHavesGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIHaves(unittest.P2PRPCIHaveFixtures(200, unittest.IdentifierListFixture(10).Strings()...)...)) + iHavesGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIHaves(unittest.P2PRPCIHaveFixtures(200, + unittest.IdentifierListFixture(10).Strings()...)...)) iHavesLessThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIHaves(unittest.P2PRPCIHaveFixtures(50, unittest.IdentifierListFixture(10).Strings()...)...)) - inspector.truncateIHaveMessageIds(iHavesGreaterThanMaxSampleSize) - inspector.truncateIHaveMessageIds(iHavesLessThanMaxSampleSize) - for _, iHave := range iHavesGreaterThanMaxSampleSize.GetControl().GetIhave() { - // rpc with iHaves message ids greater than configured max sample size should be truncated to MaxSampleSize - require.Len(t, iHave.GetMessageIDs(), inspector.config.IHaveRPCInspectionConfig.MaxMessageIDSampleSize) - } - for _, iHave := range iHavesLessThanMaxSampleSize.GetControl().GetIhave() { - // rpc with iHaves message ids less than MaxSampleSize should not be truncated - require.Len(t, iHave.GetMessageIDs(), 50) - } + + from := unittest.PeerIdFixture(t) + require.NoError(t, suite.inspector.Inspect(from, iHavesGreaterThanMaxSampleSize)) + require.NoError(t, suite.inspector.Inspect(from, iHavesLessThanMaxSampleSize)) + require.Eventually(t, func() bool { + for _, iHave := range iHavesGreaterThanMaxSampleSize.GetControl().GetIhave() { + // rpc with iHaves message ids greater than configured max sample size should be truncated to MaxSampleSize + if len(iHave.GetMessageIDs()) != suite.config.IHaveRPCInspectionConfig.MaxMessageIDSampleSize { + return false + } + } + for _, iHave := range iHavesLessThanMaxSampleSize.GetControl().GetIhave() { + // rpc with iHaves message ids less than MaxSampleSize should not be truncated + if len(iHave.GetMessageIDs()) != 50 { + return false + } + } + return true + }, time.Second, 500*time.Millisecond) }) - t.Run("truncateIWantMessages should truncate iWant messages as expected", func(t *testing.T) { - inspector, _, rpcTracker, _, _ := inspectorFixture(t) - inspector.config.IWantRPCInspectionConfig.MaxSampleSize = 100 + suite.T().Run("truncateIWantMessages should truncate iWant messages as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Maybe() + suite.config.IWantRPCInspectionConfig.MaxSampleSize = 100 + suite.inspector.Start(suite.signalerCtx) + iWantsGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixtures(200, 200)...)) - require.Greater(t, uint(len(iWantsGreaterThanMaxSampleSize.GetControl().GetIwant())), inspector.config.IWantRPCInspectionConfig.MaxSampleSize) + require.Greater(t, uint(len(iWantsGreaterThanMaxSampleSize.GetControl().GetIwant())), suite.config.IWantRPCInspectionConfig.MaxSampleSize) iWantsLessThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixtures(50, 200)...)) - require.Less(t, uint(len(iWantsLessThanMaxSampleSize.GetControl().GetIwant())), inspector.config.IWantRPCInspectionConfig.MaxSampleSize) - peerID := peer.ID("peer") - rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Twice() - inspector.truncateIWantMessages(peerID, iWantsGreaterThanMaxSampleSize) - inspector.truncateIWantMessages(peerID, iWantsLessThanMaxSampleSize) - // rpc with iWants greater than configured max sample size should be truncated to MaxSampleSize - require.Len(t, iWantsGreaterThanMaxSampleSize.GetControl().GetIwant(), int(inspector.config.IWantRPCInspectionConfig.MaxSampleSize)) - // rpc with iWants less than MaxSampleSize should not be truncated - require.Len(t, iWantsLessThanMaxSampleSize.GetControl().GetIwant(), 50) + require.Less(t, uint(len(iWantsLessThanMaxSampleSize.GetControl().GetIwant())), suite.config.IWantRPCInspectionConfig.MaxSampleSize) + + from := unittest.PeerIdFixture(t) + require.NoError(t, suite.inspector.Inspect(from, iWantsGreaterThanMaxSampleSize)) + require.NoError(t, suite.inspector.Inspect(from, iWantsLessThanMaxSampleSize)) + require.Eventually(t, func() bool { + // rpc with iWants greater than configured max sample size should be truncated to MaxSampleSize + shouldBeTruncated := len(iWantsGreaterThanMaxSampleSize.GetControl().GetIwant()) == int(suite.config.IWantRPCInspectionConfig.MaxSampleSize) + // rpc with iWants less than MaxSampleSize should not be truncated + shouldNotBeTruncated := len(iWantsLessThanMaxSampleSize.GetControl().GetIwant()) == 50 + return shouldBeTruncated && shouldNotBeTruncated + }, time.Second, 500*time.Millisecond) }) - t.Run("truncateIWantMessageIds should truncate iWant message ids as expected", func(t *testing.T) { - inspector, _, rpcTracker, _, _ := inspectorFixture(t) - inspector.config.IWantRPCInspectionConfig.MaxMessageIDSampleSize = 100 + suite.T().Run("truncateIWantMessageIds should truncate iWant message ids as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Maybe() + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Maybe() + suite.config.IWantRPCInspectionConfig.MaxMessageIDSampleSize = 100 + suite.inspector.Start(suite.signalerCtx) + iWantsGreaterThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixtures(10, 200)...)) iWantsLessThanMaxSampleSize := unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixtures(10, 50)...)) - peerID := peer.ID("peer") - rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Twice() - inspector.truncateIWantMessages(peerID, iWantsGreaterThanMaxSampleSize) - inspector.truncateIWantMessages(peerID, iWantsLessThanMaxSampleSize) - for _, iWant := range iWantsGreaterThanMaxSampleSize.GetControl().GetIwant() { - // rpc with iWants message ids greater than configured max sample size should be truncated to MaxSampleSize - require.Len(t, iWant.GetMessageIDs(), inspector.config.IWantRPCInspectionConfig.MaxMessageIDSampleSize) - } - for _, iWant := range iWantsLessThanMaxSampleSize.GetControl().GetIwant() { - // rpc with iWants less than MaxSampleSize should not be truncated - require.Len(t, iWant.GetMessageIDs(), 50) - } + + from := unittest.PeerIdFixture(t) + require.NoError(t, suite.inspector.Inspect(from, iWantsGreaterThanMaxSampleSize)) + require.NoError(t, suite.inspector.Inspect(from, iWantsLessThanMaxSampleSize)) + require.Eventually(t, func() bool { + for _, iWant := range iWantsGreaterThanMaxSampleSize.GetControl().GetIwant() { + // rpc with iWants message ids greater than configured max sample size should be truncated to MaxSampleSize + if len(iWant.GetMessageIDs()) != suite.config.IWantRPCInspectionConfig.MaxMessageIDSampleSize { + return false + } + } + for _, iWant := range iWantsLessThanMaxSampleSize.GetControl().GetIwant() { + // rpc with iWants less than MaxSampleSize should not be truncated + if len(iWant.GetMessageIDs()) != 50 { + return false + } + } + return true + }, time.Second, 500*time.Millisecond) }) } // TestControlMessageValidationInspector_processInspectRPCReq verifies the correct behavior of control message validation. // It ensures that valid RPC control messages do not trigger erroneous invalid control message notifications, // while all types of invalid control messages trigger expected notifications. -func TestControlMessageValidationInspector_processInspectRPCReq(t *testing.T) { - t.Run("processInspectRPCReq should not disseminate any invalid notification errors for valid RPC's", func(t *testing.T) { - inspector, distributor, rpcTracker, _, sporkID := inspectorFixture(t) - defer distributor.AssertNotCalled(t, "Distribute") +func (suite *ControlMsgValidationInspectorSuite) TestControlMessageValidationInspector_processInspectRPCReq() { + suite.T().Run("processInspectRPCReq should not disseminate any invalid notification errors for valid RPC's", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + defer suite.distributor.AssertNotCalled(t, "Distribute") + topics := []string{ - fmt.Sprintf("%s/%s", channels.TestNetworkChannel, sporkID), - fmt.Sprintf("%s/%s", channels.PushBlocks, sporkID), - fmt.Sprintf("%s/%s", channels.SyncCommittee, sporkID), - fmt.Sprintf("%s/%s", channels.RequestChunks, sporkID), + fmt.Sprintf("%s/%s", channels.TestNetworkChannel, suite.sporkID), + fmt.Sprintf("%s/%s", channels.PushBlocks, suite.sporkID), + fmt.Sprintf("%s/%s", channels.SyncCommittee, suite.sporkID), + fmt.Sprintf("%s/%s", channels.RequestChunks, suite.sporkID), } - // set topic oracle to return list of topics excluding first topic sent - require.NoError(t, inspector.SetTopicOracle(func() []string { - return topics - })) + suite.topicProviderOracle.UpdateTopics(topics) + suite.inspector.Start(suite.signalerCtx) grafts := unittest.P2PRPCGraftFixtures(topics...) prunes := unittest.P2PRPCPruneFixtures(topics...) ihaves := unittest.P2PRPCIHaveFixtures(50, topics...) iwants := unittest.P2PRPCIWantFixtures(2, 5) - pubsubMsgs := unittest.GossipSubMessageFixtures(t, 10, topics[0]) + pubsubMsgs := unittest.GossipSubMessageFixtures(10, topics[0]) // avoid cache misses for iwant messages. iwants[0].MessageIDs = ihaves[0].MessageIDs[:10] iwants[1].MessageIDs = ihaves[1].MessageIDs[11:20] expectedMsgIds := make([]string, 0) - expectedMsgIds = append(expectedMsgIds, ihaves[0].MessageIDs[:10]...) - expectedMsgIds = append(expectedMsgIds, ihaves[1].MessageIDs[11:20]...) - expectedPeerID := unittest.PeerIdFixture(t) - req, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture( + expectedMsgIds = append(expectedMsgIds, ihaves[0].MessageIDs...) + expectedMsgIds = append(expectedMsgIds, ihaves[1].MessageIDs...) + rpc := unittest.P2PRPCFixture( unittest.WithGrafts(grafts...), unittest.WithPrunes(prunes...), unittest.WithIHaves(ihaves...), unittest.WithIWants(iwants...), - unittest.WithPubsubMessages(pubsubMsgs...)), - ) - require.NoError(t, err, "failed to get inspect message request") - rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() - rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Run(func(args mock.Arguments) { + unittest.WithPubsubMessages(pubsubMsgs...)) + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Run(func(args mock.Arguments) { id, ok := args[0].(string) require.True(t, ok) require.Contains(t, expectedMsgIds, id) }) - require.NoError(t, inspector.processInspectRPCReq(req)) + + from := unittest.PeerIdFixture(t) + require.NoError(t, suite.inspector.Inspect(from, rpc)) + // sleep for 1 second to ensure rpc is processed + time.Sleep(time.Second) }) - t.Run("processInspectRPCReq should disseminate invalid control message notification for control messages with duplicate topics", func(t *testing.T) { - inspector, distributor, _, _, sporkID := inspectorFixture(t) - defer distributor.AssertNotCalled(t, "Distribute") - duplicateTopic := fmt.Sprintf("%s/%s", channels.TestNetworkChannel, sporkID) + suite.T().Run("processInspectRPCReq should disseminate invalid control message notification for control messages with duplicate topics", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + duplicateTopic := fmt.Sprintf("%s/%s", channels.TestNetworkChannel, suite.sporkID) + // avoid unknown topics errors + suite.topicProviderOracle.UpdateTopics([]string{duplicateTopic}) // create control messages with duplicate topic grafts := []*pubsub_pb.ControlGraft{unittest.P2PRPCGraftFixture(&duplicateTopic), unittest.P2PRPCGraftFixture(&duplicateTopic)} prunes := []*pubsub_pb.ControlPrune{unittest.P2PRPCPruneFixture(&duplicateTopic), unittest.P2PRPCPruneFixture(&duplicateTopic)} - ihaves := []*pubsub_pb.ControlIHave{unittest.P2PRPCIHaveFixture(&duplicateTopic, unittest.IdentifierListFixture(20).Strings()...), unittest.P2PRPCIHaveFixture(&duplicateTopic, unittest.IdentifierListFixture(20).Strings()...)} - expectedPeerID := unittest.PeerIdFixture(t) - duplicateTopicGraftsReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithGrafts(grafts...))) - require.NoError(t, err, "failed to get inspect message request") - duplicateTopicPrunesReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithPrunes(prunes...))) - require.NoError(t, err, "failed to get inspect message request") - duplicateTopicIHavesReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIHaves(ihaves...))) - require.NoError(t, err, "failed to get inspect message request") - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(func(args mock.Arguments) { + ihaves := []*pubsub_pb.ControlIHave{unittest.P2PRPCIHaveFixture(&duplicateTopic, unittest.IdentifierListFixture(20).Strings()...), + unittest.P2PRPCIHaveFixture(&duplicateTopic, unittest.IdentifierListFixture(20).Strings()...)} + from := unittest.PeerIdFixture(t) + duplicateTopicGraftsRpc := unittest.P2PRPCFixture(unittest.WithGrafts(grafts...)) + duplicateTopicPrunesRpc := unittest.P2PRPCFixture(unittest.WithPrunes(prunes...)) + duplicateTopicIHavesRpc := unittest.P2PRPCFixture(unittest.WithIHaves(ihaves...)) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(func(args mock.Arguments) { notification, ok := args[0].(*p2p.InvCtrlMsgNotif) require.True(t, ok) - require.Equal(t, expectedPeerID, notification.PeerID) + require.Equal(t, from, notification.PeerID) require.Contains(t, []p2pmsg.ControlMessageType{p2pmsg.CtrlMsgGraft, p2pmsg.CtrlMsgPrune, p2pmsg.CtrlMsgIHave}, notification.MsgType) - require.True(t, IsDuplicateTopicErr(notification.Error)) + require.True(t, validation.IsDuplicateTopicErr(notification.Error)) }) - require.NoError(t, inspector.processInspectRPCReq(duplicateTopicGraftsReq)) - require.NoError(t, inspector.processInspectRPCReq(duplicateTopicPrunesReq)) - require.NoError(t, inspector.processInspectRPCReq(duplicateTopicIHavesReq)) + suite.inspector.Start(suite.signalerCtx) + + require.NoError(t, suite.inspector.Inspect(from, duplicateTopicGraftsRpc)) + require.NoError(t, suite.inspector.Inspect(from, duplicateTopicPrunesRpc)) + require.NoError(t, suite.inspector.Inspect(from, duplicateTopicIHavesRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("inspectGraftMessages should disseminate invalid control message notification for invalid graft messages as expected", func(t *testing.T) { - inspector, distributor, _, _, sporkID := inspectorFixture(t) + suite.T().Run("inspectGraftMessages should disseminate invalid control message notification for invalid graft messages as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() // create unknown topic - unknownTopic, malformedTopic, invalidSporkIDTopic := invalidTopics(t, sporkID) + unknownTopic, malformedTopic, invalidSporkIDTopic := invalidTopics(t, suite.sporkID) + // avoid unknown topics errors + suite.topicProviderOracle.UpdateTopics([]string{unknownTopic, malformedTopic, invalidSporkIDTopic}) unknownTopicGraft := unittest.P2PRPCGraftFixture(&unknownTopic) malformedTopicGraft := unittest.P2PRPCGraftFixture(&malformedTopic) invalidSporkIDTopicGraft := unittest.P2PRPCGraftFixture(&invalidSporkIDTopic) - expectedPeerID := unittest.PeerIdFixture(t) - unknownTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithGrafts(unknownTopicGraft))) - require.NoError(t, err, "failed to get inspect message request") - malformedTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithGrafts(malformedTopicGraft))) - require.NoError(t, err, "failed to get inspect message request") - invalidSporkIDTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithGrafts(invalidSporkIDTopicGraft))) - require.NoError(t, err, "failed to get inspect message request") + unknownTopicReq := unittest.P2PRPCFixture(unittest.WithGrafts(unknownTopicGraft)) + malformedTopicReq := unittest.P2PRPCFixture(unittest.WithGrafts(malformedTopicGraft)) + invalidSporkIDTopicReq := unittest.P2PRPCFixture(unittest.WithGrafts(invalidSporkIDTopicGraft)) + + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.CtrlMsgGraft, channels.IsInvalidTopicErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(checkNotification) - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.CtrlMsgGraft, channels.IsInvalidTopicErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) - require.NoError(t, inspector.processInspectRPCReq(unknownTopicReq)) - require.NoError(t, inspector.processInspectRPCReq(malformedTopicReq)) - require.NoError(t, inspector.processInspectRPCReq(invalidSporkIDTopicReq)) + require.NoError(t, suite.inspector.Inspect(from, unknownTopicReq)) + require.NoError(t, suite.inspector.Inspect(from, malformedTopicReq)) + require.NoError(t, suite.inspector.Inspect(from, invalidSporkIDTopicReq)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("inspectPruneMessages should disseminate invalid control message notification for invalid prune messages as expected", func(t *testing.T) { - inspector, distributor, _, _, sporkID := inspectorFixture(t) + suite.T().Run("inspectPruneMessages should disseminate invalid control message notification for invalid prune messages as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() // create unknown topic - unknownTopic, malformedTopic, invalidSporkIDTopic := invalidTopics(t, sporkID) + unknownTopic, malformedTopic, invalidSporkIDTopic := invalidTopics(t, suite.sporkID) unknownTopicPrune := unittest.P2PRPCPruneFixture(&unknownTopic) malformedTopicPrune := unittest.P2PRPCPruneFixture(&malformedTopic) invalidSporkIDTopicPrune := unittest.P2PRPCPruneFixture(&invalidSporkIDTopic) - - expectedPeerID := unittest.PeerIdFixture(t) - unknownTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithPrunes(unknownTopicPrune))) - require.NoError(t, err, "failed to get inspect message request") - malformedTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithPrunes(malformedTopicPrune))) - require.NoError(t, err, "failed to get inspect message request") - invalidSporkIDTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithPrunes(invalidSporkIDTopicPrune))) - require.NoError(t, err, "failed to get inspect message request") - - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.CtrlMsgPrune, channels.IsInvalidTopicErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(checkNotification) - - require.NoError(t, inspector.processInspectRPCReq(unknownTopicReq)) - require.NoError(t, inspector.processInspectRPCReq(malformedTopicReq)) - require.NoError(t, inspector.processInspectRPCReq(invalidSporkIDTopicReq)) + // avoid unknown topics errors + suite.topicProviderOracle.UpdateTopics([]string{unknownTopic, malformedTopic, invalidSporkIDTopic}) + unknownTopicRpc := unittest.P2PRPCFixture(unittest.WithPrunes(unknownTopicPrune)) + malformedTopicRpc := unittest.P2PRPCFixture(unittest.WithPrunes(malformedTopicPrune)) + invalidSporkIDTopicRpc := unittest.P2PRPCFixture(unittest.WithPrunes(invalidSporkIDTopicPrune)) + + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.CtrlMsgPrune, channels.IsInvalidTopicErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(checkNotification) + + suite.inspector.Start(suite.signalerCtx) + + require.NoError(t, suite.inspector.Inspect(from, unknownTopicRpc)) + require.NoError(t, suite.inspector.Inspect(from, malformedTopicRpc)) + require.NoError(t, suite.inspector.Inspect(from, invalidSporkIDTopicRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("inspectIHaveMessages should disseminate invalid control message notification for iHave messages with invalid topics as expected", func(t *testing.T) { - inspector, distributor, _, _, sporkID := inspectorFixture(t) + suite.T().Run("inspectIHaveMessages should disseminate invalid control message notification for iHave messages with invalid topics as expected", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() // create unknown topic - unknownTopic, malformedTopic, invalidSporkIDTopic := invalidTopics(t, sporkID) + unknownTopic, malformedTopic, invalidSporkIDTopic := invalidTopics(t, suite.sporkID) + // avoid unknown topics errors + suite.topicProviderOracle.UpdateTopics([]string{unknownTopic, malformedTopic, invalidSporkIDTopic}) unknownTopicIhave := unittest.P2PRPCIHaveFixture(&unknownTopic, unittest.IdentifierListFixture(5).Strings()...) malformedTopicIhave := unittest.P2PRPCIHaveFixture(&malformedTopic, unittest.IdentifierListFixture(5).Strings()...) invalidSporkIDTopicIhave := unittest.P2PRPCIHaveFixture(&invalidSporkIDTopic, unittest.IdentifierListFixture(5).Strings()...) - expectedPeerID := unittest.PeerIdFixture(t) - unknownTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIHaves(unknownTopicIhave))) - require.NoError(t, err, "failed to get inspect message request") - malformedTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIHaves(malformedTopicIhave))) - require.NoError(t, err, "failed to get inspect message request") - invalidSporkIDTopicReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIHaves(invalidSporkIDTopicIhave))) - require.NoError(t, err, "failed to get inspect message request") - - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.CtrlMsgIHave, channels.IsInvalidTopicErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(checkNotification) - - require.NoError(t, inspector.processInspectRPCReq(unknownTopicReq)) - require.NoError(t, inspector.processInspectRPCReq(malformedTopicReq)) - require.NoError(t, inspector.processInspectRPCReq(invalidSporkIDTopicReq)) - }) + unknownTopicRpc := unittest.P2PRPCFixture(unittest.WithIHaves(unknownTopicIhave)) + malformedTopicRpc := unittest.P2PRPCFixture(unittest.WithIHaves(malformedTopicIhave)) + invalidSporkIDTopicRpc := unittest.P2PRPCFixture(unittest.WithIHaves(invalidSporkIDTopicIhave)) - t.Run("inspectIHaveMessages should disseminate invalid control message notification for iHave messages with duplicate message ids as expected", func(t *testing.T) { - inspector, distributor, _, _, sporkID := inspectorFixture(t) - validTopic := fmt.Sprintf("%s/%s", channels.PushBlocks.String(), sporkID) - duplicateMsgID := unittest.IdentifierFixture() - msgIds := flow.IdentifierList{duplicateMsgID, duplicateMsgID, duplicateMsgID} - duplicateMsgIDIHave := unittest.P2PRPCIHaveFixture(&validTopic, append(msgIds, unittest.IdentifierListFixture(5)...).Strings()...) + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.CtrlMsgIHave, channels.IsInvalidTopicErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Times(3).Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) - expectedPeerID := unittest.PeerIdFixture(t) - duplicateMsgIDReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIHaves(duplicateMsgIDIHave))) - require.NoError(t, err, "failed to get inspect message request") - - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.CtrlMsgIHave, IsDuplicateTopicErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) - require.NoError(t, inspector.processInspectRPCReq(duplicateMsgIDReq)) + require.NoError(t, suite.inspector.Inspect(from, unknownTopicRpc)) + require.NoError(t, suite.inspector.Inspect(from, malformedTopicRpc)) + require.NoError(t, suite.inspector.Inspect(from, invalidSporkIDTopicRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("inspectIWantMessages should disseminate invalid control message notification for iWant messages when duplicate message ids exceeds the allowed threshold", func(t *testing.T) { - inspector, distributor, rpcTracker, _, _ := inspectorFixture(t) - duplicateMsgID := unittest.IdentifierFixture() - duplicates := flow.IdentifierList{duplicateMsgID, duplicateMsgID} - msgIds := append(duplicates, unittest.IdentifierListFixture(5)...).Strings() - duplicateMsgIDIWant := unittest.P2PRPCIWantFixture(msgIds...) - - expectedPeerID := unittest.PeerIdFixture(t) - duplicateMsgIDReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIWants(duplicateMsgIDIWant))) - require.NoError(t, err, "failed to get inspect message request") + suite.T().Run("inspectIHaveMessages should disseminate invalid control message notification for iHave messages with duplicate message ids as expected", + func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + validTopic := fmt.Sprintf("%s/%s", channels.PushBlocks.String(), suite.sporkID) + // avoid unknown topics errors + suite.topicProviderOracle.UpdateTopics([]string{validTopic}) + duplicateMsgID := unittest.IdentifierFixture() + msgIds := flow.IdentifierList{duplicateMsgID, duplicateMsgID, duplicateMsgID} + duplicateMsgIDIHave := unittest.P2PRPCIHaveFixture(&validTopic, append(msgIds, unittest.IdentifierListFixture(5)...).Strings()...) + + duplicateMsgIDRpc := unittest.P2PRPCFixture(unittest.WithIHaves(duplicateMsgIDIHave)) + + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.CtrlMsgIHave, validation.IsDuplicateTopicErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) + + require.NoError(t, suite.inspector.Inspect(from, duplicateMsgIDRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) + }) - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.CtrlMsgIWant, IsIWantDuplicateMsgIDThresholdErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) - rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() - rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Run(func(args mock.Arguments) { - id, ok := args[0].(string) - require.True(t, ok) - require.Contains(t, msgIds, id) + suite.T().Run("inspectIWantMessages should disseminate invalid control message notification for iWant messages when duplicate message ids exceeds the allowed threshold", + func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + duplicateMsgID := unittest.IdentifierFixture() + duplicates := flow.IdentifierList{duplicateMsgID, duplicateMsgID} + msgIds := append(duplicates, unittest.IdentifierListFixture(5)...).Strings() + duplicateMsgIDIWant := unittest.P2PRPCIWantFixture(msgIds...) + + duplicateMsgIDRpc := unittest.P2PRPCFixture(unittest.WithIWants(duplicateMsgIDIWant)) + + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.CtrlMsgIWant, validation.IsIWantDuplicateMsgIDThresholdErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Run(func(args mock.Arguments) { + id, ok := args[0].(string) + require.True(t, ok) + require.Contains(t, msgIds, id) + }) + + suite.inspector.Start(suite.signalerCtx) + + require.NoError(t, suite.inspector.Inspect(from, duplicateMsgIDRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - require.NoError(t, inspector.processInspectRPCReq(duplicateMsgIDReq)) - }) - t.Run("inspectIWantMessages should disseminate invalid control message notification for iWant messages when cache misses exceeds allowed threshold", func(t *testing.T) { - inspector, distributor, rpcTracker, _, _ := inspectorFixture(t) - // set cache miss check size to 0 forcing the inspector to check the cache misses with only a single iWant - inspector.config.CacheMissCheckSize = 0 - // set high cache miss threshold to ensure we only disseminate notification when it is exceeded - inspector.config.IWantRPCInspectionConfig.CacheMissThreshold = .9 - msgIds := unittest.IdentifierListFixture(100).Strings() - expectedPeerID := unittest.PeerIdFixture(t) - inspectMsgReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixture(msgIds...)))) - require.NoError(t, err, "failed to get inspect message request") + suite.T().Run("inspectIWantMessages should disseminate invalid control message notification for iWant messages when cache misses exceeds allowed threshold", + func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + // set cache miss check size to 0 forcing the inspector to check the cache misses with only a single iWant + suite.config.CacheMissCheckSize = 0 + // set high cache miss threshold to ensure we only disseminate notification when it is exceeded + suite.config.IWantRPCInspectionConfig.CacheMissThreshold = .9 + msgIds := unittest.IdentifierListFixture(100).Strings() + // oracle must be set even though iWant messages do not have topic IDs + inspectMsgRpc := unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixture(msgIds...))) + + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.CtrlMsgIWant, validation.IsIWantCacheMissThresholdErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + // return false each time to eventually force a notification to be disseminated when the cache miss count finally exceeds the 90% threshold + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(false).Run(func(args mock.Arguments) { + id, ok := args[0].(string) + require.True(t, ok) + require.Contains(t, msgIds, id) + }) + + suite.inspector.Start(suite.signalerCtx) + + require.NoError(t, suite.inspector.Inspect(from, inspectMsgRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) + }) - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.CtrlMsgIWant, IsIWantCacheMissThresholdErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) - rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() - // return false each time to eventually force a notification to be disseminated when the cache miss count finally exceeds the 90% threshold - rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(false).Run(func(args mock.Arguments) { - id, ok := args[0].(string) - require.True(t, ok) - require.Contains(t, msgIds, id) + suite.T().Run("inspectIWantMessages should not disseminate invalid control message notification for iWant messages when cache misses exceeds allowed threshold if cache miss check size not exceeded", + func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + defer suite.distributor.AssertNotCalled(t, "Distribute") + // if size of iwants not greater than 10 cache misses will not be checked + suite.config.CacheMissCheckSize = 10 + // set high cache miss threshold to ensure we only disseminate notification when it is exceeded + suite.config.IWantRPCInspectionConfig.CacheMissThreshold = .9 + msgIds := unittest.IdentifierListFixture(100).Strings() + inspectMsgRpc := unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixture(msgIds...))) + suite.rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() + // return false each time to eventually force a notification to be disseminated when the cache miss count finally exceeds the 90% threshold + suite.rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(false).Run(func(args mock.Arguments) { + id, ok := args[0].(string) + require.True(t, ok) + require.Contains(t, msgIds, id) + }) + + from := unittest.PeerIdFixture(t) + suite.inspector.Start(suite.signalerCtx) + + require.NoError(t, suite.inspector.Inspect(from, inspectMsgRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - require.NoError(t, inspector.processInspectRPCReq(inspectMsgReq)) - }) - t.Run("inspectIWantMessages should not disseminate invalid control message notification for iWant messages when cache misses exceeds allowed threshold if cache miss check size not exceeded", func(t *testing.T) { - inspector, distributor, rpcTracker, _, _ := inspectorFixture(t) - defer distributor.AssertNotCalled(t, "Distribute") - // if size of iwants not greater than 10 cache misses will not be checked - inspector.config.CacheMissCheckSize = 10 - // set high cache miss threshold to ensure we only disseminate notification when it is exceeded - inspector.config.IWantRPCInspectionConfig.CacheMissThreshold = .9 - msgIds := unittest.IdentifierListFixture(100).Strings() - expectedPeerID := unittest.PeerIdFixture(t) - inspectMsgReq, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithIWants(unittest.P2PRPCIWantFixture(msgIds...)))) - require.NoError(t, err, "failed to get inspect message request") - rpcTracker.On("LastHighestIHaveRPCSize").Return(int64(100)).Maybe() - // return false each time to eventually force a notification to be disseminated when the cache miss count finally exceeds the 90% threshold - rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(false).Run(func(args mock.Arguments) { - id, ok := args[0].(string) - require.True(t, ok) - require.Contains(t, msgIds, id) + suite.T().Run("inspectRpcPublishMessages should disseminate invalid control message notification when invalid pubsub messages count greater than configured RpcMessageErrorThreshold", + func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + // 5 invalid pubsub messages will force notification dissemination + suite.config.RpcMessageErrorThreshold = 4 + // create unknown topic + unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", unittest.IdentifierFixture(), suite.sporkID)).String() + // create malformed topic + malformedTopic := channels.Topic("!@#$%^&**((").String() + // a topics spork ID is considered invalid if it does not match the current spork ID + invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())).String() + + // create 10 normal messages + pubsubMsgs := unittest.GossipSubMessageFixtures(10, fmt.Sprintf("%s/%s", channels.TestNetworkChannel, suite.sporkID)) + // add 5 invalid messages to force notification dissemination + pubsubMsgs = append(pubsubMsgs, []*pubsub_pb.Message{ + {Topic: &unknownTopic}, + {Topic: &malformedTopic}, + {Topic: &malformedTopic}, + {Topic: &invalidSporkIDTopic}, + {Topic: &invalidSporkIDTopic}, + }...) + rpc := unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...)) + topics := make([]string, len(pubsubMsgs)) + for i, msg := range pubsubMsgs { + topics[i] = *msg.Topic + } + // set topic oracle to return list of topics to avoid hasSubscription errors and force topic validation + suite.topicProviderOracle.UpdateTopics(topics) + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.RpcPublishMessage, validation.IsInvalidRpcPublishMessagesErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + + suite.inspector.Start(suite.signalerCtx) + + require.NoError(t, suite.inspector.Inspect(from, rpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - require.NoError(t, inspector.processInspectRPCReq(inspectMsgReq)) - }) - t.Run("inspectRpcPublishMessages should disseminate invalid control message notification when invalid pubsub messages count greater than configured RpcMessageErrorThreshold", func(t *testing.T) { - inspector, distributor, _, _, sporkID := inspectorFixture(t) + suite.T().Run("inspectRpcPublishMessages should disseminate invalid control message notification when subscription missing for topic", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() // 5 invalid pubsub messages will force notification dissemination - inspector.config.RpcMessageErrorThreshold = 4 - // create unknown topic - unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", unittest.IdentifierFixture(), sporkID)).String() - // create malformed topic - malformedTopic := channels.Topic("!@#$%^&**((").String() - // a topics spork ID is considered invalid if it does not match the current spork ID - invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())).String() - - // create 10 normal messages - pubsubMsgs := unittest.GossipSubMessageFixtures(t, 10, fmt.Sprintf("%s/%s", channels.TestNetworkChannel, sporkID)) - // add 5 invalid messages to force notification dissemination - pubsubMsgs = append(pubsubMsgs, []*pubsub_pb.Message{ - {Topic: &unknownTopic}, - {Topic: &malformedTopic}, - {Topic: &malformedTopic}, - {Topic: &invalidSporkIDTopic}, - {Topic: &invalidSporkIDTopic}, - }...) - expectedPeerID := unittest.PeerIdFixture(t) - req, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...))) - require.NoError(t, err, "failed to get inspect message request") - topics := make([]string, len(pubsubMsgs)) - for i, msg := range pubsubMsgs { - topics[i] = *msg.Topic - } - // set topic oracle to return list of topics to avoid hasSubscription errors and force topic validation - require.NoError(t, inspector.SetTopicOracle(func() []string { - return topics - })) - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.RpcPublishMessage, IsInvalidRpcPublishMessagesErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) - require.NoError(t, inspector.processInspectRPCReq(req)) + suite.config.RpcMessageErrorThreshold = 4 + pubsubMsgs := unittest.GossipSubMessageFixtures(5, fmt.Sprintf("%s/%s", channels.TestNetworkChannel, suite.sporkID)) + from := unittest.PeerIdFixture(t) + rpc := unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...)) + checkNotification := checkNotificationFunc(t, from, p2pmsg.RpcPublishMessage, validation.IsInvalidRpcPublishMessagesErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, rpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("inspectRpcPublishMessages should disseminate invalid control message notification when subscription missing for topic", func(t *testing.T) { - inspector, distributor, _, _, sporkID := inspectorFixture(t) + suite.T().Run("inspectRpcPublishMessages should disseminate invalid control message notification when publish messages contain no topic", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() // 5 invalid pubsub messages will force notification dissemination - inspector.config.RpcMessageErrorThreshold = 4 - pubsubMsgs := unittest.GossipSubMessageFixtures(t, 5, fmt.Sprintf("%s/%s", channels.TestNetworkChannel, sporkID)) - expectedPeerID := unittest.PeerIdFixture(t) - req, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...))) - require.NoError(t, err, "failed to get inspect message request") + suite.config.RpcMessageErrorThreshold = 4 + + pubsubMsgs := unittest.GossipSubMessageFixtures(10, "") + rpc := unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...)) topics := make([]string, len(pubsubMsgs)) for i, msg := range pubsubMsgs { topics[i] = *msg.Topic } - // set topic oracle to return list of topics excluding first topic sent - require.NoError(t, inspector.SetTopicOracle(func() []string { - return []string{} - })) - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.RpcPublishMessage, IsInvalidRpcPublishMessagesErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) - require.NoError(t, inspector.processInspectRPCReq(req)) + from := unittest.PeerIdFixture(t) + checkNotification := checkNotificationFunc(t, from, p2pmsg.RpcPublishMessage, validation.IsInvalidRpcPublishMessagesErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, rpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - - t.Run("inspectRpcPublishMessages should disseminate invalid control message notification when publish messages contain no topic", func(t *testing.T) { - inspector, distributor, _, _, _ := inspectorFixture(t) - // 5 invalid pubsub messages will force notification dissemination - inspector.config.RpcMessageErrorThreshold = 4 - pubsubMsgs := unittest.GossipSubMessageFixtures(t, 10, "") - expectedPeerID := unittest.PeerIdFixture(t) - req, err := NewInspectRPCRequest(expectedPeerID, unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...))) + suite.T().Run("inspectRpcPublishMessages should not inspect pubsub message sender on public networks", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + from := unittest.PeerIdFixture(t) + defer suite.idProvider.AssertNotCalled(t, "ByPeerID", from) + topic := fmt.Sprintf("%s/%s", channels.TestNetworkChannel, suite.sporkID) + suite.topicProviderOracle.UpdateTopics([]string{topic}) + pubsubMsgs := unittest.GossipSubMessageFixtures(10, topic, unittest.WithFrom(from)) + rpc := unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...)) + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, rpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) + }) + suite.T().Run("inspectRpcPublishMessages should disseminate invalid control message notification when message is from unstaked peer", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + + // override the inspector and params, run the inspector in private mode + suite.params.NetworkingType = network.PrivateNetwork + var err error + suite.inspector, err = validation.NewControlMsgValidationInspector(suite.params) + require.NoError(suite.T(), err, "failed to create control message validation inspector fixture") + + from := unittest.PeerIdFixture(t) + topic := fmt.Sprintf("%s/%s", channels.TestNetworkChannel, suite.sporkID) + suite.topicProviderOracle.UpdateTopics([]string{topic}) + // default RpcMessageErrorThreshold is 500, 501 messages should trigger a notification + pubsubMsgs := unittest.GossipSubMessageFixtures(501, topic, unittest.WithFrom(from)) + suite.idProvider.On("ByPeerID", from).Return(nil, false).Times(501) + rpc := unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...)) + checkNotification := checkNotificationFunc(t, from, p2pmsg.RpcPublishMessage, validation.IsInvalidRpcPublishMessagesErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, rpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) + }) + suite.T().Run("inspectRpcPublishMessages should disseminate invalid control message notification when message is from ejected peer", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + + // override the inspector and params, run the inspector in private mode + suite.params.NetworkingType = network.PrivateNetwork + var err error + suite.inspector, err = validation.NewControlMsgValidationInspector(suite.params) + require.NoError(suite.T(), err, "failed to create control message validation inspector fixture") + + from := unittest.PeerIdFixture(t) + id := unittest.IdentityFixture() + id.Ejected = true + topic := fmt.Sprintf("%s/%s", channels.TestNetworkChannel, suite.sporkID) + suite.topicProviderOracle.UpdateTopics([]string{topic}) + pubsubMsgs := unittest.GossipSubMessageFixtures(501, topic, unittest.WithFrom(from)) + suite.idProvider.On("ByPeerID", from).Return(id, true).Times(501) + rpc := unittest.P2PRPCFixture(unittest.WithPubsubMessages(pubsubMsgs...)) require.NoError(t, err, "failed to get inspect message request") - topics := make([]string, len(pubsubMsgs)) - for i, msg := range pubsubMsgs { - topics[i] = *msg.Topic - } - // set topic oracle to return list of topics excluding first topic sent - require.NoError(t, inspector.SetTopicOracle(func() []string { - return []string{} - })) - checkNotification := checkNotificationFunc(t, expectedPeerID, p2pmsg.RpcPublishMessage, IsInvalidRpcPublishMessagesErr) - distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) - require.NoError(t, inspector.processInspectRPCReq(req)) + checkNotification := checkNotificationFunc(t, from, p2pmsg.RpcPublishMessage, validation.IsInvalidRpcPublishMessagesErr) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, rpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) } // TestNewControlMsgValidationInspector_validateClusterPrefixedTopic ensures cluster prefixed topics are validated as expected. -func TestNewControlMsgValidationInspector_validateClusterPrefixedTopic(t *testing.T) { - t.Run("validateClusterPrefixedTopic should not return an error for valid cluster prefixed topics", func(t *testing.T) { - inspector, distributor, _, idProvider, sporkID := inspectorFixture(t) - defer distributor.AssertNotCalled(t, "Distribute") +func (suite *ControlMsgValidationInspectorSuite) TestNewControlMsgValidationInspector_validateClusterPrefixedTopic() { + suite.T().Run("validateClusterPrefixedTopic should not return an error for valid cluster prefixed topics", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + defer suite.distributor.AssertNotCalled(t, "Distribute") clusterID := flow.ChainID(unittest.IdentifierFixture().String()) - clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), sporkID)) - from := peer.ID("peerID987654321") - idProvider.On("ByPeerID", from).Return(unittest.IdentityFixture(), true).Once() - require.NoError(t, inspector.validateClusterPrefixedTopic(from, clusterPrefixedTopic, flow.ChainIDList{clusterID, flow.ChainID(unittest.IdentifierFixture().String()), flow.ChainID(unittest.IdentifierFixture().String())})) + clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), suite.sporkID)).String() + suite.topicProviderOracle.UpdateTopics([]string{clusterPrefixedTopic}) + from := unittest.PeerIdFixture(t) + suite.idProvider.On("ByPeerID", from).Return(unittest.IdentityFixture(), true).Once() + inspectMsgRpc := unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixture(&clusterPrefixedTopic))) + suite.inspector.ActiveClustersChanged(flow.ChainIDList{clusterID, + flow.ChainID(unittest.IdentifierFixture().String()), + flow.ChainID(unittest.IdentifierFixture().String())}) + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, inspectMsgRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("validateClusterPrefixedTopic should not return error if cluster prefixed hard threshold not exceeded for unknown cluster ids", func(t *testing.T) { - inspector, distributor, _, idProvider, sporkID := inspectorFixture(t) + suite.T().Run("validateClusterPrefixedTopic should not return error if cluster prefixed hard threshold not exceeded for unknown cluster ids", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + defer suite.distributor.AssertNotCalled(t, "Distribute") // set hard threshold to small number , ensure that a single unknown cluster prefix id does not cause a notification to be disseminated - inspector.config.ClusterPrefixHardThreshold = 2 - defer distributor.AssertNotCalled(t, "Distribute") + suite.config.ClusterPrefixHardThreshold = 2 + defer suite.distributor.AssertNotCalled(t, "Distribute") clusterID := flow.ChainID(unittest.IdentifierFixture().String()) - clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), sporkID)).String() - from := peer.ID("peerID987654321") - inspectMsgReq, err := NewInspectRPCRequest(from, unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixture(&clusterPrefixedTopic)))) - require.NoError(t, err, "failed to get inspect message request") - idProvider.On("ByPeerID", from).Return(unittest.IdentityFixture(), true).Once() - require.NoError(t, inspector.processInspectRPCReq(inspectMsgReq)) + clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), suite.sporkID)).String() + from := unittest.PeerIdFixture(t) + inspectMsgRpc := unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixture(&clusterPrefixedTopic))) + suite.idProvider.On("ByPeerID", from).Return(unittest.IdentityFixture(), true).Once() + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, inspectMsgRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("validateClusterPrefixedTopic should return an error when sender is unstaked", func(t *testing.T) { - inspector, distributor, _, idProvider, sporkID := inspectorFixture(t) - defer distributor.AssertNotCalled(t, "Distribute") + suite.T().Run("validateClusterPrefixedTopic should return an error when sender is unstaked", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() + defer suite.distributor.AssertNotCalled(t, "Distribute") clusterID := flow.ChainID(unittest.IdentifierFixture().String()) - clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), sporkID)) - from := peer.ID("peerID987654321") - idProvider.On("ByPeerID", from).Return(nil, false).Once() - err := inspector.validateClusterPrefixedTopic(from, clusterPrefixedTopic, flow.ChainIDList{flow.ChainID(unittest.IdentifierFixture().String())}) - require.True(t, IsErrUnstakedPeer(err)) + clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), suite.sporkID)).String() + suite.topicProviderOracle.UpdateTopics([]string{clusterPrefixedTopic}) + from := unittest.PeerIdFixture(t) + suite.idProvider.On("ByPeerID", from).Return(nil, false).Once() + inspectMsgRpc := unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixture(&clusterPrefixedTopic))) + suite.inspector.ActiveClustersChanged(flow.ChainIDList{flow.ChainID(unittest.IdentifierFixture().String())}) + + suite.inspector.Start(suite.signalerCtx) + require.NoError(t, suite.inspector.Inspect(from, inspectMsgRpc)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) - t.Run("validateClusterPrefixedTopic should return error if cluster prefixed hard threshold exceeded for unknown cluster ids", func(t *testing.T) { - inspector, distributor, _, idProvider, sporkID := inspectorFixture(t) - defer distributor.AssertNotCalled(t, "Distribute") + suite.T().Run("validateClusterPrefixedTopic should return error if cluster prefixed hard threshold exceeded for unknown cluster ids", func(t *testing.T) { + suite.SetupTest() + defer suite.StopInspector() clusterID := flow.ChainID(unittest.IdentifierFixture().String()) - clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), sporkID)) - from := peer.ID("peerID987654321") + clusterPrefixedTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(clusterID), suite.sporkID)).String() + suite.topicProviderOracle.UpdateTopics([]string{clusterPrefixedTopic}) + // the 11th unknown cluster ID error should cause an error + suite.config.ClusterPrefixHardThreshold = 10 + from := unittest.PeerIdFixture(t) identity := unittest.IdentityFixture() - idProvider.On("ByPeerID", from).Return(identity, true).Once() - inspector.config.ClusterPrefixHardThreshold = 10 - for i := 0; i < 15; i++ { - _, err := inspector.tracker.Inc(identity.NodeID) - require.NoError(t, err) + suite.idProvider.On("ByPeerID", from).Return(identity, true).Times(11) + checkNotification := checkNotificationFunc(t, from, p2pmsg.CtrlMsgGraft, channels.IsUnknownClusterIDErr) + inspectMsgRpc := unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixture(&clusterPrefixedTopic))) + suite.inspector.ActiveClustersChanged(flow.ChainIDList{flow.ChainID(unittest.IdentifierFixture().String())}) + suite.distributor.On("Distribute", mock.AnythingOfType("*p2p.InvCtrlMsgNotif")).Return(nil).Once().Run(checkNotification) + suite.inspector.Start(suite.signalerCtx) + for i := 0; i < 11; i++ { + require.NoError(t, suite.inspector.Inspect(from, inspectMsgRpc)) } - err := inspector.validateClusterPrefixedTopic(from, clusterPrefixedTopic, flow.ChainIDList{flow.ChainID(unittest.IdentifierFixture().String())}) - require.True(t, channels.IsUnknownClusterIDErr(err)) + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) }) } // TestControlMessageValidationInspector_ActiveClustersChanged validates the expected update of the active cluster IDs list. -func TestControlMessageValidationInspector_ActiveClustersChanged(t *testing.T) { - sporkID := unittest.IdentifierFixture() - flowConfig, err := config.DefaultConfig() - require.NoError(t, err, "failed to get default flow config") - distributor := mockp2p.NewGossipSubInspectorNotifDistributor(t) - signalerCtx := irrecoverable.NewMockSignalerContext(t, context.Background()) - inspector, err := NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &flowConfig.NetworkConfig.GossipSubRPCValidationInspectorConfigs, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), mockmodule.NewIdentityProvider(t), metrics.NewNoopCollector(), mockp2p.NewRpcControlTracking(t)) - require.NoError(t, err) +func (suite *ControlMsgValidationInspectorSuite) TestControlMessageValidationInspector_ActiveClustersChanged() { + suite.SetupTest() + defer suite.StopInspector() + defer suite.distributor.AssertNotCalled(suite.T(), "Distribute") + identity := unittest.IdentityFixture() + suite.idProvider.On("ByPeerID", mock.AnythingOfType("peer.ID")).Return(identity, true).Times(5) activeClusterIds := make(flow.ChainIDList, 0) for _, id := range unittest.IdentifierListFixture(5) { activeClusterIds = append(activeClusterIds, flow.ChainID(id.String())) } - - inspector.ActiveClustersChanged(activeClusterIds) - require.ElementsMatch(t, activeClusterIds, inspector.tracker.GetActiveClusterIds(), "mismatch active cluster ids list") -} - -// inspectorFixture returns a *ControlMsgValidationInspector fixture. -func inspectorFixture(t *testing.T) (*ControlMsgValidationInspector, *mockp2p.GossipSubInspectorNotifDistributor, *mockp2p.RpcControlTracking, *mockmodule.IdentityProvider, flow.Identifier) { - sporkID := unittest.IdentifierFixture() - flowConfig, err := config.DefaultConfig() - require.NoError(t, err, "failed to get default flow config") - distributor := mockp2p.NewGossipSubInspectorNotifDistributor(t) - idProvider := mockmodule.NewIdentityProvider(t) - signalerCtx := irrecoverable.NewMockSignalerContext(t, context.Background()) - inspector, err := NewControlMsgValidationInspector(signalerCtx, unittest.Logger(), sporkID, &flowConfig.NetworkConfig.GossipSubRPCValidationInspectorConfigs, distributor, metrics.NewNoopCollector(), metrics.NewNoopCollector(), idProvider, metrics.NewNoopCollector(), mockp2p.NewRpcControlTracking(t)) - require.NoError(t, err, "failed to create control message validation inspector fixture") - rpcTracker := mockp2p.NewRpcControlTracking(t) - inspector.rpcTracker = rpcTracker - return inspector, distributor, rpcTracker, idProvider, sporkID + suite.inspector.ActiveClustersChanged(activeClusterIds) + suite.inspector.Start(suite.signalerCtx) + from := unittest.PeerIdFixture(suite.T()) + for _, id := range activeClusterIds { + topic := channels.Topic(fmt.Sprintf("%s/%s", channels.SyncCluster(id), suite.sporkID)).String() + rpc := unittest.P2PRPCFixture(unittest.WithGrafts(unittest.P2PRPCGraftFixture(&topic))) + require.NoError(suite.T(), suite.inspector.Inspect(from, rpc)) + } + // sleep for 1 second to ensure rpc's is processed + time.Sleep(time.Second) } // invalidTopics returns 3 invalid topics. @@ -541,7 +849,10 @@ func invalidTopics(t *testing.T, sporkID flow.Identifier) (string, string, strin } // checkNotificationFunc returns util func used to ensure invalid control message notification disseminated contains expected information. -func checkNotificationFunc(t *testing.T, expectedPeerID peer.ID, expectedMsgType p2pmsg.ControlMessageType, isExpectedErr func(err error) bool) func(args mock.Arguments) { +func checkNotificationFunc(t *testing.T, + expectedPeerID peer.ID, + expectedMsgType p2pmsg.ControlMessageType, + isExpectedErr func(err error) bool) func(args mock.Arguments) { return func(args mock.Arguments) { notification, ok := args[0].(*p2p.InvCtrlMsgNotif) require.True(t, ok) diff --git a/network/p2p/mock/gossip_sub_inspector_suite.go b/network/p2p/mock/gossip_sub_inspector_suite.go index 59e4fa743f6..90c7e5b15d7 100644 --- a/network/p2p/mock/gossip_sub_inspector_suite.go +++ b/network/p2p/mock/gossip_sub_inspector_suite.go @@ -78,20 +78,6 @@ func (_m *GossipSubInspectorSuite) Ready() <-chan struct{} { return r0 } -// SetTopicOracle provides a mock function with given fields: topicOracle -func (_m *GossipSubInspectorSuite) SetTopicOracle(topicOracle func() []string) error { - ret := _m.Called(topicOracle) - - var r0 error - if rf, ok := ret.Get(0).(func(func() []string) error); ok { - r0 = rf(topicOracle) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // Start provides a mock function with given fields: _a0 func (_m *GossipSubInspectorSuite) Start(_a0 irrecoverable.SignalerContext) { _m.Called(_a0) diff --git a/network/p2p/mock/gossip_sub_rpc_inspector_suite_factory_func.go b/network/p2p/mock/gossip_sub_rpc_inspector_suite_factory_func.go index ae84974031c..24c253b70d2 100644 --- a/network/p2p/mock/gossip_sub_rpc_inspector_suite_factory_func.go +++ b/network/p2p/mock/gossip_sub_rpc_inspector_suite_factory_func.go @@ -26,25 +26,25 @@ type GossipSubRpcInspectorSuiteFactoryFunc struct { mock.Mock } -// Execute provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7 -func (_m *GossipSubRpcInspectorSuiteFactoryFunc) Execute(_a0 irrecoverable.SignalerContext, _a1 zerolog.Logger, _a2 flow.Identifier, _a3 *p2pconf.GossipSubRPCInspectorsConfig, _a4 module.GossipSubMetrics, _a5 metrics.HeroCacheMetricsFactory, _a6 network.NetworkingType, _a7 module.IdentityProvider) (p2p.GossipSubInspectorSuite, error) { - ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) +// Execute provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8 +func (_m *GossipSubRpcInspectorSuiteFactoryFunc) Execute(_a0 irrecoverable.SignalerContext, _a1 zerolog.Logger, _a2 flow.Identifier, _a3 *p2pconf.GossipSubRPCInspectorsConfig, _a4 module.GossipSubMetrics, _a5 metrics.HeroCacheMetricsFactory, _a6 network.NetworkingType, _a7 module.IdentityProvider, _a8 func() p2p.TopicProvider) (p2p.GossipSubInspectorSuite, error) { + ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) var r0 p2p.GossipSubInspectorSuite var r1 error - if rf, ok := ret.Get(0).(func(irrecoverable.SignalerContext, zerolog.Logger, flow.Identifier, *p2pconf.GossipSubRPCInspectorsConfig, module.GossipSubMetrics, metrics.HeroCacheMetricsFactory, network.NetworkingType, module.IdentityProvider) (p2p.GossipSubInspectorSuite, error)); ok { - return rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) + if rf, ok := ret.Get(0).(func(irrecoverable.SignalerContext, zerolog.Logger, flow.Identifier, *p2pconf.GossipSubRPCInspectorsConfig, module.GossipSubMetrics, metrics.HeroCacheMetricsFactory, network.NetworkingType, module.IdentityProvider, func() p2p.TopicProvider) (p2p.GossipSubInspectorSuite, error)); ok { + return rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) } - if rf, ok := ret.Get(0).(func(irrecoverable.SignalerContext, zerolog.Logger, flow.Identifier, *p2pconf.GossipSubRPCInspectorsConfig, module.GossipSubMetrics, metrics.HeroCacheMetricsFactory, network.NetworkingType, module.IdentityProvider) p2p.GossipSubInspectorSuite); ok { - r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) + if rf, ok := ret.Get(0).(func(irrecoverable.SignalerContext, zerolog.Logger, flow.Identifier, *p2pconf.GossipSubRPCInspectorsConfig, module.GossipSubMetrics, metrics.HeroCacheMetricsFactory, network.NetworkingType, module.IdentityProvider, func() p2p.TopicProvider) p2p.GossipSubInspectorSuite); ok { + r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(p2p.GossipSubInspectorSuite) } } - if rf, ok := ret.Get(1).(func(irrecoverable.SignalerContext, zerolog.Logger, flow.Identifier, *p2pconf.GossipSubRPCInspectorsConfig, module.GossipSubMetrics, metrics.HeroCacheMetricsFactory, network.NetworkingType, module.IdentityProvider) error); ok { - r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) + if rf, ok := ret.Get(1).(func(irrecoverable.SignalerContext, zerolog.Logger, flow.Identifier, *p2pconf.GossipSubRPCInspectorsConfig, module.GossipSubMetrics, metrics.HeroCacheMetricsFactory, network.NetworkingType, module.IdentityProvider, func() p2p.TopicProvider) error); ok { + r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) } else { r1 = ret.Error(1) } diff --git a/network/p2p/mock/score_option_builder.go b/network/p2p/mock/score_option_builder.go index d0f437bfc12..e5698aadf94 100644 --- a/network/p2p/mock/score_option_builder.go +++ b/network/p2p/mock/score_option_builder.go @@ -3,6 +3,7 @@ package mockp2p import ( + irrecoverable "github.com/onflow/flow-go/module/irrecoverable" mock "github.com/stretchr/testify/mock" pubsub "github.com/libp2p/go-libp2p-pubsub" @@ -41,6 +42,43 @@ func (_m *ScoreOptionBuilder) BuildFlowPubSubScoreOption() (*pubsub.PeerScorePar return r0, r1 } +// Done provides a mock function with given fields: +func (_m *ScoreOptionBuilder) Done() <-chan struct{} { + ret := _m.Called() + + var r0 <-chan struct{} + if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan struct{}) + } + } + + return r0 +} + +// Ready provides a mock function with given fields: +func (_m *ScoreOptionBuilder) Ready() <-chan struct{} { + ret := _m.Called() + + var r0 <-chan struct{} + if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan struct{}) + } + } + + return r0 +} + +// Start provides a mock function with given fields: _a0 +func (_m *ScoreOptionBuilder) Start(_a0 irrecoverable.SignalerContext) { + _m.Called(_a0) +} + // TopicScoreParams provides a mock function with given fields: _a0 func (_m *ScoreOptionBuilder) TopicScoreParams(_a0 *pubsub.Topic) *pubsub.TopicScoreParams { ret := _m.Called(_a0) diff --git a/network/p2p/mock/stream_factory.go b/network/p2p/mock/stream_factory.go index 5b2192e703f..b95e52d3ff8 100644 --- a/network/p2p/mock/stream_factory.go +++ b/network/p2p/mock/stream_factory.go @@ -18,46 +18,25 @@ type StreamFactory struct { mock.Mock } -// Connect provides a mock function with given fields: _a0, _a1 -func (_m *StreamFactory) Connect(_a0 context.Context, _a1 peer.AddrInfo) error { - ret := _m.Called(_a0, _a1) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, peer.AddrInfo) error); ok { - r0 = rf(_a0, _a1) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // NewStream provides a mock function with given fields: _a0, _a1, _a2 -func (_m *StreamFactory) NewStream(_a0 context.Context, _a1 peer.ID, _a2 ...protocol.ID) (network.Stream, error) { - _va := make([]interface{}, len(_a2)) - for _i := range _a2 { - _va[_i] = _a2[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, _a1) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +func (_m *StreamFactory) NewStream(_a0 context.Context, _a1 peer.ID, _a2 protocol.ID) (network.Stream, error) { + ret := _m.Called(_a0, _a1, _a2) var r0 network.Stream var r1 error - if rf, ok := ret.Get(0).(func(context.Context, peer.ID, ...protocol.ID) (network.Stream, error)); ok { - return rf(_a0, _a1, _a2...) + if rf, ok := ret.Get(0).(func(context.Context, peer.ID, protocol.ID) (network.Stream, error)); ok { + return rf(_a0, _a1, _a2) } - if rf, ok := ret.Get(0).(func(context.Context, peer.ID, ...protocol.ID) network.Stream); ok { - r0 = rf(_a0, _a1, _a2...) + if rf, ok := ret.Get(0).(func(context.Context, peer.ID, protocol.ID) network.Stream); ok { + r0 = rf(_a0, _a1, _a2) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(network.Stream) } } - if rf, ok := ret.Get(1).(func(context.Context, peer.ID, ...protocol.ID) error); ok { - r1 = rf(_a0, _a1, _a2...) + if rf, ok := ret.Get(1).(func(context.Context, peer.ID, protocol.ID) error); ok { + r1 = rf(_a0, _a1, _a2) } else { r1 = ret.Error(1) } diff --git a/network/p2p/mock/subscription_provider.go b/network/p2p/mock/subscription_provider.go index bc119c00f02..3445a89e6a0 100644 --- a/network/p2p/mock/subscription_provider.go +++ b/network/p2p/mock/subscription_provider.go @@ -3,6 +3,7 @@ package mockp2p import ( + irrecoverable "github.com/onflow/flow-go/module/irrecoverable" mock "github.com/stretchr/testify/mock" peer "github.com/libp2p/go-libp2p/core/peer" @@ -13,6 +14,22 @@ type SubscriptionProvider struct { mock.Mock } +// Done provides a mock function with given fields: +func (_m *SubscriptionProvider) Done() <-chan struct{} { + ret := _m.Called() + + var r0 <-chan struct{} + if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan struct{}) + } + } + + return r0 +} + // GetSubscribedTopics provides a mock function with given fields: pid func (_m *SubscriptionProvider) GetSubscribedTopics(pid peer.ID) []string { ret := _m.Called(pid) @@ -29,6 +46,27 @@ func (_m *SubscriptionProvider) GetSubscribedTopics(pid peer.ID) []string { return r0 } +// Ready provides a mock function with given fields: +func (_m *SubscriptionProvider) Ready() <-chan struct{} { + ret := _m.Called() + + var r0 <-chan struct{} + if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan struct{}) + } + } + + return r0 +} + +// Start provides a mock function with given fields: _a0 +func (_m *SubscriptionProvider) Start(_a0 irrecoverable.SignalerContext) { + _m.Called(_a0) +} + type mockConstructorTestingTNewSubscriptionProvider interface { mock.TestingT Cleanup(func()) diff --git a/network/p2p/mock/subscription_validator.go b/network/p2p/mock/subscription_validator.go index b7f71843639..33c7d1f2d75 100644 --- a/network/p2p/mock/subscription_validator.go +++ b/network/p2p/mock/subscription_validator.go @@ -4,9 +4,9 @@ package mockp2p import ( flow "github.com/onflow/flow-go/model/flow" - mock "github.com/stretchr/testify/mock" + irrecoverable "github.com/onflow/flow-go/module/irrecoverable" - p2p "github.com/onflow/flow-go/network/p2p" + mock "github.com/stretchr/testify/mock" peer "github.com/libp2p/go-libp2p/core/peer" ) @@ -30,20 +30,43 @@ func (_m *SubscriptionValidator) CheckSubscribedToAllowedTopics(pid peer.ID, rol return r0 } -// RegisterSubscriptionProvider provides a mock function with given fields: provider -func (_m *SubscriptionValidator) RegisterSubscriptionProvider(provider p2p.SubscriptionProvider) error { - ret := _m.Called(provider) +// Done provides a mock function with given fields: +func (_m *SubscriptionValidator) Done() <-chan struct{} { + ret := _m.Called() - var r0 error - if rf, ok := ret.Get(0).(func(p2p.SubscriptionProvider) error); ok { - r0 = rf(provider) + var r0 <-chan struct{} + if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { + r0 = rf() } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan struct{}) + } } return r0 } +// Ready provides a mock function with given fields: +func (_m *SubscriptionValidator) Ready() <-chan struct{} { + ret := _m.Called() + + var r0 <-chan struct{} + if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan struct{}) + } + } + + return r0 +} + +// Start provides a mock function with given fields: _a0 +func (_m *SubscriptionValidator) Start(_a0 irrecoverable.SignalerContext) { + _m.Called(_a0) +} + type mockConstructorTestingTNewSubscriptionValidator interface { mock.TestingT Cleanup(func()) diff --git a/network/p2p/p2pbuilder/gossipsub/gossipSubBuilder.go b/network/p2p/p2pbuilder/gossipsub/gossipSubBuilder.go index f3221555b09..251e578185a 100644 --- a/network/p2p/p2pbuilder/gossipsub/gossipSubBuilder.go +++ b/network/p2p/p2pbuilder/gossipsub/gossipSubBuilder.go @@ -44,12 +44,13 @@ type Builder struct { gossipSubScoreTracerInterval time.Duration // the interval at which the gossipsub score tracer logs the peer scores. // gossipSubTracer is a callback interface that is called by the gossipsub implementation upon // certain events. Currently, we use it to log and observe the local mesh of the node. - gossipSubTracer p2p.PubSubTracer - scoreOptionConfig *scoring.ScoreOptionConfig - idProvider module.IdentityProvider - routingSystem routing.Routing - rpcInspectorConfig *p2pconf.GossipSubRPCInspectorsConfig - rpcInspectorSuiteFactory p2p.GossipSubRpcInspectorSuiteFactoryFunc + gossipSubTracer p2p.PubSubTracer + scoreOptionConfig *scoring.ScoreOptionConfig + subscriptionProviderParam *p2pconf.SubscriptionProviderParameters + idProvider module.IdentityProvider + routingSystem routing.Routing + rpcInspectorConfig *p2pconf.GossipSubRPCInspectorsConfig + rpcInspectorSuiteFactory p2p.GossipSubRpcInspectorSuiteFactoryFunc } var _ p2p.GossipSubBuilder = (*Builder)(nil) @@ -184,24 +185,25 @@ func NewGossipSubBuilder( sporkId flow.Identifier, idProvider module.IdentityProvider, rpcInspectorConfig *p2pconf.GossipSubRPCInspectorsConfig, - rpcTracker p2p.RpcControlTracking, -) *Builder { + subscriptionProviderPrams *p2pconf.SubscriptionProviderParameters, + rpcTracker p2p.RpcControlTracking) *Builder { lg := logger.With(). Str("component", "gossipsub"). Str("network-type", networkType.String()). Logger() b := &Builder{ - logger: lg, - metricsCfg: metricsCfg, - sporkId: sporkId, - networkType: networkType, - idProvider: idProvider, - gossipSubFactory: defaultGossipSubFactory(), - gossipSubConfigFunc: defaultGossipSubAdapterConfig(), - scoreOptionConfig: scoring.NewScoreOptionConfig(lg, idProvider), - rpcInspectorConfig: rpcInspectorConfig, - rpcInspectorSuiteFactory: defaultInspectorSuite(rpcTracker), + logger: lg, + metricsCfg: metricsCfg, + sporkId: sporkId, + networkType: networkType, + idProvider: idProvider, + gossipSubFactory: defaultGossipSubFactory(), + gossipSubConfigFunc: defaultGossipSubAdapterConfig(), + scoreOptionConfig: scoring.NewScoreOptionConfig(lg, idProvider), + rpcInspectorConfig: rpcInspectorConfig, + rpcInspectorSuiteFactory: defaultInspectorSuite(rpcTracker), + subscriptionProviderParam: subscriptionProviderPrams, } return b @@ -240,7 +242,8 @@ func defaultInspectorSuite(rpcTracker p2p.RpcControlTracking) p2p.GossipSubRpcIn gossipSubMetrics module.GossipSubMetrics, heroCacheMetricsFactory metrics.HeroCacheMetricsFactory, networkType network.NetworkingType, - idProvider module.IdentityProvider) (p2p.GossipSubInspectorSuite, error) { + idProvider module.IdentityProvider, + topicProvider func() p2p.TopicProvider) (p2p.GossipSubInspectorSuite, error) { metricsInspector := inspector.NewControlMsgMetricsInspector( logger, p2pnode.NewGossipSubControlMessageMetrics(gossipSubMetrics, logger), @@ -253,27 +256,23 @@ func defaultInspectorSuite(rpcTracker p2p.RpcControlTracking) p2p.GossipSubRpcIn networkType)), }...) notificationDistributor := distributor.DefaultGossipSubInspectorNotificationDistributor( - logger, - []queue.HeroStoreConfigOption{ + logger, []queue.HeroStoreConfigOption{ queue.WithHeroStoreSizeLimit(inspectorCfg.GossipSubRPCInspectorNotificationCacheSize), queue.WithHeroStoreCollector(metrics.RpcInspectorNotificationQueueMetricFactory(heroCacheMetricsFactory, networkType))}...) - inspectMsgQueueCacheCollector := metrics.GossipSubRPCInspectorQueueMetricFactory(heroCacheMetricsFactory, networkType) - clusterPrefixedCacheCollector := metrics.GossipSubRPCInspectorClusterPrefixedCacheMetricFactory( - heroCacheMetricsFactory, - networkType) - rpcValidationInspector, err := validation.NewControlMsgValidationInspector( - ctx, - logger, - sporkId, - &inspectorCfg.GossipSubRPCValidationInspectorConfigs, - notificationDistributor, - inspectMsgQueueCacheCollector, - clusterPrefixedCacheCollector, - idProvider, - gossipSubMetrics, - rpcTracker, - ) + params := &validation.InspectorParams{ + Logger: logger, + SporkID: sporkId, + Config: &inspectorCfg.GossipSubRPCValidationInspectorConfigs, + Distributor: notificationDistributor, + HeroCacheMetricsFactory: heroCacheMetricsFactory, + IdProvider: idProvider, + InspectorMetrics: gossipSubMetrics, + RpcTracker: rpcTracker, + NetworkingType: networkType, + TopicOracle: topicProvider, + } + rpcValidationInspector, err := validation.NewControlMsgValidationInspector(params) if err != nil { return nil, fmt.Errorf("failed to create new control message valiadation inspector: %w", err) } @@ -292,6 +291,10 @@ func defaultInspectorSuite(rpcTracker p2p.RpcControlTracking) p2p.GossipSubRpcIn // - error: if an error occurs during the creation of the GossipSub pubsub system, it is returned. Otherwise, nil is returned. // Note that on happy path, the returned error is nil. Any error returned is unexpected and should be handled as irrecoverable. func (g *Builder) Build(ctx irrecoverable.SignalerContext) (p2p.PubSubAdapter, error) { + // placeholder for the gossipsub pubsub system that will be created (so that it can be passed around even + // before it is created). + var gossipSub p2p.PubSubAdapter + gossipSubConfigs := g.gossipSubConfigFunc( &p2p.BasePubSubAdapterConfig{ MaxMessageSize: p2pnode.DefaultMaxPubSubMsgSize, @@ -314,7 +317,10 @@ func (g *Builder) Build(ctx irrecoverable.SignalerContext) (p2p.PubSubAdapter, e g.metricsCfg.Metrics, g.metricsCfg.HeroCacheFactory, g.networkType, - g.idProvider) + g.idProvider, + func() p2p.TopicProvider { + return gossipSub + }) if err != nil { return nil, fmt.Errorf("could not create gossipsub inspector suite: %w", err) } @@ -323,19 +329,31 @@ func (g *Builder) Build(ctx irrecoverable.SignalerContext) (p2p.PubSubAdapter, e var scoreOpt *scoring.ScoreOption var scoreTracer p2p.PeerScoreTracer if g.gossipSubPeerScoring { + // wires the gossipsub score option to the subscription provider. + subscriptionProvider, err := scoring.NewSubscriptionProvider(&scoring.SubscriptionProviderConfig{ + Logger: g.logger, + TopicProviderOracle: func() p2p.TopicProvider { + // gossipSub has not been created yet, hence instead of passing it directly, we pass a function that returns it. + // the cardinal assumption is this function is only invoked when the subscription provider is started, which is + // after the gossipsub is created. + return gossipSub + }, + IdProvider: g.idProvider, + Params: g.subscriptionProviderParam, + HeroCacheMetricsFactory: g.metricsCfg.HeroCacheFactory, + }) + if err != nil { + return nil, fmt.Errorf("could not create subscription provider: %w", err) + } + g.scoreOptionConfig.SetRegisterNotificationConsumerFunc(inspectorSuite.AddInvalidControlMessageConsumer) - scoreOpt = scoring.NewScoreOption(g.scoreOptionConfig) + scoreOpt = scoring.NewScoreOption(g.scoreOptionConfig, subscriptionProvider) gossipSubConfigs.WithScoreOption(scoreOpt) if g.gossipSubScoreTracerInterval > 0 { - scoreTracer = tracer.NewGossipSubScoreTracer( - g.logger, - g.idProvider, - g.metricsCfg.Metrics, - g.gossipSubScoreTracerInterval) + scoreTracer = tracer.NewGossipSubScoreTracer(g.logger, g.idProvider, g.metricsCfg.Metrics, g.gossipSubScoreTracerInterval) gossipSubConfigs.WithScoreTracer(scoreTracer) } - } else { g.logger.Warn(). Str(logging.KeyNetworkingSecurity, "true"). @@ -350,22 +368,10 @@ func (g *Builder) Build(ctx irrecoverable.SignalerContext) (p2p.PubSubAdapter, e return nil, fmt.Errorf("could not create gossipsub: host is nil") } - gossipSub, err := g.gossipSubFactory(ctx, g.logger, g.h, gossipSubConfigs, inspectorSuite) + gossipSub, err = g.gossipSubFactory(ctx, g.logger, g.h, gossipSubConfigs, inspectorSuite) if err != nil { return nil, fmt.Errorf("could not create gossipsub: %w", err) } - err = inspectorSuite.SetTopicOracle(gossipSub.GetTopics) - if err != nil { - return nil, fmt.Errorf("could not set topic oracle on inspector suite: %w", err) - } - - if scoreOpt != nil { - err := scoreOpt.SetSubscriptionProvider(scoring.NewSubscriptionProvider(g.logger, gossipSub)) - if err != nil { - return nil, fmt.Errorf("could not set subscription provider: %w", err) - } - } - return gossipSub, nil } diff --git a/network/p2p/p2pbuilder/inspector/suite.go b/network/p2p/p2pbuilder/inspector/suite.go index 167a227ca30..8fe6a1c4547 100644 --- a/network/p2p/p2pbuilder/inspector/suite.go +++ b/network/p2p/p2pbuilder/inspector/suite.go @@ -36,7 +36,9 @@ var _ p2p.GossipSubInspectorSuite = (*GossipSubInspectorSuite)(nil) // regarding gossipsub control messages is detected. // Returns: // - the new GossipSubInspectorSuite. -func NewGossipSubInspectorSuite(metricsInspector *inspector.ControlMsgMetricsInspector, validationInspector *validation.ControlMsgValidationInspector, ctrlMsgInspectDistributor p2p.GossipSubInspectorNotifDistributor) *GossipSubInspectorSuite { +func NewGossipSubInspectorSuite(metricsInspector *inspector.ControlMsgMetricsInspector, + validationInspector *validation.ControlMsgValidationInspector, + ctrlMsgInspectDistributor p2p.GossipSubInspectorNotifDistributor) *GossipSubInspectorSuite { inspectors := []p2p.GossipSubRPCInspector{metricsInspector, validationInspector} s := &GossipSubInspectorSuite{ ctrlMsgInspectDistributor: ctrlMsgInspectDistributor, @@ -89,13 +91,3 @@ func (s *GossipSubInspectorSuite) ActiveClustersChanged(list flow.ChainIDList) { } } } - -// SetTopicOracle sets the topic oracle of the gossipsub inspector suite. -// The topic oracle is used to determine the list of topics that the node is subscribed to. -// If an oracle is not set, the node will not be able to determine the list of topics that the node is subscribed to. -// Currently, the only inspector that utilizes the topic oracle is the validation inspector. -// This func is expected to be called once and will return an error on all subsequent calls. -// All errors returned from this func are considered irrecoverable. -func (s *GossipSubInspectorSuite) SetTopicOracle(topicOracle func() []string) error { - return s.validationInspector.SetTopicOracle(topicOracle) -} diff --git a/network/p2p/p2pbuilder/libp2pNodeBuilder.go b/network/p2p/p2pbuilder/libp2pNodeBuilder.go index 8d0380db23b..45baaa0a6d5 100644 --- a/network/p2p/p2pbuilder/libp2pNodeBuilder.go +++ b/network/p2p/p2pbuilder/libp2pNodeBuilder.go @@ -36,6 +36,7 @@ import ( p2pconfig "github.com/onflow/flow-go/network/p2p/p2pbuilder/config" gossipsubbuilder "github.com/onflow/flow-go/network/p2p/p2pbuilder/gossipsub" "github.com/onflow/flow-go/network/p2p/p2pconf" + "github.com/onflow/flow-go/network/p2p/p2plogging" "github.com/onflow/flow-go/network/p2p/p2pnode" "github.com/onflow/flow-go/network/p2p/subscription" "github.com/onflow/flow-go/network/p2p/tracer" @@ -86,6 +87,7 @@ func NewNodeBuilder( rCfg *p2pconf.ResourceManagerConfig, rpcInspectorCfg *p2pconf.GossipSubRPCInspectorsConfig, peerManagerConfig *p2pconfig.PeerManagerConfig, + subscriptionProviderParam *p2pconf.SubscriptionProviderParameters, disallowListCacheCfg *p2p.DisallowListCacheConfig, rpcTracker p2p.RpcControlTracking, unicastConfig *p2pconfig.UnicastConfig, @@ -105,7 +107,7 @@ func NewNodeBuilder( networkingType, sporkId, idProvider, - rpcInspectorCfg, + rpcInspectorCfg, subscriptionProviderParam, rpcTracker), peerManagerConfig: peerManagerConfig, unicastConfig: unicastConfig, @@ -237,6 +239,7 @@ func (builder *LibP2PNodeBuilder) Build() (p2p.LibP2PNode, error) { return nil, err } builder.gossipSubBuilder.SetHost(h) + lg := builder.logger.With().Str("local_peer_id", p2plogging.PeerId(h.ID())).Logger() pCache, err := p2pnode.NewProtocolPeerCache(builder.logger, h) if err != nil { @@ -252,7 +255,7 @@ func (builder *LibP2PNodeBuilder) Build() (p2p.LibP2PNode, error) { peerUpdater, err := connection.NewPeerUpdater( &connection.PeerUpdaterConfig{ PruneConnections: builder.peerManagerConfig.ConnectionPruning, - Logger: builder.logger, + Logger: lg, Host: connection.NewConnectorHost(h), Connector: connector, }) @@ -260,35 +263,30 @@ func (builder *LibP2PNodeBuilder) Build() (p2p.LibP2PNode, error) { return nil, fmt.Errorf("failed to create libp2p connector: %w", err) } - peerManager = connection.NewPeerManager(builder.logger, builder.peerManagerConfig.UpdateInterval, peerUpdater) + peerManager = connection.NewPeerManager(lg, builder.peerManagerConfig.UpdateInterval, peerUpdater) if builder.unicastConfig.RateLimiterDistributor != nil { builder.unicastConfig.RateLimiterDistributor.AddConsumer(peerManager) } } - node := builder.createNode(builder.logger, h, pCache, peerManager, builder.disallowListCacheCfg) + node := builder.createNode(lg, h, pCache, peerManager, builder.disallowListCacheCfg) if builder.connGater != nil { builder.connGater.SetDisallowListOracle(node) } unicastManager, err := unicast.NewUnicastManager(&unicast.ManagerConfig{ - Logger: builder.logger, + Logger: lg, StreamFactory: stream.NewLibP2PStreamFactory(h), SporkId: builder.sporkId, - ConnStatus: node, CreateStreamBackoffDelay: builder.unicastConfig.CreateStreamBackoffDelay, - DialBackoffDelay: builder.unicastConfig.DialBackoffDelay, - DialInProgressBackoffDelay: builder.unicastConfig.DialInProgressBackoffDelay, Metrics: builder.metricsConfig.Metrics, StreamZeroRetryResetThreshold: builder.unicastConfig.StreamZeroRetryResetThreshold, - DialZeroRetryResetThreshold: builder.unicastConfig.DialZeroRetryResetThreshold, MaxStreamCreationRetryAttemptTimes: builder.unicastConfig.MaxStreamCreationRetryAttemptTimes, - MaxDialRetryAttemptTimes: builder.unicastConfig.MaxDialRetryAttemptTimes, - DialConfigCacheFactory: func(configFactory func() unicast.DialConfig) unicast.DialConfigCache { - return unicastcache.NewDialConfigCache(builder.unicastConfig.DialConfigCacheSize, - builder.logger, + UnicastConfigCacheFactory: func(configFactory func() unicast.Config) unicast.ConfigCache { + return unicastcache.NewUnicastConfigCache(builder.unicastConfig.ConfigCacheSize, + lg, metrics.DialConfigCacheMetricFactory(builder.metricsConfig.HeroCacheFactory, builder.networkingType), configFactory) }, @@ -310,7 +308,7 @@ func (builder *LibP2PNodeBuilder) Build() (p2p.LibP2PNode, error) { ctx.Throw(fmt.Errorf("could not set routing system: %w", err)) } builder.gossipSubBuilder.SetRoutingSystem(routingSystem) - builder.logger.Debug().Msg("routing system created") + lg.Debug().Msg("routing system created") } // gossipsub is created here, because it needs to be created during the node startup. gossipSub, err := builder.gossipSubBuilder.Build(ctx) @@ -470,8 +468,7 @@ func DefaultNodeBuilder( sporkId, idProvider, rCfg, - rpcInspectorCfg, - peerManagerCfg, + rpcInspectorCfg, peerManagerCfg, &gossipCfg.SubscriptionProviderConfig, disallowListCacheCfg, meshTracer, uniCfg). diff --git a/network/p2p/p2pconf/gossipsub.go b/network/p2p/p2pconf/gossipsub.go index 1bf50263e4c..b8392c3268c 100644 --- a/network/p2p/p2pconf/gossipsub.go +++ b/network/p2p/p2pconf/gossipsub.go @@ -64,6 +64,20 @@ type GossipSubConfig struct { // PeerScoring is whether to enable GossipSub peer scoring. PeerScoring bool `mapstructure:"gossipsub-peer-scoring-enabled"` + + SubscriptionProviderConfig SubscriptionProviderParameters `mapstructure:",squash"` +} + +type SubscriptionProviderParameters struct { + // SubscriptionUpdateInterval is the interval for updating the list of topics the node have subscribed to; as well as the list of all + // peers subscribed to each of those topics. This is used to penalize peers that have an invalid subscription based on their role. + SubscriptionUpdateInterval time.Duration `validate:"gt=0s" mapstructure:"gossipsub-subscription-provider-update-interval"` + + // CacheSize is the size of the cache that keeps the list of peers subscribed to each topic as the local node. + // This is the local view of the current node towards the subscription status of other nodes in the system. + // The cache must be large enough to accommodate the maximum number of nodes in the system, otherwise the view of the local node will be incomplete + // due to cache eviction. + CacheSize uint32 `validate:"gt=0" mapstructure:"gossipsub-subscription-provider-cache-size"` } // GossipSubTracerConfig is the config for the gossipsub tracer. GossipSub tracer is used to trace the local mesh events and peer scores. diff --git a/network/p2p/p2pconf/gossipsub_rpc_inspectors.go b/network/p2p/p2pconf/gossipsub_rpc_inspectors.go index 497df0bf724..3d3cea79b21 100644 --- a/network/p2p/p2pconf/gossipsub_rpc_inspectors.go +++ b/network/p2p/p2pconf/gossipsub_rpc_inspectors.go @@ -40,7 +40,7 @@ type IWantRPCInspectionConfig struct { // If the cache miss threshold is exceeded an invalid control message notification is disseminated and the sender will be penalized. CacheMissThreshold float64 `validate:"gt=0" mapstructure:"gossipsub-rpc-iwant-cache-miss-threshold"` // CacheMissCheckSize the iWants size at which message id cache misses will be checked. - CacheMissCheckSize int `validate:"gte=1000" mapstructure:"gossipsub-rpc-iwant-cache-miss-check-size"` + CacheMissCheckSize int `validate:"gt=0" mapstructure:"gossipsub-rpc-iwant-cache-miss-check-size"` // DuplicateMsgIDThreshold maximum allowed duplicate message IDs in a single iWant control message. // If the duplicate message threshold is exceeded an invalid control message notification is disseminated and the sender will be penalized. DuplicateMsgIDThreshold float64 `validate:"gt=0" mapstructure:"gossipsub-rpc-iwant-duplicate-message-id-threshold"` @@ -63,7 +63,7 @@ type ClusterPrefixedMessageConfig struct { // when the cluster ID's provider is set asynchronously. It also allows processing of some stale messages that may be sent by nodes // that fall behind in the protocol. After the amount of cluster prefixed control messages processed exceeds this threshold the node // will be pushed to the edge of the network mesh. - ClusterPrefixHardThreshold float64 `validate:"gt=0" mapstructure:"gossipsub-rpc-cluster-prefixed-hard-threshold"` + ClusterPrefixHardThreshold float64 `validate:"gte=0" mapstructure:"gossipsub-rpc-cluster-prefixed-hard-threshold"` // ClusterPrefixedControlMsgsReceivedCacheSize size of the cache used to track the amount of cluster prefixed topics received by peers. ClusterPrefixedControlMsgsReceivedCacheSize uint32 `validate:"gt=0" mapstructure:"gossipsub-cluster-prefix-tracker-cache-size"` // ClusterPrefixedControlMsgsReceivedCacheDecay decay val used for the geometric decay of cache counters used to keep track of cluster prefixed topics received by peers. diff --git a/network/p2p/p2plogging/internal/peerIdCache_test.go b/network/p2p/p2plogging/internal/peerIdCache_test.go index 08d32ebb44f..e4e799d9d62 100644 --- a/network/p2p/p2plogging/internal/peerIdCache_test.go +++ b/network/p2p/p2plogging/internal/peerIdCache_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/onflow/flow-go/network/internal/p2pfixtures" "github.com/onflow/flow-go/network/p2p/p2plogging/internal" "github.com/onflow/flow-go/utils/unittest" ) @@ -65,9 +64,9 @@ func TestPeerIdCache_EjectionScenarios(t *testing.T) { assert.Equal(t, 0, cache.Size()) // add peer IDs to fill the cache - pid1 := p2pfixtures.PeerIdFixture(t) - pid2 := p2pfixtures.PeerIdFixture(t) - pid3 := p2pfixtures.PeerIdFixture(t) + pid1 := unittest.PeerIdFixture(t) + pid2 := unittest.PeerIdFixture(t) + pid3 := unittest.PeerIdFixture(t) cache.PeerIdString(pid1) assert.Equal(t, 1, cache.Size()) @@ -83,7 +82,7 @@ func TestPeerIdCache_EjectionScenarios(t *testing.T) { assert.Equal(t, 3, cache.Size()) // add a new peer ID - pid4 := p2pfixtures.PeerIdFixture(t) + pid4 := unittest.PeerIdFixture(t) cache.PeerIdString(pid4) assert.Equal(t, 3, cache.Size()) diff --git a/network/p2p/p2pnode/gossipSubAdapter.go b/network/p2p/p2pnode/gossipSubAdapter.go index 59bd2f2d65a..f2d1296b588 100644 --- a/network/p2p/p2pnode/gossipSubAdapter.go +++ b/network/p2p/p2pnode/gossipSubAdapter.go @@ -39,7 +39,11 @@ type GossipSubAdapter struct { var _ p2p.PubSubAdapter = (*GossipSubAdapter)(nil) -func NewGossipSubAdapter(ctx context.Context, logger zerolog.Logger, h host.Host, cfg p2p.PubSubAdapterConfig, clusterChangeConsumer p2p.CollectionClusterChangesConsumer) (p2p.PubSubAdapter, error) { +func NewGossipSubAdapter(ctx context.Context, + logger zerolog.Logger, + h host.Host, + cfg p2p.PubSubAdapterConfig, + clusterChangeConsumer p2p.CollectionClusterChangesConsumer) (p2p.PubSubAdapter, error) { gossipSubConfig, ok := cfg.(*GossipSubAdapterConfig) if !ok { return nil, fmt.Errorf("invalid gossipsub config type: %T", cfg) @@ -68,44 +72,78 @@ func NewGossipSubAdapter(ctx context.Context, logger zerolog.Logger, h host.Host if scoreTracer := gossipSubConfig.ScoreTracer(); scoreTracer != nil { builder.AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { ready() - a.logger.Debug().Str("component", "gossipsub_score_tracer").Msg("starting score tracer") + a.logger.Info().Msg("starting score tracer") scoreTracer.Start(ctx) - a.logger.Debug().Str("component", "gossipsub_score_tracer").Msg("score tracer started") + select { + case <-ctx.Done(): + a.logger.Warn().Msg("aborting score tracer startup due to context done") + case <-scoreTracer.Ready(): + a.logger.Info().Msg("score tracer is ready") + } + ready() + <-ctx.Done() + a.logger.Info().Msg("stopping score tracer") <-scoreTracer.Done() - a.logger.Debug().Str("component", "gossipsub_score_tracer").Msg("score tracer stopped") + a.logger.Info().Msg("score tracer stopped") }) a.peerScoreExposer = scoreTracer } if tracer := gossipSubConfig.PubSubTracer(); tracer != nil { builder.AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { - ready() - a.logger.Debug().Str("component", "gossipsub_tracer").Msg("starting tracer") + a.logger.Info().Msg("starting pubsub tracer") tracer.Start(ctx) - a.logger.Debug().Str("component", "gossipsub_tracer").Msg("tracer started") + select { + case <-ctx.Done(): + a.logger.Warn().Msg("aborting pubsub tracer startup due to context done") + case <-tracer.Ready(): + a.logger.Info().Msg("pubsub tracer is ready") + } + ready() + <-ctx.Done() + a.logger.Info().Msg("stopping pubsub tracer") <-tracer.Done() - a.logger.Debug().Str("component", "gossipsub_tracer").Msg("tracer stopped") + a.logger.Info().Msg("pubsub tracer stopped") }) } if inspectorSuite := gossipSubConfig.InspectorSuiteComponent(); inspectorSuite != nil { builder.AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { - a.logger.Debug().Str("component", "gossipsub_inspector_suite").Msg("starting inspector suite") + a.logger.Info().Msg("starting inspector suite") inspectorSuite.Start(ctx) - a.logger.Debug().Str("component", "gossipsub_inspector_suite").Msg("inspector suite started") - select { case <-ctx.Done(): - a.logger.Debug().Str("component", "gossipsub_inspector_suite").Msg("inspector suite context done") + a.logger.Warn().Msg("aborting inspector suite startup due to context done") case <-inspectorSuite.Ready(): - ready() - a.logger.Debug().Str("component", "gossipsub_inspector_suite").Msg("inspector suite ready") + a.logger.Info().Msg("inspector suite is ready") } + ready() + <-ctx.Done() + a.logger.Info().Msg("stopping inspector suite") <-inspectorSuite.Done() - a.logger.Debug().Str("component", "gossipsub_inspector_suite").Msg("inspector suite stopped") + a.logger.Info().Msg("inspector suite stopped") + }) + } + + if scoringComponent := gossipSubConfig.ScoringComponent(); scoringComponent != nil { + builder.AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + a.logger.Info().Msg("starting gossipsub scoring component") + scoringComponent.Start(ctx) + select { + case <-ctx.Done(): + a.logger.Warn().Msg("aborting gossipsub scoring component startup due to context done") + case <-scoringComponent.Ready(): + a.logger.Info().Msg("gossipsub scoring component is ready") + } + ready() + + <-ctx.Done() + a.logger.Info().Msg("stopping gossipsub scoring component") + <-scoringComponent.Done() + a.logger.Info().Msg("gossipsub scoring component stopped") }) } diff --git a/network/p2p/p2pnode/gossipSubAdapterConfig.go b/network/p2p/p2pnode/gossipSubAdapterConfig.go index a2dbe59289f..f4069930612 100644 --- a/network/p2p/p2pnode/gossipSubAdapterConfig.go +++ b/network/p2p/p2pnode/gossipSubAdapterConfig.go @@ -116,6 +116,10 @@ func (g *GossipSubAdapterConfig) PubSubTracer() p2p.PubSubTracer { return g.pubsubTracer } +func (g *GossipSubAdapterConfig) ScoringComponent() component.Component { + return g.scoreOption +} + // InspectorSuiteComponent returns the component that manages the lifecycle of the inspector suite. // This is used to start and stop the inspector suite by the PubSubAdapter. // Args: diff --git a/network/p2p/pubsub.go b/network/p2p/pubsub.go index 1b45336bdfa..d0ceb33fe8c 100644 --- a/network/p2p/pubsub.go +++ b/network/p2p/pubsub.go @@ -128,6 +128,7 @@ type Topic interface { // ScoreOptionBuilder abstracts the configuration for the underlying pubsub score implementation. type ScoreOptionBuilder interface { + component.Component // BuildFlowPubSubScoreOption builds the pubsub score options as pubsub.Option for the Flow network. BuildFlowPubSubScoreOption() (*pubsub.PeerScoreParams, *pubsub.PeerScoreThresholds) // TopicScoreParams returns the topic score params for the given topic. diff --git a/network/p2p/scoring/internal/subscriptionCache.go b/network/p2p/scoring/internal/subscriptionCache.go new file mode 100644 index 00000000000..95acafdd422 --- /dev/null +++ b/network/p2p/scoring/internal/subscriptionCache.go @@ -0,0 +1,176 @@ +package internal + +import ( + "errors" + "fmt" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/rs/zerolog" + "go.uber.org/atomic" + + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module" + herocache "github.com/onflow/flow-go/module/mempool/herocache/backdata" + "github.com/onflow/flow-go/module/mempool/herocache/backdata/heropool" + "github.com/onflow/flow-go/module/mempool/stdmap" +) + +var ErrTopicRecordNotFound = fmt.Errorf("topic record not found") + +// SubscriptionRecordCache manages the subscription records of peers in a network. +// It uses a currentCycle counter to track the update cycles of the cache, ensuring the relevance of subscription data. +type SubscriptionRecordCache struct { + c *stdmap.Backend + + // currentCycle is an atomic counter used to track the update cycles of the subscription cache. + // It plays a critical role in maintaining the cache's data relevance and coherence. + // Each increment of currentCycle represents a new update cycle, signifying the cache's transition to a new state + // where only the most recent and relevant subscriptions are maintained. This design choice ensures that the cache + // does not retain stale or outdated subscription information, thereby reflecting the dynamic nature of peer + // subscriptions in the network. It is incremented every time the subscription cache is updated, either with new + // topic subscriptions or other update operations. + // The currentCycle is incremented atomically and externally by calling the MoveToNextUpdateCycle() function. + // This is called by the module that uses the subscription provider cache signaling that whatever updates it has + // made to the cache so far can be considered out-of-date, and the new updates to the cache records should + // overwrite the old ones. + currentCycle atomic.Uint64 +} + +// NewSubscriptionRecordCache creates a new subscription cache with the given size limit. +// Args: +// - sizeLimit: the size limit of the cache. +// - logger: the logger to use for logging. +// - collector: the metrics collector to use for collecting metrics. +func NewSubscriptionRecordCache(sizeLimit uint32, + logger zerolog.Logger, + collector module.HeroCacheMetrics) *SubscriptionRecordCache { + backData := herocache.NewCache(sizeLimit, + herocache.DefaultOversizeFactor, + heropool.LRUEjection, + logger.With().Str("mempool", "subscription-records").Logger(), + collector) + + return &SubscriptionRecordCache{ + c: stdmap.NewBackend(stdmap.WithBackData(backData)), + currentCycle: *atomic.NewUint64(0), + } +} + +// GetSubscribedTopics returns the list of topics a peer is subscribed to. +// Returns: +// - []string: the list of topics the peer is subscribed to. +// - bool: true if there is a record for the peer, false otherwise. +func (s *SubscriptionRecordCache) GetSubscribedTopics(pid peer.ID) ([]string, bool) { + e, ok := s.c.ByID(flow.MakeID(pid)) + if !ok { + return nil, false + } + return e.(SubscriptionRecordEntity).Topics, true +} + +// MoveToNextUpdateCycle moves the subscription cache to the next update cycle. +// A new update cycle is started when the subscription cache is first created, and then every time the subscription cache +// is updated. The update cycle is used to keep track of the last time the subscription cache was updated. It is used to +// implement a notion of time in the subscription cache. +// When the update cycle is moved forward, it means that all the updates made to the subscription cache so far are +// considered out-of-date, and the new updates to the cache records should overwrite the old ones. +// The expected behavior is that the update cycle is moved forward by the module that uses the subscription provider once +// per each update on the "entire" cache (and not per each update on a single record). +// In other words, assume a cache with 3 records: A, B, and C. If the module updates record A, then record B, and then +// record C, the module should move the update cycle forward only once after updating record C, and then update record A +// B, and C again. If the module moves the update cycle forward after updating record A, then again after updating +// record B, and then again after updating record C, the cache will be in an inconsistent state. +// Returns: +// - uint64: the current update cycle. +func (s *SubscriptionRecordCache) MoveToNextUpdateCycle() uint64 { + s.currentCycle.Inc() + return s.currentCycle.Load() +} + +// AddTopicForPeer appends a topic to the list of topics a peer is subscribed to. If the peer is not subscribed to any +// topics yet, a new record is created. +// If the last update cycle is older than the current cycle, the list of topics for the peer is first cleared, and then +// the topic is added to the list. This is to ensure that the list of topics for a peer is always up to date. +// Args: +// - pid: the peer id of the peer. +// - topic: the topic to add. +// Returns: +// - []string: the list of topics the peer is subscribed to after the update. +// - error: an error if the update failed; any returned error is an irrecoverable error and indicates a bug or misconfiguration. +// Implementation must be thread-safe. +func (s *SubscriptionRecordCache) AddTopicForPeer(pid peer.ID, topic string) ([]string, error) { + // first, we try to optimistically adjust the record assuming that the record already exists. + entityId := flow.MakeID(pid) + topics, err := s.addTopicForPeer(entityId, topic) + + switch { + case errors.Is(err, ErrTopicRecordNotFound): + // if the record does not exist, we initialize the record and try to adjust it again. + // Note: there is an edge case where the record is initialized by another goroutine between the two calls. + // In this case, the init function is invoked twice, but it is not a problem because the underlying + // cache is thread-safe. Hence, we do not need to synchronize the two calls. In such cases, one of the + // two calls returns false, and the other call returns true. We do not care which call returns false, hence, + // we ignore the return value of the init function. + _ = s.c.Add(SubscriptionRecordEntity{ + entityId: entityId, + PeerID: pid, + Topics: make([]string, 0), + LastUpdatedCycle: s.currentCycle.Load(), + }) + // as the record is initialized, the adjust attempt should not return an error, and any returned error + // is an irrecoverable error and indicates a bug. + return s.addTopicForPeer(entityId, topic) + case err != nil: + // if the adjust function returns an unexpected error on the first attempt, we return the error directly. + return nil, err + default: + // if the adjust function returns no error, we return the updated list of topics. + return topics, nil + } +} + +func (s *SubscriptionRecordCache) addTopicForPeer(entityId flow.Identifier, topic string) ([]string, error) { + var rErr error + updatedEntity, adjusted := s.c.Adjust(entityId, func(entity flow.Entity) flow.Entity { + record, ok := entity.(SubscriptionRecordEntity) + if !ok { + // sanity check + // This should never happen, because the cache only contains SubscriptionRecordEntity entities. + panic(fmt.Sprintf("invalid entity type, expected SubscriptionRecordEntity type, got: %T", entity)) + } + + currentCycle := s.currentCycle.Load() + if record.LastUpdatedCycle > currentCycle { + // sanity check + // This should never happen, because the update cycle must be moved forward before adding a topic. + panic(fmt.Sprintf("invalid last updated cycle, expected <= %d, got: %d", currentCycle, record.LastUpdatedCycle)) + } + if record.LastUpdatedCycle < currentCycle { + // This record was not updated in the current cycle, so we can wipe its topics list (topic list is only + // valid for the current cycle). + record.Topics = make([]string, 0) + } + // check if the topic already exists; if it does, we do not need to update the record. + for _, t := range record.Topics { + if t == topic { + // topic already exists + return record + } + } + record.LastUpdatedCycle = currentCycle + record.Topics = append(record.Topics, topic) + + // Return the adjusted record. + return record + }) + + if rErr != nil { + return nil, fmt.Errorf("failed to adjust record: %w", rErr) + } + + if !adjusted { + return nil, ErrTopicRecordNotFound + } + + return updatedEntity.(SubscriptionRecordEntity).Topics, nil +} diff --git a/network/p2p/scoring/internal/subscriptionCache_test.go b/network/p2p/scoring/internal/subscriptionCache_test.go new file mode 100644 index 00000000000..a333c18bdd8 --- /dev/null +++ b/network/p2p/scoring/internal/subscriptionCache_test.go @@ -0,0 +1,319 @@ +package internal_test + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/network/p2p/scoring/internal" + "github.com/onflow/flow-go/utils/unittest" +) + +// TestNewSubscriptionRecordCache tests that NewSubscriptionRecordCache returns a valid cache. +func TestNewSubscriptionRecordCache(t *testing.T) { + sizeLimit := uint32(100) + + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + require.NotNil(t, cache, "cache should not be nil") + require.IsType(t, &internal.SubscriptionRecordCache{}, cache, "cache should be of type *SubscriptionRecordCache") +} + +// TestSubscriptionCache_GetSubscribedTopics tests the retrieval of subscribed topics for a peer. +func TestSubscriptionCache_GetSubscribedTopics(t *testing.T) { + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + // create a dummy peer ID + peerID := unittest.PeerIdFixture(t) + + // case when the peer has a subscription + topics := []string{"topic1", "topic2"} + updatedTopics, err := cache.AddTopicForPeer(peerID, topics[0]) + require.NoError(t, err, "adding topic 1 should not produce an error") + require.Equal(t, topics[:1], updatedTopics, "updated topics should match the added topic") + updatedTopics, err = cache.AddTopicForPeer(peerID, topics[1]) + require.NoError(t, err, "adding topic 2 should not produce an error") + require.Equal(t, topics, updatedTopics, "updated topics should match the added topic") + + retrievedTopics, found := cache.GetSubscribedTopics(peerID) + require.True(t, found, "peer should be found") + require.ElementsMatch(t, topics, retrievedTopics, "retrieved topics should match the added topics") + + // case when the peer does not have a subscription + nonExistentPeerID := unittest.PeerIdFixture(t) + retrievedTopics, found = cache.GetSubscribedTopics(nonExistentPeerID) + require.False(t, found, "non-existent peer should not be found") + require.Nil(t, retrievedTopics, "retrieved topics for non-existent peer should be nil") +} + +// TestSubscriptionCache_MoveToNextUpdateCycle tests the increment of update cycles in SubscriptionRecordCache. +// The first increment should set the cycle to 1, and the second increment should set the cycle to 2. +func TestSubscriptionCache_MoveToNextUpdateCycle(t *testing.T) { + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + // initial cycle should be 0, so first increment sets it to 1 + firstCycle := cache.MoveToNextUpdateCycle() + require.Equal(t, uint64(1), firstCycle, "first cycle should be 1 after first increment") + + // increment cycle again and verify it's now 2 + secondCycle := cache.MoveToNextUpdateCycle() + require.Equal(t, uint64(2), secondCycle, "second cycle should be 2 after second increment") +} + +// TestSubscriptionCache_TestAddTopicForPeer tests adding a topic for a peer. +func TestSubscriptionCache_TestAddTopicForPeer(t *testing.T) { + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + // case when adding a topic to an existing peer + existingPeerID := unittest.PeerIdFixture(t) + firstTopic := "topic1" + secondTopic := "topic2" + + // add first topic to the existing peer + _, err := cache.AddTopicForPeer(existingPeerID, firstTopic) + require.NoError(t, err, "adding first topic to existing peer should not produce an error") + + // add second topic to the same peer + updatedTopics, err := cache.AddTopicForPeer(existingPeerID, secondTopic) + require.NoError(t, err, "adding second topic to existing peer should not produce an error") + require.ElementsMatch(t, []string{firstTopic, secondTopic}, updatedTopics, "updated topics should match the added topics") + + // case when adding a topic to a new peer + newPeerID := unittest.PeerIdFixture(t) + newTopic := "newTopic" + + // add a topic to the new peer + updatedTopics, err = cache.AddTopicForPeer(newPeerID, newTopic) + require.NoError(t, err, "adding topic to new peer should not produce an error") + require.Equal(t, []string{newTopic}, updatedTopics, "updated topics for new peer should match the added topic") + + // sanity check that the topics for existing peer are still the same + retrievedTopics, found := cache.GetSubscribedTopics(existingPeerID) + require.True(t, found, "existing peer should be found") + require.ElementsMatch(t, []string{firstTopic, secondTopic}, retrievedTopics, "retrieved topics should match the added topics") +} + +// TestSubscriptionCache_DuplicateTopics tests adding a duplicate topic for a peer. The duplicate topic should not be added. +func TestSubscriptionCache_DuplicateTopics(t *testing.T) { + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + peerID := unittest.PeerIdFixture(t) + topic := "topic1" + + // add first topic to the existing peer + _, err := cache.AddTopicForPeer(peerID, topic) + require.NoError(t, err, "adding first topic to existing peer should not produce an error") + + // add second topic to the same peer + updatedTopics, err := cache.AddTopicForPeer(peerID, topic) + require.NoError(t, err, "adding duplicate topic to existing peer should not produce an error") + require.Equal(t, []string{topic}, updatedTopics, "duplicate topic should not be added") +} + +// TestSubscriptionCache_MoveUpdateCycle tests that (1) within one update cycle, "AddTopicForPeer" calls append the topics to the list of +// subscribed topics for peer, (2) as long as there is no "AddTopicForPeer" call, moving to the next update cycle +// does not change the subscribed topics for a peer, and (3) calling "AddTopicForPeer" after moving to the next update +// cycle clears the subscribed topics for a peer and adds the new topic. +func TestSubscriptionCache_MoveUpdateCycle(t *testing.T) { + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + peerID := unittest.PeerIdFixture(t) + topic1 := "topic1" + topic2 := "topic2" + topic3 := "topic3" + topic4 := "topic4" + + // adds topic1, topic2, and topic3 to the peer + topics, err := cache.AddTopicForPeer(peerID, topic1) + require.NoError(t, err, "adding first topic to existing peer should not produce an error") + require.Equal(t, []string{topic1}, topics, "updated topics should match the added topic") + topics, err = cache.AddTopicForPeer(peerID, topic2) + require.NoError(t, err, "adding second topic to existing peer should not produce an error") + require.Equal(t, []string{topic1, topic2}, topics, "updated topics should match the added topics") + topics, err = cache.AddTopicForPeer(peerID, topic3) + require.NoError(t, err, "adding third topic to existing peer should not produce an error") + require.Equal(t, []string{topic1, topic2, topic3}, topics, "updated topics should match the added topics") + + // move to next update cycle + cache.MoveToNextUpdateCycle() + topics, found := cache.GetSubscribedTopics(peerID) + require.True(t, found, "existing peer should be found") + require.ElementsMatch(t, []string{topic1, topic2, topic3}, topics, "retrieved topics should match the added topics") + + // add topic4 to the peer; since we moved to the next update cycle, the topics for the peer should be cleared + // and topic4 should be the only topic for the peer + topics, err = cache.AddTopicForPeer(peerID, topic4) + require.NoError(t, err, "adding fourth topic to existing peer should not produce an error") + require.Equal(t, []string{topic4}, topics, "updated topics should match the added topic") + + // move to next update cycle + cache.MoveToNextUpdateCycle() + + // since we did not add any topic to the peer, the topics for the peer should be the same as before + topics, found = cache.GetSubscribedTopics(peerID) + require.True(t, found, "existing peer should be found") + require.ElementsMatch(t, []string{topic4}, topics, "retrieved topics should match the added topics") +} + +// TestSubscriptionCache_MoveUpdateCycleWithDifferentPeers tests that moving to the next update cycle does not affect the subscribed +// topics for other peers. +func TestSubscriptionCache_MoveUpdateCycleWithDifferentPeers(t *testing.T) { + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + peer1 := unittest.PeerIdFixture(t) + peer2 := unittest.PeerIdFixture(t) + topic1 := "topic1" + topic2 := "topic2" + + // add topic1 to peer1 + topics, err := cache.AddTopicForPeer(peer1, topic1) + require.NoError(t, err, "adding first topic to peer1 should not produce an error") + require.Equal(t, []string{topic1}, topics, "updated topics should match the added topic") + + // add topic2 to peer2 + topics, err = cache.AddTopicForPeer(peer2, topic2) + require.NoError(t, err, "adding first topic to peer2 should not produce an error") + require.Equal(t, []string{topic2}, topics, "updated topics should match the added topic") + + // move to next update cycle + cache.MoveToNextUpdateCycle() + + // since we did not add any topic to the peers, the topics for the peers should be the same as before + topics, found := cache.GetSubscribedTopics(peer1) + require.True(t, found, "peer1 should be found") + require.ElementsMatch(t, []string{topic1}, topics, "retrieved topics should match the added topics") + + topics, found = cache.GetSubscribedTopics(peer2) + require.True(t, found, "peer2 should be found") + require.ElementsMatch(t, []string{topic2}, topics, "retrieved topics should match the added topics") + + // now add topic2 to peer1; it should overwrite the previous topics for peer1, but not affect the topics for peer2 + topics, err = cache.AddTopicForPeer(peer1, topic2) + require.NoError(t, err, "adding second topic to peer1 should not produce an error") + require.Equal(t, []string{topic2}, topics, "updated topics should match the added topic") + + topics, found = cache.GetSubscribedTopics(peer2) + require.True(t, found, "peer2 should be found") + require.ElementsMatch(t, []string{topic2}, topics, "retrieved topics should match the added topics") +} + +// TestSubscriptionCache_ConcurrentUpdate tests subscription cache update in a concurrent environment. +func TestSubscriptionCache_ConcurrentUpdate(t *testing.T) { + unittest.SkipUnless(t, unittest.TEST_TODO, "this test requires atomic AdjustOrGet method to be implemented for backend") + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + peerIds := unittest.PeerIdFixtures(t, 100) + topics := []string{"topic1", "topic2", "topic3"} + + allUpdatesDone := sync.WaitGroup{} + for _, pid := range peerIds { + for _, topic := range topics { + pid := pid + topic := topic + allUpdatesDone.Add(1) + go func() { + defer allUpdatesDone.Done() + _, err := cache.AddTopicForPeer(pid, topic) + require.NoError(t, err, "adding topic to peer should not produce an error") + }() + } + } + + unittest.RequireReturnsBefore(t, allUpdatesDone.Wait, 1*time.Second, "all updates did not finish in time") + + // verify that all peers have all topics; concurrently + allTopicsVerified := sync.WaitGroup{} + for _, pid := range peerIds { + pid := pid + allTopicsVerified.Add(1) + go func() { + defer allTopicsVerified.Done() + topics, found := cache.GetSubscribedTopics(pid) + require.True(t, found, "peer should be found") + require.ElementsMatch(t, topics, topics, "retrieved topics should match the added topics") + }() + } + + unittest.RequireReturnsBefore(t, allTopicsVerified.Wait, 1*time.Second, "all topics were not verified in time") +} + +// TestSubscriptionCache_TestSizeLimit tests that the cache evicts the least recently used peer when the cache size limit is reached. +func TestSubscriptionCache_TestSizeLimit(t *testing.T) { + sizeLimit := uint32(100) + cache := internal.NewSubscriptionRecordCache( + sizeLimit, + unittest.Logger(), + metrics.NewSubscriptionRecordCacheMetricsFactory(metrics.NewNoopHeroCacheMetricsFactory())) + + peerIds := unittest.PeerIdFixtures(t, 100) + topics := []string{"topic1", "topic2", "topic3"} + + // add topics to peers + for _, pid := range peerIds { + for _, topic := range topics { + _, err := cache.AddTopicForPeer(pid, topic) + require.NoError(t, err, "adding topic to peer should not produce an error") + } + } + + // verify that all peers have all topics + for _, pid := range peerIds { + topics, found := cache.GetSubscribedTopics(pid) + require.True(t, found, "peer should be found") + require.ElementsMatch(t, topics, topics, "retrieved topics should match the added topics") + } + + // add one more peer and verify that the first peer is evicted + newPeerID := unittest.PeerIdFixture(t) + _, err := cache.AddTopicForPeer(newPeerID, topics[0]) + require.NoError(t, err, "adding topic to peer should not produce an error") + + _, found := cache.GetSubscribedTopics(peerIds[0]) + require.False(t, found, "peer should not be found") + + // verify that all other peers still have all topics + for _, pid := range peerIds[1:] { + topics, found := cache.GetSubscribedTopics(pid) + require.True(t, found, "peer should be found") + require.ElementsMatch(t, topics, topics, "retrieved topics should match the added topics") + } + + // verify that the new peer has the topic + topics, found = cache.GetSubscribedTopics(newPeerID) + require.True(t, found, "peer should be found") + require.ElementsMatch(t, topics, topics, "retrieved topics should match the added topics") +} diff --git a/network/p2p/scoring/internal/subscriptionRecord.go b/network/p2p/scoring/internal/subscriptionRecord.go new file mode 100644 index 00000000000..2ac6946c25b --- /dev/null +++ b/network/p2p/scoring/internal/subscriptionRecord.go @@ -0,0 +1,38 @@ +package internal + +import ( + "github.com/libp2p/go-libp2p/core/peer" + + "github.com/onflow/flow-go/model/flow" +) + +// SubscriptionRecordEntity is an entity that represents a the list of topics a peer is subscribed to. +// It is internally used by the SubscriptionRecordCache to store the subscription records in the cache. +type SubscriptionRecordEntity struct { + // entityId is the key of the entity in the cache. It is the hash of the peer id. + // It is intentionally encoded as part of the struct to avoid recomputing it. + entityId flow.Identifier + + // PeerID is the peer id of the peer that is the owner of the subscription. + PeerID peer.ID + + // Topics is the list of topics the peer is subscribed to. + Topics []string + + // LastUpdatedCycle is the last cycle counter value that this record was updated. + // This is used to clean up old records' topics upon update. + LastUpdatedCycle uint64 +} + +var _ flow.Entity = (*SubscriptionRecordEntity)(nil) + +// ID returns the entity id of the subscription record, which is the hash of the peer id. +func (s SubscriptionRecordEntity) ID() flow.Identifier { + return s.entityId +} + +// Checksum returns the entity id of the subscription record, which is the hash of the peer id. +// It is of no use in the cache, but it is implemented to satisfy the flow.Entity interface. +func (s SubscriptionRecordEntity) Checksum() flow.Identifier { + return s.ID() +} diff --git a/network/p2p/scoring/registry.go b/network/p2p/scoring/registry.go index f30ef63646a..bc4d81443ce 100644 --- a/network/p2p/scoring/registry.go +++ b/network/p2p/scoring/registry.go @@ -9,6 +9,8 @@ import ( "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" + "github.com/onflow/flow-go/module/component" + "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/network/p2p" netcache "github.com/onflow/flow-go/network/p2p/cache" p2pmsg "github.com/onflow/flow-go/network/p2p/message" @@ -80,6 +82,7 @@ func DefaultGossipSubCtrlMsgPenaltyValue() GossipSubCtrlMsgPenaltyValue { // Similar to the GossipSub score, the application specific score is meant to be private to the local peer, and is not // shared with other peers in the network. type GossipSubAppSpecificScoreRegistry struct { + component.Component logger zerolog.Logger idProvider module.IdentityProvider // spamScoreCache currently only holds the control message misbehaviour penalty (spam related penalty). @@ -133,6 +136,26 @@ func NewGossipSubAppSpecificScoreRegistry(config *GossipSubAppSpecificScoreRegis idProvider: config.IdProvider, } + builder := component.NewComponentManagerBuilder() + builder.AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + reg.logger.Info().Msg("starting subscription validator") + reg.validator.Start(ctx) + select { + case <-ctx.Done(): + reg.logger.Warn().Msg("aborting subscription validator startup, context cancelled") + case <-reg.validator.Ready(): + reg.logger.Info().Msg("subscription validator started") + ready() + reg.logger.Info().Msg("subscription validator is ready") + } + + <-ctx.Done() + reg.logger.Info().Msg("stopping subscription validator") + <-reg.validator.Done() + reg.logger.Info().Msg("subscription validator stopped") + }) + reg.Component = builder.Build() + return reg } diff --git a/network/p2p/scoring/score_option.go b/network/p2p/scoring/score_option.go index 0ae676005cb..b3585d108fb 100644 --- a/network/p2p/scoring/score_option.go +++ b/network/p2p/scoring/score_option.go @@ -9,6 +9,8 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/module" + "github.com/onflow/flow-go/module/component" + "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/p2p" @@ -292,7 +294,9 @@ const ( ) // ScoreOption is a functional option for configuring the peer scoring system. +// TODO: rename it to ScoreManager. type ScoreOption struct { + component.Component logger zerolog.Logger peerScoreParams *pubsub.PeerScoreParams @@ -378,7 +382,7 @@ func (c *ScoreOptionConfig) OverrideDecayInterval(interval time.Duration) { } // NewScoreOption creates a new penalty option with the given configuration. -func NewScoreOption(cfg *ScoreOptionConfig) *ScoreOption { +func NewScoreOption(cfg *ScoreOptionConfig, provider p2p.SubscriptionProvider) *ScoreOption { throttledSampler := logging.BurstSampler(MaxDebugLogs, time.Second) logger := cfg.logger.With(). Str("module", "pubsub_score_option"). @@ -387,7 +391,7 @@ func NewScoreOption(cfg *ScoreOptionConfig) *ScoreOption { TraceSampler: throttledSampler, DebugSampler: throttledSampler, }) - validator := NewSubscriptionValidator() + validator := NewSubscriptionValidator(cfg.logger, provider) scoreRegistry := NewGossipSubAppSpecificScoreRegistry(&GossipSubAppSpecificScoreRegistryConfig{ Logger: logger, Penalty: DefaultGossipSubCtrlMsgPenaltyValue(), @@ -436,11 +440,26 @@ func NewScoreOption(cfg *ScoreOptionConfig) *ScoreOption { for _, topicParams := range cfg.topicParams { topicParams(s.peerScoreParams.Topics) } - return s -} -func (s *ScoreOption) SetSubscriptionProvider(provider *SubscriptionProvider) error { - return s.validator.RegisterSubscriptionProvider(provider) + s.Component = component.NewComponentManagerBuilder().AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + s.logger.Info().Msg("starting score registry") + scoreRegistry.Start(ctx) + select { + case <-ctx.Done(): + s.logger.Warn().Msg("stopping score registry; context done") + case <-scoreRegistry.Ready(): + s.logger.Info().Msg("score registry started") + ready() + s.logger.Info().Msg("score registry ready") + } + + <-ctx.Done() + s.logger.Info().Msg("stopping score registry") + <-scoreRegistry.Done() + s.logger.Info().Msg("score registry stopped") + }).Build() + + return s } func (s *ScoreOption) BuildFlowPubSubScoreOption() (*pubsub.PeerScoreParams, *pubsub.PeerScoreThresholds) { diff --git a/network/p2p/scoring/scoring_test.go b/network/p2p/scoring/scoring_test.go index c30736ae994..906cc0b2fc6 100644 --- a/network/p2p/scoring/scoring_test.go +++ b/network/p2p/scoring/scoring_test.go @@ -46,10 +46,6 @@ func (m *mockInspectorSuite) AddInvalidControlMessageConsumer(consumer p2p.Gossi func (m *mockInspectorSuite) ActiveClustersChanged(_ flow.ChainIDList) { // no-op } -func (m *mockInspectorSuite) SetTopicOracle(_ func() []string) error { - // no-op - return nil -} // newMockInspectorSuite creates a new mockInspectorSuite. // Args: @@ -97,7 +93,8 @@ func TestInvalidCtrlMsgScoringIntegration(t *testing.T) { module.GossipSubMetrics, metrics.HeroCacheMetricsFactory, flownet.NetworkingType, - module.IdentityProvider) (p2p.GossipSubInspectorSuite, error) { + module.IdentityProvider, + func() p2p.TopicProvider) (p2p.GossipSubInspectorSuite, error) { // override the gossipsub rpc inspector suite factory to return the mock inspector suite return inspectorSuite1, nil } diff --git a/network/p2p/scoring/subscriptionCache.go b/network/p2p/scoring/subscriptionCache.go new file mode 100644 index 00000000000..8eae60bd385 --- /dev/null +++ b/network/p2p/scoring/subscriptionCache.go @@ -0,0 +1,35 @@ +package scoring + +import "github.com/libp2p/go-libp2p/core/peer" + +// SubscriptionCache implements an in-memory cache that keeps track of the topics a peer is subscribed to. +// The cache is modeled abstracted to be used in update cycles, i.e., every regular interval of time, the cache is updated for +// all peers. +type SubscriptionCache interface { + // GetSubscribedTopics returns the list of topics a peer is subscribed to. + // Returns: + // - []string: the list of topics the peer is subscribed to. + // - bool: true if there is a record for the peer, false otherwise. + GetSubscribedTopics(pid peer.ID) ([]string, bool) + + // MoveToNextUpdateCycle moves the subscription cache to the next update cycle. + // A new update cycle is started when the subscription cache is first created, and then every time the subscription cache + // is updated. The update cycle is used to keep track of the last time the subscription cache was updated. It is used to + // implement a notion of time in the subscription cache. + // Returns: + // - uint64: the current update cycle. + MoveToNextUpdateCycle() uint64 + + // AddTopicForPeer appends a topic to the list of topics a peer is subscribed to. If the peer is not subscribed to any + // topics yet, a new record is created. + // If the last update cycle is older than the current cycle, the list of topics for the peer is first cleared, and then + // the topic is added to the list. This is to ensure that the list of topics for a peer is always up to date. + // Args: + // - pid: the peer id of the peer. + // - topic: the topic to add. + // Returns: + // - []string: the list of topics the peer is subscribed to after the update. + // - error: an error if the update failed; any returned error is an irrecoverable error and indicates a bug or misconfiguration. + // Implementation must be thread-safe. + AddTopicForPeer(pid peer.ID, topic string) ([]string, error) +} diff --git a/network/p2p/scoring/subscription_provider.go b/network/p2p/scoring/subscription_provider.go index 23aea760de1..4f6918a81a0 100644 --- a/network/p2p/scoring/subscription_provider.go +++ b/network/p2p/scoring/subscription_provider.go @@ -1,123 +1,160 @@ package scoring import ( - "sync" + "fmt" + "time" + "github.com/go-playground/validator/v10" "github.com/libp2p/go-libp2p/core/peer" "github.com/rs/zerolog" "go.uber.org/atomic" + "github.com/onflow/flow-go/module" + "github.com/onflow/flow-go/module/component" + "github.com/onflow/flow-go/module/irrecoverable" + "github.com/onflow/flow-go/module/metrics" "github.com/onflow/flow-go/network/p2p" + "github.com/onflow/flow-go/network/p2p/p2pconf" + "github.com/onflow/flow-go/network/p2p/p2plogging" + "github.com/onflow/flow-go/network/p2p/scoring/internal" + "github.com/onflow/flow-go/utils/logging" ) // SubscriptionProvider provides a list of topics a peer is subscribed to. type SubscriptionProvider struct { - logger zerolog.Logger - tp p2p.TopicProvider + component.Component + logger zerolog.Logger + topicProviderOracle func() p2p.TopicProvider - // allTopics is a list of all topics in the pubsub network // TODO: we should add an expiry time to this cache and clean up the cache periodically // to avoid leakage of stale topics. - peersByTopic sync.Map // map[topic]peers - peersByTopicUpdating sync.Map // whether a goroutine is already updating the list of peers for a topic + cache SubscriptionCache + + // idProvider translates the peer ids to flow ids. + idProvider module.IdentityProvider // allTopics is a list of all topics in the pubsub network that this node is subscribed to. - allTopicsLock sync.RWMutex // protects allTopics - allTopics []string // list of all topics in the pubsub network that this node has subscribed to. - allTopicsUpdate atomic.Bool // whether a goroutine is already updating the list of topics. + allTopicsUpdate atomic.Bool // whether a goroutine is already updating the list of topics + allTopicsUpdateInterval time.Duration // the interval for updating the list of topics in the pubsub network that this node has subscribed to. } -func NewSubscriptionProvider(logger zerolog.Logger, tp p2p.TopicProvider) *SubscriptionProvider { - return &SubscriptionProvider{ - logger: logger.With().Str("module", "subscription_provider").Logger(), - tp: tp, - allTopics: make([]string, 0), - } +type SubscriptionProviderConfig struct { + Logger zerolog.Logger `validate:"required"` + TopicProviderOracle func() p2p.TopicProvider `validate:"required"` + IdProvider module.IdentityProvider `validate:"required"` + HeroCacheMetricsFactory metrics.HeroCacheMetricsFactory `validate:"required"` + Params *p2pconf.SubscriptionProviderParameters `validate:"required"` } -// GetSubscribedTopics returns all the subscriptions of a peer within the pubsub network. -// Note that the current node can only see peer subscriptions to topics that it has also subscribed to -// e.g., if current node has subscribed to topics A and B, and peer1 has subscribed to topics A, B, and C, -// then GetSubscribedTopics(peer1) will return A and B. Since this node has not subscribed to topic C, -// it will not be able to query for other peers subscribed to topic C. -func (s *SubscriptionProvider) GetSubscribedTopics(pid peer.ID) []string { - topics := s.getAllTopics() +var _ p2p.SubscriptionProvider = (*SubscriptionProvider)(nil) - // finds the topics that this peer is subscribed to. - subscriptions := make([]string, 0) - for _, topic := range topics { - peers := s.getPeersByTopic(topic) - for _, p := range peers { - if p == pid { - subscriptions = append(subscriptions, topic) - } - } +func NewSubscriptionProvider(cfg *SubscriptionProviderConfig) (*SubscriptionProvider, error) { + if err := validator.New().Struct(cfg); err != nil { + return nil, fmt.Errorf("invalid subscription provider config: %w", err) } - return subscriptions -} + cacheMetrics := metrics.NewSubscriptionRecordCacheMetricsFactory(cfg.HeroCacheMetricsFactory) + cache := internal.NewSubscriptionRecordCache(cfg.Params.CacheSize, cfg.Logger, cacheMetrics) -// getAllTopics returns all the topics in the pubsub network that this node (peer) has subscribed to. -// Note that this method always returns the cached version of the subscribed topics while querying the -// pubsub network for the list of topics in a goroutine. Hence, the first call to this method always returns an empty -// list. -func (s *SubscriptionProvider) getAllTopics() []string { - go func() { - // TODO: refactor this to a component manager worker once we have a startable libp2p node. - if updateInProgress := s.allTopicsUpdate.CompareAndSwap(false, true); updateInProgress { - // another goroutine is already updating the list of topics - return - } + p := &SubscriptionProvider{ + logger: cfg.Logger.With().Str("module", "subscription_provider").Logger(), + topicProviderOracle: cfg.TopicProviderOracle, + allTopicsUpdateInterval: cfg.Params.SubscriptionUpdateInterval, + idProvider: cfg.IdProvider, + cache: cache, + } + + builder := component.NewComponentManagerBuilder() + p.Component = builder.AddWorker( + func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + ready() + p.logger.Debug(). + Float64("update_interval_seconds", cfg.Params.SubscriptionUpdateInterval.Seconds()). + Msg("subscription provider started; starting update topics loop") + p.updateTopicsLoop(ctx) - allTopics := s.tp.GetTopics() - s.atomicUpdateAllTopics(allTopics) + <-ctx.Done() + p.logger.Debug().Msg("subscription provider stopped; stopping update topics loop") + }).Build() - // remove the update flag - s.allTopicsUpdate.Store(false) + return p, nil +} - s.logger.Trace().Msgf("all topics updated: %v", allTopics) - }() +func (s *SubscriptionProvider) updateTopicsLoop(ctx irrecoverable.SignalerContext) { + ticker := time.NewTicker(s.allTopicsUpdateInterval) + defer ticker.Stop() - s.allTopicsLock.RLock() - defer s.allTopicsLock.RUnlock() - return s.allTopics + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.updateTopics(); err != nil { + ctx.Throw(fmt.Errorf("update loop failed: %w", err)) + return + } + } + } } -// getPeersByTopic returns all the peers subscribed to a topic. -// Note that this method always returns the cached version of the subscribed peers while querying the +// updateTopics returns all the topics in the pubsub network that this node (peer) has subscribed to. +// Note that this method always returns the cached version of the subscribed topics while querying the // pubsub network for the list of topics in a goroutine. Hence, the first call to this method always returns an empty // list. -// As this method is injected into GossipSub, it is vital that it never block the caller, otherwise it causes a -// deadlock on the GossipSub. -// Also note that, this peer itself should be subscribed to the topic, otherwise, it cannot find the list of peers -// subscribed to the topic in the pubsub network due to an inherent limitation of GossipSub. -func (s *SubscriptionProvider) getPeersByTopic(topic string) []peer.ID { - go func() { - // TODO: refactor this to a component manager worker once we have a startable libp2p node. - if _, updateInProgress := s.peersByTopicUpdating.LoadOrStore(topic, true); updateInProgress { - // another goroutine is already updating the list of peers for this topic - return - } +// Args: +// - ctx: the context of the caller. +// Returns: +// - error on failure to update the list of topics. The returned error is irrecoverable and indicates an exception. +func (s *SubscriptionProvider) updateTopics() error { + if updateInProgress := s.allTopicsUpdate.CompareAndSwap(false, true); updateInProgress { + // another goroutine is already updating the list of topics + s.logger.Trace().Msg("skipping topic update; another update is already in progress") + return nil + } - subscribedPeers := s.tp.ListPeers(topic) - s.peersByTopic.Store(topic, subscribedPeers) + // start of critical section; protected by updateInProgress atomic flag + allTopics := s.topicProviderOracle().GetTopics() + s.logger.Trace().Msgf("all topics updated: %v", allTopics) - // remove the update flag - s.peersByTopicUpdating.Delete(topic) + // increments the update cycle of the cache; so that the previous cache entries are invalidated upon a read or write. + s.cache.MoveToNextUpdateCycle() + for _, topic := range allTopics { + peers := s.topicProviderOracle().ListPeers(topic) - s.logger.Trace().Str("topic", topic).Msgf("peers by topic updated: %v", subscribedPeers) - }() + for _, p := range peers { + if _, authorized := s.idProvider.ByPeerID(p); !authorized { + // peer is not authorized (staked); hence it does not have a valid role in the network; and + // we skip the topic update for this peer (also avoiding sybil attacks on the cache). + s.logger.Debug(). + Str("remote_peer_id", p2plogging.PeerId(p)). + Bool(logging.KeyNetworkingSecurity, true). + Msg("skipping topic update for unauthorized peer") + continue + } - peerId, ok := s.peersByTopic.Load(topic) - if !ok { - return make([]peer.ID, 0) + updatedTopics, err := s.cache.AddTopicForPeer(p, topic) + if err != nil { + // this is an irrecoverable error; hence, we crash the node. + return fmt.Errorf("failed to update topics for peer %s: %w", p, err) + } + s.logger.Debug(). + Str("remote_peer_id", p2plogging.PeerId(p)). + Strs("updated_topics", updatedTopics). + Msg("updated topics for peer") + } } - return peerId.([]peer.ID) + + // remove the update flag; end of critical section + s.allTopicsUpdate.Store(false) + return nil } -// atomicUpdateAllTopics updates the list of all topics in the pubsub network that this node has subscribed to. -func (s *SubscriptionProvider) atomicUpdateAllTopics(allTopics []string) { - s.allTopicsLock.Lock() - s.allTopics = allTopics - s.allTopicsLock.Unlock() +// GetSubscribedTopics returns all the subscriptions of a peer within the pubsub network. +func (s *SubscriptionProvider) GetSubscribedTopics(pid peer.ID) []string { + topics, ok := s.cache.GetSubscribedTopics(pid) + if !ok { + s.logger.Trace().Str("peer_id", p2plogging.PeerId(pid)).Msg("no topics found for peer") + return nil + } + return topics } diff --git a/network/p2p/scoring/subscription_provider_test.go b/network/p2p/scoring/subscription_provider_test.go index 25d4be455c8..cb3b45ecbd1 100644 --- a/network/p2p/scoring/subscription_provider_test.go +++ b/network/p2p/scoring/subscription_provider_test.go @@ -1,13 +1,21 @@ package scoring_test import ( + "context" "testing" "time" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/assert" + mockery "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" - "github.com/onflow/flow-go/network/internal/p2pfixtures" + "github.com/onflow/flow-go/config" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/irrecoverable" + "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/module/mock" + "github.com/onflow/flow-go/network/p2p" mockp2p "github.com/onflow/flow-go/network/p2p/mock" "github.com/onflow/flow-go/network/p2p/scoring" "github.com/onflow/flow-go/utils/slices" @@ -18,20 +26,47 @@ import ( // list of topics a peer is subscribed to. func TestSubscriptionProvider_GetSubscribedTopics(t *testing.T) { tp := mockp2p.NewTopicProvider(t) - sp := scoring.NewSubscriptionProvider(unittest.Logger(), tp) + cfg, err := config.DefaultConfig() + require.NoError(t, err) + idProvider := mock.NewIdentityProvider(t) + + // set a low update interval to speed up the test + cfg.NetworkConfig.SubscriptionProviderConfig.SubscriptionUpdateInterval = 100 * time.Millisecond + + sp, err := scoring.NewSubscriptionProvider(&scoring.SubscriptionProviderConfig{ + Logger: unittest.Logger(), + TopicProviderOracle: func() p2p.TopicProvider { + return tp + }, + Params: &cfg.NetworkConfig.SubscriptionProviderConfig, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + IdProvider: idProvider, + }) + require.NoError(t, err) tp.On("GetTopics").Return([]string{"topic1", "topic2", "topic3"}).Maybe() - peer1 := p2pfixtures.PeerIdFixture(t) - peer2 := p2pfixtures.PeerIdFixture(t) - peer3 := p2pfixtures.PeerIdFixture(t) + peer1 := unittest.PeerIdFixture(t) + peer2 := unittest.PeerIdFixture(t) + peer3 := unittest.PeerIdFixture(t) + + idProvider.On("ByPeerID", mockery.Anything).Return(unittest.IdentityFixture(), true).Maybe() // mock peers 1 and 2 subscribed to topic 1 (along with other random peers) - tp.On("ListPeers", "topic1").Return(append([]peer.ID{peer1, peer2}, p2pfixtures.PeerIdsFixture(t, 10)...)) + tp.On("ListPeers", "topic1").Return(append([]peer.ID{peer1, peer2}, unittest.PeerIdFixtures(t, 10)...)) // mock peers 2 and 3 subscribed to topic 2 (along with other random peers) - tp.On("ListPeers", "topic2").Return(append([]peer.ID{peer2, peer3}, p2pfixtures.PeerIdsFixture(t, 10)...)) + tp.On("ListPeers", "topic2").Return(append([]peer.ID{peer2, peer3}, unittest.PeerIdFixtures(t, 10)...)) // mock peers 1 and 3 subscribed to topic 3 (along with other random peers) - tp.On("ListPeers", "topic3").Return(append([]peer.ID{peer1, peer3}, p2pfixtures.PeerIdsFixture(t, 10)...)) + tp.On("ListPeers", "topic3").Return(append([]peer.ID{peer1, peer3}, unittest.PeerIdFixtures(t, 10)...)) + + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + unittest.RequireCloseBefore(t, sp.Done(), 1*time.Second, "subscription provider did not stop in time") + }() + signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) + sp.Start(signalerCtx) + unittest.RequireCloseBefore(t, sp.Ready(), 1*time.Second, "subscription provider did not start in time") // As the calls to the TopicProvider are asynchronous, we need to wait for the goroutines to finish. assert.Eventually(t, func() bool { @@ -46,3 +81,76 @@ func TestSubscriptionProvider_GetSubscribedTopics(t *testing.T) { return slices.AreStringSlicesEqual([]string{"topic2", "topic3"}, sp.GetSubscribedTopics(peer3)) }, 1*time.Second, 100*time.Millisecond) } + +// TestSubscriptionProvider_GetSubscribedTopics_SkippingUnknownPeers tests that the SubscriptionProvider skips +// unknown peers when returning the list of topics a peer is subscribed to. In other words, if a peer is unknown, +// the SubscriptionProvider should not keep track of its subscriptions. +func TestSubscriptionProvider_GetSubscribedTopics_SkippingUnknownPeers(t *testing.T) { + tp := mockp2p.NewTopicProvider(t) + cfg, err := config.DefaultConfig() + require.NoError(t, err) + idProvider := mock.NewIdentityProvider(t) + + // set a low update interval to speed up the test + cfg.NetworkConfig.SubscriptionProviderConfig.SubscriptionUpdateInterval = 100 * time.Millisecond + + sp, err := scoring.NewSubscriptionProvider(&scoring.SubscriptionProviderConfig{ + Logger: unittest.Logger(), + TopicProviderOracle: func() p2p.TopicProvider { + return tp + }, + Params: &cfg.NetworkConfig.SubscriptionProviderConfig, + HeroCacheMetricsFactory: metrics.NewNoopHeroCacheMetricsFactory(), + IdProvider: idProvider, + }) + require.NoError(t, err) + + tp.On("GetTopics").Return([]string{"topic1", "topic2", "topic3"}).Maybe() + + peer1 := unittest.PeerIdFixture(t) + peer2 := unittest.PeerIdFixture(t) + peer3 := unittest.PeerIdFixture(t) + + // mock peers 1 and 2 as a known peer; peer 3 as an unknown peer + idProvider.On("ByPeerID", mockery.Anything). + Return(func(pid peer.ID) *flow.Identity { + if pid == peer1 || pid == peer2 { + return unittest.IdentityFixture() + } + return nil + }, func(pid peer.ID) bool { + if pid == peer1 || pid == peer2 { + return true + } + return false + }).Maybe() + + // mock peers 1 and 2 subscribed to topic 1 (along with other random peers) + tp.On("ListPeers", "topic1").Return(append([]peer.ID{peer1, peer2}, unittest.PeerIdFixtures(t, 10)...)) + // mock peers 2 and 3 subscribed to topic 2 (along with other random peers) + tp.On("ListPeers", "topic2").Return(append([]peer.ID{peer2, peer3}, unittest.PeerIdFixtures(t, 10)...)) + // mock peers 1 and 3 subscribed to topic 3 (along with other random peers) + tp.On("ListPeers", "topic3").Return(append([]peer.ID{peer1, peer3}, unittest.PeerIdFixtures(t, 10)...)) + + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + unittest.RequireCloseBefore(t, sp.Done(), 1*time.Second, "subscription provider did not stop in time") + }() + signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) + sp.Start(signalerCtx) + unittest.RequireCloseBefore(t, sp.Ready(), 1*time.Second, "subscription provider did not start in time") + + // As the calls to the TopicProvider are asynchronous, we need to wait for the goroutines to finish. + // peer 1 should be eventually subscribed to topic 1 and topic 3; while peer 3 should not have any subscriptions record since it is unknown + assert.Eventually(t, func() bool { + return slices.AreStringSlicesEqual([]string{"topic1", "topic3"}, sp.GetSubscribedTopics(peer1)) && + slices.AreStringSlicesEqual([]string{}, sp.GetSubscribedTopics(peer3)) + }, 1*time.Second, 100*time.Millisecond) + + // peer 2 should be eventually subscribed to topic 1 and topic 2; while peer 3 should not have any subscriptions record since it is unknown + assert.Eventually(t, func() bool { + return slices.AreStringSlicesEqual([]string{"topic1", "topic2"}, sp.GetSubscribedTopics(peer2)) && + slices.AreStringSlicesEqual([]string{}, sp.GetSubscribedTopics(peer3)) + }, 1*time.Second, 100*time.Millisecond) +} diff --git a/network/p2p/scoring/subscription_validator.go b/network/p2p/scoring/subscription_validator.go index fbffe27752a..8c3fc048168 100644 --- a/network/p2p/scoring/subscription_validator.go +++ b/network/p2p/scoring/subscription_validator.go @@ -1,47 +1,55 @@ package scoring import ( - "fmt" - "github.com/libp2p/go-libp2p/core/peer" + "github.com/rs/zerolog" "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/component" + "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/network/p2p" + "github.com/onflow/flow-go/network/p2p/p2plogging" p2putils "github.com/onflow/flow-go/network/p2p/utils" ) // SubscriptionValidator validates that a peer is subscribed to topics that it is allowed to subscribe to. // It is used to penalize peers that subscribe to topics that they are not allowed to subscribe to in GossipSub. type SubscriptionValidator struct { + component.Component + logger zerolog.Logger subscriptionProvider p2p.SubscriptionProvider } -func NewSubscriptionValidator() *SubscriptionValidator { - return &SubscriptionValidator{} -} +func NewSubscriptionValidator(logger zerolog.Logger, provider p2p.SubscriptionProvider) *SubscriptionValidator { + v := &SubscriptionValidator{ + logger: logger.With().Str("component", "subscription_validator").Logger(), + subscriptionProvider: provider, + } -var _ p2p.SubscriptionValidator = (*SubscriptionValidator)(nil) + v.Component = component.NewComponentManagerBuilder(). + AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + logger.Debug().Msg("starting subscription validator") + v.subscriptionProvider.Start(ctx) + select { + case <-ctx.Done(): + logger.Debug().Msg("subscription validator is stopping") + case <-v.subscriptionProvider.Ready(): + logger.Debug().Msg("subscription validator started") + ready() + logger.Debug().Msg("subscription validator is ready") + } -// RegisterSubscriptionProvider registers the subscription provider with the subscription validator. -// This follows a dependency injection pattern. -// Args: -// -// provider: the subscription provider -// -// Returns: -// -// error: if the subscription provider is nil, an error is returned. The error is irrecoverable, i.e., -// it indicates an illegal state in the execution of the code. We expect this error only when there is a bug in the code. -// Such errors should lead to a crash of the node. -func (v *SubscriptionValidator) RegisterSubscriptionProvider(provider p2p.SubscriptionProvider) error { - if v.subscriptionProvider != nil { - return fmt.Errorf("subscription provider already registered") - } - v.subscriptionProvider = provider + <-ctx.Done() + logger.Debug().Msg("subscription validator is stopping") + <-v.subscriptionProvider.Done() + logger.Debug().Msg("subscription validator stopped") + }).Build() - return nil + return v } +var _ p2p.SubscriptionValidator = (*SubscriptionValidator)(nil) + // CheckSubscribedToAllowedTopics checks if a peer is subscribed to topics that it is allowed to subscribe to. // Args: // @@ -53,7 +61,10 @@ func (v *SubscriptionValidator) RegisterSubscriptionProvider(provider p2p.Subscr // The error is benign, i.e., it does not indicate an illegal state in the execution of the code. We expect this error // when there are malicious peers in the network. But such errors should not lead to a crash of the node. func (v *SubscriptionValidator) CheckSubscribedToAllowedTopics(pid peer.ID, role flow.Role) error { + lg := v.logger.With().Str("remote_peer_id", p2plogging.PeerId(pid)).Logger() + topics := v.subscriptionProvider.GetSubscribedTopics(pid) + lg.Trace().Strs("topics", topics).Msg("checking subscription for remote peer id") for _, topic := range topics { if !p2putils.AllowedSubscription(role, topic) { @@ -61,5 +72,6 @@ func (v *SubscriptionValidator) CheckSubscribedToAllowedTopics(pid peer.ID, role } } + lg.Trace().Msg("subscription is valid") return nil } diff --git a/network/p2p/scoring/subscription_validator_test.go b/network/p2p/scoring/subscription_validator_test.go index 338d26d67c5..770f74cf146 100644 --- a/network/p2p/scoring/subscription_validator_test.go +++ b/network/p2p/scoring/subscription_validator_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/onflow/flow-go/config" "github.com/onflow/flow-go/network/message" "github.com/onflow/flow-go/network/p2p" p2ptest "github.com/onflow/flow-go/network/p2p/test" @@ -32,12 +33,10 @@ import ( // any topic, the subscription validator returns no error. func TestSubscriptionValidator_NoSubscribedTopic(t *testing.T) { sp := mockp2p.NewSubscriptionProvider(t) - - sv := scoring.NewSubscriptionValidator() - require.NoError(t, sv.RegisterSubscriptionProvider(sp)) + sv := scoring.NewSubscriptionValidator(unittest.Logger(), sp) // mocks peer 1 not subscribed to any topic. - peer1 := p2pfixtures.PeerIdFixture(t) + peer1 := unittest.PeerIdFixture(t) sp.On("GetSubscribedTopics", peer1).Return([]string{}) // as peer 1 has not subscribed to any topic, the subscription validator should return no error regardless of the @@ -51,11 +50,10 @@ func TestSubscriptionValidator_NoSubscribedTopic(t *testing.T) { // topic, the subscription validator returns an error. func TestSubscriptionValidator_UnknownChannel(t *testing.T) { sp := mockp2p.NewSubscriptionProvider(t) - sv := scoring.NewSubscriptionValidator() - require.NoError(t, sv.RegisterSubscriptionProvider(sp)) + sv := scoring.NewSubscriptionValidator(unittest.Logger(), sp) // mocks peer 1 not subscribed to an unknown topic. - peer1 := p2pfixtures.PeerIdFixture(t) + peer1 := unittest.PeerIdFixture(t) sp.On("GetSubscribedTopics", peer1).Return([]string{"unknown-topic-1", "unknown-topic-2"}) // as peer 1 has subscribed to unknown topics, the subscription validator should return an error @@ -71,11 +69,10 @@ func TestSubscriptionValidator_UnknownChannel(t *testing.T) { // topics based on its Flow protocol role, the subscription validator returns no error. func TestSubscriptionValidator_ValidSubscriptions(t *testing.T) { sp := mockp2p.NewSubscriptionProvider(t) - sv := scoring.NewSubscriptionValidator() - require.NoError(t, sv.RegisterSubscriptionProvider(sp)) + sv := scoring.NewSubscriptionValidator(unittest.Logger(), sp) for _, role := range flow.Roles() { - peerId := p2pfixtures.PeerIdFixture(t) + peerId := unittest.PeerIdFixture(t) // allowed channels for the role excluding the test channels. allowedChannels := channels.ChannelsByRole(role).ExcludePattern(regexp.MustCompile("^(test).*")) sporkID := unittest.IdentifierFixture() @@ -102,8 +99,7 @@ func TestSubscriptionValidator_ValidSubscriptions(t *testing.T) { // is no longer true. func TestSubscriptionValidator_SubscribeToAllTopics(t *testing.T) { sp := mockp2p.NewSubscriptionProvider(t) - sv := scoring.NewSubscriptionValidator() - require.NoError(t, sv.RegisterSubscriptionProvider(sp)) + sv := scoring.NewSubscriptionValidator(unittest.Logger(), sp) allChannels := channels.Channels().ExcludePattern(regexp.MustCompile("^(test).*")) sporkID := unittest.IdentifierFixture() @@ -113,7 +109,7 @@ func TestSubscriptionValidator_SubscribeToAllTopics(t *testing.T) { } for _, role := range flow.Roles() { - peerId := p2pfixtures.PeerIdFixture(t) + peerId := unittest.PeerIdFixture(t) sp.On("GetSubscribedTopics", peerId).Return(allTopics) err := sv.CheckSubscribedToAllowedTopics(peerId, role) require.Error(t, err, role) @@ -125,11 +121,10 @@ func TestSubscriptionValidator_SubscribeToAllTopics(t *testing.T) { // topics based on its Flow protocol role, the subscription validator returns an error. func TestSubscriptionValidator_InvalidSubscriptions(t *testing.T) { sp := mockp2p.NewSubscriptionProvider(t) - sv := scoring.NewSubscriptionValidator() - require.NoError(t, sv.RegisterSubscriptionProvider(sp)) + sv := scoring.NewSubscriptionValidator(unittest.Logger(), sp) for _, role := range flow.Roles() { - peerId := p2pfixtures.PeerIdFixture(t) + peerId := unittest.PeerIdFixture(t) unauthorizedChannels := channels.Channels(). // all channels ExcludeChannels(channels.ChannelsByRole(role)). // excluding the channels for the role ExcludePattern(regexp.MustCompile("^(test).*")) // excluding the test channels. @@ -172,6 +167,11 @@ func TestSubscriptionValidator_Integration(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) + cfg, err := config.DefaultConfig() + require.NoError(t, err) + // set a low update interval to speed up the test + cfg.NetworkConfig.SubscriptionProviderConfig.SubscriptionUpdateInterval = 100 * time.Millisecond + sporkId := unittest.IdentifierFixture() idProvider := mock.NewIdentityProvider(t) @@ -179,6 +179,7 @@ func TestSubscriptionValidator_Integration(t *testing.T) { conNode, conId := p2ptest.NodeFixture(t, sporkId, t.Name(), idProvider, p2ptest.WithLogger(unittest.Logger()), + p2ptest.OverrideFlowConfig(cfg), p2ptest.EnablePeerScoringWithOverride(p2p.PeerScoringConfigNoOverride), p2ptest.WithRole(flow.RoleConsensus)) @@ -186,12 +187,14 @@ func TestSubscriptionValidator_Integration(t *testing.T) { verNode1, verId1 := p2ptest.NodeFixture(t, sporkId, t.Name(), idProvider, p2ptest.WithLogger(unittest.Logger()), + p2ptest.OverrideFlowConfig(cfg), p2ptest.EnablePeerScoringWithOverride(p2p.PeerScoringConfigNoOverride), p2ptest.WithRole(flow.RoleVerification)) verNode2, verId2 := p2ptest.NodeFixture(t, sporkId, t.Name(), idProvider, p2ptest.WithLogger(unittest.Logger()), + p2ptest.OverrideFlowConfig(cfg), p2ptest.EnablePeerScoringWithOverride(p2p.PeerScoringConfigNoOverride), p2ptest.WithRole(flow.RoleVerification)) diff --git a/network/p2p/scoring/utils_test.go b/network/p2p/scoring/utils_test.go index 5a458e1a730..1da111bf748 100644 --- a/network/p2p/scoring/utils_test.go +++ b/network/p2p/scoring/utils_test.go @@ -6,14 +6,13 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/flow-go/module/mock" - "github.com/onflow/flow-go/network/internal/p2pfixtures" "github.com/onflow/flow-go/network/p2p/scoring" "github.com/onflow/flow-go/utils/unittest" ) // TestHasValidIdentity_Unknown tests that when a peer has an unknown identity, the HasValidIdentity returns InvalidPeerIDError func TestHasValidIdentity_Unknown(t *testing.T) { - peerId := p2pfixtures.PeerIdFixture(t) + peerId := unittest.PeerIdFixture(t) idProvider := mock.NewIdentityProvider(t) idProvider.On("ByPeerID", peerId).Return(nil, false) @@ -30,7 +29,7 @@ func TestHasValidIdentity_Ejected(t *testing.T) { ejectedIdentity := unittest.IdentityFixture() ejectedIdentity.Ejected = true - peerId := p2pfixtures.PeerIdFixture(t) + peerId := unittest.PeerIdFixture(t) idProvider.On("ByPeerID", peerId).Return(ejectedIdentity, true) identity, err := scoring.HasValidFlowIdentity(idProvider, peerId) @@ -45,7 +44,7 @@ func TestHasValidIdentity_Valid(t *testing.T) { idProvider := mock.NewIdentityProvider(t) trueID := unittest.IdentityFixture() - peerId := p2pfixtures.PeerIdFixture(t) + peerId := unittest.PeerIdFixture(t) idProvider.On("ByPeerID", peerId).Return(trueID, true) identity, err := scoring.HasValidFlowIdentity(idProvider, peerId) diff --git a/network/p2p/stream.go b/network/p2p/stream.go index 7b73187b100..a012ef8926c 100644 --- a/network/p2p/stream.go +++ b/network/p2p/stream.go @@ -12,14 +12,8 @@ import ( // it can create libp2p streams with finer granularity. type StreamFactory interface { SetStreamHandler(protocol.ID, network.StreamHandler) - // Connect connects host to peer with peerAddrInfo. - // All errors returned from this function can be considered benign. We expect the following errors during normal operations: - // - ErrSecurityProtocolNegotiationFailed this indicates there was an issue upgrading the connection. - // - ErrGaterDisallowedConnection this indicates the connection was disallowed by the gater. - // - There may be other unexpected errors from libp2p but they should be considered benign. - Connect(context.Context, peer.AddrInfo) error // NewStream creates a new stream on the libp2p host. // Expected errors during normal operations: // - ErrProtocolNotSupported this indicates remote node is running on a different spork. - NewStream(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) + NewStream(context.Context, peer.ID, protocol.ID) (network.Stream, error) } diff --git a/network/p2p/subscription.go b/network/p2p/subscription.go index 9d4a117d0bc..99212b566d1 100644 --- a/network/p2p/subscription.go +++ b/network/p2p/subscription.go @@ -7,10 +7,12 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/component" ) // SubscriptionProvider provides a list of topics a peer is subscribed to. type SubscriptionProvider interface { + component.Component // GetSubscribedTopics returns all the subscriptions of a peer within the pubsub network. // Note that the current peer must be subscribed to the topic for it to the same topics in order // to query for other peers, e.g., if current peer has subscribed to topics A and B, and peer1 @@ -22,9 +24,7 @@ type SubscriptionProvider interface { // SubscriptionValidator validates the subscription of a peer to a topic. // It is used to ensure that a peer is only subscribed to topics that it is allowed to subscribe to. type SubscriptionValidator interface { - // RegisterSubscriptionProvider registers the subscription provider with the subscription validator. - // If there is a subscription provider already registered, it will be replaced by the new one. - RegisterSubscriptionProvider(provider SubscriptionProvider) error + component.Component // CheckSubscribedToAllowedTopics checks if a peer is subscribed to topics that it is allowed to subscribe to. // Args: // pid: the peer ID of the peer to check diff --git a/network/p2p/test/fixtures.go b/network/p2p/test/fixtures.go index c57928bda5b..ed9520d3ae1 100644 --- a/network/p2p/test/fixtures.go +++ b/network/p2p/test/fixtures.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/routing" discoveryBackoff "github.com/libp2p/go-libp2p/p2p/discovery/backoff" "github.com/rs/zerolog" + mockery "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/onflow/flow-go/config" @@ -32,6 +33,7 @@ import ( "github.com/onflow/flow-go/network/p2p" "github.com/onflow/flow-go/network/p2p/connection" p2pdht "github.com/onflow/flow-go/network/p2p/dht" + mockp2p "github.com/onflow/flow-go/network/p2p/mock" "github.com/onflow/flow-go/network/p2p/p2pbuilder" p2pconfig "github.com/onflow/flow-go/network/p2p/p2pbuilder/config" "github.com/onflow/flow-go/network/p2p/tracer" @@ -67,18 +69,19 @@ func NetworkingKeyFixtures(t *testing.T) crypto.PrivateKey { // NodeFixture is a test fixture that creates a single libp2p node with the given key, spork id, and options. // It returns the node and its identity. -func NodeFixture( - t *testing.T, sporkID flow.Identifier, dhtPrefix string, idProvider module.IdentityProvider, opts ...NodeFixtureParameterOption, -) (p2p.LibP2PNode, flow.Identity) { +func NodeFixture(t *testing.T, + sporkID flow.Identifier, + dhtPrefix string, + idProvider module.IdentityProvider, + opts ...NodeFixtureParameterOption) (p2p.LibP2PNode, flow.Identity) { defaultFlowConfig, err := config.DefaultConfig() require.NoError(t, err) logger := unittest.Logger().Level(zerolog.WarnLevel) require.NotNil(t, idProvider) - connectionGater := NewConnectionGater( - idProvider, func(p peer.ID) error { - return nil - }) + connectionGater := NewConnectionGater(idProvider, func(p peer.ID) error { + return nil + }) require.NotNil(t, connectionGater) meshTracerCfg := &tracer.GossipSubMeshTracerConfig{ @@ -126,10 +129,7 @@ func NodeFixture( logger = parameters.Logger.With().Hex("node_id", logging.ID(identity.NodeID)).Logger() - connManager, err := connection.NewConnManager( - logger, - parameters.MetricsCfg.Metrics, - &defaultFlowConfig.NetworkConfig.ConnectionManagerConfig) + connManager, err := connection.NewConnManager(logger, parameters.MetricsCfg.Metrics, ¶meters.FlowConfig.NetworkConfig.ConnectionManagerConfig) require.NoError(t, err) builder := p2pbuilder.NewNodeBuilder(logger, @@ -139,9 +139,10 @@ func NodeFixture( parameters.Key, sporkID, parameters.IdProvider, - &defaultFlowConfig.NetworkConfig.ResourceManager, + ¶meters.FlowConfig.NetworkConfig.ResourceManager, ¶meters.FlowConfig.NetworkConfig.GossipSubRPCInspectorsConfig, parameters.PeerManagerConfig, + ¶meters.FlowConfig.NetworkConfig.GossipSubConfig.SubscriptionProviderConfig, &p2p.DisallowListCacheConfig{ MaxSize: uint32(1000), Metrics: metrics.NewNoopCollector(), @@ -160,16 +161,14 @@ func NodeFixture( // Only access and execution nodes need to run DHT; // Access nodes and execution nodes need DHT to run a blob service. // Moreover, access nodes run a DHT to let un-staked (public) access nodes find each other on the public network. - builder.SetRoutingSystem( - func(ctx context.Context, host host.Host) (routing.Routing, error) { - return p2pdht.NewDHT( - ctx, - host, - protocol.ID(protocols.FlowDHTProtocolIDPrefix+sporkID.String()+"/"+dhtPrefix), - logger, - parameters.MetricsCfg.Metrics, - parameters.DhtOptions...) - }) + builder.SetRoutingSystem(func(ctx context.Context, host host.Host) (routing.Routing, error) { + return p2pdht.NewDHT(ctx, + host, + protocol.ID(protocols.FlowDHTProtocolIDPrefix+sporkID.String()+"/"+dhtPrefix), + logger, + parameters.MetricsCfg.Metrics, + parameters.DhtOptions...) + }) } if parameters.GossipSubRpcInspectorSuiteFactory != nil { @@ -432,15 +431,12 @@ func WithZeroJitterAndZeroBackoff(t *testing.T) func(*p2pconfig.PeerManagerConfi // NodesFixture is a test fixture that creates a number of libp2p nodes with the given callback function for stream handling. // It returns the nodes and their identities. -func NodesFixture( - t *testing.T, +func NodesFixture(t *testing.T, sporkID flow.Identifier, dhtPrefix string, count int, idProvider module.IdentityProvider, - opts ...NodeFixtureParameterOption) ( - []p2p.LibP2PNode, - flow.IdentityList) { + opts ...NodeFixtureParameterOption) ([]p2p.LibP2PNode, flow.IdentityList) { var nodes []p2p.LibP2PNode // creating nodes @@ -563,23 +559,22 @@ func TryConnectionAndEnsureConnected(t *testing.T, ctx context.Context, nodes [] // - tick: the tick duration // - timeout: the timeout duration func RequireConnectedEventually(t *testing.T, nodes []p2p.LibP2PNode, tick time.Duration, timeout time.Duration) { - require.Eventually( - t, func() bool { - for _, node := range nodes { - for _, other := range nodes { - if node == other { - continue - } - if node.Host().Network().Connectedness(other.ID()) != network.Connected { - return false - } - if len(node.Host().Network().ConnsToPeer(other.ID())) == 0 { - return false - } + require.Eventually(t, func() bool { + for _, node := range nodes { + for _, other := range nodes { + if node == other { + continue + } + if node.Host().Network().Connectedness(other.ID()) != network.Connected { + return false + } + if len(node.Host().Network().ConnsToPeer(other.ID())) == 0 { + return false } } - return true - }, timeout, tick) + } + return true + }, timeout, tick) } // RequireEventuallyNotConnected ensures eventually that the given groups of nodes are not connected to each other. @@ -589,26 +584,20 @@ func RequireConnectedEventually(t *testing.T, nodes []p2p.LibP2PNode, tick time. // - groupB: the second group of nodes // - tick: the tick duration // - timeout: the timeout duration -func RequireEventuallyNotConnected( - t *testing.T, - groupA []p2p.LibP2PNode, - groupB []p2p.LibP2PNode, - tick time.Duration, - timeout time.Duration) { - require.Eventually( - t, func() bool { - for _, node := range groupA { - for _, other := range groupB { - if node.Host().Network().Connectedness(other.ID()) == network.Connected { - return false - } - if len(node.Host().Network().ConnsToPeer(other.ID())) > 0 { - return false - } +func RequireEventuallyNotConnected(t *testing.T, groupA []p2p.LibP2PNode, groupB []p2p.LibP2PNode, tick time.Duration, timeout time.Duration) { + require.Eventually(t, func() bool { + for _, node := range groupA { + for _, other := range groupB { + if node.Host().Network().Connectedness(other.ID()) == network.Connected { + return false + } + if len(node.Host().Network().ConnsToPeer(other.ID())) > 0 { + return false } } - return true - }, timeout, tick) + } + return true + }, timeout, tick) } // EnsureStreamCreationInBothDirections ensure that between each pair of nodes in the given list, a stream is created in both directions. @@ -619,12 +608,11 @@ func EnsureStreamCreationInBothDirections(t *testing.T, ctx context.Context, nod continue } // stream creation should pass without error - err := this.OpenProtectedStream( - ctx, other.ID(), t.Name(), func(stream network.Stream) error { - // do nothing - require.NotNil(t, stream) - return nil - }) + err := this.OpenProtectedStream(ctx, other.ID(), t.Name(), func(stream network.Stream) error { + // do nothing + require.NotNil(t, stream) + return nil + }) require.NoError(t, err) } @@ -642,13 +630,7 @@ func EnsureStreamCreationInBothDirections(t *testing.T, ctx context.Context, nod // // Note-1: this function assumes a timeout of 5 seconds for each message to be received. // Note-2: TryConnectionAndEnsureConnected() must be called to connect all nodes before calling this function. -func EnsurePubsubMessageExchange( - t *testing.T, - ctx context.Context, - nodes []p2p.LibP2PNode, - topic channels.Topic, - count int, - messageFactory func() interface{}) { +func EnsurePubsubMessageExchange(t *testing.T, ctx context.Context, nodes []p2p.LibP2PNode, topic channels.Topic, count int, messageFactory func() interface{}) { subs := make([]p2p.Subscription, len(nodes)) for i, node := range nodes { ps, err := node.Subscribe(topic, validator.TopicValidator(unittest.Logger(), unittest.AllowAllPeerFilter())) @@ -692,16 +674,14 @@ func EnsurePubsubMessageExchange( // - topic: the topic to exchange messages on. // - count: the number of messages to exchange from `sender` to `receiver`. // - messageFactory: a function that creates a unique message to be published by the node. -func EnsurePubsubMessageExchangeFromNode( - t *testing.T, +func EnsurePubsubMessageExchangeFromNode(t *testing.T, ctx context.Context, sender p2p.LibP2PNode, receiverNode p2p.LibP2PNode, receiverIdentifier flow.Identifier, topic channels.Topic, count int, - messageFactory func() interface{}, -) { + messageFactory func() interface{}) { _, err := sender.Subscribe(topic, validator.TopicValidator(unittest.Logger(), unittest.AllowAllPeerFilter())) require.NoError(t, err) @@ -747,16 +727,14 @@ func EnsureNotConnectedBetweenGroups(t *testing.T, ctx context.Context, groupA [ // - topic: the topic to exchange messages on. // - count: the number of messages to exchange from each node. // - messageFactory: a function that creates a unique message to be published by the node. -func EnsureNoPubsubMessageExchange( - t *testing.T, +func EnsureNoPubsubMessageExchange(t *testing.T, ctx context.Context, from []p2p.LibP2PNode, to []p2p.LibP2PNode, toIdentifiers flow.IdentifierList, topic channels.Topic, count int, - messageFactory func() interface{}, -) { + messageFactory func() interface{}) { subs := make([]p2p.Subscription, len(to)) tv := validator.TopicValidator(unittest.Logger(), unittest.AllowAllPeerFilter()) var err error @@ -811,8 +789,7 @@ func EnsureNoPubsubMessageExchange( // - topic: pubsub topic- no message should be exchanged on this topic. // - count: number of messages to be exchanged- no message should be exchanged. // - messageFactory: function to create a unique message to be published by the node. -func EnsureNoPubsubExchangeBetweenGroups( - t *testing.T, +func EnsureNoPubsubExchangeBetweenGroups(t *testing.T, ctx context.Context, groupANodes []p2p.LibP2PNode, groupAIdentifiers flow.IdentifierList, @@ -820,8 +797,7 @@ func EnsureNoPubsubExchangeBetweenGroups( groupBIdentifiers flow.IdentifierList, topic channels.Topic, count int, - messageFactory func() interface{}, -) { + messageFactory func() interface{}) { // ensure no message exchange from group A to group B EnsureNoPubsubMessageExchange(t, ctx, groupANodes, groupBNodes, groupBIdentifiers, topic, count, messageFactory) // ensure no message exchange from group B to group A @@ -846,9 +822,21 @@ func PeerIdSliceFixture(t *testing.T, n int) peer.IDSlice { // NewConnectionGater creates a new connection gater for testing with given allow listing filter. func NewConnectionGater(idProvider module.IdentityProvider, allowListFilter p2p.PeerFilter) p2p.ConnectionGater { filters := []p2p.PeerFilter{allowListFilter} - return connection.NewConnGater( - unittest.Logger(), - idProvider, - connection.WithOnInterceptPeerDialFilters(filters), - connection.WithOnInterceptSecuredFilters(filters)) + return connection.NewConnGater(unittest.Logger(), idProvider, connection.WithOnInterceptPeerDialFilters(filters), connection.WithOnInterceptSecuredFilters(filters)) +} + +// MockInspectorNotificationDistributorReadyDoneAware mocks the Ready and Done methods of the distributor to return a channel that is already closed, +// so that the distributor is considered ready and done when the test needs. +func MockInspectorNotificationDistributorReadyDoneAware(d *mockp2p.GossipSubInspectorNotificationDistributor) { + d.On("Start", mockery.Anything).Return().Maybe() + d.On("Ready").Return(func() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch + }()).Maybe() + d.On("Done").Return(func() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch + }()).Maybe() } diff --git a/network/p2p/unicast/README.MD b/network/p2p/unicast/README.MD index 8e094275dfd..47829c6c80f 100644 --- a/network/p2p/unicast/README.MD +++ b/network/p2p/unicast/README.MD @@ -19,15 +19,19 @@ functionalities, hence, it operates on the notion of the `peer` (rather than Flo than `flow.Identifier`. It is the responsibility of the caller to provide the correct `peer.ID` of the remote node. -If there is existing connection between the local node and the remote node, the manager will try to establish -a connection first, and then open a stream over the connection. The connection is assumed persistent, i.e., it -will be kept until certain events such as Flow node shutdown, restart, disallow-listing of either ends of the connection -by each other, etc. However, a stream is a one-time communication channel, i.e., it is assumed to be closed +The `UnicastManager` relies on the underlying libp2p node to establish the connection to the remote peer. Once the underlying +libp2p node receives a stream creation request from the `UnicastManager`, it will try to establish a connection to the remote peer if +there is no existing connection to the peer. Otherwise, it will pick and re-use the best existing connection to the remote peer. +Hence, the `UnicastManager` does not (and should not) care about the connection establishment, and rather relies on the underlying +libp2p node to establish the connection. The `UnicastManager` only cares about the stream creation, and will return an error +if the underlying libp2p node fails to establish a connection to the remote peer. + + +A stream is a one-time communication channel, i.e., it is assumed to be closed by the caller once the message is sent. The caller (i.e., the Flow node) does not necessarily re-use a stream, and the `Manager` creates one stream per request (i.e., `CreateStream` invocation), which is typically a single message. -However, we have certain safeguards in place to prevent nodes from establishing more than one connection to each other. -That is why the `Manager` establishes the connection only when there is no existing connection between the nodes, and otherwise -re-uses the existing connection. + +Note: the limit of number of streams and connections between nodes is set throught eh libp2p resource manager limits (see `config/default-config.yml`): Note: `pubsub` protocol also establishes connections between nodes to exchange gossip messages with each other. The connection type is the same between `pubsub` and `unicast` protocols, as they both consult the underlying LibP2P node to @@ -48,84 +52,65 @@ that the connection is persistent and will be kept open by the `PeerManager`. ## Backoff and Retry Attempts The flowchart below explains the abstract logic of the `UnicastManager` when it receives a `CreateStream` invocation. -One a happy path, the `UnicastManager` expects a connection to the remote peer exists and hence it can successfully open a stream to the peer. -However, there can be cases that the connection does not exist, the remote peer is not reliable for stream creation, or the remote peer acts -maliciously and does not respond to connection and stream creation requests. In order to distinguish between the cases that the remote peer +On the happy path, the `UnicastManager` successfully opens a stream to the peer. +However, there can be cases that the remote peer is not reliable for stream creation, or the remote peer acts +maliciously and does not respond stream creation requests. In order to distinguish between the cases that the remote peer is not reliable and the cases that the remote peer is malicious, the `UnicastManager` uses a backoff and retry mechanism. ![retry.png](retry.png) ### Addressing Unreliable Remote Peer -To address the unreliability of remote peer, upon an unsuccessful attempt to establish a connection or stream, the `UnicastManager` will wait for a certain amount of time before it tries to establish (i.e., the backoff mechanism), -and will retry a certain number of times before it gives up (i.e., the retry mechanism). The backoff and retry parameters are configurable through runtime flags. +To address the unreliability of remote peer, upon an unsuccessful attempt to establish a stream, the `UnicastManager` will wait for a certain +amount of time before it tries to establish (i.e., the backoff mechanism), and will retry a certain number of times before it gives up (i.e., the retry mechanism). +The backoff and retry parameters are configurable through runtime flags. If all backoff and retry attempts fail, the `UnicastManager` will return an error to the caller. The caller can then decide to retry the request or not. -By default, `UnicastManager` retries each connection (dialing) attempt as well as stream creation attempt 3 times. Also, the backoff intervals for dialing and stream creation are initialized to 1 second and progress -exponentially with a factor of 2, i.e., the `i-th` retry attempt is made after `t * 2^(i-1)`, where `t` is the backoff interval. The formulation is the same for dialing and -stream creation. For example, if the backoff interval is 1s, the first attempt is made right-away, the first (retry) attempt is made after 1s * 2^(1 - 1) = 1s, the third (retry) attempt is made +By default, `UnicastManager` retries each stream creation attempt 3 times. Also, the backoff intervals for dialing and stream creation are initialized to 1 second and progress +exponentially with a factor of 2, i.e., the `i-th` retry attempt is made after `t * 2^(i-1)`, where `t` is the backoff interval. +For example, if the backoff interval is 1s, the first attempt is made right-away, the first (retry) attempt is made after 1s * 2^(1 - 1) = 1s, the third (retry) attempt is made after `1s * 2^(2 - 1) = 2s`, and so on. These parameters are configured using the `config/default-config.yml` file: ```yaml # Unicast create stream retry delay is initial delay used in the exponential backoff for create stream retries unicast-create-stream-retry-delay: 1s - # The backoff delay used in the exponential backoff for consecutive failed unicast dial attempts to a remote peer. - unicast-dial-backoff-delay: 1s -``` - -#### Addressing Concurrent Stream Creation Attempts -There might be the case that multiple threads attempt to create a stream to the same remote peer concurrently, while there is no -existing connection to the remote peer. In such cases, the `UnicastManager` will let the first attempt to create a stream to the remote peer to proceed with dialing -while it will backoff the concurrent attempts for a certain amount of time. The backoff delay is configurable through the `config/default-config.yml` file: -```yaml - # The backoff delay used in the exponential backoff for backing off concurrent create stream attempts to the same remote peer - # when there is no available connections to that remote peer and a dial is in progress. - unicast-dial-in-progress-backoff-delay: 1s ``` -This is done for several reasons including: -- The resource manager of the remote peer may block the concurrent dial attempts if they exceed a certain threshold. -- As a convention in networking layer, we don't desire more than one connection to a remote peer, and there are hard reactive constraints in place. - However, as a soft proactive measure, we backoff concurrent dial attempts to the same remote peer to prevent multiple connections to the same peer. -- Dialing is a resource-intensive operation, and we don't want to waste resources on concurrent dial attempts to the same remote peer. ### Addressing Malicious Remote Peer The backoff and retry mechanism is used to address the cases that the remote peer is not reliable. -However, there can be cases that the remote peer is malicious and does not respond to connection and stream creation requests. -Such cases may cause the `UnicastManager` to wait for a long time before it gives up, resulting in a resource exhaustion and slow-down of the dialing node. -To mitigate such cases, the `UnicastManager` uses a retry budget for the stream creation and dialing. The retry budgets are initialized +However, there can be cases that the remote peer is malicious and does not respond to stream creation requests. +Such cases may cause the `UnicastManager` to wait for a long time before it gives up, resulting in a resource exhaustion and slow-down of the stream creation. +To mitigate such cases, the `UnicastManager` uses a retry budget for the stream creation. The retry budgets are initialized using the `config/default-config.yml` file: ```yaml # The maximum number of retry attempts for creating a unicast stream to a remote peer before giving up. If it is set to 3 for example, it means that if a peer fails to create # retry a unicast stream to a remote peer 3 times, the peer will give up and will not retry creating a unicast stream to that remote peer. # When it is set to zero it means that the peer will not retry creating a unicast stream to a remote peer if it fails. unicast-max-stream-creation-retry-attempt-times: 3 - # The maximum number of retry attempts for dialing a remote peer before giving up. If it is set to 3 for example, it means that if a peer fails to dial a remote peer 3 times, - # the peer will give up and will not retry dialing that remote peer. - unicast-max-dial-retry-attempt-times: 3 ``` -As shown in the above snippet, both retry budgets for dialing and stream creation are set to 3 by default for every remote peer. -Each time the `UnicastManager` is invoked on `CreateStream` to `pid` (`peer.ID`), it loads the retry budgets for `pid` from the dial config cache. -If no dial config record exists for `pid`, one is created with the default retry budgets. The `UnicastManager` then uses the retry budgets to decide -whether to retry the dialing or stream creation attempt or not. If the retry budget for dialing or stream creation is exhausted, the `UnicastManager` -will not retry the dialing or stream creation attempt, respectively, and returns an error to the caller. The caller can then decide to retry the request or not. +As shown in the above snippet, the stream creation is set to 3 by default for every remote peer. +Each time the `UnicastManager` is invoked on `CreateStream` to `pid` (`peer.ID`), it loads the retry budgets for `pid` from the unicast config cache. +If no unicast config record exists for `pid`, one is created with the default retry budgets. The `UnicastManager` then uses the retry budgets to decide +whether to retry the stream creation attempt or not. If the retry budget for stream creation is exhausted, the `UnicastManager` +will not retry the stream creation attempt, and returns an error to the caller. The caller can then decide to retry the request or not. +Note that even when the retry budget is exhausted, the `UnicastManager` will try the stream creation attempt once, though it will not retry the attempt if it fails. #### Penalizing Malicious Remote Peer -Each time the `UnicastManager` fails to dial or create a stream to a remote peer and exhausts the retry budget, it penalizes the remote peer as follows: -- If the `UnicastManager` exhausts the retry budget for dialing, it will decrement the dial retry budget as well as the stream creation retry budget for the remote peer. +Each time the `UnicastManager` fails to create a stream to a remote peer and exhausts the retry budget, it penalizes the remote peer as follows: - If the `UnicastManager` exhausts the retry budget for stream creation, it will decrement the stream creation retry budget for the remote peer. -- If the retry budget reaches zero, the `UnicastManager` will only attempt once to dial or create a stream to the remote peer, and will not retry the attempt, and rather return an error to the caller. -- When any of the budgets reaches zero, the `UnicastManager` will not decrement the budget anymore. +- If the retry budget reaches zero, the `UnicastManager` will only attempt once to create a stream to the remote peer, and will not retry the attempt, and rather return an error to the caller. +- When the budget reaches zero, the `UnicastManager` will not decrement the budget anymore. **Note:** `UnicastManager` is part of the networking layer of the Flow node, which is a lower-order component than the Flow protocol engines who call the `UnicastManager` to send messages to remote peers. Hence, the `UnicastManager` _must not_ outsmart -the Flow protocol engines on deciding whether to _dial or create stream_ in the first place. This means that `UnicastManager` will attempt -to dial and create stream even to peers with zero retry budgets. However, `UnicastManager` does not retry attempts for the peers with zero budgets, and rather +the Flow protocol engines on deciding whether to _create stream_ in the first place. This means that `UnicastManager` will attempt +to create stream even to peers with zero retry budgets. However, `UnicastManager` does not retry attempts for the peers with zero budgets, and rather returns an error immediately upon a failure. This is the responsibility of the Flow protocol engines to decide whether to send a message to a remote peer or not after a certain number of failures. #### Restoring Retry Budgets -The `UnicastManager` may reset the dial and stream creation budgets for a remote peers _from zero to the default values_ in the following cases: +The `UnicastManager` may reset the stream creation budget for a remote peers _from zero to the default values_ in the following cases: - **Restoring Stream Creation Retry Budget**: To restore the stream creation budget from zero to the default value, the `UnicastManager` keeps track of the _consecutive_ successful streams created to the remote peer. Everytime a stream is created successfully, the `UnicastManager` increments a counter for the remote peer. The counter is @@ -137,19 +122,4 @@ The `UnicastManager` may reset the dial and stream creation budgets for a remote # the unicast stream creation retry budget for that remote peer will be reset to the maximum default. unicast-stream-zero-retry-reset-threshold: 100 ``` - Reaching the threshold means that the remote peer is reliable enough to regain the default retry budget for stream creation. -- **Restoring Dial Retry Budget**: To restore the dial retry budget from zero to the default value, the `UnicastManager` keeps track of the last successful - dial time to the remote peer. Every failed dialing attempt will reset the last successful dial time to zero. If the time since the last successful dialing attempt - reaches a certain threshold, the `UnicastManager` will reset the dial budget for the remote peer to the default value. - The threshold is configurable through the `config/default-config.yml` file: - ```yaml - # The number of seconds that the local peer waits since the last successful dial to a remote peer before resetting the unicast dial retry budget from zero to the maximum default. - # If it is set to 3600s (1h) for example, it means that if it has passed at least one hour since the last successful dial, and the remote peer has a zero dial retry budget, - # the unicast dial retry budget for that remote peer will be reset to the maximum default. - unicast-dial-zero-retry-reset-threshold: 3600s - ``` - Reaching the threshold means that either the `UnicastManager` has not dialed the remote peer for a long time, and the peer - deserves a chance to regain its dial retry budget, or the remote peer maintains a persistent connection to the local peer, for a long time, and - deserves a chance to regain its dial retry budget. Note that the networking layer enforces a maximum number of _one_ connection to a remote peer, hence - the remote peer cannot have multiple connections to the local peer. Also, connection establishment is assumed a more resource-intensive operation than the stream creation, - hence, in contrast to the stream reliability that is measured by the number of consecutive successful streams, the dial reliability is measured by the time since the last successful dial. \ No newline at end of file + Reaching the threshold means that the remote peer is reliable enough to regain the default retry budget for stream creation. \ No newline at end of file diff --git a/network/p2p/unicast/cache/dialConfigCache_test.go b/network/p2p/unicast/cache/dialConfigCache_test.go deleted file mode 100644 index 1945070d0c5..00000000000 --- a/network/p2p/unicast/cache/dialConfigCache_test.go +++ /dev/null @@ -1,281 +0,0 @@ -package unicastcache_test - -import ( - "fmt" - "sync" - "testing" - "time" - - "github.com/libp2p/go-libp2p/core/peer" - "github.com/rs/zerolog" - "github.com/stretchr/testify/require" - - "github.com/onflow/flow-go/module/metrics" - "github.com/onflow/flow-go/network/p2p/unicast" - unicastcache "github.com/onflow/flow-go/network/p2p/unicast/cache" - "github.com/onflow/flow-go/utils/unittest" -) - -// TestNewDialConfigCache tests the creation of a new DialConfigCache. -// It asserts that the cache is created and its size is 0. -func TestNewDialConfigCache(t *testing.T) { - sizeLimit := uint32(100) - logger := zerolog.Nop() - collector := metrics.NewNoopCollector() - cache := unicastcache.NewDialConfigCache(sizeLimit, logger, collector, dialConfigFixture) - require.NotNil(t, cache) - require.Equalf(t, uint(0), cache.Size(), "cache size must be 0") -} - -// dialConfigFixture returns a dial config fixture. -// The dial config is initialized with the default values. -func dialConfigFixture() unicast.DialConfig { - return unicast.DialConfig{ - DialRetryAttemptBudget: 3, - StreamCreationRetryAttemptBudget: 3, - } -} - -// TestDialConfigCache_Adjust tests the Adjust method of the DialConfigCache. It asserts that the dial config is initialized, adjusted, -// and stored in the cache. -func TestDialConfigCache_Adjust_Init(t *testing.T) { - sizeLimit := uint32(100) - logger := zerolog.Nop() - collector := metrics.NewNoopCollector() - - dialFactoryCalled := 0 - dialConfigFactory := func() unicast.DialConfig { - require.Less(t, dialFactoryCalled, 2, "dial config factory must be called at most twice") - dialFactoryCalled++ - return dialConfigFixture() - } - adjustFuncIncrement := func(cfg unicast.DialConfig) (unicast.DialConfig, error) { - cfg.DialRetryAttemptBudget++ - return cfg, nil - } - - cache := unicastcache.NewDialConfigCache(sizeLimit, logger, collector, dialConfigFactory) - require.NotNil(t, cache) - require.Zerof(t, cache.Size(), "cache size must be 0") - - peerID1 := unittest.PeerIdFixture(t) - peerID2 := unittest.PeerIdFixture(t) - - // Initializing the dial config for peerID1 through GetOrInit. - // dial config for peerID1 does not exist in the cache, so it must be initialized when using GetOrInit. - cfg, err := cache.GetOrInit(peerID1) - require.NoError(t, err) - require.NotNil(t, cfg, "dial config must not be nil") - require.Equal(t, dialConfigFixture(), *cfg, "dial config must be initialized with the default values") - require.Equal(t, uint(1), cache.Size(), "cache size must be 1") - - // Initializing and adjusting the dial config for peerID2 through Adjust. - // dial config for peerID2 does not exist in the cache, so it must be initialized when using Adjust. - cfg, err = cache.Adjust(peerID2, adjustFuncIncrement) - require.NoError(t, err) - // adjusting a non-existing dial config must not initialize the config. - require.Equal(t, uint(2), cache.Size(), "cache size must be 2") - require.Equal(t, cfg.LastSuccessfulDial, dialConfigFixture().LastSuccessfulDial, "last successful dial must be 0") - require.Equal(t, cfg.DialRetryAttemptBudget, dialConfigFixture().DialRetryAttemptBudget+1, "dial backoff must be adjusted") - require.Equal(t, cfg.StreamCreationRetryAttemptBudget, dialConfigFixture().StreamCreationRetryAttemptBudget, "stream backoff must be 1") - - // Retrieving the dial config of peerID2 through GetOrInit. - // retrieve the dial config for peerID2 and assert than it is initialized with the default values; and the adjust function is applied. - cfg, err = cache.GetOrInit(peerID2) - require.NoError(t, err, "dial config must exist in the cache") - require.NotNil(t, cfg, "dial config must not be nil") - // retrieving an existing dial config must not change the cache size. - require.Equal(t, uint(2), cache.Size(), "cache size must be 2") - // config should be the same as the one returned by Adjust. - require.Equal(t, cfg.LastSuccessfulDial, dialConfigFixture().LastSuccessfulDial, "last successful dial must be 0") - require.Equal(t, cfg.DialRetryAttemptBudget, dialConfigFixture().DialRetryAttemptBudget+1, "dial backoff must be adjusted") - require.Equal(t, cfg.StreamCreationRetryAttemptBudget, dialConfigFixture().StreamCreationRetryAttemptBudget, "stream backoff must be 1") - - // Adjusting the dial config of peerID1 through Adjust. - // dial config for peerID1 already exists in the cache, so it must be adjusted when using Adjust. - cfg, err = cache.Adjust(peerID1, adjustFuncIncrement) - require.NoError(t, err) - // adjusting an existing dial config must not change the cache size. - require.Equal(t, uint(2), cache.Size(), "cache size must be 2") - require.Equal(t, cfg.LastSuccessfulDial, dialConfigFixture().LastSuccessfulDial, "last successful dial must be 0") - require.Equal(t, cfg.DialRetryAttemptBudget, dialConfigFixture().DialRetryAttemptBudget+1, "dial backoff must be adjusted") - require.Equal(t, cfg.StreamCreationRetryAttemptBudget, dialConfigFixture().StreamCreationRetryAttemptBudget, "stream backoff must be 1") - - // Recurring adjustment of the dial config of peerID1 through Adjust. - // dial config for peerID1 already exists in the cache, so it must be adjusted when using Adjust. - cfg, err = cache.Adjust(peerID1, adjustFuncIncrement) - require.NoError(t, err) - // adjusting an existing dial config must not change the cache size. - require.Equal(t, uint(2), cache.Size(), "cache size must be 2") - require.Equal(t, cfg.LastSuccessfulDial, dialConfigFixture().LastSuccessfulDial, "last successful dial must be 0") - require.Equal(t, cfg.DialRetryAttemptBudget, dialConfigFixture().DialRetryAttemptBudget+2, "dial backoff must be adjusted") - require.Equal(t, cfg.StreamCreationRetryAttemptBudget, dialConfigFixture().StreamCreationRetryAttemptBudget, "stream backoff must be 1") -} - -// TestDialConfigCache_Adjust tests the Adjust method of the DialConfigCache. It asserts that the dial config is adjusted, -// and stored in the cache as expected under concurrent adjustments. -func TestDialConfigCache_Concurrent_Adjust(t *testing.T) { - sizeLimit := uint32(100) - logger := zerolog.Nop() - collector := metrics.NewNoopCollector() - - cache := unicastcache.NewDialConfigCache(sizeLimit, logger, collector, func() unicast.DialConfig { - return unicast.DialConfig{} // empty dial config - }) - require.NotNil(t, cache) - require.Zerof(t, cache.Size(), "cache size must be 0") - - peerIds := make([]peer.ID, sizeLimit) - for i := 0; i < int(sizeLimit); i++ { - peerId := unittest.PeerIdFixture(t) - require.NotContainsf(t, peerIds, peerId, "peer id must be unique") - peerIds[i] = peerId - } - - wg := sync.WaitGroup{} - for i := 0; i < int(sizeLimit); i++ { - // adjusts the ith dial config for peerID i times, concurrently. - for j := 0; j < i+1; j++ { - wg.Add(1) - go func(peerId peer.ID) { - defer wg.Done() - _, err := cache.Adjust(peerId, func(cfg unicast.DialConfig) (unicast.DialConfig, error) { - cfg.DialRetryAttemptBudget++ - return cfg, nil - }) - require.NoError(t, err) - }(peerIds[i]) - } - } - - unittest.RequireReturnsBefore(t, wg.Wait, time.Second*1, "adjustments must be done on time") - - // assert that the cache size is equal to the size limit. - require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") - - // assert that the dial config for each peer is adjusted i times, concurrently. - for i := 0; i < int(sizeLimit); i++ { - wg.Add(1) - go func(j int) { - wg.Done() - - peerID := peerIds[j] - cfg, err := cache.GetOrInit(peerID) - require.NoError(t, err) - require.Equal(t, uint64(j+1), cfg.DialRetryAttemptBudget, fmt.Sprintf("peerId %s dial backoff must be adjusted %d times got: %d", peerID, j+1, cfg.DialRetryAttemptBudget)) - }(i) - } - - unittest.RequireReturnsBefore(t, wg.Wait, time.Second*1, "retrievals must be done on time") -} - -// TestConcurrent_Adjust_And_Get_Is_Safe tests that concurrent adjustments and retrievals are safe, and do not cause error even if they cause eviction. The test stress tests the cache -// with 2 * SizeLimit concurrent operations (SizeLimit times concurrent adjustments and SizeLimit times concurrent retrievals). -// It asserts that the cache size is equal to the size limit, and the dial config for each peer is adjusted and retrieved correctly. -func TestConcurrent_Adjust_And_Get_Is_Safe(t *testing.T) { - sizeLimit := uint32(100) - logger := zerolog.Nop() - collector := metrics.NewNoopCollector() - - cache := unicastcache.NewDialConfigCache(sizeLimit, logger, collector, dialConfigFixture) - require.NotNil(t, cache) - require.Zerof(t, cache.Size(), "cache size must be 0") - - wg := sync.WaitGroup{} - for i := 0; i < int(sizeLimit); i++ { - // concurrently adjusts the dial configs. - wg.Add(1) - go func() { - defer wg.Done() - peerId := unittest.PeerIdFixture(t) - dialTime := time.Now() - updatedConfig, err := cache.Adjust(peerId, func(cfg unicast.DialConfig) (unicast.DialConfig, error) { - cfg.DialRetryAttemptBudget = 1 // some random adjustment - cfg.LastSuccessfulDial = dialTime - cfg.StreamCreationRetryAttemptBudget = 2 // some random adjustment - cfg.ConsecutiveSuccessfulStream = 3 // some random adjustment - return cfg, nil - }) - require.NoError(t, err) // concurrent adjustment must not fail. - require.Equal(t, uint64(1), updatedConfig.DialRetryAttemptBudget) // adjustment must be successful - require.Equal(t, uint64(2), updatedConfig.StreamCreationRetryAttemptBudget) - require.Equal(t, uint64(3), updatedConfig.ConsecutiveSuccessfulStream) - require.Equal(t, dialTime, updatedConfig.LastSuccessfulDial) - }() - } - - // assert that the dial config for each peer is adjusted i times, concurrently. - for i := 0; i < int(sizeLimit); i++ { - wg.Add(1) - go func() { - wg.Done() - peerId := unittest.PeerIdFixture(t) - cfg, err := cache.GetOrInit(peerId) - require.NoError(t, err) // concurrent retrieval must not fail. - require.Equal(t, dialConfigFixture().DialRetryAttemptBudget, cfg.DialRetryAttemptBudget) // dial config must be initialized with the default values. - require.Equal(t, dialConfigFixture().StreamCreationRetryAttemptBudget, cfg.StreamCreationRetryAttemptBudget) - require.Equal(t, uint64(0), cfg.ConsecutiveSuccessfulStream) - require.True(t, cfg.LastSuccessfulDial.IsZero()) - }() - } - - unittest.RequireReturnsBefore(t, wg.Wait, time.Second*1, "all operations must be done on time") - - // cache was stress-tested with 2 * SizeLimit concurrent operations. Nevertheless, the cache size must be equal to the size limit due to LRU eviction. - require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") -} - -// TestDialConfigCache_LRU_Eviction tests that the cache evicts the least recently used dial config when the cache size reaches the size limit. -func TestDialConfigCache_LRU_Eviction(t *testing.T) { - sizeLimit := uint32(100) - logger := zerolog.Nop() - collector := metrics.NewNoopCollector() - - cache := unicastcache.NewDialConfigCache(sizeLimit, logger, collector, dialConfigFixture) - require.NotNil(t, cache) - require.Zerof(t, cache.Size(), "cache size must be 0") - - peerIds := make([]peer.ID, sizeLimit+1) - for i := 0; i < int(sizeLimit+1); i++ { - peerId := unittest.PeerIdFixture(t) - require.NotContainsf(t, peerIds, peerId, "peer id must be unique") - peerIds[i] = peerId - } - for i := 0; i < int(sizeLimit+1); i++ { - dialTime := time.Now() - updatedConfig, err := cache.Adjust(peerIds[i], func(cfg unicast.DialConfig) (unicast.DialConfig, error) { - cfg.DialRetryAttemptBudget = 1 // some random adjustment - cfg.StreamCreationRetryAttemptBudget = 2 // some random adjustment - cfg.ConsecutiveSuccessfulStream = 3 // some random adjustment - cfg.LastSuccessfulDial = dialTime - return cfg, nil - }) - require.NoError(t, err) // concurrent adjustment must not fail. - require.Equal(t, uint64(1), updatedConfig.DialRetryAttemptBudget) // adjustment must be successful - require.Equal(t, uint64(2), updatedConfig.StreamCreationRetryAttemptBudget) - require.Equal(t, uint64(3), updatedConfig.ConsecutiveSuccessfulStream) - require.Equal(t, dialTime, updatedConfig.LastSuccessfulDial) - } - - // except the first peer id, all other peer ids should stay intact in the cache. - for i := 1; i < int(sizeLimit+1); i++ { - cfg, err := cache.GetOrInit(peerIds[i]) - require.NoError(t, err) - require.Equal(t, uint64(1), cfg.DialRetryAttemptBudget) - require.Equal(t, uint64(2), cfg.StreamCreationRetryAttemptBudget) - require.Equal(t, uint64(3), cfg.ConsecutiveSuccessfulStream) - require.False(t, cfg.LastSuccessfulDial.IsZero()) - } - - require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") - - // querying the first peer id should return a fresh dial config, since it should be evicted due to LRU eviction, and the initiated with the default values. - cfg, err := cache.GetOrInit(peerIds[0]) - require.NoError(t, err) - require.Equal(t, dialConfigFixture().DialRetryAttemptBudget, cfg.DialRetryAttemptBudget) - require.Equal(t, dialConfigFixture().StreamCreationRetryAttemptBudget, cfg.StreamCreationRetryAttemptBudget) - require.Equal(t, uint64(0), cfg.ConsecutiveSuccessfulStream) - require.True(t, cfg.LastSuccessfulDial.IsZero()) - - require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") -} diff --git a/network/p2p/unicast/cache/dialConfigCache.go b/network/p2p/unicast/cache/unicastConfigCache.go similarity index 54% rename from network/p2p/unicast/cache/dialConfigCache.go rename to network/p2p/unicast/cache/unicastConfigCache.go index d0e88e43786..13c000110fe 100644 --- a/network/p2p/unicast/cache/dialConfigCache.go +++ b/network/p2p/unicast/cache/unicastConfigCache.go @@ -15,74 +15,74 @@ import ( "github.com/onflow/flow-go/network/p2p/unicast" ) -// ErrDialConfigNotFound is a benign error that indicates that the dial config does not exist in the cache. It is not a fatal error. -var ErrDialConfigNotFound = fmt.Errorf("dial config not found") +// ErrUnicastConfigNotFound is a benign error that indicates that the unicast config does not exist in the cache. It is not a fatal error. +var ErrUnicastConfigNotFound = fmt.Errorf("unicast config not found") -type DialConfigCache struct { +type UnicastConfigCache struct { // mutex is temporarily protect the edge case in HeroCache that optimistic adjustment causes the cache to be full. // TODO: remove this mutex after the HeroCache is fixed. mutex sync.RWMutex peerCache *stdmap.Backend - cfgFactory func() unicast.DialConfig // factory function that creates a new dial config. + cfgFactory func() unicast.Config // factory function that creates a new unicast config. } -var _ unicast.DialConfigCache = (*DialConfigCache)(nil) +var _ unicast.ConfigCache = (*UnicastConfigCache)(nil) -// NewDialConfigCache creates a new DialConfigCache. +// NewUnicastConfigCache creates a new UnicastConfigCache. // Args: -// - size: the maximum number of dial configs that the cache can hold. +// - size: the maximum number of unicast configs that the cache can hold. // - logger: the logger used by the cache. // - collector: the metrics collector used by the cache. -// - cfgFactory: a factory function that creates a new dial config. +// - cfgFactory: a factory function that creates a new unicast config. // Returns: -// - *DialConfigCache, the created cache. -// Note that the cache is supposed to keep the dial config for all types of nodes. Since the number of such nodes is -// expected to be small, size must be large enough to hold all the dial configs of the authorized nodes. +// - *UnicastConfigCache, the created cache. +// Note that the cache is supposed to keep the unicast config for all types of nodes. Since the number of such nodes is +// expected to be small, size must be large enough to hold all the unicast configs of the authorized nodes. // To avoid any crash-failure, the cache is configured to eject the least recently used configs when the cache is full. // Hence, we recommend setting the size to a large value to minimize the ejections. -func NewDialConfigCache( +func NewUnicastConfigCache( size uint32, logger zerolog.Logger, collector module.HeroCacheMetrics, - cfgFactory func() unicast.DialConfig, -) *DialConfigCache { - return &DialConfigCache{ + cfgFactory func() unicast.Config, +) *UnicastConfigCache { + return &UnicastConfigCache{ peerCache: stdmap.NewBackend(stdmap.WithBackData(herocache.NewCache(size, herocache.DefaultOversizeFactor, heropool.LRUEjection, - logger.With().Str("module", "dial-config-cache").Logger(), + logger.With().Str("module", "unicast-config-cache").Logger(), collector))), cfgFactory: cfgFactory, } } -// Adjust applies the given adjust function to the dial config of the given peer ID, and stores the adjusted config in the cache. +// Adjust applies the given adjust function to the unicast config of the given peer ID, and stores the adjusted config in the cache. // It returns an error if the adjustFunc returns an error. // Note that if the Adjust is called when the config does not exist, the config is initialized and the // adjust function is applied to the initialized config again. In this case, the adjust function should not return an error. // Args: -// - peerID: the peer id of the dial config. -// - adjustFunc: the function that adjusts the dial config. +// - peerID: the peer id of the unicast config. +// - adjustFunc: the function that adjusts the unicast config. // Returns: // - error any returned error should be considered as an irrecoverable error and indicates a bug. -func (d *DialConfigCache) Adjust(peerID peer.ID, adjustFunc unicast.DialConfigAdjustFunc) (*unicast.DialConfig, error) { +func (d *UnicastConfigCache) Adjust(peerID peer.ID, adjustFunc unicast.UnicastConfigAdjustFunc) (*unicast.Config, error) { d.mutex.Lock() // making optimistic adjustment atomic. defer d.mutex.Unlock() // first we translate the peer id to a flow id (taking peerIdHash := PeerIdToFlowId(peerID) - adjustedDialCfg, err := d.adjust(peerIdHash, adjustFunc) + adjustedUnicastCfg, err := d.adjust(peerIdHash, adjustFunc) if err != nil { - if err == ErrDialConfigNotFound { + if err == ErrUnicastConfigNotFound { // if the config does not exist, we initialize the config and try to adjust it again. // Note: there is an edge case where the config is initialized by another goroutine between the two calls. // In this case, the init function is invoked twice, but it is not a problem because the underlying // cache is thread-safe. Hence, we do not need to synchronize the two calls. In such cases, one of the // two calls returns false, and the other call returns true. We do not care which call returns false, hence, // we ignore the return value of the init function. - e := DialConfigEntity{ - PeerId: peerID, - DialConfig: d.cfgFactory(), + e := UnicastConfigEntity{ + PeerId: peerID, + Config: d.cfgFactory(), } _ = d.peerCache.Add(e) @@ -93,39 +93,39 @@ func (d *DialConfigCache) Adjust(peerID peer.ID, adjustFunc unicast.DialConfigAd } // if the adjust function returns an unexpected error on the first attempt, we return the error directly. // any returned error should be considered as an irrecoverable error and indicates a bug. - return nil, fmt.Errorf("failed to adjust dial config: %w", err) + return nil, fmt.Errorf("failed to adjust unicast config: %w", err) } // if the adjust function returns no error on the first attempt, we return the adjusted config. - return adjustedDialCfg, nil + return adjustedUnicastCfg, nil } -// adjust applies the given adjust function to the dial config of the given origin id. +// adjust applies the given adjust function to the unicast config of the given origin id. // It returns an error if the adjustFunc returns an error or if the config does not exist. // Args: -// - peerIDHash: the hash value of the peer id of the dial config (i.e., the ID of the dial config entity). -// - adjustFunc: the function that adjusts the dial config. +// - peerIDHash: the hash value of the peer id of the unicast config (i.e., the ID of the unicast config entity). +// - adjustFunc: the function that adjusts the unicast config. // Returns: -// - error if the adjustFunc returns an error or if the config does not exist (ErrDialConfigNotFound). Except the ErrDialConfigNotFound, +// - error if the adjustFunc returns an error or if the config does not exist (ErrUnicastConfigNotFound). Except the ErrUnicastConfigNotFound, // any other error should be treated as an irrecoverable error and indicates a bug. -func (d *DialConfigCache) adjust(peerIdHash flow.Identifier, adjustFunc unicast.DialConfigAdjustFunc) (*unicast.DialConfig, error) { +func (d *UnicastConfigCache) adjust(peerIdHash flow.Identifier, adjustFunc unicast.UnicastConfigAdjustFunc) (*unicast.Config, error) { var rErr error adjustedEntity, adjusted := d.peerCache.Adjust(peerIdHash, func(entity flow.Entity) flow.Entity { - cfgEntity, ok := entity.(DialConfigEntity) + cfgEntity, ok := entity.(UnicastConfigEntity) if !ok { // sanity check - // This should never happen, because the cache only contains DialConfigEntity entities. - panic(fmt.Sprintf("invalid entity type, expected DialConfigEntity type, got: %T", entity)) + // This should never happen, because the cache only contains UnicastConfigEntity entities. + panic(fmt.Sprintf("invalid entity type, expected UnicastConfigEntity type, got: %T", entity)) } - // adjust the dial config. - adjustedCfg, err := adjustFunc(cfgEntity.DialConfig) + // adjust the unicast config. + adjustedCfg, err := adjustFunc(cfgEntity.Config) if err != nil { rErr = fmt.Errorf("adjust function failed: %w", err) return entity // returns the original entity (reverse the adjustment). } // Return the adjusted config. - cfgEntity.DialConfig = adjustedCfg + cfgEntity.Config = adjustedCfg return cfgEntity }) @@ -134,65 +134,61 @@ func (d *DialConfigCache) adjust(peerIdHash flow.Identifier, adjustFunc unicast. } if !adjusted { - return nil, ErrDialConfigNotFound + return nil, ErrUnicastConfigNotFound } - return &unicast.DialConfig{ - DialRetryAttemptBudget: adjustedEntity.(DialConfigEntity).DialRetryAttemptBudget, - StreamCreationRetryAttemptBudget: adjustedEntity.(DialConfigEntity).StreamCreationRetryAttemptBudget, - LastSuccessfulDial: adjustedEntity.(DialConfigEntity).LastSuccessfulDial, - ConsecutiveSuccessfulStream: adjustedEntity.(DialConfigEntity).ConsecutiveSuccessfulStream, + return &unicast.Config{ + StreamCreationRetryAttemptBudget: adjustedEntity.(UnicastConfigEntity).StreamCreationRetryAttemptBudget, + ConsecutiveSuccessfulStream: adjustedEntity.(UnicastConfigEntity).ConsecutiveSuccessfulStream, }, nil } -// GetOrInit returns the dial config for the given peer id. If the config does not exist, it creates a new config +// GetOrInit returns the unicast config for the given peer id. If the config does not exist, it creates a new config // using the factory function and stores it in the cache. // Args: -// - peerID: the peer id of the dial config. +// - peerID: the peer id of the unicast config. // Returns: -// - *DialConfig, the dial config for the given peer id. +// - *Config, the unicast config for the given peer id. // - error if the factory function returns an error. Any error should be treated as an irrecoverable error and indicates a bug. -func (d *DialConfigCache) GetOrInit(peerID peer.ID) (*unicast.DialConfig, error) { +func (d *UnicastConfigCache) GetOrInit(peerID peer.ID) (*unicast.Config, error) { // first we translate the peer id to a flow id (taking flowPeerId := PeerIdToFlowId(peerID) cfg, ok := d.get(flowPeerId) if !ok { - _ = d.peerCache.Add(DialConfigEntity{ - PeerId: peerID, - DialConfig: d.cfgFactory(), + _ = d.peerCache.Add(UnicastConfigEntity{ + PeerId: peerID, + Config: d.cfgFactory(), }) cfg, ok = d.get(flowPeerId) if !ok { - return nil, fmt.Errorf("failed to initialize dial config for peer %s", peerID) + return nil, fmt.Errorf("failed to initialize unicast config for peer %s", peerID) } } return cfg, nil } -// Get returns the dial config of the given peer ID. -func (d *DialConfigCache) get(peerIDHash flow.Identifier) (*unicast.DialConfig, bool) { +// Get returns the unicast config of the given peer ID. +func (d *UnicastConfigCache) get(peerIDHash flow.Identifier) (*unicast.Config, bool) { entity, ok := d.peerCache.ByID(peerIDHash) if !ok { return nil, false } - cfg, ok := entity.(DialConfigEntity) + cfg, ok := entity.(UnicastConfigEntity) if !ok { // sanity check - // This should never happen, because the cache only contains DialConfigEntity entities. - panic(fmt.Sprintf("invalid entity type, expected DialConfigEntity type, got: %T", entity)) + // This should never happen, because the cache only contains UnicastConfigEntity entities. + panic(fmt.Sprintf("invalid entity type, expected UnicastConfigEntity type, got: %T", entity)) } // return a copy of the config (we do not want the caller to modify the config). - return &unicast.DialConfig{ - DialRetryAttemptBudget: cfg.DialRetryAttemptBudget, + return &unicast.Config{ StreamCreationRetryAttemptBudget: cfg.StreamCreationRetryAttemptBudget, - LastSuccessfulDial: cfg.LastSuccessfulDial, ConsecutiveSuccessfulStream: cfg.ConsecutiveSuccessfulStream, }, true } -// Size returns the number of dial configs in the cache. -func (d *DialConfigCache) Size() uint { +// Size returns the number of unicast configs in the cache. +func (d *UnicastConfigCache) Size() uint { return d.peerCache.Size() } diff --git a/network/p2p/unicast/cache/unicastConfigCache_test.go b/network/p2p/unicast/cache/unicastConfigCache_test.go new file mode 100644 index 00000000000..4d07c9980d2 --- /dev/null +++ b/network/p2p/unicast/cache/unicastConfigCache_test.go @@ -0,0 +1,260 @@ +package unicastcache_test + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/network/p2p/unicast" + unicastcache "github.com/onflow/flow-go/network/p2p/unicast/cache" + "github.com/onflow/flow-go/utils/unittest" +) + +// TestNewUnicastConfigCache tests the creation of a new UnicastConfigCache. +// It asserts that the cache is created and its size is 0. +func TestNewUnicastConfigCache(t *testing.T) { + sizeLimit := uint32(100) + logger := zerolog.Nop() + collector := metrics.NewNoopCollector() + cache := unicastcache.NewUnicastConfigCache(sizeLimit, logger, collector, unicastConfigFixture) + require.NotNil(t, cache) + require.Equalf(t, uint(0), cache.Size(), "cache size must be 0") +} + +// unicastConfigFixture returns a unicast config fixture. +// The unicast config is initialized with the default values. +func unicastConfigFixture() unicast.Config { + return unicast.Config{ + StreamCreationRetryAttemptBudget: 3, + } +} + +// TestUnicastConfigCache_Adjust tests the Adjust method of the UnicastConfigCache. It asserts that the unicast config is initialized, adjusted, +// and stored in the cache. +func TestUnicastConfigCache_Adjust_Init(t *testing.T) { + sizeLimit := uint32(100) + logger := zerolog.Nop() + collector := metrics.NewNoopCollector() + + unicastFactoryCalled := 0 + unicastConfigFactory := func() unicast.Config { + require.Less(t, unicastFactoryCalled, 2, "unicast config factory must be called at most twice") + unicastFactoryCalled++ + return unicastConfigFixture() + } + adjustFuncIncrement := func(cfg unicast.Config) (unicast.Config, error) { + cfg.StreamCreationRetryAttemptBudget++ + return cfg, nil + } + + cache := unicastcache.NewUnicastConfigCache(sizeLimit, logger, collector, unicastConfigFactory) + require.NotNil(t, cache) + require.Zerof(t, cache.Size(), "cache size must be 0") + + peerID1 := unittest.PeerIdFixture(t) + peerID2 := unittest.PeerIdFixture(t) + + // Initializing the unicast config for peerID1 through GetOrInit. + // unicast config for peerID1 does not exist in the cache, so it must be initialized when using GetOrInit. + cfg, err := cache.GetOrInit(peerID1) + require.NoError(t, err) + require.NotNil(t, cfg, "unicast config must not be nil") + require.Equal(t, unicastConfigFixture(), *cfg, "unicast config must be initialized with the default values") + require.Equal(t, uint(1), cache.Size(), "cache size must be 1") + + // Initializing and adjusting the unicast config for peerID2 through Adjust. + // unicast config for peerID2 does not exist in the cache, so it must be initialized when using Adjust. + cfg, err = cache.Adjust(peerID2, adjustFuncIncrement) + require.NoError(t, err) + // adjusting a non-existing unicast config must not initialize the config. + require.Equal(t, uint(2), cache.Size(), "cache size must be 2") + require.Equal(t, cfg.StreamCreationRetryAttemptBudget, unicastConfigFixture().StreamCreationRetryAttemptBudget+1, "stream backoff must be 2") + + // Retrieving the unicast config of peerID2 through GetOrInit. + // retrieve the unicast config for peerID2 and assert than it is initialized with the default values; and the adjust function is applied. + cfg, err = cache.GetOrInit(peerID2) + require.NoError(t, err, "unicast config must exist in the cache") + require.NotNil(t, cfg, "unicast config must not be nil") + // retrieving an existing unicast config must not change the cache size. + require.Equal(t, uint(2), cache.Size(), "cache size must be 2") + // config should be the same as the one returned by Adjust. + require.Equal(t, cfg.StreamCreationRetryAttemptBudget, unicastConfigFixture().StreamCreationRetryAttemptBudget+1, "stream backoff must be 2") + + // Adjusting the unicast config of peerID1 through Adjust. + // unicast config for peerID1 already exists in the cache, so it must be adjusted when using Adjust. + cfg, err = cache.Adjust(peerID1, adjustFuncIncrement) + require.NoError(t, err) + // adjusting an existing unicast config must not change the cache size. + require.Equal(t, uint(2), cache.Size(), "cache size must be 2") + require.Equal(t, cfg.StreamCreationRetryAttemptBudget, unicastConfigFixture().StreamCreationRetryAttemptBudget+1, "stream backoff must be 2") + + // Recurring adjustment of the unicast config of peerID1 through Adjust. + // unicast config for peerID1 already exists in the cache, so it must be adjusted when using Adjust. + cfg, err = cache.Adjust(peerID1, adjustFuncIncrement) + require.NoError(t, err) + // adjusting an existing unicast config must not change the cache size. + require.Equal(t, uint(2), cache.Size(), "cache size must be 2") + require.Equal(t, cfg.StreamCreationRetryAttemptBudget, unicastConfigFixture().StreamCreationRetryAttemptBudget+2, "stream backoff must be 3") +} + +// TestUnicastConfigCache_Adjust tests the Adjust method of the UnicastConfigCache. It asserts that the unicast config is adjusted, +// and stored in the cache as expected under concurrent adjustments. +func TestUnicastConfigCache_Concurrent_Adjust(t *testing.T) { + sizeLimit := uint32(100) + logger := zerolog.Nop() + collector := metrics.NewNoopCollector() + + cache := unicastcache.NewUnicastConfigCache(sizeLimit, logger, collector, func() unicast.Config { + return unicast.Config{} // empty unicast config + }) + require.NotNil(t, cache) + require.Zerof(t, cache.Size(), "cache size must be 0") + + peerIds := make([]peer.ID, sizeLimit) + for i := 0; i < int(sizeLimit); i++ { + peerId := unittest.PeerIdFixture(t) + require.NotContainsf(t, peerIds, peerId, "peer id must be unique") + peerIds[i] = peerId + } + + wg := sync.WaitGroup{} + for i := 0; i < int(sizeLimit); i++ { + // adjusts the ith unicast config for peerID i times, concurrently. + for j := 0; j < i+1; j++ { + wg.Add(1) + go func(peerId peer.ID) { + defer wg.Done() + _, err := cache.Adjust(peerId, func(cfg unicast.Config) (unicast.Config, error) { + cfg.StreamCreationRetryAttemptBudget++ + return cfg, nil + }) + require.NoError(t, err) + }(peerIds[i]) + } + } + + unittest.RequireReturnsBefore(t, wg.Wait, time.Second*1, "adjustments must be done on time") + + // assert that the cache size is equal to the size limit. + require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") + + // assert that the unicast config for each peer is adjusted i times, concurrently. + for i := 0; i < int(sizeLimit); i++ { + wg.Add(1) + go func(j int) { + wg.Done() + + peerID := peerIds[j] + cfg, err := cache.GetOrInit(peerID) + require.NoError(t, err) + require.Equal(t, + uint64(j+1), + cfg.StreamCreationRetryAttemptBudget, + fmt.Sprintf("peerId %s unicast backoff must be adjusted %d times got: %d", peerID, j+1, cfg.StreamCreationRetryAttemptBudget)) + }(i) + } + + unittest.RequireReturnsBefore(t, wg.Wait, time.Second*1, "retrievals must be done on time") +} + +// TestConcurrent_Adjust_And_Get_Is_Safe tests that concurrent adjustments and retrievals are safe, and do not cause error even if they cause eviction. The test stress tests the cache +// with 2 * SizeLimit concurrent operations (SizeLimit times concurrent adjustments and SizeLimit times concurrent retrievals). +// It asserts that the cache size is equal to the size limit, and the unicast config for each peer is adjusted and retrieved correctly. +func TestConcurrent_Adjust_And_Get_Is_Safe(t *testing.T) { + sizeLimit := uint32(100) + logger := zerolog.Nop() + collector := metrics.NewNoopCollector() + + cache := unicastcache.NewUnicastConfigCache(sizeLimit, logger, collector, unicastConfigFixture) + require.NotNil(t, cache) + require.Zerof(t, cache.Size(), "cache size must be 0") + + wg := sync.WaitGroup{} + for i := 0; i < int(sizeLimit); i++ { + // concurrently adjusts the unicast configs. + wg.Add(1) + go func() { + defer wg.Done() + peerId := unittest.PeerIdFixture(t) + updatedConfig, err := cache.Adjust(peerId, func(cfg unicast.Config) (unicast.Config, error) { + cfg.StreamCreationRetryAttemptBudget = 2 // some random adjustment + cfg.ConsecutiveSuccessfulStream = 3 // some random adjustment + return cfg, nil + }) + require.NoError(t, err) // concurrent adjustment must not fail. + require.Equal(t, uint64(2), updatedConfig.StreamCreationRetryAttemptBudget) + require.Equal(t, uint64(3), updatedConfig.ConsecutiveSuccessfulStream) + }() + } + + // assert that the unicast config for each peer is adjusted i times, concurrently. + for i := 0; i < int(sizeLimit); i++ { + wg.Add(1) + go func() { + wg.Done() + peerId := unittest.PeerIdFixture(t) + cfg, err := cache.GetOrInit(peerId) + require.NoError(t, err) // concurrent retrieval must not fail. + require.Equal(t, unicastConfigFixture().StreamCreationRetryAttemptBudget, cfg.StreamCreationRetryAttemptBudget) + require.Equal(t, uint64(0), cfg.ConsecutiveSuccessfulStream) + }() + } + + unittest.RequireReturnsBefore(t, wg.Wait, time.Second*1, "all operations must be done on time") + + // cache was stress-tested with 2 * SizeLimit concurrent operations. Nevertheless, the cache size must be equal to the size limit due to LRU eviction. + require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") +} + +// TestUnicastConfigCache_LRU_Eviction tests that the cache evicts the least recently used unicast config when the cache size reaches the size limit. +func TestUnicastConfigCache_LRU_Eviction(t *testing.T) { + sizeLimit := uint32(100) + logger := zerolog.Nop() + collector := metrics.NewNoopCollector() + + cache := unicastcache.NewUnicastConfigCache(sizeLimit, logger, collector, unicastConfigFixture) + require.NotNil(t, cache) + require.Zerof(t, cache.Size(), "cache size must be 0") + + peerIds := make([]peer.ID, sizeLimit+1) + for i := 0; i < int(sizeLimit+1); i++ { + peerId := unittest.PeerIdFixture(t) + require.NotContainsf(t, peerIds, peerId, "peer id must be unique") + peerIds[i] = peerId + } + for i := 0; i < int(sizeLimit+1); i++ { + updatedConfig, err := cache.Adjust(peerIds[i], func(cfg unicast.Config) (unicast.Config, error) { + cfg.StreamCreationRetryAttemptBudget = 2 // some random adjustment + cfg.ConsecutiveSuccessfulStream = 3 // some random adjustment + return cfg, nil + }) + require.NoError(t, err) // concurrent adjustment must not fail. + require.Equal(t, uint64(2), updatedConfig.StreamCreationRetryAttemptBudget) + require.Equal(t, uint64(3), updatedConfig.ConsecutiveSuccessfulStream) + } + + // except the first peer id, all other peer ids should stay intact in the cache. + for i := 1; i < int(sizeLimit+1); i++ { + cfg, err := cache.GetOrInit(peerIds[i]) + require.NoError(t, err) + require.Equal(t, uint64(2), cfg.StreamCreationRetryAttemptBudget) + require.Equal(t, uint64(3), cfg.ConsecutiveSuccessfulStream) + } + + require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") + + // querying the first peer id should return a fresh unicast config, + // since it should be evicted due to LRU eviction, and the initiated with the default values. + cfg, err := cache.GetOrInit(peerIds[0]) + require.NoError(t, err) + require.Equal(t, unicastConfigFixture().StreamCreationRetryAttemptBudget, cfg.StreamCreationRetryAttemptBudget) + require.Equal(t, uint64(0), cfg.ConsecutiveSuccessfulStream) + + require.Equal(t, uint(sizeLimit), cache.Size(), "cache size must be equal to the size limit") +} diff --git a/network/p2p/unicast/cache/dialConfigEntity.go b/network/p2p/unicast/cache/unicastConfigEntity.go similarity index 60% rename from network/p2p/unicast/cache/dialConfigEntity.go rename to network/p2p/unicast/cache/unicastConfigEntity.go index 71a9b6844c9..c1db31523fe 100644 --- a/network/p2p/unicast/cache/dialConfigEntity.go +++ b/network/p2p/unicast/cache/unicastConfigEntity.go @@ -7,18 +7,18 @@ import ( "github.com/onflow/flow-go/network/p2p/unicast" ) -// DialConfigEntity is a struct that represents a dial config entry for storing in the dial config cache. +// UnicastConfigEntity is a struct that represents a unicast config entry for storing in the unicast config cache. // It implements the flow.Entity interface. -type DialConfigEntity struct { - unicast.DialConfig - PeerId peer.ID // remote peer id; used as the "key" in the dial config cache. +type UnicastConfigEntity struct { + unicast.Config + PeerId peer.ID // remote peer id; used as the "key" in the unicast config cache. id flow.Identifier // cache the id for fast lookup (HeroCache). } -var _ flow.Entity = (*DialConfigEntity)(nil) +var _ flow.Entity = (*UnicastConfigEntity)(nil) -// ID returns the ID of the dial config entity; it is hash value of the peer id. -func (d DialConfigEntity) ID() flow.Identifier { +// ID returns the ID of the unicast config entity; it is hash value of the peer id. +func (d UnicastConfigEntity) ID() flow.Identifier { if d.id == flow.ZeroID { d.id = PeerIdToFlowId(d.PeerId) } @@ -26,7 +26,7 @@ func (d DialConfigEntity) ID() flow.Identifier { } // Checksum acts the same as ID. -func (d DialConfigEntity) Checksum() flow.Identifier { +func (d UnicastConfigEntity) Checksum() flow.Identifier { return d.ID() } diff --git a/network/p2p/unicast/cache/dialConfigEntity_test.go b/network/p2p/unicast/cache/unicastConfigEntity_test.go similarity index 64% rename from network/p2p/unicast/cache/dialConfigEntity_test.go rename to network/p2p/unicast/cache/unicastConfigEntity_test.go index 87da16248c4..d7bad635c04 100644 --- a/network/p2p/unicast/cache/dialConfigEntity_test.go +++ b/network/p2p/unicast/cache/unicastConfigEntity_test.go @@ -2,7 +2,6 @@ package unicastcache_test import ( "testing" - "time" "github.com/stretchr/testify/require" @@ -11,16 +10,14 @@ import ( "github.com/onflow/flow-go/utils/unittest" ) -// TestDialConfigEntity tests the DialConfigEntity struct and its methods. -func TestDialConfigEntity(t *testing.T) { +// TestUnicastConfigEntity tests the UnicastConfigEntity struct and its methods. +func TestUnicastConfigEntity(t *testing.T) { peerID := unittest.PeerIdFixture(t) - d := &unicastcache.DialConfigEntity{ + d := &unicastcache.UnicastConfigEntity{ PeerId: peerID, - DialConfig: unicast.DialConfig{ - DialRetryAttemptBudget: 10, + Config: unicast.Config{ StreamCreationRetryAttemptBudget: 20, - LastSuccessfulDial: time.Now(), ConsecutiveSuccessfulStream: 30, }, } @@ -39,20 +36,18 @@ func TestDialConfigEntity(t *testing.T) { ) t.Run("ID is only calculated from peer.ID", func(t *testing.T) { - d2 := &unicastcache.DialConfigEntity{ - PeerId: unittest.PeerIdFixture(t), - DialConfig: d.DialConfig, + d2 := &unicastcache.UnicastConfigEntity{ + PeerId: unittest.PeerIdFixture(t), + Config: d.Config, } require.NotEqual(t, d.ID(), d2.ID()) // different peer id, different id. - d3 := &unicastcache.DialConfigEntity{ + d3 := &unicastcache.UnicastConfigEntity{ PeerId: d.PeerId, - DialConfig: unicast.DialConfig{ - DialRetryAttemptBudget: 100, + Config: unicast.Config{ StreamCreationRetryAttemptBudget: 200, - LastSuccessfulDial: time.Now(), }, } - require.Equal(t, d.ID(), d3.ID()) // same peer id, same id, even though the dial config is different. + require.Equal(t, d.ID(), d3.ID()) // same peer id, same id, even though the unicast config is different. }) } diff --git a/network/p2p/unicast/dialConfig.go b/network/p2p/unicast/dialConfig.go index 7d53e829a29..e88c4fd7554 100644 --- a/network/p2p/unicast/dialConfig.go +++ b/network/p2p/unicast/dialConfig.go @@ -1,17 +1,13 @@ package unicast -import "time" - -// DialConfig is a struct that represents the dial config for a peer. -type DialConfig struct { - DialRetryAttemptBudget uint64 // number of times we have to try to dial the peer before we give up. - StreamCreationRetryAttemptBudget uint64 // number of times we have to try to open a stream to the peer before we give up. - LastSuccessfulDial time.Time // timestamp of the last successful dial to the peer. - ConsecutiveSuccessfulStream uint64 // consecutive number of successful streams to the peer since the last time stream creation failed. +// Config is a struct that represents the dial config for a peer. +type Config struct { + StreamCreationRetryAttemptBudget uint64 // number of times we have to try to open a stream to the peer before we give up. + ConsecutiveSuccessfulStream uint64 // consecutive number of successful streams to the peer since the last time stream creation failed. } -// DialConfigAdjustFunc is a function that is used to adjust the fields of a DialConfigEntity. +// UnicastConfigAdjustFunc is a function that is used to adjust the fields of a DialConfigEntity. // The function is called with the current config and should return the adjusted record. // Returned error indicates that the adjustment is not applied, and the config should not be updated. // In BFT setup, the returned error should be treated as a fatal error. -type DialConfigAdjustFunc func(DialConfig) (DialConfig, error) +type UnicastConfigAdjustFunc func(Config) (Config, error) diff --git a/network/p2p/unicast/dialConfigCache.go b/network/p2p/unicast/dialConfigCache.go index 879e2756d49..fc4c3199b5b 100644 --- a/network/p2p/unicast/dialConfigCache.go +++ b/network/p2p/unicast/dialConfigCache.go @@ -4,17 +4,17 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ) -// DialConfigCache is a thread-safe cache for dial configs. It is used by the unicast service to store +// ConfigCache is a thread-safe cache for dial configs. It is used by the unicast service to store // the dial configs for peers. -type DialConfigCache interface { +type ConfigCache interface { // GetOrInit returns the dial config for the given peer id. If the config does not exist, it creates a new config // using the factory function and stores it in the cache. // Args: // - peerID: the peer id of the dial config. // Returns: - // - *DialConfig, the dial config for the given peer id. + // - *Config, the dial config for the given peer id. // - error if the factory function returns an error. Any error should be treated as an irrecoverable error and indicates a bug. - GetOrInit(peerID peer.ID) (*DialConfig, error) + GetOrInit(peerID peer.ID) (*Config, error) // Adjust adjusts the dial config for the given peer id using the given adjustFunc. // It returns an error if the adjustFunc returns an error. @@ -23,7 +23,7 @@ type DialConfigCache interface { // - adjustFunc: the function that adjusts the dial config. // Returns: // - error if the adjustFunc returns an error. Any error should be treated as an irrecoverable error and indicates a bug. - Adjust(peerID peer.ID, adjustFunc DialConfigAdjustFunc) (*DialConfig, error) + Adjust(peerID peer.ID, adjustFunc UnicastConfigAdjustFunc) (*Config, error) // Size returns the number of dial configs in the cache. Size() uint diff --git a/network/p2p/unicast/errors.go b/network/p2p/unicast/errors.go index d8abb2624f7..99bb8bdeaed 100644 --- a/network/p2p/unicast/errors.go +++ b/network/p2p/unicast/errors.go @@ -3,32 +3,8 @@ package unicast import ( "errors" "fmt" - - "github.com/libp2p/go-libp2p/core/peer" - - "github.com/onflow/flow-go/network/p2p/p2plogging" ) -// ErrDialInProgress indicates that the libp2p node is currently dialing the peer. -type ErrDialInProgress struct { - pid peer.ID -} - -func (e ErrDialInProgress) Error() string { - return fmt.Sprintf("dialing to peer %s already in progress", p2plogging.PeerId(e.pid)) -} - -// NewDialInProgressErr returns a new ErrDialInProgress. -func NewDialInProgressErr(pid peer.ID) ErrDialInProgress { - return ErrDialInProgress{pid: pid} -} - -// IsErrDialInProgress returns whether an error is ErrDialInProgress -func IsErrDialInProgress(err error) bool { - var e ErrDialInProgress - return errors.As(err, &e) -} - // ErrMaxRetries indicates retries completed with max retries without a successful attempt. type ErrMaxRetries struct { attempts uint64 diff --git a/network/p2p/unicast/manager.go b/network/p2p/unicast/manager.go index b5dea5de02b..16b4dce703b 100644 --- a/network/p2p/unicast/manager.go +++ b/network/p2p/unicast/manager.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "sync" "time" "github.com/go-playground/validator/v10" @@ -34,7 +33,7 @@ var ( _ p2p.UnicastManager = (*Manager)(nil) ) -type DialConfigCacheFactory func(configFactory func() DialConfig) DialConfigCache +type DialConfigCacheFactory func(configFactory func() Config) ConfigCache // Manager manages libp2p stream negotiation and creation, which is utilized for unicast dispatches. type Manager struct { @@ -43,8 +42,6 @@ type Manager struct { protocols []protocols.Protocol defaultHandler libp2pnet.StreamHandler sporkId flow.Identifier - connStatus p2p.PeerConnections - peerDialing sync.Map metrics module.UnicastManagerMetrics // createStreamBackoffDelay is the delay between each stream creation retry attempt. @@ -52,19 +49,9 @@ type Manager struct { // is the initial delay between each retry attempt. The delay is doubled after each retry attempt. createStreamBackoffDelay time.Duration - // dialInProgressBackoffDelay is the backoff delay for parallel attempts on dialing to the same peer. - // When the unicast manager is invoked to create stream to the same peer concurrently while there is - // already an ongoing dialing attempt to the same peer, the unicast manager will wait for this backoff delay - // and retry creating the stream after the backoff delay has elapsed. This is to prevent the unicast manager - // from creating too many parallel dialing attempts to the same peer. - dialInProgressBackoffDelay time.Duration - - // dialBackoffDelay is the backoff delay between retrying connection to the same peer. - dialBackoffDelay time.Duration - // dialConfigCache is a cache to store the dial config for each peer. // TODO: encapsulation can be further improved by wrapping the dialConfigCache together with the dial config adjustment logic into a single struct. - dialConfigCache DialConfigCache + dialConfigCache ConfigCache // streamZeroBackoffResetThreshold is the threshold that determines when to reset the stream creation backoff budget to the default value. // @@ -78,21 +65,6 @@ type Manager struct { // 100 stream creations are all successful. streamZeroBackoffResetThreshold uint64 - // dialZeroBackoffResetThreshold is the threshold that determines when to reset the dial backoff budget to the default value. - // - // For example the threshold of 1 hour means that if the dial backoff budget is decreased to 0, then it will be reset to default value - // when it has been 1 hour since the last successful dial. - // - // This is to prevent the backoff budget from being reset too frequently, as the backoff budget is used to gauge the reliability of the dialing a remote peer. - // When the dial backoff budget is reset to the default value, it means that the dialing is reliable enough to be trusted again. - // This parameter mandates when the dialing is reliable enough to be trusted again; i.e., when it has been 1 hour since the last successful dial. - // Note that the last dial attempt timestamp is reset to zero when the dial fails, so the value of for example 1 hour means that the dialing to the remote peer is reliable enough that the last - // successful dial attempt was 1 hour ago. - dialZeroBackoffResetThreshold time.Duration - - // maxDialAttemptTimes is the maximum number of attempts to be made to connect to a remote node to establish a unicast (1:1) connection before we give up. - maxDialAttemptTimes uint64 - // maxStreamCreationAttemptTimes is the maximum number of attempts to be made to create a stream to a remote node over a direct unicast (1:1) connection before we give up. maxStreamCreationAttemptTimes uint64 } @@ -111,33 +83,23 @@ func NewUnicastManager(cfg *ManagerConfig) (*Manager, error) { m := &Manager{ logger: cfg.Logger.With().Str("module", "unicast-manager").Logger(), - dialConfigCache: cfg.DialConfigCacheFactory(func() DialConfig { - return DialConfig{ + dialConfigCache: cfg.UnicastConfigCacheFactory(func() Config { + return Config{ StreamCreationRetryAttemptBudget: cfg.MaxStreamCreationRetryAttemptTimes, - DialRetryAttemptBudget: cfg.MaxDialRetryAttemptTimes, } }), streamFactory: cfg.StreamFactory, sporkId: cfg.SporkId, - connStatus: cfg.ConnStatus, - peerDialing: sync.Map{}, metrics: cfg.Metrics, createStreamBackoffDelay: cfg.CreateStreamBackoffDelay, - dialBackoffDelay: cfg.DialBackoffDelay, - dialInProgressBackoffDelay: cfg.DialInProgressBackoffDelay, streamZeroBackoffResetThreshold: cfg.StreamZeroRetryResetThreshold, - dialZeroBackoffResetThreshold: cfg.DialZeroRetryResetThreshold, maxStreamCreationAttemptTimes: cfg.MaxStreamCreationRetryAttemptTimes, - maxDialAttemptTimes: cfg.MaxDialRetryAttemptTimes, } m.logger.Info(). Hex("spork_id", logging.ID(cfg.SporkId)). Dur("create_stream_backoff_delay", cfg.CreateStreamBackoffDelay). - Dur("dial_backoff_delay", cfg.DialBackoffDelay). - Dur("dial_in_progress_backoff_delay", cfg.DialInProgressBackoffDelay). Uint64("stream_zero_backoff_reset_threshold", cfg.StreamZeroRetryResetThreshold). - Dur("dial_zero_backoff_reset_threshold", cfg.DialZeroRetryResetThreshold). Msg("unicast manager created") return m, nil @@ -207,7 +169,7 @@ func (m *Manager) CreateStream(ctx context.Context, peerID peer.ID) (libp2pnet.S Msg("dial config for the peer retrieved") for i := len(m.protocols) - 1; i >= 0; i-- { - s, err := m.tryCreateStream(ctx, peerID, m.protocols[i], dialCfg) + s, err := m.createStream(ctx, peerID, m.protocols[i], dialCfg) if err != nil { errs = multierror.Append(errs, err) continue @@ -217,8 +179,7 @@ func (m *Manager) CreateStream(ctx context.Context, peerID peer.ID) (libp2pnet.S return s, nil } - connected, connErr := m.connStatus.IsConnected(peerID) // we don't check connErr as it indicates that the peer is not connected. - updatedCfg, err := m.adjustUnsuccessfulStreamAttempt(peerID, connErr == nil && connected) + updatedCfg, err := m.adjustUnsuccessfulStreamAttempt(peerID) if err != nil { // TODO: technically, we better to return an error here, but the error must be irrecoverable, and we cannot // guarantee a clear distinction between recoverable and irrecoverable errors at the moment with CreateStream. @@ -235,64 +196,45 @@ func (m *Manager) CreateStream(ctx context.Context, peerID peer.ID) (libp2pnet.S Bool(logging.KeySuspicious, true). Str("peer_id", p2plogging.PeerId(peerID)). Str("dial_config", fmt.Sprintf("%+v", updatedCfg)). - Bool("is_connected", err == nil && connected). Msg("failed to create stream to peer id, dial config adjusted") return nil, fmt.Errorf("could not create stream on any available unicast protocol: %w", errs) } -// tryCreateStream will retry createStream with the configured exponential backoff delay and maxAttempts. -// During retries, each error encountered is aggregated in a multierror. If max attempts are made before a -// stream can be successfully the multierror will be returned. During stream creation when IsErrDialInProgress -// is encountered during retries this would indicate that no connection to the peer exists yet. -// In this case we will retry creating the stream with a backoff until a connection is established. -func (m *Manager) tryCreateStream(ctx context.Context, peerID peer.ID, protocol protocols.Protocol, dialCfg *DialConfig) (libp2pnet.Stream, error) { +// createStream attempts to establish a new stream with a peer using the specified protocol. It employs +// exponential backoff with a maximum number of attempts defined by dialCfg.StreamCreationRetryAttemptBudget. +// If the stream cannot be established after the maximum attempts, it returns a compiled multierror of all +// encountered errors. Errors related to in-progress dials trigger a retry until a connection is established +// or the attempt budget is exhausted. +// +// The function increments the Config's ConsecutiveSuccessfulStream count upon success. In the case of +// adjustment errors in Config, a fatal error is logged indicating an issue that requires attention. +// Metrics are collected to monitor the duration and number of attempts for stream creation. +// +// Arguments: +// - ctx: Context to control the lifecycle of the stream creation. +// - peerID: The ID of the peer with which the stream is to be established. +// - protocol: The specific protocol used for the stream. +// - dialCfg: Configuration parameters for dialing and stream creation, including retry logic. +// +// Returns: +// - libp2pnet.Stream: The successfully created stream, or nil if the stream creation fails. +// - error: An aggregated multierror of all encountered errors during stream creation, or nil if successful; any returned error is benign and can be retried. +func (m *Manager) createStream(ctx context.Context, peerID peer.ID, protocol protocols.Protocol, dialCfg *Config) (libp2pnet.Stream, error) { var err error var s libp2pnet.Stream - // backoff delay for dial in progress errors; this backoff delay only kicks in if there is no connection to the peer - // and there is already a dial in progress to the peer. - backoff := retry.NewExponential(m.dialInProgressBackoffDelay) - // https://github.com/sethvargo/go-retry#maxretries retries counter starts at zero and library will make last attempt - // when retries == maxAttempts causing 1 more func invocation than expected. - maxRetries := dialCfg.StreamCreationRetryAttemptBudget - backoff = retry.WithMaxRetries(maxRetries, backoff) - - attempts := 0 - // retryable func will attempt to create the stream and only retry if dialing the peer is in progress - f := func(context.Context) error { - attempts++ - s, err = m.rawStreamWithProtocol(ctx, protocol.ProtocolId(), peerID, dialCfg) - if err != nil { - if IsErrDialInProgress(err) { - m.logger.Warn(). - Err(err). - Str("peer_id", p2plogging.PeerId(peerID)). - Int("attempt", attempts). - Uint64("max_retries", maxRetries). - Msg("retrying create stream, dial to peer in progress") - return retry.RetryableError(err) - } - return err - } - - s, err = protocol.UpgradeRawStream(s) - if err != nil { - return fmt.Errorf("failed to upgrade raw stream: %w", err) - } - - return nil + s, err = m.createStreamWithRetry(ctx, peerID, protocol.ProtocolId(), dialCfg) + if err != nil { + return nil, fmt.Errorf("failed to create a stream to peer: %w", err) } - start := time.Now() - err = retry.Do(ctx, backoff, f) - duration := time.Since(start) + + s, err = protocol.UpgradeRawStream(s) if err != nil { - m.metrics.OnStreamCreationFailure(duration, attempts) - return nil, err + return nil, fmt.Errorf("failed to upgrade raw stream: %w", err) } - m.metrics.OnStreamCreated(duration, attempts) - updatedConfig, err := m.dialConfigCache.Adjust(peerID, func(config DialConfig) (DialConfig, error) { + updatedConfig, err := m.dialConfigCache.Adjust(peerID, func(config Config) (Config, error) { config.ConsecutiveSuccessfulStream++ // increase consecutive successful stream count. return config, nil }) @@ -314,120 +256,26 @@ func (m *Manager) tryCreateStream(ctx context.Context, peerID peer.ID, protocol return s, nil } -// rawStreamWithProtocol creates a stream raw libp2p stream on specified protocol. +// createStreamWithRetry attempts to create a new stream to the specified peer using the given protocolID. +// This function is streamlined for use-cases where retries are managed externally or +// not required at all. // -// Note: a raw stream must be upgraded by the given unicast protocol id. +// Expected errors: +// - If the context expires before stream creation, it returns a context-related error with the number of attempts. +// - If the protocol ID is not supported, no retries are attempted and the error is returned immediately. // -// It makes at most `maxAttempts` to create a stream with the peer. -// This was put in as a fix for #2416. PubSub and 1-1 communication compete with each other when trying to connect to -// remote nodes and once in a while NewStream returns an error 'both yamux endpoints are clients'. +// Metrics are collected to monitor the duration and attempts of the stream creation process. // -// Note that in case an existing TCP connection underneath to `peerID` exists, that connection is utilized for creating a new stream. -// The multiaddr.Multiaddr return value represents the addresses of `peerID` we dial while trying to create a stream to it, the -// multiaddr is only returned when a peer is initially dialed. -// Expected errors during normal operations: -// - ErrDialInProgress if no connection to the peer exists and there is already a dial in progress to the peer. If a dial to -// the peer is already in progress the caller needs to wait until it is completed, a peer should be dialed only once. +// Arguments: +// - ctx: Context to control the lifecycle of the stream creation. +// - peerID: The ID of the peer with which the stream is to be established. +// - protocolID: The identifier for the protocol used for the stream. +// - dialCfg: Configuration parameters for dialing, including the retry attempt budget. // -// Unexpected errors during normal operations: -// - network.ErrIllegalConnectionState indicates bug in libpp2p when checking IsConnected status of peer. -func (m *Manager) rawStreamWithProtocol(ctx context.Context, protocolID protocol.ID, peerID peer.ID, dialCfg *DialConfig) (libp2pnet.Stream, error) { - isConnected, err := m.connStatus.IsConnected(peerID) - if err != nil { - return nil, err - } - - // check connection status and attempt to dial the peer if dialing is not in progress - if !isConnected { - // return error if we can't start dialing - if _, inProgress := m.peerDialing.LoadOrStore(peerID, struct{}{}); inProgress { - return nil, NewDialInProgressErr(peerID) - } - defer m.peerDialing.Delete(peerID) - err := m.dialPeer(ctx, peerID, dialCfg) - if err != nil { - return nil, err - } - } - - // at this point dialing should have completed, we are already connected we can attempt to create the stream - s, err := m.rawStream(ctx, peerID, protocolID, dialCfg) - if err != nil { - return nil, err - } - - return s, nil -} - -// dialPeer dial peer with retries. -// Expected errors during normal operations: -// - ErrMaxRetries if retry attempts are exhausted -func (m *Manager) dialPeer(ctx context.Context, peerID peer.ID, dialCfg *DialConfig) error { - // aggregated retryable errors that occur during retries, errs will be returned - // if retry context times out or maxAttempts have been made before a successful retry occurs - var errs error - dialAttempts := 0 - backoff := retryBackoff(dialCfg.DialRetryAttemptBudget, m.dialBackoffDelay) - f := func(context.Context) error { - dialAttempts++ - select { - case <-ctx.Done(): - return fmt.Errorf("context done before stream could be created (retry attempt: %d, errors: %w)", dialAttempts, errs) - default: - } - err := m.streamFactory.Connect(ctx, peer.AddrInfo{ID: peerID}) - if err != nil { - // if the connection was rejected due to invalid node id or - // if the connection was rejected due to connection gating skip the re-attempt - // if there is no address for the peer skip the re-attempt - if stream.IsErrSecurityProtocolNegotiationFailed(err) || stream.IsErrGaterDisallowedConnection(err) || errors.Is(err, swarm.ErrNoAddresses) { - return multierror.Append(errs, err) - } - m.logger.Warn(). - Err(err). - Str("peer_id", p2plogging.PeerId(peerID)). - Int("attempt", dialAttempts). - Uint64("max_attempts", dialCfg.DialRetryAttemptBudget). - Msg("retrying peer dialing") - return retry.RetryableError(multierror.Append(errs, err)) - } - updatedConfig, err := m.dialConfigCache.Adjust(peerID, func(config DialConfig) (DialConfig, error) { - config.LastSuccessfulDial = time.Now() // update last successful dial time - return config, nil - }) - if err != nil { - // This is not a connection retryable error, this is a fatal error. - // TODO: technically, we better to return an error here, but the error must be irrecoverable, and we cannot - // guarantee a clear distinction between recoverable and irrecoverable errors at the moment with CreateStream. - // We have to revisit this once we studied the error handling paths in the unicast manager. - m.logger.Fatal(). - Err(err). - Bool(logging.KeyNetworkingSecurity, true). - Str("peer_id", p2plogging.PeerId(peerID)). - Msg("failed to adjust dial config for peer id") - } - m.logger.Info(). - Str("peer_id", p2plogging.PeerId(peerID)). - Str("updated_dial_config", fmt.Sprintf("%+v", updatedConfig)). - Msg("peer dialed successfully") - return nil - } - - start := time.Now() - err := retry.Do(ctx, backoff, f) - duration := time.Since(start) - if err != nil { - m.metrics.OnPeerDialFailure(duration, dialAttempts) - return retryFailedError(uint64(dialAttempts), dialCfg.DialRetryAttemptBudget, fmt.Errorf("failed to dial peer %s: %w", p2plogging.PeerId(peerID), err)) - } - m.metrics.OnPeerDialed(duration, dialAttempts) - return nil -} - -// rawStream creates a stream to peer with retries. -// Expected errors during normal operations: -// - ErrMaxRetries if retry attempts are exhausted -func (m *Manager) rawStream(ctx context.Context, peerID peer.ID, protocolID protocol.ID, dialCfg *DialConfig) (libp2pnet.Stream, error) { +// Returns: +// - libp2pnet.Stream: The successfully created stream, or nil if an error occurs. +// - error: An error encountered during the stream creation, or nil if the stream is successfully established. +func (m *Manager) createStreamWithRetry(ctx context.Context, peerID peer.ID, protocolID protocol.ID, dialCfg *Config) (libp2pnet.Stream, error) { // aggregated retryable errors that occur during retries, errs will be returned // if retry context times out or maxAttempts have been made before a successful retry occurs var errs error @@ -442,14 +290,14 @@ func (m *Manager) rawStream(ctx context.Context, peerID peer.ID, protocolID prot } var err error - // add libp2p context value NoDial to prevent the underlying host from dialingComplete the peer while creating the stream - // we've already ensured that a connection already exists. - ctx = libp2pnet.WithNoDial(ctx, "application ensured connection to peer exists") // creates stream using stream factory s, err = m.streamFactory.NewStream(ctx, peerID, protocolID) if err != nil { - // if the stream creation failed due to invalid protocol id, skip the re-attempt - if stream.IsErrProtocolNotSupported(err) { + // if the stream creation failed due to invalid protocol id or no address, skip the re-attempt + if stream.IsErrProtocolNotSupported(err) || + errors.Is(err, swarm.ErrNoAddresses) || + stream.IsErrSecurityProtocolNegotiationFailed(err) || + stream.IsErrGaterDisallowedConnection(err) { return err } return retry.RetryableError(multierror.Append(errs, err)) @@ -504,7 +352,7 @@ func retryFailedError(dialAttempts, maxAttempts uint64, err error) error { // Returns: // - dial config for the given peer id. // - error if the dial config cannot be retrieved or adjusted; any error is irrecoverable and indicates a fatal error. -func (m *Manager) getDialConfig(peerID peer.ID) (*DialConfig, error) { +func (m *Manager) getDialConfig(peerID peer.ID) (*Config, error) { dialCfg, err := m.dialConfigCache.GetOrInit(peerID) if err != nil { return nil, fmt.Errorf("failed to get or init dial config for peer id: %w", err) @@ -513,7 +361,7 @@ func (m *Manager) getDialConfig(peerID peer.ID) (*DialConfig, error) { if dialCfg.StreamCreationRetryAttemptBudget == uint64(0) && dialCfg.ConsecutiveSuccessfulStream >= m.streamZeroBackoffResetThreshold { // reset the stream creation backoff budget to the default value if the number of consecutive successful streams reaches the threshold, // as the stream creation is reliable enough to be trusted again. - dialCfg, err = m.dialConfigCache.Adjust(peerID, func(config DialConfig) (DialConfig, error) { + dialCfg, err = m.dialConfigCache.Adjust(peerID, func(config Config) (Config, error) { config.StreamCreationRetryAttemptBudget = m.maxStreamCreationAttemptTimes m.metrics.OnStreamCreationRetryBudgetUpdated(config.StreamCreationRetryAttemptBudget) m.metrics.OnStreamCreationRetryBudgetResetToDefault() @@ -523,21 +371,6 @@ func (m *Manager) getDialConfig(peerID peer.ID) (*DialConfig, error) { return nil, fmt.Errorf("failed to adjust dial config for peer id (resetting stream creation attempt budget): %w", err) } } - if dialCfg.DialRetryAttemptBudget == uint64(0) && - !dialCfg.LastSuccessfulDial.IsZero() && // if the last successful dial time is zero, it means that we have never successfully dialed to the peer, so we should not reset the dial backoff budget. - time.Since(dialCfg.LastSuccessfulDial) >= m.dialZeroBackoffResetThreshold { - // reset the dial backoff budget to the default value if the last successful dial was long enough ago, - // as the dialing is reliable enough to be trusted again. - dialCfg, err = m.dialConfigCache.Adjust(peerID, func(config DialConfig) (DialConfig, error) { - config.DialRetryAttemptBudget = m.maxDialAttemptTimes - m.metrics.OnDialRetryBudgetUpdated(config.DialRetryAttemptBudget) - m.metrics.OnDialRetryBudgetResetToDefault() - return config, nil - }) - if err != nil { - return nil, fmt.Errorf("failed to adjust dial config for peer id (resetting dial attempt budget): %w", err) - } - } return dialCfg, nil } @@ -551,28 +384,18 @@ func (m *Manager) getDialConfig(peerID peer.ID) (*DialConfig, error) { // - dial config for the given peer id. // - connected indicates whether there is a connection to the peer. // - error if the dial config cannot be adjusted; any error is irrecoverable and indicates a fatal error. -func (m *Manager) adjustUnsuccessfulStreamAttempt(peerID peer.ID, connected bool) (*DialConfig, error) { - updatedCfg, err := m.dialConfigCache.Adjust(peerID, func(config DialConfig) (DialConfig, error) { +func (m *Manager) adjustUnsuccessfulStreamAttempt(peerID peer.ID) (*Config, error) { + updatedCfg, err := m.dialConfigCache.Adjust(peerID, func(config Config) (Config, error) { // consecutive successful stream count is reset to 0 if we fail to create a stream or connection to the peer. config.ConsecutiveSuccessfulStream = 0 - if !connected { - // if no connections could be established to the peer, we will try to dial with a more strict dial config next time. - if config.DialRetryAttemptBudget > 0 { - config.DialRetryAttemptBudget-- - m.metrics.OnDialRetryBudgetUpdated(config.DialRetryAttemptBudget) - } - // last successful dial time is reset to 0 if we fail to create a stream to the peer. - config.LastSuccessfulDial = time.Time{} - - } else { - // there is a connection to the peer it means that the stream creation failed, hence we decrease the stream backoff budget - // to try to create a stream with a more strict dial config next time. - if config.StreamCreationRetryAttemptBudget > 0 { - config.StreamCreationRetryAttemptBudget-- - m.metrics.OnStreamCreationRetryBudgetUpdated(config.StreamCreationRetryAttemptBudget) - } + // there is a connection to the peer it means that the stream creation failed, hence we decrease the stream backoff budget + // to try to create a stream with a more strict dial config next time. + if config.StreamCreationRetryAttemptBudget > 0 { + config.StreamCreationRetryAttemptBudget-- + m.metrics.OnStreamCreationRetryBudgetUpdated(config.StreamCreationRetryAttemptBudget) } + return config, nil }) diff --git a/network/p2p/unicast/manager_config.go b/network/p2p/unicast/manager_config.go index 438e05daa75..eac00c76611 100644 --- a/network/p2p/unicast/manager_config.go +++ b/network/p2p/unicast/manager_config.go @@ -14,22 +14,11 @@ type ManagerConfig struct { Logger zerolog.Logger `validate:"required"` StreamFactory p2p.StreamFactory `validate:"required"` SporkId flow.Identifier `validate:"required"` - ConnStatus p2p.PeerConnections `validate:"required"` Metrics module.UnicastManagerMetrics `validate:"required"` // CreateStreamBackoffDelay is the backoff delay between retrying stream creations to the same peer. CreateStreamBackoffDelay time.Duration `validate:"gt=0"` - // DialInProgressBackoffDelay is the backoff delay for parallel attempts on dialing to the same peer. - // When the unicast manager is invoked to create stream to the same peer concurrently while there is - // already an ongoing dialing attempt to the same peer, the unicast manager will wait for this backoff delay - // and retry creating the stream after the backoff delay has elapsed. This is to prevent the unicast manager - // from creating too many parallel dialing attempts to the same peer. - DialInProgressBackoffDelay time.Duration `validate:"gt=0"` - - // DialBackoffDelay is the backoff delay between retrying connection to the same peer. - DialBackoffDelay time.Duration `validate:"gt=0"` - // StreamZeroRetryResetThreshold is the threshold that determines when to reset the stream creation retry budget to the default value. // // For example the default value of 100 means that if the stream creation retry budget is decreased to 0, then it will be reset to default value @@ -42,23 +31,9 @@ type ManagerConfig struct { // 100 stream creations are all successful. StreamZeroRetryResetThreshold uint64 `validate:"gt=0"` - // DialZeroRetryResetThreshold is the threshold that determines when to reset the dial retry budget to the default value. - // For example the threshold of 1 hour means that if the dial retry budget is decreased to 0, then it will be reset to default value - // when it has been 1 hour since the last successful dial. - // - // This is to prevent the retry budget from being reset too frequently, as the retry budget is used to gauge the reliability of the dialing a remote peer. - // When the dial retry budget is reset to the default value, it means that the dialing is reliable enough to be trusted again. - // This parameter mandates when the dialing is reliable enough to be trusted again; i.e., when it has been 1 hour since the last successful dial. - // Note that the last dial attempt timestamp is reset to zero when the dial fails, so the value of for example 1 hour means that the dialing to the remote peer is reliable enough that the last - // successful dial attempt was 1 hour ago. - DialZeroRetryResetThreshold time.Duration `validate:"gt=0"` - - // MaxDialRetryAttemptTimes is the maximum number of attempts to be made to connect to a remote node to establish a unicast (1:1) connection before we give up. - MaxDialRetryAttemptTimes uint64 `validate:"gt=0"` - // MaxStreamCreationRetryAttemptTimes is the maximum number of attempts to be made to create a stream to a remote node over a direct unicast (1:1) connection before we give up. MaxStreamCreationRetryAttemptTimes uint64 `validate:"gt=0"` - // DialConfigCacheFactory is a factory function to create a new dial config cache. - DialConfigCacheFactory DialConfigCacheFactory `validate:"required"` + // UnicastConfigCacheFactory is a factory function to create a new dial config cache. + UnicastConfigCacheFactory DialConfigCacheFactory `validate:"required"` } diff --git a/network/p2p/unicast/manager_test.go b/network/p2p/unicast/manager_test.go index ebaed0ccb71..32b35cd9dfc 100644 --- a/network/p2p/unicast/manager_test.go +++ b/network/p2p/unicast/manager_test.go @@ -4,10 +4,8 @@ import ( "context" "fmt" "testing" - "time" libp2pnet "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/stretchr/testify/mock" @@ -15,7 +13,6 @@ import ( "github.com/onflow/flow-go/config" "github.com/onflow/flow-go/module/metrics" - mockmetrics "github.com/onflow/flow-go/module/mock" mockp2p "github.com/onflow/flow-go/network/p2p/mock" p2ptest "github.com/onflow/flow-go/network/p2p/test" "github.com/onflow/flow-go/network/p2p/unicast" @@ -24,20 +21,18 @@ import ( "github.com/onflow/flow-go/utils/unittest" ) -func unicastManagerFixture(t *testing.T) (*unicast.Manager, *mockp2p.StreamFactory, *mockp2p.PeerConnections, unicast.DialConfigCache) { +func unicastManagerFixture(t *testing.T) (*unicast.Manager, *mockp2p.StreamFactory, unicast.ConfigCache) { streamFactory := mockp2p.NewStreamFactory(t) streamFactory.On("SetStreamHandler", mock.AnythingOfType("protocol.ID"), mock.AnythingOfType("network.StreamHandler")).Return().Once() - connStatus := mockp2p.NewPeerConnections(t) cfg, err := config.DefaultConfig() require.NoError(t, err) - dialConfigCache := unicastcache.NewDialConfigCache(cfg.NetworkConfig.UnicastConfig.DialConfigCacheSize, + unicastConfigCache := unicastcache.NewUnicastConfigCache(cfg.NetworkConfig.UnicastConfig.ConfigCacheSize, unittest.Logger(), metrics.NewNoopCollector(), - func() unicast.DialConfig { - return unicast.DialConfig{ - DialRetryAttemptBudget: cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, + func() unicast.Config { + return unicast.Config{ StreamCreationRetryAttemptBudget: cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, } }) @@ -46,23 +41,18 @@ func unicastManagerFixture(t *testing.T) (*unicast.Manager, *mockp2p.StreamFacto Logger: unittest.Logger(), StreamFactory: streamFactory, SporkId: unittest.IdentifierFixture(), - ConnStatus: connStatus, CreateStreamBackoffDelay: cfg.NetworkConfig.UnicastConfig.CreateStreamBackoffDelay, Metrics: metrics.NewNoopCollector(), StreamZeroRetryResetThreshold: cfg.NetworkConfig.UnicastConfig.StreamZeroRetryResetThreshold, - DialZeroRetryResetThreshold: cfg.NetworkConfig.UnicastConfig.DialZeroRetryResetThreshold, MaxStreamCreationRetryAttemptTimes: cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - MaxDialRetryAttemptTimes: cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, - DialInProgressBackoffDelay: cfg.NetworkConfig.UnicastConfig.DialInProgressBackoffDelay, - DialBackoffDelay: cfg.NetworkConfig.UnicastConfig.DialBackoffDelay, - DialConfigCacheFactory: func(func() unicast.DialConfig) unicast.DialConfigCache { - return dialConfigCache + UnicastConfigCacheFactory: func(func() unicast.Config) unicast.ConfigCache { + return unicastConfigCache }, }) require.NoError(t, err) mgr.SetDefaultHandler(func(libp2pnet.Stream) {}) // no-op handler, we don't care about the handler for this test - return mgr, streamFactory, connStatus, dialConfigCache + return mgr, streamFactory, unicastConfigCache } // TestManagerConfigValidation tests the validation of the unicast manager config. @@ -75,22 +65,16 @@ func TestManagerConfigValidation(t *testing.T) { Logger: unittest.Logger(), StreamFactory: mockp2p.NewStreamFactory(t), SporkId: unittest.IdentifierFixture(), - ConnStatus: mockp2p.NewPeerConnections(t), CreateStreamBackoffDelay: cfg.NetworkConfig.UnicastConfig.CreateStreamBackoffDelay, Metrics: metrics.NewNoopCollector(), StreamZeroRetryResetThreshold: cfg.NetworkConfig.UnicastConfig.StreamZeroRetryResetThreshold, - DialZeroRetryResetThreshold: cfg.NetworkConfig.UnicastConfig.DialZeroRetryResetThreshold, MaxStreamCreationRetryAttemptTimes: cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - MaxDialRetryAttemptTimes: cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, - DialInProgressBackoffDelay: cfg.NetworkConfig.UnicastConfig.DialInProgressBackoffDelay, - DialBackoffDelay: cfg.NetworkConfig.UnicastConfig.DialBackoffDelay, - DialConfigCacheFactory: func(func() unicast.DialConfig) unicast.DialConfigCache { - return unicastcache.NewDialConfigCache(cfg.NetworkConfig.UnicastConfig.DialConfigCacheSize, + UnicastConfigCacheFactory: func(func() unicast.Config) unicast.ConfigCache { + return unicastcache.NewUnicastConfigCache(cfg.NetworkConfig.UnicastConfig.ConfigCacheSize, unittest.Logger(), metrics.NewNoopCollector(), - func() unicast.DialConfig { - return unicast.DialConfig{ - DialRetryAttemptBudget: cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, + func() unicast.Config { + return unicast.Config{ StreamCreationRetryAttemptBudget: cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, } }) @@ -118,22 +102,6 @@ func TestManagerConfigValidation(t *testing.T) { require.Nil(t, mgr) }) - t.Run("Invalid DialInProgressBackoffDelay", func(t *testing.T) { - cfg := validConfig - cfg.DialInProgressBackoffDelay = 0 - mgr, err := unicast.NewUnicastManager(&cfg) - require.Error(t, err) - require.Nil(t, mgr) - }) - - t.Run("Invalid DialBackoffDelay", func(t *testing.T) { - cfg := validConfig - cfg.DialBackoffDelay = 0 - mgr, err := unicast.NewUnicastManager(&cfg) - require.Error(t, err) - require.Nil(t, mgr) - }) - t.Run("Invalid StreamZeroRetryResetThreshold", func(t *testing.T) { cfg := validConfig cfg.StreamZeroRetryResetThreshold = 0 @@ -142,22 +110,6 @@ func TestManagerConfigValidation(t *testing.T) { require.Nil(t, mgr) }) - t.Run("Invalid DialZeroRetryResetThreshold", func(t *testing.T) { - cfg := validConfig - cfg.DialZeroRetryResetThreshold = 0 - mgr, err := unicast.NewUnicastManager(&cfg) - require.Error(t, err) - require.Nil(t, mgr) - }) - - t.Run("Invalid MaxDialRetryAttemptTimes", func(t *testing.T) { - cfg := validConfig - cfg.MaxDialRetryAttemptTimes = 0 - mgr, err := unicast.NewUnicastManager(&cfg) - require.Error(t, err) - require.Nil(t, mgr) - }) - t.Run("Invalid MaxStreamCreationRetryAttemptTimes", func(t *testing.T) { cfg := validConfig cfg.MaxStreamCreationRetryAttemptTimes = 0 @@ -166,9 +118,9 @@ func TestManagerConfigValidation(t *testing.T) { require.Nil(t, mgr) }) - t.Run("Invalid DialConfigCacheFactory", func(t *testing.T) { + t.Run("Invalid UnicastConfigCacheFactory", func(t *testing.T) { cfg := validConfig - cfg.DialConfigCacheFactory = nil + cfg.UnicastConfigCacheFactory = nil mgr, err := unicast.NewUnicastManager(&cfg) require.Error(t, err) require.Nil(t, mgr) @@ -182,14 +134,6 @@ func TestManagerConfigValidation(t *testing.T) { require.Nil(t, mgr) }) - t.Run("Missing ConnStatus", func(t *testing.T) { - cfg := validConfig - cfg.ConnStatus = nil - mgr, err := unicast.NewUnicastManager(&cfg) - require.Error(t, err) - require.Nil(t, mgr) - }) - t.Run("Missing Metrics", func(t *testing.T) { cfg := validConfig cfg.Metrics = nil @@ -199,140 +143,79 @@ func TestManagerConfigValidation(t *testing.T) { }) } -// TestUnicastManager_StreamFactory_ConnectionBackoff tests the backoff mechanism of the unicast manager for connection creation. -// It tests that when there is no connection, it tries to connect to the peer some number of times (unicastmodel.MaxDialAttemptTimes), before -// giving up. -func TestUnicastManager_Connection_ConnectionBackoff(t *testing.T) { - peerID := unittest.PeerIdFixture(t) - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) - - cfg, err := config.DefaultConfig() - require.NoError(t, err) - - connStatus.On("IsConnected", peerID).Return(false, nil) // not connected - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}). - Return(fmt.Errorf("some error")).Times(int(cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes + 1)) // connect - - _, err = dialConfigCache.Adjust(peerID, func(dialConfig unicast.DialConfig) (unicast.DialConfig, error) { - // assumes that there was a successful connection to the peer before (2 minutes ago), and now the connection is lost. - dialConfig.LastSuccessfulDial = time.Now().Add(2 * time.Minute) - return dialConfig, nil - }) - require.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s, err := mgr.CreateStream(ctx, peerID) - require.Error(t, err) - require.Nil(t, s) - - // The dial config must be updated with the backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) - require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes-1, dialCfg.DialRetryAttemptBudget) // dial backoff budget must be decremented by 1. - require.Equal(t, - cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must remain intact (no stream creation attempt yet). - // last successful dial is set back to zero, since although we have a successful dial in the past, the most recent dial failed. - require.True(t, dialCfg.LastSuccessfulDial.IsZero()) - require.Equal(t, uint64(0), dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be intact. -} - -// TestUnicastManager_StreamFactory_Connection_SuccessfulConnection_And_Stream tests that when there is no connection, and CreateStream is successful on the first attempt for connection and stream creation, -// it updates the last successful dial time and the consecutive successful stream counter. -func TestUnicastManager_Connection_SuccessfulConnection_And_Stream(t *testing.T) { +// TestUnicastManager_SuccessfulStream tests that when CreateStream is successful on the first attempt for stream creation, +// it updates the consecutive successful stream counter. +func TestUnicastManager_SuccessfulStream(t *testing.T) { peerID := unittest.PeerIdFixture(t) - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, configCache := unicastManagerFixture(t) cfg, err := config.DefaultConfig() require.NoError(t, err) - connStatus.On("IsConnected", peerID).Return(false, nil) // not connected - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}).Return(nil).Once() // connect on the first attempt. - // mocks that it attempts to create a stream once and succeeds. streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything).Return(&p2ptest.MockStream{}, nil).Once() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - dialTime := time.Now() s, err := mgr.CreateStream(ctx, peerID) require.NoError(t, err) require.NotNil(t, s) - // The dial config must be updated with the backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) + // The unicast config must be updated with the backoff budget decremented. + unicastCfg, err := configCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) // dial backoff budget must be intact. require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must remain intact. - // last successful dial must be set AFTER the successful dial. - require.True(t, dialCfg.LastSuccessfulDial.After(dialTime)) - require.Equal(t, uint64(1), dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must incremented. + unicastCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must remain intact. + require.Equal(t, uint64(1), unicastCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must incremented. } -// TestUnicastManager_StreamFactory_Connection_SuccessfulConnection_StreamBackoff tests the backoff mechanism of the unicast manager for stream creation. -// It tests the situation that there is no connection when CreateStream is called. The connection is created successfully, but the stream creation fails. -// It tests that when there is a connection, but no stream, it tries to create a stream some number of times (unicastmodel.MaxStreamCreationAttemptTimes), before -// giving up. -// It also checks the consecutive successful stream counter is reset when the stream creation fails, and the last successful dial time is updated. -func TestUnicastManager_Connection_SuccessfulConnection_StreamBackoff(t *testing.T) { +// TestUnicastManager_StreamBackoff tests the backoff mechanism of the unicast manager for stream creation. +// It tests the situation that CreateStream is called but the stream creation fails. +// It tests that it tries to create a stream some number of times (unicastmodel.MaxStreamCreationAttemptTimes), before giving up. +// It also checks the consecutive successful stream counter is reset when the stream creation fails. +func TestUnicastManager_StreamBackoff(t *testing.T) { peerID := unittest.PeerIdFixture(t) - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, configCache := unicastManagerFixture(t) cfg, err := config.DefaultConfig() require.NoError(t, err) - isConnectedCalled := 0 - connStatus.On("IsConnected", peerID).Return(func(id peer.ID) bool { - if isConnectedCalled == 0 { - // we mock that the connection is not established on the first call, and is established on the second call and onwards. - isConnectedCalled++ - return false - } - return true - }, nil) - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}).Return(nil).Once() // connect on the first attempt. - streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything).Return(nil, fmt.Errorf("some error")). - Times(int(cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes + 1)) // mocks that it attempts to create a stream some number of times, before giving up. + // mocks that it attempts to create a stream some number of times, before giving up. + streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). + Return(nil, fmt.Errorf("some error")). + Times(int(cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes + 1)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - dialTime := time.Now() s, err := mgr.CreateStream(ctx, peerID) require.Error(t, err) require.Nil(t, s) - // The dial config must be updated with the backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) + // The unicast config must be updated with the backoff budget decremented. + unicastCfg, err := configCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, - cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, - dialCfg.DialRetryAttemptBudget) // dial backoff budget must be intact, since the connection is successful. - require.Equal(t, - cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes-1, - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must be decremented by 1 since all budget is used up. - // last successful dial must be set AFTER the successful dial. - require.True(t, dialCfg.LastSuccessfulDial.After(dialTime)) - require.Equal(t, uint64(0), dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be reset to zero, since the stream creation failed. + // stream backoff budget must be decremented by 1 since all budget is used up. + require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes-1, unicastCfg.StreamCreationRetryAttemptBudget) + // consecutive successful stream must be reset to zero, since the stream creation failed. + require.Equal(t, uint64(0), unicastCfg.ConsecutiveSuccessfulStream) } // TestUnicastManager_StreamFactory_StreamBackoff tests the backoff mechanism of the unicast manager for stream creation. // It tests when there is a connection, but no stream, it tries to create a stream some number of times (unicastmodel.MaxStreamCreationAttemptTimes), before // giving up. func TestUnicastManager_StreamFactory_StreamBackoff(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) cfg, err := config.DefaultConfig() require.NoError(t, err) - connStatus.On("IsConnected", peerID).Return(true, nil) // connected. + // mocks that it attempts to create a stream some number of times, before giving up. streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). Return(nil, fmt.Errorf("some error")). - Times(int(cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes + 1)) // mocks that it attempts to create a stream some number of times, before giving up. + Times(int(cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes + 1)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -340,22 +223,19 @@ func TestUnicastManager_StreamFactory_StreamBackoff(t *testing.T) { require.Error(t, err) require.Nil(t, s) - // The dial config must be updated with the stream backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) + // The unicast config must be updated with the stream backoff budget decremented. + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) // dial backoff budget must be intact. - require.Equal(t, - cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes-1, - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must be decremented by 1. - require.Equal(t, - uint64(0), - dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be zero as we have not created a successful stream yet. + // stream backoff budget must be decremented by 1. + require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes-1, unicastCfg.StreamCreationRetryAttemptBudget) + // consecutive successful stream must be zero as we have not created a successful stream yet. + require.Equal(t, uint64(0), unicastCfg.ConsecutiveSuccessfulStream) } -// TestUnicastManager_Stream_ConsecutiveStreamCreation_Increment tests that when there is a connection, and the stream creation is successful, -// it increments the consecutive successful stream counter in the dial config. +// TestUnicastManager_Stream_ConsecutiveStreamCreation_Increment tests that when stream creation is successful, +// it increments the consecutive successful stream counter in the unicast config. func TestUnicastManager_Stream_ConsecutiveStreamCreation_Increment(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) cfg, err := config.DefaultConfig() @@ -364,7 +244,6 @@ func TestUnicastManager_Stream_ConsecutiveStreamCreation_Increment(t *testing.T) // total times we successfully create a stream to the peer. totalSuccessAttempts := 10 - connStatus.On("IsConnected", peerID).Return(true, nil) // connected. // mocks that it attempts to create a stream 10 times, and each time it succeeds. streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything).Return(&p2ptest.MockStream{}, nil).Times(totalSuccessAttempts) @@ -376,38 +255,37 @@ func TestUnicastManager_Stream_ConsecutiveStreamCreation_Increment(t *testing.T) require.NoError(t, err) require.NotNil(t, s) - // The dial config must be updated with the stream backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) + // The unicast config must be updated with the stream backoff budget decremented. + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) // dial backoff budget must be intact. // stream backoff budget must be intact (all stream creation attempts are successful). - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, dialCfg.StreamCreationRetryAttemptBudget) - require.Equal(t, uint64(i+1), dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be incremented. + require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, unicastCfg.StreamCreationRetryAttemptBudget) + // consecutive successful stream must be incremented. + require.Equal(t, uint64(i+1), unicastCfg.ConsecutiveSuccessfulStream) } } -// TestUnicastManager_Stream_ConsecutiveStreamCreation_Reset tests that when there is a connection, and the stream creation fails, it resets -// the consecutive successful stream counter in the dial config. +// TestUnicastManager_Stream_ConsecutiveStreamCreation_Reset tests that when the stream creation fails, it resets +// the consecutive successful stream counter in the unicast config. func TestUnicastManager_Stream_ConsecutiveStreamCreation_Reset(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) - cfg, err := config.DefaultConfig() - require.NoError(t, err) - + // mocks that it attempts to create a stream once and fails. streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). Return(nil, fmt.Errorf("some error")). - Once() // mocks that it attempts to create a stream once and fails. - connStatus.On("IsConnected", peerID).Return(true, nil) // connected. + Once() - adjustedDialConfig, err := dialConfigCache.Adjust(peerID, func(dialConfig unicast.DialConfig) (unicast.DialConfig, error) { - dialConfig.ConsecutiveSuccessfulStream = 5 // sets the consecutive successful stream to 5 meaning that the last 5 stream creation attempts were successful. - dialConfig.StreamCreationRetryAttemptBudget = 0 // sets the stream back budget to 0 meaning that the stream backoff budget is exhausted. + adjustedUnicastConfig, err := unicastConfigCache.Adjust(peerID, func(unicastConfig unicast.Config) (unicast.Config, error) { + // sets the consecutive successful stream to 5 meaning that the last 5 stream creation attempts were successful. + unicastConfig.ConsecutiveSuccessfulStream = 5 + // sets the stream back budget to 0 meaning that the stream backoff budget is exhausted. + unicastConfig.StreamCreationRetryAttemptBudget = 0 - return dialConfig, nil + return unicastConfig, nil }) require.NoError(t, err) - require.Equal(t, uint64(5), adjustedDialConfig.ConsecutiveSuccessfulStream) + require.Equal(t, uint64(5), adjustedUnicastConfig.ConsecutiveSuccessfulStream) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -416,25 +294,25 @@ func TestUnicastManager_Stream_ConsecutiveStreamCreation_Reset(t *testing.T) { require.Error(t, err) require.Nil(t, s) - // The dial config must be updated with the stream backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) + // The unicast config must be updated with the stream backoff budget decremented. + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) // dial backoff budget must be intact. - require.Equal(t, - uint64(0), - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must be intact (we can't decrement it below 0). - require.Equal(t, uint64(0), dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be reset to 0. + + // stream backoff budget must be intact (we can't decrement it below 0). + require.Equal(t, uint64(0), unicastCfg.StreamCreationRetryAttemptBudget) + // consecutive successful stream must be reset to 0. + require.Equal(t, uint64(0), unicastCfg.ConsecutiveSuccessfulStream) } // TestUnicastManager_StreamFactory_ErrProtocolNotSupported tests that when there is a protocol not supported error, it does not retry creating a stream. func TestUnicastManager_StreamFactory_ErrProtocolNotSupported(t *testing.T) { - mgr, streamFactory, connStatus, _ := unicastManagerFixture(t) + mgr, streamFactory, _ := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) - connStatus.On("IsConnected", peerID).Return(true, nil) // connected + // mocks that upon creating a stream, it returns a protocol not supported error, the mock is set to once, meaning that it won't retry stream creation again. streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). - Return(nil, stream.NewProtocolNotSupportedErr(peerID, []protocol.ID{"protocol-1"}, fmt.Errorf("some error"))). - Once() // mocks that upon creating a stream, it returns a protocol not supported error, the mock is set to once, meaning that it won't retry stream creation again. + Return(nil, stream.NewProtocolNotSupportedErr(peerID, protocol.ID("protocol-1"), fmt.Errorf("some error"))). + Once() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -443,20 +321,19 @@ func TestUnicastManager_StreamFactory_ErrProtocolNotSupported(t *testing.T) { require.Nil(t, s) } -// TestUnicastManager_StreamFactory_ErrNoAddresses tests that when dialing returns a no addresses error, it does not retry dialing again and returns an error immediately. +// TestUnicastManager_StreamFactory_ErrNoAddresses tests that when stream creation returns a no addresses error, +// it does not retry stream creation again and returns an error immediately. func TestUnicastManager_StreamFactory_ErrNoAddresses(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) cfg, err := config.DefaultConfig() require.NoError(t, err) peerID := unittest.PeerIdFixture(t) - // mocks that the connection is not established. - connStatus.On("IsConnected", peerID).Return(false, nil) - // mocks that dialing the peer returns a no addresses error, and the mock is set to once, meaning that it won't retry dialing again. - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}). - Return(fmt.Errorf("some error to ensure wrapping works fine: %w", swarm.ErrNoAddresses)). + // mocks that stream creation returns a no addresses error, and the mock is set to once, meaning that it won't retry stream creation again. + streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). + Return(nil, fmt.Errorf("some error to ensure wrapping works fine: %w", swarm.ErrNoAddresses)). Once() ctx, cancel := context.WithCancel(context.Background()) @@ -465,32 +342,27 @@ func TestUnicastManager_StreamFactory_ErrNoAddresses(t *testing.T) { require.Error(t, err) require.Nil(t, s) - dialCfg, err := dialConfigCache.GetOrInit(peerID) + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - // dial backoff budget must be decremented by 1 (although we didn't have a backoff attempt, the connection was unsuccessful). - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes-1, dialCfg.DialRetryAttemptBudget) - // stream backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, dialCfg.StreamCreationRetryAttemptBudget) - // last successful dial must be set to zero. - require.True(t, dialCfg.LastSuccessfulDial.IsZero()) + + // stream backoff budget must be reduced by 1 due to failed stream creation. + require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes-1, unicastCfg.StreamCreationRetryAttemptBudget) // consecutive successful stream must be set to zero. - require.Equal(t, uint64(0), dialCfg.ConsecutiveSuccessfulStream) + require.Equal(t, uint64(0), unicastCfg.ConsecutiveSuccessfulStream) } -// TestUnicastManager_Dial_ErrSecurityProtocolNegotiationFailed tests that when there is a security protocol negotiation error, it does not retry dialing. -func TestUnicastManager_Dial_ErrSecurityProtocolNegotiationFailed(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) +// TestUnicastManager_Stream_ErrSecurityProtocolNegotiationFailed tests that when there is a security protocol negotiation error, it does not retry stream creation. +func TestUnicastManager_Stream_ErrSecurityProtocolNegotiationFailed(t *testing.T) { + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) cfg, err := config.DefaultConfig() require.NoError(t, err) peerID := unittest.PeerIdFixture(t) - // mocks that the connection is not established. - connStatus.On("IsConnected", peerID).Return(false, nil) - // mocks that dialing the peer returns a security protocol negotiation error, and the mock is set to once, meaning that it won't retry dialing again. - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}). - Return(stream.NewSecurityProtocolNegotiationErr(peerID, fmt.Errorf("some error"))). + // mocks that stream creation returns a security protocol negotiation error, and the mock is set to once, meaning that it won't retry stream creation. + streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). + Return(nil, stream.NewSecurityProtocolNegotiationErr(peerID, fmt.Errorf("some error"))). Once() ctx, cancel := context.WithCancel(context.Background()) @@ -499,31 +371,25 @@ func TestUnicastManager_Dial_ErrSecurityProtocolNegotiationFailed(t *testing.T) require.Error(t, err) require.Nil(t, s) - dialCfg, err := dialConfigCache.GetOrInit(peerID) + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - // dial backoff budget must be decremented by 1 (although we didn't have a backoff attempt, the connection was unsuccessful). - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes-1, dialCfg.DialRetryAttemptBudget) - // stream backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, dialCfg.StreamCreationRetryAttemptBudget) - // last successful dial must be set to zero. - require.True(t, dialCfg.LastSuccessfulDial.IsZero()) + // stream retry budget must be decremented by 1 (since we didn't have a successful stream creation, the budget is decremented). + require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes-1, unicastCfg.StreamCreationRetryAttemptBudget) // consecutive successful stream must be set to zero. - require.Equal(t, uint64(0), dialCfg.ConsecutiveSuccessfulStream) + require.Equal(t, uint64(0), unicastCfg.ConsecutiveSuccessfulStream) } -// TestUnicastManager_Dial_ErrGaterDisallowedConnection tests that when there is a connection gater disallow listing error, it does not retry dialing. -func TestUnicastManager_Dial_ErrGaterDisallowedConnection(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) +// TestUnicastManager_StreamFactory_ErrGaterDisallowedConnection tests that when there is a connection-gater disallow listing error, it does not retry stream creation. +func TestUnicastManager_StreamFactory_ErrGaterDisallowedConnection(t *testing.T) { + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) - // mocks that the connection is not established. - connStatus.On("IsConnected", peerID).Return(false, nil) cfg, err := config.DefaultConfig() require.NoError(t, err) - // mocks that dialing the peer returns a security protocol negotiation error, and the mock is set to once, meaning that it won't retry dialing again. - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}). - Return(stream.NewGaterDisallowedConnectionErr(fmt.Errorf("some error"))). + // mocks that stream creation to the peer returns a connection gater disallow-listing, and the mock is set to once, meaning that it won't retry stream creation. + streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). + Return(nil, stream.NewGaterDisallowedConnectionErr(fmt.Errorf("some error"))). Once() ctx, cancel := context.WithCancel(context.Background()) @@ -532,85 +398,18 @@ func TestUnicastManager_Dial_ErrGaterDisallowedConnection(t *testing.T) { require.Error(t, err) require.Nil(t, s) - dialCfg, err := dialConfigCache.GetOrInit(peerID) + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - // dial backoff budget must be decremented by 1 (although we didn't have a backoff attempt, the connection was unsuccessful). - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes-1, dialCfg.DialRetryAttemptBudget) - // stream backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, dialCfg.StreamCreationRetryAttemptBudget) - // last successful dial must be set to zero. - require.True(t, dialCfg.LastSuccessfulDial.IsZero()) + // stream backoff budget must be reduced by 1 due to failed stream creation. + require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes-1, unicastCfg.StreamCreationRetryAttemptBudget) // consecutive successful stream must be set to zero. - require.Equal(t, uint64(0), dialCfg.ConsecutiveSuccessfulStream) + require.Equal(t, uint64(0), unicastCfg.ConsecutiveSuccessfulStream) } -// TestUnicastManager_Connection_BackoffBudgetDecremented tests that everytime the unicast manger gives up on creating a connection (after retrials), -// it decrements the backoff budget for the remote peer. -func TestUnicastManager_Connection_BackoffBudgetDecremented(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) - peerID := unittest.PeerIdFixture(t) - - cfg, err := config.DefaultConfig() - require.NoError(t, err) - - // totalAttempts is the total number of times that unicast manager calls Connect on the stream factory to dial the peer. - // Let's consider x = unicastmodel.MaxDialRetryAttemptTimes + 1. Then the test tries x times CreateStream. With dynamic backoffs, - // the first CreateStream call will try to Connect x times, the second CreateStream call will try to Connect x-1 times, - // and so on. So the total number of Connect calls is x + (x-1) + (x-2) + ... + 1 = x(x+1)/2. - // However, we also attempt one more time at the end of the test to CreateStream, when the backoff budget is 0. - maxDialRetryAttemptBudget := int(cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes) - attemptTimes := maxDialRetryAttemptBudget + 1 // 1 attempt + retry times - totalAttempts := attemptTimes * (attemptTimes + 1) / 2 - - connStatus.On("IsConnected", peerID).Return(false, nil) // not connected - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}). - Return(fmt.Errorf("some error")). - Times(int(totalAttempts)) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - for i := 0; i < maxDialRetryAttemptBudget; i++ { - s, err := mgr.CreateStream(ctx, peerID) - require.Error(t, err) - require.Nil(t, s) - - dialCfg, err := dialConfigCache.GetOrInit(peerID) - require.NoError(t, err) - - if i == maxDialRetryAttemptBudget-1 { - require.Equal(t, uint64(0), dialCfg.DialRetryAttemptBudget) - } else { - require.Equal(t, uint64(maxDialRetryAttemptBudget-i-1), dialCfg.DialRetryAttemptBudget) - } - - // The stream backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, dialCfg.StreamCreationRetryAttemptBudget) - } - // At this time the backoff budget for connection must be 0. - dialCfg, err := dialConfigCache.GetOrInit(peerID) - require.NoError(t, err) - - require.Equal(t, uint64(0), dialCfg.DialRetryAttemptBudget) - // The stream backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, dialCfg.StreamCreationRetryAttemptBudget) - - // After all the backoff budget is used up, it should stay at 0. - s, err := mgr.CreateStream(ctx, peerID) - require.Error(t, err) - require.Nil(t, s) - - dialCfg, err = dialConfigCache.GetOrInit(peerID) - require.NoError(t, err) - require.Equal(t, uint64(0), dialCfg.DialRetryAttemptBudget) - - // The stream backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, dialCfg.StreamCreationRetryAttemptBudget) -} - -// TestUnicastManager_Connection_BackoffBudgetDecremented tests that everytime the unicast manger gives up on creating a connection (after retrials), +// TestUnicastManager_Connection_BackoffBudgetDecremented tests that everytime the unicast manger gives up on creating a stream (after retrials), // it decrements the backoff budget for the remote peer. func TestUnicastManager_Stream_BackoffBudgetDecremented(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) cfg, err := config.DefaultConfig() @@ -625,7 +424,6 @@ func TestUnicastManager_Stream_BackoffBudgetDecremented(t *testing.T) { maxStreamAttempt := maxStreamRetryBudget + 1 // 1 attempt + retry times totalAttempts := maxStreamAttempt * (maxStreamAttempt + 1) / 2 - connStatus.On("IsConnected", peerID).Return(true, nil) // not connected streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything). Return(nil, fmt.Errorf("some error")). Times(int(totalAttempts)) @@ -637,57 +435,47 @@ func TestUnicastManager_Stream_BackoffBudgetDecremented(t *testing.T) { require.Error(t, err) require.Nil(t, s) - dialCfg, err := dialConfigCache.GetOrInit(peerID) + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) if i == int(maxStreamRetryBudget)-1 { - require.Equal(t, uint64(0), dialCfg.StreamCreationRetryAttemptBudget) + require.Equal(t, uint64(0), unicastCfg.StreamCreationRetryAttemptBudget) } else { - require.Equal(t, maxStreamRetryBudget-uint64(i)-1, dialCfg.StreamCreationRetryAttemptBudget) + require.Equal(t, maxStreamRetryBudget-uint64(i)-1, unicastCfg.StreamCreationRetryAttemptBudget) } - - // The dial backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) } // At this time the backoff budget for connection must be 0. - dialCfg, err := dialConfigCache.GetOrInit(peerID) + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - - require.Equal(t, uint64(0), dialCfg.StreamCreationRetryAttemptBudget) - // The dial backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) + require.Equal(t, uint64(0), unicastCfg.StreamCreationRetryAttemptBudget) // After all the backoff budget is used up, it should stay at 0. s, err := mgr.CreateStream(ctx, peerID) require.Error(t, err) require.Nil(t, s) - dialCfg, err = dialConfigCache.GetOrInit(peerID) + unicastCfg, err = unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, uint64(0), dialCfg.StreamCreationRetryAttemptBudget) - - // The dial backoff budget must remain intact, as we have not tried to create a stream yet. - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) + require.Equal(t, uint64(0), unicastCfg.StreamCreationRetryAttemptBudget) } -// TestUnicastManager_StreamFactory_Connection_SuccessfulConnection_And_Stream tests that when there is no connection, and CreateStream is successful on the first attempt for connection and stream creation, -// it updates the last successful dial time and the consecutive successful stream counter. +// TestUnicastManager_Stream_BackoffBudgetResetToDefault tests that when the stream retry attempt budget is zero, and the consecutive successful stream counter is above the reset threshold, +// it resets the stream retry attempt budget to the default value and increments the consecutive successful stream counter. func TestUnicastManager_Stream_BackoffBudgetResetToDefault(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) cfg, err := config.DefaultConfig() require.NoError(t, err) - connStatus.On("IsConnected", peerID).Return(true, nil) // there is a connection. // mocks that it attempts to create a stream once and succeeds. streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything).Return(&p2ptest.MockStream{}, nil).Once() - // update the dial config of the peer to have a zero stream backoff budget but a consecutive successful stream counter above the reset threshold. - adjustedCfg, err := dialConfigCache.Adjust(peerID, func(dialConfig unicast.DialConfig) (unicast.DialConfig, error) { - dialConfig.StreamCreationRetryAttemptBudget = 0 - dialConfig.ConsecutiveSuccessfulStream = cfg.NetworkConfig.UnicastConfig.StreamZeroRetryResetThreshold + 1 - return dialConfig, nil + // update the unicast config of the peer to have a zero stream backoff budget but a consecutive successful stream counter above the reset threshold. + adjustedCfg, err := unicastConfigCache.Adjust(peerID, func(unicastConfig unicast.Config) (unicast.Config, error) { + unicastConfig.StreamCreationRetryAttemptBudget = 0 + unicastConfig.ConsecutiveSuccessfulStream = cfg.NetworkConfig.UnicastConfig.StreamZeroRetryResetThreshold + 1 + return unicastConfig, nil }) require.NoError(t, err) require.Equal(t, uint64(0), adjustedCfg.StreamCreationRetryAttemptBudget) @@ -700,135 +488,30 @@ func TestUnicastManager_Stream_BackoffBudgetResetToDefault(t *testing.T) { require.NoError(t, err) require.NotNil(t, s) - // The dial config must be updated with the backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) // dial backoff budget must be intact. - require.Equal(t, - cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must reset to default. - require.True(t, dialCfg.LastSuccessfulDial.IsZero()) // last successful dial must be intact. + // stream backoff budget must reset to default. + require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, unicastCfg.StreamCreationRetryAttemptBudget) // consecutive successful stream must increment by 1 (it was threshold + 1 before). - require.Equal(t, cfg.NetworkConfig.UnicastConfig.StreamZeroRetryResetThreshold+1+1, dialCfg.ConsecutiveSuccessfulStream) + require.Equal(t, cfg.NetworkConfig.UnicastConfig.StreamZeroRetryResetThreshold+1+1, unicastCfg.ConsecutiveSuccessfulStream) } -// TestUnicastManager_StreamFactory_Connection_SuccessfulConnection_And_Stream tests that when there is no connection, and CreateStream is successful on the first attempt for connection and stream creation, -// it updates the last successful dial time and the consecutive successful stream counter. -func TestUnicastManager_Stream_BackoffConnectionBudgetResetToDefault(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) - peerID := unittest.PeerIdFixture(t) - - cfg, err := config.DefaultConfig() - require.NoError(t, err) - - connStatus.On("IsConnected", peerID).Return(false, nil) // there is no connection. - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}).Return(nil).Once() // connect on the first attempt. - // mocks that it attempts to create a stream once and succeeds. - streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything).Return(&p2ptest.MockStream{}, nil).Once() - - // update the dial config of the peer to have a zero dial backoff budget but it has not been long enough since the last successful dial. - adjustedCfg, err := dialConfigCache.Adjust(peerID, func(dialConfig unicast.DialConfig) (unicast.DialConfig, error) { - dialConfig.DialRetryAttemptBudget = 0 - dialConfig.LastSuccessfulDial = time.Now().Add(-cfg.NetworkConfig.UnicastConfig.DialZeroRetryResetThreshold) - return dialConfig, nil - }) - require.NoError(t, err) - require.Equal(t, uint64(0), adjustedCfg.DialRetryAttemptBudget) - require.True(t, - adjustedCfg.LastSuccessfulDial.Before(time.Now().Add(-cfg.NetworkConfig.UnicastConfig.DialZeroRetryResetThreshold))) // last successful dial must be within the threshold. - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - dialTime := time.Now() - s, err := mgr.CreateStream(ctx, peerID) - require.NoError(t, err) - require.NotNil(t, s) - - // The dial config must be updated with the backoff budget decremented. - dialCfg, err := dialConfigCache.GetOrInit(peerID) - require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) // dial backoff budget must be reset to default. - require.Equal(t, - cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must be intact. - require.True(t, - dialCfg.LastSuccessfulDial.After(dialTime)) // last successful dial must be updated when the dial was successful. - require.Equal(t, - uint64(1), - dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be incremented by 1 (0 -> 1). -} - -// TestUnicastManager_Connection_NoBackoff_When_Budget_Is_Zero tests that when there is no connection, and the dial backoff budget is zero and last successful dial is not within the zero reset threshold -// the unicast manager does not backoff if the dial attempt fails. -func TestUnicastManager_Connection_NoBackoff_When_Budget_Is_Zero(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) - peerID := unittest.PeerIdFixture(t) - - cfg, err := config.DefaultConfig() - require.NoError(t, err) - - connStatus.On("IsConnected", peerID).Return(false, nil) // there is no connection. - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}).Return(fmt.Errorf("some error")).Once() // connection is tried only once and fails. - - // update the dial config of the peer to have a zero dial backoff, and the last successful dial is not within the threshold. - adjustedCfg, err := dialConfigCache.Adjust(peerID, func(dialConfig unicast.DialConfig) (unicast.DialConfig, error) { - dialConfig.DialRetryAttemptBudget = 0 // set the dial backoff budget to 0, meaning that the dial backoff budget is exhausted. - dialConfig.LastSuccessfulDial = time.Now().Add(-10 * time.Minute) // last successful dial is not within the threshold. - dialConfig.ConsecutiveSuccessfulStream = 2 // set the consecutive successful stream to 2, meaning that the last 2 stream creation attempts were successful. - return dialConfig, nil - }) - require.NoError(t, err) - require.Equal(t, uint64(0), adjustedCfg.DialRetryAttemptBudget) - require.False(t, - adjustedCfg.LastSuccessfulDial.Before(time.Now().Add(-cfg.NetworkConfig.UnicastConfig.DialZeroRetryResetThreshold))) // last successful dial must not be within the threshold. - require.Equal(t, uint64(2), adjustedCfg.ConsecutiveSuccessfulStream) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - s, err := mgr.CreateStream(ctx, peerID) - require.Error(t, err) - require.Nil(t, s) - - dialCfg, err := dialConfigCache.GetOrInit(peerID) - require.NoError(t, err) - require.Equal(t, uint64(0), dialCfg.DialRetryAttemptBudget) // dial backoff budget must remain at 0. - require.Equal(t, - cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must be intact. - require.True(t, - dialCfg.LastSuccessfulDial.IsZero()) // last successful dial must be set to zero. - require.Equal(t, - uint64(0), - dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be set to zero. -} - -// TestUnicastManager_Stream_NoBackoff_When_Budget_Is_Zero tests that when there is a connection, and the stream backoff budget is zero and the consecutive successful stream counter is not above the -// zero rest threshold, the unicast manager does not backoff if the dial attempt fails. +// TestUnicastManager_Stream_NoBackoff_When_Budget_Is_Zero tests that when the stream backoff budget is zero and the consecutive successful stream counter is not above the +// zero rest threshold, the unicast manager does not backoff if the stream creation attempt fails. func TestUnicastManager_Stream_NoBackoff_When_Budget_Is_Zero(t *testing.T) { - mgr, streamFactory, connStatus, dialConfigCache := unicastManagerFixture(t) + mgr, streamFactory, unicastConfigCache := unicastManagerFixture(t) peerID := unittest.PeerIdFixture(t) - cfg, err := config.DefaultConfig() - require.NoError(t, err) - - connStatus.On("IsConnected", peerID).Return(true, nil) // there is a connection. // mocks that it attempts to create a stream once and fails, and does not retry. streamFactory.On("NewStream", mock.Anything, peerID, mock.Anything).Return(nil, fmt.Errorf("some error")).Once() - // update the dial config of the peer to have a zero dial backoff, and the last successful dial is not within the threshold. - lastSuccessfulDial := time.Now().Add(-10 * time.Minute) - adjustedCfg, err := dialConfigCache.Adjust(peerID, func(dialConfig unicast.DialConfig) (unicast.DialConfig, error) { - dialConfig.LastSuccessfulDial = lastSuccessfulDial // last successful dial is not within the threshold. - dialConfig.ConsecutiveSuccessfulStream = 2 // set the consecutive successful stream to 2, which is below the reset threshold. - dialConfig.StreamCreationRetryAttemptBudget = 0 // set the stream backoff budget to 0, meaning that the stream backoff budget is exhausted. - return dialConfig, nil + adjustedCfg, err := unicastConfigCache.Adjust(peerID, func(unicastConfig unicast.Config) (unicast.Config, error) { + unicastConfig.ConsecutiveSuccessfulStream = 2 // set the consecutive successful stream to 2, which is below the reset threshold. + unicastConfig.StreamCreationRetryAttemptBudget = 0 // set the stream backoff budget to 0, meaning that the stream backoff budget is exhausted. + return unicastConfig, nil }) require.NoError(t, err) require.Equal(t, uint64(0), adjustedCfg.StreamCreationRetryAttemptBudget) - require.False(t, - adjustedCfg.LastSuccessfulDial.Before(time.Now().Add(-cfg.NetworkConfig.UnicastConfig.DialZeroRetryResetThreshold))) // last successful dial must not be within the threshold. require.Equal(t, uint64(2), adjustedCfg.ConsecutiveSuccessfulStream) ctx, cancel := context.WithCancel(context.Background()) @@ -838,112 +521,8 @@ func TestUnicastManager_Stream_NoBackoff_When_Budget_Is_Zero(t *testing.T) { require.Error(t, err) require.Nil(t, s) - dialCfg, err := dialConfigCache.GetOrInit(peerID) - require.NoError(t, err) - require.Equal(t, cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, dialCfg.DialRetryAttemptBudget) // dial backoff budget must remain intact. - require.Equal(t, uint64(0), dialCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must remain zero. - require.Equal(t, lastSuccessfulDial, dialCfg.LastSuccessfulDial) // last successful dial must be intact. - require.Equal(t, uint64(0), dialCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be set to zero. -} - -// TestUnicastManager_Dial_In_Progress_Backoff tests that when there is a dial in progress, the unicast manager back-offs concurrent CreateStream calls. -func TestUnicastManager_Dial_In_Progress_Backoff(t *testing.T) { - streamFactory := mockp2p.NewStreamFactory(t) - streamFactory.On("SetStreamHandler", mock.Anything, mock.Anything).Return().Once() - connStatus := mockp2p.NewPeerConnections(t) - - cfg, err := config.DefaultConfig() - require.NoError(t, err) - - dialConfigCache := unicastcache.NewDialConfigCache(cfg.NetworkConfig.UnicastConfig.DialConfigCacheSize, - unittest.Logger(), - metrics.NewNoopCollector(), - func() unicast.DialConfig { - return unicast.DialConfig{ - DialRetryAttemptBudget: cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, - StreamCreationRetryAttemptBudget: cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - } - }) - collector := mockmetrics.NewNetworkMetrics(t) - mgr, err := unicast.NewUnicastManager(&unicast.ManagerConfig{ - Logger: unittest.Logger(), - StreamFactory: streamFactory, - SporkId: unittest.IdentifierFixture(), - ConnStatus: connStatus, - CreateStreamBackoffDelay: 1 * time.Millisecond, // overrides the default backoff delay to 1 millisecond to speed up the test. - Metrics: collector, - StreamZeroRetryResetThreshold: cfg.NetworkConfig.UnicastConfig.StreamZeroRetryResetThreshold, - DialZeroRetryResetThreshold: cfg.NetworkConfig.UnicastConfig.DialZeroRetryResetThreshold, - MaxStreamCreationRetryAttemptTimes: cfg.NetworkConfig.UnicastConfig.MaxStreamCreationRetryAttemptTimes, - MaxDialRetryAttemptTimes: cfg.NetworkConfig.UnicastConfig.MaxDialRetryAttemptTimes, - DialInProgressBackoffDelay: 1 * time.Millisecond, // overrides the default backoff delay to 1 millisecond to speed up the test. - DialBackoffDelay: cfg.NetworkConfig.UnicastConfig.DialBackoffDelay, - DialConfigCacheFactory: func(func() unicast.DialConfig) unicast.DialConfigCache { - return dialConfigCache - }, - }) - require.NoError(t, err) - mgr.SetDefaultHandler(func(libp2pnet.Stream) {}) // no-op handler, we don't care about the handler for this test - - testSucceeds := make(chan struct{}) - - // indicates whether OnStreamCreationFailure called with 1 attempt (this happens when dial fails), as the dial budget is 0, - // hence dial attempt is not retried after the first attempt. - streamCreationCalledFor1 := false - // indicates whether OnStreamCreationFailure called with 4 attempts (this happens when stream creation fails due to all backoff budget - // exhausted when there is another dial in progress). The stream creation retry budget is 3, so it will be called 4 times (1 attempt + 3 retries). - streamCreationCalledFor4 := false - - blockingDial := make(chan struct{}) - collector.On("OnStreamCreationFailure", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - attempts := args.Get(1).(int) - if attempts == 1 && !streamCreationCalledFor1 { // dial attempt is not retried after the first attempt. - streamCreationCalledFor1 = true - } else if attempts == 4 && !streamCreationCalledFor4 { // stream creation attempt is retried 3 times, and exhausts the budget. - close(blockingDial) // close the blocking dial to allow the dial to fail with an error. - streamCreationCalledFor4 = true - } else { - require.Fail(t, "unexpected attempt", "expected 1 or 4 (each once), got %d (maybe twice)", attempts) - } - if streamCreationCalledFor1 && streamCreationCalledFor4 { - close(testSucceeds) - } - }).Twice() - collector.On("OnPeerDialFailure", mock.Anything, mock.Anything).Once() - - peerID := unittest.PeerIdFixture(t) - adjustedCfg, err := dialConfigCache.Adjust(peerID, func(dialConfig unicast.DialConfig) (unicast.DialConfig, error) { - dialConfig.DialRetryAttemptBudget = 0 // set the dial backoff budget to 0, meaning that the dial backoff budget is exhausted. - dialConfig.StreamCreationRetryAttemptBudget = 3 // set the stream backoff budget to 3, meaning that the stream backoff budget is exhausted after 1 attempt + 3 retries. - return dialConfig, nil - }) + unicastCfg, err := unicastConfigCache.GetOrInit(peerID) require.NoError(t, err) - require.Equal(t, uint64(0), adjustedCfg.DialRetryAttemptBudget) - require.Equal(t, uint64(3), adjustedCfg.StreamCreationRetryAttemptBudget) - - connStatus.On("IsConnected", peerID).Return(false, nil) - streamFactory.On("Connect", mock.Anything, peer.AddrInfo{ID: peerID}). - Return(func(ctx context.Context, info peer.AddrInfo) error { - <-blockingDial // blocks the call to Connect until the test unblocks it, this is to simulate a dial in progress. - return fmt.Errorf("some error") // dial fails with an error when it is unblocked. - }). - Once() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // create 2 streams concurrently, the first one will block the dial and the second one will fail after 1 + 3 backoff attempts (4 attempts). - go func() { - s, err := mgr.CreateStream(ctx, peerID) - require.Error(t, err) - require.Nil(t, s) - }() - - go func() { - s, err := mgr.CreateStream(ctx, peerID) - require.Error(t, err) - require.Nil(t, s) - }() - - unittest.RequireCloseBefore(t, testSucceeds, 1*time.Second, "test timed out") + require.Equal(t, uint64(0), unicastCfg.StreamCreationRetryAttemptBudget) // stream backoff budget must remain zero. + require.Equal(t, uint64(0), unicastCfg.ConsecutiveSuccessfulStream) // consecutive successful stream must be set to zero. } diff --git a/network/p2p/unicast/retry.png b/network/p2p/unicast/retry.png index 72aa18752a3..c86edb3ce5f 100644 Binary files a/network/p2p/unicast/retry.png and b/network/p2p/unicast/retry.png differ diff --git a/network/p2p/unicast/stream/errors.go b/network/p2p/unicast/stream/errors.go index 9c73294c52b..725b5e45247 100644 --- a/network/p2p/unicast/stream/errors.go +++ b/network/p2p/unicast/stream/errors.go @@ -20,31 +20,34 @@ func (e ErrSecurityProtocolNegotiationFailed) Error() string { return fmt.Errorf("failed to dial remote peer %s in stream factory invalid node ID: %w", p2plogging.PeerId(e.pid), e.err).Error() } -// NewSecurityProtocolNegotiationErr returns a new ErrSecurityProtocolNegotiationFailed. -func NewSecurityProtocolNegotiationErr(pid peer.ID, err error) ErrSecurityProtocolNegotiationFailed { - return ErrSecurityProtocolNegotiationFailed{pid: pid, err: err} -} - // IsErrSecurityProtocolNegotiationFailed returns whether an error is ErrSecurityProtocolNegotiationFailed. func IsErrSecurityProtocolNegotiationFailed(err error) bool { var e ErrSecurityProtocolNegotiationFailed return errors.As(err, &e) } +// NewSecurityProtocolNegotiationErr returns a new ErrSecurityProtocolNegotiationFailed. +func NewSecurityProtocolNegotiationErr(pid peer.ID, err error) ErrSecurityProtocolNegotiationFailed { + return ErrSecurityProtocolNegotiationFailed{pid: pid, err: err} +} + // ErrProtocolNotSupported indicates node is running on a different spork. type ErrProtocolNotSupported struct { - peerID peer.ID - protocolIDS []protocol.ID - err error + peerID peer.ID + protocolID protocol.ID + err error } func (e ErrProtocolNotSupported) Error() string { - return fmt.Errorf("failed to dial remote peer %s remote node is running on a different spork: %w, protocol attempted: %s", p2plogging.PeerId(e.peerID), e.err, e.protocolIDS).Error() + return fmt.Errorf("failed to dial remote peer %s remote node is running on a different spork: %w, protocol attempted: %s", + p2plogging.PeerId(e.peerID), + e.err, + e.protocolID).Error() } // NewProtocolNotSupportedErr returns a new ErrSecurityProtocolNegotiationFailed. -func NewProtocolNotSupportedErr(peerID peer.ID, protocolIDS []protocol.ID, err error) ErrProtocolNotSupported { - return ErrProtocolNotSupported{peerID: peerID, protocolIDS: protocolIDS, err: err} +func NewProtocolNotSupportedErr(peerID peer.ID, protocolID protocol.ID, err error) ErrProtocolNotSupported { + return ErrProtocolNotSupported{peerID: peerID, protocolID: protocolID, err: err} } // IsErrProtocolNotSupported returns whether an error is ErrProtocolNotSupported. diff --git a/network/p2p/unicast/stream/factory.go b/network/p2p/unicast/stream/factory.go index 6e3b1804b4a..8336836d3a7 100644 --- a/network/p2p/unicast/stream/factory.go +++ b/network/p2p/unicast/stream/factory.go @@ -3,6 +3,7 @@ package stream import ( "context" "errors" + "fmt" "strings" "github.com/libp2p/go-libp2p/core/host" @@ -33,44 +34,59 @@ func (l *LibP2PStreamFactory) SetStreamHandler(pid protocol.ID, handler network. l.host.SetStreamHandler(pid, handler) } -// Connect connects host to peer with peerAddrInfo. -// All errors returned from this function can be considered benign. We expect the following errors during normal operations: +// NewStream establishes a new stream with the given peer using the provided protocol.ID on the libp2p host. +// This function is a critical part of the network communication, facilitating the creation of a dedicated +// bidirectional channel (stream) between two nodes in the network. +// If there exists no connection between the two nodes, the function attempts to establish one before creating the stream. +// If there are multiple connections between the two nodes, the function selects the best one (based on libp2p internal criteria) to create the stream. +// +// Usage: +// The function is intended to be used when there is a need to initiate a direct communication stream with a peer. +// It is typically invoked in scenarios where a node wants to send a message or start a series of messages to another +// node using a specific protocol. The protocol ID is used to ensure that both nodes communicate over the same +// protocol, which defines the structure and semantics of the communication. +// +// Expected errors: +// During normal operation, the function may encounter specific expected errors, which are handled as follows: +// +// - ErrProtocolNotSupported: This error occurs when the remote node does not support the specified protocol ID, +// which may indicate that the remote node is running a different version of the software or a different spork. +// The error contains details about the peer ID and the unsupported protocol, and it is generated when the +// underlying error message indicates a protocol mismatch. This is a critical error as it signifies that the +// two nodes cannot communicate using the requested protocol, and it must be handled by either retrying with +// a different protocol ID or by performing some form of negotiation or fallback. +// // - ErrSecurityProtocolNegotiationFailed this indicates there was an issue upgrading the connection. +// // - ErrGaterDisallowedConnection this indicates the connection was disallowed by the gater. -// - There may be other unexpected errors from libp2p but they should be considered benign. -func (l *LibP2PStreamFactory) Connect(ctx context.Context, peerAddrInfo peer.AddrInfo) error { - // libp2p internally uses swarm dial - https://github.com/libp2p/go-libp2p-swarm/blob/master/swarm_dial.go - // to connect to a peer. Swarm dial adds a back off each time it fails connecting to a peer. While this is - // the desired behaviour for pub-sub (1-k style of communication) for 1-1 style we want to retry the connection - // immediately without backing off and fail-fast. - // Hence, explicitly cancel the dial back off (if any) and try connecting again - if swm, ok := l.host.Network().(*swarm.Swarm); ok { - swm.Backoff().Clear(peerAddrInfo.ID) - } - - err := l.host.Connect(ctx, peerAddrInfo) +// +// - Any other error returned by the libp2p host: This error indicates that the stream creation failed due to +// some unexpected error, which may be caused by a variety of reasons. This is NOT a critical error, and it +// can be handled by retrying the stream creation or by performing some other action. Crashing node upon this +// error is NOT recommended. +// +// Arguments: +// - ctx: A context.Context that governs the lifetime of the stream creation. It can be used to cancel the +// operation or to set deadlines. +// - p: The peer.ID of the target node with which the stream is to be established. +// - pid: The protocol.ID that specifies the communication protocol to be used for the stream. +// +// Returns: +// - network.Stream: The successfully created stream, ready for reading and writing, or nil if an error occurs. +// - error: An error encountered during stream creation, wrapped in a contextually appropriate error type when necessary, +// or nil if the operation is successful. +func (l *LibP2PStreamFactory) NewStream(ctx context.Context, p peer.ID, pid protocol.ID) (network.Stream, error) { + s, err := l.host.NewStream(ctx, p, pid) switch { case err == nil: - return nil + return s, nil + case strings.Contains(err.Error(), protocolNotSupportedStr): + return nil, NewProtocolNotSupportedErr(p, pid, err) case strings.Contains(err.Error(), protocolNegotiationFailedStr): - return NewSecurityProtocolNegotiationErr(peerAddrInfo.ID, err) + return nil, NewSecurityProtocolNegotiationErr(p, err) case errors.Is(err, swarm.ErrGaterDisallowedConnection): - return NewGaterDisallowedConnectionErr(err) + return nil, NewGaterDisallowedConnectionErr(err) default: - return err - } -} - -// NewStream creates a new stream on the libp2p host. -// Expected errors during normal operations: -// - ErrProtocolNotSupported this indicates remote node is running on a different spork. -func (l *LibP2PStreamFactory) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { - s, err := l.host.NewStream(ctx, p, pids...) - if err != nil { - if strings.Contains(err.Error(), protocolNotSupportedStr) { - return nil, NewProtocolNotSupportedErr(p, pids, err) - } - return nil, err + return nil, fmt.Errorf("failed to create stream: %w", err) } - return s, err } diff --git a/network/validator/target_validator.go b/network/validator/target_validator.go index 5a9b1ab73f9..d02901b166e 100644 --- a/network/validator/target_validator.go +++ b/network/validator/target_validator.go @@ -35,7 +35,8 @@ func (tv *TargetValidator) Validate(msg network.IncomingMessageScope) bool { } } tv.log.Debug(). - Hex("target", logging.ID(tv.target)). + Hex("message_target_id", logging.ID(tv.target)). + Hex("local_node_id", logging.ID(tv.target)). Hex("event_id", msg.EventID()). Msg("message not intended for target") return false diff --git a/state/protocol/events.go b/state/protocol/events.go index e97c4f7c84c..c8dcf460159 100644 --- a/state/protocol/events.go +++ b/state/protocol/events.go @@ -56,15 +56,14 @@ type Consumer interface { // the current epoch. This is equivalent to the end of the epoch staking // phase for the current epoch. // - // Referencing the diagram below, the event is emitted when block c is incorporated. - // The block parameter is the first block of the epoch setup phase (block c). + // Referencing the diagram below, the event is emitted when block b is finalized. + // The block parameter is the first block of the epoch setup phase (block b). // // |<-- Epoch N ------------------------------------------------->| // |<-- StakingPhase -->|<-- SetupPhase -->|<-- CommittedPhase -->| // ^--- block A - this block's execution result contains an EpochSetup event - // ^--- block b - contains seal for block A - // ^--- block c - contains qc for block b, first block of Setup phase - // ^--- block d - finalizes block c, triggers EpochSetupPhaseStarted event + // ^--- block b - contains seal for block A, first block of Setup phase + // ^--- block c - finalizes block b, triggers EpochSetupPhaseStarted event // // NOTE: Only called once the phase transition has been finalized. EpochSetupPhaseStarted(currentEpochCounter uint64, first *flow.Header) @@ -73,16 +72,14 @@ type Consumer interface { // for the current epoch. This is equivalent to the end of the epoch setup // phase for the current epoch. // - // Referencing the diagram below, the event is emitted when block f is received. - // The block parameter is the first block of the epoch committed phase (block f). + // Referencing the diagram below, the event is emitted when block e is finalized. + // The block parameter is the first block of the epoch committed phase (block e). // // |<-- Epoch N ------------------------------------------------->| // |<-- StakingPhase -->|<-- SetupPhase -->|<-- CommittedPhase -->| // ^--- block D - this block's execution result contains an EpochCommit event - // ^--- block e - contains seal for block D - // ^--- block f - contains qc for block e, first block of Committed phase - // ^--- block g - finalizes block f, triggers EpochCommittedPhaseStarted event - /// + // ^--- block e - contains seal for block D, first block of Committed phase + // ^--- block f - finalizes block e, triggers EpochCommittedPhaseStarted event // // NOTE: Only called once the phase transition has been finalized. EpochCommittedPhaseStarted(currentEpochCounter uint64, first *flow.Header) diff --git a/utils/unittest/execution_state.go b/utils/unittest/execution_state.go index 8c993b28c16..25d180ae80b 100644 --- a/utils/unittest/execution_state.go +++ b/utils/unittest/execution_state.go @@ -90,8 +90,5 @@ func genesisCommitHexByChainID(chainID flow.ChainID) string { if chainID == flow.Testnet { return "7192a942310f70b21579f3e3bbf6f381a2c350e23130b465aa6a25f4c7612d87" } - if chainID == flow.Sandboxnet { - return "e1c08b17f9e5896f03fe28dd37ca396c19b26628161506924fbf785834646ea1" - } return "93fc7a1d086c25794822a034288f33580bd1eca485ffab1a36691590518606fa" } diff --git a/utils/unittest/fixtures.go b/utils/unittest/fixtures.go index 34088313df5..0483e6ce406 100644 --- a/utils/unittest/fixtures.go +++ b/utils/unittest/fixtures.go @@ -2705,14 +2705,14 @@ func P2PRPCFixture(opts ...RPCFixtureOpt) *pubsub.RPC { return rpc } -func WithTopic(topic string) func(*pubsub_pb.Message) { +func WithFrom(pid peer.ID) func(*pubsub_pb.Message) { return func(msg *pubsub_pb.Message) { - msg.Topic = &topic + msg.From = []byte(pid) } } // GossipSubMessageFixture returns a gossip sub message fixture for the specified topic. -func GossipSubMessageFixture(t *testing.T, s string, opts ...func(*pubsub_pb.Message)) *pubsub_pb.Message { +func GossipSubMessageFixture(s string, opts ...func(*pubsub_pb.Message)) *pubsub_pb.Message { pb := &pubsub_pb.Message{ From: RandomBytes(32), Data: RandomBytes(32), @@ -2730,10 +2730,10 @@ func GossipSubMessageFixture(t *testing.T, s string, opts ...func(*pubsub_pb.Mes } // GossipSubMessageFixtures returns a list of gossipsub message fixtures. -func GossipSubMessageFixtures(t *testing.T, n int, topic string, opts ...func(*pubsub_pb.Message)) []*pubsub_pb.Message { +func GossipSubMessageFixtures(n int, topic string, opts ...func(*pubsub_pb.Message)) []*pubsub_pb.Message { msgs := make([]*pubsub_pb.Message, n) for i := 0; i < n; i++ { - msgs[i] = GossipSubMessageFixture(t, topic, opts...) + msgs[i] = GossipSubMessageFixture(topic, opts...) } return msgs } @@ -2752,3 +2752,25 @@ func LibP2PResourceLimitOverrideFixture() p2pconf.ResourceManagerOverrideLimit { Memory: rand.Intn(1000), } } + +func RegisterEntryFixture() flow.RegisterEntry { + val := make([]byte, 4) + _, _ = crand.Read(val) + return flow.RegisterEntry{ + Key: flow.RegisterID{ + Owner: "owner", + Key: "key1", + }, + Value: val, + } +} + +func MakeOwnerReg(key string, value string) flow.RegisterEntry { + return flow.RegisterEntry{ + Key: flow.RegisterID{ + Owner: "owner", + Key: key, + }, + Value: []byte(value), + } +} diff --git a/utils/unittest/logging.go b/utils/unittest/logging.go index ee9dd762b77..a200a61525e 100644 --- a/utils/unittest/logging.go +++ b/utils/unittest/logging.go @@ -30,7 +30,7 @@ func Logger() zerolog.Logger { writer = os.Stderr } - return LoggerWithWriterAndLevel(writer, zerolog.DebugLevel) + return LoggerWithWriterAndLevel(writer, zerolog.TraceLevel) } func LoggerWithWriterAndLevel(writer io.Writer, level zerolog.Level) zerolog.Logger { diff --git a/utils/unittest/unittest.go b/utils/unittest/unittest.go index 0ad7a8736e4..9fba23ccd69 100644 --- a/utils/unittest/unittest.go +++ b/utils/unittest/unittest.go @@ -1,7 +1,6 @@ package unittest import ( - crand "crypto/rand" "encoding/json" "math" "math/rand" @@ -18,15 +17,16 @@ import ( "github.com/cockroachdb/pebble" "github.com/dgraph-io/badger/v2" "github.com/libp2p/go-libp2p/core/peer" - "github.com/multiformats/go-multihash" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/flow-go/crypto" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/util" "github.com/onflow/flow-go/network" cborcodec "github.com/onflow/flow-go/network/codec/cbor" + "github.com/onflow/flow-go/network/p2p/keyutils" "github.com/onflow/flow-go/network/topology" ) @@ -453,15 +453,43 @@ func GenerateRandomStringWithLen(commentLen uint) string { return string(bytes) } -// PeerIdFixture returns a random peer ID for testing. -// peer ID is the identifier of a node on the libp2p network. -func PeerIdFixture(t testing.TB) peer.ID { - buf := make([]byte, 16) - n, err := crand.Read(buf) - require.NoError(t, err) - require.Equal(t, 16, n) - h, err := multihash.Sum(buf, multihash.SHA2_256, -1) - require.NoError(t, err) +// PeerIdFixture creates a random and unique peer ID (libp2p node ID). +func PeerIdFixture(tb testing.TB) peer.ID { + peerID, err := peerIDFixture() + require.NoError(tb, err) + return peerID +} + +func peerIDFixture() (peer.ID, error) { + key, err := generateNetworkingKey(IdentifierFixture()) + if err != nil { + return "", err + } + pubKey, err := keyutils.LibP2PPublicKeyFromFlow(key.PublicKey()) + if err != nil { + return "", err + } + + peerID, err := peer.IDFromPublicKey(pubKey) + if err != nil { + return "", err + } + + return peerID, nil +} - return peer.ID(h) +// generateNetworkingKey generates a Flow ECDSA key using the given seed +func generateNetworkingKey(s flow.Identifier) (crypto.PrivateKey, error) { + seed := make([]byte, crypto.KeyGenSeedMinLen) + copy(seed, s[:]) + return crypto.GeneratePrivateKey(crypto.ECDSASecp256k1, seed) +} + +// PeerIdFixtures creates random and unique peer IDs (libp2p node IDs). +func PeerIdFixtures(t *testing.T, n int) []peer.ID { + peerIDs := make([]peer.ID, n) + for i := 0; i < n; i++ { + peerIDs[i] = PeerIdFixture(t) + } + return peerIDs }