diff --git a/README.md b/README.md index 69f43236391..22b612af93f 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ [![](https://img.shields.io/badge/project-MultiversX%20Mainnet-blue.svg)](https://explorer.multiversx.com/) [![Go Report Card](https://goreportcard.com/badge/github.com/multiversx/mx-chain-go)](https://goreportcard.com/report/github.com/multiversx/mx-chain-go) [![codecov](https://codecov.io/gh/multiversx/mx-chain-go/branch/master/graph/badge.svg?token=MYS5EDASOJ)](https://codecov.io/gh/multiversx/mx-chain-go) +[![Contributors](https://img.shields.io/github/contributors/multiversx/mx-chain-go)](https://github.com/multiversx/mx-chain-go/graphs/contributors) # mx-chain-go diff --git a/cmd/keygenerator/converter/pidPubkeyConverter.go b/cmd/keygenerator/converter/pidPubkeyConverter.go index cacda3d1bcd..eba89f5d9c0 100644 --- a/cmd/keygenerator/converter/pidPubkeyConverter.go +++ b/cmd/keygenerator/converter/pidPubkeyConverter.go @@ -7,6 +7,7 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-crypto-go/signing" "github.com/multiversx/mx-chain-crypto-go/signing/secp256k1" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/p2p/factory" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -14,13 +15,15 @@ import ( var log = logger.GetOrCreate("cmd/keygenerator/converter") type pidPubkeyConverter struct { - keyGen crypto.KeyGenerator + keyGen crypto.KeyGenerator + p2PKeyConverter p2p.P2PKeyConverter } // NewPidPubkeyConverter creates a new instance of a public key converter that can handle conversions involving core.PeerID string representations func NewPidPubkeyConverter() *pidPubkeyConverter { return &pidPubkeyConverter{ - keyGen: signing.NewKeyGenerator(secp256k1.NewSecp256k1()), + keyGen: signing.NewKeyGenerator(secp256k1.NewSecp256k1()), + p2PKeyConverter: factory.NewP2PKeyConverter(), } } @@ -51,7 +54,7 @@ func (converter *pidPubkeyConverter) encode(pkBytes []byte) (string, error) { return "", err } - pid, err := factory.ConvertPublicKeyToPeerID(pk) + pid, err := converter.p2PKeyConverter.ConvertPublicKeyToPeerID(pk) if err != nil { return "", err } diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index c26aaa60d96..23667a34af3 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -668,7 +668,6 @@ [StateTriesConfig] CheckpointRoundsModulus = 100 CheckpointsEnabled = false - SnapshotsEnabled = true AccountsStatePruningEnabled = false PeerStatePruningEnabled = true MaxStateTrieLevelInMemory = 5 @@ -933,6 +932,7 @@ HardforkTimeBetweenSendsInSec = 60 # 1min # time between hardfork messages TimeBetweenConnectionsMetricsUpdateInSec = 30 # 30sec # time between consecutive connections metrics updates TimeToReadDirectConnectionsInSec = 15 # 15sec # time between consecutive peer shard mapper updates with direct connections + PeerAuthenticationTimeBetweenChecksInSec = 6 # 6sec [HeartbeatV2.HeartbeatPool] Name = "HeartbeatPool" Capacity = 50000 diff --git a/cmd/node/config/enableEpochs.toml b/cmd/node/config/enableEpochs.toml index fe102090767..7bd4ffbcd41 100644 --- a/cmd/node/config/enableEpochs.toml +++ b/cmd/node/config/enableEpochs.toml @@ -239,6 +239,9 @@ # AlwaysSaveTokenMetaDataEnableEpoch represents the epoch when the token metadata is always saved AlwaysSaveTokenMetaDataEnableEpoch = 1 + # RuntimeCodeSizeFixEnableEpoch represents the epoch when the code size fix in the VM is enabled + RuntimeCodeSizeFixEnableEpoch = 2 + # BLSMultiSignerEnableEpoch represents the activation epoch for different types of BLS multi-signers BLSMultiSignerEnableEpoch = [ { EnableEpoch = 0, Type = "no-KOSK"}, diff --git a/cmd/node/config/prefs.toml b/cmd/node/config/prefs.toml index e82064b9029..d6339fca6ab 100644 --- a/cmd/node/config/prefs.toml +++ b/cmd/node/config/prefs.toml @@ -4,10 +4,12 @@ # if "disabled" is provided then the node will start in the corresponding shard for its public key or 0 otherwise DestinationShardAsObserver = "disabled" - # NodeDisplayName represents the friendly name a user can pick for his node in the status monitor + # NodeDisplayName represents the friendly name a user can pick for his node in the status monitor when the node does not run in multikey mode + # In multikey mode, all bls keys not mentioned in NamedIdentity section will use this one as default NodeDisplayName = "" - # Identity represents the keybase's identity + # Identity represents the keybase's identity when the node does not run in multikey mode + # In multikey mode, all bls keys not mentioned in NamedIdentity section will use this one as default Identity = "" # RedundancyLevel represents the level of redundancy used by the node (-1 = disabled, 0 = main instance (default), @@ -47,3 +49,16 @@ # { File = "config.toml", Path = "MiniBlocksStorage.Cache.Name", Value = "MiniBlocksStorage" }, # { File = "external.toml", Path = "ElasticSearchConnector.Enabled", Value = "true" } #] + +# NamedIdentity represents an identity that runs nodes on the multikey +# There can be multiple identities set on the same node, each one of them having different bls keys, just by duplicating the NamedIdentity +[[NamedIdentity]] + # Identity represents the keybase identity for the current NamedIdentity + Identity = "" + # NodeName represents the name that will be given to the names of the current identity + NodeName = "" + # BLSKeys represents the BLS keys assigned to the current NamedIdentity + BLSKeys = [ + "", + "" + ] diff --git a/cmd/node/flags.go b/cmd/node/flags.go index fb50a34c337..0cb32cb937e 100644 --- a/cmd/node/flags.go +++ b/cmd/node/flags.go @@ -208,6 +208,13 @@ var ( Usage: "The `filepath` for the PEM file which contains the secret keys for the validator key.", Value: "./config/validatorKey.pem", } + // allValidatorKeysPemFile defines a flag for the path to the file that hold all validator keys used in block signing + // managed by the current node + allValidatorKeysPemFile = cli.StringFlag{ + Name: "all-validator-keys-pem-file", + Usage: "The `filepath` for the PEM file which contains all the secret keys managed by the current node.", + Value: "./config/allValidatorsKeys.pem", + } // logLevel defines the logger level logLevel = cli.StringFlag{ @@ -365,6 +372,12 @@ var ( Value: "./config/p2pKey.pem", } + // snapshotsEnabled is used to enable snapshots, if it is not set it defaults to true, it will be set to false if it is set specifically + snapshotsEnabled = cli.BoolTFlag{ + Name: "snapshots-enabled", + Usage: "Boolean option for enabling state snapshots. If it is not set it defaults to true, it will be set to false if it is set specifically as --snapshots-enabled=false", + } + // operationMode defines the flag for specifying how configs should be altered depending on the node's intent operationMode = cli.StringFlag{ Name: "operation-mode", @@ -391,6 +404,7 @@ func getFlags() []cli.Flag { gasScheduleConfigurationDirectory, validatorKeyIndex, validatorKeyPemFile, + allValidatorKeysPemFile, port, profileMode, useHealthService, @@ -425,6 +439,7 @@ func getFlags() []cli.Flag { serializeSnapshots, noKey, p2pKeyPemFile, + snapshotsEnabled, dbDirectory, logsDirectory, operationMode, @@ -455,6 +470,7 @@ func getFlagsConfig(ctx *cli.Context, log logger.Logger) *config.ContextFlagsCon flagsConfig.DisableConsensusWatchdog = ctx.GlobalBool(disableConsensusWatchdog.Name) flagsConfig.SerializeSnapshots = ctx.GlobalBool(serializeSnapshots.Name) flagsConfig.NoKeyProvided = ctx.GlobalBool(noKey.Name) + flagsConfig.SnapshotsEnabled = ctx.GlobalBool(snapshotsEnabled.Name) flagsConfig.OperationMode = ctx.GlobalString(operationMode.Name) return flagsConfig @@ -467,6 +483,7 @@ func applyFlags(ctx *cli.Context, cfgs *config.Configs, flagsConfig *config.Cont cfgs.ConfigurationPathsHolder.GasScheduleDirectoryName = ctx.GlobalString(gasScheduleConfigurationDirectory.Name) cfgs.ConfigurationPathsHolder.SmartContracts = ctx.GlobalString(smartContractsFile.Name) cfgs.ConfigurationPathsHolder.ValidatorKey = ctx.GlobalString(validatorKeyPemFile.Name) + cfgs.ConfigurationPathsHolder.AllValidatorKeys = ctx.GlobalString(allValidatorKeysPemFile.Name) cfgs.ConfigurationPathsHolder.P2pKey = ctx.GlobalString(p2pKeyPemFile.Name) if ctx.IsSet(startInEpoch.Name) { @@ -627,12 +644,12 @@ func processDbLookupExtensionMode(log logger.Logger, configs *config.Configs) { func processLiteObserverMode(log logger.Logger, configs *config.Configs) { configs.GeneralConfig.StoragePruning.ObserverCleanOldEpochsData = true - configs.GeneralConfig.StateTriesConfig.SnapshotsEnabled = false + configs.FlagsConfig.SnapshotsEnabled = false configs.GeneralConfig.StateTriesConfig.AccountsStatePruningEnabled = true log.Warn("the node is in snapshotless observer mode! Will auto-set some config values", "StoragePruning.ObserverCleanOldEpochsData", configs.GeneralConfig.StoragePruning.ObserverCleanOldEpochsData, - "StateTriesConfig.SnapshotsEnabled", configs.GeneralConfig.StateTriesConfig.SnapshotsEnabled, + "FlagsConfig.SnapshotsEnabled", configs.FlagsConfig.SnapshotsEnabled, "StateTriesConfig.AccountsStatePruningEnabled", configs.GeneralConfig.StateTriesConfig.AccountsStatePruningEnabled, ) } diff --git a/cmd/node/main.go b/cmd/node/main.go index 2609eb87def..f89702cb3c3 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -63,7 +63,8 @@ func main() { app.Name = "MultiversX Node CLI App" machineID := core.GetAnonymizedMachineID(app.Name) - app.Version = fmt.Sprintf("%s/%s/%s-%s/%s", appVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH, machineID) + baseVersion := fmt.Sprintf("%s/%s/%s-%s", appVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + app.Version = fmt.Sprintf("%s/%s", baseVersion, machineID) app.Usage = "This is the entry point for starting a new MultiversX node - the app will start after the genesis timestamp" app.Flags = getFlags() app.Authors = []cli.Author{ @@ -74,7 +75,7 @@ func main() { } app.Action = func(c *cli.Context) error { - return startNodeRunner(c, log, app.Version) + return startNodeRunner(c, log, baseVersion, app.Version) } err := app.Run(os.Args) @@ -84,7 +85,7 @@ func main() { } } -func startNodeRunner(c *cli.Context, log logger.Logger, version string) error { +func startNodeRunner(c *cli.Context, log logger.Logger, baseVersion string, version string) error { flagsConfig := getFlagsConfig(c, log) fileLogging, errLogger := attachFileLogger(log, flagsConfig) @@ -125,6 +126,7 @@ func startNodeRunner(c *cli.Context, log logger.Logger, version string) error { log.Debug("initialized memory ballast object", "size", core.ConvertBytes(uint64(len(memoryBallastObject)))) } + cfgs.FlagsConfig.BaseVersion = baseVersion cfgs.FlagsConfig.Version = version nodeRunner, errRunner := node.NewNodeRunner(cfgs) diff --git a/common/constants.go b/common/constants.go index 016b6003726..dcf620095ac 100644 --- a/common/constants.go +++ b/common/constants.go @@ -301,7 +301,7 @@ const MetricRedundancyIsMainActive = "erd_redundancy_is_main_active" // MetricValueNA represents the value to be used when a metric is not available/applicable const MetricValueNA = "N/A" -//MetricProcessedProposedBlock is the metric that specify the percent of the block subround used for header and body +// MetricProcessedProposedBlock is the metric that specify the percent of the block subround used for header and body // processing (0 meaning that the block was processed in no-time and 100 meaning that the block processing used all the // subround spare duration) const MetricProcessedProposedBlock = "erd_consensus_processed_proposed_block" diff --git a/common/enablers/enableEpochsHandler.go b/common/enablers/enableEpochsHandler.go index 85a9e115056..b9017bdcd4e 100644 --- a/common/enablers/enableEpochsHandler.go +++ b/common/enablers/enableEpochsHandler.go @@ -114,6 +114,7 @@ func (handler *enableEpochsHandler) EpochConfirmed(epoch uint32, _ uint64) { handler.setFlagValue(epoch >= handler.enableEpochsConfig.FixAsyncCallBackArgsListEnableEpoch, handler.fixAsyncCallBackArgsList, "fixAsyncCallBackArgsList") handler.setFlagValue(epoch >= handler.enableEpochsConfig.FixOldTokenLiquidityEnableEpoch, handler.fixOldTokenLiquidity, "fixOldTokenLiquidity") handler.setFlagValue(epoch >= handler.enableEpochsConfig.RuntimeMemStoreLimitEnableEpoch, handler.runtimeMemStoreLimitFlag, "runtimeMemStoreLimitFlag") + handler.setFlagValue(epoch >= handler.enableEpochsConfig.RuntimeCodeSizeFixEnableEpoch, handler.runtimeCodeSizeFixFlag, "runtimeCodeSizeFixFlag") handler.setFlagValue(epoch >= handler.enableEpochsConfig.MaxBlockchainHookCountersEnableEpoch, handler.maxBlockchainHookCountersFlag, "maxBlockchainHookCountersFlag") handler.setFlagValue(epoch >= handler.enableEpochsConfig.WipeSingleNFTLiquidityDecreaseEnableEpoch, handler.wipeSingleNFTLiquidityDecreaseFlag, "wipeSingleNFTLiquidityDecreaseFlag") handler.setFlagValue(epoch >= handler.enableEpochsConfig.AlwaysSaveTokenMetaDataEnableEpoch, handler.alwaysSaveTokenMetaDataFlag, "alwaysSaveTokenMetaDataFlag") diff --git a/common/enablers/enableEpochsHandler_test.go b/common/enablers/enableEpochsHandler_test.go index c4721025939..661d684f010 100644 --- a/common/enablers/enableEpochsHandler_test.go +++ b/common/enablers/enableEpochsHandler_test.go @@ -90,6 +90,7 @@ func createEnableEpochsConfig() config.EnableEpochs { MaxBlockchainHookCountersEnableEpoch: 74, WipeSingleNFTLiquidityDecreaseEnableEpoch: 75, AlwaysSaveTokenMetaDataEnableEpoch: 76, + RuntimeCodeSizeFixEnableEpoch: 77, } } @@ -128,7 +129,7 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { handler, _ := NewEnableEpochsHandler(cfg, &epochNotifier.EpochNotifierStub{}) require.False(t, check.IfNil(handler)) - handler.EpochConfirmed(76, 0) + handler.EpochConfirmed(77, 0) assert.Equal(t, cfg.BlockGasAndFeesReCheckEnableEpoch, handler.BlockGasAndFeesReCheckEnableEpoch()) assert.True(t, handler.IsSCDeployFlagEnabled()) @@ -211,11 +212,12 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.True(t, handler.IsRuntimeMemStoreLimitEnabled()) assert.True(t, handler.IsMaxBlockchainHookCountersFlagEnabled()) assert.True(t, handler.IsAlwaysSaveTokenMetaDataEnabled()) + assert.True(t, handler.IsRuntimeCodeSizeFixEnabled()) }) t.Run("flags with == condition should be set, along with all >=", func(t *testing.T) { t.Parallel() - epoch := uint32(77) + epoch := uint32(78) cfg := createEnableEpochsConfig() cfg.StakingV2EnableEpoch = epoch cfg.ESDTEnableEpoch = epoch @@ -310,6 +312,7 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.True(t, handler.IsMaxBlockchainHookCountersFlagEnabled()) assert.True(t, handler.IsWipeSingleNFTLiquidityDecreaseEnabled()) assert.True(t, handler.IsAlwaysSaveTokenMetaDataEnabled()) + assert.True(t, handler.IsRuntimeCodeSizeFixEnabled()) }) t.Run("flags with < should be set", func(t *testing.T) { t.Parallel() @@ -404,5 +407,6 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.False(t, handler.IsMaxBlockchainHookCountersFlagEnabled()) assert.False(t, handler.IsWipeSingleNFTLiquidityDecreaseEnabled()) assert.False(t, handler.IsAlwaysSaveTokenMetaDataEnabled()) + assert.False(t, handler.IsRuntimeCodeSizeFixEnabled()) }) } diff --git a/common/enablers/epochFlags.go b/common/enablers/epochFlags.go index eaa720cafe7..fe11469f4bb 100644 --- a/common/enablers/epochFlags.go +++ b/common/enablers/epochFlags.go @@ -86,6 +86,7 @@ type epochFlagsHolder struct { fixAsyncCallBackArgsList *atomic.Flag fixOldTokenLiquidity *atomic.Flag runtimeMemStoreLimitFlag *atomic.Flag + runtimeCodeSizeFixFlag *atomic.Flag maxBlockchainHookCountersFlag *atomic.Flag wipeSingleNFTLiquidityDecreaseFlag *atomic.Flag alwaysSaveTokenMetaDataFlag *atomic.Flag @@ -174,6 +175,7 @@ func newEpochFlagsHolder() *epochFlagsHolder { fixAsyncCallBackArgsList: &atomic.Flag{}, fixOldTokenLiquidity: &atomic.Flag{}, runtimeMemStoreLimitFlag: &atomic.Flag{}, + runtimeCodeSizeFixFlag: &atomic.Flag{}, maxBlockchainHookCountersFlag: &atomic.Flag{}, wipeSingleNFTLiquidityDecreaseFlag: &atomic.Flag{}, alwaysSaveTokenMetaDataFlag: &atomic.Flag{}, @@ -638,6 +640,11 @@ func (holder *epochFlagsHolder) IsRuntimeMemStoreLimitEnabled() bool { return holder.runtimeMemStoreLimitFlag.IsSet() } +// IsRuntimeCodeSizeFixEnabled returns true if runtimeCodeSizeFixFlag is enabled +func (holder *epochFlagsHolder) IsRuntimeCodeSizeFixEnabled() bool { + return holder.runtimeCodeSizeFixFlag.IsSet() +} + // IsMaxBlockchainHookCountersFlagEnabled returns true if maxBlockchainHookCountersFlagEnabled is enabled func (holder *epochFlagsHolder) IsMaxBlockchainHookCountersFlagEnabled() bool { return holder.maxBlockchainHookCountersFlag.IsSet() diff --git a/common/interface.go b/common/interface.go index 597359d4209..33137234db7 100644 --- a/common/interface.go +++ b/common/interface.go @@ -7,18 +7,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-go/trie/statistics" ) -// NumNodesDTO represents the DTO structure that will hold the number of nodes split by category and other -// trie structure relevant data such as maximum number of trie levels including the roothash node and all leaves -type NumNodesDTO struct { - Leaves int - Extensions int - Branches int - MaxLevel int -} - // TrieIteratorChannels defines the channels that are being used when iterating the trie nodes type TrieIteratorChannels struct { LeavesChan chan core.KeyValueHolder @@ -40,7 +32,6 @@ type Trie interface { GetOldRoot() []byte GetSerializedNodes([]byte, uint64) ([][]byte, uint64, error) GetSerializedNode([]byte) ([]byte, error) - GetNumNodes() NumNodesDTO GetAllLeavesOnChannel(allLeavesChan *TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder) error GetAllHashes() ([][]byte, error) GetProof(key []byte) ([][]byte, []byte, error) @@ -333,9 +324,31 @@ type EnableEpochsHandler interface { IsFixAsyncCallBackArgsListFlagEnabled() bool IsFixOldTokenLiquidityEnabled() bool IsRuntimeMemStoreLimitEnabled() bool + IsRuntimeCodeSizeFixEnabled() bool IsMaxBlockchainHookCountersFlagEnabled() bool IsWipeSingleNFTLiquidityDecreaseEnabled() bool IsAlwaysSaveTokenMetaDataEnabled() bool IsInterfaceNil() bool } + +// ManagedPeersHolder defines the operations of an entity that holds managed identities for a node +type ManagedPeersHolder interface { + AddManagedPeer(privateKeyBytes []byte) error + GetPrivateKey(pkBytes []byte) (crypto.PrivateKey, error) + GetP2PIdentity(pkBytes []byte) ([]byte, core.PeerID, error) + GetMachineID(pkBytes []byte) (string, error) + GetNameAndIdentity(pkBytes []byte) (string, string, error) + IncrementRoundsWithoutReceivedMessages(pkBytes []byte) + ResetRoundsWithoutReceivedMessages(pkBytes []byte) + GetManagedKeysByCurrentNode() map[string]crypto.PrivateKey + IsKeyManagedByCurrentNode(pkBytes []byte) bool + IsKeyRegistered(pkBytes []byte) bool + IsPidManagedByCurrentNode(pid core.PeerID) bool + IsKeyValidator(pkBytes []byte) bool + SetValidatorState(pkBytes []byte, state bool) + GetNextPeerAuthenticationTime(pkBytes []byte) (time.Time, error) + SetNextPeerAuthenticationTime(pkBytes []byte, nextTime time.Time) + IsMultiKeyMode() bool + IsInterfaceNil() bool +} diff --git a/config/config.go b/config/config.go index fe7aa82c157..6dfc1170a2f 100644 --- a/config/config.go +++ b/config/config.go @@ -128,6 +128,7 @@ type HeartbeatV2Config struct { HardforkTimeBetweenSendsInSec int64 TimeBetweenConnectionsMetricsUpdateInSec int64 TimeToReadDirectConnectionsInSec int64 + PeerAuthenticationTimeBetweenChecksInSec int64 } // Config will hold the entire application configuration parameters @@ -290,7 +291,6 @@ type FacadeConfig struct { type StateTriesConfig struct { CheckpointRoundsModulus uint CheckpointsEnabled bool - SnapshotsEnabled bool AccountsStatePruningEnabled bool PeerStatePruningEnabled bool MaxStateTrieLevelInMemory uint @@ -587,6 +587,7 @@ type ConfigurationPathsHolder struct { Genesis string SmartContracts string ValidatorKey string + AllValidatorKeys string Epoch string RoundActivation string P2pKey string diff --git a/config/contextFlagsConfig.go b/config/contextFlagsConfig.go index 9bab4ea3d9e..360eeabf349 100644 --- a/config/contextFlagsConfig.go +++ b/config/contextFlagsConfig.go @@ -20,11 +20,13 @@ type ContextFlagsConfig struct { UseLogView bool ValidatorKeyIndex int EnableRestAPIServerDebugMode bool + BaseVersion string Version string ForceStartFromNetwork bool DisableConsensusWatchdog bool SerializeSnapshots bool NoKeyProvided bool + SnapshotsEnabled bool OperationMode string } diff --git a/config/epochConfig.go b/config/epochConfig.go index 9d24ff990d3..e729f362d91 100644 --- a/config/epochConfig.go +++ b/config/epochConfig.go @@ -87,6 +87,7 @@ type EnableEpochs struct { FixAsyncCallBackArgsListEnableEpoch uint32 FixOldTokenLiquidityEnableEpoch uint32 RuntimeMemStoreLimitEnableEpoch uint32 + RuntimeCodeSizeFixEnableEpoch uint32 SetSenderInEeiOutputTransferEnableEpoch uint32 RefactorPeersMiniBlocksEnableEpoch uint32 MaxBlockchainHookCountersEnableEpoch uint32 diff --git a/config/prefsConfig.go b/config/prefsConfig.go index 0c21eb775dc..4a6df0c9a73 100644 --- a/config/prefsConfig.go +++ b/config/prefsConfig.go @@ -2,7 +2,8 @@ package config // Preferences will hold the configuration related to node's preferences type Preferences struct { - Preferences PreferencesConfig + Preferences PreferencesConfig + NamedIdentity []NamedIdentity } // PreferencesConfig will hold the fields which are node specific such as the display name @@ -23,3 +24,10 @@ type OverridableConfig struct { Path string Value string } + +// NamedIdentity will hold the fields which are node named identities +type NamedIdentity struct { + Identity string + NodeName string + BLSKeys []string +} diff --git a/config/tomlConfig_test.go b/config/tomlConfig_test.go index 92802c97d02..611fd342d12 100644 --- a/config/tomlConfig_test.go +++ b/config/tomlConfig_test.go @@ -139,7 +139,7 @@ func TestTomlParser(t *testing.T) { [MiniBlocksStorage.Cache] Capacity = ` + strconv.Itoa(txBlockBodyStorageSize) + ` Type = "` + txBlockBodyStorageType + `" - Shards = ` + strconv.Itoa(txBlockBodyStorageShards) + ` + Shards = ` + strconv.Itoa(txBlockBodyStorageShards) + ` [MiniBlocksStorage.DB] FilePath = "` + txBlockBodyStorageFile + `" Type = "` + txBlockBodyStorageTypeDB + `" @@ -173,13 +173,13 @@ func TestTomlParser(t *testing.T) { Type = "` + accountsStorageTypeDB + `" [Hasher] - Type = "` + hasherType + `" + Type = "` + hasherType + `" [MultisigHasher] - Type = "` + multiSigHasherType + `" + Type = "` + multiSigHasherType + `" [Consensus] - Type = "` + consensusType + `" + Type = "` + consensusType + `" [VirtualMachine] [VirtualMachine.Execution] @@ -197,15 +197,15 @@ func TestTomlParser(t *testing.T) { { StartEpoch = 88, Version = "v1.2" }, ] - [VirtualMachine.GasConfig] - ShardMaxGasPerVmQuery = 1500000000 - MetaMaxGasPerVmQuery = 0 + [VirtualMachine.GasConfig] + ShardMaxGasPerVmQuery = 1500000000 + MetaMaxGasPerVmQuery = 0 [Debug] [Debug.InterceptorResolver] Enabled = true CacheSize = 10000 - EnablePrint = true + EnablePrint = true IntervalAutoPrintInSeconds = 20 NumRequestsThreshold = 9 NumResolveFailureThreshold = 3 @@ -277,22 +277,22 @@ func TestTomlEconomicsParser(t *testing.T) { [GlobalSettings] Denomination = ` + fmt.Sprintf("%d", denomination) + ` [RewardsSettings] - [[RewardsSettings.RewardsConfigByEpoch]] - EpochEnable = ` + fmt.Sprintf("%d", epoch0) + ` - LeaderPercentage = ` + fmt.Sprintf("%.6f", leaderPercentage1) + ` - DeveloperPercentage = ` + fmt.Sprintf("%.6f", developerPercentage) + ` - ProtocolSustainabilityPercentage = ` + fmt.Sprintf("%.6f", protocolSustainabilityPercentage) + ` #fraction of value 0.1 - 10% - ProtocolSustainabilityAddress = "` + protocolSustainabilityAddress + `" - - [[RewardsSettings.RewardsConfigByEpoch]] - EpochEnable = ` + fmt.Sprintf("%d", epoch1) + ` - LeaderPercentage = ` + fmt.Sprintf("%.6f", leaderPercentage2) + ` + [[RewardsSettings.RewardsConfigByEpoch]] + EpochEnable = ` + fmt.Sprintf("%d", epoch0) + ` + LeaderPercentage = ` + fmt.Sprintf("%.6f", leaderPercentage1) + ` + DeveloperPercentage = ` + fmt.Sprintf("%.6f", developerPercentage) + ` + ProtocolSustainabilityPercentage = ` + fmt.Sprintf("%.6f", protocolSustainabilityPercentage) + ` #fraction of value 0.1 - 10% + ProtocolSustainabilityAddress = "` + protocolSustainabilityAddress + `" + + [[RewardsSettings.RewardsConfigByEpoch]] + EpochEnable = ` + fmt.Sprintf("%d", epoch1) + ` + LeaderPercentage = ` + fmt.Sprintf("%.6f", leaderPercentage2) + ` DeveloperPercentage = ` + fmt.Sprintf("%.6f", developerPercentage) + ` ProtocolSustainabilityPercentage = ` + fmt.Sprintf("%.6f", protocolSustainabilityPercentage) + ` #fraction of value 0.1 - 10% ProtocolSustainabilityAddress = "` + protocolSustainabilityAddress + `" [FeeSettings] - GasLimitSettings = [{EnableEpoch = 0, MaxGasLimitPerBlock = "` + maxGasLimitPerBlock + `", MaxGasLimitPerMiniBlock = "", MaxGasLimitPerMetaBlock = "", MaxGasLimitPerMetaMiniBlock = "", MaxGasLimitPerTx = "", MinGasLimit = "` + minGasLimit + `"}] + GasLimitSettings = [{EnableEpoch = 0, MaxGasLimitPerBlock = "` + maxGasLimitPerBlock + `", MaxGasLimitPerMiniBlock = "", MaxGasLimitPerMetaBlock = "", MaxGasLimitPerMetaMiniBlock = "", MaxGasLimitPerTx = "", MinGasLimit = "` + minGasLimit + `"}] MinGasPrice = "` + minGasPrice + `" ` cfg := EconomicsConfig{} @@ -323,14 +323,14 @@ func TestTomlPreferencesParser(t *testing.T) { testString := ` [Preferences] - NodeDisplayName = "` + nodeDisplayName + `" - DestinationShardAsObserver = "` + destinationShardAsObs + `" - Identity = "` + identity + `" - RedundancyLevel = ` + fmt.Sprintf("%d", redundancyLevel) + ` - PreferredConnections = [ - "` + prefPubKey0 + `", - "` + prefPubKey1 + `" - ] + NodeDisplayName = "` + nodeDisplayName + `" + DestinationShardAsObserver = "` + destinationShardAsObs + `" + Identity = "` + identity + `" + RedundancyLevel = ` + fmt.Sprintf("%d", redundancyLevel) + ` + PreferredConnections = [ + "` + prefPubKey0 + `", + "` + prefPubKey1 + `" + ] ` cfg := Preferences{} @@ -408,16 +408,16 @@ func TestAPIRoutesToml(t *testing.T) { [APIPackages] [APIPackages.` + package0 + `] - Routes = [ + Routes = [ # test comment { Name = "` + route0 + `", Open = true }, # test comment { Name = "` + route1 + `", Open = true }, - ] + ] [APIPackages.` + package1 + `] - Routes = [ + Routes = [ # test comment { Name = "` + route2 + `", Open = false } ] @@ -493,202 +493,143 @@ func TestEnableEpochConfig(t *testing.T) { [EnableEpochs] # SCDeployEnableEpoch represents the epoch when the deployment of smart contracts will be enabled SCDeployEnableEpoch = 1 - # BuiltInFunctionsEnableEpoch represents the epoch when the built in functions will be enabled BuiltInFunctionsEnableEpoch = 2 - # RelayedTransactionsEnableEpoch represents the epoch when the relayed transactions will be enabled RelayedTransactionsEnableEpoch = 3 - # PenalizedTooMuchGasEnableEpoch represents the epoch when the penalization for using too much gas will be enabled PenalizedTooMuchGasEnableEpoch = 4 - # SwitchJailWaitingEnableEpoch represents the epoch when the system smart contract processing at end of epoch is enabled SwitchJailWaitingEnableEpoch = 5 - # BelowSignedThresholdEnableEpoch represents the epoch when the change for computing rating for validators below signed rating is enabled BelowSignedThresholdEnableEpoch = 6 - # SwitchHysteresisForMinNodesEnableEpoch represents the epoch when the system smart contract changes its config to consider # also (minimum) hysteresis nodes for the minimum number of nodes SwitchHysteresisForMinNodesEnableEpoch = 7 - # TransactionSignedWithTxHashEnableEpoch represents the epoch when the node will also accept transactions that are # signed with the hash of transaction TransactionSignedWithTxHashEnableEpoch = 8 - # MetaProtectionEnableEpoch represents the epoch when the transactions to the metachain are checked to have enough gas MetaProtectionEnableEpoch = 9 - # AheadOfTimeGasUsageEnableEpoch represents the epoch when the cost of smart contract prepare changes from compiler per byte to ahead of time prepare per byte AheadOfTimeGasUsageEnableEpoch = 10 - # GasPriceModifierEnableEpoch represents the epoch when the gas price modifier in fee computation is enabled GasPriceModifierEnableEpoch = 11 - # RepairCallbackEnableEpoch represents the epoch when the callback repair is activated for scrs RepairCallbackEnableEpoch = 12 - # BlockGasAndFeesReCheckEnableEpoch represents the epoch when gas and fees used in each created or processed block are re-checked BlockGasAndFeesReCheckEnableEpoch = 13 - # BalanceWaitingListsEnableEpoch represents the epoch when the shard waiting lists are balanced at the start of an epoch BalanceWaitingListsEnableEpoch = 14 - # ReturnDataToLastTransferEnableEpoch represents the epoch when returned data is added to last output transfer for callbacks ReturnDataToLastTransferEnableEpoch = 15 - # SenderInOutTransferEnableEpoch represents the epoch when the feature of having different senders in output transfer is enabled SenderInOutTransferEnableEpoch = 16 - # StakeEnableEpoch represents the epoch when staking is enabled StakeEnableEpoch = 17 - # StakingV2EnableEpoch represents the epoch when staking v2 is enabled StakingV2EnableEpoch = 18 - DoubleKeyProtectionEnableEpoch = 19 - # ESDTEnableEpoch represents the epoch when ESDT is enabled ESDTEnableEpoch = 20 - # GovernanceEnableEpoch represents the epoch when governance is enabled GovernanceEnableEpoch = 21 - # DelegationManagerEnableEpoch represents the epoch when the delegation manager is enabled # epoch should not be 0 DelegationManagerEnableEpoch = 22 - # DelegationSmartContractEnableEpoch represents the epoch when delegation smart contract is enabled # epoch should not be 0 DelegationSmartContractEnableEpoch = 23 - # CorrectLastUnjailedEnableEpoch represents the epoch when the fix regaring the last unjailed node should apply CorrectLastUnjailedEnableEpoch = 24 - # RelayedTransactionsV2EnableEpoch represents the epoch when the relayed transactions V2 will be enabled RelayedTransactionsV2EnableEpoch = 25 - # UnbondTokensV2EnableEpoch represents the epoch when the new implementation of the unbond tokens function is available UnbondTokensV2EnableEpoch = 26 - # SaveJailedAlwaysEnableEpoch represents the epoch when saving jailed status at end of epoch will happen in all cases SaveJailedAlwaysEnableEpoch = 27 - # ReDelegateBelowMinCheckEnableEpoch represents the epoch when the check for the re-delegated value will be enabled ReDelegateBelowMinCheckEnableEpoch = 28 - # ValidatorToDelegationEnableEpoch represents the epoch when the validator-to-delegation feature will be enabled ValidatorToDelegationEnableEpoch = 29 - # WaitingListFixEnableEpoch represents the epoch when the 6 epoch waiting list fix is enabled WaitingListFixEnableEpoch = 30 - # IncrementSCRNonceInMultiTransferEnableEpoch represents the epoch when the fix for preventing the generation of the same SCRs # is enabled. The fix is done by adding an extra increment. IncrementSCRNonceInMultiTransferEnableEpoch = 31 - # ESDTMultiTransferEnableEpoch represents the epoch when esdt multitransfer built in function is enabled ESDTMultiTransferEnableEpoch = 32 - # GlobalMintBurnDisableEpoch represents the epoch when the global mint and burn functions are disabled GlobalMintBurnDisableEpoch = 33 - # ESDTTransferRoleEnableEpoch represents the epoch when esdt transfer role set is enabled ESDTTransferRoleEnableEpoch = 34 - # BuiltInFunctionOnMetaEnableEpoch represents the epoch when built in function processing on metachain is enabled BuiltInFunctionOnMetaEnableEpoch = 35 - # ComputeRewardCheckpointEnableEpoch represents the epoch when compute rewards checkpoint epoch is enabled ComputeRewardCheckpointEnableEpoch = 36 - # SCRSizeInvariantCheckEnableEpoch represents the epoch when the scr size invariant check is enabled SCRSizeInvariantCheckEnableEpoch = 37 - # BackwardCompSaveKeyValueEnableEpoch represents the epoch when backward compatibility save key value is enabled BackwardCompSaveKeyValueEnableEpoch = 38 - # ESDTNFTCreateOnMultiShardEnableEpoch represents the epoch when esdt nft creation on multiple shards is enabled ESDTNFTCreateOnMultiShardEnableEpoch = 39 - # MetaESDTSetEnableEpoch represents the epoch when the backward compatibility for save key value error is enabled MetaESDTSetEnableEpoch = 40 - # AddTokensToDelegationEnableEpoch represents the epoch when adding tokens to delegation is enabled for whitelisted address AddTokensToDelegationEnableEpoch = 41 - # MultiESDTTransferFixOnCallBackOnEnableEpoch represents the epoch when multi esdt transfer on callback fix is enabled MultiESDTTransferFixOnCallBackOnEnableEpoch = 42 - # OptimizeGasUsedInCrossMiniBlocksEnableEpoch represents the epoch when gas used in cross shard mini blocks will be optimized OptimizeGasUsedInCrossMiniBlocksEnableEpoch = 43 - # FixOOGReturnCodeEnableEpoch represents the epoch when the backward compatibility returning out of gas error is enabled FixOOGReturnCodeEnableEpoch = 44 - # RemoveNonUpdatedStorageEnableEpoch represents the epoch when the backward compatibility for removing non updated storage is enabled RemoveNonUpdatedStorageEnableEpoch = 45 - # OptimizeNFTStoreEnableEpoch represents the epoch when optimizations on NFT metadata store and send are enabled OptimizeNFTStoreEnableEpoch = 46 - # CreateNFTThroughExecByCallerEnableEpoch represents the epoch when nft creation through execution on destination by caller is enabled CreateNFTThroughExecByCallerEnableEpoch = 47 - # IsPayableBySCEnableEpoch represents the epoch when a new flag isPayable by SC is enabled IsPayableBySCEnableEpoch = 48 - - # CleanUpInformativeSCRsEnableEpoch represents the epoch when the scrs which contain only information are cleaned from miniblocks and logs are created from it - CleanUpInformativeSCRsEnableEpoch = 49 - + # CleanUpInformativeSCRsEnableEpoch represents the epoch when the scrs which contain only information are cleaned from miniblocks and logs are created from it + CleanUpInformativeSCRsEnableEpoch = 49 # StorageAPICostOptimizationEnableEpoch represents the epoch when new storage helper functions are enabled and cost is reduced in Wasm VM StorageAPICostOptimizationEnableEpoch = 50 - # TransformToMultiShardCreateEnableEpoch represents the epoch when the new function on esdt system sc is enabled to transfer create role into multishard - TransformToMultiShardCreateEnableEpoch = 51 - + TransformToMultiShardCreateEnableEpoch = 51 # ESDTRegisterAndSetAllRolesEnableEpoch represents the epoch when new function to register tickerID and set all roles is enabled ESDTRegisterAndSetAllRolesEnableEpoch = 52 - - # FailExecutionOnEveryAPIErrorEnableEpoch represent the epoch when new protection in VM is enabled to fail all wrong API calls - FailExecutionOnEveryAPIErrorEnableEpoch = 53 - - # ManagedCryptoAPIsEnableEpoch represents the epoch when the new managed crypto APIs are enabled - ManagedCryptoAPIsEnableEpoch = 54 - - # ESDTMetadataContinuousCleanupEnableEpoch represents the epoch when esdt metadata is automatically deleted according to inshard liquidity - ESDTMetadataContinuousCleanupEnableEpoch = 55 - + # FailExecutionOnEveryAPIErrorEnableEpoch represent the epoch when new protection in VM is enabled to fail all wrong API calls + FailExecutionOnEveryAPIErrorEnableEpoch = 53 + # ManagedCryptoAPIsEnableEpoch represents the epoch when the new managed crypto APIs are enabled + ManagedCryptoAPIsEnableEpoch = 54 + # ESDTMetadataContinuousCleanupEnableEpoch represents the epoch when esdt metadata is automatically deleted according to inshard liquidity + ESDTMetadataContinuousCleanupEnableEpoch = 55 # FixAsyncCallBackArgsListEnableEpoch represents the epoch when the async callback arguments lists fix will be enabled FixAsyncCallBackArgsListEnableEpoch = 56 - - # FixOldTokenLiquidityEnableEpoch represents the epoch when the fix for old token liquidity is enabled - FixOldTokenLiquidityEnableEpoch = 57 - - # SetSenderInEeiOutputTransferEnableEpoch represents the epoch when setting the sender in eei output transfers will be enabled + # FixOldTokenLiquidityEnableEpoch represents the epoch when the fix for old token liquidity is enabled + FixOldTokenLiquidityEnableEpoch = 57 + # SetSenderInEeiOutputTransferEnableEpoch represents the epoch when setting the sender in eei output transfers will be enabled SetSenderInEeiOutputTransferEnableEpoch = 58 - - # MaxBlockchainHookCountersEnableEpoch represents the epoch when the max blockchainhook counters are enabled - MaxBlockchainHookCountersEnableEpoch = 59 - + # MaxBlockchainHookCountersEnableEpoch represents the epoch when the max blockchainhook counters are enabled + MaxBlockchainHookCountersEnableEpoch = 59 # WipeSingleNFTLiquidityDecreaseEnableEpoch represents the epoch when the system account liquidity is decreased for wipeSingleNFT as well WipeSingleNFTLiquidityDecreaseEnableEpoch = 60 - # AlwaysSaveTokenMetaDataEnableEpoch represents the epoch when the token metadata is always saved AlwaysSaveTokenMetaDataEnableEpoch = 61 - + # RuntimeCodeSizeFixEnableEpoch represents the epoch when the code size fix in the VM is enabled + RuntimeCodeSizeFixEnableEpoch = 62 + # RuntimeMemStoreLimitEnableEpoch represents the epoch when the condition for Runtime MemStore is enabled + RuntimeMemStoreLimitEnableEpoch = 63 # MaxNodesChangeEnableEpoch holds configuration for changing the maximum number of nodes and the enabling epoch MaxNodesChangeEnableEpoch = [ { EpochEnable = 44, MaxNumNodes = 2169, NodesToShufflePerShard = 80 }, { EpochEnable = 45, MaxNumNodes = 3200, NodesToShufflePerShard = 80 } ] - - BLSMultiSignerEnableEpoch = [ - {EnableEpoch = 0, Type = "no-KOSK"}, - {EnableEpoch = 3, Type = "KOSK"} - ] - + BLSMultiSignerEnableEpoch = [ + {EnableEpoch = 0, Type = "no-KOSK"}, + {EnableEpoch = 3, Type = "KOSK"} + ] [GasSchedule] GasScheduleByEpochs = [ { StartEpoch = 46, FileName = "gasScheduleV1.toml" }, @@ -771,6 +712,8 @@ func TestEnableEpochConfig(t *testing.T) { MaxBlockchainHookCountersEnableEpoch: 59, WipeSingleNFTLiquidityDecreaseEnableEpoch: 60, AlwaysSaveTokenMetaDataEnableEpoch: 61, + RuntimeCodeSizeFixEnableEpoch: 62, + RuntimeMemStoreLimitEnableEpoch: 63, BLSMultiSignerEnableEpoch: []MultiSignerConfig{ { EnableEpoch: 0, diff --git a/consensus/broadcast/commonMessenger.go b/consensus/broadcast/commonMessenger.go index 79a190d130f..60c59e01145 100644 --- a/consensus/broadcast/commonMessenger.go +++ b/consensus/broadcast/commonMessenger.go @@ -28,9 +28,9 @@ type delayedBroadcaster interface { SetValidatorData(data *delayedBroadcastData) error SetHeaderForValidator(vData *validatorHeaderBroadcastData) error SetBroadcastHandlers( - mbBroadcast func(mbData map[uint32][]byte) error, - txBroadcast func(txData map[string][][]byte) error, - headerBroadcast func(header data.HeaderHandler) error, + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, ) error Close() } @@ -39,10 +39,10 @@ type commonMessenger struct { marshalizer marshal.Marshalizer hasher hashing.Hasher messenger consensus.P2PMessenger - privateKey crypto.PrivateKey shardCoordinator sharding.Coordinator peerSignatureHandler crypto.PeerSignatureHandler delayedBlockBroadcaster delayedBroadcaster + keysHandler consensus.KeysHandler } // CommonMessengerArgs holds the arguments for creating commonMessenger instance @@ -50,7 +50,6 @@ type CommonMessengerArgs struct { Marshalizer marshal.Marshalizer Hasher hashing.Hasher Messenger consensus.P2PMessenger - PrivateKey crypto.PrivateKey ShardCoordinator sharding.Coordinator PeerSignatureHandler crypto.PeerSignatureHandler HeadersSubscriber consensus.HeadersPoolSubscriber @@ -58,6 +57,7 @@ type CommonMessengerArgs struct { MaxDelayCacheSize uint32 MaxValidatorDelayCacheSize uint32 AlarmScheduler core.TimersScheduler + KeysHandler consensus.KeysHandler } func checkCommonMessengerNilParameters( @@ -72,9 +72,6 @@ func checkCommonMessengerNilParameters( if check.IfNil(args.Messenger) { return spos.ErrNilMessenger } - if check.IfNil(args.PrivateKey) { - return spos.ErrNilPrivateKey - } if check.IfNil(args.ShardCoordinator) { return spos.ErrNilShardCoordinator } @@ -93,13 +90,17 @@ func checkCommonMessengerNilParameters( if args.MaxDelayCacheSize == 0 || args.MaxValidatorDelayCacheSize == 0 { return spos.ErrInvalidCacheSize } + if check.IfNil(args.KeysHandler) { + return ErrNilKeysHandler + } return nil } // BroadcastConsensusMessage will send on consensus topic the consensus message func (cm *commonMessenger) BroadcastConsensusMessage(message *consensus.Message) error { - signature, err := cm.peerSignatureHandler.GetPeerSignature(cm.privateKey, message.OriginatorPid) + privateKey := cm.keysHandler.GetHandledPrivateKey(message.PubKey) + signature, err := cm.peerSignatureHandler.GetPeerSignature(privateKey, message.OriginatorPid) if err != nil { return err } @@ -114,18 +115,18 @@ func (cm *commonMessenger) BroadcastConsensusMessage(message *consensus.Message) consensusTopic := common.ConsensusTopic + cm.shardCoordinator.CommunicationIdentifier(cm.shardCoordinator.SelfId()) - cm.messenger.Broadcast(consensusTopic, buff) + cm.broadcast(consensusTopic, buff, message.PubKey) return nil } // BroadcastMiniBlocks will send on miniblocks topic the cross-shard miniblocks -func (cm *commonMessenger) BroadcastMiniBlocks(miniBlocks map[uint32][]byte) error { +func (cm *commonMessenger) BroadcastMiniBlocks(miniBlocks map[uint32][]byte, pkBytes []byte) error { for k, v := range miniBlocks { miniBlocksTopic := factory.MiniBlocksTopic + cm.shardCoordinator.CommunicationIdentifier(k) - cm.messenger.Broadcast(miniBlocksTopic, v) + cm.broadcast(miniBlocksTopic, v, pkBytes) } if len(miniBlocks) > 0 { @@ -138,7 +139,7 @@ func (cm *commonMessenger) BroadcastMiniBlocks(miniBlocks map[uint32][]byte) err } // BroadcastTransactions will send on transaction topic the transactions -func (cm *commonMessenger) BroadcastTransactions(transactions map[string][][]byte) error { +func (cm *commonMessenger) BroadcastTransactions(transactions map[string][][]byte, pkBytes []byte) error { dataPacker, err := partitioning.NewSimpleDataPacker(cm.marshalizer) if err != nil { return err @@ -155,7 +156,7 @@ func (cm *commonMessenger) BroadcastTransactions(transactions map[string][][]byt } for _, buff := range packets { - cm.messenger.Broadcast(topic, buff) + cm.broadcast(topic, buff, pkBytes) } } @@ -172,12 +173,13 @@ func (cm *commonMessenger) BroadcastTransactions(transactions map[string][][]byt func (cm *commonMessenger) BroadcastBlockData( miniBlocks map[uint32][]byte, transactions map[string][][]byte, + pkBytes []byte, extraDelayForBroadcast time.Duration, ) { time.Sleep(extraDelayForBroadcast) if len(miniBlocks) > 0 { - err := cm.BroadcastMiniBlocks(miniBlocks) + err := cm.BroadcastMiniBlocks(miniBlocks, pkBytes) if err != nil { log.Warn("commonMessenger.BroadcastBlockData: broadcast miniblocks", "error", err.Error()) } @@ -186,7 +188,7 @@ func (cm *commonMessenger) BroadcastBlockData( time.Sleep(common.ExtraDelayBetweenBroadcastMbsAndTxs) if len(transactions) > 0 { - err := cm.BroadcastTransactions(transactions) + err := cm.BroadcastTransactions(transactions, pkBytes) if err != nil { log.Warn("commonMessenger.BroadcastBlockData: broadcast transactions", "error", err.Error()) } @@ -223,3 +225,19 @@ func (cm *commonMessenger) extractMetaMiniBlocksAndTransactions( return metaMiniBlocks, metaTransactions } + +func (cm *commonMessenger) broadcast(topic string, data []byte, pkBytes []byte) { + if cm.keysHandler.IsOriginalPublicKeyOfTheNode(pkBytes) { + cm.messenger.Broadcast(topic, data) + return + } + + skBytes, pid, err := cm.keysHandler.GetP2PIdentity(pkBytes) + if err != nil { + log.Error("setup error in commonMessenger.broadcast - public key is managed but does not contain p2p sign info", + "pk", pkBytes, "error", err) + return + } + + cm.messenger.BroadcastUsingPrivateKey(topic, data, pid, skBytes) +} diff --git a/consensus/broadcast/commonMessenger_test.go b/consensus/broadcast/commonMessenger_test.go index 75ac560b809..939f2854ff3 100644 --- a/consensus/broadcast/commonMessenger_test.go +++ b/consensus/broadcast/commonMessenger_test.go @@ -1,6 +1,7 @@ package broadcast_test import ( + "bytes" "sync" "testing" "time" @@ -11,12 +12,18 @@ import ( "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/broadcast" "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +const ( + broadcastMethodPrefix = "broadcast" + broadcastUsingPrivateKeyCalledMethodPrefix = "broadcastUsingPrivateKeyCalled" +) + func newTestBlockBody() *block.Body { return &block.Body{ MiniBlocks: []*block.MiniBlock{ @@ -34,7 +41,6 @@ func TestCommonMessenger_BroadcastConsensusMessageShouldErrWhenSignMessageFail(t err := errors.New("sign message error") marshalizerMock := &mock.MarshalizerMock{} messengerMock := &p2pmocks.MessengerStub{} - privateKeyMock := &mock.PrivateKeyMock{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{ SignStub: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { @@ -46,9 +52,9 @@ func TestCommonMessenger_BroadcastConsensusMessageShouldErrWhenSignMessageFail(t cm, _ := broadcast.NewCommonMessenger( marshalizerMock, messengerMock, - privateKeyMock, shardCoordinatorMock, peerSigHandler, + &testscommon.KeysHandlerStub{}, ) msg := &consensus.Message{} @@ -62,7 +68,6 @@ func TestCommonMessenger_BroadcastConsensusMessageShouldWork(t *testing.T) { BroadcastCalled: func(topic string, buff []byte) { }, } - privateKeyMock := &mock.PrivateKeyMock{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{ SignStub: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { @@ -74,9 +79,9 @@ func TestCommonMessenger_BroadcastConsensusMessageShouldWork(t *testing.T) { cm, _ := broadcast.NewCommonMessenger( marshalizerMock, messengerMock, - privateKeyMock, shardCoordinatorMock, peerSigHandler, + &testscommon.KeysHandlerStub{}, ) msg := &consensus.Message{} @@ -84,32 +89,6 @@ func TestCommonMessenger_BroadcastConsensusMessageShouldWork(t *testing.T) { assert.Nil(t, err) } -func TestCommonMessenger_SignMessageShouldErrWhenSignFail(t *testing.T) { - err := errors.New("sign message error") - marshalizerMock := &mock.MarshalizerMock{} - messengerMock := &p2pmocks.MessengerStub{} - privateKeyMock := &mock.PrivateKeyMock{} - shardCoordinatorMock := &mock.ShardCoordinatorMock{} - singleSignerMock := &mock.SingleSignerMock{ - SignStub: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { - return nil, err - }, - } - peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock} - - cm, _ := broadcast.NewCommonMessenger( - marshalizerMock, - messengerMock, - privateKeyMock, - shardCoordinatorMock, - peerSigHandler, - ) - - msg := &consensus.Message{} - _, err2 := cm.SignMessage(msg) - assert.Equal(t, err, err2) -} - func TestSubroundEndRound_ExtractMiniBlocksAndTransactionsShouldWork(t *testing.T) { t.Parallel() @@ -143,7 +122,6 @@ func TestSubroundEndRound_ExtractMiniBlocksAndTransactionsShouldWork(t *testing. BroadcastCalled: func(topic string, buff []byte) { }, } - privateKeyMock := &mock.PrivateKeyMock{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{ SignStub: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { @@ -155,9 +133,9 @@ func TestSubroundEndRound_ExtractMiniBlocksAndTransactionsShouldWork(t *testing. cm, _ := broadcast.NewCommonMessenger( marshalizerMock, messengerMock, - privateKeyMock, shardCoordinatorMock, peerSigHandler, + &testscommon.KeysHandlerStub{}, ) metaMiniBlocks, metaTransactions := cm.ExtractMetaMiniBlocksAndTransactions(miniBlocks, transactions) @@ -181,11 +159,15 @@ func TestCommonMessenger_BroadcastBlockData(t *testing.T) { messengerMock := &p2pmocks.MessengerStub{ BroadcastCalled: func(topic string, buff []byte) { mutCounters.Lock() - countersBroadcast[topic]++ + countersBroadcast[broadcastMethodPrefix+topic]++ + mutCounters.Unlock() + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + mutCounters.Lock() + countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+topic]++ mutCounters.Unlock() }, } - privateKeyMock := &mock.PrivateKeyMock{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{ SignStub: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { @@ -197,20 +179,161 @@ func TestCommonMessenger_BroadcastBlockData(t *testing.T) { cm, _ := broadcast.NewCommonMessenger( marshalizerMock, messengerMock, - privateKeyMock, shardCoordinatorMock, peerSigHandler, + &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(pkBytes, nodePkBytes) + }, + }, ) miniBlocks := map[uint32][]byte{0: []byte("mbs data1"), 1: []byte("mbs data2")} transactions := map[string][][]byte{"topic1": {[]byte("txdata1"), []byte("txdata2")}, "topic2": {[]byte("txdata3")}} delay := time.Millisecond * 10 - cm.BroadcastBlockData(miniBlocks, transactions, delay) - time.Sleep(delay * 2) - mutCounters.Lock() - defer mutCounters.Unlock() + t.Run("original public key of the node", func(t *testing.T) { + mutCounters.Lock() + countersBroadcast = make(map[string]int) + mutCounters.Unlock() + + cm.BroadcastBlockData(miniBlocks, transactions, nodePkBytes, delay) + time.Sleep(delay * 2) + + mutCounters.Lock() + defer mutCounters.Unlock() + + numBroadcast := countersBroadcast[broadcastMethodPrefix+"txBlockBodies_0"] + numBroadcast += countersBroadcast[broadcastMethodPrefix+"txBlockBodies_0_1"] + assert.Equal(t, len(miniBlocks), numBroadcast) + + numBroadcast = countersBroadcast[broadcastMethodPrefix+"topic1"] + numBroadcast += countersBroadcast[broadcastMethodPrefix+"topic2"] + assert.Equal(t, len(transactions), numBroadcast) + }) + t.Run("managed key", func(t *testing.T) { + mutCounters.Lock() + countersBroadcast = make(map[string]int) + mutCounters.Unlock() + + cm.BroadcastBlockData(miniBlocks, transactions, []byte("managed key"), delay) + time.Sleep(delay * 2) + + mutCounters.Lock() + defer mutCounters.Unlock() + + numBroadcast := countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"txBlockBodies_0"] + numBroadcast += countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"txBlockBodies_0_1"] + assert.Equal(t, len(miniBlocks), numBroadcast) + + numBroadcast = countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"topic1"] + numBroadcast += countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"topic2"] + assert.Equal(t, len(transactions), numBroadcast) + }) +} + +func TestCommonMessenger_broadcast(t *testing.T) { + t.Parallel() + + testTopic := "test" + expectedErr := errors.New("expected error") + marshallerMock := &mock.MarshalizerMock{} + countersBroadcast := make(map[string]int) + mutCounters := &sync.Mutex{} + + messengerMock := &p2pmocks.MessengerStub{ + BroadcastCalled: func(topic string, buff []byte) { + mutCounters.Lock() + countersBroadcast[broadcastMethodPrefix+topic]++ + mutCounters.Unlock() + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + mutCounters.Lock() + countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+topic]++ + mutCounters.Unlock() + }, + } + shardCoordinatorMock := &mock.ShardCoordinatorMock{} + singleSignerMock := &mock.SingleSignerMock{ + SignStub: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { + return []byte(""), nil + }, + } + peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock} + + t.Run("using the original public key bytes of the node", func(t *testing.T) { + mutCounters.Lock() + countersBroadcast = make(map[string]int) + mutCounters.Unlock() + + cm, _ := broadcast.NewCommonMessenger( + marshallerMock, + messengerMock, + shardCoordinatorMock, + peerSigHandler, + &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(nodePkBytes, pkBytes) + }, + }, + ) + + cm.Broadcast(testTopic, []byte("data"), nodePkBytes) + + mutCounters.Lock() + assert.Equal(t, 1, countersBroadcast[broadcastMethodPrefix+testTopic]) + assert.Equal(t, 0, countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+testTopic]) + mutCounters.Unlock() + }) + t.Run("using a managed key", func(t *testing.T) { + mutCounters.Lock() + countersBroadcast = make(map[string]int) + mutCounters.Unlock() + + cm, _ := broadcast.NewCommonMessenger( + marshallerMock, + messengerMock, + shardCoordinatorMock, + peerSigHandler, + &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return false + }, + }, + ) + + cm.Broadcast(testTopic, []byte("data"), []byte("managed key")) + + mutCounters.Lock() + assert.Equal(t, 1, countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+testTopic]) + assert.Equal(t, 0, countersBroadcast[broadcastMethodPrefix+testTopic]) + mutCounters.Unlock() + }) + t.Run("managed key and the keys handler fails", func(t *testing.T) { + mutCounters.Lock() + countersBroadcast = make(map[string]int) + mutCounters.Unlock() + + cm, _ := broadcast.NewCommonMessenger( + marshallerMock, + messengerMock, + shardCoordinatorMock, + peerSigHandler, + &testscommon.KeysHandlerStub{ + GetP2PIdentityCalled: func(pkBytes []byte) ([]byte, core.PeerID, error) { + return nil, "", expectedErr + }, + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return false + }, + }, + ) + + cm.Broadcast(testTopic, []byte("data"), []byte("managed key")) - assert.Equal(t, len(miniBlocks), countersBroadcast["txBlockBodies_0"]+countersBroadcast["txBlockBodies_0_1"]) - assert.Equal(t, len(transactions), countersBroadcast["topic1"]+countersBroadcast["topic2"]) + mutCounters.Lock() + assert.Equal(t, 0, countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+testTopic]) + assert.Equal(t, 0, countersBroadcast[broadcastMethodPrefix+testTopic]) + mutCounters.Unlock() + }) } diff --git a/consensus/broadcast/delayedBroadcast.go b/consensus/broadcast/delayedBroadcast.go index 234134d81dd..955a81f0f73 100644 --- a/consensus/broadcast/delayedBroadcast.go +++ b/consensus/broadcast/delayedBroadcast.go @@ -41,6 +41,7 @@ type validatorHeaderBroadcastData struct { metaMiniBlocksData map[uint32][]byte metaTransactionsData map[string][][]byte order uint32 + pkBytes []byte } type delayedBroadcastData struct { @@ -50,6 +51,7 @@ type delayedBroadcastData struct { miniBlockHashes map[string]map[string]struct{} transactions map[string][][]byte order uint32 + pkBytes []byte } // timersScheduler exposes functionality for scheduling multiple timers @@ -76,9 +78,9 @@ type delayedBlockBroadcaster struct { maxDelayCacheSize uint32 maxValidatorDelayCacheSize uint32 mutDataForBroadcast sync.RWMutex - broadcastMiniblocksData func(mbData map[uint32][]byte) error - broadcastTxsData func(txData map[string][][]byte) error - broadcastHeader func(header data.HeaderHandler) error + broadcastMiniblocksData func(mbData map[uint32][]byte, pkBytes []byte) error + broadcastTxsData func(txData map[string][][]byte, pkBytes []byte) error + broadcastHeader func(header data.HeaderHandler, pkBytes []byte) error cacheHeaders storage.Cacher mutHeadersCache sync.RWMutex } @@ -246,9 +248,9 @@ func (dbb *delayedBlockBroadcaster) SetValidatorData(broadcastData *delayedBroad // SetBroadcastHandlers sets the broadcast handlers for miniBlocks and transactions func (dbb *delayedBlockBroadcaster) SetBroadcastHandlers( - mbBroadcast func(mbData map[uint32][]byte) error, - txBroadcast func(txData map[string][][]byte) error, - headerBroadcast func(header data.HeaderHandler) error, + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, ) error { if mbBroadcast == nil || txBroadcast == nil || headerBroadcast == nil { return spos.ErrNilParameter @@ -461,7 +463,7 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { "alarmID", alarmID, ) // broadcast header - err = dbb.broadcastHeader(vHeader.header) + err = dbb.broadcastHeader(vHeader.header, vHeader.pkBytes) if err != nil { log.Warn("delayedBlockBroadcaster.headerAlarmExpired", "error", err.Error(), "headerHash", headerHash, @@ -475,33 +477,34 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { "headerHash", headerHash, "alarmID", alarmID, ) - go dbb.broadcastBlockData(vHeader.metaMiniBlocksData, vHeader.metaTransactionsData, common.ExtraDelayForBroadcastBlockInfo) + go dbb.broadcastBlockData(vHeader.metaMiniBlocksData, vHeader.metaTransactionsData, vHeader.pkBytes, common.ExtraDelayForBroadcastBlockInfo) } } func (dbb *delayedBlockBroadcaster) broadcastDelayedData(broadcastData []*delayedBroadcastData) { for _, bData := range broadcastData { - go func(miniBlocks map[uint32][]byte, transactions map[string][][]byte) { - dbb.broadcastBlockData(miniBlocks, transactions, 0) - }(bData.miniBlocksData, bData.transactions) + go func(miniBlocks map[uint32][]byte, transactions map[string][][]byte, pkBytes []byte) { + dbb.broadcastBlockData(miniBlocks, transactions, pkBytes, 0) + }(bData.miniBlocksData, bData.transactions, bData.pkBytes) } } func (dbb *delayedBlockBroadcaster) broadcastBlockData( miniBlocks map[uint32][]byte, transactions map[string][][]byte, + pkBytes []byte, delay time.Duration, ) { time.Sleep(delay) - err := dbb.broadcastMiniblocksData(miniBlocks) + err := dbb.broadcastMiniblocksData(miniBlocks, pkBytes) if err != nil { log.Error("broadcastBlockData.broadcastMiniblocksData", "error", err.Error()) } time.Sleep(common.ExtraDelayBetweenBroadcastMbsAndTxs) - err = dbb.broadcastTxsData(transactions) + err = dbb.broadcastTxsData(transactions, pkBytes) if err != nil { log.Error("broadcastBlockData.broadcastTxsData", "error", err.Error()) } diff --git a/consensus/broadcast/delayedBroadcast_test.go b/consensus/broadcast/delayedBroadcast_test.go index d601e6a7848..0f22e8a5157 100644 --- a/consensus/broadcast/delayedBroadcast_test.go +++ b/consensus/broadcast/delayedBroadcast_test.go @@ -166,15 +166,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedNoDelayedDataRegistered(t *testin mbBroadcastCalled := atomic.Flag{} txBroadcastCalled := atomic.Flag{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { _ = mbBroadcastCalled.SetReturningPrevious() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { _ = txBroadcastCalled.SetReturningPrevious() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -199,15 +199,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForRegisteredDelayedDataShouldBro mbBroadcastCalled := atomic.Flag{} txBroadcastCalled := atomic.Flag{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { _ = mbBroadcastCalled.SetReturningPrevious() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { _ = txBroadcastCalled.SetReturningPrevious() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -245,15 +245,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForNotRegisteredDelayedDataShould mbBroadcastCalled := atomic.Flag{} txBroadcastCalled := atomic.Flag{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { _ = mbBroadcastCalled.SetReturningPrevious() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { _ = txBroadcastCalled.SetReturningPrevious() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -290,15 +290,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForNextRegisteredDelayedDataShoul mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -412,15 +412,15 @@ func TestDelayedBlockBroadcaster_SetHeaderForValidatorShouldSetAlarmAndBroadcast txBroadcastCalled := atomic.Counter{} headerBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { headerBroadcastCalled.Increment() return nil } @@ -472,15 +472,15 @@ func TestDelayedBlockBroadcaster_SetValidatorDataFinalizedMetaHeaderShouldSetAla txBroadcastCalled := atomic.Counter{} headerBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { headerBroadcastCalled.Increment() return nil } @@ -540,15 +540,15 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarm(t *testing.T txBroadcastCalled := atomic.Counter{} headerBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { headerBroadcastCalled.Increment() return nil } @@ -609,15 +609,15 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarmForHeaderBroa txBroadcastCalled := atomic.Counter{} headerBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { headerBroadcastCalled.Increment() return nil } @@ -677,15 +677,15 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderInvalidOrDifferentShouldIgnore txBroadcastCalled := atomic.Counter{} headerBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { headerBroadcastCalled.Increment() return nil } @@ -791,15 +791,15 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastDifferentHeaderRoundS mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -848,15 +848,15 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastDifferentPrevRandShou mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -908,15 +908,15 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastSameRoundAndPrevRandS mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -968,15 +968,15 @@ func TestDelayedBlockBroadcaster_AlarmExpiredShouldBroadcastTheDataForRegistered mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -1021,15 +1021,15 @@ func TestDelayedBlockBroadcaster_AlarmExpiredShouldDoNothingForNotRegisteredData mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -1169,15 +1169,15 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockForNotSetValDataShouldBroad mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -1232,15 +1232,15 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockOutOfManyForSetValDataShoul mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -1296,15 +1296,15 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockFinalForSetValDataShouldNot mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } @@ -1360,15 +1360,15 @@ func TestDelayedBlockBroadcaster_Close(t *testing.T) { mbBroadcastCalled := atomic.Counter{} txBroadcastCalled := atomic.Counter{} - broadcastMiniBlocks := func(mbData map[uint32][]byte) error { + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { mbBroadcastCalled.Increment() return nil } - broadcastTransactions := func(txData map[string][][]byte) error { + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { txBroadcastCalled.Increment() return nil } - broadcastHeader := func(header data.HeaderHandler) error { + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } diff --git a/consensus/broadcast/errors.go b/consensus/broadcast/errors.go new file mode 100644 index 00000000000..86acef6937b --- /dev/null +++ b/consensus/broadcast/errors.go @@ -0,0 +1,6 @@ +package broadcast + +import "errors" + +// ErrNilKeysHandler signals that a nil keys handler was provided +var ErrNilKeysHandler = errors.New("nil keys handler") diff --git a/consensus/broadcast/export.go b/consensus/broadcast/export.go index fd41f943c2a..e7b0e4dfa80 100644 --- a/consensus/broadcast/export.go +++ b/consensus/broadcast/export.go @@ -16,11 +16,6 @@ type HeaderDataForValidator struct { PrevRandSeed []byte } -// SignMessage will sign and return the given message -func (cm *commonMessenger) SignMessage(message *consensus.Message) ([]byte, error) { - return cm.peerSignatureHandler.GetPeerSignature(cm.privateKey, message.OriginatorPid) -} - // ExtractMetaMiniBlocksAndTransactions - func (cm *commonMessenger) ExtractMetaMiniBlocksAndTransactions( miniBlocks map[uint32][]byte, @@ -175,16 +170,21 @@ func (dbb *delayedBlockBroadcaster) InterceptedHeaderData(topic string, hash []b func NewCommonMessenger( marshalizer marshal.Marshalizer, messenger consensus.P2PMessenger, - privateKey crypto.PrivateKey, shardCoordinator sharding.Coordinator, peerSigHandler crypto.PeerSignatureHandler, + keysHandler consensus.KeysHandler, ) (*commonMessenger, error) { return &commonMessenger{ marshalizer: marshalizer, messenger: messenger, - privateKey: privateKey, shardCoordinator: shardCoordinator, peerSignatureHandler: peerSigHandler, + keysHandler: keysHandler, }, nil } + +// Broadcast - +func (cm *commonMessenger) Broadcast(topic string, data []byte, pkBytes []byte) { + cm.broadcast(topic, data, pkBytes) +} diff --git a/consensus/broadcast/metaChainMessenger.go b/consensus/broadcast/metaChainMessenger.go index cf3c177ec30..daca3b436a5 100644 --- a/consensus/broadcast/metaChainMessenger.go +++ b/consensus/broadcast/metaChainMessenger.go @@ -50,10 +50,10 @@ func NewMetaChainMessenger( marshalizer: args.Marshalizer, hasher: args.Hasher, messenger: args.Messenger, - privateKey: args.PrivateKey, shardCoordinator: args.ShardCoordinator, peerSignatureHandler: args.PeerSignatureHandler, delayedBlockBroadcaster: dbb, + keysHandler: args.KeysHandler, } mcm := &metaChainMessenger{ @@ -109,7 +109,7 @@ func (mcm *metaChainMessenger) BroadcastBlock(blockBody data.BodyHandler, header } // BroadcastHeader will send on metachain blocks topic the header -func (mcm *metaChainMessenger) BroadcastHeader(header data.HeaderHandler) error { +func (mcm *metaChainMessenger) BroadcastHeader(header data.HeaderHandler, pkBytes []byte) error { if check.IfNil(header) { return spos.ErrNilHeader } @@ -119,7 +119,7 @@ func (mcm *metaChainMessenger) BroadcastHeader(header data.HeaderHandler) error return err } - mcm.messenger.Broadcast(factory.MetachainBlocksTopic, msgHeader) + mcm.broadcast(factory.MetachainBlocksTopic, msgHeader, pkBytes) return nil } @@ -129,8 +129,9 @@ func (mcm *metaChainMessenger) BroadcastBlockDataLeader( _ data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, + pkBytes []byte, ) error { - go mcm.BroadcastBlockData(miniBlocks, transactions, common.ExtraDelayForBroadcastBlockInfo) + go mcm.BroadcastBlockData(miniBlocks, transactions, pkBytes, common.ExtraDelayForBroadcastBlockInfo) return nil } @@ -140,6 +141,7 @@ func (mcm *metaChainMessenger) PrepareBroadcastHeaderValidator( miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, + pkBytes []byte, ) { if check.IfNil(header) { log.Error("metaChainMessenger.PrepareBroadcastHeaderValidator", "error", spos.ErrNilHeader) @@ -158,6 +160,7 @@ func (mcm *metaChainMessenger) PrepareBroadcastHeaderValidator( metaMiniBlocksData: miniBlocks, metaTransactionsData: transactions, order: uint32(idx), + pkBytes: pkBytes, } err = mcm.delayedBlockBroadcaster.SetHeaderForValidator(vData) @@ -173,6 +176,7 @@ func (mcm *metaChainMessenger) PrepareBroadcastBlockDataValidator( _ map[uint32][]byte, _ map[string][][]byte, _ int, + _ []byte, ) { } diff --git a/consensus/broadcast/metaChainMessenger_test.go b/consensus/broadcast/metaChainMessenger_test.go index 06eb908a865..01cbb6a151d 100644 --- a/consensus/broadcast/metaChainMessenger_test.go +++ b/consensus/broadcast/metaChainMessenger_test.go @@ -1,25 +1,29 @@ package broadcast_test import ( + "bytes" "sync" "testing" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/broadcast" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var nodePkBytes = []byte("node public key bytes") + func createDefaultMetaChainArgs() broadcast.MetaChainMessengerArgs { marshalizerMock := &mock.MarshalizerMock{} messengerMock := &p2pmocks.MessengerStub{} - privateKeyMock := &mock.PrivateKeyMock{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{} hasher := &hashingMocks.HasherMock{} @@ -33,7 +37,6 @@ func createDefaultMetaChainArgs() broadcast.MetaChainMessengerArgs { Marshalizer: marshalizerMock, Hasher: hasher, Messenger: messengerMock, - PrivateKey: privateKeyMock, ShardCoordinator: shardCoordinatorMock, PeerSignatureHandler: peerSigHandler, HeadersSubscriber: headersSubscriber, @@ -41,6 +44,7 @@ func createDefaultMetaChainArgs() broadcast.MetaChainMessengerArgs { MaxValidatorDelayCacheSize: 2, MaxDelayCacheSize: 2, AlarmScheduler: alarmScheduler, + KeysHandler: &testscommon.KeysHandlerStub{}, }, } } @@ -63,15 +67,6 @@ func TestMetaChainMessenger_NewMetaChainMessengerNilMessengerShouldFail(t *testi assert.Equal(t, spos.ErrNilMessenger, err) } -func TestMetaChainMessenger_NewMetaChainMessengerNilPrivateKeyShouldFail(t *testing.T) { - args := createDefaultMetaChainArgs() - args.PrivateKey = nil - mcm, err := broadcast.NewMetaChainMessenger(args) - - assert.Nil(t, mcm) - assert.Equal(t, spos.ErrNilPrivateKey, err) -} - func TestMetaChainMessenger_NewMetaChainMessengerNilShardCoordinatorShouldFail(t *testing.T) { args := createDefaultMetaChainArgs() args.ShardCoordinator = nil @@ -90,6 +85,15 @@ func TestMetaChainMessenger_NewMetaChainMessengerNilPeerSignatureHandlerShouldFa assert.Equal(t, spos.ErrNilPeerSignatureHandler, err) } +func TestMetaChainMessenger_NilKeysHandlerShouldError(t *testing.T) { + args := createDefaultMetaChainArgs() + args.KeysHandler = nil + mcm, err := broadcast.NewMetaChainMessenger(args) + + assert.Nil(t, mcm) + assert.Equal(t, broadcast.ErrNilKeysHandler, err) +} + func TestMetaChainMessenger_NewMetaChainMessengerShouldWork(t *testing.T) { args := createDefaultMetaChainArgs() mcm, err := broadcast.NewMetaChainMessenger(args) @@ -136,7 +140,7 @@ func TestMetaChainMessenger_BroadcastMiniBlocksShouldWork(t *testing.T) { args := createDefaultMetaChainArgs() mcm, _ := broadcast.NewMetaChainMessenger(args) - err := mcm.BroadcastMiniBlocks(nil) + err := mcm.BroadcastMiniBlocks(nil, []byte("pk bytes")) assert.Nil(t, err) } @@ -144,7 +148,7 @@ func TestMetaChainMessenger_BroadcastTransactionsShouldWork(t *testing.T) { args := createDefaultMetaChainArgs() mcm, _ := broadcast.NewMetaChainMessenger(args) - err := mcm.BroadcastTransactions(nil) + err := mcm.BroadcastTransactions(nil, []byte("pk bytes")) assert.Nil(t, err) } @@ -152,19 +156,28 @@ func TestMetaChainMessenger_BroadcastHeaderNilHeaderShouldErr(t *testing.T) { args := createDefaultMetaChainArgs() mcm, _ := broadcast.NewMetaChainMessenger(args) - err := mcm.BroadcastHeader(nil) + err := mcm.BroadcastHeader(nil, []byte("pk bytes")) assert.Equal(t, spos.ErrNilHeader, err) } func TestMetaChainMessenger_BroadcastHeaderOkHeaderShouldWork(t *testing.T) { - channelCalled := make(chan bool, 1) + channelBroadcastCalled := make(chan bool, 1) + channelBroadcastUsingPrivateKeyCalled := make(chan bool, 1) messenger := &p2pmocks.MessengerStub{ BroadcastCalled: func(topic string, buff []byte) { - channelCalled <- true + channelBroadcastCalled <- true + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + channelBroadcastUsingPrivateKeyCalled <- true }, } args := createDefaultMetaChainArgs() + args.KeysHandler = &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(pkBytes, nodePkBytes) + }, + } args.Messenger = messenger mcm, _ := broadcast.NewMetaChainMessenger(args) @@ -172,18 +185,35 @@ func TestMetaChainMessenger_BroadcastHeaderOkHeaderShouldWork(t *testing.T) { Nonce: 10, } - err := mcm.BroadcastHeader(&hdr) - assert.Nil(t, err) + t.Run("original public key of the node", func(t *testing.T) { + err := mcm.BroadcastHeader(&hdr, nodePkBytes) + assert.Nil(t, err) + + wasCalled := false + select { + case <-channelBroadcastCalled: + wasCalled = true + case <-time.After(time.Millisecond * 100): + } + + assert.Nil(t, err) + assert.True(t, wasCalled) + }) + t.Run("managed key", func(t *testing.T) { + err := mcm.BroadcastHeader(&hdr, []byte("managed key")) + assert.Nil(t, err) + + wasCalled := false + select { + case <-channelBroadcastUsingPrivateKeyCalled: + wasCalled = true + case <-time.After(time.Millisecond * 100): + } + + assert.Nil(t, err) + assert.True(t, wasCalled) + }) - wasCalled := false - select { - case <-channelCalled: - wasCalled = true - case <-time.After(time.Millisecond * 100): - } - - assert.Nil(t, err) - assert.True(t, wasCalled) } func TestMetaChainMessenger_BroadcastBlockDataLeader(t *testing.T) { @@ -193,28 +223,72 @@ func TestMetaChainMessenger_BroadcastBlockDataLeader(t *testing.T) { messengerMock := &p2pmocks.MessengerStub{ BroadcastCalled: func(topic string, buff []byte) { mutCounters.Lock() - countersBroadcast[topic]++ + countersBroadcast[broadcastMethodPrefix+topic]++ + mutCounters.Unlock() + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + mutCounters.Lock() + countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+topic]++ mutCounters.Unlock() }, } args := createDefaultMetaChainArgs() + args.KeysHandler = &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(pkBytes, nodePkBytes) + }, + } args.Messenger = messengerMock mcm, _ := broadcast.NewMetaChainMessenger(args) miniBlocks := map[uint32][]byte{0: []byte("mbs data1"), 1: []byte("mbs data2")} transactions := map[string][][]byte{"topic1": {[]byte("txdata1"), []byte("txdata2")}, "topic2": {[]byte("txdata3")}} - err := mcm.BroadcastBlockDataLeader(nil, miniBlocks, transactions) - require.Nil(t, err) - sleepTime := common.ExtraDelayBetweenBroadcastMbsAndTxs + - common.ExtraDelayForBroadcastBlockInfo + - time.Millisecond*100 - time.Sleep(sleepTime) - - mutCounters.Lock() - defer mutCounters.Unlock() - - assert.Equal(t, len(miniBlocks), countersBroadcast["txBlockBodies_0"]+countersBroadcast["txBlockBodies_0_1"]) - assert.Equal(t, len(transactions), countersBroadcast["topic1"]+countersBroadcast["topic2"]) + t.Run("original public key of the node", func(t *testing.T) { + mutCounters.Lock() + countersBroadcast = make(map[string]int) + mutCounters.Unlock() + + err := mcm.BroadcastBlockDataLeader(nil, miniBlocks, transactions, nodePkBytes) + require.Nil(t, err) + sleepTime := common.ExtraDelayBetweenBroadcastMbsAndTxs + + common.ExtraDelayForBroadcastBlockInfo + + time.Millisecond*100 + time.Sleep(sleepTime) + + mutCounters.Lock() + defer mutCounters.Unlock() + + numBroadcast := countersBroadcast[broadcastMethodPrefix+"txBlockBodies_0"] + numBroadcast += countersBroadcast[broadcastMethodPrefix+"txBlockBodies_0_1"] + assert.Equal(t, len(miniBlocks), numBroadcast) + + numBroadcast = countersBroadcast[broadcastMethodPrefix+"topic1"] + numBroadcast += countersBroadcast[broadcastMethodPrefix+"topic2"] + assert.Equal(t, len(transactions), numBroadcast) + }) + t.Run("managed key", func(t *testing.T) { + mutCounters.Lock() + countersBroadcast = make(map[string]int) + mutCounters.Unlock() + + err := mcm.BroadcastBlockDataLeader(nil, miniBlocks, transactions, []byte("pk bytes")) + require.Nil(t, err) + sleepTime := common.ExtraDelayBetweenBroadcastMbsAndTxs + + common.ExtraDelayForBroadcastBlockInfo + + time.Millisecond*100 + time.Sleep(sleepTime) + + mutCounters.Lock() + defer mutCounters.Unlock() + + numBroadcast := countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"txBlockBodies_0"] + numBroadcast += countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"txBlockBodies_0_1"] + assert.Equal(t, len(miniBlocks), numBroadcast) + + numBroadcast = countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"topic1"] + numBroadcast += countersBroadcast[broadcastUsingPrivateKeyCalledMethodPrefix+"topic2"] + assert.Equal(t, len(transactions), numBroadcast) + }) } diff --git a/consensus/broadcast/shardChainMessenger.go b/consensus/broadcast/shardChainMessenger.go index d744ee01cf8..ac7485a8d1f 100644 --- a/consensus/broadcast/shardChainMessenger.go +++ b/consensus/broadcast/shardChainMessenger.go @@ -40,9 +40,9 @@ func NewShardChainMessenger( marshalizer: args.Marshalizer, hasher: args.Hasher, messenger: args.Messenger, - privateKey: args.PrivateKey, shardCoordinator: args.ShardCoordinator, peerSignatureHandler: args.PeerSignatureHandler, + keysHandler: args.KeysHandler, } dbbArgs := &ArgsDelayedBlockBroadcaster{ @@ -120,7 +120,7 @@ func (scm *shardChainMessenger) BroadcastBlock(blockBody data.BodyHandler, heade } // BroadcastHeader will send on in-shard headers topic the header -func (scm *shardChainMessenger) BroadcastHeader(header data.HeaderHandler) error { +func (scm *shardChainMessenger) BroadcastHeader(header data.HeaderHandler, pkBytes []byte) error { if check.IfNil(header) { return spos.ErrNilHeader } @@ -131,7 +131,7 @@ func (scm *shardChainMessenger) BroadcastHeader(header data.HeaderHandler) error } shardIdentifier := scm.shardCoordinator.CommunicationIdentifier(core.MetachainShardId) - scm.messenger.Broadcast(factory.ShardBlocksTopic+shardIdentifier, msgHeader) + scm.broadcast(factory.ShardBlocksTopic+shardIdentifier, msgHeader, pkBytes) return nil } @@ -141,6 +141,7 @@ func (scm *shardChainMessenger) BroadcastBlockDataLeader( header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, + pkBytes []byte, ) error { if check.IfNil(header) { return spos.ErrNilHeader @@ -160,6 +161,7 @@ func (scm *shardChainMessenger) BroadcastBlockDataLeader( headerHash: headerHash, miniBlocksData: miniBlocks, transactions: transactions, + pkBytes: pkBytes, } err = scm.delayedBlockBroadcaster.SetLeaderData(broadcastData) @@ -167,7 +169,7 @@ func (scm *shardChainMessenger) BroadcastBlockDataLeader( return err } - go scm.BroadcastBlockData(metaMiniBlocks, metaTransactions, common.ExtraDelayForBroadcastBlockInfo) + go scm.BroadcastBlockData(metaMiniBlocks, metaTransactions, pkBytes, common.ExtraDelayForBroadcastBlockInfo) return nil } @@ -177,6 +179,7 @@ func (scm *shardChainMessenger) PrepareBroadcastHeaderValidator( _ map[uint32][]byte, _ map[string][][]byte, idx int, + pkBytes []byte, ) { if check.IfNil(header) { log.Error("shardChainMessenger.PrepareBroadcastHeaderValidator", "error", spos.ErrNilHeader) @@ -193,6 +196,7 @@ func (scm *shardChainMessenger) PrepareBroadcastHeaderValidator( headerHash: headerHash, header: header, order: uint32(idx), + pkBytes: pkBytes, } err = scm.delayedBlockBroadcaster.SetHeaderForValidator(vData) @@ -208,6 +212,7 @@ func (scm *shardChainMessenger) PrepareBroadcastBlockDataValidator( miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, + pkBytes []byte, ) { if check.IfNil(header) { log.Error("shardChainMessenger.PrepareBroadcastBlockDataValidator", "error", spos.ErrNilHeader) @@ -229,6 +234,7 @@ func (scm *shardChainMessenger) PrepareBroadcastBlockDataValidator( miniBlocksData: miniBlocks, transactions: transactions, order: uint32(idx), + pkBytes: pkBytes, } err = scm.delayedBlockBroadcaster.SetValidatorData(broadcastData) diff --git a/consensus/broadcast/shardChainMessenger_test.go b/consensus/broadcast/shardChainMessenger_test.go index e6c0f8957b5..c81d2d98c28 100644 --- a/consensus/broadcast/shardChainMessenger_test.go +++ b/consensus/broadcast/shardChainMessenger_test.go @@ -1,9 +1,11 @@ package broadcast_test import ( + "bytes" "testing" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/consensus/broadcast" @@ -54,7 +56,6 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { marshalizerMock := &mock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} messengerMock := &p2pmocks.MessengerStub{} - privateKeyMock := &mock.PrivateKeyMock{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{} headersSubscriber := &mock.HeadersCacherStub{} @@ -69,7 +70,6 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { Marshalizer: marshalizerMock, Hasher: hasher, Messenger: messengerMock, - PrivateKey: privateKeyMock, ShardCoordinator: shardCoordinatorMock, PeerSignatureHandler: peerSigHandler, HeadersSubscriber: headersSubscriber, @@ -77,6 +77,7 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { MaxDelayCacheSize: 1, MaxValidatorDelayCacheSize: 1, AlarmScheduler: alarmScheduler, + KeysHandler: &testscommon.KeysHandlerStub{}, }, } } @@ -99,15 +100,6 @@ func TestShardChainMessenger_NewShardChainMessengerNilMessengerShouldFail(t *tes assert.Equal(t, spos.ErrNilMessenger, err) } -func TestShardChainMessenger_NewShardChainMessengerNilPrivateKeyShouldFail(t *testing.T) { - args := createDefaultShardChainArgs() - args.PrivateKey = nil - scm, err := broadcast.NewShardChainMessenger(args) - - assert.Nil(t, scm) - assert.Equal(t, spos.ErrNilPrivateKey, err) -} - func TestShardChainMessenger_NewShardChainMessengerNilShardCoordinatorShouldFail(t *testing.T) { args := createDefaultShardChainArgs() args.ShardCoordinator = nil @@ -144,6 +136,15 @@ func TestShardChainMessenger_NewShardChainMessengerNilHeadersSubscriberShouldFai assert.Equal(t, spos.ErrNilHeadersSubscriber, err) } +func TestShardChainMessenger_NilKeysHandlerShouldError(t *testing.T) { + args := createDefaultShardChainArgs() + args.KeysHandler = nil + scm, err := broadcast.NewShardChainMessenger(args) + + assert.Nil(t, scm) + assert.Equal(t, broadcast.ErrNilKeysHandler, err) +} + func TestShardChainMessenger_NewShardChainMessengerShouldWork(t *testing.T) { args := createDefaultShardChainArgs() scm, err := broadcast.NewShardChainMessenger(args) @@ -195,15 +196,24 @@ func TestShardChainMessenger_BroadcastBlockShouldWork(t *testing.T) { } func TestShardChainMessenger_BroadcastMiniBlocksShouldBeDone(t *testing.T) { - channelCalled := make(chan bool, 100) + channelBroadcastCalled := make(chan bool, 100) + channelBroadcastUsingPrivateKeyCalled := make(chan bool, 100) messenger := &p2pmocks.MessengerStub{ BroadcastCalled: func(topic string, buff []byte) { - channelCalled <- true + channelBroadcastCalled <- true + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + channelBroadcastUsingPrivateKeyCalled <- true }, } args := createDefaultShardChainArgs() args.Messenger = messenger + args.KeysHandler = &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(nodePkBytes, pkBytes) + }, + } scm, _ := broadcast.NewShardChainMessenger(args) miniBlocks := make(map[uint32][]byte) @@ -211,20 +221,39 @@ func TestShardChainMessenger_BroadcastMiniBlocksShouldBeDone(t *testing.T) { miniBlocks[1] = make([]byte, 0) miniBlocks[2] = make([]byte, 0) miniBlocks[3] = make([]byte, 0) - err := scm.BroadcastMiniBlocks(miniBlocks) - - called := 0 - for i := 0; i < 4; i++ { - select { - case <-channelCalled: - called++ - case <-time.After(time.Millisecond * 100): - break + + t.Run("original public key of the node", func(t *testing.T) { + err := scm.BroadcastMiniBlocks(miniBlocks, nodePkBytes) + + called := 0 + for i := 0; i < 4; i++ { + select { + case <-channelBroadcastCalled: + called++ + case <-time.After(time.Millisecond * 100): + break + } } - } - assert.Nil(t, err) - assert.Equal(t, 4, called) + assert.Nil(t, err) + assert.Equal(t, 4, called) + }) + t.Run("managed key", func(t *testing.T) { + err := scm.BroadcastMiniBlocks(miniBlocks, []byte("managed key")) + + called := 0 + for i := 0; i < 4; i++ { + select { + case <-channelBroadcastUsingPrivateKeyCalled: + called++ + case <-time.After(time.Millisecond * 100): + break + } + } + + assert.Nil(t, err) + assert.Equal(t, 4, called) + }) } func TestShardChainMessenger_BroadcastTransactionsShouldNotBeCalled(t *testing.T) { @@ -240,7 +269,7 @@ func TestShardChainMessenger_BroadcastTransactionsShouldNotBeCalled(t *testing.T scm, _ := broadcast.NewShardChainMessenger(args) transactions := make(map[string][][]byte) - err := scm.BroadcastTransactions(transactions) + err := scm.BroadcastTransactions(transactions, []byte("pk bytes")) wasCalled := false select { @@ -253,7 +282,7 @@ func TestShardChainMessenger_BroadcastTransactionsShouldNotBeCalled(t *testing.T assert.False(t, wasCalled) transactions[factory.TransactionTopic] = make([][]byte, 0) - err = scm.BroadcastTransactions(transactions) + err = scm.BroadcastTransactions(transactions, []byte("pk bytes")) wasCalled = false select { @@ -267,67 +296,127 @@ func TestShardChainMessenger_BroadcastTransactionsShouldNotBeCalled(t *testing.T } func TestShardChainMessenger_BroadcastTransactionsShouldBeCalled(t *testing.T) { - channelCalled := make(chan bool, 1) + channelBroadcastCalled := make(chan bool, 1) + channelBroadcastUsingPrivateKeyCalled := make(chan bool, 1) messenger := &p2pmocks.MessengerStub{ BroadcastCalled: func(topic string, buff []byte) { - channelCalled <- true + channelBroadcastCalled <- true + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + channelBroadcastUsingPrivateKeyCalled <- true }, } args := createDefaultShardChainArgs() args.Messenger = messenger + args.KeysHandler = &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(pkBytes, nodePkBytes) + }, + } scm, _ := broadcast.NewShardChainMessenger(args) transactions := make(map[string][][]byte) txs := make([][]byte, 0) txs = append(txs, []byte("")) transactions[factory.TransactionTopic] = txs - err := scm.BroadcastTransactions(transactions) + t.Run("original public key of the node", func(t *testing.T) { + err := scm.BroadcastTransactions(transactions, nodePkBytes) + + wasCalled := false + for i := 0; i < 4; i++ { + select { + case <-channelBroadcastCalled: + wasCalled = true + case <-time.After(time.Millisecond * 100): + break + } + } - wasCalled := false - select { - case <-channelCalled: - wasCalled = true - case <-time.After(time.Millisecond * 100): - } + assert.Nil(t, err) + assert.True(t, wasCalled) + }) + t.Run("managed key", func(t *testing.T) { + err := scm.BroadcastTransactions(transactions, []byte("managed key")) + + wasCalled := false + for i := 0; i < 4; i++ { + select { + case <-channelBroadcastUsingPrivateKeyCalled: + wasCalled = true + case <-time.After(time.Millisecond * 100): + break + } + } - assert.Nil(t, err) - assert.True(t, wasCalled) + assert.Nil(t, err) + assert.True(t, wasCalled) + }) } func TestShardChainMessenger_BroadcastHeaderNilHeaderShouldErr(t *testing.T) { args := createDefaultShardChainArgs() scm, _ := broadcast.NewShardChainMessenger(args) - err := scm.BroadcastHeader(nil) + err := scm.BroadcastHeader(nil, []byte("pk bytes")) assert.Equal(t, spos.ErrNilHeader, err) } func TestShardChainMessenger_BroadcastHeaderShouldWork(t *testing.T) { - channelCalled := make(chan bool, 1) + channelBroadcastCalled := make(chan bool, 1) + channelBroadcastUsingPrivateKeyCalled := make(chan bool, 1) messenger := &p2pmocks.MessengerStub{ BroadcastCalled: func(topic string, buff []byte) { - channelCalled <- true + channelBroadcastCalled <- true + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + channelBroadcastUsingPrivateKeyCalled <- true }, } args := createDefaultShardChainArgs() + args.KeysHandler = &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(pkBytes, nodePkBytes) + }, + } args.Messenger = messenger scm, _ := broadcast.NewShardChainMessenger(args) - hdr := block.MetaBlock{Nonce: 10} - err := scm.BroadcastHeader(&hdr) + hdr := &block.MetaBlock{Nonce: 10} + t.Run("original public key of the node", func(t *testing.T) { + err := scm.BroadcastHeader(hdr, nodePkBytes) + + wasCalled := false + for i := 0; i < 4; i++ { + select { + case <-channelBroadcastCalled: + wasCalled = true + case <-time.After(time.Millisecond * 100): + break + } + } - wasCalled := false - select { - case <-channelCalled: - wasCalled = true - case <-time.After(time.Millisecond * 100): - } + assert.Nil(t, err) + assert.True(t, wasCalled) + }) + t.Run("managed key", func(t *testing.T) { + err := scm.BroadcastHeader(hdr, []byte("managed key")) + + wasCalled := false + for i := 0; i < 4; i++ { + select { + case <-channelBroadcastUsingPrivateKeyCalled: + wasCalled = true + case <-time.After(time.Millisecond * 100): + break + } + } - assert.Nil(t, err) - assert.True(t, wasCalled) + assert.Nil(t, err) + assert.True(t, wasCalled) + }) } func TestShardChainMessenger_BroadcastBlockDataLeaderNilHeaderShouldErr(t *testing.T) { @@ -336,7 +425,7 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderNilHeaderShouldErr(t *testi _, _, miniblocks, transactions := createDelayData("1") - err := scm.BroadcastBlockDataLeader(nil, miniblocks, transactions) + err := scm.BroadcastBlockDataLeader(nil, miniblocks, transactions, []byte("pk bytes")) assert.Equal(t, spos.ErrNilHeader, err) } @@ -346,31 +435,56 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderNilMiniblocksShouldReturnNi _, header, _, transactions := createDelayData("1") - err := scm.BroadcastBlockDataLeader(header, nil, transactions) + err := scm.BroadcastBlockDataLeader(header, nil, transactions, []byte("pk bytes")) assert.Nil(t, err) } func TestShardChainMessenger_BroadcastBlockDataLeaderShouldTriggerWaitingDelayedMessage(t *testing.T) { - wasCalled := atomic.Flag{} + broadcastWasCalled := atomic.Flag{} + broadcastUsingPrivateKeyWasCalled := atomic.Flag{} messenger := &p2pmocks.MessengerStub{ BroadcastCalled: func(topic string, buff []byte) { - _ = wasCalled.SetReturningPrevious() + broadcastWasCalled.SetValue(true) + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + broadcastUsingPrivateKeyWasCalled.SetValue(true) }, } args := createDefaultShardChainArgs() args.Messenger = messenger + args.KeysHandler = &testscommon.KeysHandlerStub{ + IsOriginalPublicKeyOfTheNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal(pkBytes, nodePkBytes) + }, + } scm, _ := broadcast.NewShardChainMessenger(args) - _, header, miniBlocksMarshalled, transactions := createDelayData("1") - err := scm.BroadcastBlockDataLeader(header, miniBlocksMarshalled, transactions) - time.Sleep(10 * time.Millisecond) - assert.Nil(t, err) - assert.False(t, wasCalled.IsSet()) - - wasCalled.Reset() - _, header2, miniBlocksMarshalled2, transactions2 := createDelayData("2") - err = scm.BroadcastBlockDataLeader(header2, miniBlocksMarshalled2, transactions2) - time.Sleep(10 * time.Millisecond) - assert.Nil(t, err) - assert.True(t, wasCalled.IsSet()) + t.Run("original public key of the node", func(t *testing.T) { + _, header, miniBlocksMarshalled, transactions := createDelayData("1") + err := scm.BroadcastBlockDataLeader(header, miniBlocksMarshalled, transactions, nodePkBytes) + time.Sleep(10 * time.Millisecond) + assert.Nil(t, err) + assert.False(t, broadcastWasCalled.IsSet()) + + broadcastWasCalled.Reset() + _, header2, miniBlocksMarshalled2, transactions2 := createDelayData("2") + err = scm.BroadcastBlockDataLeader(header2, miniBlocksMarshalled2, transactions2, nodePkBytes) + time.Sleep(10 * time.Millisecond) + assert.Nil(t, err) + assert.True(t, broadcastWasCalled.IsSet()) + }) + t.Run("managed key", func(t *testing.T) { + _, header, miniBlocksMarshalled, transactions := createDelayData("1") + err := scm.BroadcastBlockDataLeader(header, miniBlocksMarshalled, transactions, []byte("managed key")) + time.Sleep(10 * time.Millisecond) + assert.Nil(t, err) + assert.False(t, broadcastUsingPrivateKeyWasCalled.IsSet()) + + broadcastWasCalled.Reset() + _, header2, miniBlocksMarshalled2, transactions2 := createDelayData("2") + err = scm.BroadcastBlockDataLeader(header2, miniBlocksMarshalled2, transactions2, []byte("managed key")) + time.Sleep(10 * time.Millisecond) + assert.Nil(t, err) + assert.True(t, broadcastUsingPrivateKeyWasCalled.IsSet()) + }) } diff --git a/consensus/interface.go b/consensus/interface.go index 9a2891a35e4..00b173a2eb3 100644 --- a/consensus/interface.go +++ b/consensus/interface.go @@ -60,19 +60,20 @@ type ChronologyHandler interface { // BroadcastMessenger defines the behaviour of the broadcast messages by the consensus group type BroadcastMessenger interface { BroadcastBlock(data.BodyHandler, data.HeaderHandler) error - BroadcastHeader(data.HeaderHandler) error - BroadcastMiniBlocks(map[uint32][]byte) error - BroadcastTransactions(map[string][][]byte) error + BroadcastHeader(data.HeaderHandler, []byte) error + BroadcastMiniBlocks(map[uint32][]byte, []byte) error + BroadcastTransactions(map[string][][]byte, []byte) error BroadcastConsensusMessage(*Message) error - BroadcastBlockDataLeader(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte) error - PrepareBroadcastHeaderValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, order int) - PrepareBroadcastBlockDataValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int) + BroadcastBlockDataLeader(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, pkBytes []byte) error + PrepareBroadcastHeaderValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, pkBytes []byte) + PrepareBroadcastBlockDataValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, pkBytes []byte) IsInterfaceNil() bool } // P2PMessenger defines a subset of the p2p.Messenger interface type P2PMessenger interface { Broadcast(topic string, buff []byte) + BroadcastUsingPrivateKey(topic string, buff []byte, pid core.PeerID, skBytes []byte) IsInterfaceNil() bool } @@ -164,10 +165,12 @@ type PeerBlacklistHandler interface { IsInterfaceNil() bool } -// SignatureHandler defines the behaviour of a component that handles signatures in consensus -type SignatureHandler interface { +// SigningHandler defines the behaviour of a component that handles multi and single signatures used in consensus operations +type SigningHandler interface { Reset(pubKeys []string) error - CreateSignatureShare(msg []byte, index uint16, epoch uint32) ([]byte, error) + CreateSignatureShareForPublicKey(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) + CreateSignatureForPublicKey(message []byte, publicKeyBytes []byte) ([]byte, error) + VerifySingleSignature(publicKeyBytes []byte, message []byte, signature []byte) error StoreSignatureShare(index uint16, sig []byte) error SignatureShare(index uint16) ([]byte, error) VerifySignatureShare(index uint16, sig []byte, msg []byte, epoch uint32) error @@ -176,3 +179,16 @@ type SignatureHandler interface { Verify(msg []byte, bitmap []byte, epoch uint32) error IsInterfaceNil() bool } + +// KeysHandler defines the operations implemented by a component that will manage all keys, +// including the single signer keys or the set of multi-keys +type KeysHandler interface { + GetHandledPrivateKey(pkBytes []byte) crypto.PrivateKey + GetP2PIdentity(pkBytes []byte) ([]byte, core.PeerID, error) + IsKeyManagedByCurrentNode(pkBytes []byte) bool + IncrementRoundsWithoutReceivedMessages(pkBytes []byte) + GetAssociatedPid(pkBytes []byte) core.PeerID + IsOriginalPublicKeyOfTheNode(pkBytes []byte) bool + UpdatePublicKeyLiveness(pkBytes []byte, pid core.PeerID) + IsInterfaceNil() bool +} diff --git a/consensus/mock/broadcastMessangerMock.go b/consensus/mock/broadcastMessangerMock.go index 1584c7f405a..2d659490725 100644 --- a/consensus/mock/broadcastMessangerMock.go +++ b/consensus/mock/broadcastMessangerMock.go @@ -8,13 +8,13 @@ import ( // BroadcastMessengerMock - type BroadcastMessengerMock struct { BroadcastBlockCalled func(data.BodyHandler, data.HeaderHandler) error - BroadcastHeaderCalled func(data.HeaderHandler) error - PrepareBroadcastBlockDataValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int) error - PrepareBroadcastHeaderValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int) - BroadcastMiniBlocksCalled func(map[uint32][]byte) error - BroadcastTransactionsCalled func(map[string][][]byte) error + BroadcastHeaderCalled func(data.HeaderHandler, []byte) error + PrepareBroadcastBlockDataValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) error + PrepareBroadcastHeaderValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) + BroadcastMiniBlocksCalled func(map[uint32][]byte, []byte) error + BroadcastTransactionsCalled func(map[string][][]byte, []byte) error BroadcastConsensusMessageCalled func(*consensus.Message) error - BroadcastBlockDataLeaderCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte) error + BroadcastBlockDataLeaderCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, pkBytes []byte) error } // BroadcastBlock - @@ -26,9 +26,9 @@ func (bmm *BroadcastMessengerMock) BroadcastBlock(bodyHandler data.BodyHandler, } // BroadcastMiniBlocks - -func (bmm *BroadcastMessengerMock) BroadcastMiniBlocks(miniBlocks map[uint32][]byte) error { +func (bmm *BroadcastMessengerMock) BroadcastMiniBlocks(miniBlocks map[uint32][]byte, pkBytes []byte) error { if bmm.BroadcastMiniBlocksCalled != nil { - return bmm.BroadcastMiniBlocksCalled(miniBlocks) + return bmm.BroadcastMiniBlocksCalled(miniBlocks, pkBytes) } return nil } @@ -38,17 +38,18 @@ func (bmm *BroadcastMessengerMock) BroadcastBlockDataLeader( header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, + pkBytes []byte, ) error { if bmm.BroadcastBlockDataLeaderCalled != nil { - return bmm.BroadcastBlockDataLeaderCalled(header, miniBlocks, transactions) + return bmm.BroadcastBlockDataLeaderCalled(header, miniBlocks, transactions, pkBytes) } - err := bmm.BroadcastMiniBlocks(miniBlocks) + err := bmm.BroadcastMiniBlocks(miniBlocks, pkBytes) if err != nil { return err } - return bmm.BroadcastTransactions(transactions) + return bmm.BroadcastTransactions(transactions, pkBytes) } // PrepareBroadcastBlockDataValidator - @@ -57,6 +58,7 @@ func (bmm *BroadcastMessengerMock) PrepareBroadcastBlockDataValidator( miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, + pkBytes []byte, ) { if bmm.PrepareBroadcastBlockDataValidatorCalled != nil { _ = bmm.PrepareBroadcastBlockDataValidatorCalled( @@ -64,6 +66,7 @@ func (bmm *BroadcastMessengerMock) PrepareBroadcastBlockDataValidator( miniBlocks, transactions, idx, + pkBytes, ) } } @@ -74,6 +77,7 @@ func (bmm *BroadcastMessengerMock) PrepareBroadcastHeaderValidator( miniBlocks map[uint32][]byte, transactions map[string][][]byte, order int, + pkBytes []byte, ) { if bmm.PrepareBroadcastHeaderValidatorCalled != nil { bmm.PrepareBroadcastHeaderValidatorCalled( @@ -81,14 +85,15 @@ func (bmm *BroadcastMessengerMock) PrepareBroadcastHeaderValidator( miniBlocks, transactions, order, + pkBytes, ) } } // BroadcastTransactions - -func (bmm *BroadcastMessengerMock) BroadcastTransactions(transactions map[string][][]byte) error { +func (bmm *BroadcastMessengerMock) BroadcastTransactions(transactions map[string][][]byte, pkBytes []byte) error { if bmm.BroadcastTransactionsCalled != nil { - return bmm.BroadcastTransactionsCalled(transactions) + return bmm.BroadcastTransactionsCalled(transactions, pkBytes) } return nil } @@ -102,9 +107,9 @@ func (bmm *BroadcastMessengerMock) BroadcastConsensusMessage(message *consensus. } // BroadcastHeader - -func (bmm *BroadcastMessengerMock) BroadcastHeader(headerhandler data.HeaderHandler) error { +func (bmm *BroadcastMessengerMock) BroadcastHeader(headerhandler data.HeaderHandler, pkBytes []byte) error { if bmm.BroadcastHeaderCalled != nil { - return bmm.BroadcastHeaderCalled(headerhandler) + return bmm.BroadcastHeaderCalled(headerhandler, pkBytes) } return nil } diff --git a/consensus/mock/consensusDataContainerMock.go b/consensus/mock/consensusDataContainerMock.go index dedf3228660..88f837b1da1 100644 --- a/consensus/mock/consensusDataContainerMock.go +++ b/consensus/mock/consensusDataContainerMock.go @@ -4,7 +4,6 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" - crypto "github.com/multiversx/mx-chain-crypto-go" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -24,9 +23,6 @@ type ConsensusCoreMock struct { chronologyHandler consensus.ChronologyHandler hasher hashing.Hasher marshalizer marshal.Marshalizer - blsPrivateKey crypto.PrivateKey - blsSingleSigner crypto.SingleSigner - keyGenerator crypto.KeyGenerator multiSignerContainer cryptoCommon.MultiSignerContainer roundHandler consensus.RoundHandler shardCoordinator sharding.Coordinator @@ -41,7 +37,7 @@ type ConsensusCoreMock struct { scheduledProcessor consensus.ScheduledProcessor messageSigningHandler consensus.P2PSigningHandler peerBlacklistHandler consensus.PeerBlacklistHandler - signatureHandler consensus.SignatureHandler + signingHandler consensus.SigningHandler } // GetAntiFloodHandler - @@ -124,11 +120,6 @@ func (ccm *ConsensusCoreMock) SetBlockchain(blockChain data.ChainHandler) { ccm.blockChain = blockChain } -// SetSingleSigner - -func (ccm *ConsensusCoreMock) SetSingleSigner(signer crypto.SingleSigner) { - ccm.blsSingleSigner = signer -} - // SetBlockProcessor - func (ccm *ConsensusCoreMock) SetBlockProcessor(blockProcessor process.BlockProcessor) { ccm.blockProcessor = blockProcessor @@ -184,21 +175,6 @@ func (ccm *ConsensusCoreMock) SetValidatorGroupSelector(validatorGroupSelector n ccm.validatorGroupSelector = validatorGroupSelector } -// PrivateKey - -func (ccm *ConsensusCoreMock) PrivateKey() crypto.PrivateKey { - return ccm.blsPrivateKey -} - -// SingleSigner returns the bls single signer stored in the ConsensusCore -func (ccm *ConsensusCoreMock) SingleSigner() crypto.SingleSigner { - return ccm.blsSingleSigner -} - -// KeyGenerator - -func (ccm *ConsensusCoreMock) KeyGenerator() crypto.KeyGenerator { - return ccm.keyGenerator -} - // PeerHonestyHandler - func (ccm *ConsensusCoreMock) PeerHonestyHandler() consensus.PeerHonestyHandler { return ccm.peerHonestyHandler @@ -249,24 +225,19 @@ func (ccm *ConsensusCoreMock) SetMessageSigningHandler(messageSigningHandler con ccm.messageSigningHandler = messageSigningHandler } -// SetKeyGenerator - -func (ccm *ConsensusCoreMock) SetKeyGenerator(keyGenerator crypto.KeyGenerator) { - ccm.keyGenerator = keyGenerator -} - // PeerBlacklistHandler will return the peer blacklist handler func (ccm *ConsensusCoreMock) PeerBlacklistHandler() consensus.PeerBlacklistHandler { return ccm.peerBlacklistHandler } -// SignatureHandler - -func (ccm *ConsensusCoreMock) SignatureHandler() consensus.SignatureHandler { - return ccm.signatureHandler +// SigningHandler - +func (ccm *ConsensusCoreMock) SigningHandler() consensus.SigningHandler { + return ccm.signingHandler } -// SetSignatureHandler - -func (ccm *ConsensusCoreMock) SetSignatureHandler(signatureHandler consensus.SignatureHandler) { - ccm.signatureHandler = signatureHandler +// SetSigningHandler - +func (ccm *ConsensusCoreMock) SetSigningHandler(signingHandler consensus.SigningHandler) { + ccm.signingHandler = signingHandler } // IsInterfaceNil returns true if there is no value under the interface diff --git a/consensus/mock/mockTestInitializer.go b/consensus/mock/mockTestInitializer.go index 5802744aae0..4468c3d338b 100644 --- a/consensus/mock/mockTestInitializer.go +++ b/consensus/mock/mockTestInitializer.go @@ -175,14 +175,7 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus chronologyHandlerMock := InitChronologyHandlerMock() hasherMock := &hashingMocks.HasherMock{} - blsPrivateKeyMock := &PrivateKeyMock{} - blsSingleSignerMock := &SingleSignerMock{ - SignStub: func(private crypto.PrivateKey, msg []byte) (bytes []byte, e error) { - return make([]byte, 0), nil - }, - } roundHandlerMock := &RoundHandlerMock{} - keyGen := &KeyGenMock{} shardCoordinatorMock := ShardCoordinatorMock{} syncTimerMock := &SyncTimerMock{} validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{ @@ -212,7 +205,7 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus messageSigningHandler := &MessageSigningHandlerStub{} peerBlacklistHandler := &PeerBlacklistHandlerStub{} multiSignerContainer := cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signatureHandler := &SignatureHandlerStub{} + signingHandler := &SigningHandlerStub{} container := &ConsensusCoreMock{ blockChain: blockChain, @@ -223,9 +216,6 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus chronologyHandler: chronologyHandlerMock, hasher: hasherMock, marshalizer: marshalizerMock, - blsPrivateKey: blsPrivateKeyMock, - blsSingleSigner: blsSingleSignerMock, - keyGenerator: keyGen, multiSignerContainer: multiSignerContainer, roundHandler: roundHandlerMock, shardCoordinator: shardCoordinatorMock, @@ -240,7 +230,7 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus scheduledProcessor: scheduledProcessor, messageSigningHandler: messageSigningHandler, peerBlacklistHandler: peerBlacklistHandler, - signatureHandler: signatureHandler, + signingHandler: signingHandler, } return container diff --git a/consensus/mock/signatureHandlerStub.go b/consensus/mock/signatureHandlerStub.go deleted file mode 100644 index 802b4ae908f..00000000000 --- a/consensus/mock/signatureHandlerStub.go +++ /dev/null @@ -1,90 +0,0 @@ -package mock - -// SignatureHandlerStub implements SignatureHandler interface -type SignatureHandlerStub struct { - ResetCalled func(pubKeys []string) error - CreateSignatureShareCalled func(msg []byte, index uint16, epoch uint32) ([]byte, error) - StoreSignatureShareCalled func(index uint16, sig []byte) error - SignatureShareCalled func(index uint16) ([]byte, error) - VerifySignatureShareCalled func(index uint16, sig []byte, msg []byte, epoch uint32) error - AggregateSigsCalled func(bitmap []byte, epoch uint32) ([]byte, error) - SetAggregatedSigCalled func(_ []byte) error - VerifyCalled func(msg []byte, bitmap []byte, epoch uint32) error -} - -// Reset - -func (stub *SignatureHandlerStub) Reset(pubKeys []string) error { - if stub.ResetCalled != nil { - return stub.ResetCalled(pubKeys) - } - - return nil -} - -// CreateSignatureShare - -func (stub *SignatureHandlerStub) CreateSignatureShare(msg []byte, index uint16, epoch uint32) ([]byte, error) { - if stub.CreateSignatureShareCalled != nil { - return stub.CreateSignatureShareCalled(msg, index, epoch) - } - - return []byte("sigShare"), nil -} - -// StoreSignatureShare - -func (stub *SignatureHandlerStub) StoreSignatureShare(index uint16, sig []byte) error { - if stub.StoreSignatureShareCalled != nil { - return stub.StoreSignatureShareCalled(index, sig) - } - - return nil -} - -// SignatureShare - -func (stub *SignatureHandlerStub) SignatureShare(index uint16) ([]byte, error) { - if stub.SignatureShareCalled != nil { - return stub.SignatureShareCalled(index) - } - - return []byte("sigShare"), nil -} - -// VerifySignatureShare - -func (stub *SignatureHandlerStub) VerifySignatureShare(index uint16, sig []byte, msg []byte, epoch uint32) error { - if stub.VerifySignatureShareCalled != nil { - return stub.VerifySignatureShareCalled(index, sig, msg, epoch) - } - - return nil -} - -// AggregateSigs - -func (stub *SignatureHandlerStub) AggregateSigs(bitmap []byte, epoch uint32) ([]byte, error) { - if stub.AggregateSigsCalled != nil { - return stub.AggregateSigsCalled(bitmap, epoch) - } - - return []byte("aggSigs"), nil -} - -// SetAggregatedSig - -func (stub *SignatureHandlerStub) SetAggregatedSig(sig []byte) error { - if stub.SetAggregatedSigCalled != nil { - return stub.SetAggregatedSigCalled(sig) - } - - return nil -} - -// Verify - -func (stub *SignatureHandlerStub) Verify(msg []byte, bitmap []byte, epoch uint32) error { - if stub.VerifyCalled != nil { - return stub.VerifyCalled(msg, bitmap, epoch) - } - - return nil -} - -// IsInterfaceNil - -func (stub *SignatureHandlerStub) IsInterfaceNil() bool { - return stub == nil -} diff --git a/consensus/mock/signingHandlerStub.go b/consensus/mock/signingHandlerStub.go new file mode 100644 index 00000000000..33c4121d74c --- /dev/null +++ b/consensus/mock/signingHandlerStub.go @@ -0,0 +1,110 @@ +package mock + +// SigningHandlerStub implements SigningHandler interface +type SigningHandlerStub struct { + ResetCalled func(pubKeys []string) error + CreateSignatureShareForPublicKeyCalled func(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) + CreateSignatureForPublicKeyCalled func(message []byte, publicKeyBytes []byte) ([]byte, error) + VerifySingleSignatureCalled func(publicKeyBytes []byte, message []byte, signature []byte) error + StoreSignatureShareCalled func(index uint16, sig []byte) error + SignatureShareCalled func(index uint16) ([]byte, error) + VerifySignatureShareCalled func(index uint16, sig []byte, msg []byte, epoch uint32) error + AggregateSigsCalled func(bitmap []byte, epoch uint32) ([]byte, error) + SetAggregatedSigCalled func(_ []byte) error + VerifyCalled func(msg []byte, bitmap []byte, epoch uint32) error +} + +// Reset - +func (stub *SigningHandlerStub) Reset(pubKeys []string) error { + if stub.ResetCalled != nil { + return stub.ResetCalled(pubKeys) + } + + return nil +} + +// CreateSignatureShareForPublicKey - +func (stub *SigningHandlerStub) CreateSignatureShareForPublicKey(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + if stub.CreateSignatureShareForPublicKeyCalled != nil { + return stub.CreateSignatureShareForPublicKeyCalled(message, index, epoch, publicKeyBytes) + } + + return make([]byte, 0), nil +} + +// CreateSignatureForPublicKey - +func (stub *SigningHandlerStub) CreateSignatureForPublicKey(message []byte, publicKeyBytes []byte) ([]byte, error) { + if stub.CreateSignatureForPublicKeyCalled != nil { + return stub.CreateSignatureForPublicKeyCalled(message, publicKeyBytes) + } + + return make([]byte, 0), nil +} + +// VerifySingleSignature - +func (stub *SigningHandlerStub) VerifySingleSignature(publicKeyBytes []byte, message []byte, signature []byte) error { + if stub.VerifySingleSignatureCalled != nil { + return stub.VerifySingleSignatureCalled(publicKeyBytes, message, signature) + } + + return nil +} + +// StoreSignatureShare - +func (stub *SigningHandlerStub) StoreSignatureShare(index uint16, sig []byte) error { + if stub.StoreSignatureShareCalled != nil { + return stub.StoreSignatureShareCalled(index, sig) + } + + return nil +} + +// SignatureShare - +func (stub *SigningHandlerStub) SignatureShare(index uint16) ([]byte, error) { + if stub.SignatureShareCalled != nil { + return stub.SignatureShareCalled(index) + } + + return []byte("sigShare"), nil +} + +// VerifySignatureShare - +func (stub *SigningHandlerStub) VerifySignatureShare(index uint16, sig []byte, msg []byte, epoch uint32) error { + if stub.VerifySignatureShareCalled != nil { + return stub.VerifySignatureShareCalled(index, sig, msg, epoch) + } + + return nil +} + +// AggregateSigs - +func (stub *SigningHandlerStub) AggregateSigs(bitmap []byte, epoch uint32) ([]byte, error) { + if stub.AggregateSigsCalled != nil { + return stub.AggregateSigsCalled(bitmap, epoch) + } + + return []byte("aggSigs"), nil +} + +// SetAggregatedSig - +func (stub *SigningHandlerStub) SetAggregatedSig(sig []byte) error { + if stub.SetAggregatedSigCalled != nil { + return stub.SetAggregatedSigCalled(sig) + } + + return nil +} + +// Verify - +func (stub *SigningHandlerStub) Verify(msg []byte, bitmap []byte, epoch uint32) error { + if stub.VerifyCalled != nil { + return stub.VerifyCalled(msg, bitmap, epoch) + } + + return nil +} + +// IsInterfaceNil - +func (stub *SigningHandlerStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/consensus/spos/bls/blsSubroundsFactory.go b/consensus/spos/bls/blsSubroundsFactory.go index 59ed8933143..f8d58ab81b0 100644 --- a/consensus/spos/bls/blsSubroundsFactory.go +++ b/consensus/spos/bls/blsSubroundsFactory.go @@ -139,7 +139,7 @@ func (fct *factory) generateStartRoundSubround() error { return err } - subroundStartRound, err := NewSubroundStartRound( + subroundStartRoundInstance, err := NewSubroundStartRound( subround, fct.worker.Extend, processingThresholdPercent, @@ -150,12 +150,12 @@ func (fct *factory) generateStartRoundSubround() error { return err } - err = subroundStartRound.SetOutportHandler(fct.outportHandler) + err = subroundStartRoundInstance.SetOutportHandler(fct.outportHandler) if err != nil { return err } - fct.consensusCore.Chronology().AddSubround(subroundStartRound) + fct.consensusCore.Chronology().AddSubround(subroundStartRoundInstance) return nil } @@ -180,7 +180,7 @@ func (fct *factory) generateBlockSubround() error { return err } - subroundBlock, err := NewSubroundBlock( + subroundBlockInstance, err := NewSubroundBlock( subround, fct.worker.Extend, processingThresholdPercent, @@ -189,10 +189,10 @@ func (fct *factory) generateBlockSubround() error { return err } - fct.worker.AddReceivedMessageCall(MtBlockBodyAndHeader, subroundBlock.receivedBlockBodyAndHeader) - fct.worker.AddReceivedMessageCall(MtBlockBody, subroundBlock.receivedBlockBody) - fct.worker.AddReceivedMessageCall(MtBlockHeader, subroundBlock.receivedBlockHeader) - fct.consensusCore.Chronology().AddSubround(subroundBlock) + fct.worker.AddReceivedMessageCall(MtBlockBodyAndHeader, subroundBlockInstance.receivedBlockBodyAndHeader) + fct.worker.AddReceivedMessageCall(MtBlockBody, subroundBlockInstance.receivedBlockBody) + fct.worker.AddReceivedMessageCall(MtBlockHeader, subroundBlockInstance.receivedBlockHeader) + fct.consensusCore.Chronology().AddSubround(subroundBlockInstance) return nil } diff --git a/consensus/spos/bls/blsWorker_test.go b/consensus/spos/bls/blsWorker_test.go index ca38b8a6147..6786b96cde8 100644 --- a/consensus/spos/bls/blsWorker_test.go +++ b/consensus/spos/bls/blsWorker_test.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/stretchr/testify/assert" ) @@ -19,6 +20,10 @@ func createEligibleList(size int) []string { } func initConsensusState() *spos.ConsensusState { + return initConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) +} + +func initConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler) *spos.ConsensusState { consensusGroupSize := 9 eligibleList := createEligibleList(consensusGroupSize) @@ -28,10 +33,12 @@ func initConsensusState() *spos.ConsensusState { } indexLeader := 1 - rcns := spos.NewRoundConsensus( + rcns, _ := spos.NewRoundConsensus( eligibleNodesPubKeys, consensusGroupSize, - eligibleList[indexLeader]) + eligibleList[indexLeader], + keysHandler, + ) rcns.SetConsensusGroup(eligibleList) rcns.ResetRoundState() diff --git a/consensus/spos/bls/export_test.go b/consensus/spos/bls/export_test.go index aa776b77909..413e2495c38 100644 --- a/consensus/spos/bls/export_test.go +++ b/consensus/spos/bls/export_test.go @@ -315,6 +315,11 @@ func (sr *subroundEndRound) VerifyInvalidSigners(invalidSigners []byte) error { return sr.verifyInvalidSigners(invalidSigners) } +// GetMinConsensusGroupIndexOfManagedKeys - +func (sr *subroundEndRound) GetMinConsensusGroupIndexOfManagedKeys() int { + return sr.getMinConsensusGroupIndexOfManagedKeys() +} + // GetStringValue gets the name of the message type func GetStringValue(messageType consensus.MessageType) string { return getStringValue(messageType) diff --git a/consensus/spos/bls/subroundBlock.go b/consensus/spos/bls/subroundBlock.go index 02629853ccf..d032a04eb63 100644 --- a/consensus/spos/bls/subroundBlock.go +++ b/consensus/spos/bls/subroundBlock.go @@ -63,7 +63,7 @@ func checkNewSubroundBlockParams( // doBlockJob method does the job of the subround Block func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { - if !sr.IsSelfLeaderInCurrentRound() { // is NOT self leader in this round? + if !sr.IsSelfLeaderInCurrentRound() && !sr.IsMultiKeyLeaderInCurrentRound() { // is NOT self leader in this round? return false } @@ -71,7 +71,7 @@ func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { return false } - if sr.IsSelfJobDone(sr.Current()) { + if sr.IsLeaderJobDone(sr.Current()) { return false } @@ -99,7 +99,13 @@ func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { return false } - err = sr.SetSelfJobDone(sr.Current(), true) + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("doBlockJob.GetLeader", "error", errGetLeader) + return false + } + + err = sr.SetJobDone(leader, sr.Current(), true) if err != nil { log.Debug("doBlockJob.SetSelfJobDone", "error", err.Error()) return false @@ -182,12 +188,18 @@ func (sr *subroundBlock) sendHeaderAndBlockBody( ) bool { headerHash := sr.Hasher().Compute(string(marshalizedHeader)) + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("sendBlockBodyAndHeader.GetLeader", "error", errGetLeader) + return false + } + cnsMsg := consensus.NewConsensusMessage( headerHash, nil, marshalizedBody, marshalizedHeader, - []byte(sr.SelfPubKey()), + []byte(leader), nil, int(MtBlockBodyAndHeader), sr.RoundHandler().Index(), @@ -195,7 +207,7 @@ func (sr *subroundBlock) sendHeaderAndBlockBody( nil, nil, nil, - sr.CurrentPid(), + sr.GetAssociatedPid([]byte(leader)), nil, ) @@ -218,12 +230,18 @@ func (sr *subroundBlock) sendHeaderAndBlockBody( // sendBlockBody method sends the proposed block body in the subround Block func (sr *subroundBlock) sendBlockBody(bodyHandler data.BodyHandler, marshalizedBody []byte) bool { + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("sendBlockBody.GetLeader", "error", errGetLeader) + return false + } + cnsMsg := consensus.NewConsensusMessage( nil, nil, marshalizedBody, nil, - []byte(sr.SelfPubKey()), + []byte(leader), nil, int(MtBlockBody), sr.RoundHandler().Index(), @@ -231,7 +249,7 @@ func (sr *subroundBlock) sendBlockBody(bodyHandler data.BodyHandler, marshalized nil, nil, nil, - sr.CurrentPid(), + sr.GetAssociatedPid([]byte(leader)), nil, ) @@ -252,12 +270,18 @@ func (sr *subroundBlock) sendBlockBody(bodyHandler data.BodyHandler, marshalized func (sr *subroundBlock) sendBlockHeader(headerHandler data.HeaderHandler, marshalizedHeader []byte) bool { headerHash := sr.Hasher().Compute(string(marshalizedHeader)) + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("sendBlockBody.GetLeader", "error", errGetLeader) + return false + } + cnsMsg := consensus.NewConsensusMessage( headerHash, nil, nil, marshalizedHeader, - []byte(sr.SelfPubKey()), + []byte(leader), nil, int(MtBlockHeader), sr.RoundHandler().Index(), @@ -265,7 +289,7 @@ func (sr *subroundBlock) sendBlockHeader(headerHandler data.HeaderHandler, marsh nil, nil, nil, - sr.CurrentPid(), + sr.GetAssociatedPid([]byte(leader)), nil, ) @@ -312,7 +336,12 @@ func (sr *subroundBlock) createHeader() (data.HeaderHandler, error) { return nil, err } - randSeed, err := sr.SingleSigner().Sign(sr.PrivateKey(), prevRandSeed) + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + return nil, errGetLeader + } + + randSeed, err := sr.SigningHandler().CreateSignatureForPublicKey(prevRandSeed, []byte(leader)) if err != nil { return nil, err } diff --git a/consensus/spos/bls/subroundBlock_test.go b/consensus/spos/bls/subroundBlock_test.go index e86b5b8f3ef..2354ab92b11 100644 --- a/consensus/spos/bls/subroundBlock_test.go +++ b/consensus/spos/bls/subroundBlock_test.go @@ -901,8 +901,6 @@ func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { _ = sr.SendBlockBody(body, marshalizedBody) _ = sr.SendBlockHeader(header, marshalizedHeader) - oldRand := sr.BlockChain().GetGenesisHeader().GetRandSeed() - newRand, _ := sr.SingleSigner().Sign(sr.PrivateKey(), oldRand) expectedHeader, _ := container.BlockProcessor().CreateNewHeader(uint64(sr.RoundHandler().Index()), uint64(1)) err := expectedHeader.SetTimeStamp(uint64(sr.RoundHandler().TimeStamp().Unix())) require.Nil(t, err) @@ -912,7 +910,7 @@ func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { require.Nil(t, err) err = expectedHeader.SetPrevRandSeed(sr.BlockChain().GetGenesisHeader().GetRandSeed()) require.Nil(t, err) - err = expectedHeader.SetRandSeed(newRand) + err = expectedHeader.SetRandSeed(make([]byte, 0)) require.Nil(t, err) err = expectedHeader.SetMiniBlockHeaderHandlers(header.GetMiniBlockHeaderHandlers()) require.Nil(t, err) @@ -937,8 +935,6 @@ func TestSubroundBlock_CreateHeaderNotNilCurrentHeader(t *testing.T) { _ = sr.SendBlockBody(body, marshalizedBody) _ = sr.SendBlockHeader(header, marshalizedHeader) - oldRand := sr.BlockChain().GetGenesisHeader().GetRandSeed() - newRand, _ := sr.SingleSigner().Sign(sr.PrivateKey(), oldRand) expectedHeader, _ := container.BlockProcessor().CreateNewHeader( uint64(sr.RoundHandler().Index()), sr.BlockChain().GetCurrentBlockHeader().GetNonce()+1) @@ -948,7 +944,7 @@ func TestSubroundBlock_CreateHeaderNotNilCurrentHeader(t *testing.T) { require.Nil(t, err) err = expectedHeader.SetPrevHash(sr.BlockChain().GetCurrentBlockHeaderHash()) require.Nil(t, err) - err = expectedHeader.SetRandSeed(newRand) + err = expectedHeader.SetRandSeed(make([]byte, 0)) require.Nil(t, err) err = expectedHeader.SetMiniBlockHeaderHandlers(header.GetMiniBlockHeaderHandlers()) require.Nil(t, err) @@ -990,15 +986,13 @@ func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { _ = sr.SendBlockBody(body, marshalizedBody) _ = sr.SendBlockHeader(header, marshalizedHeader) - oldRand := sr.BlockChain().GetCurrentBlockHeader().GetRandSeed() - newRand, _ := sr.SingleSigner().Sign(sr.PrivateKey(), oldRand) expectedHeader := &block.Header{ Round: uint64(sr.RoundHandler().Index()), TimeStamp: uint64(sr.RoundHandler().TimeStamp().Unix()), RootHash: []byte{}, Nonce: sr.BlockChain().GetCurrentBlockHeader().GetNonce() + 1, PrevHash: sr.BlockChain().GetCurrentBlockHeaderHash(), - RandSeed: newRand, + RandSeed: make([]byte, 0), MiniBlockHeaders: mbHeaders, ChainID: chainID, } diff --git a/consensus/spos/bls/subroundEndRound.go b/consensus/spos/bls/subroundEndRound.go index cde2fd82d87..b4913c04792 100644 --- a/consensus/spos/bls/subroundEndRound.go +++ b/consensus/spos/bls/subroundEndRound.go @@ -87,7 +87,7 @@ func (sr *subroundEndRound) receivedBlockHeaderFinalInfo(_ context.Context, cnsD return false } - if sr.IsSelfLeaderInCurrentRound() { + if sr.IsSelfLeaderInCurrentRound() || sr.IsMultiKeyLeaderInCurrentRound() { return false } @@ -214,7 +214,7 @@ func (sr *subroundEndRound) verifyInvalidSigners(invalidSigners []byte) error { } for _, msg := range messages { - err := sr.verifyInvalidSigner(msg) + err = sr.verifyInvalidSigner(msg) if err != nil { return err } @@ -235,13 +235,7 @@ func (sr *subroundEndRound) verifyInvalidSigner(msg p2p.MessageP2P) error { return err } - pubKey, err := sr.KeyGenerator().PublicKeyFromByteArray(cnsMsg.PubKey) - if err != nil { - return err - } - - singleSigner := sr.SingleSigner() - err = singleSigner.Verify(pubKey, cnsMsg.BlockHeaderHash, cnsMsg.SignatureShare) + err = sr.SigningHandler().VerifySingleSignature(cnsMsg.PubKey, cnsMsg.BlockHeaderHash, cnsMsg.SignatureShare) if err != nil { log.Debug("verifyInvalidSigner: confirmed that node provided invalid signature", "pubKey", cnsMsg.PubKey, @@ -259,7 +253,7 @@ func (sr *subroundEndRound) applyBlacklistOnNode(peer core.PeerID) { } func (sr *subroundEndRound) receivedHeader(headerHandler data.HeaderHandler) { - if sr.ConsensusGroup() == nil || sr.IsSelfLeaderInCurrentRound() { + if sr.ConsensusGroup() == nil || sr.IsSelfLeaderInCurrentRound() || sr.IsMultiKeyLeaderInCurrentRound() { return } @@ -270,8 +264,8 @@ func (sr *subroundEndRound) receivedHeader(headerHandler data.HeaderHandler) { // doEndRoundJob method does the job of the subround EndRound func (sr *subroundEndRound) doEndRoundJob(_ context.Context) bool { - if !sr.IsSelfLeaderInCurrentRound() { - if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { + if !sr.IsSelfLeaderInCurrentRound() && !sr.IsMultiKeyLeaderInCurrentRound() { + if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || sr.IsMultiKeyInConsensusGroup() { err := sr.prepareBroadcastBlockDataForValidator() if err != nil { log.Warn("validator in consensus group preparing for delayed broadcast", @@ -346,11 +340,16 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { // broadcast header and final info section - // create and broadcast header final info sr.createAndBroadcastHeaderFinalInfo() + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("doEndRoundJobByLeader.GetLeader", "error", errGetLeader) + return false + } + // broadcast header - err = sr.BroadcastMessenger().BroadcastHeader(sr.Header) + err = sr.BroadcastMessenger().BroadcastHeader(sr.Header, []byte(leader)) if err != nil { log.Debug("doEndRoundJobByLeader.BroadcastHeader", "error", err.Error()) } @@ -390,20 +389,20 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { } func (sr *subroundEndRound) aggregateSigsAndHandleInvalidSigners(bitmap []byte) ([]byte, []byte, error) { - sig, err := sr.SignatureHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) + sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) if err != nil { log.Debug("doEndRoundJobByLeader.AggregateSigs", "error", err.Error()) return sr.handleInvalidSignersOnAggSigFail() } - err = sr.SignatureHandler().SetAggregatedSig(sig) + err = sr.SigningHandler().SetAggregatedSig(sig) if err != nil { log.Debug("doEndRoundJobByLeader.SetAggregatedSig", "error", err.Error()) return nil, nil, err } - err = sr.SignatureHandler().Verify(sr.GetData(), bitmap, sr.Header.GetEpoch()) + err = sr.SigningHandler().Verify(sr.GetData(), bitmap, sr.Header.GetEpoch()) if err != nil { log.Debug("doEndRoundJobByLeader.Verify", "error", err.Error()) @@ -427,13 +426,13 @@ func (sr *subroundEndRound) verifyNodesOnAggSigFail() ([]string, error) { continue } - sigShare, err := sr.SignatureHandler().SignatureShare(uint16(i)) + sigShare, err := sr.SigningHandler().SignatureShare(uint16(i)) if err != nil { return nil, err } isSuccessfull := true - err = sr.SignatureHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), sr.Header.GetEpoch()) + err = sr.SigningHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), sr.Header.GetEpoch()) if err != nil { isSuccessfull = false @@ -525,12 +524,12 @@ func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) return nil, nil, err } - sig, err := sr.SignatureHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) + sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) if err != nil { return nil, nil, err } - err = sr.SignatureHandler().SetAggregatedSig(sig) + err = sr.SigningHandler().SetAggregatedSig(sig) if err != nil { return nil, nil, err } @@ -539,12 +538,18 @@ func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) } func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("createAndBroadcastHeaderFinalInfo.GetLeader", "error", errGetLeader) + return + } + cnsMsg := consensus.NewConsensusMessage( sr.GetData(), nil, nil, nil, - []byte(sr.SelfPubKey()), + []byte(leader), nil, int(MtBlockHeaderFinalInfo), sr.RoundHandler().Index(), @@ -552,7 +557,7 @@ func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { sr.Header.GetPubKeysBitmap(), sr.Header.GetSignature(), sr.Header.GetLeaderSignature(), - sr.CurrentPid(), + sr.GetAssociatedPid([]byte(leader)), nil, ) @@ -660,7 +665,7 @@ func (sr *subroundEndRound) doEndRoundJobByParticipant(cnsDta *consensus.Message sr.SetStatus(sr.Current(), spos.SsFinished) - if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { + if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || sr.IsMultiKeyInConsensusGroup() { err = sr.setHeaderForValidator(header) if err != nil { log.Warn("doEndRoundJobByParticipant", "error", err.Error()) @@ -769,7 +774,12 @@ func (sr *subroundEndRound) signBlockHeader() ([]byte, error) { return nil, err } - return sr.SingleSigner().Sign(sr.PrivateKey(), marshalizedHdr) + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + return nil, errGetLeader + } + + return sr.SigningHandler().CreateSignatureForPublicKey(marshalizedHdr, []byte(leader)) } func (sr *subroundEndRound) updateMetricsForLeader() { @@ -784,38 +794,33 @@ func (sr *subroundEndRound) broadcastBlockDataLeader() error { return err } - return sr.BroadcastMessenger().BroadcastBlockDataLeader(sr.Header, miniBlocks, transactions) + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("broadcastBlockDataLeader.GetLeader", "error", errGetLeader) + return errGetLeader + } + + return sr.BroadcastMessenger().BroadcastBlockDataLeader(sr.Header, miniBlocks, transactions, []byte(leader)) } func (sr *subroundEndRound) setHeaderForValidator(header data.HeaderHandler) error { - idx, err := sr.SelfConsensusGroupIndex() + idx, pk, miniBlocks, transactions, err := sr.getIndexPkAndDataToBroadcast() if err != nil { return err } - // todo: avoid calling MarshalizeDataToBroadcast twice for validators - miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.Header, sr.Body) - if err != nil { - return err - } - - go sr.BroadcastMessenger().PrepareBroadcastHeaderValidator(header, miniBlocks, transactions, idx) + go sr.BroadcastMessenger().PrepareBroadcastHeaderValidator(header, miniBlocks, transactions, idx, pk) return nil } func (sr *subroundEndRound) prepareBroadcastBlockDataForValidator() error { - idx, err := sr.SelfConsensusGroupIndex() + idx, pk, miniBlocks, transactions, err := sr.getIndexPkAndDataToBroadcast() if err != nil { return err } - miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.Header, sr.Body) - if err != nil { - return err - } - - go sr.BroadcastMessenger().PrepareBroadcastBlockDataValidator(sr.Header, miniBlocks, transactions, idx) + go sr.BroadcastMessenger().PrepareBroadcastBlockDataValidator(sr.Header, miniBlocks, transactions, idx, pk) return nil } @@ -878,6 +883,47 @@ func (sr *subroundEndRound) isOutOfTime() bool { return false } +func (sr *subroundEndRound) getIndexPkAndDataToBroadcast() (int, []byte, map[uint32][]byte, map[string][][]byte, error) { + minIdx := sr.getMinConsensusGroupIndexOfManagedKeys() + + idx, err := sr.SelfConsensusGroupIndex() + if err == nil { + if idx < minIdx { + minIdx = idx + } + } + + if minIdx == sr.ConsensusGroupSize() { + return -1, nil, nil, nil, err + } + + miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.Header, sr.Body) + if err != nil { + return -1, nil, nil, nil, err + } + + consensusGroup := sr.ConsensusGroup() + pk := []byte(consensusGroup[minIdx]) + + return minIdx, pk, miniBlocks, transactions, nil +} + +func (sr *subroundEndRound) getMinConsensusGroupIndexOfManagedKeys() int { + minIdx := sr.ConsensusGroupSize() + + for idx, validator := range sr.ConsensusGroup() { + if !sr.IsKeyManagedByCurrentNode([]byte(validator)) { + continue + } + + if idx < minIdx { + minIdx = idx + } + } + + return minIdx +} + // IsInterfaceNil returns true if there is no value under the interface func (sr *subroundEndRound) IsInterfaceNil() bool { return sr == nil diff --git a/consensus/spos/bls/subroundEndRound_test.go b/consensus/spos/bls/subroundEndRound_test.go index a27262e3e57..6899649f84b 100644 --- a/consensus/spos/bls/subroundEndRound_test.go +++ b/consensus/spos/bls/subroundEndRound_test.go @@ -1,6 +1,7 @@ package bls_test import ( + "bytes" "errors" "sync" "testing" @@ -11,14 +12,15 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" - "github.com/multiversx/mx-chain-p2p-go/message" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" + "github.com/multiversx/mx-chain-p2p-go/message" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -326,12 +328,12 @@ func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) container := mock.InitConsensusCore() sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { return nil, crypto.ErrNilHasher }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr.Header = &block.Header{} @@ -426,10 +428,10 @@ func TestSubroundEndRound_DoEndRoundJobErrMarshalizedDataToBroadcastOK(t *testin BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, - BroadcastMiniBlocksCalled: func(bytes map[uint32][]byte) error { + BroadcastMiniBlocksCalled: func(bytes map[uint32][]byte, pkBytes []byte) error { return nil }, - BroadcastTransactionsCalled: func(bytes map[string][][]byte) error { + BroadcastTransactionsCalled: func(bytes map[string][][]byte, pkBytes []byte) error { return nil }, } @@ -460,11 +462,11 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastMiniBlocksOK(t *testing.T) { BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, - BroadcastMiniBlocksCalled: func(bytes map[uint32][]byte) error { + BroadcastMiniBlocksCalled: func(bytes map[uint32][]byte, pkBytes []byte) error { err = errors.New("error broadcast miniblocks") return err }, - BroadcastTransactionsCalled: func(bytes map[string][][]byte) error { + BroadcastTransactionsCalled: func(bytes map[string][][]byte, pkBytes []byte) error { return nil }, } @@ -496,10 +498,10 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastTransactionsOK(t *testing.T) BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, - BroadcastMiniBlocksCalled: func(bytes map[uint32][]byte) error { + BroadcastMiniBlocksCalled: func(bytes map[uint32][]byte, pkBytes []byte) error { return nil }, - BroadcastTransactionsCalled: func(bytes map[string][][]byte) error { + BroadcastTransactionsCalled: func(bytes map[string][][]byte, pkBytes []byte) error { err = errors.New("error broadcast transactions") return err }, @@ -540,14 +542,14 @@ func TestSubroundEndRound_CheckIfSignatureIsFilled(t *testing.T) { expectedSignature := []byte("signature") container := mock.InitConsensusCore() - singleSigner := &mock.SingleSignerMock{ - SignStub: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { + signingHandler := &mock.SigningHandlerStub{ + CreateSignatureForPublicKeyCalled: func(publicKeyBytes []byte, msg []byte) ([]byte, error) { var receivedHdr block.Header _ = container.Marshalizer().Unmarshal(&receivedHdr, msg) return expectedSignature, nil }, } - container.SetSingleSigner(singleSigner) + container.SetSigningHandler(signingHandler) bm := &mock.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return errors.New("error") @@ -954,13 +956,13 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ SignatureShareCalled: func(index uint16) ([]byte, error) { return nil, expectedErr }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr.Header = &block.Header{} _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) @@ -976,7 +978,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ SignatureShareCalled: func(index uint16) ([]byte, error) { return nil, nil }, @@ -987,7 +989,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { sr.Header = &block.Header{} _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) _, err := sr.VerifyNodesOnAggSigFail() require.Nil(t, err) @@ -1002,7 +1004,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { container := mock.InitConsensusCore() sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ SignatureShareCalled: func(index uint16) ([]byte, error) { return nil, nil }, @@ -1013,7 +1015,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { return nil }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr.Header = &block.Header{} _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) @@ -1047,12 +1049,12 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { return nil, expectedErr }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr.Header = &block.Header{} _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) @@ -1068,12 +1070,12 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ SetAggregatedSigCalled: func(_ []byte) error { return expectedErr }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr.Header = &block.Header{} _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) @@ -1107,7 +1109,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { verifySigShareNumCalls := 0 verifyFirstCall := true - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ SignatureShareCalled: func(index uint16) ([]byte, error) { return nil, nil }, @@ -1130,7 +1132,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr.SetThreshold(bls.SrEndRound, 2) @@ -1154,7 +1156,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { verifySigShareNumCalls := 0 verifyFirstCall := true - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ SignatureShareCalled: func(index uint16) ([]byte, error) { return nil, nil }, @@ -1177,7 +1179,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr.SetThreshold(bls.SrEndRound, 2) @@ -1416,14 +1418,15 @@ func TestVerifyInvalidSigners(t *testing.T) { }, } - singleSignerMock := &mock.SingleSignerMock{} wasCalled := false - singleSignerMock.VerifyStub = func(public crypto.PublicKey, msg, sig []byte) error { - wasCalled = true - return errors.New("expected err") + signingHandler := &mock.SigningHandlerStub{ + VerifySingleSignatureCalled: func(publicKeyBytes []byte, message []byte, signature []byte) error { + wasCalled = true + return errors.New("expected err") + }, } - container.SetSingleSigner(singleSignerMock) + container.SetSigningHandler(signingHandler) container.SetMessageSigningHandler(messageSigningHandler) sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) @@ -1543,3 +1546,64 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { require.Equal(t, expectedInvalidSigners, invalidSignersBytes) }) } + +func TestSubroundEndRound_getMinConsensusGroupIndexOfManagedKeys(t *testing.T) { + t.Parallel() + + container := mock.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + ch := make(chan bool, 1) + consensusState := initConsensusStateWithKeysHandler(keysHandler) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := bls.NewSubroundEndRound( + sr, + extend, + bls.ProcessingThresholdPercent, + displayStatistics, + &statusHandler.AppStatusHandlerStub{}, + ) + + t.Run("no managed keys from consensus group", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return false + } + + assert.Equal(t, 9, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("first managed key in consensus group should return 0", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("A"), pkBytes) + } + + assert.Equal(t, 0, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("third managed key in consensus group should return 2", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("C"), pkBytes) + } + + assert.Equal(t, 2, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("last managed key in consensus group should return 8", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("I"), pkBytes) + } + + assert.Equal(t, 8, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) +} diff --git a/consensus/spos/bls/subroundSignature.go b/consensus/spos/bls/subroundSignature.go index 1d79ea0aff7..e58c1786d27 100644 --- a/consensus/spos/bls/subroundSignature.go +++ b/consensus/spos/bls/subroundSignature.go @@ -62,7 +62,7 @@ func checkNewSubroundSignatureParams( // doSignatureJob method does the job of the subround Signature func (sr *subroundSignature) doSignatureJob(_ context.Context) bool { - if !sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { + if !sr.IsNodeInConsensusGroup(sr.SelfPubKey()) && !sr.IsMultiKeyInConsensusGroup() { return true } if !sr.CanDoSubroundJob(sr.Current()) { @@ -73,57 +73,85 @@ func (sr *subroundSignature) doSignatureJob(_ context.Context) bool { return false } - selfIndex, err := sr.SelfConsensusGroupIndex() - if err != nil { - log.Debug("doSignatureJob.SelfConsensusGroupIndex: not in consensus group") - return false - } - - signatureShare, err := sr.SignatureHandler().CreateSignatureShare(sr.GetData(), uint16(selfIndex), sr.Header.GetEpoch()) - if err != nil { - log.Debug("doSignatureJob.CreateSignatureShare", "error", err.Error()) - return false - } - isSelfLeader := sr.IsSelfLeaderInCurrentRound() - if !isSelfLeader { - // TODO: Analyze it is possible to send message only to leader with O(1) instead of O(n) - cnsMsg := consensus.NewConsensusMessage( + if isSelfLeader || sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { + selfIndex, err := sr.SelfConsensusGroupIndex() + if err != nil { + log.Debug("doSignatureJob.SelfConsensusGroupIndex: not in consensus group") + return false + } + + signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( sr.GetData(), - signatureShare, - nil, - nil, + uint16(selfIndex), + sr.Header.GetEpoch(), []byte(sr.SelfPubKey()), - nil, - int(MtSignature), - sr.RoundHandler().Index(), - sr.ChainID(), - nil, - nil, - nil, - sr.CurrentPid(), - nil, ) - - err = sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) if err != nil { - log.Debug("doSignatureJob.BroadcastConsensusMessage", "error", err.Error()) + log.Debug("doSignatureJob.CreateSignatureShareForPublicKey", "error", err.Error()) + return false + } + + if !isSelfLeader { + ok := sr.createAndSendSignatureMessage(signatureShare, []byte(sr.SelfPubKey())) + if !ok { + return false + } + } + + ok := sr.completeSignatureSubRound(sr.SelfPubKey(), isSelfLeader) + if !ok { return false } + } + + return sr.doSignatureJobForManagedKeys() +} + +func (sr *subroundSignature) createAndSendSignatureMessage(signatureShare []byte, pkBytes []byte) bool { + // TODO: Analyze it is possible to send message only to leader with O(1) instead of O(n) + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + signatureShare, + nil, + nil, + pkBytes, + nil, + int(MtSignature), + sr.RoundHandler().Index(), + sr.ChainID(), + nil, + nil, + nil, + sr.GetAssociatedPid(pkBytes), + nil, + ) - log.Debug("step 2: signature has been sent") + err := sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) + if err != nil { + log.Debug("createAndSendSignatureMessage.BroadcastConsensusMessage", + "error", err.Error(), "pk", pkBytes) + return false } - err = sr.SetSelfJobDone(sr.Current(), true) + log.Debug("step 2: signature has been sent", "pk", pkBytes) + + return true +} + +func (sr *subroundSignature) completeSignatureSubRound(pk string, shouldWaitForAllSigsAsync bool) bool { + err := sr.SetJobDone(pk, sr.Current(), true) if err != nil { log.Debug("doSignatureJob.SetSelfJobDone", "subround", sr.Name(), - "error", err.Error()) + "error", err.Error(), + "pk", []byte(pk), + ) return false } - if isSelfLeader { + if shouldWaitForAllSigsAsync { go sr.waitAllSignatures() } @@ -151,7 +179,7 @@ func (sr *subroundSignature) receivedSignature(_ context.Context, cnsDta *consen return false } - if !sr.IsSelfLeaderInCurrentRound() { + if !sr.IsSelfLeaderInCurrentRound() && !sr.IsMultiKeyLeaderInCurrentRound() { return false } @@ -171,7 +199,7 @@ func (sr *subroundSignature) receivedSignature(_ context.Context, cnsDta *consen return false } - err = sr.SignatureHandler().StoreSignatureShare(uint16(index), cnsDta.SignatureShare) + err = sr.SigningHandler().StoreSignatureShare(uint16(index), cnsDta.SignatureShare) if err != nil { log.Debug("receivedSignature.StoreSignatureShare", "node", pkForLogs, @@ -211,8 +239,8 @@ func (sr *subroundSignature) doSignatureConsensusCheck() bool { return true } - isSelfLeader := sr.IsSelfLeaderInCurrentRound() - isSelfInConsensusGroup := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) + isSelfLeader := sr.IsSelfLeaderInCurrentRound() || sr.IsMultiKeyLeaderInCurrentRound() + isSelfInConsensusGroup := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || sr.IsMultiKeyInConsensusGroup() threshold := sr.Threshold(sr.Current()) if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.Header) { @@ -227,7 +255,16 @@ func (sr *subroundSignature) doSignatureConsensusCheck() bool { areAllSignaturesCollected := numSigs == sr.ConsensusGroupSize() isJobDoneByLeader := isSelfLeader && (areAllSignaturesCollected || (areSignaturesCollected && sr.WaitingAllSignaturesTimeOut)) - isJobDoneByConsensusNode := !isSelfLeader && isSelfInConsensusGroup && sr.IsSelfJobDone(sr.Current()) + + selfJobDone := true + if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { + selfJobDone = sr.IsSelfJobDone(sr.Current()) + } + multiKeyJobDone := true + if sr.IsMultiKeyInConsensusGroup() { + multiKeyJobDone = sr.IsMultiKeyJobDone(sr.Current()) + } + isJobDoneByConsensusNode := !isSelfLeader && isSelfInConsensusGroup && selfJobDone && multiKeyJobDone isSubroundFinished := !isSelfInConsensusGroup || isJobDoneByConsensusNode || isJobDoneByLeader @@ -304,6 +341,59 @@ func (sr *subroundSignature) remainingTime() time.Duration { return remainigTime } +func (sr *subroundSignature) doSignatureJobForManagedKeys() bool { + isMultiKeyLeader := sr.IsMultiKeyLeaderInCurrentRound() + + numMultiKeysSignaturesSent := 0 + for idx, pk := range sr.ConsensusGroup() { + pkBytes := []byte(pk) + if sr.IsJobDone(pk, sr.Current()) { + continue + } + if !sr.IsKeyManagedByCurrentNode(pkBytes) { + continue + } + + selfIndex, err := sr.ConsensusGroupIndex(pk) + if err != nil { + log.Warn("doSignatureJobForManagedKeys: index not found", "pk", pkBytes) + continue + } + + signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( + sr.GetData(), + uint16(selfIndex), + sr.Header.GetEpoch(), + pkBytes, + ) + if err != nil { + log.Debug("doSignatureJobForManagedKeys.CreateSignatureShareForPublicKey", "error", err.Error()) + return false + } + + if !isMultiKeyLeader { + ok := sr.createAndSendSignatureMessage(signatureShare, pkBytes) + if !ok { + return false + } + + numMultiKeysSignaturesSent++ + } + + isLeader := idx == spos.IndexOfLeaderInConsensusGroup + ok := sr.completeSignatureSubRound(pk, isLeader) + if !ok { + return false + } + } + + if numMultiKeysSignaturesSent > 0 { + log.Debug("step 2: multi keys signatures have been sent", "num", numMultiKeysSignaturesSent) + } + + return true +} + // IsInterfaceNil returns true if there is no value under the interface func (sr *subroundSignature) IsInterfaceNil() bool { return sr == nil diff --git a/consensus/spos/bls/subroundSignature_test.go b/consensus/spos/bls/subroundSignature_test.go index 62f38ee92fe..613d1f315e8 100644 --- a/consensus/spos/bls/subroundSignature_test.go +++ b/consensus/spos/bls/subroundSignature_test.go @@ -277,22 +277,22 @@ func TestSubroundSignature_DoSignatureJob(t *testing.T) { sr.Data = []byte("X") err := errors.New("create signature share error") - signatureHandler := &mock.SignatureHandlerStub{ - CreateSignatureShareCalled: func(msg []byte, index uint16, epoch uint32) ([]byte, error) { + signingHandler := &mock.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(msg []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { return nil, err }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) r = sr.DoSignatureJob() assert.False(t, r) - signatureHandler = &mock.SignatureHandlerStub{ - CreateSignatureShareCalled: func(msg []byte, index uint16, epoch uint32) ([]byte, error) { + signingHandler = &mock.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(msg []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { return []byte("SIG"), nil }, } - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) r = sr.DoSignatureJob() assert.True(t, r) @@ -367,7 +367,7 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { errStore := errors.New("signature share store failed") storeSigShareCalled := false - signatureHandler := &mock.SignatureHandlerStub{ + signingHandler := &mock.SigningHandlerStub{ VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { return nil }, @@ -378,7 +378,7 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { } container := mock.InitConsensusCore() - container.SetSignatureHandler(signatureHandler) + container.SetSigningHandler(signingHandler) sr := *initSubroundSignatureWithContainer(container) sr.Header = &block.Header{} diff --git a/consensus/spos/bls/subroundStartRound.go b/consensus/spos/bls/subroundStartRound.go index bce049a8783..a5c1f179609 100644 --- a/consensus/spos/bls/subroundStartRound.go +++ b/consensus/spos/bls/subroundStartRound.go @@ -155,6 +155,9 @@ func (sr *subroundStartRound) initCurrentRound() bool { } msg := "" + if sr.IsKeyManagedByCurrentNode([]byte(leader)) { + msg = " (my turn in multi-key)" + } if leader == sr.SelfPubKey() { sr.AppStatusHandler().Increment(common.MetricCountLeader) sr.AppStatusHandler().SetStringValue(common.MetricConsensusRoundState, "proposed") @@ -167,12 +170,15 @@ func (sr *subroundStartRound) initCurrentRound() bool { "messsage", msg) pubKeys := sr.ConsensusGroup() + numMultiKeysInConsensusGroup := sr.computeNumManagedKeysInConsensusGroup(pubKeys) sr.indexRoundIfNeeded(pubKeys) _, err = sr.SelfConsensusGroupIndex() if err != nil { - log.Debug("not in consensus group") + if numMultiKeysInConsensusGroup == 0 { + log.Debug("not in consensus group") + } sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "not in consensus group") } else { if leader != sr.SelfPubKey() { @@ -181,7 +187,7 @@ func (sr *subroundStartRound) initCurrentRound() bool { sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "participant") } - err = sr.SignatureHandler().Reset(pubKeys) + err = sr.SigningHandler().Reset(pubKeys) if err != nil { log.Debug("initCurrentRound.Reset", "error", err.Error()) @@ -210,6 +216,25 @@ func (sr *subroundStartRound) initCurrentRound() bool { return true } +func (sr *subroundStartRound) computeNumManagedKeysInConsensusGroup(pubKeys []string) int { + numMultiKeysInConsensusGroup := 0 + for _, pk := range pubKeys { + pkBytes := []byte(pk) + if sr.IsKeyManagedByCurrentNode(pkBytes) { + sr.IncrementRoundsWithoutReceivedMessages(pkBytes) + numMultiKeysInConsensusGroup++ + log.Trace("in consensus group with multi key", + "pk", core.GetTrimmedPk(hex.EncodeToString(pkBytes))) + } + } + + if numMultiKeysInConsensusGroup > 0 { + log.Debug("in consensus group with multi keys identities", "num", numMultiKeysInConsensusGroup) + } + + return numMultiKeysInConsensusGroup +} + func (sr *subroundStartRound) indexRoundIfNeeded(pubKeys []string) { sr.outportMutex.RLock() defer sr.outportMutex.RUnlock() diff --git a/consensus/spos/consensusCore.go b/consensus/spos/consensusCore.go index e460aa2cf66..1edfb09b5fc 100644 --- a/consensus/spos/consensusCore.go +++ b/consensus/spos/consensusCore.go @@ -4,7 +4,6 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" - crypto "github.com/multiversx/mx-chain-crypto-go" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -24,9 +23,6 @@ type ConsensusCore struct { chronologyHandler consensus.ChronologyHandler hasher hashing.Hasher marshalizer marshal.Marshalizer - blsPrivateKey crypto.PrivateKey - blsSingleSigner crypto.SingleSigner - keyGenerator crypto.KeyGenerator multiSignerContainer cryptoCommon.MultiSignerContainer roundHandler consensus.RoundHandler shardCoordinator sharding.Coordinator @@ -41,7 +37,7 @@ type ConsensusCore struct { scheduledProcessor consensus.ScheduledProcessor messageSigningHandler consensus.P2PSigningHandler peerBlacklistHandler consensus.PeerBlacklistHandler - signatureHandler consensus.SignatureHandler + signingHandler consensus.SigningHandler } // ConsensusCoreArgs store all arguments that are needed to create a ConsensusCore object @@ -53,9 +49,6 @@ type ConsensusCoreArgs struct { ChronologyHandler consensus.ChronologyHandler Hasher hashing.Hasher Marshalizer marshal.Marshalizer - BlsPrivateKey crypto.PrivateKey - BlsSingleSigner crypto.SingleSigner - KeyGenerator crypto.KeyGenerator MultiSignerContainer cryptoCommon.MultiSignerContainer RoundHandler consensus.RoundHandler ShardCoordinator sharding.Coordinator @@ -70,7 +63,7 @@ type ConsensusCoreArgs struct { ScheduledProcessor consensus.ScheduledProcessor MessageSigningHandler consensus.P2PSigningHandler PeerBlacklistHandler consensus.PeerBlacklistHandler - SignatureHandler consensus.SignatureHandler + SigningHandler consensus.SigningHandler } // NewConsensusCore creates a new ConsensusCore instance @@ -85,9 +78,6 @@ func NewConsensusCore( chronologyHandler: args.ChronologyHandler, hasher: args.Hasher, marshalizer: args.Marshalizer, - blsPrivateKey: args.BlsPrivateKey, - blsSingleSigner: args.BlsSingleSigner, - keyGenerator: args.KeyGenerator, multiSignerContainer: args.MultiSignerContainer, roundHandler: args.RoundHandler, shardCoordinator: args.ShardCoordinator, @@ -102,7 +92,7 @@ func NewConsensusCore( scheduledProcessor: args.ScheduledProcessor, messageSigningHandler: args.MessageSigningHandler, peerBlacklistHandler: args.PeerBlacklistHandler, - signatureHandler: args.SignatureHandler, + signingHandler: args.SigningHandler, } err := ValidateConsensusCore(consensusCore) @@ -183,21 +173,6 @@ func (cc *ConsensusCore) EpochStartRegistrationHandler() epochStart.Registration return cc.epochStartRegistrationHandler } -// PrivateKey returns the BLS private key stored in the ConsensusCore -func (cc *ConsensusCore) PrivateKey() crypto.PrivateKey { - return cc.blsPrivateKey -} - -// SingleSigner returns the bls single signer stored in the ConsensusCore -func (cc *ConsensusCore) SingleSigner() crypto.SingleSigner { - return cc.blsSingleSigner -} - -// KeyGenerator returns the bls key generator stored in the ConsensusCore -func (cc *ConsensusCore) KeyGenerator() crypto.KeyGenerator { - return cc.keyGenerator -} - // PeerHonestyHandler will return the peer honesty handler which will be used in subrounds func (cc *ConsensusCore) PeerHonestyHandler() consensus.PeerHonestyHandler { return cc.peerHonestyHandler @@ -233,9 +208,9 @@ func (cc *ConsensusCore) PeerBlacklistHandler() consensus.PeerBlacklistHandler { return cc.peerBlacklistHandler } -// SignatureHandler will return the signature handler component -func (cc *ConsensusCore) SignatureHandler() consensus.SignatureHandler { - return cc.signatureHandler +// SigningHandler will return the signing handler component +func (cc *ConsensusCore) SigningHandler() consensus.SigningHandler { + return cc.signingHandler } // IsInterfaceNil returns true if there is no value under the interface diff --git a/consensus/spos/consensusCoreValidator.go b/consensus/spos/consensusCoreValidator.go index 0d9c04ab476..239c762f6d3 100644 --- a/consensus/spos/consensusCoreValidator.go +++ b/consensus/spos/consensusCoreValidator.go @@ -47,15 +47,6 @@ func ValidateConsensusCore(container ConsensusCoreHandler) error { if check.IfNil(container.NodesCoordinator()) { return ErrNilNodesCoordinator } - if check.IfNil(container.PrivateKey()) { - return ErrNilBlsPrivateKey - } - if check.IfNil(container.SingleSigner()) { - return ErrNilBlsSingleSigner - } - if check.IfNil(container.KeyGenerator()) { - return ErrNilKeyGenerator - } if check.IfNil(container.GetAntiFloodHandler()) { return ErrNilAntifloodHandler } @@ -80,8 +71,8 @@ func ValidateConsensusCore(container ConsensusCoreHandler) error { if check.IfNil(container.PeerBlacklistHandler()) { return ErrNilPeerBlacklistHandler } - if check.IfNil(container.SignatureHandler()) { - return ErrNilSignatureHandler + if check.IfNil(container.SigningHandler()) { + return ErrNilSigningHandler } return nil diff --git a/consensus/spos/consensusCoreValidator_test.go b/consensus/spos/consensusCoreValidator_test.go index e82360eb416..41b965887b1 100644 --- a/consensus/spos/consensusCoreValidator_test.go +++ b/consensus/spos/consensusCoreValidator_test.go @@ -19,10 +19,7 @@ func initConsensusDataContainer() *ConsensusCore { bootstrapperMock := &mock.BootstrapperStub{} broadcastMessengerMock := &mock.BroadcastMessengerMock{} chronologyHandlerMock := mock.InitChronologyHandlerMock() - blsPrivateKeyMock := &mock.PrivateKeyMock{} - blsSingleSignerMock := &mock.SingleSignerMock{} multiSignerMock := cryptoMocks.NewMultiSigner() - keyGenerator := &mock.KeyGenMock{} hasherMock := &hashingMocks.HasherMock{} roundHandlerMock := &mock.RoundHandlerMock{} shardCoordinatorMock := mock.ShardCoordinatorMock{} @@ -37,7 +34,7 @@ func initConsensusDataContainer() *ConsensusCore { messageSigningHandler := &mock.MessageSigningHandlerStub{} peerBlacklistHandler := &mock.PeerBlacklistHandlerStub{} multiSignerContainer := cryptoMocks.NewMultiSignerContainerMock(multiSignerMock) - signatureHandler := &mock.SignatureHandlerStub{} + signingHandler := &mock.SigningHandlerStub{} return &ConsensusCore{ blockChain: blockChain, @@ -47,9 +44,6 @@ func initConsensusDataContainer() *ConsensusCore { chronologyHandler: chronologyHandlerMock, hasher: hasherMock, marshalizer: marshalizerMock, - blsPrivateKey: blsPrivateKeyMock, - blsSingleSigner: blsSingleSignerMock, - keyGenerator: keyGenerator, multiSignerContainer: multiSignerContainer, roundHandler: roundHandlerMock, shardCoordinator: shardCoordinatorMock, @@ -63,7 +57,7 @@ func initConsensusDataContainer() *ConsensusCore { scheduledProcessor: scheduledProcessor, messageSigningHandler: messageSigningHandler, peerBlacklistHandler: peerBlacklistHandler, - signatureHandler: signatureHandler, + signingHandler: signingHandler, } } @@ -258,11 +252,11 @@ func TestConsensusContainerValidator_ValidateNilSignatureHandlerShouldFail(t *te t.Parallel() container := initConsensusDataContainer() - container.signatureHandler = nil + container.signingHandler = nil err := ValidateConsensusCore(container) - assert.Equal(t, ErrNilSignatureHandler, err) + assert.Equal(t, ErrNilSigningHandler, err) } func TestConsensusContainerValidator_ShouldWork(t *testing.T) { diff --git a/consensus/spos/consensusCore_test.go b/consensus/spos/consensusCore_test.go index d77203c36db..2fd67a2cb63 100644 --- a/consensus/spos/consensusCore_test.go +++ b/consensus/spos/consensusCore_test.go @@ -23,9 +23,6 @@ func createDefaultConsensusCoreArgs() *spos.ConsensusCoreArgs { ChronologyHandler: consensusCoreMock.Chronology(), Hasher: consensusCoreMock.Hasher(), Marshalizer: consensusCoreMock.Marshalizer(), - BlsPrivateKey: consensusCoreMock.PrivateKey(), - BlsSingleSigner: consensusCoreMock.SingleSigner(), - KeyGenerator: consensusCoreMock.KeyGenerator(), MultiSignerContainer: consensusCoreMock.MultiSignerContainer(), RoundHandler: consensusCoreMock.RoundHandler(), ShardCoordinator: consensusCoreMock.ShardCoordinator(), @@ -40,7 +37,7 @@ func createDefaultConsensusCoreArgs() *spos.ConsensusCoreArgs { ScheduledProcessor: scheduledProcessor, MessageSigningHandler: consensusCoreMock.MessageSigningHandler(), PeerBlacklistHandler: consensusCoreMock.PeerBlacklistHandler(), - SignatureHandler: consensusCoreMock.SignatureHandler(), + SigningHandler: consensusCoreMock.SigningHandler(), } return args } @@ -141,34 +138,6 @@ func TestConsensusCore_WithNilMarshalizerShouldFail(t *testing.T) { assert.Equal(t, spos.ErrNilMarshalizer, err) } -func TestConsensusCore_WithNilBlsPrivateKeyShouldFail(t *testing.T) { - t.Parallel() - - args := createDefaultConsensusCoreArgs() - args.BlsPrivateKey = nil - - consensusCore, err := spos.NewConsensusCore( - args, - ) - - assert.Nil(t, consensusCore) - assert.Equal(t, spos.ErrNilBlsPrivateKey, err) -} - -func TestConsensusCore_WithNilBlsSingleSignerShouldFail(t *testing.T) { - t.Parallel() - - args := createDefaultConsensusCoreArgs() - args.BlsSingleSigner = nil - - consensusCore, err := spos.NewConsensusCore( - args, - ) - - assert.Nil(t, consensusCore) - assert.Equal(t, spos.ErrNilBlsSingleSigner, err) -} - func TestConsensusCore_WithNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() @@ -197,20 +166,6 @@ func TestConsensusCore_WithNilMultiSignerShouldFail(t *testing.T) { assert.Equal(t, spos.ErrNilMultiSigner, err) } -func TestConsensusCore_WithNilKeyGeneratorShouldFail(t *testing.T) { - t.Parallel() - - args := createDefaultConsensusCoreArgs() - args.KeyGenerator = nil - - consensusCore, err := spos.NewConsensusCore( - args, - ) - - assert.Nil(t, consensusCore) - assert.Equal(t, spos.ErrNilKeyGenerator, err) -} - func TestConsensusCore_WithNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() diff --git a/consensus/spos/consensusState.go b/consensus/spos/consensusState.go index 9db954a97d8..5da416b0e09 100644 --- a/consensus/spos/consensusState.go +++ b/consensus/spos/consensusState.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/p2p" @@ -12,6 +13,9 @@ import ( logger "github.com/multiversx/mx-chain-logger-go" ) +// IndexOfLeaderInConsensusGroup represents the index of the leader in the consensus group +const IndexOfLeaderInConsensusGroup = 0 + var log = logger.GetOrCreate("consensus/spos") // ConsensusState defines the data needed by spos to do the consensus in each round @@ -125,7 +129,7 @@ func (cns *ConsensusState) GetMessageWithSignature(key string) (p2p.MessageP2P, func (cns *ConsensusState) IsNodeLeaderInCurrentRound(node string) bool { leader, err := cns.GetLeader() if err != nil { - log.Debug("GetLeader", "error", err.Error()) + log.Debug("IsNodeLeaderInCurrentRound.GetLeader", "error", err.Error()) return false } @@ -147,7 +151,7 @@ func (cns *ConsensusState) GetLeader() (string, error) { return "", ErrEmptyConsensusGroup } - return cns.consensusGroup[0], nil + return cns.consensusGroup[IndexOfLeaderInConsensusGroup], nil } // GetNextConsensusGroup gets the new consensus group for the current round based on current eligible list and a random @@ -247,7 +251,16 @@ func (cns *ConsensusState) CanDoSubroundJob(currentSubroundId int) bool { return false } - if cns.IsSelfJobDone(currentSubroundId) { + selfJobDone := true + if cns.IsNodeInConsensusGroup(cns.SelfPubKey()) { + selfJobDone = cns.IsSelfJobDone(currentSubroundId) + } + multiKeyJobDone := true + if cns.IsMultiKeyInConsensusGroup() { + multiKeyJobDone = cns.IsMultiKeyJobDone(currentSubroundId) + } + + if selfJobDone && multiKeyJobDone { return false } @@ -327,3 +340,47 @@ func (cns *ConsensusState) SetProcessingBlock(processingBlock bool) { func (cns *ConsensusState) GetData() []byte { return cns.Data } + +// IsMultiKeyLeaderInCurrentRound method checks if one of the nodes which are controlled by this instance +// is leader in the current round +func (cns *ConsensusState) IsMultiKeyLeaderInCurrentRound() bool { + leader, err := cns.GetLeader() + if err != nil { + log.Debug("IsMultiKeyLeaderInCurrentRound.GetLeader", "error", err.Error()) + return false + } + + return cns.IsKeyManagedByCurrentNode([]byte(leader)) +} + +// IsLeaderJobDone method returns true if the leader job for the current subround is done and false otherwise +func (cns *ConsensusState) IsLeaderJobDone(currentSubroundId int) bool { + leader, err := cns.GetLeader() + if err != nil { + log.Debug("GetLeader", "error", err.Error()) + return false + } + + return cns.IsJobDone(leader, currentSubroundId) +} + +// IsMultiKeyJobDone method returns true if all the nodes controlled by this instance finished the current job for +// the current subround and false otherwise +func (cns *ConsensusState) IsMultiKeyJobDone(currentSubroundId int) bool { + for _, validator := range cns.consensusGroup { + if !cns.keysHandler.IsKeyManagedByCurrentNode([]byte(validator)) { + continue + } + + if !cns.IsJobDone(validator, currentSubroundId) { + return false + } + } + + return true +} + +// UpdatePublicKeyLiveness will update the public key's liveness in the network +func (cns *ConsensusState) UpdatePublicKeyLiveness(pkBytes []byte, pid core.PeerID) { + cns.keysHandler.UpdatePublicKeyLiveness(pkBytes, pid) +} diff --git a/consensus/spos/consensusState_test.go b/consensus/spos/consensusState_test.go index 2296b6b43de..74c8426f197 100644 --- a/consensus/spos/consensusState_test.go +++ b/consensus/spos/consensusState_test.go @@ -1,6 +1,7 @@ package spos_test import ( + "bytes" "errors" "testing" @@ -9,11 +10,16 @@ import ( "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/stretchr/testify/assert" ) func internalInitConsensusState() *spos.ConsensusState { + return internalInitConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) +} + +func internalInitConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler) *spos.ConsensusState { eligibleList := []string{"1", "2", "3"} eligibleNodesPubKeys := make(map[string]struct{}) @@ -21,10 +27,12 @@ func internalInitConsensusState() *spos.ConsensusState { eligibleNodesPubKeys[key] = struct{}{} } - rcns := spos.NewRoundConsensus( + rcns, _ := spos.NewRoundConsensus( eligibleNodesPubKeys, 3, - "2") + "2", + keysHandler, + ) rcns.SetConsensusGroup(eligibleList) rcns.ResetRoundState() @@ -496,3 +504,81 @@ func TestConsensusState_SetAndGetProcessingBlockShouldWork(t *testing.T) { assert.Equal(t, true, cns.ProcessingBlock()) } + +func TestConsensusState_IsMultiKeyLeaderInCurrentRound(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{} + cns := internalInitConsensusStateWithKeysHandler(keysHandler) + t.Run("no managed keys from consensus group should return false", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return false + } + assert.False(t, cns.IsMultiKeyLeaderInCurrentRound()) + }) + t.Run("node has managed keys but no managed key is leader should return false", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("2"), pkBytes) + } + + assert.False(t, cns.IsMultiKeyLeaderInCurrentRound()) + }) + t.Run("node has managed keys and one key is leader should return true", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("1"), pkBytes) + } + + assert.True(t, cns.IsMultiKeyLeaderInCurrentRound()) + }) +} + +func TestConsensusState_IsLeaderJobDone(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{} + cns := internalInitConsensusStateWithKeysHandler(keysHandler) + t.Run("should work", func(t *testing.T) { + assert.False(t, cns.IsLeaderJobDone(0)) + leader, _ := cns.GetLeader() + _ = cns.SetJobDone(leader, 0, true) + assert.True(t, cns.IsLeaderJobDone(0)) + }) + t.Run("GetLeader errors should return false", func(t *testing.T) { + leader, _ := cns.GetLeader() + _ = cns.SetJobDone(leader, 0, true) + cns.SetConsensusGroup(make([]string, 0)) + assert.False(t, cns.IsLeaderJobDone(0)) + }) +} + +func TestConsensusState_IsMultiKeyJobDone(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{} + cns := internalInitConsensusStateWithKeysHandler(keysHandler) + managedKeyInConsensus := "1" + managedKeyNotInConsensus := "managed key not in consensus group" + t.Run("no managed keys should return true", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return false + } + + assert.True(t, cns.IsMultiKeyJobDone(0)) + }) + t.Run("node has managed keys but no key is in consensus group should return true", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte(managedKeyNotInConsensus), pkBytes) + } + + assert.True(t, cns.IsMultiKeyJobDone(0)) + }) + t.Run("node has managed keys and one key is in consensus group", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte(managedKeyInConsensus), pkBytes) + } + + assert.False(t, cns.IsMultiKeyJobDone(0)) + _ = cns.SetJobDone(managedKeyInConsensus, 0, true) + assert.True(t, cns.IsMultiKeyJobDone(0)) + }) +} diff --git a/consensus/spos/errors.go b/consensus/spos/errors.go index 5185778d87d..1944ef5e3ba 100644 --- a/consensus/spos/errors.go +++ b/consensus/spos/errors.go @@ -13,9 +13,6 @@ var ErrEmptyConsensusGroup = errors.New("consensusGroup is empty") // ErrNotFoundInConsensus is raised when self expected in consensus group but not found var ErrNotFoundInConsensus = errors.New("self not found in consensus group") -// ErrNilPrivateKey is raised when a valid private key was expected but nil was used -var ErrNilPrivateKey = errors.New("private key is nil") - // ErrNilSignature is raised when a valid signature was expected but nil was used var ErrNilSignature = errors.New("signature is nil") @@ -127,12 +124,6 @@ var ErrInvalidHeader = errors.New("header is invalid") // ErrMessageFromItself is raised when a message from itself is received var ErrMessageFromItself = errors.New("message is from itself") -// ErrNilBlsPrivateKey is raised when the bls private key is nil -var ErrNilBlsPrivateKey = errors.New("BLS private key should not be nil") - -// ErrNilBlsSingleSigner is raised when a message from itself is received -var ErrNilBlsSingleSigner = errors.New("BLS single signer should not be nil") - // ErrNilHeader is raised when an expected header is nil var ErrNilHeader = errors.New("header is nil") @@ -226,9 +217,6 @@ var ErrNilScheduledProcessor = errors.New("nil scheduled processor") // ErrInvalidNumSigShares signals that an invalid number of signature shares has been provided var ErrInvalidNumSigShares = errors.New("invalid number of sig shares") -// ErrWrongTypeAssertion signals that a wrong type assertion has been triggered -var ErrWrongTypeAssertion = errors.New("wrong type assertion") - // ErrNilMessageSigningHandler signals that the provided message signing handler is nil var ErrNilMessageSigningHandler = errors.New("nil message signing handler") @@ -238,14 +226,14 @@ var ErrNilPeerBlacklistHandler = errors.New("nil peer blacklist handler") // ErrNilPeerBlacklistCacher signals that a nil peer blacklist cacher has been provided var ErrNilPeerBlacklistCacher = errors.New("nil peer blacklist cacher") -// ErrNilKeyGenerator signals that a nil key generator has been provided -var ErrNilKeyGenerator = errors.New("nil key generator") - // ErrBlacklistedConsensusPeer signals that a consensus message has been received from a blacklisted peer var ErrBlacklistedConsensusPeer = errors.New("blacklisted consensus peer") // ErrNilSignatureOnP2PMessage signals that a p2p message without signature was received var ErrNilSignatureOnP2PMessage = errors.New("nil signature on the p2p message") -// ErrNilSignatureHandler signals that provided signature handler is nil -var ErrNilSignatureHandler = errors.New("nil signature handler") +// ErrNilSigningHandler signals that provided signing handler is nil +var ErrNilSigningHandler = errors.New("nil signing handler") + +// ErrNilKeysHandler signals that a nil keys handler was provided +var ErrNilKeysHandler = errors.New("nil keys handler") diff --git a/consensus/spos/interface.go b/consensus/spos/interface.go index 1d7ae2348d6..e9e31f6d202 100644 --- a/consensus/spos/interface.go +++ b/consensus/spos/interface.go @@ -9,7 +9,6 @@ import ( "github.com/multiversx/mx-chain-core-go/data/outport" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" - crypto "github.com/multiversx/mx-chain-crypto-go" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -50,12 +49,6 @@ type ConsensusCoreHandler interface { NodesCoordinator() nodesCoordinator.NodesCoordinator // EpochStartRegistrationHandler gets the RegistrationHandler stored in the ConsensusCore EpochStartRegistrationHandler() epochStart.RegistrationHandler - // PrivateKey returns the private key stored in the ConsensusCore used for randomness and leader's signature generation - PrivateKey() crypto.PrivateKey - // SingleSigner returns the single signer stored in the ConsensusCore used for randomness and leader's signature generation - SingleSigner() crypto.SingleSigner - // KeyGenerator returns the key generator stored in the ConsensusCore - KeyGenerator() crypto.KeyGenerator // PeerHonestyHandler returns the peer honesty handler which will be used in subrounds PeerHonestyHandler() consensus.PeerHonestyHandler // HeaderSigVerifier returns the sig verifier handler which will be used in subrounds @@ -66,12 +59,12 @@ type ConsensusCoreHandler interface { NodeRedundancyHandler() consensus.NodeRedundancyHandler // ScheduledProcessor returns the scheduled txs processor ScheduledProcessor() consensus.ScheduledProcessor - // MessageSignerHandler returns the p2p signing handler + // MessageSigningHandler returns the p2p signing handler MessageSigningHandler() consensus.P2PSigningHandler - // PeerBlackListHandler return the peer blacklist handler + // PeerBlacklistHandler return the peer blacklist handler PeerBlacklistHandler() consensus.PeerBlacklistHandler - // SignatureHandler returns the signature handler component - SignatureHandler() consensus.SignatureHandler + // SigningHandler returns the signing handler component + SigningHandler() consensus.SigningHandler // IsInterfaceNil returns true if there is no value under the interface IsInterfaceNil() bool } diff --git a/consensus/spos/roundConsensus.go b/consensus/spos/roundConsensus.go index 4d9e5363be1..b230e124a15 100644 --- a/consensus/spos/roundConsensus.go +++ b/consensus/spos/roundConsensus.go @@ -2,6 +2,9 @@ package spos import ( "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus" ) // roundConsensus defines the data needed by spos to do the consensus in each round @@ -13,6 +16,7 @@ type roundConsensus struct { selfPubKey string validatorRoundStates map[string]*roundState mut sync.RWMutex + keysHandler consensus.KeysHandler } // NewRoundConsensus creates a new roundConsensus object @@ -20,18 +24,20 @@ func NewRoundConsensus( eligibleNodes map[string]struct{}, consensusGroupSize int, selfId string, -) *roundConsensus { - - rcns := roundConsensus{ - eligibleNodes: eligibleNodes, - consensusGroupSize: consensusGroupSize, - selfPubKey: selfId, - mutEligible: sync.RWMutex{}, + keysHandler consensus.KeysHandler, +) (*roundConsensus, error) { + if check.IfNil(keysHandler) { + return nil, ErrNilKeysHandler } - rcns.validatorRoundStates = make(map[string]*roundState) - - return &rcns + return &roundConsensus{ + eligibleNodes: eligibleNodes, + consensusGroupSize: consensusGroupSize, + selfPubKey: selfId, + mutEligible: sync.RWMutex{}, + validatorRoundStates: make(map[string]*roundState), + keysHandler: keysHandler, + }, nil } // ConsensusGroupIndex returns the index of given public key in the current consensus group @@ -136,11 +142,6 @@ func (rcns *roundConsensus) SelfJobDone(subroundId int) (bool, error) { return rcns.JobDone(rcns.selfPubKey, subroundId) } -// SetSelfJobDone set the self state of the action done in subround given by the subroundId parameter -func (rcns *roundConsensus) SetSelfJobDone(subroundId int, value bool) error { - return rcns.SetJobDone(rcns.selfPubKey, subroundId, value) -} - // IsNodeInConsensusGroup method checks if the node is part of consensus group of the current round func (rcns *roundConsensus) IsNodeInConsensusGroup(node string) bool { for i := 0; i < len(rcns.consensusGroup); i++ { @@ -199,3 +200,25 @@ func (rcns *roundConsensus) ResetRoundState() { rcns.mut.Unlock() } + +// IsMultiKeyInConsensusGroup method checks if one of the nodes which are controlled by this instance +// is in consensus group in the current round +func (rcns *roundConsensus) IsMultiKeyInConsensusGroup() bool { + for i := 0; i < len(rcns.consensusGroup); i++ { + if rcns.IsKeyManagedByCurrentNode([]byte(rcns.consensusGroup[i])) { + return true + } + } + + return false +} + +// IsKeyManagedByCurrentNode returns true if the key is managed by the current node +func (rcns *roundConsensus) IsKeyManagedByCurrentNode(pkBytes []byte) bool { + return rcns.keysHandler.IsKeyManagedByCurrentNode(pkBytes) +} + +// IncrementRoundsWithoutReceivedMessages increments the number of rounds without received messages on a provided public key +func (rcns *roundConsensus) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { + rcns.keysHandler.IncrementRoundsWithoutReceivedMessages(pkBytes) +} diff --git a/consensus/spos/roundConsensus_test.go b/consensus/spos/roundConsensus_test.go index 316ed9f8f84..4ba8f7e47fe 100644 --- a/consensus/spos/roundConsensus_test.go +++ b/consensus/spos/roundConsensus_test.go @@ -1,14 +1,21 @@ package spos_test import ( + "bytes" "testing" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/stretchr/testify/assert" ) func initRoundConsensus() *spos.RoundConsensus { + return initRoundConsensusWithKeysHandler(&testscommon.KeysHandlerStub{}) +} + +func initRoundConsensusWithKeysHandler(keysHandler consensus.KeysHandler) *spos.RoundConsensus { pubKeys := []string{"1", "2", "3"} eligibleNodes := make(map[string]struct{}) @@ -16,10 +23,12 @@ func initRoundConsensus() *spos.RoundConsensus { eligibleNodes[pubKeys[i]] = struct{}{} } - rcns := spos.NewRoundConsensus( + rcns, _ := spos.NewRoundConsensus( eligibleNodes, len(eligibleNodes), - "2") + "2", + keysHandler, + ) rcns.SetConsensusGroup(pubKeys) @@ -49,7 +58,7 @@ func TestRoundConsensus_ConsensusGroupIndexFound(t *testing.T) { eligibleNodes[pubKeys[i]] = struct{}{} } - rcns := spos.NewRoundConsensus(eligibleNodes, 3, "key3") + rcns, _ := spos.NewRoundConsensus(eligibleNodes, 3, "key3", &testscommon.KeysHandlerStub{}) rcns.SetConsensusGroup(pubKeys) index, err := rcns.ConsensusGroupIndex("key3") @@ -67,7 +76,7 @@ func TestRoundConsensus_ConsensusGroupIndexNotFound(t *testing.T) { eligibleNodes[pubKeys[i]] = struct{}{} } - rcns := spos.NewRoundConsensus(eligibleNodes, 3, "key4") + rcns, _ := spos.NewRoundConsensus(eligibleNodes, 3, "key4", &testscommon.KeysHandlerStub{}) rcns.SetConsensusGroup(pubKeys) index, err := rcns.ConsensusGroupIndex("key4") @@ -85,7 +94,7 @@ func TestRoundConsensus_IndexSelfConsensusGroupInConsesus(t *testing.T) { eligibleNodes[pubKeys[i]] = struct{}{} } - rcns := spos.NewRoundConsensus(eligibleNodes, 3, "key2") + rcns, _ := spos.NewRoundConsensus(eligibleNodes, 3, "key2", &testscommon.KeysHandlerStub{}) rcns.SetConsensusGroup(pubKeys) index, err := rcns.SelfConsensusGroupIndex() @@ -103,7 +112,7 @@ func TestRoundConsensus_IndexSelfConsensusGroupNotFound(t *testing.T) { eligibleNodes[pubKeys[i]] = struct{}{} } - rcns := spos.NewRoundConsensus(eligibleNodes, 3, "key4") + rcns, _ := spos.NewRoundConsensus(eligibleNodes, 3, "key4", &testscommon.KeysHandlerStub{}) rcns.SetConsensusGroup(pubKeys) index, err := rcns.SelfConsensusGroupIndex() @@ -217,7 +226,7 @@ func TestRoundConsensus_SetSelfJobDoneShouldWork(t *testing.T) { rcns := *initRoundConsensus() - _ = rcns.SetSelfJobDone(bls.SrBlock, true) + _ = rcns.SetJobDone(rcns.SelfPubKey(), bls.SrBlock, true) jobDone, _ := rcns.JobDone("2", bls.SrBlock) assert.True(t, jobDone) @@ -267,3 +276,55 @@ func TestRoundConsensus_ResetValidationMap(t *testing.T) { assert.Equal(t, false, jobDone) assert.Nil(t, err) } + +func TestRoundConsensus_IsMultiKeyInConsensusGroup(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{} + roundConsensus := initRoundConsensusWithKeysHandler(keysHandler) + t.Run("no consensus key is managed by current node should return false", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return false + } + assert.False(t, roundConsensus.IsMultiKeyInConsensusGroup()) + }) + t.Run("consensus key is managed by current node should return true", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("2"), pkBytes) + } + assert.True(t, roundConsensus.IsMultiKeyInConsensusGroup()) + }) +} + +func TestRoundConsensus_IsKeyManagedByCurrentNode(t *testing.T) { + t.Parallel() + + managedPkBytes := []byte("managed pk bytes") + wasCalled := false + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + assert.Equal(t, managedPkBytes, pkBytes) + wasCalled = true + return true + }, + } + roundConsensus := initRoundConsensusWithKeysHandler(keysHandler) + assert.True(t, roundConsensus.IsKeyManagedByCurrentNode(managedPkBytes)) + assert.True(t, wasCalled) +} + +func TestRoundConsensus_IncrementRoundsWithoutReceivedMessages(t *testing.T) { + t.Parallel() + + managedPkBytes := []byte("managed pk bytes") + wasCalled := false + keysHandler := &testscommon.KeysHandlerStub{ + IncrementRoundsWithoutReceivedMessagesCalled: func(pkBytes []byte) { + assert.Equal(t, managedPkBytes, pkBytes) + wasCalled = true + }, + } + roundConsensus := initRoundConsensusWithKeysHandler(keysHandler) + roundConsensus.IncrementRoundsWithoutReceivedMessages(managedPkBytes) + assert.True(t, wasCalled) +} diff --git a/consensus/spos/sposFactory/sposFactory.go b/consensus/spos/sposFactory/sposFactory.go index a4c51cb6f1d..8fcc249c791 100644 --- a/consensus/spos/sposFactory/sposFactory.go +++ b/consensus/spos/sposFactory/sposFactory.go @@ -64,11 +64,11 @@ func GetBroadcastMessenger( hasher hashing.Hasher, messenger consensus.P2PMessenger, shardCoordinator sharding.Coordinator, - privateKey crypto.PrivateKey, peerSignatureHandler crypto.PeerSignatureHandler, headersSubscriber consensus.HeadersPoolSubscriber, interceptorsContainer process.InterceptorsContainer, alarmScheduler core.TimersScheduler, + keysHandler consensus.KeysHandler, ) (consensus.BroadcastMessenger, error) { if check.IfNil(shardCoordinator) { @@ -79,7 +79,6 @@ func GetBroadcastMessenger( Marshalizer: marshalizer, Hasher: hasher, Messenger: messenger, - PrivateKey: privateKey, ShardCoordinator: shardCoordinator, PeerSignatureHandler: peerSignatureHandler, HeadersSubscriber: headersSubscriber, @@ -87,6 +86,7 @@ func GetBroadcastMessenger( MaxValidatorDelayCacheSize: maxDelayCacheSize, InterceptorsContainer: interceptorsContainer, AlarmScheduler: alarmScheduler, + KeysHandler: keysHandler, } if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { diff --git a/consensus/spos/sposFactory/sposFactory_test.go b/consensus/spos/sposFactory/sposFactory_test.go index ae01b73458b..8cf38e9d0dc 100644 --- a/consensus/spos/sposFactory/sposFactory_test.go +++ b/consensus/spos/sposFactory/sposFactory_test.go @@ -135,7 +135,6 @@ func TestGetBroadcastMessenger_ShardShouldWork(t *testing.T) { shardCoord.SelfIDCalled = func() uint32 { return 0 } - privateKey := &mock.PrivateKeyMock{} peerSigHandler := &mock.PeerSignatureHandler{} headersSubscriber := &mock.HeadersCacherStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} @@ -146,11 +145,11 @@ func TestGetBroadcastMessenger_ShardShouldWork(t *testing.T) { hasher, messenger, shardCoord, - privateKey, peerSigHandler, headersSubscriber, interceptosContainer, alarmSchedulerStub, + &testscommon.KeysHandlerStub{}, ) assert.Nil(t, err) @@ -167,7 +166,6 @@ func TestGetBroadcastMessenger_MetachainShouldWork(t *testing.T) { shardCoord.SelfIDCalled = func() uint32 { return core.MetachainShardId } - privateKey := &mock.PrivateKeyMock{} peerSigHandler := &mock.PeerSignatureHandler{} headersSubscriber := &mock.HeadersCacherStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} @@ -178,11 +176,11 @@ func TestGetBroadcastMessenger_MetachainShouldWork(t *testing.T) { hasher, messenger, shardCoord, - privateKey, peerSigHandler, headersSubscriber, interceptosContainer, alarmSchedulerStub, + &testscommon.KeysHandlerStub{}, ) assert.Nil(t, err) @@ -202,10 +200,10 @@ func TestGetBroadcastMessenger_NilShardCoordinatorShouldErr(t *testing.T) { nil, nil, nil, - nil, headersSubscriber, interceptosContainer, alarmSchedulerStub, + &testscommon.KeysHandlerStub{}, ) assert.Nil(t, bm) @@ -229,10 +227,10 @@ func TestGetBroadcastMessenger_InvalidShardIdShouldErr(t *testing.T) { nil, shardCoord, nil, - nil, headersSubscriber, interceptosContainer, alarmSchedulerStub, + &testscommon.KeysHandlerStub{}, ) assert.Nil(t, bm) diff --git a/consensus/spos/subround.go b/consensus/spos/subround.go index 8d5daaeb0d8..94d906488df 100644 --- a/consensus/spos/subround.go +++ b/consensus/spos/subround.go @@ -204,6 +204,11 @@ func (sr *Subround) ConsensusChannel() chan bool { return sr.consensusStateChangedChannel } +// GetAssociatedPid returns the associated PeerID to the provided public key bytes +func (sr *Subround) GetAssociatedPid(pkBytes []byte) core.PeerID { + return sr.keysHandler.GetAssociatedPid(pkBytes) +} + // IsInterfaceNil returns true if there is no value under the interface func (sr *Subround) IsInterfaceNil() bool { return sr == nil diff --git a/consensus/spos/subround_test.go b/consensus/spos/subround_test.go index 74e5a0f8e04..202899e1a24 100644 --- a/consensus/spos/subround_test.go +++ b/consensus/spos/subround_test.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/stretchr/testify/assert" @@ -48,10 +49,12 @@ func initConsensusState() *spos.ConsensusState { } indexLeader := 1 - rcns := spos.NewRoundConsensus( + rcns, _ := spos.NewRoundConsensus( eligibleNodesKeys, consensusGroupSize, - eligibleList[indexLeader]) + eligibleList[indexLeader], + &testscommon.KeysHandlerStub{}, + ) rcns.SetConsensusGroup(eligibleList) rcns.ResetRoundState() @@ -931,3 +934,40 @@ func TestSubround_Name(t *testing.T) { assert.Equal(t, "(BLOCK)", sr.Name()) } + +func TestSubround_GetAssociatedPid(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{} + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := mock.InitConsensusCore() + + subround, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + wasCalled := false + pid := core.PeerID("a pid") + providedPkBytes := []byte("pk bytes") + keysHandler.GetAssociatedPidCalled = func(pkBytes []byte) core.PeerID { + assert.Equal(t, providedPkBytes, pkBytes) + wasCalled = true + return pid + } + + assert.Equal(t, pid, subround.GetAssociatedPid(providedPkBytes)) + assert.True(t, wasCalled) +} diff --git a/consensus/spos/worker.go b/consensus/spos/worker.go index c1da088af15..e91ac9c2bda 100644 --- a/consensus/spos/worker.go +++ b/consensus/spos/worker.go @@ -376,6 +376,8 @@ func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedP return err } + wrk.consensusState.UpdatePublicKeyLiveness(cnsMsg.GetPubKey(), message.Peer()) + if wrk.nodeRedundancyHandler.IsRedundancyNode() { wrk.nodeRedundancyHandler.ResetInactivityIfNeeded( wrk.consensusState.SelfPubKey(), @@ -547,7 +549,8 @@ func (wrk *Worker) processReceivedHeaderMetric(cnsDta *consensus.Message) { } func (wrk *Worker) checkSelfState(cnsDta *consensus.Message) error { - if wrk.consensusState.SelfPubKey() == string(cnsDta.PubKey) { + isMultiKeyManagedBySelf := wrk.consensusState.keysHandler.IsKeyManagedByCurrentNode(cnsDta.PubKey) + if wrk.consensusState.SelfPubKey() == string(cnsDta.PubKey) || isMultiKeyManagedBySelf { return ErrMessageFromItself } diff --git a/dataRetriever/common.go b/dataRetriever/common.go index 785c9f2bd14..a9efaffb09d 100644 --- a/dataRetriever/common.go +++ b/dataRetriever/common.go @@ -28,6 +28,29 @@ func SetEpochHandlerToHdrResolver( return nil } +// SetEpochHandlerToHdrRequester sets the epoch handler to the metablock hdr requester +func SetEpochHandlerToHdrRequester( + requestersContainer RequestersContainer, + epochHandler EpochHandler, +) error { + requester, err := requestersContainer.Get(factory.MetachainBlocksTopic) + if err != nil { + return err + } + + hdrRequester, ok := requester.(HeaderRequester) + if !ok { + return ErrWrongTypeInContainer + } + + err = hdrRequester.SetEpochHandler(epochHandler) + if err != nil { + return err + } + + return nil +} + // GetHdrNonceHashDataUnit gets the HdrNonceHashDataUnit by shard func GetHdrNonceHashDataUnit(shard uint32) UnitType { if shard == core.MetachainShardId { diff --git a/dataRetriever/common_test.go b/dataRetriever/common_test.go index 2abcdc74e69..14017661ddd 100644 --- a/dataRetriever/common_test.go +++ b/dataRetriever/common_test.go @@ -7,69 +7,132 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" + dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/stretchr/testify/require" ) -func TestSetEpochHandlerToHdrResolver_GetErr(t *testing.T) { - t.Parallel() - - localErr := errors.New("err") - resolverContainer := &mock.ResolversContainerStub{ - GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { - return nil, localErr - }, - } - epochHandler := &mock.EpochHandlerStub{} - - err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) - require.Equal(t, localErr, err) -} - -func TestSetEpochHandlerToHdrResolver_CannotSetEpoch(t *testing.T) { - t.Parallel() - - localErr := errors.New("err") - resolverContainer := &mock.ResolversContainerStub{ - GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { - return &mock.HeaderResolverStub{ - SetEpochHandlerCalled: func(epochHandler dataRetriever.EpochHandler) error { - return localErr - }, - }, nil - }, - } - epochHandler := &mock.EpochHandlerStub{} - - err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) - require.Equal(t, localErr, err) -} +var expectedErr = errors.New("err") -func TestSetEpochHandlerToHdrResolver_WrongType(t *testing.T) { +func TestSetEpochHandlerToHdrResolver(t *testing.T) { t.Parallel() - resolverContainer := &mock.ResolversContainerStub{ - GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { - return nil, nil - }, - } - epochHandler := &mock.EpochHandlerStub{} - - err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) - require.Equal(t, dataRetriever.ErrWrongTypeInContainer, err) + t.Run("get function errors should return error", func(t *testing.T) { + t.Parallel() + + resolverContainer := &dataRetrieverMock.ResolversContainerStub{ + GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { + return nil, expectedErr + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) + require.Equal(t, expectedErr, err) + }) + t.Run("set epoch handler errors should return error", func(t *testing.T) { + t.Parallel() + + resolverContainer := &dataRetrieverMock.ResolversContainerStub{ + GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { + return &mock.HeaderResolverStub{ + SetEpochHandlerCalled: func(epochHandler dataRetriever.EpochHandler) error { + return expectedErr + }, + }, nil + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) + require.Equal(t, expectedErr, err) + }) + t.Run("wrong type should return error", func(t *testing.T) { + t.Parallel() + + resolverContainer := &dataRetrieverMock.ResolversContainerStub{ + GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { + return nil, nil + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) + require.Equal(t, dataRetriever.ErrWrongTypeInContainer, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + resolverContainer := &dataRetrieverMock.ResolversContainerStub{ + GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { + return &mock.HeaderResolverStub{}, nil + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) + require.Nil(t, err) + }) } -func TestSetEpochHandlerToHdrResolver_Ok(t *testing.T) { +func TestSetEpochHandlerToHdrRequester(t *testing.T) { t.Parallel() - resolverContainer := &mock.ResolversContainerStub{ - GetCalled: func(key string) (resolver dataRetriever.Resolver, err error) { - return &mock.HeaderResolverStub{}, nil - }, - } - epochHandler := &mock.EpochHandlerStub{} - - err := dataRetriever.SetEpochHandlerToHdrResolver(resolverContainer, epochHandler) - require.Nil(t, err) + t.Run("get function errors should return error", func(t *testing.T) { + t.Parallel() + + requestersContainer := &dataRetrieverMock.RequestersContainerStub{ + GetCalled: func(key string) (requester dataRetriever.Requester, err error) { + return nil, expectedErr + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrRequester(requestersContainer, epochHandler) + require.Equal(t, expectedErr, err) + }) + t.Run("set epoch handler errors should return error", func(t *testing.T) { + t.Parallel() + + requestersContainer := &dataRetrieverMock.RequestersContainerStub{ + GetCalled: func(key string) (requester dataRetriever.Requester, err error) { + return &mock.HeaderRequesterStub{ + SetEpochHandlerCalled: func(epochHandler dataRetriever.EpochHandler) error { + return expectedErr + }, + }, nil + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrRequester(requestersContainer, epochHandler) + require.Equal(t, expectedErr, err) + }) + t.Run("wrong type should return error", func(t *testing.T) { + t.Parallel() + + requestersContainer := &dataRetrieverMock.RequestersContainerStub{ + GetCalled: func(key string) (requester dataRetriever.Requester, err error) { + return nil, nil + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrRequester(requestersContainer, epochHandler) + require.Equal(t, dataRetriever.ErrWrongTypeInContainer, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + requestersContainer := &dataRetrieverMock.RequestersContainerStub{ + GetCalled: func(key string) (resolver dataRetriever.Requester, err error) { + return &mock.HeaderRequesterStub{}, nil + }, + } + epochHandler := &mock.EpochHandlerStub{} + + err := dataRetriever.SetEpochHandlerToHdrRequester(requestersContainer, epochHandler) + require.Nil(t, err) + }) } func TestGetHdrNonceHashDataUnit(t *testing.T) { diff --git a/dataRetriever/factory/storageRequestersContainer/args.go b/dataRetriever/factory/storageRequestersContainer/args.go index a455c1145f8..70a2db6501e 100644 --- a/dataRetriever/factory/storageRequestersContainer/args.go +++ b/dataRetriever/factory/storageRequestersContainer/args.go @@ -25,4 +25,5 @@ type FactoryArgs struct { DataPacker dataRetriever.DataPacker ManualEpochStartNotifier dataRetriever.ManualEpochStartNotifier ChanGracefullyClose chan endProcess.ArgEndProcess + SnapshotsEnabled bool } diff --git a/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go index 0ec963e58ff..0157ca5c634 100644 --- a/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go @@ -39,6 +39,7 @@ type baseRequestersContainerFactory struct { shardIDForTries uint32 chainID string workingDir string + snapshotsEnabled bool } func (brcf *baseRequestersContainerFactory) checkParams() error { @@ -258,7 +259,7 @@ func (brcf *baseRequestersContainerFactory) newImportDBTrieStorage( PruningEnabled: brcf.generalConfig.StateTriesConfig.AccountsStatePruningEnabled, CheckpointsEnabled: brcf.generalConfig.StateTriesConfig.CheckpointsEnabled, MaxTrieLevelInMem: brcf.generalConfig.StateTriesConfig.MaxStateTrieLevelInMemory, - SnapshotsEnabled: brcf.generalConfig.StateTriesConfig.SnapshotsEnabled, + SnapshotsEnabled: brcf.snapshotsEnabled, IdleProvider: disabled.NewProcessStatusHandler(), } return trieFactoryInstance.Create(args) diff --git a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go index 28501fbe438..498d02cc1b3 100644 --- a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go @@ -36,6 +36,7 @@ func NewMetaRequestersContainerFactory( shardIDForTries: args.ShardIDForTries, chainID: args.ChainID, workingDir: args.WorkingDirectory, + snapshotsEnabled: args.SnapshotsEnabled, } err := base.checkParams() diff --git a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go index 199f2c940d1..a53aca90aaf 100644 --- a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go @@ -205,7 +205,6 @@ func getArgumentsMeta() storagerequesterscontainer.FactoryArgs { }, StateTriesConfig: config.StateTriesConfig{ CheckpointRoundsModulus: 100, - SnapshotsEnabled: true, AccountsStatePruningEnabled: false, PeerStatePruningEnabled: false, MaxStateTrieLevelInMemory: 5, @@ -224,5 +223,6 @@ func getArgumentsMeta() storagerequesterscontainer.FactoryArgs { DataPacker: &mock.DataPackerStub{}, ManualEpochStartNotifier: &mock.ManualEpochStartNotifierStub{}, ChanGracefullyClose: make(chan endProcess.ArgEndProcess), + SnapshotsEnabled: true, } } diff --git a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go index b982f218b8d..f1298ae1391 100644 --- a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go @@ -6,7 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" - "github.com/multiversx/mx-chain-go/dataRetriever/storageRequesters" + storagerequesters "github.com/multiversx/mx-chain-go/dataRetriever/storageRequesters" "github.com/multiversx/mx-chain-go/process/factory" ) @@ -36,6 +36,7 @@ func NewShardRequestersContainerFactory( shardIDForTries: args.ShardIDForTries, chainID: args.ChainID, workingDir: args.WorkingDirectory, + snapshotsEnabled: args.SnapshotsEnabled, } err := base.checkParams() diff --git a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go index c41ee5b8dec..71319735278 100644 --- a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go @@ -190,7 +190,6 @@ func getArgumentsShard() storagerequesterscontainer.FactoryArgs { }, StateTriesConfig: config.StateTriesConfig{ CheckpointRoundsModulus: 100, - SnapshotsEnabled: true, AccountsStatePruningEnabled: false, PeerStatePruningEnabled: false, MaxStateTrieLevelInMemory: 5, @@ -209,5 +208,6 @@ func getArgumentsShard() storagerequesterscontainer.FactoryArgs { DataPacker: &mock.DataPackerStub{}, ManualEpochStartNotifier: &mock.ManualEpochStartNotifierStub{}, ChanGracefullyClose: make(chan endProcess.ArgEndProcess), + SnapshotsEnabled: true, } } diff --git a/dataRetriever/interface.go b/dataRetriever/interface.go index a691cb15cda..5f2a5850067 100644 --- a/dataRetriever/interface.go +++ b/dataRetriever/interface.go @@ -42,6 +42,12 @@ type HeaderResolver interface { SetEpochHandler(epochHandler EpochHandler) error } +// HeaderRequester defines what a block header requester should do +type HeaderRequester interface { + Requester + SetEpochHandler(epochHandler EpochHandler) error +} + // TopicResolverSender defines what sending operations are allowed for a topic resolver type TopicResolverSender interface { Send(buff []byte, peer core.PeerID) error diff --git a/dataRetriever/mock/headerRequesterStub.go b/dataRetriever/mock/headerRequesterStub.go new file mode 100644 index 00000000000..23336b9c79c --- /dev/null +++ b/dataRetriever/mock/headerRequesterStub.go @@ -0,0 +1,60 @@ +package mock + +import "github.com/multiversx/mx-chain-go/dataRetriever" + +// HeaderRequesterStub - +type HeaderRequesterStub struct { + RequestDataFromHashCalled func(hash []byte, epoch uint32) error + SetNumPeersToQueryCalled func(intra int, cross int) + NumPeersToQueryCalled func() (int, int) + SetDebugHandlerCalled func(handler dataRetriever.DebugHandler) error + SetEpochHandlerCalled func(epochHandler dataRetriever.EpochHandler) error +} + +// RequestDataFromHash - +func (stub *HeaderRequesterStub) RequestDataFromHash(hash []byte, epoch uint32) error { + if stub.RequestDataFromHashCalled != nil { + return stub.RequestDataFromHashCalled(hash, epoch) + } + + return nil +} + +// SetNumPeersToQuery - +func (stub *HeaderRequesterStub) SetNumPeersToQuery(intra int, cross int) { + if stub.SetNumPeersToQueryCalled != nil { + stub.SetNumPeersToQueryCalled(intra, cross) + } +} + +// NumPeersToQuery - +func (stub *HeaderRequesterStub) NumPeersToQuery() (int, int) { + if stub.NumPeersToQueryCalled != nil { + return stub.NumPeersToQueryCalled() + } + + return 0, 0 +} + +// SetDebugHandler - +func (stub *HeaderRequesterStub) SetDebugHandler(handler dataRetriever.DebugHandler) error { + if stub.SetDebugHandlerCalled != nil { + return stub.SetDebugHandlerCalled(handler) + } + + return nil +} + +// SetEpochHandler - +func (stub *HeaderRequesterStub) SetEpochHandler(epochHandler dataRetriever.EpochHandler) error { + if stub.SetEpochHandlerCalled != nil { + return stub.SetEpochHandlerCalled(epochHandler) + } + + return nil +} + +// IsInterfaceNil - +func (stub *HeaderRequesterStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/dataRetriever/mock/throttlerStub.go b/dataRetriever/mock/throttlerStub.go index 24ab94c45c3..19155681a71 100644 --- a/dataRetriever/mock/throttlerStub.go +++ b/dataRetriever/mock/throttlerStub.go @@ -1,12 +1,15 @@ package mock +import "sync" + // ThrottlerStub - type ThrottlerStub struct { CanProcessCalled func() bool StartProcessingCalled func() EndProcessingCalled func() - StartWasCalled bool - EndWasCalled bool + mutState sync.RWMutex + startWasCalled bool + endWasCalled bool } // CanProcess - @@ -20,7 +23,10 @@ func (ts *ThrottlerStub) CanProcess() bool { // StartProcessing - func (ts *ThrottlerStub) StartProcessing() { - ts.StartWasCalled = true + ts.mutState.Lock() + ts.startWasCalled = true + ts.mutState.Unlock() + if ts.StartProcessingCalled != nil { ts.StartProcessingCalled() } @@ -28,12 +34,31 @@ func (ts *ThrottlerStub) StartProcessing() { // EndProcessing - func (ts *ThrottlerStub) EndProcessing() { - ts.EndWasCalled = true + ts.mutState.Lock() + ts.endWasCalled = true + ts.mutState.Unlock() + if ts.EndProcessingCalled != nil { ts.EndProcessingCalled() } } +// StartWasCalled - +func (ts *ThrottlerStub) StartWasCalled() bool { + ts.mutState.RLock() + defer ts.mutState.RUnlock() + + return ts.startWasCalled +} + +// EndWasCalled - +func (ts *ThrottlerStub) EndWasCalled() bool { + ts.mutState.RLock() + defer ts.mutState.RUnlock() + + return ts.endWasCalled +} + // IsInterfaceNil - func (ts *ThrottlerStub) IsInterfaceNil() bool { return ts == nil diff --git a/dataRetriever/requestHandlers/requesters/headerRequester.go b/dataRetriever/requestHandlers/requesters/headerRequester.go index 7c0c0e5a687..06b07ad4fbb 100644 --- a/dataRetriever/requestHandlers/requesters/headerRequester.go +++ b/dataRetriever/requestHandlers/requesters/headerRequester.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" ) +var _ dataRetriever.HeaderRequester = (*headerRequester)(nil) + // ArgHeaderRequester is the argument structure used to create a new header requester instance type ArgHeaderRequester struct { ArgBaseRequester @@ -65,6 +67,11 @@ func (requester *headerRequester) RequestDataFromEpoch(identifier []byte) error ) } +// SetEpochHandler does nothing and returns nil +func (requester *headerRequester) SetEpochHandler(_ dataRetriever.EpochHandler) error { + return nil +} + // IsInterfaceNil returns true if there is no value under the interface func (requester *headerRequester) IsInterfaceNil() bool { return requester == nil diff --git a/dataRetriever/requestHandlers/requesters/headerRequester_test.go b/dataRetriever/requestHandlers/requesters/headerRequester_test.go index 07d87dbd092..7389b0a2fac 100644 --- a/dataRetriever/requestHandlers/requesters/headerRequester_test.go +++ b/dataRetriever/requestHandlers/requesters/headerRequester_test.go @@ -99,3 +99,12 @@ func TestHeaderRequester_RequestDataFromEpoch(t *testing.T) { assert.Nil(t, requester.RequestDataFromEpoch(providedIdentifier)) assert.True(t, wasCalled) } + +func TestHeaderRequester_SetEpochHandler(t *testing.T) { + t.Parallel() + + argBase := createMockArgBaseRequester() + requester, _ := NewHeaderRequester(createMockArgHeaderRequester(argBase)) + + assert.Nil(t, requester.SetEpochHandler(nil)) +} diff --git a/dataRetriever/resolvers/disabled/resolver.go b/dataRetriever/resolvers/disabled/resolver.go index adbee5ff4b1..077c98d8f97 100644 --- a/dataRetriever/resolvers/disabled/resolver.go +++ b/dataRetriever/resolvers/disabled/resolver.go @@ -24,6 +24,11 @@ func (r *resolver) SetDebugHandler(_ dataRetriever.DebugHandler) error { return nil } +// SetEpochHandler does nothing and returns nil +func (r *resolver) SetEpochHandler(_ dataRetriever.EpochHandler) error { + return nil +} + // Close returns nil as it is disabled func (r *resolver) Close() error { return nil diff --git a/dataRetriever/resolvers/headerResolver.go b/dataRetriever/resolvers/headerResolver.go index c6bdb96dd58..59216068c2f 100644 --- a/dataRetriever/resolvers/headerResolver.go +++ b/dataRetriever/resolvers/headerResolver.go @@ -1,6 +1,8 @@ package resolvers import ( + "sync" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/typeConverters" @@ -35,6 +37,7 @@ type HeaderResolver struct { headers dataRetriever.HeadersPool hdrNoncesStorage storage.Storer nonceConverter typeConverters.Uint64ByteSliceConverter + mutEpochHandler sync.RWMutex epochHandler dataRetriever.EpochHandler shardCoordinator sharding.Coordinator } @@ -97,7 +100,10 @@ func (hdrRes *HeaderResolver) SetEpochHandler(epochHandler dataRetriever.EpochHa return dataRetriever.ErrNilEpochHandler } + hdrRes.mutEpochHandler.Lock() hdrRes.epochHandler = epochHandler + hdrRes.mutEpochHandler.Unlock() + return nil } @@ -221,7 +227,11 @@ func (hdrRes *HeaderResolver) resolveHeaderFromEpoch(key []byte) ([]byte, error) return nil, err } if isUnknownEpoch { - actualKey = []byte(core.EpochStartIdentifier(hdrRes.epochHandler.MetaEpoch())) + hdrRes.mutEpochHandler.RLock() + metaEpoch := hdrRes.epochHandler.MetaEpoch() + hdrRes.mutEpochHandler.RUnlock() + + actualKey = []byte(core.EpochStartIdentifier(metaEpoch)) } return hdrRes.searchFirst(actualKey) diff --git a/dataRetriever/resolvers/headerResolver_test.go b/dataRetriever/resolvers/headerResolver_test.go index e28c7275c46..5743e20ef75 100644 --- a/dataRetriever/resolvers/headerResolver_test.go +++ b/dataRetriever/resolvers/headerResolver_test.go @@ -3,6 +3,7 @@ package resolvers_test import ( "bytes" "errors" + "sync" "testing" "github.com/multiversx/mx-chain-core-go/core" @@ -163,8 +164,8 @@ func TestHeaderResolver_ProcessReceivedCanProcessMessageErrorsShouldErr(t *testi err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, nil), fromConnectedPeerId) assert.True(t, errors.Is(err, expectedErr)) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -175,8 +176,8 @@ func TestHeaderResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, nil), fromConnectedPeerId) assert.Equal(t, dataRetriever.ErrNilValue, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessage_WrongIdentifierStartBlock(t *testing.T) { @@ -188,8 +189,8 @@ func TestHeaderResolver_ProcessReceivedMessage_WrongIdentifierStartBlock(t *test requestedData := []byte("request") err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "") assert.Equal(t, core.ErrInvalidIdentifierForEpochStartBlockRequest, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessage_Ok(t *testing.T) { @@ -206,8 +207,8 @@ func TestHeaderResolver_ProcessReceivedMessage_Ok(t *testing.T) { requestedData := []byte("request_1") err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "") assert.Nil(t, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestUnknownTypeShouldErr(t *testing.T) { @@ -218,8 +219,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestUnknownTypeShouldErr(t *tes err := hdrRes.ProcessReceivedMessage(createRequestMsg(254, make([]byte, 0)), fromConnectedPeerId) assert.Equal(t, dataRetriever.ErrResolveTypeUnknown, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ValidateRequestHashTypeFoundInHdrPoolShouldSearchAndSend(t *testing.T) { @@ -254,8 +255,8 @@ func TestHeaderResolver_ValidateRequestHashTypeFoundInHdrPoolShouldSearchAndSend assert.Nil(t, err) assert.True(t, searchWasCalled) assert.True(t, sendWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestHashTypeFoundInHdrPoolMarshalizerFailsShouldErr(t *testing.T) { @@ -295,8 +296,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestHashTypeFoundInHdrPoolMarsh err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId) assert.Equal(t, errExpected, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestRetFromStorageShouldRetValAndSend(t *testing.T) { @@ -337,8 +338,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestRetFromStorageShouldRetValA assert.Nil(t, err) assert.True(t, wasGotFromStorage) assert.True(t, wasSent) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeInvalidSliceShouldErr(t *testing.T) { @@ -354,8 +355,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeInvalidSliceShould err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("aaa")), fromConnectedPeerId) assert.Equal(t, dataRetriever.ErrInvalidNonceByteSlice, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceShouldCallWithTheCorrectEpoch(t *testing.T) { @@ -380,8 +381,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceShouldCallWithTheCorre ) msg := &mock.P2PMessageMock{DataField: buff} _ = hdrRes.ProcessReceivedMessage(msg, "") - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeNotFoundInHdrNoncePoolAndStorageShouldRetNilAndNotSend(t *testing.T) { @@ -427,8 +428,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeNotFoundInHdrNonce ) assert.Equal(t, expectedErr, err) assert.False(t, wasSent) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoolShouldRetFromPoolAndSend(t *testing.T) { @@ -474,8 +475,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo assert.Nil(t, err) assert.True(t, wasResolved) assert.True(t, wasSent) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoolShouldRetFromStorageAndSend(t *testing.T) { @@ -536,8 +537,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo assert.Nil(t, err) assert.True(t, wasResolved) assert.True(t, wasSend) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoolCheckRetErr(t *testing.T) { @@ -591,8 +592,8 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo ) assert.Equal(t, errExpected, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestHeaderResolver_SetEpochHandlerNilShouldErr(t *testing.T) { @@ -627,3 +628,29 @@ func TestHeaderResolver_Close(t *testing.T) { assert.Nil(t, hdrRes.Close()) } + +func TestHeaderResolver_SetEpochHandlerConcurrency(t *testing.T) { + t.Parallel() + + arg := createMockArgHeaderResolver() + hdrRes, _ := resolvers.NewHeaderResolver(arg) + + eh := &mock.EpochHandlerStub{} + var wg sync.WaitGroup + numCalls := 1000 + wg.Add(numCalls) + for i := 0; i < numCalls; i++ { + go func(idx int) { + defer wg.Done() + + if idx == 0 { + err := hdrRes.SetEpochHandler(eh) + assert.Nil(t, err) + return + } + err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, []byte("request_1")), fromConnectedPeerId) + assert.Nil(t, err) + }(i) + } + wg.Wait() +} diff --git a/dataRetriever/resolvers/miniblockResolver_test.go b/dataRetriever/resolvers/miniblockResolver_test.go index 4e98fc75c3c..94d82e2bf92 100644 --- a/dataRetriever/resolvers/miniblockResolver_test.go +++ b/dataRetriever/resolvers/miniblockResolver_test.go @@ -118,8 +118,8 @@ func TestMiniblockResolver_ProcessReceivedAntifloodErrorsShouldErr(t *testing.T) err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeerId) assert.True(t, errors.Is(err, expectedErr)) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestMiniblockResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -130,8 +130,8 @@ func TestMiniblockResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeerId) assert.Equal(t, dataRetriever.ErrNilValue, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestMiniblockResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { @@ -143,8 +143,8 @@ func TestMiniblockResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, make([]byte, 0)), fromConnectedPeerId) assert.True(t, errors.Is(err, dataRetriever.ErrRequestTypeNotImplemented)) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolShouldRetValAndSend(t *testing.T) { @@ -194,8 +194,8 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolShouldRetValAndSend( assert.Nil(t, err) assert.True(t, wasResolved) assert.True(t, wasSent) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolMarshalizerFailShouldErr(t *testing.T) { @@ -246,8 +246,8 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolMarshalizerFailShoul ) assert.True(t, errors.Is(err, errExpected)) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestMiniblockResolver_ProcessReceivedMessageNotFoundInPoolShouldRetFromStorageAndSend(t *testing.T) { @@ -293,8 +293,8 @@ func TestMiniblockResolver_ProcessReceivedMessageNotFoundInPoolShouldRetFromStor assert.Nil(t, err) assert.True(t, wasResolved) assert.True(t, wasSend) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestMiniblockResolver_ProcessReceivedMessageMissingDataShouldNotSend(t *testing.T) { @@ -335,8 +335,8 @@ func TestMiniblockResolver_ProcessReceivedMessageMissingDataShouldNotSend(t *tes ) assert.False(t, wasSent) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestMiniblockResolver_Close(t *testing.T) { diff --git a/dataRetriever/resolvers/peerAuthenticationResolver_test.go b/dataRetriever/resolvers/peerAuthenticationResolver_test.go index 25ef0cbb9ec..22b75093a4a 100644 --- a/dataRetriever/resolvers/peerAuthenticationResolver_test.go +++ b/dataRetriever/resolvers/peerAuthenticationResolver_test.go @@ -181,8 +181,8 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.ChunkType, nil), fromConnectedPeer) assert.True(t, errors.Is(err, expectedErr)) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) }) t.Run("parseReceivedMessage returns error due to marshaller error", func(t *testing.T) { t.Parallel() diff --git a/dataRetriever/resolvers/transactionResolver_test.go b/dataRetriever/resolvers/transactionResolver_test.go index 986b37dce41..a1fb37413f6 100644 --- a/dataRetriever/resolvers/transactionResolver_test.go +++ b/dataRetriever/resolvers/transactionResolver_test.go @@ -133,8 +133,8 @@ func TestTxResolver_ProcessReceivedMessageCanProcessMessageErrorsShouldErr(t *te err := txRes.ProcessReceivedMessage(&mock.P2PMessageMock{}, connectedPeerId) assert.True(t, errors.Is(err, expectedErr)) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { @@ -146,8 +146,8 @@ func TestTxResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { err := txRes.ProcessReceivedMessage(nil, connectedPeerId) assert.Equal(t, dataRetriever.ErrNilMessage, err) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { @@ -163,8 +163,8 @@ func TestTxResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { err := txRes.ProcessReceivedMessage(msg, connectedPeerId) assert.True(t, errors.Is(err, dataRetriever.ErrRequestTypeNotImplemented)) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -180,8 +180,8 @@ func TestTxResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { err := txRes.ProcessReceivedMessage(msg, connectedPeerId) assert.Equal(t, dataRetriever.ErrNilValue, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageFoundInTxPoolShouldSearchAndSend(t *testing.T) { @@ -222,8 +222,8 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxPoolShouldSearchAndSend(t *te assert.Nil(t, err) assert.True(t, searchWasCalled) assert.True(t, sendWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageFoundInTxPoolMarshalizerFailShouldRetNilAndErr(t *testing.T) { @@ -264,8 +264,8 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxPoolMarshalizerFailShouldRetN err := txRes.ProcessReceivedMessage(msg, connectedPeerId) assert.True(t, errors.Is(err, errExpected)) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageFoundInTxStorageShouldRetValAndSend(t *testing.T) { @@ -314,8 +314,8 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxStorageShouldRetValAndSend(t assert.Nil(t, err) assert.True(t, searchWasCalled) assert.True(t, sendWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageFoundInTxStorageCheckRetError(t *testing.T) { @@ -352,8 +352,8 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxStorageCheckRetError(t *testi err := txRes.ProcessReceivedMessage(msg, connectedPeerId) assert.True(t, errors.Is(err, errExpected)) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsShouldCallSliceSplitter(t *testing.T) { @@ -414,8 +414,8 @@ func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsShouldCal assert.Nil(t, err) assert.True(t, splitSliceWasCalled) assert.True(t, sendWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsFoundOnlyOneShouldWork(t *testing.T) { @@ -475,8 +475,8 @@ func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsFoundOnly assert.NotNil(t, err) assert.True(t, splitSliceWasCalled) assert.True(t, sendWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTxResolver_Close(t *testing.T) { diff --git a/dataRetriever/resolvers/trieNodeResolver_test.go b/dataRetriever/resolvers/trieNodeResolver_test.go index ae6829e9fff..ac6e252d248 100644 --- a/dataRetriever/resolvers/trieNodeResolver_test.go +++ b/dataRetriever/resolvers/trieNodeResolver_test.go @@ -109,8 +109,8 @@ func TestTrieNodeResolver_ProcessReceivedAntiflooderCanProcessMessageErrShouldEr err := tnRes.ProcessReceivedMessage(&mock.P2PMessageMock{}, fromConnectedPeer) assert.True(t, errors.Is(err, expectedErr)) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTrieNodeResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { @@ -121,8 +121,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T err := tnRes.ProcessReceivedMessage(nil, fromConnectedPeer) assert.Equal(t, dataRetriever.ErrNilMessage, err) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTrieNodeResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { @@ -138,8 +138,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Equal(t, dataRetriever.ErrRequestTypeNotImplemented, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTrieNodeResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -155,8 +155,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Equal(t, dataRetriever.ErrNilValue, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } //TODO in this PR: add more unit tests @@ -197,8 +197,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageShouldGetFromTrieAndSend(t *test assert.Nil(t, err) assert.True(t, getSerializedNodesWasCalled) assert.True(t, sendWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTrieNodeResolver_ProcessReceivedMessageShouldGetFromTrieAndMarshalizerFailShouldRetNilAndErr(t *testing.T) { @@ -224,8 +224,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageShouldGetFromTrieAndMarshalizerF err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Equal(t, errExpected, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTrieNodeResolver_ProcessReceivedMessageTrieErrorsShouldErr(t *testing.T) { @@ -244,8 +244,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageTrieErrorsShouldErr(t *testing.T err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Equal(t, expectedErr, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodeErrorsShouldNotSend(t *testing.T) { @@ -280,8 +280,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodeE err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Nil(t, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodesErrorsShouldNotSendSubtrie(t *testing.T) { @@ -333,8 +333,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodes err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Nil(t, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.Equal(t, 1, len(receivedNodes)) assert.Equal(t, nodes[0], receivedNodes[0]) } @@ -389,8 +389,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesNotEnoughSpaceShou err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Nil(t, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.Equal(t, 1, len(receivedNodes)) assert.Equal(t, nodes[0], receivedNodes[0]) } @@ -450,8 +450,8 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesShouldWorkWithSubt err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Nil(t, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.Equal(t, 4, len(receivedNodes)) for _, n := range nodes { assert.True(t, buffInSlice(n, receivedNodes)) @@ -516,8 +516,8 @@ func testTrieNodeResolverProcessReceivedMessageLargeTrieNode( err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer) assert.Nil(t, err) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.True(t, sendWasCalled) } diff --git a/dataRetriever/resolvers/validatorInfoResolver_test.go b/dataRetriever/resolvers/validatorInfoResolver_test.go index b25c3abe533..88d115de3cb 100644 --- a/dataRetriever/resolvers/validatorInfoResolver_test.go +++ b/dataRetriever/resolvers/validatorInfoResolver_test.go @@ -145,7 +145,6 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("canProcessMessage due to antiflood handler error", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") args := createMockArgValidatorInfoResolver() args.AntifloodHandler = &mock.P2PAntifloodHandlerStub{ CanProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { @@ -157,13 +156,12 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeer) assert.True(t, errors.Is(err, expectedErr)) - assert.False(t, args.Throttler.(*mock.ThrottlerStub).StartWasCalled) - assert.False(t, args.Throttler.(*mock.ThrottlerStub).EndWasCalled) + assert.False(t, args.Throttler.(*mock.ThrottlerStub).StartWasCalled()) + assert.False(t, args.Throttler.(*mock.ThrottlerStub).EndWasCalled()) }) t.Run("parseReceivedMessage returns error due to marshalizer error", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") args := createMockArgValidatorInfoResolver() args.Marshaller = &mock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { @@ -191,7 +189,6 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("data not found in cache and fetchValidatorInfoByteSlice fails when getting data from storage", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") args := createMockArgValidatorInfoResolver() args.ValidatorInfoPool = &testscommon.ShardedDataStub{ SearchFirstDataCalled: func(key []byte) (value interface{}, ok bool) { @@ -212,7 +209,6 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("data found in cache but marshal fails", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") marshallerMock := testscommon.MarshalizerMock{} args := createMockArgValidatorInfoResolver() args.ValidatorInfoPool = &testscommon.ShardedDataStub{ @@ -237,7 +233,6 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("data found in storage but marshal fails", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") marshallerMock := testscommon.MarshalizerMock{} args := createMockArgValidatorInfoResolver() args.ValidatorInfoPool = &testscommon.ShardedDataStub{ @@ -341,7 +336,6 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("unmarshal fails", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") args := createMockArgValidatorInfoResolver() args.Marshaller = &testscommon.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { @@ -388,7 +382,6 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("pack data in chuncks returns error", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") args := createMockArgValidatorInfoResolver() args.ValidatorInfoPool = &testscommon.ShardedDataStub{ SearchFirstDataCalled: func(key []byte) (value interface{}, ok bool) { @@ -507,9 +500,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { _ = marshallerMock.Unmarshal(vi, b.Data[i]) // remove this info from the provided map - buff, err := testMarshaller.Marshal(vi) + validatorInfoBuff, err := testMarshaller.Marshal(vi) require.Nil(t, err) - hash := testHasher.Compute(string(buff)) + hash := testHasher.Compute(string(validatorInfoBuff)) delete(providedDataMap, string(hash)) } diff --git a/dataRetriever/storageRequesters/headerRequester.go b/dataRetriever/storageRequesters/headerRequester.go index b9da91de630..0545ff5bccf 100644 --- a/dataRetriever/storageRequesters/headerRequester.go +++ b/dataRetriever/storageRequesters/headerRequester.go @@ -14,6 +14,8 @@ import ( logger "github.com/multiversx/mx-chain-logger-go" ) +var _ dataRetriever.HeaderRequester = (*headerRequester)(nil) + var log = logger.GetOrCreate("dataretriever/storagerequesters") // ArgHeaderRequester is the argument structure used to create new headerRequester instance diff --git a/epochStart/bootstrap/common.go b/epochStart/bootstrap/common.go index 0db2481a289..7b1bcec3c88 100644 --- a/epochStart/bootstrap/common.go +++ b/epochStart/bootstrap/common.go @@ -109,6 +109,9 @@ func checkArguments(args ArgsEpochStartBootstrap) error { if args.GeneralConfig.TrieSync.NumConcurrentTrieSyncers < 1 { return fmt.Errorf("%s: %w", baseErrorMessage, epochStart.ErrInvalidNumConcurrentTrieSyncers) } + if check.IfNil(args.CryptoComponentsHolder.ManagedPeersHolder()) { + return fmt.Errorf("%s: %w", baseErrorMessage, epochStart.ErrNilManagedPeersHolder) + } return nil } diff --git a/epochStart/bootstrap/fromLocalStorage.go b/epochStart/bootstrap/fromLocalStorage.go index 8b90d0d21a8..6905812d935 100644 --- a/epochStart/bootstrap/fromLocalStorage.go +++ b/epochStart/bootstrap/fromLocalStorage.go @@ -111,6 +111,7 @@ func (e *epochStartBootstrap) prepareEpochFromStorage() (Parameters, error) { e.closeTrieComponents() e.storageService = disabled.NewChainStorer() triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + e.flagsConfig.SnapshotsEnabled, e.generalConfig, e.coreComponentsHolder, e.storageService, diff --git a/epochStart/bootstrap/interface.go b/epochStart/bootstrap/interface.go index 34442890b5e..e934e450f7c 100644 --- a/epochStart/bootstrap/interface.go +++ b/epochStart/bootstrap/interface.go @@ -41,7 +41,9 @@ type Messenger interface { ConnectedPeers() []core.PeerID Verify(payload []byte, pid core.PeerID, signature []byte) error Broadcast(topic string, buff []byte) + BroadcastUsingPrivateKey(topic string, buff []byte, pid core.PeerID, skBytes []byte) Sign(payload []byte) ([]byte, error) + SignUsingPrivateKey(skBytes []byte, payload []byte) ([]byte, error) } // RequestHandler defines which methods a request handler should implement diff --git a/epochStart/bootstrap/metaStorageHandler.go b/epochStart/bootstrap/metaStorageHandler.go index 53c77c23fd0..6872afbfdb6 100644 --- a/epochStart/bootstrap/metaStorageHandler.go +++ b/epochStart/bootstrap/metaStorageHandler.go @@ -36,6 +36,8 @@ func NewMetaStorageHandler( currentEpoch uint32, uint64Converter typeConverters.Uint64ByteSliceConverter, nodeTypeProvider NodeTypeProviderHandler, + snapshotsEnabled bool, + managedPeersHolder common.ManagedPeersHolder, ) (*metaStorageHandler, error) { epochStartNotifier := &disabled.EpochStartNotifier{} storageFactory, err := factory.NewStorageServiceFactory( @@ -49,6 +51,8 @@ func NewMetaStorageHandler( CurrentEpoch: currentEpoch, StorageType: factory.BootstrapStorageService, CreateTrieEpochRootHashStorer: false, + SnapshotsEnabled: snapshotsEnabled, + ManagedPeersHolder: managedPeersHolder, }, ) if err != nil { diff --git a/epochStart/bootstrap/metaStorageHandler_test.go b/epochStart/bootstrap/metaStorageHandler_test.go index 732c617304f..b5e897e3c16 100644 --- a/epochStart/bootstrap/metaStorageHandler_test.go +++ b/epochStart/bootstrap/metaStorageHandler_test.go @@ -32,8 +32,9 @@ func TestNewMetaStorageHandler_InvalidConfigErr(t *testing.T) { hasher := &hashingMocks.HasherMock{} uit64Cvt := &mock.Uint64ByteSliceConverterMock{} nodeTypeProvider := &nodeTypeProviderMock.NodeTypeProviderStub{} + managedPeersHolder := &testscommon.ManagedPeersHolderStub{} - mtStrHandler, err := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider) + mtStrHandler, err := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider, false, managedPeersHolder) assert.True(t, check.IfNil(mtStrHandler)) assert.NotNil(t, err) } @@ -51,7 +52,8 @@ func TestNewMetaStorageHandler_CreateForMetaErr(t *testing.T) { hasher := &hashingMocks.HasherMock{} uit64Cvt := &mock.Uint64ByteSliceConverterMock{} nodeTypeProvider := &nodeTypeProviderMock.NodeTypeProviderStub{} - mtStrHandler, err := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider) + managedPeersHolder := &testscommon.ManagedPeersHolderStub{} + mtStrHandler, err := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider, false, managedPeersHolder) assert.False(t, check.IfNil(mtStrHandler)) assert.Nil(t, err) } @@ -69,8 +71,9 @@ func TestMetaStorageHandler_saveLastHeader(t *testing.T) { hasher := &hashingMocks.HasherMock{} uit64Cvt := &mock.Uint64ByteSliceConverterMock{} nodeTypeProvider := &nodeTypeProviderMock.NodeTypeProviderStub{} + managedPeersHolder := &testscommon.ManagedPeersHolderStub{} - mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider) + mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider, false, managedPeersHolder) header := &block.MetaBlock{Nonce: 0} @@ -97,8 +100,9 @@ func TestMetaStorageHandler_saveLastCrossNotarizedHeaders(t *testing.T) { hasher := &hashingMocks.HasherMock{} uit64Cvt := &mock.Uint64ByteSliceConverterMock{} nodeTypeProvider := &nodeTypeProviderMock.NodeTypeProviderStub{} + managedPeersHolder := &testscommon.ManagedPeersHolderStub{} - mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider) + mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider, false, managedPeersHolder) hdr1 := &block.Header{Nonce: 1} hdr2 := &block.Header{Nonce: 2} @@ -131,8 +135,9 @@ func TestMetaStorageHandler_saveTriggerRegistry(t *testing.T) { hasher := &hashingMocks.HasherMock{} uit64Cvt := &mock.Uint64ByteSliceConverterMock{} nodeTypeProvider := &nodeTypeProviderMock.NodeTypeProviderStub{} + managedPeersHolder := &testscommon.ManagedPeersHolderStub{} - mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider) + mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider, false, managedPeersHolder) components := &ComponentsNeededForBootstrap{ EpochStartMetaBlock: &block.MetaBlock{Nonce: 3}, @@ -156,8 +161,9 @@ func TestMetaStorageHandler_saveDataToStorage(t *testing.T) { hasher := &hashingMocks.HasherMock{} uit64Cvt := &mock.Uint64ByteSliceConverterMock{} nodeTypeProvider := &nodeTypeProviderMock.NodeTypeProviderStub{} + managedPeersHolder := &testscommon.ManagedPeersHolderStub{} - mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider) + mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider, false, managedPeersHolder) components := &ComponentsNeededForBootstrap{ EpochStartMetaBlock: &block.MetaBlock{Nonce: 3}, @@ -198,8 +204,9 @@ func testMetaWithMissingStorer(missingUnit dataRetriever.UnitType, atCallNumber hasher := &hashingMocks.HasherMock{} uit64Cvt := &mock.Uint64ByteSliceConverterMock{} nodeTypeProvider := &nodeTypeProviderMock.NodeTypeProviderStub{} + managedPeersHolder := &testscommon.ManagedPeersHolderStub{} - mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider) + mtStrHandler, _ := NewMetaStorageHandler(gCfg, prefsConfig, coordinator, pathManager, marshalizer, hasher, 1, uit64Cvt, nodeTypeProvider, false, managedPeersHolder) counter := 0 mtStrHandler.storageService = &storageStubs.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { diff --git a/epochStart/bootstrap/process.go b/epochStart/bootstrap/process.go index 3a90bfdab3a..5f98fe6d80b 100644 --- a/epochStart/bootstrap/process.go +++ b/epochStart/bootstrap/process.go @@ -483,6 +483,7 @@ func (e *epochStartBootstrap) prepareComponentsToSyncFromNetwork() error { e.closeTrieComponents() e.storageService = disabled.NewChainStorer() triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + e.flagsConfig.SnapshotsEnabled, e.generalConfig, e.coreComponentsHolder, e.storageService, @@ -762,6 +763,8 @@ func (e *epochStartBootstrap) requestAndProcessForMeta(peerMiniBlocks []*block.M e.epochStartMeta.GetEpoch(), e.coreComponentsHolder.Uint64ByteSliceConverter(), e.coreComponentsHolder.NodeTypeProvider(), + e.flagsConfig.SnapshotsEnabled, + e.cryptoComponentsHolder.ManagedPeersHolder(), ) if err != nil { return err @@ -771,6 +774,7 @@ func (e *epochStartBootstrap) requestAndProcessForMeta(peerMiniBlocks []*block.M e.closeTrieComponents() triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + e.flagsConfig.SnapshotsEnabled, e.generalConfig, e.coreComponentsHolder, storageHandlerComponent.storageService, @@ -928,6 +932,8 @@ func (e *epochStartBootstrap) requestAndProcessForShard(peerMiniBlocks []*block. e.baseData.lastEpoch, e.coreComponentsHolder.Uint64ByteSliceConverter(), e.coreComponentsHolder.NodeTypeProvider(), + e.flagsConfig.SnapshotsEnabled, + e.cryptoComponentsHolder.ManagedPeersHolder(), ) if err != nil { return err @@ -937,6 +943,7 @@ func (e *epochStartBootstrap) requestAndProcessForShard(peerMiniBlocks []*block. e.closeTrieComponents() triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + e.flagsConfig.SnapshotsEnabled, e.generalConfig, e.coreComponentsHolder, storageHandlerComponent.storageService, @@ -1109,6 +1116,8 @@ func (e *epochStartBootstrap) createStorageService( CurrentEpoch: startEpoch, StorageType: storageFactory.BootstrapStorageService, CreateTrieEpochRootHashStorer: createTrieEpochRootHashStorer, + SnapshotsEnabled: e.flagsConfig.SnapshotsEnabled, + ManagedPeersHolder: e.cryptoComponentsHolder.ManagedPeersHolder(), }) if err != nil { return nil, err diff --git a/epochStart/bootstrap/process_test.go b/epochStart/bootstrap/process_test.go index 7426494c896..6ac7b8899ee 100644 --- a/epochStart/bootstrap/process_test.go +++ b/epochStart/bootstrap/process_test.go @@ -42,7 +42,7 @@ import ( statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageMocks "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/syncer" - "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" + validatorInfoCacherStub "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" "github.com/multiversx/mx-chain-go/trie/factory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -87,6 +87,7 @@ func createComponentsForEpochStart() (*mock.CoreComponentsMock, *mock.CryptoComp BlKeyGen: &cryptoMocks.KeyGenStub{}, TxKeyGen: &cryptoMocks.KeyGenStub{}, PeerSignHandler: &cryptoMocks.PeerSignatureHandlerStub{}, + ManagedPeers: &testscommon.ManagedPeersHolderStub{}, } } @@ -139,7 +140,6 @@ func createMockEpochStartBootstrapArgs( }, StateTriesConfig: config.StateTriesConfig{ CheckpointRoundsModulus: 5, - SnapshotsEnabled: true, AccountsStatePruningEnabled: true, PeerStatePruningEnabled: true, MaxStateTrieLevelInMemory: 5, @@ -584,6 +584,17 @@ func TestNewEpochStartBootstrap_NilArgsChecks(t *testing.T) { assert.Equal(t, storage.ErrNotSupportedCacheType, err) assert.Nil(t, epochStartProvider) }) + t.Run("nil managed peers holder", func(t *testing.T) { + t.Parallel() + + coreComp, cryptoComp := createComponentsForEpochStart() + cryptoComp.ManagedPeers = nil + args := createMockEpochStartBootstrapArgs(coreComp, cryptoComp) + + epochStartProvider, err := NewEpochStartBootstrap(args) + require.Nil(t, epochStartProvider) + require.True(t, errors.Is(err, epochStart.ErrNilManagedPeersHolder)) + }) } func TestNewEpochStartBootstrap(t *testing.T) { @@ -992,6 +1003,7 @@ func TestSyncValidatorAccountsState_NilRequestHandlerErr(t *testing.T) { }, } triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + false, args.GeneralConfig, coreComp, disabled.NewChainStorer(), @@ -1011,6 +1023,7 @@ func TestCreateTriesForNewShardID(t *testing.T) { args.GeneralConfig = testscommon.GetGeneralConfig() triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + false, args.GeneralConfig, coreComp, disabled.NewChainStorer(), @@ -1037,6 +1050,7 @@ func TestSyncUserAccountsState(t *testing.T) { } triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + false, args.GeneralConfig, coreComp, disabled.NewChainStorer(), diff --git a/epochStart/bootstrap/shardStorageHandler.go b/epochStart/bootstrap/shardStorageHandler.go index 79d2993d204..44fedbc8bf8 100644 --- a/epochStart/bootstrap/shardStorageHandler.go +++ b/epochStart/bootstrap/shardStorageHandler.go @@ -40,6 +40,8 @@ func NewShardStorageHandler( currentEpoch uint32, uint64Converter typeConverters.Uint64ByteSliceConverter, nodeTypeProvider core.NodeTypeProviderHandler, + snapshotsEnabled bool, + managedPeersHolder common.ManagedPeersHolder, ) (*shardStorageHandler, error) { epochStartNotifier := &disabled.EpochStartNotifier{} storageFactory, err := factory.NewStorageServiceFactory( @@ -53,6 +55,8 @@ func NewShardStorageHandler( CurrentEpoch: currentEpoch, StorageType: factory.BootstrapStorageService, CreateTrieEpochRootHashStorer: false, + SnapshotsEnabled: snapshotsEnabled, + ManagedPeersHolder: managedPeersHolder, }, ) if err != nil { diff --git a/epochStart/bootstrap/shardStorageHandler_test.go b/epochStart/bootstrap/shardStorageHandler_test.go index 61adf0d4921..3405cca7f57 100644 --- a/epochStart/bootstrap/shardStorageHandler_test.go +++ b/epochStart/bootstrap/shardStorageHandler_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" @@ -39,7 +40,7 @@ func TestNewShardStorageHandler_ShouldWork(t *testing.T) { }() args := createDefaultShardStorageArgs() - shardStorage, err := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, err := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) assert.False(t, check.IfNil(shardStorage)) assert.Nil(t, err) @@ -51,7 +52,7 @@ func TestShardStorageHandler_SaveDataToStorageShardDataNotFound(t *testing.T) { }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) components := &ComponentsNeededForBootstrap{ EpochStartMetaBlock: &block.MetaBlock{Epoch: 1}, @@ -69,7 +70,7 @@ func TestShardStorageHandler_SaveDataToStorageMissingHeader(t *testing.T) { }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) components := &ComponentsNeededForBootstrap{ EpochStartMetaBlock: &block.MetaBlock{ @@ -110,7 +111,7 @@ func testShardWithMissingStorer(missingUnit dataRetriever.UnitType, atCallNumber counter := 0 args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shardStorage.storageService = &storageStubs.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { counter++ @@ -152,7 +153,7 @@ func TestShardStorageHandler_SaveDataToStorage(t *testing.T) { }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) hash1 := []byte("hash1") hdr1 := block.MetaBlock{ @@ -251,7 +252,7 @@ func TestShardStorageHandler_getCrossProcessedMiniBlockHeadersDestMe(t *testing. mbs := append(intraMbs, crossMbs...) args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shardHeader := &block.Header{ Nonce: 100, MiniBlockHeaders: mbs, @@ -271,7 +272,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksWithScheduledErrorG t.Parallel() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) meta := &block.MetaBlock{ Nonce: 100, EpochStart: block.EpochStart{}, @@ -289,7 +290,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksWithScheduledNoSche t.Parallel() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) scenario := createPendingAndProcessedMiniBlocksScenario() processedMiniBlocks, pendingMiniBlocks, err := shardStorage.getProcessedAndPendingMiniBlocksWithScheduled(scenario.metaBlock, scenario.headers, scenario.shardHeader, false) @@ -304,7 +305,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksWithScheduledWrongH t.Parallel() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) scenario := createPendingAndProcessedMiniBlocksScenario() wrongShardHeader := &block.MetaBlock{} @@ -326,7 +327,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksWithScheduled(t *te t.Parallel() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) scenario := createPendingAndProcessedMiniBlocksScenario() processedMiniBlocks, pendingMiniBlocks, err := shardStorage.getProcessedAndPendingMiniBlocksWithScheduled(scenario.metaBlock, scenario.headers, scenario.shardHeader, true) @@ -494,7 +495,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksErrorGettingEpochSt }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) meta := &block.MetaBlock{ Nonce: 100, EpochStart: block.EpochStart{}, @@ -517,7 +518,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksMissingHeader(t *te lastFinishedMetaBlock := "last finished meta block" args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) meta := &block.MetaBlock{ Nonce: 100, EpochStart: block.EpochStart{ @@ -543,7 +544,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksWrongHeader(t *test lastFinishedMetaBlockHash := "last finished meta block" firstPendingMeta := "first pending meta" args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) lastFinishedHeaders := createDefaultEpochStartShardData([]byte(lastFinishedMetaBlockHash), []byte("headerHash")) lastFinishedHeaders[0].FirstPendingMetaBlock = []byte(firstPendingMeta) meta := &block.MetaBlock{ @@ -574,7 +575,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksNilMetaBlock(t *tes lastFinishedMetaBlockHash := "last finished meta block" firstPendingMeta := "first pending meta" args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) lastFinishedHeaders := createDefaultEpochStartShardData([]byte(lastFinishedMetaBlockHash), []byte("headerHash")) lastFinishedHeaders[0].FirstPendingMetaBlock = []byte(firstPendingMeta) meta := &block.MetaBlock{ @@ -607,7 +608,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksNoProcessedNoPendin lastFinishedMetaBlockHash := "last finished meta block" firstPendingMeta := "first pending meta" args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) lastFinishedHeaders := createDefaultEpochStartShardData([]byte(lastFinishedMetaBlockHash), []byte("headerHash")) lastFinishedHeaders[0].FirstPendingMetaBlock = []byte(firstPendingMeta) lastFinishedHeaders[0].PendingMiniBlockHeaders = nil @@ -636,7 +637,7 @@ func TestShardStorageHandler_getProcessedAndPendingMiniBlocksWithProcessedAndPen t.Parallel() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) scenario := createPendingAndProcessedMiniBlocksScenario() processedMiniBlocks, pendingMiniBlocks, firstPendingMetaBlockHash, err := shardStorage.getProcessedAndPendingMiniBlocks(scenario.metaBlock, scenario.headers) @@ -654,7 +655,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithoutScheduledGetSha }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) headers := map[string]data.HeaderHandler{} meta := &block.MetaBlock{ @@ -675,7 +676,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithoutScheduledMissin }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shard0HeaderHash := "shard0 header hash" lastFinishedMetaBlock := "last finished meta block" @@ -704,7 +705,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithoutScheduledWrongT }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shard0HeaderHash := "shard0 header hash" lastFinishedMetaBlock := "last finished meta block" @@ -740,7 +741,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithoutScheduledErrorW args.marshalizer = &testscommon.MarshalizerStub{MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }} - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shard0HeaderHash := "shard0 header hash" lastFinishedMetaBlock := "last finished meta block" @@ -771,7 +772,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithoutScheduled(t *te }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shard0HeaderHash := "shard0 header hash" lastFinishedMetaBlock := "last finished meta block" @@ -807,7 +808,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithScheduledErrorUpda }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shard0HeaderHash := "shard0 header hash" lastFinishedMetaBlock := "last finished meta block" @@ -837,7 +838,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithScheduled(t *testi }() args := createDefaultShardStorageArgs() - shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider) + shardStorage, _ := NewShardStorageHandler(args.generalConfig, args.prefsConfig, args.shardCoordinator, args.pathManagerHandler, args.marshalizer, args.hasher, 1, args.uint64Converter, args.nodeTypeProvider, false, args.managedPeersHolder) shard0HeaderHash := "shard0 header hash" lastFinishedMetaBlock := "last finished meta block" prevMetaHash := "prev metaHlock hash" @@ -1059,6 +1060,7 @@ type shardStorageArgs struct { currentEpoch uint32 uint64Converter typeConverters.Uint64ByteSliceConverter nodeTypeProvider core.NodeTypeProviderHandler + managedPeersHolder common.ManagedPeersHolder } func createDefaultShardStorageArgs() shardStorageArgs { @@ -1072,6 +1074,7 @@ func createDefaultShardStorageArgs() shardStorageArgs { currentEpoch: 0, uint64Converter: &mock.Uint64ByteSliceConverterMock{}, nodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + managedPeersHolder: &testscommon.ManagedPeersHolderStub{}, } } diff --git a/epochStart/bootstrap/startInEpochScheduled.go b/epochStart/bootstrap/startInEpochScheduled.go index d9aba778f34..807d81729f5 100644 --- a/epochStart/bootstrap/startInEpochScheduled.go +++ b/epochStart/bootstrap/startInEpochScheduled.go @@ -237,7 +237,7 @@ func (ses *startInEpochWithScheduledDataSyncer) prepareScheduledIntermediateTxs( header data.HeaderHandler, miniBlocks map[string]*block.MiniBlock, ) error { - scheduledTxHashes, err := ses.getScheduledTransactionHashes(prevHeader) + scheduledTxHashes, prevHeaderMiniblocks, err := ses.getScheduledTransactionHashes(prevHeader) if err != nil { return err } @@ -272,6 +272,14 @@ func (ses *startInEpochWithScheduledDataSyncer) prepareScheduledIntermediateTxs( ses.saveScheduledInfo(header.GetPrevHash(), scheduledInfo) } + if miniBlocks == nil { + miniBlocks = make(map[string]*block.MiniBlock) + } + + for hash, mb := range prevHeaderMiniblocks { + miniBlocks[hash] = mb + } + return nil } @@ -430,11 +438,11 @@ func (ses *startInEpochWithScheduledDataSyncer) getScheduledMiniBlockHeaders(hea return schMiniBlockHeaders } -func (ses *startInEpochWithScheduledDataSyncer) getScheduledTransactionHashes(header data.HeaderHandler) (map[string]uint32, error) { +func (ses *startInEpochWithScheduledDataSyncer) getScheduledTransactionHashes(header data.HeaderHandler) (map[string]uint32, map[string]*block.MiniBlock, error) { miniBlockHeaders := ses.getScheduledMiniBlockHeaders(header) miniBlocks, err := ses.getRequiredMiniBlocksByMbHeader(miniBlockHeaders) if err != nil { - return nil, err + return nil, nil, err } scheduledTxsForShard := make(map[string]uint32) @@ -447,7 +455,7 @@ func (ses *startInEpochWithScheduledDataSyncer) getScheduledTransactionHashes(he createScheduledTxsForShardMap(pi, miniBlock, miniBlockHash, scheduledTxsForShard) } - return scheduledTxsForShard, nil + return scheduledTxsForShard, miniBlocks, nil } func getMiniBlockAndProcessedIndexes( diff --git a/epochStart/bootstrap/startInEpochScheduled_test.go b/epochStart/bootstrap/startInEpochScheduled_test.go index 7260a9587a5..fe1e23afcae 100644 --- a/epochStart/bootstrap/startInEpochScheduled_test.go +++ b/epochStart/bootstrap/startInEpochScheduled_test.go @@ -630,9 +630,10 @@ func TestStartInEpochWithScheduledDataSyncer_getScheduledTransactionHashesWithDe }, } - scheduledTxHashes, err := sds.getScheduledTransactionHashes(header) + scheduledTxHashes, scheduledMBs, err := sds.getScheduledTransactionHashes(header) require.Nil(t, err) require.Equal(t, expectedScheduledTxHashes, scheduledTxHashes) + require.Len(t, scheduledMBs, 2) } func Test_getShardIDAndHashesForIncludedMetaBlocks(t *testing.T) { diff --git a/epochStart/bootstrap/storageProcess.go b/epochStart/bootstrap/storageProcess.go index 852fa6257c3..1b4578c88c9 100644 --- a/epochStart/bootstrap/storageProcess.go +++ b/epochStart/bootstrap/storageProcess.go @@ -149,6 +149,7 @@ func (sesb *storageEpochStartBootstrap) prepareComponentsToSync() error { sesb.closeTrieComponents() sesb.storageService = disabled.NewChainStorer() triesContainer, trieStorageManagers, err := factory.CreateTriesComponentsForShardId( + sesb.flagsConfig.SnapshotsEnabled, sesb.generalConfig, sesb.coreComponentsHolder, sesb.storageService, @@ -251,6 +252,7 @@ func (sesb *storageEpochStartBootstrap) createStorageRequesters() error { DataPacker: dataPacker, ManualEpochStartNotifier: mesn, ChanGracefullyClose: sesb.chanGracefullyClose, + SnapshotsEnabled: sesb.flagsConfig.SnapshotsEnabled, } var requestersContainerFactory dataRetriever.RequestersContainerFactory diff --git a/epochStart/errors.go b/epochStart/errors.go index 4d30f202d74..bb3c68e0e43 100644 --- a/epochStart/errors.go +++ b/epochStart/errors.go @@ -325,3 +325,6 @@ var ErrNilValidatorInfoStorage = errors.New("nil validator info storage") // ErrNilTrieSyncStatistics signals that nil trie sync statistics has been provided var ErrNilTrieSyncStatistics = errors.New("nil trie sync statistics") + +// ErrNilManagedPeersHolder signals that a nil managed peers holder has been provided +var ErrNilManagedPeersHolder = errors.New("nil managed peers holder") diff --git a/epochStart/mock/cryptoComponentsMock.go b/epochStart/mock/cryptoComponentsMock.go index fc41fa173dc..4dfbe6d91a5 100644 --- a/epochStart/mock/cryptoComponentsMock.go +++ b/epochStart/mock/cryptoComponentsMock.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" ) @@ -18,6 +19,7 @@ type CryptoComponentsMock struct { PeerSignHandler crypto.PeerSignatureHandler BlKeyGen crypto.KeyGenerator TxKeyGen crypto.KeyGenerator + ManagedPeers common.ManagedPeersHolder mutCrypto sync.RWMutex } @@ -85,6 +87,11 @@ func (ccm *CryptoComponentsMock) TxSignKeyGen() crypto.KeyGenerator { return ccm.TxKeyGen } +// ManagedPeersHolder - +func (ccm *CryptoComponentsMock) ManagedPeersHolder() common.ManagedPeersHolder { + return ccm.ManagedPeers +} + // Clone - func (ccm *CryptoComponentsMock) Clone() interface{} { return &CryptoComponentsMock{ diff --git a/errors/errors.go b/errors/errors.go index c03015d9aac..1030b789082 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -482,9 +482,6 @@ var ErrNilProcessStatusHandler = errors.New("nil process status handler") // ErrNilESDTDataStorage signals that a nil esdt data storage has been provided var ErrNilESDTDataStorage = errors.New("nil esdt data storage") -// ErrDBIsClosed is raised when the DB is closed -var ErrDBIsClosed = errors.New("DB is closed") - // ErrNilEnableEpochsHandler signals that a nil enable epochs handler was provided var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") @@ -544,3 +541,9 @@ var ErrNilPersistentHandler = errors.New("nil persistent handler") // ErrNilGenesisNodesSetupHandler signals that a nil genesis nodes setup handler has been provided var ErrNilGenesisNodesSetupHandler = errors.New("nil genesis nodes setup handler") + +// ErrNilManagedPeersHolder signals that a nil managed peers holder has been provided +var ErrNilManagedPeersHolder = errors.New("nil managed peers holder") + +// ErrEmptyPeerID signals that an empty peer ID has been provided +var ErrEmptyPeerID = errors.New("empty peer ID") diff --git a/factory/consensus/consensusComponents.go b/factory/consensus/consensusComponents.go index a9befa51657..a4060921bb8 100644 --- a/factory/consensus/consensusComponents.go +++ b/factory/consensus/consensusComponents.go @@ -8,7 +8,6 @@ import ( "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/core/watchdog" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-storage-go/timecache" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/disabled" "github.com/multiversx/mx-chain-go/config" @@ -30,6 +29,7 @@ import ( "github.com/multiversx/mx-chain-go/trie/storageMarker" "github.com/multiversx/mx-chain-go/update" logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-storage-go/timecache" ) var log = logger.GetOrCreate("factory") @@ -178,11 +178,11 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { ccf.coreComponents.Hasher(), ccf.networkComponents.NetworkMessenger(), ccf.processComponents.ShardCoordinator(), - ccf.cryptoComponents.PrivateKey(), ccf.cryptoComponents.PeerSignatureHandler(), ccf.dataComponents.Datapool().Headers(), ccf.processComponents.InterceptorsContainer(), ccf.coreComponents.AlarmScheduler(), + ccf.cryptoComponents.KeysHandler(), ) if err != nil { return nil, err @@ -258,9 +258,6 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { ChronologyHandler: cc.chronology, Hasher: ccf.coreComponents.Hasher(), Marshalizer: ccf.coreComponents.InternalMarshalizer(), - BlsPrivateKey: ccf.cryptoComponents.PrivateKey(), - BlsSingleSigner: ccf.cryptoComponents.BlockSigner(), - KeyGenerator: ccf.cryptoComponents.BlockSignKeyGen(), MultiSignerContainer: ccf.cryptoComponents.MultiSignerContainer(), RoundHandler: ccf.processComponents.RoundHandler(), ShardCoordinator: ccf.processComponents.ShardCoordinator(), @@ -275,7 +272,7 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { ScheduledProcessor: ccf.scheduledProcessor, MessageSigningHandler: p2pSigningHandler, PeerBlacklistHandler: cc.peerBlacklistHandler, - SignatureHandler: ccf.cryptoComponents.ConsensusSigHandler(), + SigningHandler: ccf.cryptoComponents.ConsensusSigningHandler(), } consensusDataContainer, err := spos.NewConsensusCore( @@ -399,11 +396,16 @@ func (ccf *consensusComponentsFactory) createConsensusState(epoch uint32, consen return nil, err } - roundConsensus := spos.NewRoundConsensus( + roundConsensus, err := spos.NewRoundConsensus( eligibleNodesPubKeys, // TODO: move the consensus data from nodesSetup json to config consensusGroupSize, - string(selfId)) + string(selfId), + ccf.cryptoComponents.KeysHandler(), + ) + if err != nil { + return nil, err + } roundConsensus.ResetRoundState() diff --git a/factory/crypto/cryptoComponents.go b/factory/crypto/cryptoComponents.go index 853bf8e84e2..6191ed2ab32 100644 --- a/factory/crypto/cryptoComponents.go +++ b/factory/crypto/cryptoComponents.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-crypto-go/signing" @@ -24,6 +25,8 @@ import ( "github.com/multiversx/mx-chain-go/factory" "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" "github.com/multiversx/mx-chain-go/genesis/process/disabled" + "github.com/multiversx/mx-chain-go/keysManagement" + p2pFactory "github.com/multiversx/mx-chain-go/p2p/factory" storageFactory "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/vm" @@ -31,14 +34,19 @@ import ( logger "github.com/multiversx/mx-chain-logger-go" ) -const disabledSigChecking = "disabled" +const ( + disabledSigChecking = "disabled" + mainMachineRedundancyLevel = 0 +) // CryptoComponentsFactoryArgs holds the arguments needed for creating crypto components type CryptoComponentsFactoryArgs struct { ValidatorKeyPemFileName string + AllValidatorKeysPemFileName string SkIndex int Config config.Config EnableEpochs config.EnableEpochs + PrefsConfig config.Preferences CoreComponentsHolder factory.CoreComponentsHolder KeyLoader factory.KeyLoaderHandler ActivateBLSPubKeyMessageVerification bool @@ -51,10 +59,12 @@ type CryptoComponentsFactoryArgs struct { type cryptoComponentsFactory struct { consensusType string validatorKeyPemFileName string + allValidatorKeysPemFileName string skIndex int config config.Config enableEpochs config.EnableEpochs - coreComponentsHolder factory.CoreComponentsHolder + prefsConfig config.Preferences + validatorPubKeyConverter core.PubkeyConverter activateBLSPubKeyMessageVerification bool keyLoader factory.KeyLoaderHandler isInImportMode bool @@ -65,11 +75,12 @@ type cryptoComponentsFactory struct { // cryptoParams holds the node public/private key data type cryptoParams struct { - publicKey crypto.PublicKey - privateKey crypto.PrivateKey - publicKeyString string - publicKeyBytes []byte - privateKeyBytes []byte + publicKey crypto.PublicKey + privateKey crypto.PrivateKey + publicKeyString string + publicKeyBytes []byte + privateKeyBytes []byte + handledPrivateKeys [][]byte } // p2pCryptoParams holds the p2p public/private key data @@ -80,16 +91,18 @@ type p2pCryptoParams struct { // cryptoComponents struct holds the crypto components type cryptoComponents struct { - txSingleSigner crypto.SingleSigner - blockSingleSigner crypto.SingleSigner - p2pSingleSigner crypto.SingleSigner - multiSignerContainer cryptoCommon.MultiSignerContainer - peerSignHandler crypto.PeerSignatureHandler - blockSignKeyGen crypto.KeyGenerator - txSignKeyGen crypto.KeyGenerator - p2pKeyGen crypto.KeyGenerator - messageSignVerifier vm.MessageSignVerifier - consensusSigHandler consensus.SignatureHandler + txSingleSigner crypto.SingleSigner + blockSingleSigner crypto.SingleSigner + p2pSingleSigner crypto.SingleSigner + multiSignerContainer cryptoCommon.MultiSignerContainer + peerSignHandler crypto.PeerSignatureHandler + blockSignKeyGen crypto.KeyGenerator + txSignKeyGen crypto.KeyGenerator + p2pKeyGen crypto.KeyGenerator + messageSignVerifier vm.MessageSignVerifier + consensusSigningHandler consensus.SigningHandler + managedPeersHolder common.ManagedPeersHolder + keysHandler consensus.KeysHandler cryptoParams p2pCryptoParams } @@ -101,6 +114,9 @@ func NewCryptoComponentsFactory(args CryptoComponentsFactoryArgs) (*cryptoCompon if check.IfNil(args.CoreComponentsHolder) { return nil, errors.ErrNilCoreComponents } + if check.IfNil(args.CoreComponentsHolder.ValidatorPubKeyConverter()) { + return nil, errors.ErrNilPubKeyConverter + } if len(args.ValidatorKeyPemFileName) == 0 { return nil, errors.ErrNilPath } @@ -113,7 +129,8 @@ func NewCryptoComponentsFactory(args CryptoComponentsFactoryArgs) (*cryptoCompon validatorKeyPemFileName: args.ValidatorKeyPemFileName, skIndex: args.SkIndex, config: args.Config, - coreComponentsHolder: args.CoreComponentsHolder, + prefsConfig: args.PrefsConfig, + validatorPubKeyConverter: args.CoreComponentsHolder.ValidatorPubKeyConverter(), activateBLSPubKeyMessageVerification: args.ActivateBLSPubKeyMessageVerification, keyLoader: args.KeyLoader, isInImportMode: args.IsInImportMode, @@ -121,6 +138,7 @@ func NewCryptoComponentsFactory(args CryptoComponentsFactoryArgs) (*cryptoCompon enableEpochs: args.EnableEpochs, noKeyProvided: args.NoKeyProvided, p2pKeyPemFileName: args.P2pKeyPemFileName, + allValidatorKeysPemFileName: args.AllValidatorKeysPemFileName, } return ccf, nil @@ -188,32 +206,73 @@ func (ccf *cryptoComponentsFactory) Create() (*cryptoComponents, error) { return nil, err } - signatureHolderArgs := ArgsSignatureHolder{ + // TODO: refactor the logic for isMainMachine + redundancyLevel := int(ccf.prefsConfig.Preferences.RedundancyLevel) + isMainMachine := redundancyLevel == mainMachineRedundancyLevel + argsManagedPeersHolder := keysManagement.ArgsManagedPeersHolder{ + KeyGenerator: blockSignKeyGen, + P2PKeyGenerator: p2pKeyGenerator, + IsMainMachine: isMainMachine, + MaxRoundsWithoutReceivedMessages: redundancyLevel, + PrefsConfig: ccf.prefsConfig, + P2PKeyConverter: p2pFactory.NewP2PKeyConverter(), + } + managedPeersHolder, err := keysManagement.NewManagedPeersHolder(argsManagedPeersHolder) + if err != nil { + return nil, err + } + + for _, skBytes := range cp.handledPrivateKeys { + errAddManagedPeer := managedPeersHolder.AddManagedPeer(skBytes) + if errAddManagedPeer != nil { + return nil, errAddManagedPeer + } + } + + log.Debug("block sign pubkey", "value", cp.publicKeyString) + + currentPid, err := argsManagedPeersHolder.P2PKeyConverter.ConvertPublicKeyToPeerID(p2pCryptoParamsInstance.p2pPublicKey) + if err != nil { + return nil, err + } + + argsKeysHandler := keysManagement.ArgsKeysHandler{ + ManagedPeersHolder: managedPeersHolder, + PrivateKey: cp.privateKey, + Pid: currentPid, + } + keysHandler, err := keysManagement.NewKeysHandler(argsKeysHandler) + if err != nil { + return nil, err + } + + signingHandlerArgs := ArgsSigningHandler{ PubKeys: []string{cp.publicKeyString}, - PrivKeyBytes: cp.privateKeyBytes, MultiSignerContainer: multiSigner, KeyGenerator: blockSignKeyGen, + SingleSigner: interceptSingleSigner, + KeysHandler: keysHandler, } - consensusSigHandler, err := NewSignatureHolder(signatureHolderArgs) + consensusSigningHandler, err := NewSigningHandler(signingHandlerArgs) if err != nil { return nil, err } - log.Debug("block sign pubkey", "value", cp.publicKeyString) - return &cryptoComponents{ - txSingleSigner: txSingleSigner, - blockSingleSigner: interceptSingleSigner, - multiSignerContainer: multiSigner, - peerSignHandler: peerSigHandler, - blockSignKeyGen: blockSignKeyGen, - txSignKeyGen: txSignKeyGen, - p2pKeyGen: p2pKeyGenerator, - messageSignVerifier: messageSignVerifier, - consensusSigHandler: consensusSigHandler, - cryptoParams: *cp, - p2pCryptoParams: *p2pCryptoParamsInstance, - p2pSingleSigner: p2pSingleSigner, + txSingleSigner: txSingleSigner, + blockSingleSigner: interceptSingleSigner, + multiSignerContainer: multiSigner, + peerSignHandler: peerSigHandler, + blockSignKeyGen: blockSignKeyGen, + txSignKeyGen: txSignKeyGen, + p2pKeyGen: p2pKeyGenerator, + messageSignVerifier: messageSignVerifier, + consensusSigningHandler: consensusSigningHandler, + managedPeersHolder: managedPeersHolder, + keysHandler: keysHandler, + cryptoParams: *cp, + p2pCryptoParams: *p2pCryptoParamsInstance, + p2pSingleSigner: p2pSingleSigner, }, nil } @@ -264,9 +323,23 @@ func (ccf *cryptoComponentsFactory) createCryptoParams( keygen crypto.KeyGenerator, ) (*cryptoParams, error) { - shouldGenerateCryptoParams := ccf.isInImportMode || ccf.noKeyProvided - if shouldGenerateCryptoParams { - return ccf.generateCryptoParams(keygen) + handledPrivateKeys, err := ccf.processAllHandledKeys(keygen) + if err != nil { + return nil, err + } + + if ccf.isInImportMode { + if len(handledPrivateKeys) > 0 { + return nil, fmt.Errorf("invalid node configuration: import-db mode and allValidatorsKeys.pem file provided") + } + + return ccf.generateCryptoParams(keygen, "in import mode", handledPrivateKeys) + } + if ccf.noKeyProvided { + return ccf.generateCryptoParams(keygen, "with no-key flag enabled", make([][]byte, 0)) + } + if len(handledPrivateKeys) > 0 { + return ccf.generateCryptoParams(keygen, "running with a provided allValidatorsKeys.pem", handledPrivateKeys) } return ccf.readCryptoParams(keygen) @@ -301,8 +374,7 @@ func (ccf *cryptoComponentsFactory) readCryptoParams(keygen crypto.KeyGenerator) } } - validatorKeyConverter := ccf.coreComponentsHolder.ValidatorPubKeyConverter() - cp.publicKeyString, err = validatorKeyConverter.Encode(cp.publicKeyBytes) + cp.publicKeyString, err = ccf.validatorPubKeyConverter.Encode(cp.publicKeyBytes) if err != nil { return nil, err } @@ -310,15 +382,12 @@ func (ccf *cryptoComponentsFactory) readCryptoParams(keygen crypto.KeyGenerator) return cp, nil } -func (ccf *cryptoComponentsFactory) generateCryptoParams(keygen crypto.KeyGenerator) (*cryptoParams, error) { - var message string - if ccf.noKeyProvided { - message = "with no-key flag enabled" - } else { - message = "in import mode" - } - - log.Warn(fmt.Sprintf("the node is %s! Will generate a fresh new BLS key", message)) +func (ccf *cryptoComponentsFactory) generateCryptoParams( + keygen crypto.KeyGenerator, + reason string, + handledPrivateKeys [][]byte, +) (*cryptoParams, error) { + log.Warn(fmt.Sprintf("the node is %s! Will generate a fresh new BLS key", reason)) cp := &cryptoParams{} cp.privateKey, cp.publicKey = keygen.GeneratePair() @@ -333,11 +402,11 @@ func (ccf *cryptoComponentsFactory) generateCryptoParams(keygen crypto.KeyGenera return nil, err } - validatorKeyConverter := ccf.coreComponentsHolder.ValidatorPubKeyConverter() - cp.publicKeyString, err = validatorKeyConverter.Encode(cp.publicKeyBytes) + cp.publicKeyString, err = ccf.validatorPubKeyConverter.Encode(cp.publicKeyBytes) if err != nil { return nil, err } + cp.handledPrivateKeys = handledPrivateKeys return cp, nil } @@ -353,8 +422,7 @@ func (ccf *cryptoComponentsFactory) getSkPk() ([]byte, []byte, error) { return nil, nil, fmt.Errorf("%w for encoded secret key", err) } - validatorKeyConverter := ccf.coreComponentsHolder.ValidatorPubKeyConverter() - pkBytes, err := validatorKeyConverter.Decode(pkString) + pkBytes, err := ccf.validatorPubKeyConverter.Decode(pkString) if err != nil { return nil, nil, fmt.Errorf("%w for encoded public key %s", err, pkString) } @@ -406,6 +474,65 @@ func CreateP2pKeyPair( return privKey, privKey.GeneratePublic(), nil } +func (ccf *cryptoComponentsFactory) processAllHandledKeys(keygen crypto.KeyGenerator) ([][]byte, error) { + privateKeys, publicKeys, err := ccf.keyLoader.LoadAllKeys(ccf.allValidatorKeysPemFileName) + if err != nil { + log.Debug("allValidatorsKeys could not be loaded", "reason", err) + return make([][]byte, 0), nil + } + + if len(privateKeys) != len(publicKeys) { + return nil, fmt.Errorf("key loading error for the allValidatorsKeys file: mismatch number of private and public keys") + } + + handledPrivateKeys := make([][]byte, 0, len(privateKeys)) + for i, pkString := range publicKeys { + sk := privateKeys[i] + processedSkBytes, errCheck := ccf.processPrivatePublicKey(keygen, sk, pkString, i) + if errCheck != nil { + return nil, errCheck + } + + log.Debug("loaded handled node key", "public key", pkString) + handledPrivateKeys = append(handledPrivateKeys, processedSkBytes) + } + + return handledPrivateKeys, nil +} + +func (ccf *cryptoComponentsFactory) processPrivatePublicKey(keygen crypto.KeyGenerator, encodedSk []byte, pkString string, index int) ([]byte, error) { + skBytes, err := hex.DecodeString(string(encodedSk)) + if err != nil { + return nil, fmt.Errorf("%w for encoded secret key, key index %d", err, index) + } + + pkBytes, err := ccf.validatorPubKeyConverter.Decode(pkString) + if err != nil { + return nil, fmt.Errorf("%w for encoded public key %s, key index %d", err, pkString, index) + } + + sk, err := keygen.PrivateKeyFromByteArray(skBytes) + if err != nil { + return nil, fmt.Errorf("%w secret key, key index %d", err, index) + } + + pk := sk.GeneratePublic() + pkGeneratedBytes, err := pk.ToByteArray() + if err != nil { + return nil, fmt.Errorf("%w while generating public key bytes, key index %d", err, index) + } + + if !bytes.Equal(pkGeneratedBytes, pkBytes) { + return nil, fmt.Errorf("public keys mismatch, read %s, generated %s, key index %d", + pkString, + ccf.validatorPubKeyConverter.SilentEncode(pkBytes, log), + index, + ) + } + + return skBytes, nil +} + // Close closes all underlying components that need closing func (cc *cryptoComponents) Close() error { return nil diff --git a/factory/crypto/cryptoComponentsHandler.go b/factory/crypto/cryptoComponentsHandler.go index f17066a959a..7238e2153c4 100644 --- a/factory/crypto/cryptoComponentsHandler.go +++ b/factory/crypto/cryptoComponentsHandler.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/errors" @@ -119,6 +120,9 @@ func (mcc *managedCryptoComponents) CheckSubcomponents() error { if check.IfNil(mcc.cryptoComponents.messageSignVerifier) { return errors.ErrNilMessageSignVerifier } + if check.IfNil(mcc.cryptoComponents.managedPeersHolder) { + return errors.ErrNilManagedPeersHolder + } return nil } @@ -339,8 +343,32 @@ func (mcc *managedCryptoComponents) MessageSignVerifier() vm.MessageSignVerifier return mcc.cryptoComponents.messageSignVerifier } -// ConsensusSigHandler returns the consensus signature handler -func (mcc *managedCryptoComponents) ConsensusSigHandler() consensus.SignatureHandler { +// ConsensusSigningHandler returns the consensus signing handler +func (mcc *managedCryptoComponents) ConsensusSigningHandler() consensus.SigningHandler { + mcc.mutCryptoComponents.RLock() + defer mcc.mutCryptoComponents.RUnlock() + + if mcc.cryptoComponents == nil { + return nil + } + + return mcc.cryptoComponents.consensusSigningHandler +} + +// ManagedPeersHolder returns the managed peers holder +func (mcc *managedCryptoComponents) ManagedPeersHolder() common.ManagedPeersHolder { + mcc.mutCryptoComponents.RLock() + defer mcc.mutCryptoComponents.RUnlock() + + if mcc.cryptoComponents == nil { + return nil + } + + return mcc.cryptoComponents.managedPeersHolder +} + +// KeysHandler returns the handler that manages keys either in single sign mode or multi key mode +func (mcc *managedCryptoComponents) KeysHandler() consensus.KeysHandler { mcc.mutCryptoComponents.RLock() defer mcc.mutCryptoComponents.RUnlock() @@ -348,7 +376,7 @@ func (mcc *managedCryptoComponents) ConsensusSigHandler() consensus.SignatureHan return nil } - return mcc.cryptoComponents.consensusSigHandler + return mcc.cryptoComponents.keysHandler } // Clone creates a shallow clone of a managedCryptoComponents @@ -356,18 +384,20 @@ func (mcc *managedCryptoComponents) Clone() interface{} { cryptoComp := (*cryptoComponents)(nil) if mcc.cryptoComponents != nil { cryptoComp = &cryptoComponents{ - txSingleSigner: mcc.TxSingleSigner(), - blockSingleSigner: mcc.BlockSigner(), - p2pSingleSigner: mcc.P2pSingleSigner(), - multiSignerContainer: mcc.MultiSignerContainer(), - peerSignHandler: mcc.PeerSignatureHandler(), - blockSignKeyGen: mcc.BlockSignKeyGen(), - txSignKeyGen: mcc.TxSignKeyGen(), - p2pKeyGen: mcc.P2pKeyGen(), - messageSignVerifier: mcc.MessageSignVerifier(), - consensusSigHandler: mcc.ConsensusSigHandler(), - cryptoParams: mcc.cryptoParams, - p2pCryptoParams: mcc.p2pCryptoParams, + txSingleSigner: mcc.TxSingleSigner(), + blockSingleSigner: mcc.BlockSigner(), + p2pSingleSigner: mcc.P2pSingleSigner(), + multiSignerContainer: mcc.MultiSignerContainer(), + peerSignHandler: mcc.PeerSignatureHandler(), + blockSignKeyGen: mcc.BlockSignKeyGen(), + txSignKeyGen: mcc.TxSignKeyGen(), + p2pKeyGen: mcc.P2pKeyGen(), + messageSignVerifier: mcc.MessageSignVerifier(), + consensusSigningHandler: mcc.ConsensusSigningHandler(), + managedPeersHolder: mcc.ManagedPeersHolder(), + keysHandler: mcc.KeysHandler(), + cryptoParams: mcc.cryptoParams, + p2pCryptoParams: mcc.p2pCryptoParams, } } diff --git a/factory/crypto/cryptoComponentsHandler_test.go b/factory/crypto/cryptoComponentsHandler_test.go index 218f84c43f2..45aed193e93 100644 --- a/factory/crypto/cryptoComponentsHandler_test.go +++ b/factory/crypto/cryptoComponentsHandler_test.go @@ -45,6 +45,7 @@ func TestManagedCryptoComponents_CreateShouldWork(t *testing.T) { require.Nil(t, managedCryptoComponents.BlockSignKeyGen()) require.Nil(t, managedCryptoComponents.TxSignKeyGen()) require.Nil(t, managedCryptoComponents.MessageSignVerifier()) + require.Nil(t, managedCryptoComponents.ManagedPeersHolder()) err = managedCryptoComponents.Create() require.NoError(t, err) @@ -57,6 +58,7 @@ func TestManagedCryptoComponents_CreateShouldWork(t *testing.T) { require.NotNil(t, managedCryptoComponents.BlockSignKeyGen()) require.NotNil(t, managedCryptoComponents.TxSignKeyGen()) require.NotNil(t, managedCryptoComponents.MessageSignVerifier()) + require.NotNil(t, managedCryptoComponents.ManagedPeersHolder()) } func TestManagedCryptoComponents_CheckSubcomponents(t *testing.T) { diff --git a/factory/crypto/cryptoComponents_test.go b/factory/crypto/cryptoComponents_test.go index 1f0496b129f..dc3b1541a79 100644 --- a/factory/crypto/cryptoComponents_test.go +++ b/factory/crypto/cryptoComponents_test.go @@ -11,11 +11,13 @@ import ( errErd "github.com/multiversx/mx-chain-go/errors" cryptoComp "github.com/multiversx/mx-chain-go/factory/crypto" "github.com/multiversx/mx-chain-go/factory/mock" + integrationTestsMock "github.com/multiversx/mx-chain-go/integrationTests/mock" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestNewCryptoComponentsFactory_NiCoreComponentsHandlerShouldErr(t *testing.T) { +func TestNewCryptoComponentsFactory_NilCoreComponentsHandlerShouldErr(t *testing.T) { t.Parallel() if testing.Short() { t.Skip("this is not a short test") @@ -27,6 +29,22 @@ func TestNewCryptoComponentsFactory_NiCoreComponentsHandlerShouldErr(t *testing. require.Equal(t, errErd.ErrNilCoreComponents, err) } +func TestNewCryptoComponentsFactory_NilValidatorPublicKeyConverterShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCryptoArgs(nil) + args.CoreComponentsHolder = &integrationTestsMock.CoreComponentsStub{ + ValidatorPubKeyConverterField: nil, + } + + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) + require.Nil(t, ccf) + require.Equal(t, errErd.ErrNilPubKeyConverter, err) +} + func TestNewCryptoComponentsFactory_NilPemFileShouldErr(t *testing.T) { t.Parallel() if testing.Short() { @@ -152,6 +170,7 @@ func TestCryptoComponentsFactory_CreateOK(t *testing.T) { cc, err := ccf.Create() require.NoError(t, err) require.NotNil(t, cc) + assert.Equal(t, 0, len(cc.GetManagedPeersHolder().GetManagedKeysByCurrentNode())) } func TestCryptoComponentsFactory_CreateWithDisabledSig(t *testing.T) { @@ -168,6 +187,7 @@ func TestCryptoComponentsFactory_CreateWithDisabledSig(t *testing.T) { cc, err := ccf.Create() require.NoError(t, err) require.NotNil(t, cc) + assert.Equal(t, 0, len(cc.GetManagedPeersHolder().GetManagedKeysByCurrentNode())) } func TestCryptoComponentsFactory_CreateWithAutoGenerateKey(t *testing.T) { @@ -184,6 +204,7 @@ func TestCryptoComponentsFactory_CreateWithAutoGenerateKey(t *testing.T) { cc, err := ccf.Create() require.NoError(t, err) require.NotNil(t, cc) + assert.Equal(t, 0, len(cc.GetManagedPeersHolder().GetManagedKeysByCurrentNode())) } func TestCryptoComponentsFactory_CreateSingleSignerInvalidConsensusTypeShouldErr(t *testing.T) { @@ -411,3 +432,174 @@ func TestCryptoComponentsFactory_GetSkPkOK(t *testing.T) { require.Equal(t, expectedSk, sk) require.Equal(t, expectedPk, pk) } + +func TestCryptoComponentsFactory_MultiKey(t *testing.T) { + t.Parallel() + + t.Run("internal error, LoadAllKeys returns different lengths for private and public keys", func(t *testing.T) { + t.Parallel() + + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + + privateKeys, publicKeys := createBLSPrivatePublicKeys() + + args.KeyLoader = &mock.KeyLoaderStub{ + LoadKeyCalled: func(relativePath string, skIndex int) ([]byte, string, error) { + return privateKeys[0], publicKeys[0], nil + }, + LoadAllKeysCalled: func(path string) ([][]byte, []string, error) { + return privateKeys[2:], publicKeys[1:], nil + }, + } + + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) + require.Nil(t, err) + + cc, err := ccf.Create() + assert.Nil(t, cc) + assert.Contains(t, err.Error(), "key loading error for the allValidatorsKeys file: mismatch number of private and public keys") + }) + t.Run("encoded private key can not be hex decoded", func(t *testing.T) { + t.Parallel() + + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + privateKeys, publicKeys := createBLSPrivatePublicKeys() + + args.KeyLoader = &mock.KeyLoaderStub{ + LoadKeyCalled: func(relativePath string, skIndex int) ([]byte, string, error) { + return privateKeys[0], publicKeys[0], nil + }, + LoadAllKeysCalled: func(path string) ([][]byte, []string, error) { + return [][]byte{[]byte("not a hex")}, []string{"a"}, nil + }, + } + + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) + require.Nil(t, err) + + cc, err := ccf.Create() + assert.Nil(t, cc) + assert.Contains(t, err.Error(), "for encoded secret key, key index 0") + }) + t.Run("encoded public key can not be hex decoded", func(t *testing.T) { + t.Parallel() + + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + privateKeys, publicKeys := createBLSPrivatePublicKeys() + + args.KeyLoader = &mock.KeyLoaderStub{ + LoadKeyCalled: func(relativePath string, skIndex int) ([]byte, string, error) { + return privateKeys[0], publicKeys[0], nil + }, + LoadAllKeysCalled: func(path string) ([][]byte, []string, error) { + return [][]byte{[]byte("aa")}, []string{"not hex"}, nil + }, + } + + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) + require.Nil(t, err) + + cc, err := ccf.Create() + assert.Nil(t, cc) + assert.Contains(t, err.Error(), "for encoded public key not hex, key index 0") + }) + t.Run("not a valid private key", func(t *testing.T) { + t.Parallel() + + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + privateKeys, publicKeys := createBLSPrivatePublicKeys() + + args.KeyLoader = &mock.KeyLoaderStub{ + LoadKeyCalled: func(relativePath string, skIndex int) ([]byte, string, error) { + return privateKeys[0], publicKeys[0], nil + }, + LoadAllKeysCalled: func(path string) ([][]byte, []string, error) { + return [][]byte{[]byte("aa")}, []string{publicKeys[0]}, nil + }, + } + + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) + require.Nil(t, err) + + cc, err := ccf.Create() + assert.Nil(t, cc) + assert.Contains(t, err.Error(), "secret key, key index 0") + }) + t.Run("wrong public string read from file", func(t *testing.T) { + t.Parallel() + + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + privateKeys, publicKeys := createBLSPrivatePublicKeys() + + args.KeyLoader = &mock.KeyLoaderStub{ + LoadKeyCalled: func(relativePath string, skIndex int) ([]byte, string, error) { + return privateKeys[0], publicKeys[0], nil + }, + LoadAllKeysCalled: func(path string) ([][]byte, []string, error) { + return [][]byte{privateKeys[0]}, []string{publicKeys[1]}, nil + }, + } + + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) + require.Nil(t, err) + + cc, err := ccf.Create() + assert.Nil(t, cc) + assert.Contains(t, err.Error(), "public keys mismatch, read "+ + "ae33fdd47ca6eed4b4e33a87beb580a20e908898a88c1d91a8b376cc35e8240e5083696ba6a1eeaa4cf50431980c38086dd7acc535c7571fb952d5c025d27c422fca999eaeaa13451946504d2b0a0c5b08958da236a4877b08abbd8059218f05"+ + ", generated ae33fdd47ca6eed4b4e33a87beb580a20e908898a88c1d91a8b376cc35e8240e5083696ba6a1eeaa4cf50431980c38086dd7acc535c7571fb952d5c025d27c422fca999eaeaa13451946504d2b0a0c5b08958da236a4877b08abbd8059218f05"+ + ", key index 0") + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + + privateKeys, publicKeys := createBLSPrivatePublicKeys() + + args.KeyLoader = &mock.KeyLoaderStub{ + LoadKeyCalled: func(relativePath string, skIndex int) ([]byte, string, error) { + return privateKeys[0], publicKeys[0], nil + }, + LoadAllKeysCalled: func(path string) ([][]byte, []string, error) { + return privateKeys[1:], publicKeys[1:], nil + }, + } + + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) + require.Nil(t, err) + + cc, err := ccf.Create() + require.Nil(t, err) + // should hold num(publicKeys) - 1 as the first was used as the main key of the node + assert.Equal(t, len(publicKeys)-1, len(cc.GetManagedPeersHolder().GetManagedKeysByCurrentNode())) + skBytes, _, _ := ccf.GetSkPk() + assert.NotEqual(t, skBytes, privateKeys[0]) // should generate another private key and not use the one loaded with LoadKey call + }) +} + +func createBLSPrivatePublicKeys() ([][]byte, []string) { + privateKeys := [][]byte{ + []byte("13508f73f4bac43014ca5cdf16903bed4dcfd60f74123346f933e1cd0042ca52"), + []byte("4195470da0224c832b8cb3227fdfa2431fac50efce332dadc2e970c0977f6d3b"), + []byte("a65c1bbef47b833c2c2262cc341ad34745a10034b68e117e60c3391ed3503b49"), + []byte("1bccb646905256d20da48c875ddfa7db92a182bd0b7738560d7df6ead909892e"), + []byte("f9c165744018c13a098b7e4915d24d86379f8a06fab28bd7970092ab5a19fd41"), + } + + publicKeys := []string{ + "c1135f9b0fcae055218cc7916f626f3da33e2ccc0252fd8036be35d4e2d93b9b54a6c355b3e6520b49d32ca005a757156e2a8dc1b14e5c7773a294f6ea1faecae6739e5b3d832eab7f36ff9a6c200ca471a948dcf7671291347b79c3f1b63e93", + "ae33fdd47ca6eed4b4e33a87beb580a20e908898a88c1d91a8b376cc35e8240e5083696ba6a1eeaa4cf50431980c38086dd7acc535c7571fb952d5c025d27c422fca999eaeaa13451946504d2b0a0c5b08958da236a4877b08abbd8059218f05", + "85c447d1a50ac4dd6c8b38dd39ade1126caf16654d59ea8ef7dbec658b36ccd5cbcd946f9c4acccacdde564739f224002ea0a2ce083abd60cafbcae9d817f674966e49c3b3322a2028c64fa74b01610e25e3f9ceb7c2b2077eaed83ca6e08090", + "c550ac126ce520d7a6bbd7d5f375273df8f1a8c6d74f44d3e3b71872fc65e436ef52696674883ff27de8f674bc4c7713c4b7bdabe2292c069e81e5c8e131ea8a90215ee4038a882687e532a8c27dea94c5c2ca9f7f6072163bb0b6151c93b00a", + "1ecad7660e5a77a09661207c7a22f4a84f9be98ac520d4cc875d0caea2fc98f0ab67c2b4966a3ba1cefaa6013517b30f8c43327b46896111886fe2ba10e66d2589aea52e9bf3d72f630c051733fddba412e7c7768b80c8fb7b7e104156db0a0a", + } + + return privateKeys, publicKeys +} diff --git a/factory/crypto/errors.go b/factory/crypto/errors.go index 845d761e423..6d349d4eb2b 100644 --- a/factory/crypto/errors.go +++ b/factory/crypto/errors.go @@ -8,14 +8,11 @@ var ErrInvalidSignature = errors.New("invalid signature was provided") // ErrNilElement is raised when searching for a specific element but found nil var ErrNilElement = errors.New("element is nil") -// ErrIndexNotSelected is raised when a not selected index is used for multi-signing -var ErrIndexNotSelected = errors.New("index is not selected") - // ErrNilBitmap is raised when a nil bitmap is used var ErrNilBitmap = errors.New("bitmap is nil") -// ErrNoPrivateKeySet is raised when no private key was set -var ErrNoPrivateKeySet = errors.New("no private key was set") +// ErrNilKeysHandler is raised when a nil keys handler was provided +var ErrNilKeysHandler = errors.New("nil keys handler") // ErrNoPublicKeySet is raised when no public key was set for a multisignature var ErrNoPublicKeySet = errors.New("no public key was set") @@ -29,6 +26,9 @@ var ErrNilPublicKeys = errors.New("public keys are nil") // ErrNilMultiSignerContainer is raised when a nil multi signer container has been provided var ErrNilMultiSignerContainer = errors.New("multi signer container is nil") +// ErrNilSingleSigner signals that a nil single signer was provided +var ErrNilSingleSigner = errors.New("nil single signer") + // ErrIndexOutOfBounds is raised when an out of bound index is used var ErrIndexOutOfBounds = errors.New("index is out of bounds") diff --git a/factory/crypto/export_test.go b/factory/crypto/export_test.go index 1cac0e0d50f..1a62070f3a7 100644 --- a/factory/crypto/export_test.go +++ b/factory/crypto/export_test.go @@ -2,6 +2,7 @@ package crypto import ( crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" ) @@ -37,3 +38,8 @@ func (ccf *cryptoComponentsFactory) CreateMultiSignerContainer( func (ccf *cryptoComponentsFactory) GetSuite() (crypto.Suite, error) { return ccf.getSuite() } + +// GetManagedPeersHolder - +func (cc *cryptoComponents) GetManagedPeersHolder() common.ManagedPeersHolder { + return cc.managedPeersHolder +} diff --git a/factory/crypto/signatureHolder.go b/factory/crypto/signingHandler.go similarity index 65% rename from factory/crypto/signatureHolder.go rename to factory/crypto/signingHandler.go index 2e0d1ddf673..8b921fc1447 100644 --- a/factory/crypto/signatureHolder.go +++ b/factory/crypto/signingHandler.go @@ -6,32 +6,35 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" crypto "github.com/multiversx/mx-chain-crypto-go" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" + "github.com/multiversx/mx-chain-go/consensus" ) -// ArgsSignatureHolder defines the arguments needed to create a new signature holder component -type ArgsSignatureHolder struct { +// ArgsSigningHandler defines the arguments needed to create a new signing handler component +type ArgsSigningHandler struct { PubKeys []string - PrivKeyBytes []byte MultiSignerContainer cryptoCommon.MultiSignerContainer + SingleSigner crypto.SingleSigner KeyGenerator crypto.KeyGenerator + KeysHandler consensus.KeysHandler } type signatureHolderData struct { pubKeys [][]byte - privKey []byte sigShares [][]byte aggSig []byte } -type signatureHolder struct { +type signingHandler struct { data *signatureHolderData mutSigningData sync.RWMutex multiSignerContainer cryptoCommon.MultiSignerContainer + singleSigner crypto.SingleSigner keyGen crypto.KeyGenerator + keysHandler consensus.KeysHandler } -// NewSignatureHolder will create a new signature holder component -func NewSignatureHolder(args ArgsSignatureHolder) (*signatureHolder, error) { +// NewSigningHandler will create a new signing handler component +func NewSigningHandler(args ArgsSigningHandler) (*signingHandler, error) { err := checkArgs(args) if err != nil { return nil, err @@ -47,24 +50,28 @@ func NewSignatureHolder(args ArgsSignatureHolder) (*signatureHolder, error) { data := &signatureHolderData{ pubKeys: pubKeysBytes, - privKey: args.PrivKeyBytes, sigShares: sigShares, } - return &signatureHolder{ + return &signingHandler{ data: data, mutSigningData: sync.RWMutex{}, multiSignerContainer: args.MultiSignerContainer, + singleSigner: args.SingleSigner, keyGen: args.KeyGenerator, + keysHandler: args.KeysHandler, }, nil } -func checkArgs(args ArgsSignatureHolder) error { +func checkArgs(args ArgsSigningHandler) error { if check.IfNil(args.MultiSignerContainer) { return ErrNilMultiSignerContainer } - if len(args.PrivKeyBytes) == 0 { - return ErrNoPrivateKeySet + if check.IfNil(args.SingleSigner) { + return ErrNilSingleSigner + } + if check.IfNil(args.KeysHandler) { + return ErrNilKeysHandler } if check.IfNil(args.KeyGenerator) { return ErrNilKeyGenerator @@ -77,22 +84,19 @@ func checkArgs(args ArgsSignatureHolder) error { } // Create generates a signature holder component and initializes corresponding fields -func (sh *signatureHolder) Create(pubKeys []string) (*signatureHolder, error) { - sh.mutSigningData.RLock() - privKey := sh.data.privKey - sh.mutSigningData.RUnlock() - - args := ArgsSignatureHolder{ +func (sh *signingHandler) Create(pubKeys []string) (*signingHandler, error) { + args := ArgsSigningHandler{ PubKeys: pubKeys, - PrivKeyBytes: privKey, + KeysHandler: sh.keysHandler, MultiSignerContainer: sh.multiSignerContainer, + SingleSigner: sh.singleSigner, KeyGenerator: sh.keyGen, } - return NewSignatureHolder(args) + return NewSigningHandler(args) } // Reset resets the data inside the signature holder component -func (sh *signatureHolder) Reset(pubKeys []string) error { +func (sh *signingHandler) Reset(pubKeys []string) error { if pubKeys == nil { return ErrNilPublicKeys } @@ -107,11 +111,8 @@ func (sh *signatureHolder) Reset(pubKeys []string) error { sh.mutSigningData.Lock() defer sh.mutSigningData.Unlock() - privKey := sh.data.privKey - data := &signatureHolderData{ pubKeys: pubKeysBytes, - privKey: privKey, sigShares: sigShares, } @@ -120,12 +121,19 @@ func (sh *signatureHolder) Reset(pubKeys []string) error { return nil } -// CreateSignatureShare returns a signature over a message -func (sh *signatureHolder) CreateSignatureShare(message []byte, selfIndex uint16, epoch uint32) ([]byte, error) { +// CreateSignatureShareForPublicKey returns a signature over a message using the managed private key that was selected based on the provided +// publicKeyBytes argument +func (sh *signingHandler) CreateSignatureShareForPublicKey(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { if message == nil { return nil, ErrNilMessage } + privateKey := sh.keysHandler.GetHandledPrivateKey(publicKeyBytes) + privateKeyBytes, err := privateKey.ToByteArray() + if err != nil { + return nil, err + } + sh.mutSigningData.Lock() defer sh.mutSigningData.Unlock() @@ -134,18 +142,36 @@ func (sh *signatureHolder) CreateSignatureShare(message []byte, selfIndex uint16 return nil, err } - sigShareBytes, err := multiSigner.CreateSignatureShare(sh.data.privKey, message) + sigShareBytes, err := multiSigner.CreateSignatureShare(privateKeyBytes, message) if err != nil { return nil, err } - sh.data.sigShares[selfIndex] = sigShareBytes + sh.data.sigShares[index] = sigShareBytes return sigShareBytes, nil } +// CreateSignatureForPublicKey returns a signature over a message using the managed private key that was selected based on the provided +// publicKeyBytes argument +func (sh *signingHandler) CreateSignatureForPublicKey(message []byte, publicKeyBytes []byte) ([]byte, error) { + privateKey := sh.keysHandler.GetHandledPrivateKey(publicKeyBytes) + + return sh.singleSigner.Sign(privateKey, message) +} + +// VerifySingleSignature returns an error if the public key bytes & message provided doesn't match with the signature +func (sh *signingHandler) VerifySingleSignature(publicKeyBytes []byte, message []byte, signature []byte) error { + pk, err := sh.keyGen.PublicKeyFromByteArray(publicKeyBytes) + if err != nil { + return err + } + + return sh.singleSigner.Verify(pk, message, signature) +} + // VerifySignatureShare will verify the signature share based on the specified index -func (sh *signatureHolder) VerifySignatureShare(index uint16, sig []byte, message []byte, epoch uint32) error { +func (sh *signingHandler) VerifySignatureShare(index uint16, sig []byte, message []byte, epoch uint32) error { if len(sig) == 0 { return ErrInvalidSignature } @@ -169,7 +195,7 @@ func (sh *signatureHolder) VerifySignatureShare(index uint16, sig []byte, messag } // StoreSignatureShare stores the partial signature of the signer with specified position -func (sh *signatureHolder) StoreSignatureShare(index uint16, sig []byte) error { +func (sh *signingHandler) StoreSignatureShare(index uint16, sig []byte) error { if len(sig) == 0 { return ErrInvalidSignature } @@ -187,7 +213,7 @@ func (sh *signatureHolder) StoreSignatureShare(index uint16, sig []byte) error { } // SignatureShare returns the partial signature set for given index -func (sh *signatureHolder) SignatureShare(index uint16) ([]byte, error) { +func (sh *signingHandler) SignatureShare(index uint16) ([]byte, error) { sh.mutSigningData.RLock() defer sh.mutSigningData.RUnlock() @@ -203,7 +229,7 @@ func (sh *signatureHolder) SignatureShare(index uint16) ([]byte, error) { } // not concurrent safe, should be used under RLock mutex -func (sh *signatureHolder) isIndexInBitmap(index uint16, bitmap []byte) bool { +func (sh *signingHandler) isIndexInBitmap(index uint16, bitmap []byte) bool { indexOutOfBounds := index >= uint16(len(sh.data.pubKeys)) if indexOutOfBounds { return false @@ -215,7 +241,7 @@ func (sh *signatureHolder) isIndexInBitmap(index uint16, bitmap []byte) bool { } // AggregateSigs aggregates all collected partial signatures -func (sh *signatureHolder) AggregateSigs(bitmap []byte, epoch uint32) ([]byte, error) { +func (sh *signingHandler) AggregateSigs(bitmap []byte, epoch uint32) ([]byte, error) { if bitmap == nil { return nil, ErrNilBitmap } @@ -250,7 +276,7 @@ func (sh *signatureHolder) AggregateSigs(bitmap []byte, epoch uint32) ([]byte, e } // SetAggregatedSig sets the aggregated signature -func (sh *signatureHolder) SetAggregatedSig(aggSig []byte) error { +func (sh *signingHandler) SetAggregatedSig(aggSig []byte) error { sh.mutSigningData.Lock() defer sh.mutSigningData.Unlock() @@ -261,7 +287,7 @@ func (sh *signatureHolder) SetAggregatedSig(aggSig []byte) error { // Verify verifies the aggregated signature by checking that aggregated signature is valid with respect // to aggregated public keys. -func (sh *signatureHolder) Verify(message []byte, bitmap []byte, epoch uint32) error { +func (sh *signingHandler) Verify(message []byte, bitmap []byte, epoch uint32) error { if bitmap == nil { return ErrNilBitmap } @@ -308,6 +334,6 @@ func convertStringsToPubKeysBytes(pubKeys []string) ([][]byte, error) { } // IsInterfaceNil returns true if there is no value under the interface -func (sh *signatureHolder) IsInterfaceNil() bool { +func (sh *signingHandler) IsInterfaceNil() bool { return sh == nil } diff --git a/factory/crypto/signatureHolder_test.go b/factory/crypto/signingHandler_test.go similarity index 61% rename from factory/crypto/signatureHolder_test.go rename to factory/crypto/signingHandler_test.go index da99c646d2e..41033b904ab 100644 --- a/factory/crypto/signatureHolder_test.go +++ b/factory/crypto/signingHandler_test.go @@ -7,85 +7,94 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" crypto "github.com/multiversx/mx-chain-crypto-go" cryptoFactory "github.com/multiversx/mx-chain-go/factory/crypto" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func createMockArgsSignatureHolder() cryptoFactory.ArgsSignatureHolder { - return cryptoFactory.ArgsSignatureHolder{ +func createMockArgsSigningHandler() cryptoFactory.ArgsSigningHandler { + return cryptoFactory.ArgsSigningHandler{ PubKeys: []string{"pubkey1"}, - PrivKeyBytes: []byte("privKey"), + KeysHandler: &testscommon.KeysHandlerStub{}, MultiSignerContainer: &cryptoMocks.MultiSignerContainerMock{}, KeyGenerator: &cryptoMocks.KeyGenStub{}, + SingleSigner: &cryptoMocks.SingleSignerStub{}, } } -func TestNewSigner(t *testing.T) { +func TestNewSigningHandler(t *testing.T) { t.Parallel() t.Run("nil multi signer", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.MultiSignerContainer = nil - signer, err := cryptoFactory.NewSignatureHolder(args) + signer, err := cryptoFactory.NewSigningHandler(args) require.Nil(t, signer) require.Equal(t, cryptoFactory.ErrNilMultiSignerContainer, err) }) + t.Run("nil single signer", func(t *testing.T) { + t.Parallel() + + args := createMockArgsSigningHandler() + args.SingleSigner = nil + signer, err := cryptoFactory.NewSigningHandler(args) + require.Nil(t, signer) + require.Equal(t, cryptoFactory.ErrNilSingleSigner, err) + }) t.Run("nil key generator", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.KeyGenerator = nil - signer, err := cryptoFactory.NewSignatureHolder(args) + signer, err := cryptoFactory.NewSigningHandler(args) require.Nil(t, signer) require.Equal(t, cryptoFactory.ErrNilKeyGenerator, err) }) - - t.Run("nil private key", func(t *testing.T) { + t.Run("nil keys handler", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() - args.PrivKeyBytes = nil + args := createMockArgsSigningHandler() + args.KeysHandler = nil - signer, err := cryptoFactory.NewSignatureHolder(args) + signer, err := cryptoFactory.NewSigningHandler(args) require.Nil(t, signer) - require.Equal(t, cryptoFactory.ErrNoPrivateKeySet, err) + require.Equal(t, cryptoFactory.ErrNilKeysHandler, err) }) - t.Run("no public keys", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{} - signer, err := cryptoFactory.NewSignatureHolder(args) + signer, err := cryptoFactory.NewSigningHandler(args) require.Nil(t, signer) require.Equal(t, cryptoFactory.ErrNoPublicKeySet, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() - signer, err := cryptoFactory.NewSignatureHolder(args) + args := createMockArgsSigningHandler() + signer, err := cryptoFactory.NewSigningHandler(args) require.Nil(t, err) require.False(t, check.IfNil(signer)) }) } -func TestSignatureHolder_Create(t *testing.T) { +func TestSigningHandler_Create(t *testing.T) { t.Parallel() t.Run("empty pubkeys in list", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() - signer, err := cryptoFactory.NewSignatureHolder(args) + signer, err := cryptoFactory.NewSigningHandler(args) require.Nil(t, err) require.NotNil(t, signer) @@ -94,13 +103,12 @@ func TestSignatureHolder_Create(t *testing.T) { require.Nil(t, createdSigner) require.Equal(t, cryptoFactory.ErrEmptyPubKeyString, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() - signer, err := cryptoFactory.NewSignatureHolder(args) + signer, err := cryptoFactory.NewSigningHandler(args) require.Nil(t, err) require.NotNil(t, signer) @@ -111,59 +119,57 @@ func TestSignatureHolder_Create(t *testing.T) { }) } -func TestSignatureHolder_Reset(t *testing.T) { +func TestSigningHandler_Reset(t *testing.T) { t.Parallel() t.Run("nil public keys", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.Reset(nil) require.Equal(t, cryptoFactory.ErrNilPublicKeys, err) }) - t.Run("empty pubkeys in list", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.Reset([]string{"pubKey1", ""}) require.Equal(t, cryptoFactory.ErrEmptyPubKeyString, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.Reset([]string{"pubKey1", "pubKey2"}) require.Nil(t, err) }) } -func TestSignatureHolder_CreateSignatureShare(t *testing.T) { +func TestSigningHandler_CreateSignatureShareForPublicKey(t *testing.T) { t.Parallel() selfIndex := uint16(0) epoch := uint32(0) + pkBytes := []byte("public key bytes") t.Run("nil message", func(t *testing.T) { t.Parallel() - signer, _ := cryptoFactory.NewSignatureHolder(createMockArgsSignatureHolder()) - sigShare, err := signer.CreateSignatureShare(nil, selfIndex, epoch) + signer, _ := cryptoFactory.NewSigningHandler(createMockArgsSigningHandler()) + sigShare, err := signer.CreateSignatureShareForPublicKey(nil, selfIndex, epoch, pkBytes) require.Nil(t, sigShare) require.Equal(t, cryptoFactory.ErrNilMessage, err) }) - t.Run("create sig share failed", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() expectedErr := errors.New("expected error") multiSigner := &cryptoMocks.MultiSignerStub{ @@ -173,16 +179,34 @@ func TestSignatureHolder_CreateSignatureShare(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) - sigShare, err := signer.CreateSignatureShare([]byte("msg1"), selfIndex, epoch) + signer, _ := cryptoFactory.NewSigningHandler(args) + sigShare, err := signer.CreateSignatureShareForPublicKey([]byte("msg1"), selfIndex, epoch, pkBytes) require.Nil(t, sigShare) require.Equal(t, expectedErr, err) }) + t.Run("failed to get current multi signer", func(t *testing.T) { + t.Parallel() + args := createMockArgsSigningHandler() + + expectedErr := errors.New("expected error") + args.MultiSignerContainer = &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return nil, expectedErr + }, + } + + signer, _ := cryptoFactory.NewSigningHandler(args) + + sigShare, err := signer.CreateSignatureShareForPublicKey([]byte("message"), uint16(0), epoch, pkBytes) + require.Nil(t, sigShare) + require.Equal(t, expectedErr, err) + }) t.Run("should work", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() + getHandledPrivateKeyCalled := false expectedSigShare := []byte("sigShare") multiSigner := &cryptoMocks.MultiSignerStub{ @@ -190,16 +214,25 @@ func TestSignatureHolder_CreateSignatureShare(t *testing.T) { return expectedSigShare, nil }, } + args.KeysHandler = &testscommon.KeysHandlerStub{ + GetHandledPrivateKeyCalled: func(providedPkBytes []byte) crypto.PrivateKey { + assert.Equal(t, pkBytes, providedPkBytes) + getHandledPrivateKeyCalled = true + + return &cryptoMocks.PrivateKeyStub{} + }, + } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) - sigShare, err := signer.CreateSignatureShare([]byte("msg1"), selfIndex, epoch) + signer, _ := cryptoFactory.NewSigningHandler(args) + sigShare, err := signer.CreateSignatureShareForPublicKey([]byte("msg1"), selfIndex, epoch, pkBytes) require.Nil(t, err) require.Equal(t, expectedSigShare, sigShare) + assert.True(t, getHandledPrivateKeyCalled) }) } -func TestSignatureHolder_VerifySignatureShare(t *testing.T) { +func TestSigningHandler_VerifySignatureShare(t *testing.T) { t.Parallel() ownIndex := uint16(1) @@ -209,23 +242,21 @@ func TestSignatureHolder_VerifySignatureShare(t *testing.T) { t.Run("invalid signature share", func(t *testing.T) { t.Parallel() - signer, _ := cryptoFactory.NewSignatureHolder(createMockArgsSignatureHolder()) + signer, _ := cryptoFactory.NewSigningHandler(createMockArgsSigningHandler()) err := signer.VerifySignatureShare(ownIndex, nil, msg, epoch) require.Equal(t, cryptoFactory.ErrInvalidSignature, err) }) - t.Run("index out of bounds", func(t *testing.T) { t.Parallel() - signer, _ := cryptoFactory.NewSignatureHolder(createMockArgsSignatureHolder()) + signer, _ := cryptoFactory.NewSigningHandler(createMockArgsSigningHandler()) err := signer.VerifySignatureShare(uint16(3), []byte("sigShare"), msg, epoch) require.Equal(t, cryptoFactory.ErrIndexOutOfBounds, err) }) - t.Run("signature share verification failed", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2"} expectedErr := errors.New("expected error") @@ -236,16 +267,15 @@ func TestSignatureHolder_VerifySignatureShare(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.VerifySignatureShare(uint16(1), []byte("sigShare"), msg, epoch) require.Equal(t, expectedErr, err) }) - t.Run("failed to get current multi signer", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2"} expectedErr := errors.New("expected error") @@ -255,16 +285,15 @@ func TestSignatureHolder_VerifySignatureShare(t *testing.T) { }, } - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.VerifySignatureShare(uint16(1), []byte("sigShare"), msg, epoch) require.Equal(t, expectedErr, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2"} multiSigner := &cryptoMocks.MultiSignerStub{ @@ -274,51 +303,31 @@ func TestSignatureHolder_VerifySignatureShare(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.VerifySignatureShare(uint16(1), []byte("sigShare"), msg, epoch) require.Nil(t, err) }) } -func TestSignatureHolder_StoreSignatureShare(t *testing.T) { +func TestSigningHandler_StoreSignatureShare(t *testing.T) { t.Parallel() ownIndex := uint16(2) - epoch := uint32(0) - msg := []byte("message") t.Run("index out of bounds", func(t *testing.T) { t.Parallel() - signer, _ := cryptoFactory.NewSignatureHolder(createMockArgsSignatureHolder()) - err := signer.StoreSignatureShare(uint16(2), []byte("sigShare")) - require.Equal(t, cryptoFactory.ErrIndexOutOfBounds, err) - }) - - t.Run("failed to get current multi signer", func(t *testing.T) { - t.Parallel() - - args := createMockArgsSignatureHolder() - - expectedErr := errors.New("expected error") - args.MultiSignerContainer = &cryptoMocks.MultiSignerContainerStub{ - GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { - return nil, expectedErr - }, - } - - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, err := cryptoFactory.NewSigningHandler(createMockArgsSigningHandler()) + require.Nil(t, err) - sigShare, err := signer.CreateSignatureShare(msg, uint16(0), epoch) - require.Nil(t, sigShare) - require.Equal(t, expectedErr, err) + err = signer.StoreSignatureShare(uint16(2), []byte("sigShare")) + require.Equal(t, cryptoFactory.ErrIndexOutOfBounds, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} multiSigner := &cryptoMocks.MultiSignerStub{ @@ -328,12 +337,11 @@ func TestSignatureHolder_StoreSignatureShare(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) - sigShare, err := signer.CreateSignatureShare(msg, uint16(0), epoch) - require.Nil(t, err) + sigShare := []byte("signature share") - err = signer.StoreSignatureShare(ownIndex, sigShare) + err := signer.StoreSignatureShare(ownIndex, sigShare) require.Nil(t, err) sigShareRead, err := signer.SignatureShare(ownIndex) @@ -342,7 +350,7 @@ func TestSignatureHolder_StoreSignatureShare(t *testing.T) { }) } -func TestSignatureHolder_SignatureShare(t *testing.T) { +func TestSigningHandler_SignatureShare(t *testing.T) { t.Parallel() t.Run("index out of bounds", func(t *testing.T) { @@ -351,9 +359,9 @@ func TestSignatureHolder_SignatureShare(t *testing.T) { index := uint16(1) sigShare := []byte("sig share") - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) _ = signer.StoreSignatureShare(index, sigShare) @@ -361,16 +369,15 @@ func TestSignatureHolder_SignatureShare(t *testing.T) { require.Nil(t, sigShareRead) require.Equal(t, cryptoFactory.ErrIndexOutOfBounds, err) }) - t.Run("nil element at index", func(t *testing.T) { t.Parallel() ownIndex := uint16(1) - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) _ = signer.StoreSignatureShare(ownIndex, nil) @@ -378,17 +385,16 @@ func TestSignatureHolder_SignatureShare(t *testing.T) { require.Nil(t, sigShareRead) require.Equal(t, cryptoFactory.ErrNilElement, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() ownIndex := uint16(1) sigShare := []byte("sig share") - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) _ = signer.StoreSignatureShare(ownIndex, sigShare) @@ -398,7 +404,7 @@ func TestSignatureHolder_SignatureShare(t *testing.T) { }) } -func TestSignatureHolder_AggregateSigs(t *testing.T) { +func TestSigningHandler_AggregateSigs(t *testing.T) { t.Parallel() epoch := uint32(0) @@ -406,40 +412,38 @@ func TestSignatureHolder_AggregateSigs(t *testing.T) { t.Run("nil bitmap", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) aggSig, err := signer.AggregateSigs(nil, epoch) require.Nil(t, aggSig) require.Equal(t, cryptoFactory.ErrNilBitmap, err) }) - t.Run("bitmap mismatch", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4", "pk5", "pk6", "pk7", "pk8", "pk9"} - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) aggSig, err := signer.AggregateSigs(bitmap, epoch) require.Nil(t, aggSig) require.Equal(t, cryptoFactory.ErrBitmapMismatch, err) }) - t.Run("failed to get aggregated sig", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} expectedErr := errors.New("expected error") @@ -450,7 +454,7 @@ func TestSignatureHolder_AggregateSigs(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) for i := 0; i < len(args.PubKeys); i++ { _ = signer.StoreSignatureShare(uint16(i), []byte("sigShare")) @@ -460,14 +464,13 @@ func TestSignatureHolder_AggregateSigs(t *testing.T) { require.Nil(t, aggSig) require.Equal(t, expectedErr, err) }) - t.Run("failed to get current multi signer", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() expectedErr := errors.New("expected error") args.MultiSignerContainer = &cryptoMocks.MultiSignerContainerStub{ @@ -476,20 +479,19 @@ func TestSignatureHolder_AggregateSigs(t *testing.T) { }, } - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) aggSig, err := signer.AggregateSigs(bitmap, epoch) require.Nil(t, aggSig) require.Equal(t, expectedErr, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} expectedAggSig := []byte("agg sig") @@ -502,7 +504,7 @@ func TestSignatureHolder_AggregateSigs(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) for i := 0; i < len(args.PubKeys); i++ { _ = signer.StoreSignatureShare(uint16(i), []byte("sigShare")) @@ -514,7 +516,7 @@ func TestSignatureHolder_AggregateSigs(t *testing.T) { }) } -func TestSignatureHolder_Verify(t *testing.T) { +func TestSigningHandler_Verify(t *testing.T) { t.Parallel() message := []byte("message") @@ -523,38 +525,36 @@ func TestSignatureHolder_Verify(t *testing.T) { t.Run("verify agg sig should fail", func(t *testing.T) { t.Parallel() - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.Verify(message, nil, epoch) require.Equal(t, cryptoFactory.ErrNilBitmap, err) }) - t.Run("bitmap mismatch", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4", "pk5", "pk6", "pk7", "pk8", "pk9"} - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.Verify(message, bitmap, epoch) require.Equal(t, cryptoFactory.ErrBitmapMismatch, err) }) - t.Run("verify agg sig should fail", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} expectedErr := errors.New("expected error") @@ -565,19 +565,18 @@ func TestSignatureHolder_Verify(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.Verify(message, bitmap, epoch) require.Equal(t, expectedErr, err) }) - t.Run("failed to get current multi signer", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() expectedErr := errors.New("expected error") args.MultiSignerContainer = &cryptoMocks.MultiSignerContainerStub{ @@ -586,19 +585,18 @@ func TestSignatureHolder_Verify(t *testing.T) { }, } - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) err := signer.Verify(message, bitmap, epoch) require.Equal(t, expectedErr, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() bitmap := make([]byte, 1) bitmap[0] = 0x07 - args := createMockArgsSignatureHolder() + args := createMockArgsSigningHandler() args.PubKeys = []string{"pk1", "pk2", "pk3", "pk4"} expAggSig := []byte("aggSig") @@ -612,7 +610,7 @@ func TestSignatureHolder_Verify(t *testing.T) { } args.MultiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signer, _ := cryptoFactory.NewSignatureHolder(args) + signer, _ := cryptoFactory.NewSigningHandler(args) _ = signer.SetAggregatedSig(expAggSig) @@ -620,3 +618,86 @@ func TestSignatureHolder_Verify(t *testing.T) { require.Nil(t, err) }) } + +func TestSigningHandler_CreateSignatureForPublicKey(t *testing.T) { + t.Parallel() + + args := createMockArgsSigningHandler() + getHandledPrivateKeyCalled := false + pkBytes := []byte("public key bytes") + + expectedSigShare := []byte("sigShare") + args.KeysHandler = &testscommon.KeysHandlerStub{ + GetHandledPrivateKeyCalled: func(providedPkBytes []byte) crypto.PrivateKey { + assert.Equal(t, pkBytes, providedPkBytes) + getHandledPrivateKeyCalled = true + + return &cryptoMocks.PrivateKeyStub{} + }, + } + args.SingleSigner = &cryptoMocks.SingleSignerStub{ + SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { + return expectedSigShare, nil + }, + } + + signer, _ := cryptoFactory.NewSigningHandler(args) + sigShare, err := signer.CreateSignatureForPublicKey([]byte("msg1"), pkBytes) + require.Nil(t, err) + require.Equal(t, expectedSigShare, sigShare) + assert.True(t, getHandledPrivateKeyCalled) +} + +func TestSigningHandler_VerifySingleSignature(t *testing.T) { + t.Parallel() + + t.Run("not a valid public key should error", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected error") + args := createMockArgsSigningHandler() + args.KeyGenerator = &cryptoMocks.KeyGenStub{ + PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { + return nil, expectedErr + }, + } + + signer, _ := cryptoFactory.NewSigningHandler(args) + + err := signer.VerifySingleSignature([]byte("pk"), []byte("msg"), []byte("sig")) + assert.Equal(t, expectedErr, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + providedPkBytes := []byte("pk") + providedMsg := []byte("msg") + providedSig := []byte("sig") + pk := &cryptoMocks.PublicKeyStub{} + + verifyCalled := false + args := createMockArgsSigningHandler() + args.KeyGenerator = &cryptoMocks.KeyGenStub{ + PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { + assert.Equal(t, providedPkBytes, b) + return pk, nil + }, + } + args.SingleSigner = &cryptoMocks.SingleSignerStub{ + VerifyCalled: func(public crypto.PublicKey, msg []byte, sig []byte) error { + assert.Equal(t, pk, public) + assert.Equal(t, providedMsg, msg) + assert.Equal(t, providedSig, sig) + verifyCalled = true + + return nil + }, + } + + signer, _ := cryptoFactory.NewSigningHandler(args) + + err := signer.VerifySingleSignature(providedPkBytes, providedMsg, providedSig) + assert.Nil(t, err) + assert.True(t, verifyCalled) + }) +} diff --git a/factory/data/dataComponents.go b/factory/data/dataComponents.go index 70348f9ff81..8179e5db715 100644 --- a/factory/data/dataComponents.go +++ b/factory/data/dataComponents.go @@ -25,9 +25,10 @@ type DataComponentsFactoryArgs struct { ShardCoordinator sharding.Coordinator Core factory.CoreComponentsHolder StatusCore factory.StatusCoreComponentsHolder - EpochStartNotifier factory.EpochStartNotifier + Crypto factory.CryptoComponentsHolder CurrentEpoch uint32 CreateTrieEpochRootHashStorer bool + SnapshotsEnabled bool } type dataComponentsFactory struct { @@ -35,10 +36,11 @@ type dataComponentsFactory struct { prefsConfig config.PreferencesConfig shardCoordinator sharding.Coordinator core factory.CoreComponentsHolder - epochStartNotifier factory.EpochStartNotifier statusCore factory.StatusCoreComponentsHolder + crypto factory.CryptoComponentsHolder currentEpoch uint32 createTrieEpochRootHashStorer bool + snapshotsEnabled bool } // dataComponents struct holds the data components @@ -62,7 +64,7 @@ func NewDataComponentsFactory(args DataComponentsFactoryArgs) (*dataComponentsFa if check.IfNil(args.Core.PathHandler()) { return nil, errors.ErrNilPathHandler } - if check.IfNil(args.EpochStartNotifier) { + if check.IfNil(args.Core.EpochStartNotifierWithConfirm()) { return nil, errors.ErrNilEpochStartNotifier } if check.IfNil(args.Core.EconomicsData()) { @@ -74,6 +76,12 @@ func NewDataComponentsFactory(args DataComponentsFactoryArgs) (*dataComponentsFa if check.IfNil(args.StatusCore.AppStatusHandler()) { return nil, errors.ErrNilAppStatusHandler } + if check.IfNil(args.Crypto) { + return nil, errors.ErrNilCryptoComponents + } + if check.IfNil(args.Crypto.ManagedPeersHolder()) { + return nil, errors.ErrNilManagedPeersHolder + } return &dataComponentsFactory{ config: args.Config, @@ -81,9 +89,10 @@ func NewDataComponentsFactory(args DataComponentsFactoryArgs) (*dataComponentsFa shardCoordinator: args.ShardCoordinator, core: args.Core, statusCore: args.StatusCore, - epochStartNotifier: args.EpochStartNotifier, currentEpoch: args.CurrentEpoch, createTrieEpochRootHashStorer: args.CreateTrieEpochRootHashStorer, + snapshotsEnabled: args.SnapshotsEnabled, + crypto: args.Crypto, }, nil } @@ -167,11 +176,13 @@ func (dcf *dataComponentsFactory) createDataStoreFromConfig() (dataRetriever.Sto PrefsConfig: dcf.prefsConfig, ShardCoordinator: dcf.shardCoordinator, PathManager: dcf.core.PathHandler(), - EpochStartNotifier: dcf.epochStartNotifier, + EpochStartNotifier: dcf.core.EpochStartNotifierWithConfirm(), NodeTypeProvider: dcf.core.NodeTypeProvider(), CurrentEpoch: dcf.currentEpoch, StorageType: storageFactory.ProcessStorageService, CreateTrieEpochRootHashStorer: dcf.createTrieEpochRootHashStorer, + SnapshotsEnabled: dcf.snapshotsEnabled, + ManagedPeersHolder: dcf.crypto.ManagedPeersHolder(), }) if err != nil { return nil, err diff --git a/factory/data/dataComponents_test.go b/factory/data/dataComponents_test.go index a5de1d1d442..2eac8430020 100644 --- a/factory/data/dataComponents_test.go +++ b/factory/data/dataComponents_test.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-go/errors" dataComp "github.com/multiversx/mx-chain-go/factory/data" "github.com/multiversx/mx-chain-go/factory/mock" + "github.com/multiversx/mx-chain-go/testscommon" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -43,6 +44,56 @@ func TestNewDataComponentsFactory_NilCoreComponentsShouldErr(t *testing.T) { require.Equal(t, errors.ErrNilCoreComponents, err) } +func TestNewDataComponentsFactory_NilCryptoComponentsShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args := componentsMock.GetDataArgs(componentsMock.GetCoreComponents(), shardCoordinator) + args.Crypto = nil + + dcf, err := dataComp.NewDataComponentsFactory(args) + require.Nil(t, dcf) + require.Equal(t, errors.ErrNilCryptoComponents, err) +} + +func TestNewDataComponentsFactory_NilManagedPeersHolderShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args := componentsMock.GetDataArgs(componentsMock.GetCoreComponents(), shardCoordinator) + args.Crypto = &mock.CryptoComponentsMock{ + ManagedPeersHolderField: nil, + } + + dcf, err := dataComp.NewDataComponentsFactory(args) + require.Nil(t, dcf) + require.Equal(t, errors.ErrNilManagedPeersHolder, err) +} + +func TestNewDataComponentsFactory_NilPathHandlerShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + coreComponents := &mock.CoreComponentsMock{ + PathHdl: nil, + } + + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) + + dcf, err := dataComp.NewDataComponentsFactory(args) + require.Nil(t, dcf) + require.Equal(t, errors.ErrNilPathHandler, err) +} + func TestNewDataComponentsFactory_NilEpochStartNotifierShouldErr(t *testing.T) { t.Parallel() if testing.Short() { @@ -50,9 +101,12 @@ func TestNewDataComponentsFactory_NilEpochStartNotifierShouldErr(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - coreComponents := componentsMock.GetCoreComponents() + coreComponents := &mock.CoreComponentsMock{ + PathHdl: &testscommon.PathManagerStub{}, + EpochNotifierWithConfirm: nil, + } + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) - args.EpochStartNotifier = nil dcf, err := dataComp.NewDataComponentsFactory(args) require.Nil(t, dcf) diff --git a/factory/heartbeat/heartbeatV2Components.go b/factory/heartbeat/heartbeatV2Components.go index 178e935b962..c8f62080c9b 100644 --- a/factory/heartbeat/heartbeatV2Components.go +++ b/factory/heartbeat/heartbeatV2Components.go @@ -27,6 +27,7 @@ var log = logger.GetOrCreate("factory") type ArgHeartbeatV2ComponentsFactory struct { Config config.Config Prefs config.Preferences + BaseVersion string AppVersion string BootstrapComponents factory.BootstrapComponentsHolder CoreComponents factory.CoreComponentsHolder @@ -40,6 +41,7 @@ type ArgHeartbeatV2ComponentsFactory struct { type heartbeatV2ComponentsFactory struct { config config.Config prefs config.Preferences + baseVersion string version string bootstrapComponents factory.BootstrapComponentsHolder coreComponents factory.CoreComponentsHolder @@ -69,6 +71,7 @@ func NewHeartbeatV2ComponentsFactory(args ArgHeartbeatV2ComponentsFactory) (*hea return &heartbeatV2ComponentsFactory{ config: args.Config, prefs: args.Prefs, + baseVersion: args.BaseVersion, version: args.AppVersion, bootstrapComponents: args.BootstrapComponents, coreComponents: args.CoreComponents, @@ -162,6 +165,7 @@ func (hcf *heartbeatV2ComponentsFactory) Create() (*heartbeatV2Components, error HeartbeatTimeBetweenSends: time.Second * time.Duration(cfg.HeartbeatTimeBetweenSendsInSec), HeartbeatTimeBetweenSendsWhenError: time.Second * time.Duration(cfg.HeartbeatTimeBetweenSendsWhenErrorInSec), HeartbeatTimeThresholdBetweenSends: cfg.HeartbeatTimeThresholdBetweenSends, + BaseVersionNumber: hcf.baseVersion, VersionNumber: hcf.version, NodeDisplayName: hcf.prefs.Preferences.NodeDisplayName, Identity: hcf.prefs.Preferences.Identity, @@ -175,6 +179,9 @@ func (hcf *heartbeatV2ComponentsFactory) Create() (*heartbeatV2Components, error HardforkTimeBetweenSends: time.Second * time.Duration(cfg.HardforkTimeBetweenSendsInSec), HardforkTriggerPubKey: hcf.coreComponents.HardforkTriggerPubKey(), PeerTypeProvider: peerTypeProvider, + ManagedPeersHolder: hcf.cryptoComponents.ManagedPeersHolder(), + PeerAuthenticationTimeBetweenChecks: time.Second * time.Duration(cfg.PeerAuthenticationTimeBetweenChecksInSec), + ShardCoordinator: hcf.processComponents.ShardCoordinator(), } heartbeatV2Sender, err := sender.NewSender(argsSender) if err != nil { diff --git a/factory/heartbeat/heartbeatV2Components_test.go b/factory/heartbeat/heartbeatV2Components_test.go index 667115e5040..b6102a67c05 100644 --- a/factory/heartbeat/heartbeatV2Components_test.go +++ b/factory/heartbeat/heartbeatV2Components_test.go @@ -52,6 +52,7 @@ func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2Co HardforkTimeBetweenSendsInSec: 5, TimeBetweenConnectionsMetricsUpdateInSec: 10, TimeToReadDirectConnectionsInSec: 15, + PeerAuthenticationTimeBetweenChecksInSec: 6, HeartbeatPool: config.CacheConfig{ Type: "LRU", Capacity: 1000, @@ -68,6 +69,7 @@ func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2Co Identity: "identity", }, }, + BaseVersion: "test-base", AppVersion: "test", BootstrapComponents: bootstrapC, CoreComponents: coreC, diff --git a/factory/interface.go b/factory/interface.go index 1f2c36f7a5b..7b26670b8db 100644 --- a/factory/interface.go +++ b/factory/interface.go @@ -184,7 +184,9 @@ type CryptoComponentsHolder interface { TxSignKeyGen() crypto.KeyGenerator P2pKeyGen() crypto.KeyGenerator MessageSignVerifier() vm.MessageSignVerifier - ConsensusSigHandler() consensus.SignatureHandler + ConsensusSigningHandler() consensus.SigningHandler + ManagedPeersHolder() common.ManagedPeersHolder + KeysHandler() consensus.KeysHandler Clone() interface{} IsInterfaceNil() bool } @@ -192,6 +194,8 @@ type CryptoComponentsHolder interface { // KeyLoaderHandler defines the loading of a key from a pem file and index type KeyLoaderHandler interface { LoadKey(string, int) ([]byte, string, error) + LoadAllKeys(path string) ([][]byte, []string, error) + IsInterfaceNil() bool } // CryptoComponentsHandler defines the crypto components handler actions diff --git a/factory/mock/cryptoComponentsMock.go b/factory/mock/cryptoComponentsMock.go index 863f18ebebe..fce869adbd7 100644 --- a/factory/mock/cryptoComponentsMock.go +++ b/factory/mock/cryptoComponentsMock.go @@ -5,6 +5,7 @@ import ( "sync" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/vm" @@ -12,24 +13,26 @@ import ( // CryptoComponentsMock - type CryptoComponentsMock struct { - PubKey crypto.PublicKey - PrivKey crypto.PrivateKey - P2pPubKey crypto.PublicKey - P2pPrivKey crypto.PrivateKey - P2pSig crypto.SingleSigner - PubKeyString string - PrivKeyBytes []byte - PubKeyBytes []byte - BlockSig crypto.SingleSigner - TxSig crypto.SingleSigner - MultiSigContainer cryptoCommon.MultiSignerContainer - PeerSignHandler crypto.PeerSignatureHandler - BlKeyGen crypto.KeyGenerator - TxKeyGen crypto.KeyGenerator - P2PKeyGen crypto.KeyGenerator - MsgSigVerifier vm.MessageSignVerifier - SigHandler consensus.SignatureHandler - mutMultiSig sync.RWMutex + PubKey crypto.PublicKey + PrivKey crypto.PrivateKey + P2pPubKey crypto.PublicKey + P2pPrivKey crypto.PrivateKey + P2pSig crypto.SingleSigner + PubKeyString string + PrivKeyBytes []byte + PubKeyBytes []byte + BlockSig crypto.SingleSigner + TxSig crypto.SingleSigner + MultiSigContainer cryptoCommon.MultiSignerContainer + PeerSignHandler crypto.PeerSignatureHandler + BlKeyGen crypto.KeyGenerator + TxKeyGen crypto.KeyGenerator + P2PKeyGen crypto.KeyGenerator + MsgSigVerifier vm.MessageSignVerifier + SigHandler consensus.SigningHandler + ManagedPeersHolderField common.ManagedPeersHolder + KeysHandlerField consensus.KeysHandler + mutMultiSig sync.RWMutex } // PublicKey - @@ -139,28 +142,40 @@ func (ccm *CryptoComponentsMock) MessageSignVerifier() vm.MessageSignVerifier { return ccm.MsgSigVerifier } -// ConsensusSigHandler - -func (ccm *CryptoComponentsMock) ConsensusSigHandler() consensus.SignatureHandler { +// ConsensusSigningHandler - +func (ccm *CryptoComponentsMock) ConsensusSigningHandler() consensus.SigningHandler { return ccm.SigHandler } +// ManagedPeersHolder - +func (ccm *CryptoComponentsMock) ManagedPeersHolder() common.ManagedPeersHolder { + return ccm.ManagedPeersHolderField +} + +// KeysHandler - +func (ccm *CryptoComponentsMock) KeysHandler() consensus.KeysHandler { + return ccm.KeysHandlerField +} + // Clone - func (ccm *CryptoComponentsMock) Clone() interface{} { return &CryptoComponentsMock{ - PubKey: ccm.PubKey, - PrivKey: ccm.PrivKey, - PubKeyString: ccm.PubKeyString, - PrivKeyBytes: ccm.PrivKeyBytes, - PubKeyBytes: ccm.PubKeyBytes, - BlockSig: ccm.BlockSig, - TxSig: ccm.TxSig, - MultiSigContainer: ccm.MultiSigContainer, - PeerSignHandler: ccm.PeerSignHandler, - BlKeyGen: ccm.BlKeyGen, - TxKeyGen: ccm.TxKeyGen, - P2PKeyGen: ccm.P2PKeyGen, - MsgSigVerifier: ccm.MsgSigVerifier, - mutMultiSig: sync.RWMutex{}, + PubKey: ccm.PubKey, + PrivKey: ccm.PrivKey, + PubKeyString: ccm.PubKeyString, + PrivKeyBytes: ccm.PrivKeyBytes, + PubKeyBytes: ccm.PubKeyBytes, + BlockSig: ccm.BlockSig, + TxSig: ccm.TxSig, + MultiSigContainer: ccm.MultiSigContainer, + PeerSignHandler: ccm.PeerSignHandler, + BlKeyGen: ccm.BlKeyGen, + TxKeyGen: ccm.TxKeyGen, + P2PKeyGen: ccm.P2PKeyGen, + MsgSigVerifier: ccm.MsgSigVerifier, + ManagedPeersHolderField: ccm.ManagedPeersHolderField, + KeysHandlerField: ccm.KeysHandlerField, + mutMultiSig: sync.RWMutex{}, } } diff --git a/factory/mock/keyLoaderStub.go b/factory/mock/keyLoaderStub.go index 7adc49a24c6..33d0e4955ff 100644 --- a/factory/mock/keyLoaderStub.go +++ b/factory/mock/keyLoaderStub.go @@ -1,8 +1,11 @@ package mock +import "errors" + // KeyLoaderStub - type KeyLoaderStub struct { - LoadKeyCalled func(relativePath string, skIndex int) ([]byte, string, error) + LoadKeyCalled func(relativePath string, skIndex int) ([]byte, string, error) + LoadAllKeysCalled func(path string) ([][]byte, []string, error) } // LoadKey - @@ -13,3 +16,17 @@ func (kl *KeyLoaderStub) LoadKey(relativePath string, skIndex int) ([]byte, stri return nil, "", nil } + +// LoadAllKeys - +func (kl *KeyLoaderStub) LoadAllKeys(path string) ([][]byte, []string, error) { + if kl.LoadAllKeysCalled != nil { + return kl.LoadAllKeysCalled(path) + } + + return nil, nil, errors.New("not implemented") +} + +// IsInterfaceNil - +func (kl *KeyLoaderStub) IsInterfaceNil() bool { + return kl == nil +} diff --git a/factory/processing/processComponents.go b/factory/processing/processComponents.go index 0cccdb27d6e..24d35bbd61a 100644 --- a/factory/processing/processComponents.go +++ b/factory/processing/processComponents.go @@ -21,10 +21,10 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" "github.com/multiversx/mx-chain-go/dataRetriever/factory/epochProviders" - "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" + requesterscontainer "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" disabledResolversContainer "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer/disabled" - "github.com/multiversx/mx-chain-go/dataRetriever/factory/storageRequestersContainer" + storagerequesterscontainer "github.com/multiversx/mx-chain-go/dataRetriever/factory/storageRequestersContainer" "github.com/multiversx/mx-chain-go/dataRetriever/requestHandlers" "github.com/multiversx/mx-chain-go/dblookupext" "github.com/multiversx/mx-chain-go/epochStart" @@ -143,6 +143,7 @@ type ProcessComponentsFactoryArgs struct { ImportStartHandler update.ImportStartHandler WorkingDir string HistoryRepo dblookupext.HistoryRepository + SnapshotsEnabled bool Data factory.DataComponentsHolder CoreData factory.CoreComponentsHolder @@ -175,6 +176,7 @@ type processComponentsFactory struct { historyRepo dblookupext.HistoryRepository epochNotifier process.EpochNotifier importHandler update.ImportHandler + snapshotsEnabled bool esdtNftStorage vmcommon.ESDTNFTStorageHandler data factory.DataComponentsHolder @@ -221,6 +223,7 @@ func NewProcessComponentsFactory(args ProcessComponentsFactoryArgs) (*processCom historyRepo: args.HistoryRepo, epochNotifier: args.CoreData.EpochNotifier(), statusCoreComponents: args.StatusCoreComponents, + snapshotsEnabled: args.SnapshotsEnabled, }, nil } @@ -401,6 +404,10 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { if err != nil { return nil, err } + err = dataRetriever.SetEpochHandlerToHdrRequester(requestersContainer, epochStartTrigger) + if err != nil { + return nil, err + } log.Debug("Validator stats created", "validatorStatsRootHash", validatorStatsRootHash) @@ -1496,6 +1503,8 @@ func (pcf *processComponentsFactory) newStorageRequesters() (dataRetriever.Reque CurrentEpoch: pcf.bootstrapComponents.EpochBootstrapParams().Epoch(), StorageType: storageFactory.ProcessStorageService, CreateTrieEpochRootHashStorer: false, + SnapshotsEnabled: pcf.snapshotsEnabled, + ManagedPeersHolder: pcf.crypto.ManagedPeersHolder(), }, ) if err != nil { @@ -1548,6 +1557,7 @@ func (pcf *processComponentsFactory) createStorageRequestersForMeta( DataPacker: dataPacker, ManualEpochStartNotifier: manualEpochStartNotifier, ChanGracefullyClose: pcf.coreData.ChanStopNodeProcess(), + SnapshotsEnabled: pcf.snapshotsEnabled, } requestersContainerFactory, err := storagerequesterscontainer.NewMetaRequestersContainerFactory(requestersContainerFactoryArgs) if err != nil { @@ -1580,6 +1590,7 @@ func (pcf *processComponentsFactory) createStorageRequestersForShard( DataPacker: dataPacker, ManualEpochStartNotifier: manualEpochStartNotifier, ChanGracefullyClose: pcf.coreData.ChanStopNodeProcess(), + SnapshotsEnabled: pcf.snapshotsEnabled, } requestersContainerFactory, err := storagerequesterscontainer.NewShardRequestersContainerFactory(requestersContainerFactoryArgs) if err != nil { @@ -1927,6 +1938,9 @@ func checkProcessComponentsArgs(args ProcessComponentsFactoryArgs) error { if check.IfNil(args.StatusCoreComponents.AppStatusHandler()) { return fmt.Errorf("%s: %w", baseErrMessage, errErd.ErrNilAppStatusHandler) } + if check.IfNil(args.Crypto.ManagedPeersHolder()) { + return fmt.Errorf("%s: %w", baseErrMessage, errErd.ErrNilManagedPeersHolder) + } return nil } diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index 98afa8a2d38..6d7bb36bbd6 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -29,6 +29,7 @@ type StateComponentsFactoryArgs struct { StorageService dataRetriever.StorageService ProcessingMode common.NodeProcessingMode ShouldSerializeSnapshots bool + SnapshotsEnabled bool ChainHandler chainData.ChainHandler } @@ -40,6 +41,7 @@ type stateComponentsFactory struct { storageService dataRetriever.StorageService processingMode common.NodeProcessingMode shouldSerializeSnapshots bool + snapshotsEnabled bool chainHandler chainData.ChainHandler } @@ -92,12 +94,14 @@ func NewStateComponentsFactory(args StateComponentsFactoryArgs) (*stateComponent processingMode: args.ProcessingMode, shouldSerializeSnapshots: args.ShouldSerializeSnapshots, chainHandler: args.ChainHandler, + snapshotsEnabled: args.SnapshotsEnabled, }, nil } // Create creates the state components func (scf *stateComponentsFactory) Create() (*stateComponents, error) { triesContainer, trieStorageManagers, err := trieFactory.CreateTriesComponentsForShardId( + scf.snapshotsEnabled, scf.config, scf.core, scf.storageService, diff --git a/go.mod b/go.mod index dab5fc8ac3c..c31a163b798 100644 --- a/go.mod +++ b/go.mod @@ -7,28 +7,28 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/pprof v1.4.0 - github.com/gin-gonic/gin v1.8.1 + github.com/gin-gonic/gin v1.9.0 github.com/gizak/termui/v3 v3.1.0 github.com/gogo/protobuf v1.3.2 github.com/google/gops v0.3.18 github.com/gorilla/websocket v1.5.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/multiversx/mx-chain-core-go v1.1.34-0.20230215164556-5ec8f51e96d3 + github.com/multiversx/mx-chain-core-go v1.1.36-0.20230308081722-5262fb09cb9a github.com/multiversx/mx-chain-crypto-go v1.2.5 github.com/multiversx/mx-chain-es-indexer-go v1.3.13-0.20230216143122-cee11ea0d0f5 github.com/multiversx/mx-chain-logger-go v1.0.11 - github.com/multiversx/mx-chain-p2p-go v1.0.11 + github.com/multiversx/mx-chain-p2p-go v1.0.13 github.com/multiversx/mx-chain-storage-go v1.0.7 - github.com/multiversx/mx-chain-vm-common-go v1.3.37-0.20230216122352-6dd6ff58ca2d + github.com/multiversx/mx-chain-vm-common-go v1.3.38-0.20230308082221-60b880366741 github.com/multiversx/mx-chain-vm-v1_2-go v1.2.50 github.com/multiversx/mx-chain-vm-v1_3-go v1.3.51 - github.com/multiversx/mx-chain-vm-v1_4-go v1.4.76 + github.com/multiversx/mx-chain-vm-v1_4-go v1.4.77 github.com/pelletier/go-toml v1.9.3 github.com/pkg/errors v0.9.1 github.com/shirou/gopsutil v3.21.11+incompatible github.com/stretchr/testify v1.8.1 github.com/urfave/cli v1.22.10 - golang.org/x/crypto v0.3.0 + golang.org/x/crypto v0.5.0 gopkg.in/go-playground/validator.v8 v8.18.2 ) @@ -38,8 +38,10 @@ require ( github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect github.com/btcsuite/btcd/btcutil v1.1.3 // indirect github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 // indirect + github.com/bytedance/sonic v1.8.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cheekybits/genny v1.0.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/containerd/cgroups v1.0.4 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect @@ -54,11 +56,11 @@ require ( github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect - github.com/go-playground/locales v0.14.0 // indirect - github.com/go-playground/universal-translator v0.18.0 // indirect - github.com/go-playground/validator/v10 v10.10.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.11.2 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect - github.com/goccy/go-json v0.9.7 // indirect + github.com/goccy/go-json v0.10.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect @@ -107,7 +109,7 @@ require ( github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect - github.com/mattn/go-isatty v0.0.16 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect github.com/mattn/go-pointer v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.2 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect @@ -137,7 +139,7 @@ require ( github.com/opencontainers/runtime-spec v1.0.2 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect - github.com/pelletier/go-toml/v2 v2.0.1 // indirect + github.com/pelletier/go-toml/v2 v2.0.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/polydawn/refmt v0.0.0-20190807091052-3d65705ee9f1 // indirect github.com/prometheus/client_golang v1.12.1 // indirect @@ -155,7 +157,8 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/tklauser/go-sysconf v0.3.4 // indirect github.com/tklauser/numcpus v0.2.1 // indirect - github.com/ugorji/go/codec v1.2.7 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.9 // indirect github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 // indirect github.com/whyrusleeping/timecache v0.0.0-20160911033111-cfcb2f1abfee // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect @@ -163,16 +166,16 @@ require ( go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.22.0 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/net v0.2.0 // indirect + golang.org/x/net v0.7.0 // indirect golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect - golang.org/x/sys v0.2.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect golang.org/x/tools v0.1.12 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.1.7 // indirect ) diff --git a/go.sum b/go.sum index 7398c961c33..1efac90c48c 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,9 @@ github.com/btcsuite/snappy-go v1.0.0/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY= github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.8.0 h1:ea0Xadu+sHlu7x5O3gKhRpQ1IKiMrSiHttPF0ybECuA= +github.com/bytedance/sonic v1.8.0/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= @@ -99,6 +102,9 @@ github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cb github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -172,8 +178,9 @@ github.com/gin-contrib/pprof v1.4.0 h1:XxiBSf5jWZ5i16lNOPbMTVdgHBdhfGRD5PZ1LWazz github.com/gin-contrib/pprof v1.4.0/go.mod h1:RrehPJasUVBPK6yTUwOl8/NP6i0vbUgmxtis+Z5KE90= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= +github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8= +github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k= github.com/gizak/termui/v3 v3.1.0 h1:ZZmVDgwHl7gR7elfKf1xc4IudXZ5qqfDh4wExk4Iajc= github.com/gizak/termui/v3 v3.1.0/go.mod h1:bXQEBkJpzxUAKf0+xq9MSWAvWZlE7c+aidmyFlkYTrY= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= @@ -196,19 +203,24 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= -github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= -github.com/go-playground/validator/v10 v10.10.0 h1:I7mrTYv78z8k8VXa/qJlOlEXn/nBh+BF8dHX5nt/dr0= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= +github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= +github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA= +github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= @@ -308,7 +320,6 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4= github.com/hashicorp/golang-lru v0.6.0/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= @@ -505,8 +516,9 @@ github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcncea github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0= github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc= github.com/mattn/go-runewidth v0.0.2 h1:UnlwIPBGaTZfPQ6T1IGzPI0EkYAQmT9fAEJ/poFC63o= @@ -599,27 +611,30 @@ github.com/multiversx/concurrent-map v0.1.4 h1:hdnbM8VE4b0KYJaGY5yJS2aNIW9TFFsUY github.com/multiversx/concurrent-map v0.1.4/go.mod h1:8cWFRJDOrWHOTNSqgYCUvwT7c7eFQ4U2vKMOp4A/9+o= github.com/multiversx/mx-chain-core-go v1.1.30/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= github.com/multiversx/mx-chain-core-go v1.1.31/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= -github.com/multiversx/mx-chain-core-go v1.1.34-0.20230215164556-5ec8f51e96d3 h1:68ooIEnoUO69QGWXDa0VdDarOA4CaF9IiEMg9q/kisc= github.com/multiversx/mx-chain-core-go v1.1.34-0.20230215164556-5ec8f51e96d3/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= +github.com/multiversx/mx-chain-core-go v1.1.36-0.20230308081722-5262fb09cb9a h1:EBmifJAGrqqj5mPVIMBlxWMrE6yl7lL2VfFzvkfDWiM= +github.com/multiversx/mx-chain-core-go v1.1.36-0.20230308081722-5262fb09cb9a/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= github.com/multiversx/mx-chain-crypto-go v1.2.5 h1:tuq3BUNMhKud5DQbZi9DiVAAHUXypizy8zPH0NpTGZk= github.com/multiversx/mx-chain-crypto-go v1.2.5/go.mod h1:teqhNyWEqfMPgNn8sgWXlgtJ1a36jGCnhs/tRpXW6r4= github.com/multiversx/mx-chain-es-indexer-go v1.3.13-0.20230216143122-cee11ea0d0f5 h1:zpLIFFryEqlPVwkE8IA5CvsEpy+S+QklJCpjQ6HfDBU= github.com/multiversx/mx-chain-es-indexer-go v1.3.13-0.20230216143122-cee11ea0d0f5/go.mod h1:lp2IyG55Y2NOhJGkmCbMAsYizU5Wff4+x8lHsTO3gUc= github.com/multiversx/mx-chain-logger-go v1.0.11 h1:DFsHa+sc5fKwhDR50I8uBM99RTDTEW68ESyr5ALRDwE= github.com/multiversx/mx-chain-logger-go v1.0.11/go.mod h1:1srDkP0DQucWQ+rYfaq0BX2qLnULsUdRPADpYUTM6dA= -github.com/multiversx/mx-chain-p2p-go v1.0.11 h1:lzGLvE/liPmNtJxbmEAWKWP9iy2XB6DNjCw0ZogLtTk= -github.com/multiversx/mx-chain-p2p-go v1.0.11/go.mod h1:j9Ueo2ptCnL7TQvQg6KS/KWAoJEJpjkPgE5ZTaqEAn4= +github.com/multiversx/mx-chain-p2p-go v1.0.13 h1:woIlYkDFCKYyJQ5urDcOzz8HUFGsSEhTfUXDDxNI2zM= +github.com/multiversx/mx-chain-p2p-go v1.0.13/go.mod h1:j9Ueo2ptCnL7TQvQg6KS/KWAoJEJpjkPgE5ZTaqEAn4= github.com/multiversx/mx-chain-storage-go v1.0.7 h1:UqLo/OLTD3IHiE/TB/SEdNRV1GG2f1R6vIP5ehHwCNw= github.com/multiversx/mx-chain-storage-go v1.0.7/go.mod h1:gtKoV32Cg2Uy8deHzF8Ud0qAl0zv92FvWgPSYIP0Zmg= github.com/multiversx/mx-chain-vm-common-go v1.3.36/go.mod h1:sZ2COLCxvf2GxAAJHGmGqWybObLtFuk2tZUyGqnMXE8= -github.com/multiversx/mx-chain-vm-common-go v1.3.37-0.20230216122352-6dd6ff58ca2d h1:n2qx7CceoqbEqZXWOsjLccS34zLPh5KNriIXTlNSIc4= github.com/multiversx/mx-chain-vm-common-go v1.3.37-0.20230216122352-6dd6ff58ca2d/go.mod h1:Y9ggiJtjGCPK/8WEzFO91JLlppMa/uUMobMmsogeiRw= +github.com/multiversx/mx-chain-vm-common-go v1.3.37/go.mod h1:sZ2COLCxvf2GxAAJHGmGqWybObLtFuk2tZUyGqnMXE8= +github.com/multiversx/mx-chain-vm-common-go v1.3.38-0.20230308082221-60b880366741 h1:whMuG9WFJVWfnzeKFD91x8TaN1DFvyS8onRhZFSU5BY= +github.com/multiversx/mx-chain-vm-common-go v1.3.38-0.20230308082221-60b880366741/go.mod h1:LZzIn2H3/hn95/hL0HhZ0Ql9Zpe94CzEw8ySfG9hnwg= github.com/multiversx/mx-chain-vm-v1_2-go v1.2.50 h1:ScUq7/wq78vthMTQ6v5Ux1DvSMQMHxQ2Sl7aPP26q1w= github.com/multiversx/mx-chain-vm-v1_2-go v1.2.50/go.mod h1:e3uYdgoKzs3puaznbmSjDcRisJc5Do4tpg7VqyYwoek= github.com/multiversx/mx-chain-vm-v1_3-go v1.3.51 h1:axtp5/mpA+xYJ1cu4KtAGETV4t6v6/tNfQh0HCclBYY= github.com/multiversx/mx-chain-vm-v1_3-go v1.3.51/go.mod h1:oKj32V2nkd+KGNOL6emnwVkDRPpciwHHDqBmeorcL8k= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.76 h1:HFTf/GuLt68UiDKDZL3GD/YMFdDBaOv31OeP7n0ICQc= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.76/go.mod h1:iyAWamHL3voN/T2sjx7UZ8L4kXLImkgMNh27WUFFhxE= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.77 h1:3Yh4brS5/Jye24l5AKy+Q6Yci6Rv55pHyj9/GR3AYos= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.77/go.mod h1:3IaAOHc1JfxL5ywQZIrcaHQu5+CVdZNDaoY64NGOtUE= github.com/multiversx/mx-components-big-int v0.1.1 h1:695mYPKYOrmGEGgRH4/pZruDoe3CPP1LHrBxKfvj5l4= github.com/multiversx/mx-components-big-int v0.1.1/go.mod h1:0QrcFdfeLgJ/am10HGBeH0G0DNF+0Qx1E4DS/iozQls= github.com/multiversx/protobuf v1.3.2 h1:RaNkxvGTGbA0lMcnHAN24qE1G1i+Xs5yHA6MDvQ4mSM= @@ -659,8 +674,9 @@ github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhM github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.9.3 h1:zeC5b1GviRUyKYd6OJPvBU/mcVDVoL1OhT17FCt5dSQ= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= -github.com/pelletier/go-toml/v2 v2.0.1 h1:8e3L2cCQzLFi2CR4g7vGFuFxX7Jl1kKX8gW+iV0GUKU= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= +github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= +github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -787,11 +803,14 @@ github.com/tklauser/go-sysconf v0.3.4 h1:HT8SVixZd3IzLdfs/xlpq0jeSfTX57g1v6wB1Eu github.com/tklauser/go-sysconf v0.3.4/go.mod h1:Cl2c8ZRWfHD5IrfHo9VN+FX9kCFjIOyVklgXycLB6ek= github.com/tklauser/numcpus v0.2.1 h1:ct88eFm+Q7m2ZfXJdan1xYoXKlmwsfP+k88q05KvlZc= github.com/tklauser/numcpus v0.2.1/go.mod h1:9aU+wOc6WjUIZEwWMP62PL/41d65P+iks1gBkr4QyP8= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/ugorji/go/codec v1.2.9 h1:rmenucSohSTiyL09Y+l2OCk+FrMxGMzho2+tjr5ticU= +github.com/ugorji/go/codec v1.2.9/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.10 h1:p8Fspmz3iTctJstry1PYS3HVdllxnEzTEsgIgtxTrCk= @@ -852,6 +871,8 @@ go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= go.uber.org/zap v1.22.0 h1:Zcye5DUgBloQ9BaT4qc9BnjOFog5TvBSAGkJ3Nf70c0= go.uber.org/zap v1.22.0/go.mod h1:H4siCOZOrAolnUPJEkfaSjDqyP+BDS0DdDWzwcgt3+U= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -875,8 +896,9 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -968,8 +990,10 @@ golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220812174116-3211cb980234/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -1069,12 +1093,16 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1082,8 +1110,11 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1281,6 +1312,7 @@ lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0= lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/goversion v1.2.0/go.mod h1:Eih9y/uIBS3ulggl7KNJ09xGSLcuNaLgmvvqa07sgfo= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= diff --git a/heartbeat/errors.go b/heartbeat/errors.go index 5f3886f90e7..b8a47032325 100644 --- a/heartbeat/errors.go +++ b/heartbeat/errors.go @@ -79,3 +79,9 @@ var ErrNilShardCoordinator = errors.New("nil shard coordinator") // ErrNilTrieSyncStatisticsProvider signals that a nil trie sync statistics provider was provided var ErrNilTrieSyncStatisticsProvider = errors.New("nil trie sync statistics provider") + +// ErrNilManagedPeersHolder signals that a nil managed peers holder has been provided +var ErrNilManagedPeersHolder = errors.New("nil managed peers holder") + +// ErrInvalidConfiguration signals that an invalid configuration has been provided +var ErrInvalidConfiguration = errors.New("invalid configuration") diff --git a/heartbeat/interface.go b/heartbeat/interface.go index 2fc050db17a..d791a9b6ed0 100644 --- a/heartbeat/interface.go +++ b/heartbeat/interface.go @@ -1,6 +1,8 @@ package heartbeat import ( + "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" crypto "github.com/multiversx/mx-chain-crypto-go" @@ -12,9 +14,11 @@ import ( // P2PMessenger defines a subset of the p2p.Messenger interface type P2PMessenger interface { Broadcast(topic string, buff []byte) + BroadcastUsingPrivateKey(topic string, buff []byte, pid core.PeerID, skBytes []byte) ID() core.PeerID Sign(payload []byte) ([]byte, error) ConnectedPeersOnTopic(topic string) []core.PeerID + SignUsingPrivateKey(skBytes []byte, payload []byte) ([]byte, error) IsInterfaceNil() bool } @@ -68,3 +72,31 @@ type TrieSyncStatisticsProvider interface { NumProcessed() int IsInterfaceNil() bool } + +// ManagedPeersHolder defines the operations of an entity that holds managed identities for a node +type ManagedPeersHolder interface { + AddManagedPeer(privateKeyBytes []byte) error + GetPrivateKey(pkBytes []byte) (crypto.PrivateKey, error) + GetP2PIdentity(pkBytes []byte) ([]byte, core.PeerID, error) + GetMachineID(pkBytes []byte) (string, error) + GetNameAndIdentity(pkBytes []byte) (string, string, error) + IncrementRoundsWithoutReceivedMessages(pkBytes []byte) + ResetRoundsWithoutReceivedMessages(pkBytes []byte) + GetManagedKeysByCurrentNode() map[string]crypto.PrivateKey + IsKeyManagedByCurrentNode(pkBytes []byte) bool + IsKeyRegistered(pkBytes []byte) bool + IsPidManagedByCurrentNode(pid core.PeerID) bool + IsKeyValidator(pkBytes []byte) bool + SetValidatorState(pkBytes []byte, state bool) + GetNextPeerAuthenticationTime(pkBytes []byte) (time.Time, error) + SetNextPeerAuthenticationTime(pkBytes []byte, nextTime time.Time) + IsMultiKeyMode() bool + IsInterfaceNil() bool +} + +// ShardCoordinator defines the operations of a shard coordinator +type ShardCoordinator interface { + SelfId() uint32 + ComputeId(address []byte) uint32 + IsInterfaceNil() bool +} diff --git a/heartbeat/mock/heartbeatSenderInfoProviderStub.go b/heartbeat/mock/heartbeatSenderInfoProviderStub.go index 33d02b4d85b..bf9f313f5a0 100644 --- a/heartbeat/mock/heartbeatSenderInfoProviderStub.go +++ b/heartbeat/mock/heartbeatSenderInfoProviderStub.go @@ -4,13 +4,13 @@ import "github.com/multiversx/mx-chain-core-go/core" // HeartbeatSenderInfoProviderStub - type HeartbeatSenderInfoProviderStub struct { - GetSenderInfoCalled func() (string, core.P2PPeerSubType, error) + GetCurrentNodeTypeCalled func() (string, core.P2PPeerSubType, error) } -// GetSenderInfo - -func (stub *HeartbeatSenderInfoProviderStub) GetSenderInfo() (string, core.P2PPeerSubType, error) { - if stub.GetSenderInfoCalled != nil { - return stub.GetSenderInfoCalled() +// GetCurrentNodeType - +func (stub *HeartbeatSenderInfoProviderStub) GetCurrentNodeType() (string, core.P2PPeerSubType, error) { + if stub.GetCurrentNodeTypeCalled != nil { + return stub.GetCurrentNodeTypeCalled() } return "", 0, nil diff --git a/heartbeat/sender/commonHeartbeatSender.go b/heartbeat/sender/commonHeartbeatSender.go new file mode 100644 index 00000000000..e428127f459 --- /dev/null +++ b/heartbeat/sender/commonHeartbeatSender.go @@ -0,0 +1,93 @@ +package sender + +import ( + "fmt" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/heartbeat" +) + +type commonHeartbeatSender struct { + baseSender + versionNumber string + nodeDisplayName string + identity string + peerSubType core.P2PPeerSubType + currentBlockProvider heartbeat.CurrentBlockProvider + peerTypeProvider heartbeat.PeerTypeProviderHandler +} + +func (chs *commonHeartbeatSender) generateMessageBytes( + versionNumber string, + nodeDisplayName string, + identity string, + peerSubType uint32, + pkBytes []byte, + numProcessedTrieNodes uint64, +) ([]byte, error) { + if len(versionNumber) > maxSizeInBytes { + return nil, fmt.Errorf("%w for versionNumber, received %s of size %d, max size allowed %d", + heartbeat.ErrPropertyTooLong, versionNumber, len(versionNumber), maxSizeInBytes) + } + if len(nodeDisplayName) > maxSizeInBytes { + return nil, fmt.Errorf("%w for nodeDisplayName, received %s of size %d, max size allowed %d", + heartbeat.ErrPropertyTooLong, nodeDisplayName, len(nodeDisplayName), maxSizeInBytes) + } + if len(identity) > maxSizeInBytes { + return nil, fmt.Errorf("%w for identity, received %s of size %d, max size allowed %d", + heartbeat.ErrPropertyTooLong, identity, len(identity), maxSizeInBytes) + } + + payload := &heartbeat.Payload{ + Timestamp: time.Now().Unix(), + HardforkMessage: "", // sent through peer authentication message + } + payloadBytes, err := chs.marshaller.Marshal(payload) + if err != nil { + return nil, err + } + + nonce := uint64(0) + currentBlock := chs.currentBlockProvider.GetCurrentBlockHeader() + if currentBlock != nil { + nonce = currentBlock.GetNonce() + } + + msg := &heartbeat.HeartbeatV2{ + Payload: payloadBytes, + VersionNumber: versionNumber, + NodeDisplayName: nodeDisplayName, + Identity: identity, + Nonce: nonce, + PeerSubType: peerSubType, + Pubkey: pkBytes, + NumTrieNodesSynced: numProcessedTrieNodes, + } + + return chs.marshaller.Marshal(msg) +} + +// GetCurrentNodeType will return the current sender type and subtype +func (chs *commonHeartbeatSender) GetCurrentNodeType() (string, core.P2PPeerSubType, error) { + _, pk := chs.getCurrentPrivateAndPublicKeys() + pkBytes, err := pk.ToByteArray() + if err != nil { + return "", 0, err + } + + peerType := chs.computePeerList(pkBytes) + + return peerType, chs.peerSubType, nil +} + +func (chs *commonHeartbeatSender) computePeerList(pubkey []byte) string { + peerType, _, err := chs.peerTypeProvider.ComputeForPubKey(pubkey) + if err != nil { + log.Warn("heartbeatSender: compute peer type", "error", err) + return string(common.ObserverList) + } + + return string(peerType) +} diff --git a/heartbeat/sender/commonPeerAuthenticationSender.go b/heartbeat/sender/commonPeerAuthenticationSender.go new file mode 100644 index 00000000000..f1cf2e41eed --- /dev/null +++ b/heartbeat/sender/commonPeerAuthenticationSender.go @@ -0,0 +1,93 @@ +package sender + +import ( + "bytes" + "time" + + "github.com/multiversx/mx-chain-core-go/data/batch" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/heartbeat" +) + +type commonPeerAuthenticationSender struct { + baseSender + nodesCoordinator heartbeat.NodesCoordinator + peerSignatureHandler crypto.PeerSignatureHandler + hardforkTrigger heartbeat.HardforkTrigger + hardforkTriggerPubKey []byte +} + +func (cpas *commonPeerAuthenticationSender) generateMessageBytes( + pkBytes []byte, + privateKey crypto.PrivateKey, + p2pSkBytes []byte, + pidBytes []byte, +) ([]byte, bool, int64, error) { + msg := &heartbeat.PeerAuthentication{ + Pid: pidBytes, + Pubkey: pkBytes, + } + + hardforkPayload, isTriggered := cpas.getHardforkPayload() + payload := &heartbeat.Payload{ + Timestamp: time.Now().Unix(), + HardforkMessage: string(hardforkPayload), + } + payloadBytes, err := cpas.marshaller.Marshal(payload) + if err != nil { + return nil, isTriggered, 0, err + } + msg.Payload = payloadBytes + + if p2pSkBytes != nil { + msg.PayloadSignature, err = cpas.messenger.SignUsingPrivateKey(p2pSkBytes, payloadBytes) + if err != nil { + return nil, isTriggered, 0, err + } + } else { + msg.PayloadSignature, err = cpas.messenger.Sign(payloadBytes) + if err != nil { + return nil, isTriggered, 0, err + } + } + + msg.Signature, err = cpas.peerSignatureHandler.GetPeerSignature(privateKey, msg.Pid) + if err != nil { + return nil, isTriggered, 0, err + } + + msgBytes, err := cpas.marshaller.Marshal(msg) + if err != nil { + return nil, isTriggered, 0, err + } + + b := &batch.Batch{ + Data: make([][]byte, 1), + } + b.Data[0] = msgBytes + data, err := cpas.marshaller.Marshal(b) + if err != nil { + return nil, isTriggered, 0, err + } + + return data, isTriggered, payload.Timestamp, nil +} + +func (cpas *commonPeerAuthenticationSender) isValidator(pkBytes []byte) bool { + _, _, err := cpas.nodesCoordinator.GetValidatorWithPublicKey(pkBytes) + return err == nil +} + +func (cpas *commonPeerAuthenticationSender) isHardforkSource(pkBytes []byte) bool { + return bytes.Equal(pkBytes, cpas.hardforkTriggerPubKey) +} + +func (cpas *commonPeerAuthenticationSender) getHardforkPayload() ([]byte, bool) { + payload := make([]byte, 0) + _, isTriggered := cpas.hardforkTrigger.RecordedTriggerMessage() + if isTriggered { + payload = cpas.hardforkTrigger.CreateData() + } + + return payload, isTriggered +} diff --git a/heartbeat/sender/heartbeatSender.go b/heartbeat/sender/heartbeatSender.go index d582d7a54d7..77c52cd96ee 100644 --- a/heartbeat/sender/heartbeatSender.go +++ b/heartbeat/sender/heartbeatSender.go @@ -2,11 +2,9 @@ package sender import ( "fmt" - "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/heartbeat" ) @@ -25,13 +23,7 @@ type argHeartbeatSender struct { } type heartbeatSender struct { - baseSender - versionNumber string - nodeDisplayName string - identity string - peerSubType core.P2PPeerSubType - currentBlockProvider heartbeat.CurrentBlockProvider - peerTypeProvider heartbeat.PeerTypeProviderHandler + commonHeartbeatSender trieSyncStatisticsProvider heartbeat.TrieSyncStatisticsProvider } @@ -43,13 +35,15 @@ func newHeartbeatSender(args argHeartbeatSender) (*heartbeatSender, error) { } return &heartbeatSender{ - baseSender: createBaseSender(args.argBaseSender), - versionNumber: args.versionNumber, - nodeDisplayName: args.nodeDisplayName, - identity: args.identity, - peerSubType: args.peerSubType, - currentBlockProvider: args.currentBlockProvider, - peerTypeProvider: args.peerTypeProvider, + commonHeartbeatSender: commonHeartbeatSender{ + baseSender: createBaseSender(args.argBaseSender), + currentBlockProvider: args.currentBlockProvider, + peerTypeProvider: args.peerTypeProvider, + versionNumber: args.versionNumber, + nodeDisplayName: args.nodeDisplayName, + identity: args.identity, + peerSubType: args.peerSubType, + }, trieSyncStatisticsProvider: args.trieSyncStatisticsProvider, }, nil } @@ -99,40 +93,21 @@ func (sender *heartbeatSender) Execute() { } func (sender *heartbeatSender) execute() error { - payload := &heartbeat.Payload{ - Timestamp: time.Now().Unix(), - HardforkMessage: "", // sent through peer authentication message - } - payloadBytes, err := sender.marshaller.Marshal(payload) - if err != nil { - return err - } - - nonce := uint64(0) - currentBlock := sender.currentBlockProvider.GetCurrentBlockHeader() - if currentBlock != nil { - nonce = currentBlock.GetNonce() - } - _, pk := sender.getCurrentPrivateAndPublicKeys() pkBytes, err := pk.ToByteArray() if err != nil { return err } - trieNodesReceived := sender.trieSyncStatisticsProvider.NumProcessed() - msg := &heartbeat.HeartbeatV2{ - Payload: payloadBytes, - VersionNumber: sender.versionNumber, - NodeDisplayName: sender.nodeDisplayName, - Identity: sender.identity, - Nonce: nonce, - PeerSubType: uint32(sender.peerSubType), - Pubkey: pkBytes, - NumTrieNodesSynced: uint64(trieNodesReceived), - } - - msgBytes, err := sender.marshaller.Marshal(msg) + trieNodesReceived := uint64(sender.trieSyncStatisticsProvider.NumProcessed()) + msgBytes, err := sender.generateMessageBytes( + sender.versionNumber, + sender.nodeDisplayName, + sender.identity, + uint32(sender.peerSubType), + pkBytes, + trieNodesReceived, + ) if err != nil { return err } @@ -142,29 +117,6 @@ func (sender *heartbeatSender) execute() error { return nil } -// getSenderInfo will return the current sender info -func (sender *heartbeatSender) getSenderInfo() (string, core.P2PPeerSubType, error) { - _, pk := sender.getCurrentPrivateAndPublicKeys() - pkBytes, err := pk.ToByteArray() - if err != nil { - return "", 0, err - } - - peerType := sender.computePeerList(pkBytes) - - return peerType, sender.peerSubType, nil -} - -func (sender *heartbeatSender) computePeerList(pubkey []byte) string { - peerType, _, err := sender.peerTypeProvider.ComputeForPubKey(pubkey) - if err != nil { - log.Warn("heartbeatSender: compute peer type", "error", err) - return string(common.ObserverList) - } - - return string(peerType) -} - // IsInterfaceNil returns true if there is no value under the interface func (sender *heartbeatSender) IsInterfaceNil() bool { return sender == nil diff --git a/heartbeat/sender/heartbeatSenderFactory.go b/heartbeat/sender/heartbeatSenderFactory.go new file mode 100644 index 00000000000..487bd623924 --- /dev/null +++ b/heartbeat/sender/heartbeatSenderFactory.go @@ -0,0 +1,105 @@ +package sender + +import ( + "fmt" + + "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/heartbeat" +) + +type argHeartbeatSenderFactory struct { + argBaseSender + baseVersionNumber string + versionNumber string + nodeDisplayName string + identity string + peerSubType core.P2PPeerSubType + currentBlockProvider heartbeat.CurrentBlockProvider + peerTypeProvider heartbeat.PeerTypeProviderHandler + managedPeersHolder heartbeat.ManagedPeersHolder + shardCoordinator heartbeat.ShardCoordinator + nodesCoordinator heartbeat.NodesCoordinator + trieSyncStatisticsProvider heartbeat.TrieSyncStatisticsProvider +} + +func createHeartbeatSender(args argHeartbeatSenderFactory) (heartbeatSenderHandler, error) { + isMultikey, err := isMultikeyMode(args.privKey, args.managedPeersHolder, args.nodesCoordinator) + if err != nil { + return nil, fmt.Errorf("%w while creating heartbeat sender", err) + } + + if isMultikey { + return createMultikeyHeartbeatSender(args) + } + + return createRegularHeartbeatSender(args) +} + +func createRegularHeartbeatSender(args argHeartbeatSenderFactory) (*heartbeatSender, error) { + argsSender := argHeartbeatSender{ + argBaseSender: argBaseSender{ + messenger: args.messenger, + marshaller: args.marshaller, + topic: args.topic, + timeBetweenSends: args.timeBetweenSends, + timeBetweenSendsWhenError: args.timeBetweenSendsWhenError, + thresholdBetweenSends: args.thresholdBetweenSends, + redundancyHandler: args.redundancyHandler, + privKey: args.privKey, + }, + versionNumber: args.versionNumber, + nodeDisplayName: args.nodeDisplayName, + identity: args.identity, + peerSubType: args.peerSubType, + currentBlockProvider: args.currentBlockProvider, + peerTypeProvider: args.peerTypeProvider, + trieSyncStatisticsProvider: args.trieSyncStatisticsProvider, + } + + return newHeartbeatSender(argsSender) +} + +func createMultikeyHeartbeatSender(args argHeartbeatSenderFactory) (*multikeyHeartbeatSender, error) { + argsSender := argMultikeyHeartbeatSender{ + argBaseSender: argBaseSender{ + messenger: args.messenger, + marshaller: args.marshaller, + topic: args.topic, + timeBetweenSends: args.timeBetweenSends, + timeBetweenSendsWhenError: args.timeBetweenSendsWhenError, + thresholdBetweenSends: args.thresholdBetweenSends, + redundancyHandler: args.redundancyHandler, + privKey: args.privKey, + }, + peerTypeProvider: args.peerTypeProvider, + versionNumber: args.versionNumber, + baseVersionNumber: args.baseVersionNumber, + nodeDisplayName: args.nodeDisplayName, + identity: args.identity, + peerSubType: args.peerSubType, + currentBlockProvider: args.currentBlockProvider, + managedPeersHolder: args.managedPeersHolder, + shardCoordinator: args.shardCoordinator, + trieSyncStatisticsProvider: args.trieSyncStatisticsProvider, + } + + return newMultikeyHeartbeatSender(argsSender) +} + +func isMultikeyMode(privKey crypto.PrivateKey, managedPeersHolder heartbeat.ManagedPeersHolder, nodesCoordinator heartbeat.NodesCoordinator) (bool, error) { + pk := privKey.GeneratePublic() + pkBytes, err := pk.ToByteArray() + if err != nil { + return false, err + } + + isMultikey := managedPeersHolder.IsMultiKeyMode() + + _, _, err = nodesCoordinator.GetValidatorWithPublicKey(pkBytes) + if err == nil && isMultikey { + return false, fmt.Errorf("%w, isMultikey = %t, isValidator = %v", heartbeat.ErrInvalidConfiguration, isMultikey, err == nil) + } + + return isMultikey, nil +} diff --git a/heartbeat/sender/heartbeatSenderFactory_test.go b/heartbeat/sender/heartbeatSenderFactory_test.go new file mode 100644 index 00000000000..467be4b890f --- /dev/null +++ b/heartbeat/sender/heartbeatSenderFactory_test.go @@ -0,0 +1,124 @@ +package sender + +import ( + "errors" + "fmt" + "strings" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/heartbeat" + "github.com/multiversx/mx-chain-go/heartbeat/mock" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/stretchr/testify/assert" +) + +func createMockHeartbeatSenderFactoryArgs() argHeartbeatSenderFactory { + return argHeartbeatSenderFactory{ + argBaseSender: createMockBaseArgs(), + baseVersionNumber: "base version number", + versionNumber: "version number", + nodeDisplayName: "node name", + identity: "identity", + peerSubType: core.RegularPeer, + currentBlockProvider: &mock.CurrentBlockProviderStub{}, + peerTypeProvider: &mock.PeerTypeProviderStub{}, + managedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + shardCoordinator: createShardCoordinatorInShard(0), + nodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, + trieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, + } +} + +func TestHeartbeatSenderFactory_createHeartbeatSender(t *testing.T) { + t.Parallel() + + t.Run("ToByteArray fails should error", func(t *testing.T) { + t.Parallel() + + args := createMockHeartbeatSenderFactoryArgs() + args.privKey = &cryptoMocks.PrivateKeyStub{ + GeneratePublicStub: func() crypto.PublicKey { + return &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return nil, expectedErr + }, + } + }, + } + hbSender, err := createHeartbeatSender(args) + assert.True(t, errors.Is(err, expectedErr)) + assert.True(t, check.IfNil(hbSender)) + }) + t.Run("validator with keys managed should error", func(t *testing.T) { + t.Parallel() + + args := createMockHeartbeatSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, nil + }, + } + args.managedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IsMultiKeyModeCalled: func() bool { + return true + }, + } + hbSender, err := createHeartbeatSender(args) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidConfiguration)) + assert.True(t, strings.Contains(err.Error(), "isValidator")) + assert.True(t, check.IfNil(hbSender)) + }) + t.Run("validator should create regular sender", func(t *testing.T) { + t.Parallel() + + args := createMockHeartbeatSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, nil + }, + } + hbSender, err := createHeartbeatSender(args) + assert.Nil(t, err) + assert.False(t, check.IfNil(hbSender)) + assert.Equal(t, "*sender.heartbeatSender", fmt.Sprintf("%T", hbSender)) + }) + t.Run("regular observer should create regular sender", func(t *testing.T) { + t.Parallel() + + args := createMockHeartbeatSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, errors.New("not validator") + }, + } + hbSender, err := createHeartbeatSender(args) + assert.Nil(t, err) + assert.False(t, check.IfNil(hbSender)) + assert.Equal(t, "*sender.heartbeatSender", fmt.Sprintf("%T", hbSender)) + }) + t.Run("not validator with keys managed should create multikey sender", func(t *testing.T) { + t.Parallel() + + args := createMockHeartbeatSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, errors.New("not validator") + }, + } + args.managedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IsMultiKeyModeCalled: func() bool { + return true + }, + } + hbSender, err := createHeartbeatSender(args) + assert.Nil(t, err) + assert.False(t, check.IfNil(hbSender)) + assert.Equal(t, "*sender.multikeyHeartbeatSender", fmt.Sprintf("%T", hbSender)) + }) +} diff --git a/heartbeat/sender/heartbeatSender_test.go b/heartbeat/sender/heartbeatSender_test.go index b4411c6d068..40d8f41db30 100644 --- a/heartbeat/sender/heartbeatSender_test.go +++ b/heartbeat/sender/heartbeatSender_test.go @@ -351,7 +351,7 @@ func TestHeartbeatSender_execute(t *testing.T) { }) } -func TestHeartbeatSender_getSenderInfo(t *testing.T) { +func TestHeartbeatSender_GetCurrentNodeType(t *testing.T) { t.Parallel() args := createMockHeartbeatSenderArgs(createMockBaseArgs()) @@ -363,7 +363,7 @@ func TestHeartbeatSender_getSenderInfo(t *testing.T) { } senderInstance, _ := newHeartbeatSender(args) - peerType, subType, err := senderInstance.getSenderInfo() + peerType, subType, err := senderInstance.GetCurrentNodeType() assert.Nil(t, err) assert.Equal(t, string(common.EligibleList), peerType) assert.Equal(t, core.FullHistoryObserver, subType) diff --git a/heartbeat/sender/interface.go b/heartbeat/sender/interface.go index f7fa9a7482a..bfdb3973e59 100644 --- a/heartbeat/sender/interface.go +++ b/heartbeat/sender/interface.go @@ -1,6 +1,10 @@ package sender -import "time" +import ( + "time" + + "github.com/multiversx/mx-chain-core-go/core" +) type senderHandler interface { ExecutionReadyChannel() <-chan time.Time @@ -15,6 +19,16 @@ type hardforkHandler interface { Close() } +type peerAuthenticationSenderHandler interface { + senderHandler + hardforkHandler +} + +type heartbeatSenderHandler interface { + senderHandler + GetCurrentNodeType() (string, core.P2PPeerSubType, error) +} + type timerHandler interface { CreateNewTimer(duration time.Duration) ExecutionReadyChannel() <-chan time.Time diff --git a/heartbeat/sender/multikeyHeartbeatSender.go b/heartbeat/sender/multikeyHeartbeatSender.go new file mode 100644 index 00000000000..6e147dd8e47 --- /dev/null +++ b/heartbeat/sender/multikeyHeartbeatSender.go @@ -0,0 +1,216 @@ +package sender + +import ( + "fmt" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/heartbeat" +) + +// argMultikeyHeartbeatSender represents the arguments for the heartbeat sender +type argMultikeyHeartbeatSender struct { + argBaseSender + peerTypeProvider heartbeat.PeerTypeProviderHandler + versionNumber string + baseVersionNumber string + nodeDisplayName string + identity string + peerSubType core.P2PPeerSubType + currentBlockProvider heartbeat.CurrentBlockProvider + managedPeersHolder heartbeat.ManagedPeersHolder + shardCoordinator heartbeat.ShardCoordinator + trieSyncStatisticsProvider heartbeat.TrieSyncStatisticsProvider +} + +type multikeyHeartbeatSender struct { + commonHeartbeatSender + baseVersionNumber string + managedPeersHolder heartbeat.ManagedPeersHolder + shardCoordinator heartbeat.ShardCoordinator + trieSyncStatisticsProvider heartbeat.TrieSyncStatisticsProvider +} + +// newMultikeyHeartbeatSender creates a new instance of type multikeyHeartbeatSender +func newMultikeyHeartbeatSender(args argMultikeyHeartbeatSender) (*multikeyHeartbeatSender, error) { + err := checkMultikeyHeartbeatSenderArgs(args) + if err != nil { + return nil, err + } + + return &multikeyHeartbeatSender{ + commonHeartbeatSender: commonHeartbeatSender{ + baseSender: createBaseSender(args.argBaseSender), + currentBlockProvider: args.currentBlockProvider, + peerTypeProvider: args.peerTypeProvider, + versionNumber: args.versionNumber, + nodeDisplayName: args.nodeDisplayName, + identity: args.identity, + peerSubType: args.peerSubType, + }, + baseVersionNumber: args.baseVersionNumber, + managedPeersHolder: args.managedPeersHolder, + shardCoordinator: args.shardCoordinator, + trieSyncStatisticsProvider: args.trieSyncStatisticsProvider, + }, nil +} + +func checkMultikeyHeartbeatSenderArgs(args argMultikeyHeartbeatSender) error { + err := checkBaseSenderArgs(args.argBaseSender) + if err != nil { + return err + } + if check.IfNil(args.peerTypeProvider) { + return heartbeat.ErrNilPeerTypeProvider + } + if len(args.versionNumber) > maxSizeInBytes { + return fmt.Errorf("%w for versionNumber, received %s of size %d, max size allowed %d", + heartbeat.ErrPropertyTooLong, args.versionNumber, len(args.versionNumber), maxSizeInBytes) + } + if len(args.baseVersionNumber) > maxSizeInBytes { + return fmt.Errorf("%w for baseVersionNumber, received %s of size %d, max size allowed %d", + heartbeat.ErrPropertyTooLong, args.baseVersionNumber, len(args.baseVersionNumber), maxSizeInBytes) + } + if len(args.nodeDisplayName) > maxSizeInBytes { + return fmt.Errorf("%w for nodeDisplayName, received %s of size %d, max size allowed %d", + heartbeat.ErrPropertyTooLong, args.nodeDisplayName, len(args.nodeDisplayName), maxSizeInBytes) + } + if len(args.identity) > maxSizeInBytes { + return fmt.Errorf("%w for identity, received %s of size %d, max size allowed %d", + heartbeat.ErrPropertyTooLong, args.identity, len(args.identity), maxSizeInBytes) + } + if check.IfNil(args.currentBlockProvider) { + return heartbeat.ErrNilCurrentBlockProvider + } + if check.IfNil(args.managedPeersHolder) { + return heartbeat.ErrNilManagedPeersHolder + } + if check.IfNil(args.shardCoordinator) { + return heartbeat.ErrNilShardCoordinator + } + if check.IfNil(args.trieSyncStatisticsProvider) { + return heartbeat.ErrNilTrieSyncStatisticsProvider + } + + return nil +} + +// Execute will handle the execution of a cycle in which the heartbeat message will be sent +func (sender *multikeyHeartbeatSender) Execute() { + duration := sender.computeRandomDuration(sender.timeBetweenSends) + err := sender.execute() + if err != nil { + duration = sender.timeBetweenSendsWhenError + log.Error("error sending heartbeat messages", "error", err, "next send will be in", duration) + } else { + log.Debug("heartbeat messages sent", "next send will be in", duration) + } + + sender.CreateNewTimer(duration) +} + +func (sender *multikeyHeartbeatSender) execute() error { + _, pk := sender.getCurrentPrivateAndPublicKeys() + pkBytes, err := pk.ToByteArray() + if err != nil { + return err + } + + trieNodesReceived := uint64(sender.trieSyncStatisticsProvider.NumProcessed()) + buff, err := sender.generateMessageBytes( + sender.versionNumber, + sender.nodeDisplayName, + sender.identity, + uint32(sender.peerSubType), + pkBytes, + trieNodesReceived, + ) + if err != nil { + return err + } + + sender.messenger.Broadcast(sender.topic, buff) + + return sender.sendMultiKeysInfo() +} + +func (sender *multikeyHeartbeatSender) sendMultiKeysInfo() error { + managedKeys := sender.managedPeersHolder.GetManagedKeysByCurrentNode() + for pk := range managedKeys { + pkBytes := []byte(pk) + shouldSend := sender.processIfShouldSend(pkBytes) + if !shouldSend { + continue + } + + err := sender.sendMessageForKey(pkBytes) + if err != nil { + log.Warn("could not broadcast for pk", "pk", pkBytes, "error", err) + } + } + + return nil +} + +func (sender *multikeyHeartbeatSender) sendMessageForKey(pkBytes []byte) error { + time.Sleep(delayedBroadcast) + + name, identity, err := sender.managedPeersHolder.GetNameAndIdentity(pkBytes) + if err != nil { + return err + } + + machineID, err := sender.managedPeersHolder.GetMachineID(pkBytes) + if err != nil { + return err + } + versionNumber := fmt.Sprintf("%s/%s", sender.baseVersionNumber, machineID) + + buff, err := sender.generateMessageBytes( + versionNumber, + name, + identity, + uint32(core.RegularPeer), // force multi key handled peers to be of type regular peers + pkBytes, + 0, // hardcode this to 0, the virtual peers do not handle the trie sync + ) + if err != nil { + return err + } + + p2pSk, pid, err := sender.managedPeersHolder.GetP2PIdentity(pkBytes) + if err != nil { + return err + } + + sender.messenger.BroadcastUsingPrivateKey(sender.topic, buff, pid, p2pSk) + + return nil +} + +func (sender *multikeyHeartbeatSender) processIfShouldSend(pk []byte) bool { + if !sender.managedPeersHolder.IsKeyManagedByCurrentNode(pk) { + return false + } + _, shardID, err := sender.peerTypeProvider.ComputeForPubKey(pk) + if err != nil { + log.Debug("processIfShouldSend.ComputeForPubKey", "error", err) + return false + } + + if shardID != sender.shardCoordinator.SelfId() { + log.Trace("processIfShouldSend: shard id does not match", + "pk", pk, + "self shard", sender.shardCoordinator.SelfId(), + "pk shard", shardID) + return false + } + + return true +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sender *multikeyHeartbeatSender) IsInterfaceNil() bool { + return sender == nil +} diff --git a/heartbeat/sender/multikeyHeartbeatSender_test.go b/heartbeat/sender/multikeyHeartbeatSender_test.go new file mode 100644 index 00000000000..0d46c8facf2 --- /dev/null +++ b/heartbeat/sender/multikeyHeartbeatSender_test.go @@ -0,0 +1,480 @@ +package sender + +import ( + "errors" + "strings" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/heartbeat" + "github.com/multiversx/mx-chain-go/heartbeat/mock" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/stretchr/testify/assert" +) + +func createMockMultikeyHeartbeatSenderArgs(argBase argBaseSender) argMultikeyHeartbeatSender { + return argMultikeyHeartbeatSender{ + argBaseSender: argBase, + peerTypeProvider: &mock.PeerTypeProviderStub{ + ComputeForPubKeyCalled: func(pubKey []byte) (common.PeerType, uint32, error) { + return common.ObserverList, 0, nil + }, + }, + versionNumber: "version", + baseVersionNumber: "base version", + nodeDisplayName: "default name", + identity: "default identity", + peerSubType: core.FullHistoryObserver, + currentBlockProvider: &mock.CurrentBlockProviderStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &testscommon.HeaderHandlerStub{} + }, + }, + managedPeersHolder: &testscommon.ManagedPeersHolderStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + shardCoordinator: createShardCoordinatorInShard(0), + trieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, + } +} + +func TestNewMultikeyHeartbeatSender(t *testing.T) { + t.Parallel() + + t.Run("nil messenger should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.messenger = nil + + senderInstance, err := newMultikeyHeartbeatSender(args) + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilMessenger, err) + }) + t.Run("nil marshaller should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.marshaller = nil + + senderInstance, err := newMultikeyHeartbeatSender(args) + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilMarshaller, err) + }) + t.Run("empty topic should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.topic = "" + + args := createMockMultikeyHeartbeatSenderArgs(argsBase) + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrEmptySendTopic, err) + }) + t.Run("invalid time between sends should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.timeBetweenSends = time.Second - time.Nanosecond + + args := createMockMultikeyHeartbeatSenderArgs(argsBase) + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidTimeDuration)) + assert.True(t, strings.Contains(err.Error(), "timeBetweenSends")) + assert.False(t, strings.Contains(err.Error(), "timeBetweenSendsWhenError")) + }) + t.Run("invalid time between sends should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.timeBetweenSendsWhenError = time.Second - time.Nanosecond + + args := createMockMultikeyHeartbeatSenderArgs(argsBase) + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidTimeDuration)) + assert.True(t, strings.Contains(err.Error(), "timeBetweenSendsWhenError")) + }) + t.Run("threshold too small should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.thresholdBetweenSends = 0.001 + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidThreshold)) + assert.True(t, strings.Contains(err.Error(), "thresholdBetweenSends")) + }) + t.Run("threshold too big should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.thresholdBetweenSends = 1.001 + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidThreshold)) + assert.True(t, strings.Contains(err.Error(), "thresholdBetweenSends")) + }) + t.Run("nil peerTypeProvider should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.peerTypeProvider = nil + + senderInstance, err := newMultikeyHeartbeatSender(args) + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilPeerTypeProvider, err) + }) + t.Run("version number too long should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.versionNumber = strings.Repeat("a", maxSizeInBytes+1) + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrPropertyTooLong)) + assert.True(t, strings.Contains(err.Error(), "versionNumber")) + }) + t.Run("base version number too long should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.baseVersionNumber = strings.Repeat("a", maxSizeInBytes+1) + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrPropertyTooLong)) + assert.True(t, strings.Contains(err.Error(), "baseVersionNumber")) + }) + t.Run("node display name too long should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.nodeDisplayName = strings.Repeat("a", maxSizeInBytes+1) + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrPropertyTooLong)) + assert.True(t, strings.Contains(err.Error(), "nodeDisplayName")) + }) + t.Run("identity too long should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.identity = strings.Repeat("a", maxSizeInBytes+1) + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrPropertyTooLong)) + assert.True(t, strings.Contains(err.Error(), "identity")) + }) + t.Run("nil currentBlockProvider should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.currentBlockProvider = nil + + senderInstance, err := newMultikeyHeartbeatSender(args) + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilCurrentBlockProvider, err) + }) + t.Run("nil managed peers holder should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.managedPeersHolder = nil + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilManagedPeersHolder, err) + }) + t.Run("nil shard coordinator should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + args.shardCoordinator = nil + senderInstance, err := newMultikeyHeartbeatSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilShardCoordinator, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + + senderInstance, err := newMultikeyHeartbeatSender(args) + assert.False(t, check.IfNil(senderInstance)) + assert.Nil(t, err) + }) +} + +func TestMultikeyHeartbeatSender_Execute(t *testing.T) { + t.Parallel() + + t.Run("execute errors, should set the error time duration value", func(t *testing.T) { + t.Parallel() + + wasCalled := false + argsBase := createMockBaseArgs() + argsBase.timeBetweenSendsWhenError = time.Second * 3 + argsBase.timeBetweenSends = time.Second * 2 + argsBase.marshaller = &testscommon.MarshalizerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }, + } + + args := createMockMultikeyHeartbeatSenderArgs(argsBase) + senderInstance, _ := newMultikeyHeartbeatSender(args) + senderInstance.timerHandler = &mock.TimerHandlerStub{ + CreateNewTimerCalled: func(duration time.Duration) { + assert.Equal(t, argsBase.timeBetweenSendsWhenError, duration) + wasCalled = true + }, + } + + senderInstance.Execute() + assert.True(t, wasCalled) + }) + t.Run("execute worked, should set the normal time duration value", func(t *testing.T) { + t.Parallel() + + wasCalled := false + argsBase := createMockBaseArgs() + argsBase.timeBetweenSendsWhenError = time.Second * 3 + argsBase.timeBetweenSends = time.Second * 2 + + args := createMockMultikeyHeartbeatSenderArgs(argsBase) + senderInstance, _ := newMultikeyHeartbeatSender(args) + senderInstance.timerHandler = &mock.TimerHandlerStub{ + CreateNewTimerCalled: func(duration time.Duration) { + floatTBS := float64(argsBase.timeBetweenSends.Nanoseconds()) + maxDuration := floatTBS + floatTBS*argsBase.thresholdBetweenSends + assert.True(t, time.Duration(maxDuration) > duration) + assert.True(t, argsBase.timeBetweenSends <= duration) + wasCalled = true + }, + } + + senderInstance.Execute() + assert.True(t, wasCalled) + }) +} + +func TestMultikeyHeartbeatSender_execute(t *testing.T) { + t.Parallel() + + t.Run("should send the current node heartbeat", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + broadcastCalled := false + recordedMessages := make(map[core.PeerID][][]byte) + args.messenger = &p2pmocks.MessengerStub{ + BroadcastCalled: func(topic string, buff []byte) { + assert.Equal(t, args.topic, topic) + recordedMessages[args.messenger.ID()] = append(recordedMessages[args.messenger.ID()], buff) + broadcastCalled = true + }, + } + + senderInstance, _ := newMultikeyHeartbeatSender(args) + + err := senderInstance.execute() + assert.Nil(t, err) + assert.True(t, broadcastCalled) + assert.Equal(t, 1, len(recordedMessages)) + checkRecordedMessages(t, recordedMessages, args, args.versionNumber, args.nodeDisplayName, args.messenger.ID(), core.FullHistoryObserver) + assert.Equal(t, uint64(1), args.currentBlockProvider.GetCurrentBlockHeader().GetNonce()) + }) + t.Run("should send the current node heartbeat and some multikey heartbeats", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + recordedMessages := make(map[core.PeerID][][]byte) + args.messenger = &p2pmocks.MessengerStub{ + BroadcastCalled: func(topic string, buff []byte) { + assert.Equal(t, args.topic, topic) + recordedMessages[args.messenger.ID()] = append(recordedMessages[args.messenger.ID()], buff) + }, + BroadcastUsingPrivateKeyCalled: func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + assert.Equal(t, args.topic, topic) + recordedMessages[pid] = append(recordedMessages[pid], buff) + }, + } + args.managedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return string(pkBytes) != "dd" + }, + GetManagedKeysByCurrentNodeCalled: func() map[string]crypto.PrivateKey { + return map[string]crypto.PrivateKey{ + "aa": &mock.PrivateKeyStub{}, // in shard 1, handled by current node + "bb": &mock.PrivateKeyStub{}, // in shard 1, handled by current node + "cc": &mock.PrivateKeyStub{}, // in shard 1, handled by current node + "dd": &mock.PrivateKeyStub{}, // in shard 1, not handled by current node + "ee": &mock.PrivateKeyStub{}, // in shard 2, handled by current node + } + }, + GetP2PIdentityCalled: func(pkBytes []byte) ([]byte, core.PeerID, error) { + return []byte(string(pkBytes) + "_p2p"), core.PeerID(string(pkBytes) + "_pid"), nil + }, + GetMachineIDCalled: func(pkBytes []byte) (string, error) { + return string(pkBytes) + "_machineID", nil + }, + GetNameAndIdentityCalled: func(pkBytes []byte) (string, string, error) { + return string(pkBytes) + "_name", string(pkBytes) + "_identity", nil + }, + } + args.peerTypeProvider = &mock.PeerTypeProviderStub{ + ComputeForPubKeyCalled: func(pubKey []byte) (common.PeerType, uint32, error) { + if string(pubKey) == "ee" { + return "", 2, nil + } + return "", 1, nil + }, + } + args.shardCoordinator = createShardCoordinatorInShard(1) + + senderInstance, _ := newMultikeyHeartbeatSender(args) + + err := senderInstance.execute() + assert.Nil(t, err) + assert.Equal(t, 4, len(recordedMessages)) // current pid, aa, bb, cc + + checkRecordedMessages(t, + recordedMessages, + args, + args.versionNumber, + args.nodeDisplayName, + args.messenger.ID(), + core.FullHistoryObserver) + + checkRecordedMessages(t, + recordedMessages, + args, + args.baseVersionNumber+"/aa_machineID", + "aa_name", + "aa_pid", + core.RegularPeer) + + checkRecordedMessages(t, + recordedMessages, + args, + args.baseVersionNumber+"/bb_machineID", + "bb_name", + "bb_pid", + core.RegularPeer) + + checkRecordedMessages(t, + recordedMessages, + args, + args.baseVersionNumber+"/cc_machineID", + "cc_name", + "cc_pid", + core.RegularPeer) + + assert.Equal(t, uint64(1), args.currentBlockProvider.GetCurrentBlockHeader().GetNonce()) + }) +} + +func TestMultikeyHeartbeatSender_generateMessageBytes(t *testing.T) { + t.Parallel() + + t.Run("version too long should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + senderInstance, _ := newMultikeyHeartbeatSender(args) + + versionNumber := strings.Repeat("a", maxSizeInBytes+1) + nodeDisplayName := "a" + identity := "b" + buff, err := senderInstance.generateMessageBytes(versionNumber, nodeDisplayName, identity, 0, []byte("public key"), 0) + + assert.True(t, errors.Is(err, heartbeat.ErrPropertyTooLong)) + assert.True(t, strings.Contains(err.Error(), "versionNumber")) + assert.Nil(t, buff) + }) + t.Run("node name too long should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + senderInstance, _ := newMultikeyHeartbeatSender(args) + + versionNumber := "a" + nodeDisplayName := strings.Repeat("a", maxSizeInBytes+1) + identity := "b" + buff, err := senderInstance.generateMessageBytes(versionNumber, nodeDisplayName, identity, 0, []byte("public key"), 0) + + assert.True(t, errors.Is(err, heartbeat.ErrPropertyTooLong)) + assert.True(t, strings.Contains(err.Error(), "nodeDisplayName")) + assert.Nil(t, buff) + }) + t.Run("identity too long should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyHeartbeatSenderArgs(createMockBaseArgs()) + senderInstance, _ := newMultikeyHeartbeatSender(args) + + versionNumber := "a" + nodeDisplayName := "b" + identity := strings.Repeat("a", maxSizeInBytes+1) + buff, err := senderInstance.generateMessageBytes(versionNumber, nodeDisplayName, identity, 0, []byte("public key"), 0) + + assert.True(t, errors.Is(err, heartbeat.ErrPropertyTooLong)) + assert.True(t, strings.Contains(err.Error(), "identity")) + assert.Nil(t, buff) + }) +} + +func checkRecordedMessages( + tb testing.TB, + recordedMessages map[core.PeerID][][]byte, + args argMultikeyHeartbeatSender, + version string, + nodeDisplayName string, + pid core.PeerID, + peerSubType core.P2PPeerSubType, +) { + messages := recordedMessages[pid] + assert.True(tb, len(messages) > 0) + + for _, message := range messages { + checkRecordedMessage(tb, message, args, version, nodeDisplayName, peerSubType) + } +} + +func checkRecordedMessage( + tb testing.TB, + recordedMessage []byte, + args argMultikeyHeartbeatSender, + version string, + nodeDisplayName string, + peerSubType core.P2PPeerSubType, +) { + msg := &heartbeat.HeartbeatV2{} + err := args.marshaller.Unmarshal(msg, recordedMessage) + assert.Nil(tb, err) + + assert.Equal(tb, version, msg.VersionNumber) + assert.Equal(tb, nodeDisplayName, msg.NodeDisplayName) + assert.Equal(tb, uint32(peerSubType), msg.PeerSubType) +} diff --git a/heartbeat/sender/multikeyPeerAuthenticationSender.go b/heartbeat/sender/multikeyPeerAuthenticationSender.go new file mode 100644 index 00000000000..ac6d03b849b --- /dev/null +++ b/heartbeat/sender/multikeyPeerAuthenticationSender.go @@ -0,0 +1,225 @@ +package sender + +import ( + "encoding/hex" + "fmt" + "time" + + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/heartbeat" +) + +const delayedBroadcast = 200 * time.Millisecond + +// argMultikeyPeerAuthenticationSender represents the arguments for the peer authentication sender +type argMultikeyPeerAuthenticationSender struct { + argBaseSender + nodesCoordinator heartbeat.NodesCoordinator + peerSignatureHandler crypto.PeerSignatureHandler + hardforkTrigger heartbeat.HardforkTrigger + hardforkTimeBetweenSends time.Duration + hardforkTriggerPubKey []byte + managedPeersHolder heartbeat.ManagedPeersHolder + timeBetweenChecks time.Duration + shardCoordinator heartbeat.ShardCoordinator +} + +type multikeyPeerAuthenticationSender struct { + commonPeerAuthenticationSender + hardforkTimeBetweenSends time.Duration + managedPeersHolder heartbeat.ManagedPeersHolder + timeBetweenChecks time.Duration + shardCoordinator heartbeat.ShardCoordinator + getCurrentTimeHandler func() time.Time +} + +// newMultikeyPeerAuthenticationSender will create a new instance of type multikeyPeerAuthenticationSender +func newMultikeyPeerAuthenticationSender(args argMultikeyPeerAuthenticationSender) (*multikeyPeerAuthenticationSender, error) { + err := checkMultikeyPeerAuthenticationSenderArgs(args) + if err != nil { + return nil, err + } + + senderInstance := &multikeyPeerAuthenticationSender{ + commonPeerAuthenticationSender: commonPeerAuthenticationSender{ + baseSender: createBaseSender(args.argBaseSender), + nodesCoordinator: args.nodesCoordinator, + peerSignatureHandler: args.peerSignatureHandler, + hardforkTrigger: args.hardforkTrigger, + hardforkTriggerPubKey: args.hardforkTriggerPubKey, + }, + hardforkTimeBetweenSends: args.hardforkTimeBetweenSends, + managedPeersHolder: args.managedPeersHolder, + timeBetweenChecks: args.timeBetweenChecks, + shardCoordinator: args.shardCoordinator, + getCurrentTimeHandler: getCurrentTime, + } + + return senderInstance, nil +} + +func getCurrentTime() time.Time { + return time.Now() +} + +func checkMultikeyPeerAuthenticationSenderArgs(args argMultikeyPeerAuthenticationSender) error { + err := checkBaseSenderArgs(args.argBaseSender) + if err != nil { + return err + } + if check.IfNil(args.nodesCoordinator) { + return heartbeat.ErrNilNodesCoordinator + } + if check.IfNil(args.peerSignatureHandler) { + return heartbeat.ErrNilPeerSignatureHandler + } + if check.IfNil(args.hardforkTrigger) { + return heartbeat.ErrNilHardforkTrigger + } + if args.hardforkTimeBetweenSends < minTimeBetweenSends { + return fmt.Errorf("%w for hardforkTimeBetweenSends", heartbeat.ErrInvalidTimeDuration) + } + if len(args.hardforkTriggerPubKey) == 0 { + return fmt.Errorf("%w hardfork trigger public key bytes length is 0", heartbeat.ErrInvalidValue) + } + if check.IfNil(args.managedPeersHolder) { + return heartbeat.ErrNilManagedPeersHolder + } + if args.timeBetweenChecks < minTimeBetweenSends { + return fmt.Errorf("%w for timeBetweenChecks", heartbeat.ErrInvalidTimeDuration) + } + if check.IfNil(args.shardCoordinator) { + return heartbeat.ErrNilShardCoordinator + } + + return nil +} + +// Execute will handle the execution of a cycle in which the peer authentication message will be sent +func (sender *multikeyPeerAuthenticationSender) Execute() { + currentTimeAsUnix := sender.getCurrentTimeHandler().Unix() + managedKeys := sender.managedPeersHolder.GetManagedKeysByCurrentNode() + for pk, sk := range managedKeys { + err := sender.process(pk, sk, currentTimeAsUnix) + if err != nil { + nextTimeToCheck, errNextPeerAuth := sender.managedPeersHolder.GetNextPeerAuthenticationTime([]byte(pk)) + if errNextPeerAuth != nil { + log.Error("could not get next peer authentication time for pk", "pk", pk, "process error", err, "GetNextPeerAuthenticationTime error", errNextPeerAuth) + return + } + + log.Error("error sending peer authentication message", "bls pk", pk, + "next send is scheduled on", nextTimeToCheck, "error", err) + } + } + + sender.CreateNewTimer(sender.timeBetweenChecks) +} + +func (sender *multikeyPeerAuthenticationSender) process(pk string, sk crypto.PrivateKey, currentTimeAsUnix int64) error { + pkBytes := []byte(pk) + if !sender.processIfShouldSend(pkBytes, currentTimeAsUnix) { + return nil + } + + currentTimeStamp := time.Unix(currentTimeAsUnix, 0) + + data, isHardforkTriggered, _, err := sender.prepareMessage([]byte(pk), sk) + if err != nil { + sender.managedPeersHolder.SetNextPeerAuthenticationTime(pkBytes, currentTimeStamp.Add(sender.timeBetweenSendsWhenError)) + return err + } + if isHardforkTriggered { + nextTimeStamp := currentTimeStamp.Add(sender.computeRandomDuration(sender.hardforkTimeBetweenSends)) + sender.managedPeersHolder.SetNextPeerAuthenticationTime(pkBytes, nextTimeStamp) + } else { + nextTimeStamp := currentTimeStamp.Add(sender.computeRandomDuration(sender.timeBetweenSends)) + sender.managedPeersHolder.SetNextPeerAuthenticationTime(pkBytes, nextTimeStamp) + sender.managedPeersHolder.SetValidatorState(pkBytes, true) + } + + sender.sendData(pkBytes, data, isHardforkTriggered) + + return nil +} + +func (sender *multikeyPeerAuthenticationSender) processIfShouldSend(pkBytes []byte, currentTimeAsUnix int64) bool { + if !sender.managedPeersHolder.IsKeyManagedByCurrentNode(pkBytes) { + return false + } + isValidatorNow, shardID := sender.getIsValidatorStatusAndShardID(pkBytes) + isHardforkSource := sender.isHardforkSource(pkBytes) + oldIsValidator := sender.managedPeersHolder.IsKeyValidator(pkBytes) + sender.managedPeersHolder.SetValidatorState(pkBytes, isValidatorNow) + + if !isValidatorNow && !isHardforkSource { + return false + } + if shardID != sender.shardCoordinator.SelfId() { + return false + } + + nextTimeToCheck, err := sender.managedPeersHolder.GetNextPeerAuthenticationTime(pkBytes) + if err != nil { + return false + } + + timeToCheck := nextTimeToCheck.Unix() < currentTimeAsUnix + if timeToCheck { + return true + } + if !oldIsValidator && isValidatorNow { + return true + } + + return false +} + +func (sender *multikeyPeerAuthenticationSender) prepareMessage(pkBytes []byte, privateKey crypto.PrivateKey) ([]byte, bool, int64, error) { + p2pSkBytes, pid, err := sender.managedPeersHolder.GetP2PIdentity(pkBytes) + if err != nil { + return nil, false, 0, err + } + + return sender.generateMessageBytes(pkBytes, privateKey, p2pSkBytes, pid.Bytes()) +} + +func (sender *multikeyPeerAuthenticationSender) sendData(pkBytes []byte, data []byte, isHardforkTriggered bool) { + // extra delay as to avoid sending a lot of messages in the same time + time.Sleep(delayedBroadcast) + + p2pSk, pid, err := sender.managedPeersHolder.GetP2PIdentity(pkBytes) + if err != nil { + log.Error("could not get identity for pk", "pk", hex.EncodeToString(pkBytes), "error", err) + return + } + sender.messenger.BroadcastUsingPrivateKey(sender.topic, data, pid, p2pSk) + + nextTimeToCheck, err := sender.managedPeersHolder.GetNextPeerAuthenticationTime(pkBytes) + if err != nil { + log.Error("could not get next peer authentication time for pk", "pk", hex.EncodeToString(pkBytes), "error", err) + return + } + + log.Debug("peer authentication message sent", + "bls pk", pkBytes, + "pid", pid.Pretty(), + "is hardfork triggered", isHardforkTriggered, + "next send is scheduled on", nextTimeToCheck) +} + +// ShouldTriggerHardfork signals when hardfork message should be sent +func (sender *multikeyPeerAuthenticationSender) ShouldTriggerHardfork() <-chan struct{} { + return sender.hardforkTrigger.NotifyTriggerReceivedV2() +} + +func (sender *multikeyPeerAuthenticationSender) getIsValidatorStatusAndShardID(pkBytes []byte) (bool, uint32) { + _, shardID, err := sender.nodesCoordinator.GetValidatorWithPublicKey(pkBytes) + return err == nil, shardID +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sender *multikeyPeerAuthenticationSender) IsInterfaceNil() bool { + return sender == nil +} diff --git a/heartbeat/sender/multikeyPeerAuthenticationSender_test.go b/heartbeat/sender/multikeyPeerAuthenticationSender_test.go new file mode 100644 index 00000000000..37107f3b6e3 --- /dev/null +++ b/heartbeat/sender/multikeyPeerAuthenticationSender_test.go @@ -0,0 +1,709 @@ +package sender + +import ( + "errors" + "strings" + "sync" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/batch" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/signing" + "github.com/multiversx/mx-chain-crypto-go/signing/ed25519" + ed25519SingleSig "github.com/multiversx/mx-chain-crypto-go/signing/ed25519/singlesig" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl/singlesig" + "github.com/multiversx/mx-chain-go/heartbeat" + "github.com/multiversx/mx-chain-go/heartbeat/mock" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createShardCoordinatorInShard(shardID uint32) *testscommon.ShardsCoordinatorMock { + shardCoordinator := testscommon.NewMultiShardsCoordinatorMock(3) + shardCoordinator.CurrentShard = shardID + + return shardCoordinator +} + +func createMockMultikeyPeerAuthenticationSenderArgs(argBase argBaseSender) argMultikeyPeerAuthenticationSender { + return argMultikeyPeerAuthenticationSender{ + argBaseSender: argBase, + nodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, + peerSignatureHandler: &cryptoMocks.PeerSignatureHandlerStub{}, + hardforkTrigger: &testscommon.HardforkTriggerStub{}, + hardforkTimeBetweenSends: time.Second, + hardforkTriggerPubKey: providedHardforkPubKey, + timeBetweenChecks: time.Second, + managedPeersHolder: &testscommon.ManagedPeersHolderStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + shardCoordinator: createShardCoordinatorInShard(0), + } +} + +func createMockMultikeyPeerAuthenticationSenderArgsSemiIntegrationTests( + numKeys int, +) (argMultikeyPeerAuthenticationSender, *p2pmocks.MessengerStub) { + keyGenForBLS := signing.NewKeyGenerator(mcl.NewSuiteBLS12()) + keyGenForP2P := signing.NewKeyGenerator(ed25519.NewEd25519()) + signerMessenger := ed25519SingleSig.Ed25519Signer{} + + keyMap := make(map[string]crypto.PrivateKey) + for i := 0; i < numKeys; i++ { + sk, pk := keyGenForBLS.GeneratePair() + pkBytes, _ := pk.ToByteArray() + + keyMap[string(pkBytes)] = sk + } + + p2pSkPkMap := make(map[string][]byte) + peerIdPkMap := make(map[string]core.PeerID) + for pk := range keyMap { + p2pSk, p2pPk := keyGenForP2P.GeneratePair() + p2pSkBytes, _ := p2pSk.ToByteArray() + p2pPkBytes, _ := p2pPk.ToByteArray() + + p2pSkPkMap[pk] = p2pSkBytes + peerIdPkMap[pk] = core.PeerID(p2pPkBytes) + } + + mutTimeMap := sync.RWMutex{} + peerAuthTimeMap := make(map[string]time.Time) + managedPeersHolder := &testscommon.ManagedPeersHolderStub{ + GetP2PIdentityCalled: func(pkBytes []byte) ([]byte, core.PeerID, error) { + return p2pSkPkMap[string(pkBytes)], peerIdPkMap[string(pkBytes)], nil + }, + GetManagedKeysByCurrentNodeCalled: func() map[string]crypto.PrivateKey { + return keyMap + }, + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + IsKeyValidatorCalled: func(pkBytes []byte) bool { + return true + }, + GetNextPeerAuthenticationTimeCalled: func(pkBytes []byte) (time.Time, error) { + mutTimeMap.RLock() + defer mutTimeMap.RUnlock() + return peerAuthTimeMap[string(pkBytes)], nil + }, + SetNextPeerAuthenticationTimeCalled: func(pkBytes []byte, nextTime time.Time) { + mutTimeMap.Lock() + defer mutTimeMap.Unlock() + peerAuthTimeMap[string(pkBytes)] = nextTime + }, + } + + singleSigner := singlesig.NewBlsSigner() + + baseArgs := createMockBaseArgs() + args := argMultikeyPeerAuthenticationSender{ + argBaseSender: baseArgs, + nodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, + peerSignatureHandler: &mock.PeerSignatureHandlerStub{ + VerifyPeerSignatureCalled: func(pk []byte, pid core.PeerID, signature []byte) error { + senderPubKey, err := keyGenForBLS.PublicKeyFromByteArray(pk) + if err != nil { + return err + } + return singleSigner.Verify(senderPubKey, pid.Bytes(), signature) + }, + GetPeerSignatureCalled: func(privateKey crypto.PrivateKey, pid []byte) ([]byte, error) { + return singleSigner.Sign(privateKey, pid) + }, + }, + hardforkTrigger: &testscommon.HardforkTriggerStub{}, + hardforkTimeBetweenSends: time.Second, + hardforkTriggerPubKey: providedHardforkPubKey, + timeBetweenChecks: time.Second, + managedPeersHolder: managedPeersHolder, + shardCoordinator: createShardCoordinatorInShard(0), + } + + messenger := &p2pmocks.MessengerStub{ + SignUsingPrivateKeyCalled: func(skBytes []byte, payload []byte) ([]byte, error) { + p2pSk, _ := keyGenForP2P.PrivateKeyFromByteArray(skBytes) + + return signerMessenger.Sign(p2pSk, payload) + }, + VerifyCalled: func(payload []byte, pid core.PeerID, signature []byte) error { + pk, _ := keyGenForP2P.PublicKeyFromByteArray(pid.Bytes()) + + return signerMessenger.Verify(pk, payload, signature) + }, + } + + args.messenger = messenger + + return args, messenger +} + +func TestNewMultikeyPeerAuthenticationSender(t *testing.T) { + t.Parallel() + + t.Run("nil peer messenger should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.messenger = nil + + args := createMockMultikeyPeerAuthenticationSenderArgs(argsBase) + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilMessenger, err) + }) + t.Run("nil marshaller should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.marshaller = nil + + args := createMockMultikeyPeerAuthenticationSenderArgs(argsBase) + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilMarshaller, err) + }) + t.Run("empty topic should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.topic = "" + + args := createMockMultikeyPeerAuthenticationSenderArgs(argsBase) + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrEmptySendTopic, err) + }) + t.Run("invalid time between sends should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.timeBetweenSends = time.Second - time.Nanosecond + + args := createMockMultikeyPeerAuthenticationSenderArgs(argsBase) + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidTimeDuration)) + assert.True(t, strings.Contains(err.Error(), "timeBetweenSends")) + assert.False(t, strings.Contains(err.Error(), "timeBetweenSendsWhenError")) + }) + t.Run("invalid time between sends should error", func(t *testing.T) { + t.Parallel() + + argsBase := createMockBaseArgs() + argsBase.timeBetweenSendsWhenError = time.Second - time.Nanosecond + + args := createMockMultikeyPeerAuthenticationSenderArgs(argsBase) + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidTimeDuration)) + assert.True(t, strings.Contains(err.Error(), "timeBetweenSendsWhenError")) + }) + t.Run("threshold too small should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.thresholdBetweenSends = 0.001 + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidThreshold)) + assert.True(t, strings.Contains(err.Error(), "thresholdBetweenSends")) + }) + t.Run("threshold too big should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.thresholdBetweenSends = 1.001 + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidThreshold)) + assert.True(t, strings.Contains(err.Error(), "thresholdBetweenSends")) + }) + t.Run("nil nodes coordinator should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.nodesCoordinator = nil + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilNodesCoordinator, err) + }) + t.Run("nil peer signature handler should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.peerSignatureHandler = nil + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilPeerSignatureHandler, err) + }) + t.Run("nil hardfork trigger should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.hardforkTrigger = nil + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilHardforkTrigger, err) + }) + t.Run("invalid time between hardforks should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.hardforkTimeBetweenSends = time.Second - time.Nanosecond + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidTimeDuration)) + assert.True(t, strings.Contains(err.Error(), "hardforkTimeBetweenSends")) + }) + t.Run("nil managed peers holder should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.managedPeersHolder = nil + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilManagedPeersHolder, err) + }) + t.Run("invalid time between checks should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.timeBetweenChecks = time.Second - time.Nanosecond + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidTimeDuration)) + assert.True(t, strings.Contains(err.Error(), "timeBetweenChecks")) + }) + t.Run("nil shard coordinator should error", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + args.shardCoordinator = nil + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.True(t, check.IfNil(senderInstance)) + assert.Equal(t, heartbeat.ErrNilShardCoordinator, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := createMockMultikeyPeerAuthenticationSenderArgs(createMockBaseArgs()) + senderInstance, err := newMultikeyPeerAuthenticationSender(args) + + assert.False(t, check.IfNil(senderInstance)) + assert.Nil(t, err) + }) +} + +func TestNewMultikeyPeerAuthenticationSender_Execute(t *testing.T) { + t.Parallel() + + t.Run("should work for the first time with some real components", func(t *testing.T) { + t.Parallel() + + numKeys := 3 + mutData := sync.Mutex{} + var buffResulted [][]byte + var pids []core.PeerID + var skBytesBroadcast [][]byte + + args, messenger := createMockMultikeyPeerAuthenticationSenderArgsSemiIntegrationTests(numKeys) + messenger.BroadcastUsingPrivateKeyCalled = func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + assert.Equal(t, args.topic, topic) + + mutData.Lock() + buffResulted = append(buffResulted, buff) + pids = append(pids, pid) + skBytesBroadcast = append(skBytesBroadcast, skBytes) + mutData.Unlock() + } + + senderInstance, _ := newMultikeyPeerAuthenticationSender(args) + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys, "") + mutData.Unlock() + }) + t.Run("should work with some real components", func(t *testing.T) { + t.Parallel() + + numKeys := 3 + mutData := sync.Mutex{} + var buffResulted [][]byte + var pids []core.PeerID + var skBytesBroadcast [][]byte + + args, messenger := createMockMultikeyPeerAuthenticationSenderArgsSemiIntegrationTests(numKeys) + messenger.BroadcastUsingPrivateKeyCalled = func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + assert.Equal(t, args.topic, topic) + + mutData.Lock() + buffResulted = append(buffResulted, buff) + pids = append(pids, pid) + skBytesBroadcast = append(skBytesBroadcast, skBytes) + mutData.Unlock() + } + args.timeBetweenSends = time.Second * 3 + args.thresholdBetweenSends = 0.20 + + senderInstance, _ := newMultikeyPeerAuthenticationSender(args) + senderInstance.Execute() + + // reset data from initialization + mutData.Lock() + buffResulted = make([][]byte, 0) + pids = make([]core.PeerID, 0) + skBytesBroadcast = make([][]byte, 0) + mutData.Unlock() + + senderInstance.Execute() // this will not add messages + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, 0, "") + mutData.Unlock() + + time.Sleep(time.Second * 5) // allow the resending of the messages + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys, "") + mutData.Unlock() + }) + t.Run("should work with some real components and one key not handled by the current node", func(t *testing.T) { + t.Parallel() + + numKeys := 3 + mutData := sync.Mutex{} + var buffResulted [][]byte + var pids []core.PeerID + var skBytesBroadcast [][]byte + + args, messenger := createMockMultikeyPeerAuthenticationSenderArgsSemiIntegrationTests(numKeys) + messenger.BroadcastUsingPrivateKeyCalled = func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + assert.Equal(t, args.topic, topic) + + mutData.Lock() + buffResulted = append(buffResulted, buff) + pids = append(pids, pid) + skBytesBroadcast = append(skBytesBroadcast, skBytes) + mutData.Unlock() + } + args.timeBetweenSends = time.Second * 3 + args.thresholdBetweenSends = 0.20 + firstKeyFound := "" + pkSkMap := args.managedPeersHolder.GetManagedKeysByCurrentNode() + for key := range pkSkMap { + firstKeyFound = key + break + } + stub := args.managedPeersHolder.(*testscommon.ManagedPeersHolderStub) + stub.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return firstKeyFound != string(pkBytes) + } + + senderInstance, _ := newMultikeyPeerAuthenticationSender(args) + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys-1, "") + mutData.Unlock() + + // reset data from initialization + mutData.Lock() + buffResulted = make([][]byte, 0) + pids = make([]core.PeerID, 0) + skBytesBroadcast = make([][]byte, 0) + mutData.Unlock() + + senderInstance.Execute() // this will not add messages + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, 0, "") + mutData.Unlock() + + time.Sleep(time.Second * 5) // allow the resending of the messages + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys-1, "") + mutData.Unlock() + }) + t.Run("should work with some real components one key not a validator", func(t *testing.T) { + t.Parallel() + + numKeys := 3 + mutData := sync.Mutex{} + var buffResulted [][]byte + var pids []core.PeerID + var skBytesBroadcast [][]byte + + args, messenger := createMockMultikeyPeerAuthenticationSenderArgsSemiIntegrationTests(numKeys) + messenger.BroadcastUsingPrivateKeyCalled = func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + assert.Equal(t, args.topic, topic) + + mutData.Lock() + buffResulted = append(buffResulted, buff) + pids = append(pids, pid) + skBytesBroadcast = append(skBytesBroadcast, skBytes) + mutData.Unlock() + } + args.timeBetweenSends = time.Second * 3 + args.thresholdBetweenSends = 0.20 + firstKeyFound := "" + pkSkMap := args.managedPeersHolder.GetManagedKeysByCurrentNode() + for key := range pkSkMap { + firstKeyFound = key + break + } + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + if firstKeyFound == string(publicKey) { + return nil, 0, errors.New("not a validator") + } + + return nil, 0, nil + }, + } + senderInstance, _ := newMultikeyPeerAuthenticationSender(args) + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys-1, "") + mutData.Unlock() + + // reset data from initialization + mutData.Lock() + buffResulted = make([][]byte, 0) + pids = make([]core.PeerID, 0) + skBytesBroadcast = make([][]byte, 0) + mutData.Unlock() + + senderInstance.Execute() // this will not add messages + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, 0, "") + mutData.Unlock() + + time.Sleep(time.Second * 5) // allow the resending of the messages + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys-1, "") + mutData.Unlock() + }) + t.Run("should work with some real components one key is on a different shard", func(t *testing.T) { + t.Parallel() + + numKeys := 3 + mutData := sync.Mutex{} + var buffResulted [][]byte + var pids []core.PeerID + var skBytesBroadcast [][]byte + + args, messenger := createMockMultikeyPeerAuthenticationSenderArgsSemiIntegrationTests(numKeys) + messenger.BroadcastUsingPrivateKeyCalled = func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + assert.Equal(t, args.topic, topic) + + mutData.Lock() + buffResulted = append(buffResulted, buff) + pids = append(pids, pid) + skBytesBroadcast = append(skBytesBroadcast, skBytes) + mutData.Unlock() + } + args.timeBetweenSends = time.Second * 3 + args.thresholdBetweenSends = 0.20 + firstKeyFound := "" + pkSkMap := args.managedPeersHolder.GetManagedKeysByCurrentNode() + for key := range pkSkMap { + firstKeyFound = key + break + } + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + if firstKeyFound == string(publicKey) { + return nil, 1, nil + } + + return nil, 0, nil + }, + } + senderInstance, _ := newMultikeyPeerAuthenticationSender(args) + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys-1, "") + mutData.Unlock() + + // reset data from initialization + mutData.Lock() + buffResulted = make([][]byte, 0) + pids = make([]core.PeerID, 0) + skBytesBroadcast = make([][]byte, 0) + mutData.Unlock() + + senderInstance.Execute() // this will not add messages + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, 0, "") + mutData.Unlock() + + time.Sleep(time.Second * 5) // allow the resending of the messages + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys-1, "") + mutData.Unlock() + }) + t.Run("should work with some real components and hardfork trigger", func(t *testing.T) { + t.Parallel() + + numKeys := 3 + mutData := sync.Mutex{} + var buffResulted [][]byte + var pids []core.PeerID + var skBytesBroadcast [][]byte + + args, messenger := createMockMultikeyPeerAuthenticationSenderArgsSemiIntegrationTests(numKeys) + messenger.BroadcastUsingPrivateKeyCalled = func(topic string, buff []byte, pid core.PeerID, skBytes []byte) { + assert.Equal(t, args.topic, topic) + + mutData.Lock() + buffResulted = append(buffResulted, buff) + pids = append(pids, pid) + skBytesBroadcast = append(skBytesBroadcast, skBytes) + mutData.Unlock() + } + args.timeBetweenSends = time.Second * 3 + args.thresholdBetweenSends = 0.20 + hardforkTriggerPayload := []byte("hardfork payload") + args.hardforkTrigger = &testscommon.HardforkTriggerStub{ + RecordedTriggerMessageCalled: func() ([]byte, bool) { + return make([]byte, 0), true + }, + CreateDataCalled: func() []byte { + return hardforkTriggerPayload + }, + } + + senderInstance, _ := newMultikeyPeerAuthenticationSender(args) + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys, string(hardforkTriggerPayload)) + mutData.Unlock() + + // reset data from initialization + mutData.Lock() + buffResulted = make([][]byte, 0) + pids = make([]core.PeerID, 0) + skBytesBroadcast = make([][]byte, 0) + mutData.Unlock() + + time.Sleep(time.Second * 2) + senderInstance.Execute() // this will add messages because we are in hardfork mode + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys, string(hardforkTriggerPayload)) + mutData.Unlock() + + // reset data + mutData.Lock() + buffResulted = make([][]byte, 0) + pids = make([]core.PeerID, 0) + skBytesBroadcast = make([][]byte, 0) + mutData.Unlock() + + time.Sleep(time.Second * 5) // allow the resending of the messages + senderInstance.Execute() + + mutData.Lock() + testRecoveredMessages(t, args, buffResulted, pids, skBytesBroadcast, numKeys, string(hardforkTriggerPayload)) + mutData.Unlock() + }) +} + +func testRecoveredMessages( + tb testing.TB, + args argMultikeyPeerAuthenticationSender, + payloads [][]byte, + pids []core.PeerID, + skBytesBroadcast [][]byte, + numExpected int, + + hardforkPayload string, +) { + require.Equal(tb, numExpected, len(payloads)) + require.Equal(tb, numExpected, len(pids)) + require.Equal(tb, numExpected, len(skBytesBroadcast)) + + for i := 0; i < len(payloads); i++ { + payload := payloads[i] + pid := pids[i] + + testSingleMessage(tb, args, payload, pid, hardforkPayload) + } +} + +func testSingleMessage( + tb testing.TB, + args argMultikeyPeerAuthenticationSender, + payload []byte, + pid core.PeerID, + hardforkPayload string, +) { + recoveredBatch := batch.Batch{} + err := args.marshaller.Unmarshal(&recoveredBatch, payload) + assert.Nil(tb, err) + + recoveredMessage := &heartbeat.PeerAuthentication{} + err = args.marshaller.Unmarshal(recoveredMessage, recoveredBatch.Data[0]) + assert.Nil(tb, err) + + _, correspondingPid, err := args.managedPeersHolder.GetP2PIdentity(recoveredMessage.Pubkey) + assert.Nil(tb, err) + assert.Equal(tb, correspondingPid.Pretty(), core.PeerID(recoveredMessage.Pid).Pretty()) + assert.Equal(tb, correspondingPid, pid) + + errVerify := args.peerSignatureHandler.VerifyPeerSignature(recoveredMessage.Pubkey, core.PeerID(recoveredMessage.Pid), recoveredMessage.Signature) + assert.Nil(tb, errVerify) + + messenger := args.messenger.(*p2pmocks.MessengerStub) + errVerify = messenger.Verify(recoveredMessage.Payload, core.PeerID(recoveredMessage.Pid), recoveredMessage.PayloadSignature) + assert.Nil(tb, errVerify) + + recoveredPayload := &heartbeat.Payload{} + err = args.marshaller.Unmarshal(recoveredPayload, recoveredMessage.Payload) + assert.Nil(tb, err) + + endTime := time.Now() + + messageTime := time.Unix(recoveredPayload.Timestamp, 0) + assert.True(tb, messageTime.Unix() <= endTime.Unix()) + assert.Equal(tb, hardforkPayload, recoveredPayload.HardforkMessage) +} diff --git a/heartbeat/sender/peerAuthenticationSender.go b/heartbeat/sender/peerAuthenticationSender.go index cf67e73d0a4..6151177c8af 100644 --- a/heartbeat/sender/peerAuthenticationSender.go +++ b/heartbeat/sender/peerAuthenticationSender.go @@ -1,12 +1,10 @@ package sender import ( - "bytes" "fmt" "time" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-core-go/data/batch" crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-go/heartbeat" ) @@ -22,12 +20,12 @@ type argPeerAuthenticationSender struct { } type peerAuthenticationSender struct { - baseSender - nodesCoordinator heartbeat.NodesCoordinator - peerSignatureHandler crypto.PeerSignatureHandler - hardforkTrigger heartbeat.HardforkTrigger + commonPeerAuthenticationSender + redundancy heartbeat.NodeRedundancyHandler + privKey crypto.PrivateKey + publicKey crypto.PublicKey + observerPublicKey crypto.PublicKey hardforkTimeBetweenSends time.Duration - hardforkTriggerPubKey []byte } // newPeerAuthenticationSender will create a new instance of type peerAuthenticationSender @@ -37,13 +35,20 @@ func newPeerAuthenticationSender(args argPeerAuthenticationSender) (*peerAuthent return nil, err } + redundancyHandler := args.redundancyHandler senderInstance := &peerAuthenticationSender{ - baseSender: createBaseSender(args.argBaseSender), - nodesCoordinator: args.nodesCoordinator, - peerSignatureHandler: args.peerSignatureHandler, - hardforkTrigger: args.hardforkTrigger, + commonPeerAuthenticationSender: commonPeerAuthenticationSender{ + baseSender: createBaseSender(args.argBaseSender), + nodesCoordinator: args.nodesCoordinator, + peerSignatureHandler: args.peerSignatureHandler, + hardforkTrigger: args.hardforkTrigger, + hardforkTriggerPubKey: args.hardforkTriggerPubKey, + }, + redundancy: redundancyHandler, + privKey: args.privKey, + publicKey: args.privKey.GeneratePublic(), + observerPublicKey: redundancyHandler.ObserverPrivateKey().GeneratePublic(), hardforkTimeBetweenSends: args.hardforkTimeBetweenSends, - hardforkTriggerPubKey: args.hardforkTriggerPubKey, } return senderInstance, nil @@ -110,52 +115,19 @@ func (sender *peerAuthenticationSender) Execute() { func (sender *peerAuthenticationSender) execute() (error, bool) { sk, pk := sender.getCurrentPrivateAndPublicKeys() - msg := &heartbeat.PeerAuthentication{ - Pid: sender.messenger.ID().Bytes(), - } - - hardforkPayload, isTriggered := sender.getHardforkPayload() - payload := &heartbeat.Payload{ - Timestamp: time.Now().Unix(), - HardforkMessage: string(hardforkPayload), - } - payloadBytes, err := sender.marshaller.Marshal(payload) - if err != nil { - return err, isTriggered - } - msg.Payload = payloadBytes - msg.PayloadSignature, err = sender.messenger.Sign(payloadBytes) - if err != nil { - return err, isTriggered - } - - msg.Pubkey, err = pk.ToByteArray() - if err != nil { - return err, isTriggered - } - - msg.Signature, err = sender.peerSignatureHandler.GetPeerSignature(sk, msg.Pid) - if err != nil { - return err, isTriggered - } - - msgBytes, err := sender.marshaller.Marshal(msg) + pkBytes, err := pk.ToByteArray() if err != nil { - return err, isTriggered + return err, false } - b := &batch.Batch{ - Data: make([][]byte, 1), - } - b.Data[0] = msgBytes - data, err := sender.marshaller.Marshal(b) + data, isTriggered, msgTimestamp, err := sender.generateMessageBytes(pkBytes, sk, nil, sender.messenger.ID().Bytes()) if err != nil { return err, isTriggered } log.Debug("sending peer authentication message", - "public key", msg.Pubkey, "pid", sender.messenger.ID().Pretty(), - "timestamp", payload.Timestamp) + "public key", pkBytes, "pid", sender.messenger.ID().Pretty(), + "timestamp", msgTimestamp) sender.messenger.Broadcast(sender.topic, data) return nil, isTriggered @@ -166,23 +138,13 @@ func (sender *peerAuthenticationSender) ShouldTriggerHardfork() <-chan struct{} return sender.hardforkTrigger.NotifyTriggerReceivedV2() } -func (sender *peerAuthenticationSender) isValidator(pkBytes []byte) bool { - _, _, err := sender.nodesCoordinator.GetValidatorWithPublicKey(pkBytes) - return err == nil -} - -func (sender *peerAuthenticationSender) isHardforkSource(pkBytes []byte) bool { - return bytes.Equal(pkBytes, sender.hardforkTriggerPubKey) -} - -func (sender *peerAuthenticationSender) getHardforkPayload() ([]byte, bool) { - payload := make([]byte, 0) - _, isTriggered := sender.hardforkTrigger.RecordedTriggerMessage() - if isTriggered { - payload = sender.hardforkTrigger.CreateData() +func (sender *peerAuthenticationSender) getCurrentPrivateAndPublicKeys() (crypto.PrivateKey, crypto.PublicKey) { + shouldUseOriginalKeys := !sender.redundancy.IsRedundancyNode() || (sender.redundancy.IsRedundancyNode() && !sender.redundancy.IsMainMachineActive()) + if shouldUseOriginalKeys { + return sender.privKey, sender.publicKey } - return payload, isTriggered + return sender.redundancy.ObserverPrivateKey(), sender.observerPublicKey } // IsInterfaceNil returns true if there is no value under the interface diff --git a/heartbeat/sender/peerAuthenticationSenderFactory.go b/heartbeat/sender/peerAuthenticationSenderFactory.go new file mode 100644 index 00000000000..6467e1a36e0 --- /dev/null +++ b/heartbeat/sender/peerAuthenticationSenderFactory.go @@ -0,0 +1,52 @@ +package sender + +import ( + "fmt" + "time" + + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/heartbeat" +) + +type argPeerAuthenticationSenderFactory struct { + argBaseSender + nodesCoordinator heartbeat.NodesCoordinator + peerSignatureHandler crypto.PeerSignatureHandler + hardforkTrigger heartbeat.HardforkTrigger + hardforkTimeBetweenSends time.Duration + hardforkTriggerPubKey []byte + managedPeersHolder heartbeat.ManagedPeersHolder + timeBetweenChecks time.Duration + shardCoordinator heartbeat.ShardCoordinator +} + +func createPeerAuthenticationSender(args argPeerAuthenticationSenderFactory) (peerAuthenticationSenderHandler, error) { + isMultikey, err := isMultikeyMode(args.privKey, args.managedPeersHolder, args.nodesCoordinator) + if err != nil { + return nil, fmt.Errorf("%w while creating peer authentication sender", err) + } + + if isMultikey { + return createMultikeyPeerAuthenticationSender(args) + } + + return createRegularPeerAuthenticationSender(args) +} + +func createRegularPeerAuthenticationSender(args argPeerAuthenticationSenderFactory) (*peerAuthenticationSender, error) { + argsSender := argPeerAuthenticationSender{ + argBaseSender: args.argBaseSender, + nodesCoordinator: args.nodesCoordinator, + peerSignatureHandler: args.peerSignatureHandler, + hardforkTrigger: args.hardforkTrigger, + hardforkTimeBetweenSends: args.hardforkTimeBetweenSends, + hardforkTriggerPubKey: args.hardforkTriggerPubKey, + } + + return newPeerAuthenticationSender(argsSender) +} + +func createMultikeyPeerAuthenticationSender(args argPeerAuthenticationSenderFactory) (*multikeyPeerAuthenticationSender, error) { + argsSender := argMultikeyPeerAuthenticationSender(args) + return newMultikeyPeerAuthenticationSender(argsSender) +} diff --git a/heartbeat/sender/peerAuthenticationSenderFactory_test.go b/heartbeat/sender/peerAuthenticationSenderFactory_test.go new file mode 100644 index 00000000000..46c60f39ddf --- /dev/null +++ b/heartbeat/sender/peerAuthenticationSenderFactory_test.go @@ -0,0 +1,123 @@ +package sender + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/heartbeat" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/stretchr/testify/assert" +) + +func createMockPeerAuthenticationSenderFactoryArgs() argPeerAuthenticationSenderFactory { + return argPeerAuthenticationSenderFactory{ + argBaseSender: createMockBaseArgs(), + nodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, + peerSignatureHandler: &cryptoMocks.PeerSignatureHandlerStub{}, + hardforkTrigger: &testscommon.HardforkTriggerStub{}, + hardforkTimeBetweenSends: time.Second, + hardforkTriggerPubKey: providedHardforkPubKey, + managedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + timeBetweenChecks: time.Second, + shardCoordinator: createShardCoordinatorInShard(0), + } +} + +func TestPeerAuthenticationSenderFactory_createPeerAuthenticationSender(t *testing.T) { + t.Parallel() + + t.Run("ToByteArray fails should error", func(t *testing.T) { + t.Parallel() + + args := createMockPeerAuthenticationSenderFactoryArgs() + args.privKey = &cryptoMocks.PrivateKeyStub{ + GeneratePublicStub: func() crypto.PublicKey { + return &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return nil, expectedErr + }, + } + }, + } + peerAuthSender, err := createPeerAuthenticationSender(args) + assert.True(t, errors.Is(err, expectedErr)) + assert.True(t, check.IfNil(peerAuthSender)) + }) + t.Run("validator with keys managed should error", func(t *testing.T) { + t.Parallel() + + args := createMockPeerAuthenticationSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, nil + }, + } + args.managedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IsMultiKeyModeCalled: func() bool { + return true + }, + } + peerAuthSender, err := createPeerAuthenticationSender(args) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidConfiguration)) + assert.True(t, check.IfNil(peerAuthSender)) + }) + t.Run("validator should create regular sender", func(t *testing.T) { + t.Parallel() + + args := createMockPeerAuthenticationSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, nil + }, + } + peerAuthSender, err := createPeerAuthenticationSender(args) + assert.Nil(t, err) + assert.False(t, check.IfNil(peerAuthSender)) + assert.Equal(t, "*sender.peerAuthenticationSender", fmt.Sprintf("%T", peerAuthSender)) + }) + t.Run("regular observer should create regular sender", func(t *testing.T) { + t.Parallel() + + args := createMockPeerAuthenticationSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, errors.New("not validator") + }, + } + args.managedPeersHolder = &testscommon.ManagedPeersHolderStub{ + GetManagedKeysByCurrentNodeCalled: func() map[string]crypto.PrivateKey { + return make(map[string]crypto.PrivateKey) + }, + } + peerAuthSender, err := createPeerAuthenticationSender(args) + assert.Nil(t, err) + assert.False(t, check.IfNil(peerAuthSender)) + assert.Equal(t, "*sender.peerAuthenticationSender", fmt.Sprintf("%T", peerAuthSender)) + }) + t.Run("not validator with keys managed should create multikey sender", func(t *testing.T) { + t.Parallel() + + args := createMockPeerAuthenticationSenderFactoryArgs() + args.nodesCoordinator = &shardingMocks.NodesCoordinatorStub{ + GetValidatorWithPublicKeyCalled: func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) { + return nil, 0, errors.New("not validator") + }, + } + args.managedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IsMultiKeyModeCalled: func() bool { + return true + }, + } + peerAuthSender, err := createPeerAuthenticationSender(args) + assert.Nil(t, err) + assert.False(t, check.IfNil(peerAuthSender)) + assert.Equal(t, "*sender.multikeyPeerAuthenticationSender", fmt.Sprintf("%T", peerAuthSender)) + }) +} diff --git a/heartbeat/sender/sender.go b/heartbeat/sender/sender.go index eed51c3bead..fbc5525be26 100644 --- a/heartbeat/sender/sender.go +++ b/heartbeat/sender/sender.go @@ -22,6 +22,7 @@ type ArgSender struct { HeartbeatTimeBetweenSends time.Duration HeartbeatTimeBetweenSendsWhenError time.Duration HeartbeatTimeThresholdBetweenSends float64 + BaseVersionNumber string VersionNumber string NodeDisplayName string Identity string @@ -35,11 +36,14 @@ type ArgSender struct { HardforkTimeBetweenSends time.Duration HardforkTriggerPubKey []byte PeerTypeProvider heartbeat.PeerTypeProviderHandler + ManagedPeersHolder heartbeat.ManagedPeersHolder + PeerAuthenticationTimeBetweenChecks time.Duration + ShardCoordinator heartbeat.ShardCoordinator } // sender defines the component which sends authentication and heartbeat messages type sender struct { - heartbeatSender *heartbeatSender + heartbeatSender heartbeatSenderHandler routineHandler *routineHandler } @@ -50,7 +54,7 @@ func NewSender(args ArgSender) (*sender, error) { return nil, err } - pas, err := newPeerAuthenticationSender(argPeerAuthenticationSender{ + pas, err := createPeerAuthenticationSender(argPeerAuthenticationSenderFactory{ argBaseSender: argBaseSender{ messenger: args.Messenger, marshaller: args.Marshaller, @@ -66,12 +70,15 @@ func NewSender(args ArgSender) (*sender, error) { hardforkTrigger: args.HardforkTrigger, hardforkTimeBetweenSends: args.HardforkTimeBetweenSends, hardforkTriggerPubKey: args.HardforkTriggerPubKey, + managedPeersHolder: args.ManagedPeersHolder, + timeBetweenChecks: args.PeerAuthenticationTimeBetweenChecks, + shardCoordinator: args.ShardCoordinator, }) if err != nil { return nil, err } - hbs, err := newHeartbeatSender(argHeartbeatSender{ + hbs, err := createHeartbeatSender(argHeartbeatSenderFactory{ argBaseSender: argBaseSender{ messenger: args.Messenger, marshaller: args.Marshaller, @@ -82,12 +89,16 @@ func NewSender(args ArgSender) (*sender, error) { privKey: args.PrivateKey, redundancyHandler: args.RedundancyHandler, }, + baseVersionNumber: args.BaseVersionNumber, versionNumber: args.VersionNumber, nodeDisplayName: args.NodeDisplayName, identity: args.Identity, peerSubType: args.PeerSubType, currentBlockProvider: args.CurrentBlockProvider, peerTypeProvider: args.PeerTypeProvider, + managedPeersHolder: args.ManagedPeersHolder, + shardCoordinator: args.ShardCoordinator, + nodesCoordinator: args.NodesCoordinator, trieSyncStatisticsProvider: disabled.NewTrieSyncStatisticsProvider(), }) if err != nil { @@ -101,24 +112,41 @@ func NewSender(args ArgSender) (*sender, error) { } func checkSenderArgs(args ArgSender) error { - pasArg := argPeerAuthenticationSender{ - argBaseSender: argBaseSender{ - messenger: args.Messenger, - marshaller: args.Marshaller, - topic: args.PeerAuthenticationTopic, - timeBetweenSends: args.PeerAuthenticationTimeBetweenSends, - timeBetweenSendsWhenError: args.PeerAuthenticationTimeBetweenSendsWhenError, - thresholdBetweenSends: args.PeerAuthenticationTimeThresholdBetweenSends, - privKey: args.PrivateKey, - redundancyHandler: args.RedundancyHandler, - }, + basePeerAuthSenderArgs := argBaseSender{ + messenger: args.Messenger, + marshaller: args.Marshaller, + topic: args.PeerAuthenticationTopic, + timeBetweenSends: args.PeerAuthenticationTimeBetweenSends, + timeBetweenSendsWhenError: args.PeerAuthenticationTimeBetweenSendsWhenError, + thresholdBetweenSends: args.PeerAuthenticationTimeThresholdBetweenSends, + privKey: args.PrivateKey, + redundancyHandler: args.RedundancyHandler, + } + pasArgs := argPeerAuthenticationSender{ + argBaseSender: basePeerAuthSenderArgs, nodesCoordinator: args.NodesCoordinator, peerSignatureHandler: args.PeerSignatureHandler, hardforkTrigger: args.HardforkTrigger, hardforkTimeBetweenSends: args.HardforkTimeBetweenSends, hardforkTriggerPubKey: args.HardforkTriggerPubKey, } - err := checkPeerAuthenticationSenderArgs(pasArg) + err := checkPeerAuthenticationSenderArgs(pasArgs) + if err != nil { + return err + } + + mpasArgs := argMultikeyPeerAuthenticationSender{ + argBaseSender: basePeerAuthSenderArgs, + nodesCoordinator: args.NodesCoordinator, + peerSignatureHandler: args.PeerSignatureHandler, + hardforkTrigger: args.HardforkTrigger, + hardforkTimeBetweenSends: args.HardforkTimeBetweenSends, + hardforkTriggerPubKey: args.HardforkTriggerPubKey, + managedPeersHolder: args.ManagedPeersHolder, + timeBetweenChecks: args.PeerAuthenticationTimeBetweenChecks, + shardCoordinator: args.ShardCoordinator, + } + err = checkMultikeyPeerAuthenticationSenderArgs(mpasArgs) if err != nil { return err } @@ -142,7 +170,35 @@ func checkSenderArgs(args ArgSender) error { peerTypeProvider: args.PeerTypeProvider, trieSyncStatisticsProvider: disabled.NewTrieSyncStatisticsProvider(), } - return checkHeartbeatSenderArgs(hbsArgs) + err = checkHeartbeatSenderArgs(hbsArgs) + if err != nil { + return err + } + + mhbsArgs := argMultikeyHeartbeatSender{ + argBaseSender: argBaseSender{ + messenger: args.Messenger, + marshaller: args.Marshaller, + topic: args.HeartbeatTopic, + timeBetweenSends: args.HeartbeatTimeBetweenSends, + timeBetweenSendsWhenError: args.HeartbeatTimeBetweenSendsWhenError, + thresholdBetweenSends: args.HeartbeatTimeThresholdBetweenSends, + privKey: args.PrivateKey, + redundancyHandler: args.RedundancyHandler, + }, + peerTypeProvider: args.PeerTypeProvider, + versionNumber: args.VersionNumber, + baseVersionNumber: args.BaseVersionNumber, + nodeDisplayName: args.NodeDisplayName, + identity: args.Identity, + peerSubType: args.PeerSubType, + currentBlockProvider: args.CurrentBlockProvider, + managedPeersHolder: args.ManagedPeersHolder, + shardCoordinator: args.ShardCoordinator, + trieSyncStatisticsProvider: disabled.NewTrieSyncStatisticsProvider(), + } + + return checkMultikeyHeartbeatSenderArgs(mhbsArgs) } // Close closes the internal components @@ -152,9 +208,9 @@ func (sender *sender) Close() error { return nil } -// GetSenderInfo will return the current sender info -func (sender *sender) GetSenderInfo() (string, core.P2PPeerSubType, error) { - return sender.heartbeatSender.getSenderInfo() +// GetCurrentNodeType will return the current peer details +func (sender *sender) GetCurrentNodeType() (string, core.P2PPeerSubType, error) { + return sender.heartbeatSender.GetCurrentNodeType() } // IsInterfaceNil returns true if there is no value under the interface diff --git a/heartbeat/sender/sender_test.go b/heartbeat/sender/sender_test.go index 393ef092ea2..b6402821e77 100644 --- a/heartbeat/sender/sender_test.go +++ b/heartbeat/sender/sender_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func createMockSenderArgs() ArgSender { @@ -30,6 +31,7 @@ func createMockSenderArgs() ArgSender { HeartbeatTimeBetweenSends: time.Second, HeartbeatTimeBetweenSendsWhenError: time.Second, HeartbeatTimeThresholdBetweenSends: 0.1, + BaseVersionNumber: "v1-base", VersionNumber: "v1", NodeDisplayName: "node", Identity: "identity", @@ -43,6 +45,9 @@ func createMockSenderArgs() ArgSender { HardforkTimeBetweenSends: time.Second, HardforkTriggerPubKey: providedHardforkPubKey, PeerTypeProvider: &mock.PeerTypeProviderStub{}, + ManagedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + PeerAuthenticationTimeBetweenChecks: time.Second, + ShardCoordinator: createShardCoordinatorInShard(0), } } @@ -260,6 +265,37 @@ func TestNewSender(t *testing.T) { assert.Nil(t, senderInstance) assert.Equal(t, heartbeat.ErrNilPeerTypeProvider, err) }) + t.Run("nil managed peers holder should error", func(t *testing.T) { + t.Parallel() + + args := createMockSenderArgs() + args.ManagedPeersHolder = nil + senderInstance, err := NewSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrNilManagedPeersHolder)) + }) + t.Run("invalid time between checks should error", func(t *testing.T) { + t.Parallel() + + args := createMockSenderArgs() + args.PeerAuthenticationTimeBetweenChecks = time.Second - time.Nanosecond + senderInstance, err := NewSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrInvalidTimeDuration)) + assert.True(t, strings.Contains(err.Error(), "timeBetweenChecks")) + }) + t.Run("nil shard coordinator should error", func(t *testing.T) { + t.Parallel() + + args := createMockSenderArgs() + args.ShardCoordinator = nil + senderInstance, err := NewSender(args) + + assert.Nil(t, senderInstance) + assert.True(t, errors.Is(err, heartbeat.ErrNilShardCoordinator)) + }) t.Run("should work", func(t *testing.T) { t.Parallel() @@ -287,7 +323,7 @@ func TestSender_Close(t *testing.T) { assert.Nil(t, err) } -func TestSender_GetSenderInfoShouldNotPanic(t *testing.T) { +func TestSender_GetCurrentNodeTypeShouldNotPanic(t *testing.T) { t.Parallel() defer func() { @@ -298,9 +334,10 @@ func TestSender_GetSenderInfoShouldNotPanic(t *testing.T) { }() args := createMockSenderArgs() - senderInstance, _ := NewSender(args) + senderInstance, err := NewSender(args) + require.Nil(t, err) - _, _, err := senderInstance.GetSenderInfo() + _, _, err = senderInstance.GetCurrentNodeType() assert.Nil(t, err) _ = senderInstance.Close() diff --git a/heartbeat/status/interface.go b/heartbeat/status/interface.go index dfd959f5eac..47ecff9b737 100644 --- a/heartbeat/status/interface.go +++ b/heartbeat/status/interface.go @@ -13,6 +13,6 @@ type HeartbeatMonitor interface { // HeartbeatSenderInfoProvider is able to provide correct information about the current sender type HeartbeatSenderInfoProvider interface { - GetSenderInfo() (string, core.P2PPeerSubType, error) + GetCurrentNodeType() (string, core.P2PPeerSubType, error) IsInterfaceNil() bool } diff --git a/heartbeat/status/metricsUpdater.go b/heartbeat/status/metricsUpdater.go index 12d6478d3a1..ad599c56dcc 100644 --- a/heartbeat/status/metricsUpdater.go +++ b/heartbeat/status/metricsUpdater.go @@ -122,7 +122,7 @@ func (updater *metricsUpdater) updateConnectionsMetrics() { } func (updater *metricsUpdater) updateSenderMetrics() { - result, subType, err := updater.heartbeatSenderInfoProvider.GetSenderInfo() + result, subType, err := updater.heartbeatSenderInfoProvider.GetCurrentNodeType() if err != nil { log.Warn("error while updating metrics in heartbeat v2 metricsUpdater", "error", err) return diff --git a/heartbeat/status/metricsUpdater_test.go b/heartbeat/status/metricsUpdater_test.go index d9875921af3..645f4edb0dd 100644 --- a/heartbeat/status/metricsUpdater_test.go +++ b/heartbeat/status/metricsUpdater_test.go @@ -161,7 +161,7 @@ func TestMetricsUpdater_updateMetrics(t *testing.T) { t.Run("should send sender metrics", func(t *testing.T) { t.Run("eligible node", func(t *testing.T) { args.HeartbeatSenderInfoProvider = &mock.HeartbeatSenderInfoProviderStub{ - GetSenderInfoCalled: func() (string, core.P2PPeerSubType, error) { + GetCurrentNodeTypeCalled: func() (string, core.P2PPeerSubType, error) { return string(common.EligibleList), core.FullHistoryObserver, nil }, } @@ -174,7 +174,7 @@ func TestMetricsUpdater_updateMetrics(t *testing.T) { }) t.Run("waiting node", func(t *testing.T) { args.HeartbeatSenderInfoProvider = &mock.HeartbeatSenderInfoProviderStub{ - GetSenderInfoCalled: func() (string, core.P2PPeerSubType, error) { + GetCurrentNodeTypeCalled: func() (string, core.P2PPeerSubType, error) { return string(common.WaitingList), core.FullHistoryObserver, nil }, } @@ -187,7 +187,7 @@ func TestMetricsUpdater_updateMetrics(t *testing.T) { }) t.Run("observer node", func(t *testing.T) { args.HeartbeatSenderInfoProvider = &mock.HeartbeatSenderInfoProviderStub{ - GetSenderInfoCalled: func() (string, core.P2PPeerSubType, error) { + GetCurrentNodeTypeCalled: func() (string, core.P2PPeerSubType, error) { return string(common.ObserverList), core.FullHistoryObserver, nil }, } @@ -201,7 +201,7 @@ func TestMetricsUpdater_updateMetrics(t *testing.T) { }) t.Run("GetSenderInfo errors", func(t *testing.T) { args.HeartbeatSenderInfoProvider = &mock.HeartbeatSenderInfoProviderStub{ - GetSenderInfoCalled: func() (string, core.P2PPeerSubType, error) { + GetCurrentNodeTypeCalled: func() (string, core.P2PPeerSubType, error) { return "", 0, errors.New("expected error") }, } diff --git a/integrationTests/consensus/consensusSigning_test.go b/integrationTests/consensus/consensusSigning_test.go index 32d528dcd54..7566828ada1 100644 --- a/integrationTests/consensus/consensusSigning_test.go +++ b/integrationTests/consensus/consensusSigning_test.go @@ -29,6 +29,7 @@ func initNodesWithTestSigner( int(consensusSize), roundTime, consensusType, + 1, ) for shardID, nodesList := range nodes { @@ -42,10 +43,6 @@ func initNodesWithTestSigner( for i := uint32(0); i < numInvalid; i++ { ii := numNodes - i - 1 nodes[shardID][ii].MultiSigner.CreateSignatureShareCalled = func(privateKeyBytes, message []byte) ([]byte, error) { - fmt.Println("invalid sig share from ", - getPkEncoded(nodes[shardID][ii].NodeKeys.Pk), - ) - var invalidSigShare []byte if i%2 == 0 { // invalid sig share but with valid format @@ -54,6 +51,7 @@ func initNodesWithTestSigner( // sig share with invalid size invalidSigShare = bytes.Repeat([]byte("a"), 3) } + log.Warn("invalid sig share from ", "pk", getPkEncoded(nodes[shardID][ii].NodeKeys.Pk), "sig", invalidSigShare) return invalidSigShare, nil } @@ -107,7 +105,7 @@ func TestConsensusWithInvalidSigners(t *testing.T) { case <-chDone: case <-time.After(endTime): mutex.Lock() - fmt.Println("currently saved nonces for rounds: \n", nonceForRoundMap) + log.Error("currently saved nonces for rounds", "nonceForRoundMap", nonceForRoundMap) assert.Fail(t, "consensus too slow, not working.") mutex.Unlock() return diff --git a/integrationTests/consensus/consensus_test.go b/integrationTests/consensus/consensus_test.go index 6f99ffe3f33..cbc52fcc855 100644 --- a/integrationTests/consensus/consensus_test.go +++ b/integrationTests/consensus/consensus_test.go @@ -20,8 +20,6 @@ import ( "github.com/stretchr/testify/assert" ) -var log = logger.GetOrCreate("integrationtests/consensus") - const ( consensusTimeBetweenRounds = time.Second blsConsensusType = "bls" @@ -30,6 +28,7 @@ const ( var ( p2pBootstrapDelay = time.Second * 5 testPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(32) + log = logger.GetOrCreate("integrationtests/consensus") ) func encodeAddress(address []byte) string { @@ -52,6 +51,7 @@ func initNodesAndTest( numInvalid uint32, roundTime uint64, consensusType string, + numKeysOnEachNode int, ) (map[uint32][]*integrationTests.TestConsensusNode, error) { fmt.Println("Step 1. Setup nodes...") @@ -62,6 +62,7 @@ func initNodesAndTest( int(consensusSize), roundTime, consensusType, + numKeysOnEachNode, ) for shardID, nodesList := range nodes { @@ -113,6 +114,8 @@ func startNodesWithCommitBlock(nodes []*integrationTests.TestConsensusNode, mute nCopy.ChainHandler.SetCurrentBlockHeaderHash(headerHash) _ = nCopy.ChainHandler.SetCurrentBlockHeaderAndRootHash(header, header.GetRootHash()) + log.Info("BlockProcessor.CommitBlockCalled", "shard", header.GetShardID(), "nonce", header.GetNonce(), "round", header.GetRound()) + mutex.Lock() nonceForRoundMap[header.GetRound()] = header.GetNonce() *totalCalled += 1 @@ -195,7 +198,7 @@ func checkBlockProposedEveryRound(numCommBlock uint64, nonceForRoundMap map[uint for i := minRound; i <= maxRound; i++ { if _, ok := nonceForRoundMap[i]; !ok { assert.Fail(t, "consensus not reached in each round") - fmt.Println("currently saved nonces for rounds: \n", nonceForRoundMap) + log.Error("currently saved nonces for rounds", "nonceForRoundMap", nonceForRoundMap) mutex.Unlock() return } @@ -212,15 +215,15 @@ func checkBlockProposedEveryRound(numCommBlock uint64, nonceForRoundMap map[uint } } -func runFullConsensusTest(t *testing.T, consensusType string) { +func runFullConsensusTest(t *testing.T, consensusType string, numKeysOnEachNode int) { numMetaNodes := uint32(4) numNodes := uint32(4) - consensusSize := uint32(4) + consensusSize := uint32(4 * numKeysOnEachNode) numInvalid := uint32(0) roundTime := uint64(1000) numCommBlock := uint64(8) - nodes, err := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType) + nodes, err := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, numKeysOnEachNode) if err != nil { assert.Nil(t, err) } @@ -252,6 +255,7 @@ func runFullConsensusTest(t *testing.T, consensusType string) { endTime := time.Duration(roundTime)*time.Duration(numCommBlock+extraTime)*time.Millisecond + time.Minute select { case <-chDone: + log.Info("consensus done", "shard", shardID) case <-time.After(endTime): mutex.Lock() fmt.Println("currently saved nonces for rounds: \n", nonceForRoundMap) @@ -262,12 +266,20 @@ func runFullConsensusTest(t *testing.T, consensusType string) { } } -func TestConsensusBLSFullTest(t *testing.T) { +func TestConsensusBLSFullTestSingleKeys(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + runFullConsensusTest(t, blsConsensusType, 1) +} + +func TestConsensusBLSFullTestMultiKeys(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - runFullConsensusTest(t, blsConsensusType) + runFullConsensusTest(t, blsConsensusType, 5) } func runConsensusWithNotEnoughValidators(t *testing.T, consensusType string) { @@ -276,7 +288,7 @@ func runConsensusWithNotEnoughValidators(t *testing.T, consensusType string) { consensusSize := uint32(4) numInvalid := uint32(2) roundTime := uint64(1000) - nodes, err := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType) + nodes, err := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, 1) if err != nil { assert.Nil(t, err) } diff --git a/integrationTests/factory/componentsHelper.go b/integrationTests/factory/componentsHelper.go index c901d3a06eb..64f70e6bb8c 100644 --- a/integrationTests/factory/componentsHelper.go +++ b/integrationTests/factory/componentsHelper.go @@ -55,11 +55,12 @@ func CreateDefaultConfig(tb testing.TB) *config.Configs { configs.EpochConfig = epochConfig configs.RoundConfig = roundConfig configs.FlagsConfig = &config.ContextFlagsConfig{ - WorkingDir: tb.TempDir(), - DbDir: "dbDir", - LogsDir: "logsDir", - UseLogView: true, - Version: Version, + WorkingDir: tb.TempDir(), + DbDir: "dbDir", + LogsDir: "logsDir", + UseLogView: true, + BaseVersion: BaseVersion, + Version: Version, } configs.ConfigurationPathsHolder = configPathsHolder configs.ImportDbConfig = &config.ImportDbConfig{} diff --git a/integrationTests/factory/consensusComponents/consensusComponents_test.go b/integrationTests/factory/consensusComponents/consensusComponents_test.go index 1f8404f7e5c..96b7afec65d 100644 --- a/integrationTests/factory/consensusComponents/consensusComponents_test.go +++ b/integrationTests/factory/consensusComponents/consensusComponents_test.go @@ -41,7 +41,7 @@ func TestConsensusComponents_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedBootstrapComponents, err := nr.CreateManagedBootstrapComponents(managedStatusCoreComponents, managedCoreComponents, managedCryptoComponents, managedNetworkComponents) require.Nil(t, err) - managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents) + managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents, managedStatusCoreComponents) require.Nil(t, err) diff --git a/integrationTests/factory/constants.go b/integrationTests/factory/constants.go index 761fcb1239a..1db46e07547 100644 --- a/integrationTests/factory/constants.go +++ b/integrationTests/factory/constants.go @@ -16,6 +16,7 @@ const ( GenesisPath = "../testdata/genesis.json" GenesisSmartContracts = "../testdata/genesisSmartContracts.json" ValidatorKeyPemPath = "../validatorKey.pem" + BaseVersion = "v1.1.6.1-0-gbae61225f/go1.14.2/linux-amd64" Version = "v1.1.6.1-0-gbae61225f/go1.14.2/linux-amd64/a72b5f2eff" RoundActivationPath = "enableRounds.toml" P2pKeyPath = "../p2pKey.pem" diff --git a/integrationTests/factory/cryptoComponents/cryptoComponents_test.go b/integrationTests/factory/cryptoComponents/cryptoComponents_test.go index cc2d6012e84..0fbecfd0f4b 100644 --- a/integrationTests/factory/cryptoComponents/cryptoComponents_test.go +++ b/integrationTests/factory/cryptoComponents/cryptoComponents_test.go @@ -31,14 +31,22 @@ func TestCryptoComponents_Create_Close_ShouldWork(t *testing.T) { managedCoreComponents, err := nr.CreateManagedCoreComponents(chanStopNodeProcess) require.Nil(t, err) + managedStatusCoreComponents, err := nr.CreateManagedStatusCoreComponents(managedCoreComponents) + require.Nil(t, err) managedCryptoComponents, err := nr.CreateManagedCryptoComponents(managedCoreComponents) require.Nil(t, err) + managedNetworkComponents, err := nr.CreateManagedNetworkComponents(managedCoreComponents, managedStatusCoreComponents, managedCryptoComponents) + require.Nil(t, err) require.NotNil(t, managedCryptoComponents) time.Sleep(5 * time.Second) + err = managedNetworkComponents.Close() + require.Nil(t, err) err = managedCryptoComponents.Close() require.Nil(t, err) + err = managedStatusCoreComponents.Close() + require.Nil(t, err) err = managedCoreComponents.Close() require.Nil(t, err) diff --git a/integrationTests/factory/dataComponents/dataComponents_test.go b/integrationTests/factory/dataComponents/dataComponents_test.go index c2eb1ac91fb..9ebc4a49fc5 100644 --- a/integrationTests/factory/dataComponents/dataComponents_test.go +++ b/integrationTests/factory/dataComponents/dataComponents_test.go @@ -34,7 +34,7 @@ func TestDataComponents_Create_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedBootstrapComponents, err := nr.CreateManagedBootstrapComponents(managedStatusCoreComponents, managedCoreComponents, managedCryptoComponents, managedNetworkComponents) require.Nil(t, err) - managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents) + managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) require.Nil(t, err) require.NotNil(t, managedDataComponents) diff --git a/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go b/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go index bcae8ad1344..bf9ce94dda0 100644 --- a/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go +++ b/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go @@ -41,7 +41,7 @@ func TestHeartbeatComponents_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedBootstrapComponents, err := nr.CreateManagedBootstrapComponents(managedStatusCoreComponents, managedCoreComponents, managedCryptoComponents, managedNetworkComponents) require.Nil(t, err) - managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents) + managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents, managedStatusCoreComponents) require.Nil(t, err) diff --git a/integrationTests/factory/processComponents/processComponents_test.go b/integrationTests/factory/processComponents/processComponents_test.go index 0a2802df441..87ddd17f644 100644 --- a/integrationTests/factory/processComponents/processComponents_test.go +++ b/integrationTests/factory/processComponents/processComponents_test.go @@ -42,7 +42,7 @@ func TestProcessComponents_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedBootstrapComponents, err := nr.CreateManagedBootstrapComponents(managedStatusCoreComponents, managedCoreComponents, managedCryptoComponents, managedNetworkComponents) require.Nil(t, err) - managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents) + managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents, managedStatusCoreComponents) require.Nil(t, err) diff --git a/integrationTests/factory/stateComponents/stateComponents_test.go b/integrationTests/factory/stateComponents/stateComponents_test.go index afc6c5b1983..6056fcc2126 100644 --- a/integrationTests/factory/stateComponents/stateComponents_test.go +++ b/integrationTests/factory/stateComponents/stateComponents_test.go @@ -38,7 +38,7 @@ func TestStateComponents_Create_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedBootstrapComponents, err := nr.CreateManagedBootstrapComponents(managedStatusCoreComponents, managedCoreComponents, managedCryptoComponents, managedNetworkComponents) require.Nil(t, err) - managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents) + managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents, managedStatusCoreComponents) require.Nil(t, err) diff --git a/integrationTests/factory/statusComponents/statusComponents_test.go b/integrationTests/factory/statusComponents/statusComponents_test.go index 3d642cfd5cd..df8c26ef00a 100644 --- a/integrationTests/factory/statusComponents/statusComponents_test.go +++ b/integrationTests/factory/statusComponents/statusComponents_test.go @@ -42,7 +42,7 @@ func TestStatusComponents_Create_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedBootstrapComponents, err := nr.CreateManagedBootstrapComponents(managedStatusCoreComponents, managedCoreComponents, managedCryptoComponents, managedNetworkComponents) require.Nil(t, err) - managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents) + managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents, managedStatusCoreComponents) require.Nil(t, err) diff --git a/integrationTests/mock/cryptoComponentsStub.go b/integrationTests/mock/cryptoComponentsStub.go index b57ba941348..9d927d8d33a 100644 --- a/integrationTests/mock/cryptoComponentsStub.go +++ b/integrationTests/mock/cryptoComponentsStub.go @@ -5,6 +5,7 @@ import ( "sync" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/vm" @@ -12,24 +13,26 @@ import ( // CryptoComponentsStub - type CryptoComponentsStub struct { - PubKey crypto.PublicKey - PrivKey crypto.PrivateKey - P2pPubKey crypto.PublicKey - P2pPrivKey crypto.PrivateKey - P2pSig crypto.SingleSigner - PubKeyString string - PrivKeyBytes []byte - PubKeyBytes []byte - BlockSig crypto.SingleSigner - TxSig crypto.SingleSigner - MultiSigContainer cryptoCommon.MultiSignerContainer - PeerSignHandler crypto.PeerSignatureHandler - BlKeyGen crypto.KeyGenerator - TxKeyGen crypto.KeyGenerator - P2PKeyGen crypto.KeyGenerator - MsgSigVerifier vm.MessageSignVerifier - SigHandler consensus.SignatureHandler - mutMultiSig sync.RWMutex + PubKey crypto.PublicKey + PrivKey crypto.PrivateKey + P2pPubKey crypto.PublicKey + P2pPrivKey crypto.PrivateKey + PrivKeyBytes []byte + PubKeyBytes []byte + PubKeyString string + BlockSig crypto.SingleSigner + TxSig crypto.SingleSigner + P2pSig crypto.SingleSigner + MultiSigContainer cryptoCommon.MultiSignerContainer + PeerSignHandler crypto.PeerSignatureHandler + BlKeyGen crypto.KeyGenerator + TxKeyGen crypto.KeyGenerator + P2PKeyGen crypto.KeyGenerator + MsgSigVerifier vm.MessageSignVerifier + ManagedPeersHolderField common.ManagedPeersHolder + KeysHandlerField consensus.KeysHandler + SigHandler consensus.SigningHandler + mutMultiSig sync.RWMutex } // Create - @@ -154,30 +157,42 @@ func (ccs *CryptoComponentsStub) MessageSignVerifier() vm.MessageSignVerifier { return ccs.MsgSigVerifier } -// ConsensusSigHandler - -func (ccs *CryptoComponentsStub) ConsensusSigHandler() consensus.SignatureHandler { +// ConsensusSigningHandler - +func (ccs *CryptoComponentsStub) ConsensusSigningHandler() consensus.SigningHandler { return ccs.SigHandler } +// ManagedPeersHolder - +func (ccs *CryptoComponentsStub) ManagedPeersHolder() common.ManagedPeersHolder { + return ccs.ManagedPeersHolderField +} + +// KeysHandler - +func (ccs *CryptoComponentsStub) KeysHandler() consensus.KeysHandler { + return ccs.KeysHandlerField +} + // Clone - func (ccs *CryptoComponentsStub) Clone() interface{} { return &CryptoComponentsStub{ - PubKey: ccs.PubKey, - P2pPubKey: ccs.P2pPubKey, - PrivKey: ccs.PrivKey, - P2pPrivKey: ccs.P2pPrivKey, - PubKeyString: ccs.PubKeyString, - PrivKeyBytes: ccs.PrivKeyBytes, - PubKeyBytes: ccs.PubKeyBytes, - BlockSig: ccs.BlockSig, - TxSig: ccs.TxSig, - MultiSigContainer: ccs.MultiSigContainer, - PeerSignHandler: ccs.PeerSignHandler, - BlKeyGen: ccs.BlKeyGen, - TxKeyGen: ccs.TxKeyGen, - P2PKeyGen: ccs.P2PKeyGen, - MsgSigVerifier: ccs.MsgSigVerifier, - mutMultiSig: sync.RWMutex{}, + PubKey: ccs.PubKey, + P2pPubKey: ccs.P2pPubKey, + PrivKey: ccs.PrivKey, + P2pPrivKey: ccs.P2pPrivKey, + PubKeyString: ccs.PubKeyString, + PrivKeyBytes: ccs.PrivKeyBytes, + PubKeyBytes: ccs.PubKeyBytes, + BlockSig: ccs.BlockSig, + TxSig: ccs.TxSig, + MultiSigContainer: ccs.MultiSigContainer, + PeerSignHandler: ccs.PeerSignHandler, + BlKeyGen: ccs.BlKeyGen, + TxKeyGen: ccs.TxKeyGen, + P2PKeyGen: ccs.P2PKeyGen, + MsgSigVerifier: ccs.MsgSigVerifier, + ManagedPeersHolderField: ccs.ManagedPeersHolderField, + KeysHandlerField: ccs.KeysHandlerField, + mutMultiSig: sync.RWMutex{}, } } diff --git a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go index ada034da39b..82eca349947 100644 --- a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go +++ b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go @@ -61,7 +61,8 @@ func TestInterceptedShardBlockHeaderVerifiedWithCorrectConsensusGroup(t *testing header, err = fillHeaderFields(nodesMap[0][0], header, singleSigner) assert.Nil(t, err) - nodesMap[0][0].BroadcastBlock(body, header) + pk := nodesMap[0][0].NodeKeys.MainKey.Pk + nodesMap[0][0].BroadcastBlock(body, header, pk) time.Sleep(broadcastDelay) @@ -130,7 +131,8 @@ func TestInterceptedMetaBlockVerifiedWithCorrectConsensusGroup(t *testing.T) { 0, ) - nodesMap[core.MetachainShardId][0].BroadcastBlock(body, header) + pk := nodesMap[core.MetachainShardId][0].NodeKeys.MainKey.Pk + nodesMap[core.MetachainShardId][0].BroadcastBlock(body, header, pk) time.Sleep(broadcastDelay) @@ -203,7 +205,8 @@ func TestInterceptedShardBlockHeaderWithLeaderSignatureAndRandSeedChecks(t *test header, err = fillHeaderFields(nodeToSendFrom, header, singleSigner) assert.Nil(t, err) - nodeToSendFrom.BroadcastBlock(body, header) + pk := nodeToSendFrom.NodeKeys.MainKey.Pk + nodeToSendFrom.BroadcastBlock(body, header, pk) time.Sleep(broadcastDelay) @@ -267,7 +270,8 @@ func TestInterceptedShardHeaderBlockWithWrongPreviousRandSeedShouldNotBeAccepted nonce := uint64(2) body, header, _, _ := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, wrongRandomness, 0) - nodesMap[0][0].BroadcastBlock(body, header) + pk := nodesMap[0][0].NodeKeys.MainKey.Pk + nodesMap[0][0].BroadcastBlock(body, header, pk) time.Sleep(broadcastDelay) @@ -288,7 +292,7 @@ func TestInterceptedShardHeaderBlockWithWrongPreviousRandSeedShouldNotBeAccepted } func fillHeaderFields(proposer *integrationTests.TestProcessorNode, hdr data.HeaderHandler, signer crypto.SingleSigner) (data.HeaderHandler, error) { - leaderSk := proposer.NodeKeys.Sk + leaderSk := proposer.NodeKeys.MainKey.Sk randSeed, _ := signer.Sign(leaderSk, hdr.GetPrevRandSeed()) err := hdr.SetRandSeed(randSeed) diff --git a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go index 9a54cbaded2..d78a9d4145b 100644 --- a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go +++ b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go @@ -133,6 +133,7 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui } for _, node := range nodes { _ = dataRetriever.SetEpochHandlerToHdrResolver(node.ResolversContainer, epochHandler) + _ = dataRetriever.SetEpochHandlerToHdrRequester(node.RequestersContainer, epochHandler) } generalConfig := getGeneralConfig() @@ -205,7 +206,7 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui roundHandler := &mock.RoundHandlerMock{IndexField: int64(round)} cryptoComponents := integrationTests.GetDefaultCryptoComponents() - cryptoComponents.PubKey = nodeToJoinLate.NodeKeys.Pk + cryptoComponents.PubKey = nodeToJoinLate.NodeKeys.MainKey.Pk cryptoComponents.BlockSig = &mock.SignerMock{} cryptoComponents.TxSig = &mock.SignerMock{} cryptoComponents.BlKeyGen = &mock.KeyGenMock{} @@ -285,6 +286,8 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui CurrentEpoch: 0, StorageType: factory.ProcessStorageService, CreateTrieEpochRootHashStorer: false, + SnapshotsEnabled: false, + ManagedPeersHolder: &testscommon.ManagedPeersHolderStub{}, }, ) assert.NoError(t, err) diff --git a/integrationTests/node/heartbeatV2/heartbeatV2_test.go b/integrationTests/node/heartbeatV2/heartbeatV2_test.go index e6398cc2732..2b9051534df 100644 --- a/integrationTests/node/heartbeatV2/heartbeatV2_test.go +++ b/integrationTests/node/heartbeatV2/heartbeatV2_test.go @@ -114,7 +114,7 @@ func TestHeartbeatV2_PeerAuthenticationMessageExpiration(t *testing.T) { requestHashes := make([][]byte, 0) for i := 1; i < len(nodes); i++ { - pkBytes, err := nodes[i].NodeKeys.Pk.ToByteArray() + pkBytes, err := nodes[i].NodeKeys.MainKey.Pk.ToByteArray() assert.Nil(t, err) requestHashes = append(requestHashes, pkBytes) @@ -151,7 +151,7 @@ func checkMessages(t *testing.T, nodes []*integrationTests.TestHeartbeatNode, ma // Check this node received messages from all peers for _, node := range nodes { - pkBytes, err := node.NodeKeys.Pk.ToByteArray() + pkBytes, err := node.NodeKeys.MainKey.Pk.ToByteArray() assert.Nil(t, err) assert.True(t, paCache.Has(pkBytes)) diff --git a/integrationTests/nodesCoordinatorFactory.go b/integrationTests/nodesCoordinatorFactory.go index e96dad46080..543f7966595 100644 --- a/integrationTests/nodesCoordinatorFactory.go +++ b/integrationTests/nodesCoordinatorFactory.go @@ -39,8 +39,8 @@ type IndexHashedNodesCoordinatorFactory struct { // CreateNodesCoordinator - func (tpn *IndexHashedNodesCoordinatorFactory) CreateNodesCoordinator(arg ArgIndexHashedNodesCoordinatorFactory) nodesCoordinator.NodesCoordinator { - keys := arg.cp.Keys[arg.shardId][arg.keyIndex] - pubKeyBytes, _ := keys.Pk.ToByteArray() + keys := arg.cp.NodesKeys[arg.shardId][arg.keyIndex] + pubKeyBytes, _ := keys.MainKey.Pk.ToByteArray() nodeShufflerArgs := &nodesCoordinator.NodesShufflerArgs{ NodesShard: uint32(arg.nodesPerShard), @@ -93,8 +93,8 @@ type IndexHashedNodesCoordinatorWithRaterFactory struct { func (ihncrf *IndexHashedNodesCoordinatorWithRaterFactory) CreateNodesCoordinator( arg ArgIndexHashedNodesCoordinatorFactory, ) nodesCoordinator.NodesCoordinator { - keys := arg.cp.Keys[arg.shardId][arg.keyIndex] - pubKeyBytes, _ := keys.Pk.ToByteArray() + keys := arg.cp.NodesKeys[arg.shardId][arg.keyIndex] + pubKeyBytes, _ := keys.MainKey.Pk.ToByteArray() shufflerArgs := &nodesCoordinator.NodesShufflerArgs{ NodesShard: uint32(arg.nodesPerShard), diff --git a/integrationTests/singleShard/block/consensusNotAchieved/consensusNotAchieved_test.go b/integrationTests/singleShard/block/consensusNotAchieved/consensusNotAchieved_test.go index 59b0969475d..560e8f0ae74 100644 --- a/integrationTests/singleShard/block/consensusNotAchieved/consensusNotAchieved_test.go +++ b/integrationTests/singleShard/block/consensusNotAchieved/consensusNotAchieved_test.go @@ -78,7 +78,8 @@ func TestConsensus_BlockWithoutTwoThirdsPlusOneSignaturesOrWrongBitmapShouldNotB assert.NotNil(t, body) assert.NotNil(t, hdr) - nodesMap[0][0].BroadcastBlock(body, hdr) + pk := nodesMap[0][0].NodeKeys.MainKey.Pk + nodesMap[0][0].BroadcastBlock(body, hdr, pk) time.Sleep(testBlock.StepDelay) // the block should have not pass the interceptor @@ -95,7 +96,7 @@ func TestConsensus_BlockWithoutTwoThirdsPlusOneSignaturesOrWrongBitmapShouldNotB assert.NotNil(t, body) assert.NotNil(t, hdr) - nodesMap[0][0].BroadcastBlock(body, hdr) + nodesMap[0][0].BroadcastBlock(body, hdr, pk) time.Sleep(testBlock.StepDelay) // this block should have not passed the interceptor @@ -112,7 +113,7 @@ func TestConsensus_BlockWithoutTwoThirdsPlusOneSignaturesOrWrongBitmapShouldNotB assert.NotNil(t, body) assert.NotNil(t, hdr) - nodesMap[0][0].BroadcastBlock(body, hdr) + nodesMap[0][0].BroadcastBlock(body, hdr, pk) time.Sleep(testBlock.StepDelay) // this block should have passed the interceptor diff --git a/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go b/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go index 0f6bf0e1f4f..8301679001f 100644 --- a/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go +++ b/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go @@ -266,7 +266,8 @@ func proposeAndCommitBlock(node *integrationTests.TestProcessorNode, round uint6 return err } - node.BroadcastBlock(body, hdr) + pk := node.NodeKeys.MainKey.Pk + node.BroadcastBlock(body, hdr, pk) time.Sleep(testBlock.StepDelay) return nil } diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index 692faa4f77b..0cf62d8db64 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -101,15 +101,6 @@ func testNodeRequestInterceptTrieNodesWithMessenger(t *testing.T, version int) { _ = resolverTrie.Update([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) } - nodes := resolverTrie.GetNumNodes() - log.Info("trie nodes", - "total", nodes.Branches+nodes.Extensions+nodes.Leaves, - "branches", nodes.Branches, - "extensions", nodes.Extensions, - "leaves", nodes.Leaves, - "max level", nodes.MaxLevel, - ) - _ = resolverTrie.Commit() rootHash, _ := resolverTrie.RootHash() @@ -231,15 +222,6 @@ func testNodeRequestInterceptTrieNodesWithMessengerNotSyncingShouldErr(t *testin _ = resolverTrie.Update([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) } - nodes := resolverTrie.GetNumNodes() - log.Info("trie nodes", - "total", nodes.Branches+nodes.Extensions+nodes.Leaves, - "branches", nodes.Branches, - "extensions", nodes.Extensions, - "leaves", nodes.Leaves, - "max level", nodes.MaxLevel, - ) - _ = resolverTrie.Commit() rootHash, _ := resolverTrie.RootHash() diff --git a/integrationTests/sync/basicSync/basicSync_test.go b/integrationTests/sync/basicSync/basicSync_test.go index 9f51bf5e1a8..52cc2c7af79 100644 --- a/integrationTests/sync/basicSync/basicSync_test.go +++ b/integrationTests/sync/basicSync/basicSync_test.go @@ -168,7 +168,9 @@ func proposeBlockWithPubKeyBitmap(n *integrationTests.TestProcessorNode, round u if err != nil { log.Error("header.SetPubKeysBitmap", "error", err.Error()) } - n.BroadcastBlock(body, header) + + pk := n.NodeKeys.MainKey.Pk + n.BroadcastBlock(body, header, pk) n.CommitBlock(body, header) } diff --git a/integrationTests/testConsensusNode.go b/integrationTests/testConsensusNode.go index eb973d1744d..c4b6f89c673 100644 --- a/integrationTests/testConsensusNode.go +++ b/integrationTests/testConsensusNode.go @@ -24,9 +24,11 @@ import ( cryptoFactory "github.com/multiversx/mx-chain-go/factory/crypto" "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" "github.com/multiversx/mx-chain-go/integrationTests/mock" + "github.com/multiversx/mx-chain-go/keysManagement" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/ntp" "github.com/multiversx/mx-chain-go/p2p" + p2pFactory "github.com/multiversx/mx-chain-go/p2p/factory" "github.com/multiversx/mx-chain-go/process/factory" syncFork "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/sharding" @@ -52,11 +54,25 @@ const ( signatureSize = 48 publicKeySize = 96 maxShards = 1 - nodeShardId = 0 ) var testPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(32) +// ArgsTestConsensusNode represents the arguments for the test consensus node constructor(s) +type ArgsTestConsensusNode struct { + ShardID uint32 + ConsensusSize int + RoundTime uint64 + ConsensusType string + NodeKeys *TestNodeKeys + EligibleMap map[uint32][]nodesCoordinator.Validator + WaitingMap map[uint32][]nodesCoordinator.Validator + KeyGen crypto.KeyGenerator + P2PKeyGen crypto.KeyGenerator + MultiSigner *cryptoMocks.MultisignerMock + StartTime int64 +} + // TestConsensusNode represents a structure used in integration tests used for consensus tests type TestConsensusNode struct { Node *node.Node @@ -67,32 +83,21 @@ type TestConsensusNode struct { BlockProcessor *mock.BlockProcessorMock RequestersFinder dataRetriever.RequestersFinder AccountsDB *state.AccountsDB - NodeKeys TestKeyPair - MultiSigner cryptoMocks.MultisignerMock + NodeKeys *TestKeyPair + MultiSigner *cryptoMocks.MultisignerMock } // NewTestConsensusNode returns a new TestConsensusNode -func NewTestConsensusNode( - shardID uint32, - consensusSize int, - roundTime uint64, - consensusType string, - nodeKeys TestKeyPair, - eligibleMap map[uint32][]nodesCoordinator.Validator, - waitingMap map[uint32][]nodesCoordinator.Validator, - keyGen crypto.KeyGenerator, - startTime int64, - multiSigner cryptoMocks.MultisignerMock, -) *TestConsensusNode { +func NewTestConsensusNode(args ArgsTestConsensusNode) *TestConsensusNode { - shardCoordinator, _ := sharding.NewMultiShardCoordinator(maxShards, shardID) + shardCoordinator, _ := sharding.NewMultiShardCoordinator(maxShards, args.ShardID) tcn := &TestConsensusNode{ - NodeKeys: nodeKeys, + NodeKeys: args.NodeKeys.MainKey, ShardCoordinator: shardCoordinator, - MultiSigner: multiSigner, + MultiSigner: args.MultiSigner, } - tcn.initNode(consensusSize, roundTime, consensusType, eligibleMap, waitingMap, keyGen, startTime) + tcn.initNode(args) return tcn } @@ -104,11 +109,12 @@ func CreateNodesWithTestConsensusNode( consensusSize int, roundTime uint64, consensusType string, + numKeysOnEachNode int, ) map[uint32][]*TestConsensusNode { nodes := make(map[uint32][]*TestConsensusNode, nodesPerShard) - cp := CreateCryptoParams(nodesPerShard, numMetaNodes, maxShards) - keysMap := PubKeysMapFromKeysMap(cp.Keys) + cp := CreateCryptoParams(nodesPerShard, numMetaNodes, maxShards, numKeysOnEachNode) + keysMap := PubKeysMapFromNodesKeysMap(cp.NodesKeys) validatorsMap := GenValidatorsFromPubKeys(keysMap, maxShards) eligibleMap, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) waitingMap := make(map[uint32][]nodesCoordinator.Validator) @@ -116,23 +122,27 @@ func CreateNodesWithTestConsensusNode( startTime := time.Now().Unix() testHasher := createHasher(consensusType) - multiSigner, _ := multisig.NewBLSMultisig(&mclMultiSig.BlsMultiSigner{Hasher: testHasher}, cp.KeyGen) - multiSignerMock := createCustomMultiSignerMock(multiSigner) - - for shardID := range cp.Keys { - for _, keysPair := range cp.Keys[shardID] { - tcn := NewTestConsensusNode( - shardID, - consensusSize, - roundTime, - consensusType, - *keysPair, - eligibleMap, - waitingMap, - cp.KeyGen, - startTime, - multiSignerMock, - ) + + for shardID := range cp.NodesKeys { + for _, keysPair := range cp.NodesKeys[shardID] { + multiSigner, _ := multisig.NewBLSMultisig(&mclMultiSig.BlsMultiSigner{Hasher: testHasher}, cp.KeyGen) + multiSignerMock := createCustomMultiSignerMock(multiSigner) + + args := ArgsTestConsensusNode{ + ShardID: shardID, + ConsensusSize: consensusSize, + RoundTime: roundTime, + ConsensusType: consensusType, + NodeKeys: keysPair, + EligibleMap: eligibleMap, + WaitingMap: waitingMap, + KeyGen: cp.KeyGen, + P2PKeyGen: cp.P2PKeyGen, + MultiSigner: multiSignerMock, + StartTime: startTime, + } + + tcn := NewTestConsensusNode(args) nodes[shardID] = append(nodes[shardID], tcn) connectableNodes[shardID] = append(connectableNodes[shardID], tcn) } @@ -145,8 +155,8 @@ func CreateNodesWithTestConsensusNode( return nodes } -func createCustomMultiSignerMock(multiSigner crypto.MultiSigner) cryptoMocks.MultisignerMock { - multiSignerMock := cryptoMocks.MultisignerMock{} +func createCustomMultiSignerMock(multiSigner crypto.MultiSigner) *cryptoMocks.MultisignerMock { + multiSignerMock := &cryptoMocks.MultisignerMock{} multiSignerMock.CreateSignatureShareCalled = func(privateKeyBytes, message []byte) ([]byte, error) { return multiSigner.CreateSignatureShare(privateKeyBytes, message) } @@ -163,22 +173,13 @@ func createCustomMultiSignerMock(multiSigner crypto.MultiSigner) cryptoMocks.Mul return multiSignerMock } -func (tcn *TestConsensusNode) initNode( - consensusSize int, - roundTime uint64, - consensusType string, - eligibleMap map[uint32][]nodesCoordinator.Validator, - waitingMap map[uint32][]nodesCoordinator.Validator, - keyGen crypto.KeyGenerator, - startTime int64, -) { - - testHasher := createHasher(consensusType) +func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { + testHasher := createHasher(args.ConsensusType) epochStartRegistrationHandler := notifier.NewEpochStartSubscriptionHandler() consensusCache, _ := cache.NewLRUCache(10000) pkBytes, _ := tcn.NodeKeys.Pk.ToByteArray() - tcn.initNodesCoordinator(consensusSize, testHasher, epochStartRegistrationHandler, eligibleMap, waitingMap, pkBytes, consensusCache) + tcn.initNodesCoordinator(args.ConsensusSize, testHasher, epochStartRegistrationHandler, args.EligibleMap, args.WaitingMap, pkBytes, consensusCache) tcn.Messenger = CreateMessengerWithNoDiscovery() tcn.initBlockChain(testHasher) tcn.initBlockProcessor() @@ -187,20 +188,20 @@ func (tcn *TestConsensusNode) initNode( syncer.StartSyncingTime() roundHandler, _ := round.NewRound( - time.Unix(startTime, 0), + time.Unix(args.StartTime, 0), syncer.CurrentTime(), - time.Millisecond*time.Duration(roundTime), + time.Millisecond*time.Duration(args.RoundTime), syncer, 0) dataPool := dataRetrieverMock.CreatePoolsHolder(1, 0) argsNewMetaEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ - GenesisTime: time.Unix(startTime, 0), + GenesisTime: time.Unix(args.StartTime, 0), EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), Settings: &config.EpochStartConfig{ MinRoundsBetweenEpochs: 1, - RoundsPerEpoch: 3, + RoundsPerEpoch: 1000, }, Epoch: 0, Storage: createTestStore(), @@ -215,15 +216,15 @@ func (tcn *TestConsensusNode) initNode( roundHandler, cache.NewTimeCache(time.Second), &mock.BlockTrackerStub{}, - startTime, + args.StartTime, ) tcn.initRequestersFinder() peerSigCache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) - peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler(peerSigCache, TestSingleBlsSigner, keyGen) + peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler(peerSigCache, TestSingleBlsSigner, args.KeyGen) - multiSigContainer := cryptoMocks.NewMultiSignerContainerMock(&tcn.MultiSigner) + multiSigContainer := cryptoMocks.NewMultiSignerContainerMock(tcn.MultiSigner) privKey := tcn.NodeKeys.Sk pubKey := tcn.NodeKeys.Sk.GeneratePublic() @@ -238,26 +239,54 @@ func (tcn *TestConsensusNode) initNode( coreComponents.ChainIdCalled = func() string { return string(ChainID) } - coreComponents.GenesisTimeField = time.Unix(startTime, 0) + coreComponents.GenesisTimeField = time.Unix(args.StartTime, 0) coreComponents.GenesisNodesSetupField = &testscommon.NodesSetupStub{ GetShardConsensusGroupSizeCalled: func() uint32 { - return uint32(consensusSize) + return uint32(args.ConsensusSize) }, GetMetaConsensusGroupSizeCalled: func() uint32 { - return uint32(consensusSize) + return uint32(args.ConsensusSize) }, } + argsKeysHolder := keysManagement.ArgsManagedPeersHolder{ + KeyGenerator: args.KeyGen, + P2PKeyGenerator: args.P2PKeyGen, + IsMainMachine: true, + MaxRoundsWithoutReceivedMessages: 10, + PrefsConfig: config.Preferences{}, + P2PKeyConverter: p2pFactory.NewP2PKeyConverter(), + } + keysHolder, _ := keysManagement.NewManagedPeersHolder(argsKeysHolder) + + // adding provided handled keys + for _, key := range args.NodeKeys.HandledKeys { + skBytes, _ := key.Sk.ToByteArray() + _ = keysHolder.AddManagedPeer(skBytes) + } + pubKeyBytes, _ := pubKey.ToByteArray() pubKeyString := coreComponents.ValidatorPubKeyConverterField.SilentEncode(pubKeyBytes, log) - privKeyBytes, _ := privKey.ToByteArray() - signatureHolderArgs := cryptoFactory.ArgsSignatureHolder{ + argsKeysHandler := keysManagement.ArgsKeysHandler{ + ManagedPeersHolder: keysHolder, + PrivateKey: tcn.NodeKeys.Sk, + Pid: tcn.Messenger.ID(), + } + keysHandler, _ := keysManagement.NewKeysHandler(argsKeysHandler) + + signingHandlerArgs := cryptoFactory.ArgsSigningHandler{ PubKeys: []string{pubKeyString}, - PrivKeyBytes: privKeyBytes, MultiSignerContainer: multiSigContainer, - KeyGenerator: keyGen, + KeyGenerator: args.KeyGen, + KeysHandler: keysHandler, + SingleSigner: TestSingleBlsSigner, } - sigHandler, _ := cryptoFactory.NewSignatureHolder(signatureHolderArgs) + sigHandler, _ := cryptoFactory.NewSigningHandler(signingHandlerArgs) + + networkComponents := GetDefaultNetworkComponents() + networkComponents.Messenger = tcn.Messenger + networkComponents.InputAntiFlood = &mock.NilAntifloodHandler{} + networkComponents.PeerHonesty = &mock.PeerHonestyHandlerStub{} cryptoComponents := GetDefaultCryptoComponents() cryptoComponents.PrivKey = privKey @@ -265,9 +294,10 @@ func (tcn *TestConsensusNode) initNode( cryptoComponents.BlockSig = TestSingleBlsSigner cryptoComponents.TxSig = TestSingleSigner cryptoComponents.MultiSigContainer = multiSigContainer - cryptoComponents.BlKeyGen = keyGen + cryptoComponents.BlKeyGen = args.KeyGen cryptoComponents.PeerSignHandler = peerSigHandler cryptoComponents.SigHandler = sigHandler + cryptoComponents.KeysHandlerField = keysHandler processComponents := GetDefaultProcessComponents() processComponents.ForkDetect = forkDetector @@ -296,11 +326,6 @@ func (tcn *TestConsensusNode) initNode( stateComponents.Accounts = tcn.AccountsDB stateComponents.AccountsAPI = tcn.AccountsDB - networkComponents := GetDefaultNetworkComponents() - networkComponents.Messenger = tcn.Messenger - networkComponents.InputAntiFlood = &mock.NilAntifloodHandler{} - networkComponents.PeerHonesty = &mock.PeerHonestyHandlerStub{} - statusCoreComponents := &testFactory.StatusCoreComponentsStub{ AppStatusHandlerField: &statusHandlerMock.AppStatusHandlerStub{}, } @@ -314,10 +339,10 @@ func (tcn *TestConsensusNode) initNode( node.WithDataComponents(dataComponents), node.WithStateComponents(stateComponents), node.WithNetworkComponents(networkComponents), - node.WithRoundDuration(roundTime), - node.WithConsensusGroupSize(consensusSize), - node.WithConsensusType(consensusType), - node.WithGenesisTime(time.Unix(startTime, 0)), + node.WithRoundDuration(args.RoundTime), + node.WithConsensusGroupSize(args.ConsensusSize), + node.WithConsensusType(args.ConsensusType), + node.WithGenesisTime(time.Unix(args.StartTime, 0)), node.WithValidatorSignatureSize(signatureSize), node.WithPublicKeySize(publicKeySize), ) @@ -357,6 +382,7 @@ func (tcn *TestConsensusNode) initNodesCoordinator( IsWaitingListFixFlagEnabledField: true, }, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, + ShardIDAsObserver: tcn.ShardCoordinator.SelfId(), } tcn.NodesCoordinator, _ = nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) diff --git a/integrationTests/testHeartbeatNode.go b/integrationTests/testHeartbeatNode.go index c2f160d2169..c56468b73f3 100644 --- a/integrationTests/testHeartbeatNode.go +++ b/integrationTests/testHeartbeatNode.go @@ -87,7 +87,7 @@ type TestHeartbeatNode struct { NodesCoordinator nodesCoordinator.NodesCoordinator PeerShardMapper process.NetworkShardingCollector Messenger p2p.Messenger - NodeKeys TestKeyPair + NodeKeys *TestNodeKeys DataPool dataRetriever.PoolsHolder Sender update.Closer PeerAuthInterceptor *interceptors.MultiDataInterceptor @@ -190,9 +190,11 @@ func NewTestHeartbeatNode( pkBytes, _ := pk.ToByteArray() thn.PeerShardMapper.UpdatePeerIDInfo(localId, pkBytes, shardCoordinator.SelfId()) - thn.NodeKeys = TestKeyPair{ - Sk: sk, - Pk: pk, + thn.NodeKeys = &TestNodeKeys{ + MainKey: &TestKeyPair{ + Sk: sk, + Pk: pk, + }, } // start a go routine in order to allow peers to connect first @@ -208,7 +210,7 @@ func NewTestHeartbeatNodeWithCoordinator( nodeShardId uint32, p2pConfig p2pConfig.P2PConfig, coordinator nodesCoordinator.NodesCoordinator, - keys TestKeyPair, + keys *TestNodeKeys, ) *TestHeartbeatNode { keygen := signing.NewKeyGenerator(mcl.NewSuiteBLS12()) singleSigner := singlesig.NewBlsSigner() @@ -278,8 +280,8 @@ func CreateNodesWithTestHeartbeatNode( p2pConfig p2pConfig.P2PConfig, ) map[uint32][]*TestHeartbeatNode { - cp := CreateCryptoParams(nodesPerShard, numMetaNodes, uint32(numShards)) - pubKeys := PubKeysMapFromKeysMap(cp.Keys) + cp := CreateCryptoParams(nodesPerShard, numMetaNodes, uint32(numShards), 1) + pubKeys := PubKeysMapFromNodesKeysMap(cp.NodesKeys) validatorsMap := GenValidatorsFromPubKeys(pubKeys, uint32(numShards)) validatorsForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) nodesMap := make(map[uint32][]*TestHeartbeatNode) @@ -313,13 +315,13 @@ func CreateNodesWithTestHeartbeatNode( nodesList := make([]*TestHeartbeatNode, len(validatorList)) for i := range validatorList { - kp := cp.Keys[shardId][i] + kp := cp.NodesKeys[shardId][i] nodesList[i] = NewTestHeartbeatNodeWithCoordinator( uint32(numShards), shardId, p2pConfig, nodesCoordinatorInstance, - *kp, + kp, ) } nodesMap[shardId] = nodesList @@ -357,12 +359,15 @@ func CreateNodesWithTestHeartbeatNode( nodesCoordinatorInstance, err := nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) log.LogIfError(err) + nodeKeysInstance := &TestNodeKeys{ + MainKey: createCryptoPair(), + } n := NewTestHeartbeatNodeWithCoordinator( uint32(numShards), shardId, p2pConfig, nodesCoordinatorInstance, - createCryptoPair(), + nodeKeysInstance, ) nodesMap[shardId] = append(nodesMap[shardId], n) @@ -410,18 +415,21 @@ func (thn *TestHeartbeatNode) initSender() { Marshaller: TestMarshaller, PeerAuthenticationTopic: common.PeerAuthenticationTopic, HeartbeatTopic: identifierHeartbeat, + BaseVersionNumber: "v01-base", VersionNumber: "v01", NodeDisplayName: defaultNodeName, Identity: defaultNodeName + "_identity", PeerSubType: core.RegularPeer, CurrentBlockProvider: &testscommon.ChainHandlerStub{}, PeerSignatureHandler: thn.PeerSigHandler, - PrivateKey: thn.NodeKeys.Sk, + PrivateKey: thn.NodeKeys.MainKey.Sk, RedundancyHandler: &mock.RedundancyHandlerStub{}, NodesCoordinator: thn.NodesCoordinator, HardforkTrigger: &testscommon.HardforkTriggerStub{}, HardforkTriggerPubKey: []byte(providedHardforkPubKey), PeerTypeProvider: &mock.PeerTypeProviderStub{}, + ManagedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + ShardCoordinator: thn.ShardCoordinator, PeerAuthenticationTimeBetweenSends: timeBetweenPeerAuths, PeerAuthenticationTimeBetweenSendsWhenError: timeBetweenSendsWhenError, @@ -430,6 +438,7 @@ func (thn *TestHeartbeatNode) initSender() { HeartbeatTimeBetweenSendsWhenError: timeBetweenSendsWhenError, HeartbeatTimeThresholdBetweenSends: thresholdBetweenSends, HardforkTimeBetweenSends: timeBetweenHardforks, + PeerAuthenticationTimeBetweenChecks: time.Second * 2, } thn.Sender, _ = sender.NewSender(argsSender) @@ -718,7 +727,7 @@ func MakeDisplayTableForHeartbeatNodes(nodes map[uint32][]*TestHeartbeatNode) st for shardId, nodesList := range nodes { for _, n := range nodesList { - buffPk, _ := n.NodeKeys.Pk.ToByteArray() + buffPk, _ := n.NodeKeys.MainKey.Pk.ToByteArray() peerInfo := n.Messenger.GetConnectedPeersInfo() @@ -830,11 +839,11 @@ func (thn *TestHeartbeatNode) IsInterfaceNil() bool { return thn == nil } -func createCryptoPair() TestKeyPair { +func createCryptoPair() *TestKeyPair { suite := mcl.NewSuiteBLS12() keyGen := signing.NewKeyGenerator(suite) - kp := TestKeyPair{} + kp := &TestKeyPair{} kp.Sk, kp.Pk = keyGen.GeneratePair() return kp diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index ccbc67b3875..7ff1c51bdb1 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -27,6 +27,7 @@ import ( "github.com/multiversx/mx-chain-crypto-go/signing" "github.com/multiversx/mx-chain-crypto-go/signing/ed25519" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/multiversx/mx-chain-crypto-go/signing/secp256k1" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -1102,7 +1103,8 @@ func ProposeBlock(nodes []*TestProcessorNode, idxProposers []int, round uint64, body, header, _ := n.ProposeBlock(round, nonce) n.WhiteListBody(nodes, body) - n.BroadcastBlock(body, header) + pk := n.NodeKeys.MainKey.Pk + n.BroadcastBlock(body, header, pk) n.CommitBlock(body, header) } @@ -1481,7 +1483,7 @@ func CreateNodesWithFullGenesis( NodeShardId: shardId, TxSignPrivKeyShardId: shardId, GenesisFile: genesisFile, - HardforkPk: hardforkStarter.NodeKeys.Pk, + HardforkPk: hardforkStarter.NodeKeys.MainKey.Pk, EpochsConfig: enableEpochsConfig, EconomicsConfig: economicsConfig, }) @@ -1497,7 +1499,7 @@ func CreateNodesWithFullGenesis( NodeShardId: core.MetachainShardId, TxSignPrivKeyShardId: 0, GenesisFile: genesisFile, - HardforkPk: hardforkStarter.NodeKeys.Pk, + HardforkPk: hardforkStarter.NodeKeys.MainKey.Pk, EpochsConfig: enableEpochsConfig, EconomicsConfig: economicsConfig, }) @@ -1988,7 +1990,8 @@ func ProposeBlockSignalsEmptyBlock( log.Info("Proposing block without commit...") body, header, txHashes := node.ProposeBlock(round, nonce) - node.BroadcastBlock(body, header) + pk := node.NodeKeys.MainKey.Pk + node.BroadcastBlock(body, header, pk) isEmptyBlock := len(txHashes) == 0 log.Info("Delaying for disseminating headers and miniblocks...") @@ -2184,8 +2187,8 @@ func ProposeAndSyncOneBlock( return round, nonce } -// PubKeysMapFromKeysMap returns a map of public keys per shard from the key pairs per shard map. -func PubKeysMapFromKeysMap(keyPairMap map[uint32][]*TestKeyPair) map[uint32][]string { +// PubKeysMapFromTxKeysMap returns a map of public keys per shard from the key pairs per shard map. +func PubKeysMapFromTxKeysMap(keyPairMap map[uint32][]*TestKeyPair) map[uint32][]string { keysMap := make(map[uint32][]string) for shardId, pairList := range keyPairMap { @@ -2200,6 +2203,36 @@ func PubKeysMapFromKeysMap(keyPairMap map[uint32][]*TestKeyPair) map[uint32][]st return keysMap } +// PubKeysMapFromNodesKeysMap returns a map of public keys per shard from the key pairs per shard map. +func PubKeysMapFromNodesKeysMap(keyPairMap map[uint32][]*TestNodeKeys) map[uint32][]string { + keysMap := make(map[uint32][]string) + + for shardId, keys := range keyPairMap { + addAllKeysOnShard(keysMap, shardId, keys) + } + + return keysMap +} + +func addAllKeysOnShard(m map[uint32][]string, shardID uint32, keys []*TestNodeKeys) { + for _, keyOfTheNode := range keys { + addKeysToMap(m, shardID, keyOfTheNode) + } +} + +func addKeysToMap(m map[uint32][]string, shardID uint32, keysOfTheNode *TestNodeKeys) { + if len(keysOfTheNode.HandledKeys) == 0 { + b, _ := keysOfTheNode.MainKey.Pk.ToByteArray() + m[shardID] = append(m[shardID], string(b)) + return + } + + for _, handledKey := range keysOfTheNode.HandledKeys { + b, _ := handledKey.Pk.ToByteArray() + m[shardID] = append(m[shardID], string(b)) + } +} + // GenValidatorsFromPubKeys generates a map of validators per shard out of public keys map func GenValidatorsFromPubKeys(pubKeysMap map[uint32][]string, _ uint32) map[uint32][]nodesCoordinator.GenesisNodeInfoHandler { validatorsMap := make(map[uint32][]nodesCoordinator.GenesisNodeInfoHandler) @@ -2236,48 +2269,31 @@ func GenValidatorsFromPubKeysAndTxPubKeys( } // CreateCryptoParams generates the crypto parameters (key pairs, key generator and suite) for multiple nodes -func CreateCryptoParams(nodesPerShard int, nbMetaNodes int, nbShards uint32) *CryptoParams { +func CreateCryptoParams(nodesPerShard int, nbMetaNodes int, nbShards uint32, numKeysOnEachNode int) *CryptoParams { txSuite := ed25519.NewEd25519() txKeyGen := signing.NewKeyGenerator(txSuite) suite := mcl.NewSuiteBLS12() singleSigner := TestSingleSigner keyGen := signing.NewKeyGenerator(suite) + p2pSuite := secp256k1.NewSecp256k1() + p2pKeyGen := signing.NewKeyGenerator(p2pSuite) + nodesKeysMap := make(map[uint32][]*TestNodeKeys) txKeysMap := make(map[uint32][]*TestKeyPair) - keysMap := make(map[uint32][]*TestKeyPair) for shardId := uint32(0); shardId < nbShards; shardId++ { - txKeyPairs := make([]*TestKeyPair, nodesPerShard) - keyPairs := make([]*TestKeyPair, nodesPerShard) for n := 0; n < nodesPerShard; n++ { - kp := &TestKeyPair{} - kp.Sk, kp.Pk = keyGen.GeneratePair() - keyPairs[n] = kp - - txKp := &TestKeyPair{} - txKp.Sk, txKp.Pk = txKeyGen.GeneratePair() - txKeyPairs[n] = txKp + createAndAddKeys(keyGen, txKeyGen, shardId, nodesKeysMap, txKeysMap, numKeysOnEachNode) } - keysMap[shardId] = keyPairs - txKeysMap[shardId] = txKeyPairs } - txKeyPairs := make([]*TestKeyPair, nbMetaNodes) - keyPairs := make([]*TestKeyPair, nbMetaNodes) for n := 0; n < nbMetaNodes; n++ { - kp := &TestKeyPair{} - kp.Sk, kp.Pk = keyGen.GeneratePair() - keyPairs[n] = kp - - txKp := &TestKeyPair{} - txKp.Sk, txKp.Pk = txKeyGen.GeneratePair() - txKeyPairs[n] = txKp + createAndAddKeys(keyGen, txKeyGen, core.MetachainShardId, nodesKeysMap, txKeysMap, numKeysOnEachNode) } - keysMap[core.MetachainShardId] = keyPairs - txKeysMap[core.MetachainShardId] = txKeyPairs params := &CryptoParams{ - Keys: keysMap, + NodesKeys: nodesKeysMap, KeyGen: keyGen, + P2PKeyGen: p2pKeyGen, SingleSigner: singleSigner, TxKeyGen: txKeyGen, TxKeys: txKeysMap, @@ -2286,6 +2302,38 @@ func CreateCryptoParams(nodesPerShard int, nbMetaNodes int, nbShards uint32) *Cr return params } +func createAndAddKeys( + keyGen crypto.KeyGenerator, + txKeyGen crypto.KeyGenerator, + shardId uint32, + nodeKeysMap map[uint32][]*TestNodeKeys, + txKeysMap map[uint32][]*TestKeyPair, + numKeysOnEachNode int, +) { + kp := &TestKeyPair{} + kp.Sk, kp.Pk = keyGen.GeneratePair() + + txKp := &TestKeyPair{} + txKp.Sk, txKp.Pk = txKeyGen.GeneratePair() + + nodeKey := &TestNodeKeys{ + MainKey: kp, + } + + txKeysMap[shardId] = append(txKeysMap[shardId], txKp) + nodeKeysMap[shardId] = append(nodeKeysMap[shardId], nodeKey) + if numKeysOnEachNode == 1 { + return + } + + for i := 0; i < numKeysOnEachNode; i++ { + validatorKp := &TestKeyPair{} + validatorKp.Sk, validatorKp.Pk = keyGen.GeneratePair() + + nodeKey.HandledKeys = append(nodeKey.HandledKeys, validatorKp) + } +} + // CloseProcessorNodes closes the used TestProcessorNodes and advertiser func CloseProcessorNodes(nodes []*TestProcessorNode) { for _, n := range nodes { diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index 8ee9725f386..a242624c051 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -40,7 +40,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" - "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" + requesterscontainer "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" "github.com/multiversx/mx-chain-go/dataRetriever/requestHandlers" "github.com/multiversx/mx-chain-go/dblookupext" @@ -231,10 +231,17 @@ type TestKeyPair struct { Pk crypto.PublicKey } -// CryptoParams holds crypto parametres +// TestNodeKeys will hold the main key along the handled keys of a node +type TestNodeKeys struct { + MainKey *TestKeyPair + HandledKeys []*TestKeyPair +} + +// CryptoParams holds crypto parameters type CryptoParams struct { KeyGen crypto.KeyGenerator - Keys map[uint32][]*TestKeyPair + P2PKeyGen crypto.KeyGenerator + NodesKeys map[uint32][]*TestNodeKeys SingleSigner crypto.SingleSigner TxKeyGen crypto.KeyGenerator TxKeys map[uint32][]*TestKeyPair @@ -264,7 +271,7 @@ type ArgTestProcessorNode struct { HardforkPk crypto.PublicKey GenesisFile string StateCheckpointModulus *IntWrapper - NodeKeys *TestKeyPair + NodeKeys *TestNodeKeys NodesSetup sharding.GenesisNodesSetupHandler NodesCoordinator nodesCoordinator.NodesCoordinator MultiSigner crypto.MultiSigner @@ -285,7 +292,7 @@ type TestProcessorNode struct { Messenger p2p.Messenger OwnAccount *TestWalletAccount - NodeKeys *TestKeyPair + NodeKeys *TestNodeKeys ExportFolder string DataPool dataRetriever.PoolsHolder @@ -454,8 +461,11 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { tpn.NodeKeys = args.NodeKeys if tpn.NodeKeys == nil { kg := &mock.KeyGenMock{} - tpn.NodeKeys = &TestKeyPair{} - tpn.NodeKeys.Sk, tpn.NodeKeys.Pk = kg.GeneratePair() + kp := &TestKeyPair{} + kp.Sk, kp.Pk = kg.GeneratePair() + tpn.NodeKeys = &TestNodeKeys{ + MainKey: kp, + } } tpn.MultiSigner = TestMultiSig @@ -712,11 +722,14 @@ func (tpn *TestProcessorNode) initTestNodeWithArgs(args ArgTestProcessorNode) { TestHasher, tpn.Messenger, tpn.ShardCoordinator, - tpn.OwnAccount.SkTxSign, tpn.OwnAccount.PeerSigHandler, tpn.DataPool.Headers(), tpn.InterceptorsContainer, &testscommon.AlarmSchedulerStub{}, + testscommon.NewKeysHandlerSingleSignerMock( + tpn.NodeKeys.MainKey.Sk, + tpn.Messenger.ID(), + ), ) if args.WithSync { @@ -893,11 +906,14 @@ func (tpn *TestProcessorNode) InitializeProcessors(gasMap map[string]map[string] TestHasher, tpn.Messenger, tpn.ShardCoordinator, - tpn.OwnAccount.SkTxSign, tpn.OwnAccount.PeerSigHandler, tpn.DataPool.Headers(), tpn.InterceptorsContainer, &testscommon.AlarmSchedulerStub{}, + testscommon.NewKeysHandlerSingleSignerMock( + tpn.NodeKeys.MainKey.Sk, + tpn.Messenger.ID(), + ), ) tpn.setGenesisBlock() tpn.initNode() @@ -1236,7 +1252,7 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { } func (tpn *TestProcessorNode) createHardforkTrigger(heartbeatPk string) []byte { - pkBytes, _ := tpn.NodeKeys.Pk.ToByteArray() + pkBytes, _ := tpn.NodeKeys.MainKey.Pk.ToByteArray() argHardforkTrigger := trigger.ArgHardforkTrigger{ TriggerPubKeyBytes: pkBytes, Enabled: true, @@ -1971,7 +1987,6 @@ func (tpn *TestProcessorNode) initBlockProcessor(stateCheckpointModulus uint) { triesConfig := config.Config{ StateTriesConfig: config.StateTriesConfig{ - SnapshotsEnabled: true, CheckpointRoundsModulus: stateCheckpointModulus, }, } @@ -2262,8 +2277,8 @@ func (tpn *TestProcessorNode) initNode() { processComponents.HardforkTriggerField = tpn.HardforkTrigger cryptoComponents := GetDefaultCryptoComponents() - cryptoComponents.PrivKey = tpn.NodeKeys.Sk - cryptoComponents.PubKey = tpn.NodeKeys.Pk + cryptoComponents.PrivKey = tpn.NodeKeys.MainKey.Sk + cryptoComponents.PubKey = tpn.NodeKeys.MainKey.Pk cryptoComponents.TxSig = tpn.OwnAccount.SingleSigner cryptoComponents.BlockSig = tpn.OwnAccount.SingleSigner cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(tpn.MultiSigner) @@ -2520,14 +2535,16 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod } // BroadcastBlock broadcasts the block and body to the connected peers -func (tpn *TestProcessorNode) BroadcastBlock(body data.BodyHandler, header data.HeaderHandler) { +func (tpn *TestProcessorNode) BroadcastBlock(body data.BodyHandler, header data.HeaderHandler, publicKey crypto.PublicKey) { _ = tpn.BroadcastMessenger.BroadcastBlock(body, header) time.Sleep(tpn.WaitTime) + pkBytes, _ := publicKey.ToByteArray() + miniBlocks, transactions, _ := tpn.BlockProcessor.MarshalizedDataToBroadcast(header, body) - _ = tpn.BroadcastMessenger.BroadcastMiniBlocks(miniBlocks) - _ = tpn.BroadcastMessenger.BroadcastTransactions(transactions) + _ = tpn.BroadcastMessenger.BroadcastMiniBlocks(miniBlocks, pkBytes) + _ = tpn.BroadcastMessenger.BroadcastTransactions(transactions, pkBytes) } // WhiteListBody will whitelist all miniblocks from the given body for all the given nodes @@ -2816,8 +2833,8 @@ func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { log.LogIfError(err) cryptoComponents := GetDefaultCryptoComponents() - cryptoComponents.PrivKey = tpn.NodeKeys.Sk - cryptoComponents.PubKey = tpn.NodeKeys.Pk + cryptoComponents.PrivKey = tpn.NodeKeys.MainKey.Sk + cryptoComponents.PubKey = tpn.NodeKeys.MainKey.Pk cryptoComponents.TxSig = tpn.OwnAccount.SingleSigner cryptoComponents.BlockSig = tpn.OwnAccount.SingleSigner cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(tpn.MultiSigner) @@ -2886,6 +2903,7 @@ func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { HideInactiveValidatorIntervalInSec: 60, HardforkTimeBetweenSendsInSec: 2, TimeBetweenConnectionsMetricsUpdateInSec: 10, + PeerAuthenticationTimeBetweenChecksInSec: 1, HeartbeatPool: config.CacheConfig{ Type: "LRU", Capacity: 1000, @@ -3093,18 +3111,20 @@ func GetDefaultDataComponents() *mock.DataComponentsStub { // GetDefaultCryptoComponents - func GetDefaultCryptoComponents() *mock.CryptoComponentsStub { return &mock.CryptoComponentsStub{ - PubKey: &mock.PublicKeyMock{}, - PrivKey: &mock.PrivateKeyMock{}, - PubKeyString: "pubKey", - PrivKeyBytes: []byte("privKey"), - PubKeyBytes: []byte("pubKey"), - BlockSig: &mock.SignerMock{}, - TxSig: &mock.SignerMock{}, - MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(TestMultiSig), - PeerSignHandler: &mock.PeerSignatureHandler{}, - BlKeyGen: &mock.KeyGenMock{}, - TxKeyGen: &mock.KeyGenMock{}, - MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, + PubKey: &mock.PublicKeyMock{}, + PrivKey: &mock.PrivateKeyMock{}, + PubKeyString: "pubKey", + PrivKeyBytes: []byte("privKey"), + PubKeyBytes: []byte("pubKey"), + BlockSig: &mock.SignerMock{}, + TxSig: &mock.SignerMock{}, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(TestMultiSig), + PeerSignHandler: &mock.PeerSignatureHandler{}, + BlKeyGen: &mock.KeyGenMock{}, + TxKeyGen: &mock.KeyGenMock{}, + MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, + ManagedPeersHolderField: &testscommon.ManagedPeersHolderStub{}, + KeysHandlerField: &testscommon.KeysHandlerStub{}, } } diff --git a/integrationTests/testProcessorNodeWithCoordinator.go b/integrationTests/testProcessorNodeWithCoordinator.go index 843c2001dcb..7ab761e960f 100644 --- a/integrationTests/testProcessorNodeWithCoordinator.go +++ b/integrationTests/testProcessorNodeWithCoordinator.go @@ -37,13 +37,13 @@ func CreateProcessorNodesWithNodesCoordinator( ) (map[uint32][]*TestProcessorNode, uint32) { ncp, nbShards := createNodesCryptoParams(rewardsAddrsAssignments) - cp := CreateCryptoParams(len(ncp[0]), len(ncp[core.MetachainShardId]), nbShards) - pubKeys := PubKeysMapFromKeysMap(cp.Keys) + cp := CreateCryptoParams(len(ncp[0]), len(ncp[core.MetachainShardId]), nbShards, 1) + pubKeys := PubKeysMapFromNodesKeysMap(cp.NodesKeys) validatorsMap := GenValidatorsFromPubKeys(pubKeys, nbShards) validatorsMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) - cpWaiting := CreateCryptoParams(1, 1, nbShards) - pubKeysWaiting := PubKeysMapFromKeysMap(cpWaiting.Keys) + cpWaiting := CreateCryptoParams(1, 1, nbShards, 1) + pubKeysWaiting := PubKeysMapFromNodesKeysMap(cpWaiting.NodesKeys) waitingMap := GenValidatorsFromPubKeys(pubKeysWaiting, nbShards) waitingMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(waitingMap) @@ -107,9 +107,11 @@ func CreateProcessorNodesWithNodesCoordinator( MaxShards: numShards, NodeShardId: shardId, TxSignPrivKeyShardId: shardId, - NodeKeys: &TestKeyPair{ - Sk: kp.BlockSignSk, - Pk: kp.BlockSignPk, + NodeKeys: &TestNodeKeys{ + MainKey: &TestKeyPair{ + Sk: kp.BlockSignSk, + Pk: kp.BlockSignPk, + }, }, NodesSetup: nodesSetup, NodesCoordinator: nodesCoordinatorInstance, diff --git a/integrationTests/testProcessorNodeWithMultisigner.go b/integrationTests/testProcessorNodeWithMultisigner.go index 6eac46681d2..16011396d99 100644 --- a/integrationTests/testProcessorNodeWithMultisigner.go +++ b/integrationTests/testProcessorNodeWithMultisigner.go @@ -70,9 +70,9 @@ func CreateNodesWithNodesCoordinatorAndTxKeys( coordinatorFactory := &IndexHashedNodesCoordinatorWithRaterFactory{ PeerAccountListAndRatingHandler: rater, } - cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards)) - blsPubKeys := PubKeysMapFromKeysMap(cp.Keys) - txPubKeys := PubKeysMapFromKeysMap(cp.TxKeys) + cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards), 1) + blsPubKeys := PubKeysMapFromNodesKeysMap(cp.NodesKeys) + txPubKeys := PubKeysMapFromTxKeysMap(cp.TxKeys) validatorsMap := GenValidatorsFromPubKeysAndTxPubKeys(blsPubKeys, txPubKeys) validatorsMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) @@ -208,13 +208,13 @@ func CreateNodesWithNodesCoordinatorFactory( metaConsensusGroupSize int, nodesCoordinatorFactory NodesCoordinatorFactory, ) map[uint32][]*TestProcessorNode { - cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards)) - pubKeys := PubKeysMapFromKeysMap(cp.Keys) + cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards), 1) + pubKeys := PubKeysMapFromNodesKeysMap(cp.NodesKeys) validatorsMap := GenValidatorsFromPubKeys(pubKeys, uint32(nbShards)) validatorsMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) - cpWaiting := CreateCryptoParams(1, 1, uint32(nbShards)) - pubKeysWaiting := PubKeysMapFromKeysMap(cpWaiting.Keys) + cpWaiting := CreateCryptoParams(1, 1, uint32(nbShards), 1) + pubKeysWaiting := PubKeysMapFromNodesKeysMap(cpWaiting.NodesKeys) waitingMap := GenValidatorsFromPubKeys(pubKeysWaiting, uint32(nbShards)) waitingMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(waitingMap) @@ -355,7 +355,7 @@ func CreateNode( NodeShardId: shardId, TxSignPrivKeyShardId: txSignPrivKeyShardId, EpochsConfig: &epochsConfig, - NodeKeys: cp.Keys[shardId][keyIndex], + NodeKeys: cp.NodesKeys[shardId][keyIndex], NodesSetup: nodesSetup, NodesCoordinator: nodesCoordinatorInstance, RatingsData: ratingsData, @@ -385,8 +385,8 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( signer crypto.SingleSigner, keyGen crypto.KeyGenerator, ) map[uint32][]*TestProcessorNode { - cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards)) - pubKeys := PubKeysMapFromKeysMap(cp.Keys) + cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards), 1) + pubKeys := PubKeysMapFromNodesKeysMap(cp.NodesKeys) validatorsMap := GenValidatorsFromPubKeys(pubKeys, uint32(nbShards)) validatorsMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) nodesMap := make(map[uint32][]*TestProcessorNode) @@ -456,9 +456,9 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( } for i := range validatorList { - multiSigner, err := createMultiSigner(*cp) - if err != nil { - log.Error("error generating multisigner: %s\n", err) + multiSigner, errCreate := createMultiSigner(*cp) + if errCreate != nil { + log.Error("error generating multisigner: %s\n", errCreate) return nil } @@ -471,7 +471,7 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, }, - NodeKeys: cp.Keys[shardId][i], + NodeKeys: cp.NodesKeys[shardId][i], NodesSetup: nodesSetup, NodesCoordinator: nodesCoordinatorInstance, MultiSigner: multiSigner, @@ -502,13 +502,13 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( singleSigner crypto.SingleSigner, keyGenForBlocks crypto.KeyGenerator, ) map[uint32][]*TestProcessorNode { - cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards)) - pubKeys := PubKeysMapFromKeysMap(cp.Keys) + cp := CreateCryptoParams(nodesPerShard, nbMetaNodes, uint32(nbShards), 1) + pubKeys := PubKeysMapFromNodesKeysMap(cp.NodesKeys) validatorsMap := GenValidatorsFromPubKeys(pubKeys, uint32(nbShards)) validatorsMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) - cpWaiting := CreateCryptoParams(2, 2, uint32(nbShards)) - pubKeysWaiting := PubKeysMapFromKeysMap(cpWaiting.Keys) + cpWaiting := CreateCryptoParams(2, 2, uint32(nbShards), 1) + pubKeysWaiting := PubKeysMapFromNodesKeysMap(cpWaiting.NodesKeys) waitingMap := GenValidatorsFromPubKeys(pubKeysWaiting, uint32(nbShards)) waitingMapForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(waitingMap) @@ -580,9 +580,9 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( headerSig, _ := headerCheck.NewHeaderSigVerifier(&args) - multiSigner, err := createMultiSigner(*cp) - if err != nil { - log.Error("error generating multisigner: %s\n", err) + multiSigner, errCreate := createMultiSigner(*cp) + if errCreate != nil { + log.Error("error generating multisigner: %s\n", errCreate) return nil } @@ -595,7 +595,7 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, }, - NodeKeys: cp.Keys[shardId][i], + NodeKeys: cp.NodesKeys[shardId][i], NodesSetup: nodesSetup, NodesCoordinator: nodesCoord, MultiSigner: multiSigner, @@ -652,7 +652,7 @@ func selectTestNodesForPubKeys(nodes []*TestProcessorNode, pubKeys []string) []* for i, pk := range pubKeys { for _, node := range nodes { - pubKeyBytes, _ := node.NodeKeys.Pk.ToByteArray() + pubKeyBytes, _ := node.NodeKeys.MainKey.Pk.ToByteArray() if bytes.Equal(pubKeyBytes, []byte(pk)) { selectedNodes[i] = node cntNodes++ @@ -704,7 +704,7 @@ func DoConsensusSigningOnBlock( for i := 0; i < len(consensusNodes); i++ { pubKeysBytes[i] = []byte(pubKeys[i]) - sk, _ := consensusNodes[i].NodeKeys.Sk.ToByteArray() + sk, _ := consensusNodes[i].NodeKeys.MainKey.Sk.ToByteArray() sigShares[i], _ = msig.CreateSignatureShare(sk, blockHeaderHash) } @@ -767,7 +767,8 @@ func AllShardsProposeBlock( // propagate blocks for i := range nodesMap { - consensusNodes[i][0].BroadcastBlock(body[i], header[i]) + pk := consensusNodes[i][0].NodeKeys.MainKey.Pk + consensusNodes[i][0].BroadcastBlock(body[i], header[i], pk) consensusNodes[i][0].CommitBlock(body[i], header[i]) } diff --git a/integrationTests/testSyncNode.go b/integrationTests/testSyncNode.go index 3243b21f81d..8b2b72d5419 100644 --- a/integrationTests/testSyncNode.go +++ b/integrationTests/testSyncNode.go @@ -63,7 +63,6 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { triesConfig := config.Config{ StateTriesConfig: config.StateTriesConfig{ - SnapshotsEnabled: true, CheckpointRoundsModulus: stateCheckpointModulus, }, } diff --git a/integrationTests/vm/testInitializer.go b/integrationTests/vm/testInitializer.go index e5fb473ac9f..1acb1994d02 100644 --- a/integrationTests/vm/testInitializer.go +++ b/integrationTests/vm/testInitializer.go @@ -478,7 +478,7 @@ func CreateOneSCExecutorMockVM(accnts state.AccountsAdapter) vmcommon.VMExecutio ConfigSCStorage: *defaultStorageConfig(), EpochNotifier: &epochNotifier.EpochNotifierStub{}, EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, - GasSchedule: createMockGasScheduleNotifier(), + GasSchedule: CreateMockGasScheduleNotifier(), Counter: &testscommon.BlockChainHookCounterStub{}, } blockChainHook, _ := hooks.NewBlockChainHookImpl(args) @@ -1033,7 +1033,8 @@ func CreatePreparedTxProcessorAndAccountsWithVMs( }, nil } -func createMockGasScheduleNotifier() *mock.GasScheduleNotifierMock { +// CreateMockGasScheduleNotifier will create a mock gas schedule notifier to be used in tests +func CreateMockGasScheduleNotifier() *mock.GasScheduleNotifierMock { return createMockGasScheduleNotifierWithCustomGasSchedule(func(gasMap wasmConfig.GasScheduleMap) {}) } @@ -1067,7 +1068,7 @@ func CreatePreparedTxProcessorWithVMsWithShardCoordinator(enableEpochsConfig con enableEpochsConfig, shardCoordinator, integrationtests.CreateMemUnit(), - createMockGasScheduleNotifier(), + CreateMockGasScheduleNotifier(), ) } diff --git a/integrationTests/vm/txsFee/asyncCall_test.go b/integrationTests/vm/txsFee/asyncCall_test.go index d7afad10ad0..20c635f0a94 100644 --- a/integrationTests/vm/txsFee/asyncCall_test.go +++ b/integrationTests/vm/txsFee/asyncCall_test.go @@ -9,20 +9,28 @@ import ( "encoding/hex" "fmt" "math/big" + "strings" "testing" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/scheduled" + "github.com/multiversx/mx-chain-core-go/data/smartContractResult" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/integrationTests/vm/txsFee/utils" + "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/sharding" + "github.com/multiversx/mx-chain-go/testscommon/integrationtests" vmcommon "github.com/multiversx/mx-chain-vm-common-go" wasmConfig "github.com/multiversx/mx-chain-vm-v1_4-go/config" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +const upgradeContractFunction = "upgradeContract" + func TestAsyncCallShouldWork(t *testing.T) { testContext, err := vm.CreatePreparedTxProcessorWithVMs(config.EnableEpochs{}) require.Nil(t, err) @@ -132,3 +140,248 @@ func TestMinterContractWithAsyncCalls(t *testing.T) { require.Equal(t, "internalVMErrors", string(event.GetIdentifier())) require.Contains(t, string(event.GetData()), process.ErrMaxCallsReached.Error()) } + +func TestAsyncCallsOnInitFunctionOnUpgrade(t *testing.T) { + t.Run("backwards compatibility for unset flag", func(t *testing.T) { + gasScheduleNotifier := vm.CreateMockGasScheduleNotifier() + + firstContractCode := wasm.GetSCCode("./testdata/first/first.wasm") + + expectedGasLimit := gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallbackGasLockField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallStepField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOperationCost]["AoTPreparePerByte"]*uint64(len(firstContractCode))/2 + + enableEpoch := config.EnableEpochs{ + RuntimeCodeSizeFixEnableEpoch: 100000, // fix not activated + } + + testAsyncCallsOnInitFunctionOnUpgrade(t, enableEpoch, expectedGasLimit, gasScheduleNotifier) + }) + t.Run("fix activated", func(t *testing.T) { + gasScheduleNotifier := vm.CreateMockGasScheduleNotifier() + + newContractCode := wasm.GetSCCode("./testdata/asyncOnInit/asyncOnInit.wasm") + + expectedGasLimit := gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallbackGasLockField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallStepField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOperationCost]["AoTPreparePerByte"]*uint64(len(newContractCode))/2 + + enableEpoch := config.EnableEpochs{ + RuntimeCodeSizeFixEnableEpoch: 0, // fix activated + } + + testAsyncCallsOnInitFunctionOnUpgrade(t, enableEpoch, expectedGasLimit, gasScheduleNotifier) + }) +} + +func testAsyncCallsOnInitFunctionOnUpgrade(t *testing.T, enableEpochs config.EnableEpochs, expectedGasLimit uint64, gasScheduleNotifier core.GasScheduleNotifier) { + shardCoordinatorForShard0, _ := sharding.NewMultiShardCoordinator(3, 1) + shardCoordinatorForShardMeta, _ := sharding.NewMultiShardCoordinator(3, core.MetachainShardId) + + testContextShard0, err := vm.CreatePreparedTxProcessorWithVMsWithShardCoordinatorDBAndGas( + enableEpochs, + shardCoordinatorForShard0, + integrationtests.CreateMemUnit(), + gasScheduleNotifier, + ) + require.Nil(t, err) + testContextShardMeta, err := vm.CreatePreparedTxProcessorWithVMsWithShardCoordinatorDBAndGas( + enableEpochs, + shardCoordinatorForShardMeta, + integrationtests.CreateMemUnit(), + gasScheduleNotifier, + ) + require.Nil(t, err) + + // step 1. deploy the first contract + scAddress, owner := utils.DoDeployWithCustomParams( + t, + testContextShard0, + "./testdata/first/first.wasm", + big.NewInt(100000000000), + 2000, + nil, + ) + assert.Equal(t, 32, len(owner)) + assert.Equal(t, 32, len(scAddress)) + + intermediates := testContextShard0.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShard0.CleanIntermediateTransactions(t) + + // step 2. call a dummy function on the first version of the contract + + tx := utils.CreateSmartContractCall(1, owner, scAddress, 10, 2000, "callMe", nil) + code, err := testContextShard0.TxProcessor.ProcessTransaction(tx) + require.Nil(t, err) + require.Equal(t, vmcommon.Ok, code) + + intermediates = testContextShard0.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShard0.CleanIntermediateTransactions(t) + + // step 3. upgrade to the second contract + + newScCode := wasm.GetSCCode("./testdata/asyncOnInit/asyncOnInit.wasm") + txData := strings.Join([]string{ + upgradeContractFunction, + newScCode, + wasm.VMTypeHex, + hex.EncodeToString(core.ESDTSCAddress), + hex.EncodeToString([]byte("nonExistentFunction")), + hex.EncodeToString([]byte("dummyArg")), + }, "@") + tx = utils.CreateSmartContractCall(2, owner, scAddress, 10, 10000000, txData, nil) + code, err = testContextShard0.TxProcessor.ProcessTransaction(tx) + assert.Nil(t, err) + assert.Equal(t, vmcommon.Ok, code) + + intermediates = testContextShard0.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShard0.CleanIntermediateTransactions(t) + + // step 4. execute scr on metachain, should fail + + scr := intermediates[0].(*smartContractResult.SmartContractResult) + code, err = testContextShardMeta.ScProcessor.ProcessSmartContractResult(scr) + assert.Nil(t, err) + assert.Equal(t, vmcommon.UserError, code) + + intermediates = testContextShardMeta.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShardMeta.CleanIntermediateTransactions(t) + + // step 5. execute generated metachain scr on the contract + scr = intermediates[0].(*smartContractResult.SmartContractResult) + code, err = testContextShard0.ScProcessor.ProcessSmartContractResult(scr) + assert.Nil(t, err) + assert.Equal(t, vmcommon.Ok, code) + + assert.Equal(t, 1, len(intermediates)) + testContextShardMeta.CleanIntermediateTransactions(t) + + assert.Equal(t, expectedGasLimit, intermediates[0].GetGasLimit()) +} + +func TestAsyncCallsOnInitFunctionOnDeploy(t *testing.T) { + t.Run("backwards compatibility for unset flag", func(t *testing.T) { + gasScheduleNotifier := vm.CreateMockGasScheduleNotifier() + + firstContractCode := wasm.GetSCCode("./testdata/first/first.wasm") + + expectedGasLimit := gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallbackGasLockField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallStepField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOperationCost]["AoTPreparePerByte"]*uint64(len(firstContractCode))/2 + + enableEpoch := config.EnableEpochs{ + RuntimeCodeSizeFixEnableEpoch: 100000, // fix not activated + } + + testAsyncCallsOnInitFunctionOnDeploy(t, enableEpoch, expectedGasLimit, gasScheduleNotifier) + }) + t.Run("fix activated", func(t *testing.T) { + gasScheduleNotifier := vm.CreateMockGasScheduleNotifier() + + newContractCode := wasm.GetSCCode("./testdata/asyncOnInit/asyncOnInit.wasm") + + expectedGasLimit := gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallbackGasLockField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOpsAPICost][common.AsyncCallStepField] + + gasScheduleNotifier.LatestGasSchedule()[common.BaseOperationCost]["AoTPreparePerByte"]*uint64(len(newContractCode))/2 + + enableEpoch := config.EnableEpochs{ + RuntimeCodeSizeFixEnableEpoch: 0, // fix activated + } + + testAsyncCallsOnInitFunctionOnDeploy(t, enableEpoch, expectedGasLimit, gasScheduleNotifier) + }) +} + +func testAsyncCallsOnInitFunctionOnDeploy(t *testing.T, enableEpochs config.EnableEpochs, expectedGasLimit uint64, gasScheduleNotifier core.GasScheduleNotifier) { + shardCoordinatorForShard0, _ := sharding.NewMultiShardCoordinator(3, 1) + shardCoordinatorForShardMeta, _ := sharding.NewMultiShardCoordinator(3, core.MetachainShardId) + + testContextShard0, err := vm.CreatePreparedTxProcessorWithVMsWithShardCoordinatorDBAndGas( + enableEpochs, + shardCoordinatorForShard0, + integrationtests.CreateMemUnit(), + gasScheduleNotifier, + ) + require.Nil(t, err) + testContextShardMeta, err := vm.CreatePreparedTxProcessorWithVMsWithShardCoordinatorDBAndGas( + enableEpochs, + shardCoordinatorForShardMeta, + integrationtests.CreateMemUnit(), + gasScheduleNotifier, + ) + require.Nil(t, err) + + // step 1. deploy the first contract + scAddressFirst, firstOwner := utils.DoDeployWithCustomParams( + t, + testContextShard0, + "./testdata/first/first.wasm", + big.NewInt(100000000000), + 2000, + nil, + ) + assert.Equal(t, 32, len(firstOwner)) + assert.Equal(t, 32, len(scAddressFirst)) + + intermediates := testContextShard0.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShard0.CleanIntermediateTransactions(t) + + // step 2. call a dummy function on the first contract + + tx := utils.CreateSmartContractCall(1, firstOwner, scAddressFirst, 10, 2000, "callMe", nil) + code, err := testContextShard0.TxProcessor.ProcessTransaction(tx) + require.Nil(t, err) + require.Equal(t, vmcommon.Ok, code) + + intermediates = testContextShard0.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShard0.CleanIntermediateTransactions(t) + + // step 3. deploy the second contract that does an async on init function + + scAddressSecond, secondOwner := utils.DoDeployWithCustomParams( + t, + testContextShard0, + "./testdata/asyncOnInit/asyncOnInit.wasm", + big.NewInt(100000000000), + 10000000, + []string{ + hex.EncodeToString(core.ESDTSCAddress), + hex.EncodeToString([]byte("nonExistentFunction")), + hex.EncodeToString([]byte("dummyArg")), + }, + ) + assert.Equal(t, 32, len(secondOwner)) + assert.Equal(t, 32, len(scAddressSecond)) + + intermediates = testContextShard0.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShard0.CleanIntermediateTransactions(t) + + // step 4. execute scr on metachain, should fail + + scr := intermediates[0].(*smartContractResult.SmartContractResult) + code, err = testContextShardMeta.ScProcessor.ProcessSmartContractResult(scr) + assert.Nil(t, err) + assert.Equal(t, vmcommon.UserError, code) + + intermediates = testContextShardMeta.GetIntermediateTransactions(t) + assert.Equal(t, 1, len(intermediates)) + testContextShardMeta.CleanIntermediateTransactions(t) + + // step 5. execute generated metachain scr on the contract + scr = intermediates[0].(*smartContractResult.SmartContractResult) + code, err = testContextShard0.ScProcessor.ProcessSmartContractResult(scr) + assert.Nil(t, err) + assert.Equal(t, vmcommon.Ok, code) + + assert.Equal(t, 1, len(intermediates)) + testContextShardMeta.CleanIntermediateTransactions(t) + + assert.Equal(t, expectedGasLimit, intermediates[0].GetGasLimit()) +} diff --git a/integrationTests/vm/txsFee/testdata/asyncOnInit/asyncOnInit.wasm b/integrationTests/vm/txsFee/testdata/asyncOnInit/asyncOnInit.wasm new file mode 100644 index 00000000000..8e0ff9d45cd Binary files /dev/null and b/integrationTests/vm/txsFee/testdata/asyncOnInit/asyncOnInit.wasm differ diff --git a/integrationTests/vm/txsFee/testdata/second/elrond.json b/integrationTests/vm/txsFee/testdata/second/elrond.json deleted file mode 100644 index d9a84b2b6db..00000000000 --- a/integrationTests/vm/txsFee/testdata/second/elrond.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "language": "clang" -} diff --git a/integrationTests/vm/txsFee/utils/utils.go b/integrationTests/vm/txsFee/utils/utils.go index 02a8ff7c435..1d9e114d8c2 100644 --- a/integrationTests/vm/txsFee/utils/utils.go +++ b/integrationTests/vm/txsFee/utils/utils.go @@ -80,7 +80,9 @@ func DoDeployWithCustomParams( contractHexParams []string, ) (scAddr []byte, owner []byte) { owner = []byte("12345678901234567890123456789011") - senderNonce := uint64(0) + account, err := testContext.Accounts.LoadAccount(owner) + require.Nil(tb, err) + senderNonce := account.GetNonce() gasPrice := uint64(10) _, _ = vm.CreateAccount(testContext.Accounts, owner, 0, senderBalance) diff --git a/keysManagement/errors.go b/keysManagement/errors.go new file mode 100644 index 00000000000..5c2338d28f3 --- /dev/null +++ b/keysManagement/errors.go @@ -0,0 +1,30 @@ +package keysManagement + +import "errors" + +// ErrDuplicatedKey signals that a key is already managed by the node +var ErrDuplicatedKey = errors.New("duplicated key found") + +// ErrMissingPublicKeyDefinition signals that a public key definition is missing +var ErrMissingPublicKeyDefinition = errors.New("missing public key definition") + +// ErrNilKeyGenerator signals that a nil key generator was provided +var ErrNilKeyGenerator = errors.New("nil key generator") + +// ErrInvalidValue signals that an invalid value was provided +var ErrInvalidValue = errors.New("invalid value") + +// ErrInvalidKey signals that an invalid key was provided +var ErrInvalidKey = errors.New("invalid key") + +// ErrNilManagedPeersHolder signals that a nil managed peers holder was provided +var ErrNilManagedPeersHolder = errors.New("nil managed peers holder") + +// ErrNilPrivateKey signals that a nil private key was provided +var ErrNilPrivateKey = errors.New("nil private key") + +// ErrEmptyPeerID signals that an empty peer ID was provided +var ErrEmptyPeerID = errors.New("empty peer ID") + +// ErrNilP2PKeyConverter signals that a nil p2p key converter has been provided +var ErrNilP2PKeyConverter = errors.New("nil p2p key converter") diff --git a/keysManagement/export_test.go b/keysManagement/export_test.go new file mode 100644 index 00000000000..db42feed8b6 --- /dev/null +++ b/keysManagement/export_test.go @@ -0,0 +1,75 @@ +package keysManagement + +import ( + "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" +) + +// GetRoundsWithoutReceivedMessages - +func (pInfo *peerInfo) GetRoundsWithoutReceivedMessages() int { + pInfo.mutChangeableData.RLock() + defer pInfo.mutChangeableData.RUnlock() + + return pInfo.roundsWithoutReceivedMessages +} + +// Pid - +func (pInfo *peerInfo) Pid() core.PeerID { + return pInfo.pid +} + +// P2pPrivateKeyBytes - +func (pInfo *peerInfo) P2pPrivateKeyBytes() []byte { + return pInfo.p2pPrivateKeyBytes +} + +// PrivateKey - +func (pInfo *peerInfo) PrivateKey() crypto.PrivateKey { + return pInfo.privateKey +} + +// MachineID - +func (pInfo *peerInfo) MachineID() string { + return pInfo.machineID +} + +// NodeName - +func (pInfo *peerInfo) NodeName() string { + return pInfo.nodeName +} + +// NodeIdentity - +func (pInfo *peerInfo) NodeIdentity() string { + return pInfo.nodeIdentity +} + +// GetPeerInfo - +func (holder *managedPeersHolder) GetPeerInfo(pkBytes []byte) *peerInfo { + return holder.getPeerInfo(pkBytes) +} + +// ManagedPeersHolder - +func (handler *keysHandler) ManagedPeersHolder() common.ManagedPeersHolder { + return handler.managedPeersHolder +} + +// PrivateKey - +func (handler *keysHandler) PrivateKey() crypto.PrivateKey { + return handler.privateKey +} + +// PublicKey - +func (handler *keysHandler) PublicKey() crypto.PublicKey { + return handler.publicKey +} + +// PublicKeyBytes - +func (handler *keysHandler) PublicKeyBytes() []byte { + return handler.publicKeyBytes +} + +// Pid - +func (handler *keysHandler) Pid() core.PeerID { + return handler.pid +} diff --git a/keysManagement/keysHandler.go b/keysManagement/keysHandler.go new file mode 100644 index 00000000000..6414d2f8a2e --- /dev/null +++ b/keysManagement/keysHandler.go @@ -0,0 +1,130 @@ +package keysManagement + +import ( + "bytes" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" +) + +// ArgsKeysHandler is the argument DTO struct for the NewKeysHandler constructor function +type ArgsKeysHandler struct { + ManagedPeersHolder common.ManagedPeersHolder + PrivateKey crypto.PrivateKey + Pid core.PeerID +} + +// keysHandler will manage all keys available on the node either in single signer mode or multi key mode +type keysHandler struct { + managedPeersHolder common.ManagedPeersHolder + privateKey crypto.PrivateKey + publicKey crypto.PublicKey + publicKeyBytes []byte + pid core.PeerID +} + +// NewKeysHandler will create a new instance of type keysHandler +func NewKeysHandler(args ArgsKeysHandler) (*keysHandler, error) { + err := checkArgsKeysHandler(args) + if err != nil { + return nil, err + } + + pk := args.PrivateKey.GeneratePublic() + pkBytes, err := pk.ToByteArray() + if err != nil { + return nil, err + } + + return &keysHandler{ + managedPeersHolder: args.ManagedPeersHolder, + privateKey: args.PrivateKey, + publicKey: pk, + publicKeyBytes: pkBytes, + pid: args.Pid, + }, nil +} + +func checkArgsKeysHandler(args ArgsKeysHandler) error { + if check.IfNil(args.ManagedPeersHolder) { + return ErrNilManagedPeersHolder + } + if check.IfNil(args.PrivateKey) { + return ErrNilPrivateKey + } + if len(args.Pid) == 0 { + return ErrEmptyPeerID + } + + return nil +} + +// GetHandledPrivateKey will return the correct private key by using the provided pkBytes to select from a stored one +// Returns the current private key if the pkBytes is not handled by the current node +func (handler *keysHandler) GetHandledPrivateKey(pkBytes []byte) crypto.PrivateKey { + if handler.IsOriginalPublicKeyOfTheNode(pkBytes) { + return handler.privateKey + } + + privateKey, err := handler.managedPeersHolder.GetPrivateKey(pkBytes) + if err != nil { + log.Warn("setup error in keysHandler.GetHandledPrivateKey, returning original private key", "error", err) + + return handler.privateKey + } + + return privateKey +} + +// GetP2PIdentity returns the associated p2p identity with the provided public key bytes: the private key and the peer ID +func (handler *keysHandler) GetP2PIdentity(pkBytes []byte) ([]byte, core.PeerID, error) { + return handler.managedPeersHolder.GetP2PIdentity(pkBytes) +} + +// IsKeyManagedByCurrentNode will return if the provided key is a managed one and the current node should use it +func (handler *keysHandler) IsKeyManagedByCurrentNode(pkBytes []byte) bool { + return handler.managedPeersHolder.IsKeyManagedByCurrentNode(pkBytes) +} + +// IncrementRoundsWithoutReceivedMessages increments the provided rounds without received messages counter on the provided public key +func (handler *keysHandler) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { + handler.managedPeersHolder.IncrementRoundsWithoutReceivedMessages(pkBytes) +} + +// GetAssociatedPid will return the associated peer ID from the provided public key bytes. Will search in the managed keys mapping +// if the public key is not the original public key of the node +func (handler *keysHandler) GetAssociatedPid(pkBytes []byte) core.PeerID { + if handler.IsOriginalPublicKeyOfTheNode(pkBytes) { + return handler.pid + } + + _, pid, err := handler.managedPeersHolder.GetP2PIdentity(pkBytes) + if err != nil { + log.Warn("setup error in keysHandler.GetAssociatedPid, returning original pid", "error", err) + + return handler.pid + } + + return pid +} + +// IsOriginalPublicKeyOfTheNode returns true if the provided public key bytes are the original ones used by the node +func (handler *keysHandler) IsOriginalPublicKeyOfTheNode(pkBytes []byte) bool { + return bytes.Equal(pkBytes, handler.publicKeyBytes) +} + +// UpdatePublicKeyLiveness update the provided public key liveness if the provided pid is not managed by the current node +func (handler *keysHandler) UpdatePublicKeyLiveness(pkBytes []byte, pid core.PeerID) { + if bytes.Equal(handler.pid.Bytes(), pid.Bytes()) { + return + } + + handler.managedPeersHolder.ResetRoundsWithoutReceivedMessages(pkBytes) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (handler *keysHandler) IsInterfaceNil() bool { + return handler == nil +} diff --git a/keysManagement/keysHandler_test.go b/keysManagement/keysHandler_test.go new file mode 100644 index 00000000000..b6c490fb448 --- /dev/null +++ b/keysManagement/keysHandler_test.go @@ -0,0 +1,276 @@ +package keysManagement_test + +import ( + "errors" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/keysManagement" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/stretchr/testify/assert" +) + +var testPrivateKeyBytes = []byte("private key bytes") +var testPublicKeyBytes = []byte("public key bytes") +var randomPublicKeyBytes = []byte("random key bytes") + +func createMockArgsKeysHandler() keysManagement.ArgsKeysHandler { + return keysManagement.ArgsKeysHandler{ + ManagedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + PrivateKey: &cryptoMocks.PrivateKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return testPrivateKeyBytes, nil + }, + GeneratePublicStub: func() crypto.PublicKey { + return &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return testPublicKeyBytes, nil + }, + } + }, + }, + Pid: pid, + } +} + +func TestNewKeysHandler(t *testing.T) { + t.Parallel() + + t.Run("nil managed peers holder should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = nil + handler, err := keysManagement.NewKeysHandler(args) + + assert.True(t, check.IfNil(handler)) + assert.Equal(t, keysManagement.ErrNilManagedPeersHolder, err) + }) + t.Run("nil private key should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + args.PrivateKey = nil + handler, err := keysManagement.NewKeysHandler(args) + + assert.True(t, check.IfNil(handler)) + assert.Equal(t, keysManagement.ErrNilPrivateKey, err) + }) + t.Run("empty pid should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + args.Pid = "" + handler, err := keysManagement.NewKeysHandler(args) + + assert.True(t, check.IfNil(handler)) + assert.Equal(t, keysManagement.ErrEmptyPeerID, err) + }) + t.Run("public key bytes generation errors should error", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected error") + args := createMockArgsKeysHandler() + args.PrivateKey = &cryptoMocks.PrivateKeyStub{ + GeneratePublicStub: func() crypto.PublicKey { + return &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return nil, expectedErr + }, + } + }, + } + handler, err := keysManagement.NewKeysHandler(args) + + assert.True(t, check.IfNil(handler)) + assert.Equal(t, expectedErr, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + handler, err := keysManagement.NewKeysHandler(args) + + assert.False(t, check.IfNil(handler)) + assert.Nil(t, err) + assert.Equal(t, testPublicKeyBytes, handler.PublicKeyBytes()) + assert.Equal(t, pid, handler.Pid()) + assert.False(t, check.IfNil(handler.PrivateKey())) + assert.False(t, check.IfNil(handler.ManagedPeersHolder())) + assert.False(t, check.IfNil(handler.PrivateKey())) + }) +} + +func TestKeysHandler_GetHandledPrivateKey(t *testing.T) { + t.Parallel() + + t.Run("is original public key of the node", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + handler, _ := keysManagement.NewKeysHandler(args) + + sk := handler.GetHandledPrivateKey(testPublicKeyBytes) + assert.True(t, sk == handler.PrivateKey()) // pointer testing + }) + t.Run("managedPeersHolder.GetPrivateKey errors", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + GetPrivateKeyCalled: func(pkBytes []byte) (crypto.PrivateKey, error) { + return nil, errors.New("private key not found") + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + sk := handler.GetHandledPrivateKey(randomPublicKeyBytes) + assert.True(t, sk == handler.PrivateKey()) // pointer testing + }) + t.Run("managedPeersHolder.GetPrivateKey returns the private key", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + handledPrivateKey := &cryptoMocks.PrivateKeyStub{} + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + GetPrivateKeyCalled: func(pkBytes []byte) (crypto.PrivateKey, error) { + assert.Equal(t, randomPublicKeyBytes, pkBytes) + return handledPrivateKey, nil + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + sk := handler.GetHandledPrivateKey(randomPublicKeyBytes) + assert.True(t, sk == handledPrivateKey) // pointer testing + }) +} + +func TestKeysHandler_GetP2PIdentity(t *testing.T) { + t.Parallel() + + p2pPrivateKeyBytes := []byte("p2p private key bytes") + wasCalled := false + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + GetP2PIdentityCalled: func(pkBytes []byte) ([]byte, core.PeerID, error) { + assert.Equal(t, randomPublicKeyBytes, pkBytes) + wasCalled = true + + return p2pPrivateKeyBytes, pid, nil + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + recoveredPrivateKey, recoveredPid, err := handler.GetP2PIdentity(randomPublicKeyBytes) + assert.Nil(t, err) + assert.True(t, wasCalled) + assert.Equal(t, p2pPrivateKeyBytes, recoveredPrivateKey) + assert.Equal(t, pid, recoveredPid) +} + +func TestKeysHandler_IsKeyManagedByCurrentNode(t *testing.T) { + t.Parallel() + + wasCalled := false + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + assert.Equal(t, randomPublicKeyBytes, pkBytes) + wasCalled = true + return true + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + isManaged := handler.IsKeyManagedByCurrentNode(randomPublicKeyBytes) + assert.True(t, wasCalled) + assert.True(t, isManaged) +} + +func TestKeysHandler_IncrementRoundsWithoutReceivedMessages(t *testing.T) { + t.Parallel() + + wasCalled := false + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IncrementRoundsWithoutReceivedMessagesCalled: func(pkBytes []byte) { + assert.Equal(t, randomPublicKeyBytes, pkBytes) + wasCalled = true + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + handler.IncrementRoundsWithoutReceivedMessages(randomPublicKeyBytes) + assert.True(t, wasCalled) +} + +func TestKeysHandler_GetAssociatedPid(t *testing.T) { + t.Parallel() + + t.Run("is original public key of the node", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + handler, _ := keysManagement.NewKeysHandler(args) + + recoveredPid := handler.GetAssociatedPid(testPublicKeyBytes) + assert.True(t, recoveredPid == args.Pid) + }) + t.Run("managedPeersHolder.GetP2PIdentity errors", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + GetP2PIdentityCalled: func(pkBytes []byte) ([]byte, core.PeerID, error) { + return nil, "", errors.New("identity not found") + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + recoveredPid := handler.GetAssociatedPid(randomPublicKeyBytes) + assert.True(t, recoveredPid == args.Pid) + }) + t.Run("managedPeersHolder.GetP2PIdentity returns the identity", func(t *testing.T) { + t.Parallel() + + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + GetP2PIdentityCalled: func(pkBytes []byte) ([]byte, core.PeerID, error) { + assert.Equal(t, randomPublicKeyBytes, pkBytes) + + return make([]byte, 0), pid, nil + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + recoveredPid := handler.GetAssociatedPid(randomPublicKeyBytes) + assert.Equal(t, pid, recoveredPid) + }) +} + +func TestKeysHandler_UpdatePublicKeyLiveness(t *testing.T) { + t.Parallel() + + mapResetCalled := make(map[string]int) + args := createMockArgsKeysHandler() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + ResetRoundsWithoutReceivedMessagesCalled: func(pkBytes []byte) { + mapResetCalled[string(pkBytes)]++ + }, + } + handler, _ := keysManagement.NewKeysHandler(args) + + t.Run("same pid should not call reset", func(t *testing.T) { + handler.UpdatePublicKeyLiveness(randomPublicKeyBytes, pid) + assert.Zero(t, len(mapResetCalled)) + }) + t.Run("another pid should call reset", func(t *testing.T) { + randomPid := core.PeerID("random pid") + handler.UpdatePublicKeyLiveness(randomPublicKeyBytes, randomPid) + assert.Equal(t, 1, len(mapResetCalled)) + assert.Equal(t, 1, mapResetCalled[string(randomPublicKeyBytes)]) + }) +} diff --git a/keysManagement/managedPeersHolder.go b/keysManagement/managedPeersHolder.go new file mode 100644 index 00000000000..0cc7ea8c9e6 --- /dev/null +++ b/keysManagement/managedPeersHolder.go @@ -0,0 +1,356 @@ +package keysManagement + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "sync" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/p2p" + logger "github.com/multiversx/mx-chain-logger-go" +) + +const minRoundsWithoutReceivedMessages = -1 + +var log = logger.GetOrCreate("keysManagement") + +type managedPeersHolder struct { + mut sync.RWMutex + data map[string]*peerInfo + pids map[core.PeerID]struct{} + keyGenerator crypto.KeyGenerator + p2pKeyGenerator crypto.KeyGenerator + isMainMachine bool + maxRoundsWithoutReceivedMessages int + defaultName string + defaultIdentity string + p2pKeyConverter p2p.P2PKeyConverter +} + +// ArgsManagedPeersHolder represents the argument for the managed peers holder +type ArgsManagedPeersHolder struct { + KeyGenerator crypto.KeyGenerator + P2PKeyGenerator crypto.KeyGenerator + IsMainMachine bool + MaxRoundsWithoutReceivedMessages int + PrefsConfig config.Preferences + P2PKeyConverter p2p.P2PKeyConverter +} + +// NewManagedPeersHolder creates a new instance of a managed peers holder +func NewManagedPeersHolder(args ArgsManagedPeersHolder) (*managedPeersHolder, error) { + err := checkManagedPeersHolderArgs(args) + if err != nil { + return nil, err + } + + dataMap, err := createDataMap(args.PrefsConfig.NamedIdentity) + if err != nil { + return nil, err + } + + holder := &managedPeersHolder{ + data: dataMap, + pids: make(map[core.PeerID]struct{}), + keyGenerator: args.KeyGenerator, + p2pKeyGenerator: args.P2PKeyGenerator, + isMainMachine: args.IsMainMachine, + maxRoundsWithoutReceivedMessages: args.MaxRoundsWithoutReceivedMessages, + defaultName: args.PrefsConfig.Preferences.NodeDisplayName, + defaultIdentity: args.PrefsConfig.Preferences.Identity, + p2pKeyConverter: args.P2PKeyConverter, + } + + return holder, nil +} + +func checkManagedPeersHolderArgs(args ArgsManagedPeersHolder) error { + if check.IfNil(args.KeyGenerator) { + return fmt.Errorf("%w for args.KeyGenerator", ErrNilKeyGenerator) + } + if check.IfNil(args.P2PKeyGenerator) { + return fmt.Errorf("%w for args.P2PKeyGenerator", ErrNilKeyGenerator) + } + if args.MaxRoundsWithoutReceivedMessages < minRoundsWithoutReceivedMessages { + return fmt.Errorf("%w for MaxRoundsWithoutReceivedMessages, minimum %d, got %d", + ErrInvalidValue, minRoundsWithoutReceivedMessages, args.MaxRoundsWithoutReceivedMessages) + } + if check.IfNil(args.P2PKeyConverter) { + return fmt.Errorf("%w for args.P2PKeyConverter", ErrNilP2PKeyConverter) + } + + return nil +} + +func createDataMap(namedIdentities []config.NamedIdentity) (map[string]*peerInfo, error) { + dataMap := make(map[string]*peerInfo) + + for _, identity := range namedIdentities { + for _, blsKey := range identity.BLSKeys { + bls, err := hex.DecodeString(blsKey) + if err != nil { + return nil, fmt.Errorf("%w for key %s", ErrInvalidKey, blsKey) + } + + blsStr := string(bls) + dataMap[blsStr] = &peerInfo{ + machineID: generateRandomMachineID(), + nodeName: identity.NodeName, + nodeIdentity: identity.Identity, + } + } + } + + return dataMap, nil +} + +// AddManagedPeer will try to add a new managed peer providing the private key bytes. +// It errors if the generated public key is already contained by the struct +// It will auto-generate some fields like the machineID and pid +func (holder *managedPeersHolder) AddManagedPeer(privateKeyBytes []byte) error { + privateKey, err := holder.keyGenerator.PrivateKeyFromByteArray(privateKeyBytes) + if err != nil { + return fmt.Errorf("%w for provided bytes %s", err, hex.EncodeToString(privateKeyBytes)) + } + + publicKey := privateKey.GeneratePublic() + publicKeyBytes, err := publicKey.ToByteArray() + if err != nil { + return fmt.Errorf("%w for provided bytes %s", err, hex.EncodeToString(privateKeyBytes)) + } + + p2pPrivateKey, p2pPublicKey := holder.p2pKeyGenerator.GeneratePair() + + p2pPrivateKeyBytes, err := p2pPrivateKey.ToByteArray() + if err != nil { + return err + } + + pid, err := holder.p2pKeyConverter.ConvertPublicKeyToPeerID(p2pPublicKey) + if err != nil { + return err + } + + holder.mut.Lock() + defer holder.mut.Unlock() + + pInfo, found := holder.data[string(publicKeyBytes)] + if found && len(pInfo.pid.Bytes()) != 0 { + return fmt.Errorf("%w for provided bytes %s and generated public key %s", + ErrDuplicatedKey, hex.EncodeToString(privateKeyBytes), hex.EncodeToString(publicKeyBytes)) + } + + if !found { + pInfo = &peerInfo{ + machineID: generateRandomMachineID(), + nodeName: holder.defaultName, + nodeIdentity: holder.defaultIdentity, + } + } + + pInfo.pid = pid + pInfo.p2pPrivateKeyBytes = p2pPrivateKeyBytes + pInfo.privateKey = privateKey + holder.data[string(publicKeyBytes)] = pInfo + holder.pids[pid] = struct{}{} + + log.Debug("added new key definition", + "hex public key", hex.EncodeToString(publicKeyBytes), + "pid", pid.Pretty(), + "machine ID", pInfo.machineID, + "name", pInfo.nodeName, + "identity", pInfo.nodeIdentity) + + return nil +} + +func (holder *managedPeersHolder) getPeerInfo(pkBytes []byte) *peerInfo { + holder.mut.RLock() + defer holder.mut.RUnlock() + + return holder.data[string(pkBytes)] +} + +func generateRandomMachineID() string { + buff := make([]byte, core.MaxMachineIDLen/2) + _, _ = rand.Read(buff) + + return hex.EncodeToString(buff) +} + +// GetPrivateKey returns the associated private key with the provided public key bytes. Errors if the key is not found +func (holder *managedPeersHolder) GetPrivateKey(pkBytes []byte) (crypto.PrivateKey, error) { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return nil, fmt.Errorf("%w in GetPrivateKey for public key %s", + ErrMissingPublicKeyDefinition, hex.EncodeToString(pkBytes)) + } + + return pInfo.privateKey, nil +} + +// GetP2PIdentity returns the associated p2p identity with the provided public key bytes: the private key and the peer ID +func (holder *managedPeersHolder) GetP2PIdentity(pkBytes []byte) ([]byte, core.PeerID, error) { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return nil, "", fmt.Errorf("%w in GetP2PIdentity for public key %s", + ErrMissingPublicKeyDefinition, hex.EncodeToString(pkBytes)) + } + + return pInfo.p2pPrivateKeyBytes, pInfo.pid, nil +} + +// GetMachineID returns the associated machine ID with the provided public key bytes +func (holder *managedPeersHolder) GetMachineID(pkBytes []byte) (string, error) { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return "", fmt.Errorf("%w in GetMachineID for public key %s", + ErrMissingPublicKeyDefinition, hex.EncodeToString(pkBytes)) + } + + return pInfo.machineID, nil +} + +// GetNameAndIdentity returns the associated name and identity with the provided public key bytes +func (holder *managedPeersHolder) GetNameAndIdentity(pkBytes []byte) (string, string, error) { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return "", "", fmt.Errorf("%w in GetNameAndIdentity for public key %s", + ErrMissingPublicKeyDefinition, hex.EncodeToString(pkBytes)) + } + + return pInfo.nodeName, pInfo.nodeIdentity, nil +} + +// IncrementRoundsWithoutReceivedMessages increments the number of rounds without received messages on a provided public key +func (holder *managedPeersHolder) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { + if holder.isMainMachine { + return + } + + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return + } + + pInfo.incrementRoundsWithoutReceivedMessages() +} + +// ResetRoundsWithoutReceivedMessages resets the number of rounds without received messages on a provided public key +func (holder *managedPeersHolder) ResetRoundsWithoutReceivedMessages(pkBytes []byte) { + if holder.isMainMachine { + return + } + + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return + } + + pInfo.resetRoundsWithoutReceivedMessages() +} + +// GetManagedKeysByCurrentNode returns all keys that will be managed by this node +func (holder *managedPeersHolder) GetManagedKeysByCurrentNode() map[string]crypto.PrivateKey { + holder.mut.RLock() + defer holder.mut.RUnlock() + + allManagedKeys := make(map[string]crypto.PrivateKey) + for pk, pInfo := range holder.data { + isSlaveAndMainFailed := !holder.isMainMachine && !pInfo.isNodeActiveOnMainMachine(holder.maxRoundsWithoutReceivedMessages) + shouldAddToMap := holder.isMainMachine || isSlaveAndMainFailed + if !shouldAddToMap { + continue + } + + allManagedKeys[pk] = pInfo.privateKey + } + + return allManagedKeys +} + +// IsKeyManagedByCurrentNode returns true if the key is managed by the current node +func (holder *managedPeersHolder) IsKeyManagedByCurrentNode(pkBytes []byte) bool { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return false + } + + if holder.isMainMachine { + return true + } + + return !pInfo.isNodeActiveOnMainMachine(holder.maxRoundsWithoutReceivedMessages) +} + +// IsKeyRegistered returns true if the key is registered (not necessarily managed by the current node) +func (holder *managedPeersHolder) IsKeyRegistered(pkBytes []byte) bool { + pInfo := holder.getPeerInfo(pkBytes) + return pInfo != nil +} + +// IsPidManagedByCurrentNode returns true if the peer id is managed by the current node +func (holder *managedPeersHolder) IsPidManagedByCurrentNode(pid core.PeerID) bool { + holder.mut.RLock() + defer holder.mut.RUnlock() + + _, found := holder.pids[pid] + + return found +} + +// IsKeyValidator returns true if the key validator status was set to true +func (holder *managedPeersHolder) IsKeyValidator(pkBytes []byte) bool { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return false + } + + return pInfo.isNodeValidator() +} + +// SetValidatorState sets the provided validator status for the key +func (holder *managedPeersHolder) SetValidatorState(pkBytes []byte, state bool) { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return + } + + pInfo.setNodeValidator(state) +} + +// GetNextPeerAuthenticationTime returns the next time the key should try to send peer authentication again +func (holder *managedPeersHolder) GetNextPeerAuthenticationTime(pkBytes []byte) (time.Time, error) { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return time.Now(), fmt.Errorf("%w in GetNextPeerAuthenticationTime for public key %s", + ErrMissingPublicKeyDefinition, hex.EncodeToString(pkBytes)) + } + + return pInfo.getNextPeerAuthenticationTime(), nil +} + +// SetNextPeerAuthenticationTime sets the next time the key should try to send peer authentication +func (holder *managedPeersHolder) SetNextPeerAuthenticationTime(pkBytes []byte, nextTime time.Time) { + pInfo := holder.getPeerInfo(pkBytes) + if pInfo == nil { + return + } + + pInfo.setNextPeerAuthenticationTime(nextTime) +} + +// IsMultiKeyMode returns true if the node has at least one managed key +func (holder *managedPeersHolder) IsMultiKeyMode() bool { + return len(holder.GetManagedKeysByCurrentNode()) > 0 +} + +// IsInterfaceNil returns true if there is no value under the interface +func (holder *managedPeersHolder) IsInterfaceNil() bool { + return holder == nil +} diff --git a/keysManagement/managedPeersHolder_test.go b/keysManagement/managedPeersHolder_test.go new file mode 100644 index 00000000000..a69aa8be68b --- /dev/null +++ b/keysManagement/managedPeersHolder_test.go @@ -0,0 +1,824 @@ +package keysManagement_test + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/keysManagement" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/stretchr/testify/assert" +) + +var ( + p2pPrivateKey = []byte("p2p private key") + + skBytes0 = []byte("private key 0") + skBytes1 = []byte("private key 1") + pkBytes0 = []byte("public key 0") + pkBytes1 = []byte("public key 1") + defaultName = "default node name" + defaultIdentity = "default node identity" + p2pPkHex = "03ca8ec5bd3b84d05e59d5c9cecd548059106649d9e3465f628498628732ab23c9" + p2pPidString = "16Uiu2HAmSHgyTYyawhsZv9opxTHX77vKjoPeGkyCYS5fYVMssHjN" + pid, _ = core.NewPeerID(p2pPidString) +) + +func createMockArgsManagedPeersHolder() keysManagement.ArgsManagedPeersHolder { + return keysManagement.ArgsManagedPeersHolder{ + KeyGenerator: createMockKeyGenerator(), + P2PKeyGenerator: &cryptoMocks.KeyGenStub{ + GeneratePairStub: func() (crypto.PrivateKey, crypto.PublicKey) { + return &cryptoMocks.PrivateKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return []byte("p2p private key"), nil + }, + }, + &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return hex.DecodeString(p2pPkHex) + }, + } + }, + }, + IsMainMachine: true, + MaxRoundsWithoutReceivedMessages: 1, + PrefsConfig: config.Preferences{ + Preferences: config.PreferencesConfig{ + Identity: defaultIdentity, + NodeDisplayName: defaultName, + }, + }, + P2PKeyConverter: &p2pmocks.P2PKeyConverterStub{ + ConvertPublicKeyToPeerIDCalled: func(pk crypto.PublicKey) (core.PeerID, error) { + return pid, nil + }, + }, + } +} + +func createMockKeyGenerator() crypto.KeyGenerator { + return &cryptoMocks.KeyGenStub{ + PrivateKeyFromByteArrayStub: func(b []byte) (crypto.PrivateKey, error) { + return &cryptoMocks.PrivateKeyStub{ + GeneratePublicStub: func() crypto.PublicKey { + pk := &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + pkBytes := bytes.Replace(b, []byte("private"), []byte("public"), -1) + + return pkBytes, nil + }, + } + + return pk + }, + ToByteArrayStub: func() ([]byte, error) { + return b, nil + }, + }, nil + }, + } +} + +func testManagedKeys(tb testing.TB, result map[string]crypto.PrivateKey, pkBytes ...[]byte) { + assert.Equal(tb, len(pkBytes), len(result)) + + for _, pk := range pkBytes { + _, found := result[string(pk)] + assert.True(tb, found) + } +} + +func TestNewManagedPeersHolder(t *testing.T) { + t.Parallel() + + t.Run("nil key generator should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + args.KeyGenerator = nil + holder, err := keysManagement.NewManagedPeersHolder(args) + + assert.ErrorIs(t, err, keysManagement.ErrNilKeyGenerator) + assert.Contains(t, err.Error(), "for args.KeyGenerator") + assert.True(t, check.IfNil(holder)) + }) + t.Run("nil p2p key generator should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + args.P2PKeyGenerator = nil + holder, err := keysManagement.NewManagedPeersHolder(args) + + assert.ErrorIs(t, err, keysManagement.ErrNilKeyGenerator) + assert.Contains(t, err.Error(), "for args.P2PKeyGenerator") + assert.True(t, check.IfNil(holder)) + }) + t.Run("invalid MaxRoundsWithoutReceivedMessages should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + args.MaxRoundsWithoutReceivedMessages = -2 + holder, err := keysManagement.NewManagedPeersHolder(args) + + assert.True(t, errors.Is(err, keysManagement.ErrInvalidValue)) + assert.True(t, strings.Contains(err.Error(), "MaxRoundsWithoutReceivedMessages")) + assert.True(t, check.IfNil(holder)) + }) + t.Run("invalid key from config should error", func(t *testing.T) { + t.Parallel() + + providedInvalidKey := "invalid key" + args := createMockArgsManagedPeersHolder() + args.PrefsConfig.NamedIdentity = []config.NamedIdentity{ + { + BLSKeys: []string{providedInvalidKey}, + }, + } + holder, err := keysManagement.NewManagedPeersHolder(args) + + assert.True(t, errors.Is(err, keysManagement.ErrInvalidKey)) + assert.True(t, strings.Contains(err.Error(), providedInvalidKey)) + assert.True(t, check.IfNil(holder)) + }) + t.Run("nil p2p key converter should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + args.P2PKeyConverter = nil + holder, err := keysManagement.NewManagedPeersHolder(args) + + assert.True(t, errors.Is(err, keysManagement.ErrNilP2PKeyConverter)) + assert.True(t, check.IfNil(holder)) + }) + t.Run("valid arguments should work", func(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + holder, err := keysManagement.NewManagedPeersHolder(args) + + assert.Nil(t, err) + assert.False(t, check.IfNil(holder)) + }) +} + +func TestManagedPeersHolder_AddManagedPeer(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected error") + t.Run("private key from byte array errors", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.KeyGenerator = &cryptoMocks.KeyGenStub{ + PrivateKeyFromByteArrayStub: func(b []byte) (crypto.PrivateKey, error) { + return nil, expectedErr + }, + } + + holder, _ := keysManagement.NewManagedPeersHolder(args) + err := holder.AddManagedPeer([]byte("private key")) + + assert.True(t, errors.Is(err, expectedErr)) + }) + t.Run("public key from byte array errors", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + + pk := &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return nil, expectedErr + }, + } + + sk := &cryptoMocks.PrivateKeyStub{ + GeneratePublicStub: func() crypto.PublicKey { + return pk + }, + } + + args.KeyGenerator = &cryptoMocks.KeyGenStub{ + PrivateKeyFromByteArrayStub: func(b []byte) (crypto.PrivateKey, error) { + return sk, nil + }, + } + + holder, _ := keysManagement.NewManagedPeersHolder(args) + err := holder.AddManagedPeer([]byte("private key")) + + assert.True(t, errors.Is(err, expectedErr)) + }) + t.Run("p2p key generation returns an invalid private key creation errors", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.P2PKeyGenerator = &cryptoMocks.KeyGenStub{ + GeneratePairStub: func() (crypto.PrivateKey, crypto.PublicKey) { + return &cryptoMocks.PrivateKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return nil, expectedErr + }, + }, &cryptoMocks.PublicKeyStub{} + }, + } + + holder, _ := keysManagement.NewManagedPeersHolder(args) + err := holder.AddManagedPeer([]byte("private key")) + + assert.True(t, errors.Is(err, expectedErr)) + }) + t.Run("should work for a new pk", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + + holder, _ := keysManagement.NewManagedPeersHolder(args) + err := holder.AddManagedPeer(skBytes0) + assert.Nil(t, err) + + pInfo := holder.GetPeerInfo(pkBytes0) + assert.NotNil(t, pInfo) + assert.Equal(t, pid, pInfo.Pid()) + assert.Equal(t, p2pPrivateKey, pInfo.P2pPrivateKeyBytes()) + skBytesRecovered, _ := pInfo.PrivateKey().ToByteArray() + assert.Equal(t, skBytes0, skBytesRecovered) + assert.Equal(t, 10, len(pInfo.MachineID())) + assert.Equal(t, defaultIdentity, pInfo.NodeIdentity()) + assert.Equal(t, defaultName, pInfo.NodeName()) + }) + t.Run("should work for a new pk with identity from config", func(t *testing.T) { + providedAddress := []byte("erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th") + providedHex := hex.EncodeToString(providedAddress) + providedName := "provided name" + providedIdentity := "provided identity" + args := createMockArgsManagedPeersHolder() + args.KeyGenerator = &cryptoMocks.KeyGenStub{ + PrivateKeyFromByteArrayStub: func(b []byte) (crypto.PrivateKey, error) { + return &cryptoMocks.PrivateKeyStub{ + GeneratePublicStub: func() crypto.PublicKey { + return &cryptoMocks.PublicKeyStub{ + ToByteArrayStub: func() ([]byte, error) { + return providedAddress, nil + }, + } + }, + ToByteArrayStub: func() ([]byte, error) { + return providedAddress, nil + }, + }, nil + }, + } + args.PrefsConfig.NamedIdentity = []config.NamedIdentity{ + { + Identity: providedIdentity, + NodeName: providedName, + BLSKeys: []string{providedHex}, + }, + } + + holder, _ := keysManagement.NewManagedPeersHolder(args) + err := holder.AddManagedPeer(skBytes0) + assert.Nil(t, err) + + pInfo := holder.GetPeerInfo(providedAddress) + assert.NotNil(t, pInfo) + assert.Equal(t, pid, pInfo.Pid()) + assert.Equal(t, p2pPrivateKey, pInfo.P2pPrivateKeyBytes()) + skBytesRecovered, _ := pInfo.PrivateKey().ToByteArray() + assert.Equal(t, providedAddress, skBytesRecovered) + assert.Equal(t, 10, len(pInfo.MachineID())) + assert.Equal(t, providedIdentity, pInfo.NodeIdentity()) + assert.Equal(t, providedName, pInfo.NodeName()) + }) + t.Run("should error when trying to add the same pk", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + + holder, _ := keysManagement.NewManagedPeersHolder(args) + err := holder.AddManagedPeer(skBytes0) + assert.Nil(t, err) + + err = holder.AddManagedPeer(skBytes0) + assert.True(t, errors.Is(err, keysManagement.ErrDuplicatedKey)) + }) + t.Run("should work for 2 new public keys", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + + holder, _ := keysManagement.NewManagedPeersHolder(args) + err := holder.AddManagedPeer(skBytes0) + assert.Nil(t, err) + + err = holder.AddManagedPeer(skBytes1) + assert.Nil(t, err) + + pInfo0 := holder.GetPeerInfo(pkBytes0) + assert.NotNil(t, pInfo0) + + pInfo1 := holder.GetPeerInfo(pkBytes1) + assert.NotNil(t, pInfo1) + + assert.Equal(t, p2pPrivateKey, pInfo0.P2pPrivateKeyBytes()) + assert.Equal(t, p2pPrivateKey, pInfo1.P2pPrivateKeyBytes()) + + assert.Equal(t, pid, pInfo0.Pid()) + assert.Equal(t, pid, pInfo1.Pid()) + + skBytesRecovered0, _ := pInfo0.PrivateKey().ToByteArray() + assert.Equal(t, skBytes0, skBytesRecovered0) + + skBytesRecovered1, _ := pInfo1.PrivateKey().ToByteArray() + assert.Equal(t, skBytes1, skBytesRecovered1) + + assert.NotEqual(t, pInfo0.MachineID(), pInfo1.MachineID()) + assert.Equal(t, 10, len(pInfo0.MachineID())) + assert.Equal(t, 10, len(pInfo1.MachineID())) + }) +} + +func TestManagedPeersHolder_GetPrivateKey(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + t.Run("public key not added should error", func(t *testing.T) { + skRecovered, err := holder.GetPrivateKey(pkBytes1) + assert.Nil(t, skRecovered) + assert.True(t, errors.Is(err, keysManagement.ErrMissingPublicKeyDefinition)) + }) + t.Run("public key exists should return the private key", func(t *testing.T) { + skRecovered, err := holder.GetPrivateKey(pkBytes0) + assert.Nil(t, err) + + skBytesRecovered, _ := skRecovered.ToByteArray() + assert.Equal(t, skBytes0, skBytesRecovered) + }) +} + +func TestManagedPeersHolder_GetP2PIdentity(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + t.Run("public key not added should error", func(t *testing.T) { + p2pPrivateKeyRecovered, pidRecovered, err := holder.GetP2PIdentity(pkBytes1) + assert.Nil(t, p2pPrivateKeyRecovered) + assert.Empty(t, pidRecovered) + assert.True(t, errors.Is(err, keysManagement.ErrMissingPublicKeyDefinition)) + }) + t.Run("public key exists should return the p2p identity", func(t *testing.T) { + p2pPrivateKeyRecovered, pidRecovered, err := holder.GetP2PIdentity(pkBytes0) + assert.Nil(t, err) + assert.Equal(t, p2pPrivateKey, p2pPrivateKeyRecovered) + assert.Equal(t, pid, pidRecovered) + }) +} + +func TestManagedPeersHolder_GetMachineID(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + t.Run("public key not added should error", func(t *testing.T) { + machineIDRecovered, err := holder.GetMachineID(pkBytes1) + assert.Empty(t, machineIDRecovered) + assert.True(t, errors.Is(err, keysManagement.ErrMissingPublicKeyDefinition)) + }) + t.Run("public key exists should return machine ID", func(t *testing.T) { + machineIDRecovered, err := holder.GetMachineID(pkBytes0) + assert.Nil(t, err) + assert.Equal(t, 10, len(machineIDRecovered)) + }) +} + +func TestManagedPeersHolder_GetNameAndIdentity(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + t.Run("public key not added should error", func(t *testing.T) { + name, identity, err := holder.GetNameAndIdentity(pkBytes1) + assert.Empty(t, name) + assert.Empty(t, identity) + assert.True(t, errors.Is(err, keysManagement.ErrMissingPublicKeyDefinition)) + }) + t.Run("public key exists should return name and identity", func(t *testing.T) { + name, identity, err := holder.GetNameAndIdentity(pkBytes0) + assert.Nil(t, err) + assert.Equal(t, defaultName, name) + assert.Equal(t, defaultIdentity, identity) + }) +} + +func TestManagedPeersHolder_IncrementRoundsWithoutReceivedMessages(t *testing.T) { + t.Parallel() + + t.Run("is main machine should ignore the call", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = true + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("missing public key should not panic", func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked %v", r)) + } + }() + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes1) + pInfoRecovered := holder.GetPeerInfo(pkBytes1) + assert.Nil(t, pInfoRecovered) + }) + t.Run("existing public key", func(t *testing.T) { + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + + pInfoRecovered := holder.GetPeerInfo(pkBytes0) + assert.Zero(t, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + }) + }) + t.Run("is secondary machine should increment, if existing", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = false + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("missing public key should not panic", func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked %v", r)) + } + }() + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes1) + pInfoRecovered := holder.GetPeerInfo(pkBytes1) + assert.Nil(t, pInfoRecovered) + }) + t.Run("existing public key should increment", func(t *testing.T) { + pInfoRecovered := holder.GetPeerInfo(pkBytes0) + assert.Zero(t, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + + pInfoRecovered = holder.GetPeerInfo(pkBytes0) + assert.Equal(t, 1, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + + pInfoRecovered = holder.GetPeerInfo(pkBytes0) + assert.Equal(t, 2, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + }) + }) +} + +func TestManagedPeersHolder_ResetRoundsWithoutReceivedMessages(t *testing.T) { + t.Parallel() + + t.Run("is main machine should ignore the call", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = true + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("missing public key should not panic", func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked %v", r)) + } + }() + + holder.ResetRoundsWithoutReceivedMessages(pkBytes1) + }) + t.Run("existing public key", func(t *testing.T) { + holder.ResetRoundsWithoutReceivedMessages(pkBytes0) + + pInfoRecovered := holder.GetPeerInfo(pkBytes0) + assert.Zero(t, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + }) + }) + t.Run("is secondary machine should reset, if existing", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = false + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("missing public key should not panic", func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked %v", r)) + } + }() + + holder.ResetRoundsWithoutReceivedMessages(pkBytes1) + }) + t.Run("existing public key should reset", func(t *testing.T) { + pInfoRecovered := holder.GetPeerInfo(pkBytes0) + assert.Zero(t, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + pInfoRecovered = holder.GetPeerInfo(pkBytes0) + assert.Equal(t, 1, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + + holder.ResetRoundsWithoutReceivedMessages(pkBytes0) + + pInfoRecovered = holder.GetPeerInfo(pkBytes0) + assert.Equal(t, 0, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + pInfoRecovered = holder.GetPeerInfo(pkBytes0) + assert.Equal(t, 3, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + + holder.ResetRoundsWithoutReceivedMessages(pkBytes0) + + pInfoRecovered = holder.GetPeerInfo(pkBytes0) + assert.Equal(t, 0, pInfoRecovered.GetRoundsWithoutReceivedMessages()) + }) + }) +} + +func TestManagedPeersHolder_GetManagedKeysByCurrentNode(t *testing.T) { + t.Parallel() + + t.Run("main machine should return all keys, always", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = true + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + _ = holder.AddManagedPeer(skBytes1) + + for i := 0; i < 10; i++ { + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + } + + result := holder.GetManagedKeysByCurrentNode() + testManagedKeys(t, result, pkBytes0, pkBytes1) + }) + t.Run("is secondary machine should return managed keys", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = false + args.MaxRoundsWithoutReceivedMessages = 2 + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + _ = holder.AddManagedPeer(skBytes1) + + t.Run("MaxRoundsWithoutReceivedMessages not reached should return none", func(t *testing.T) { + result := holder.GetManagedKeysByCurrentNode() + testManagedKeys(t, result) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + + result = holder.GetManagedKeysByCurrentNode() + testManagedKeys(t, result) + }) + t.Run("MaxRoundsWithoutReceivedMessages reached, should return failed pk", func(t *testing.T) { + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + + result := holder.GetManagedKeysByCurrentNode() + testManagedKeys(t, result, pkBytes0) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + result = holder.GetManagedKeysByCurrentNode() + testManagedKeys(t, result, pkBytes0) + }) + }) +} + +func TestManagedPeersHolder_IsKeyManagedByCurrentNode(t *testing.T) { + t.Parallel() + + t.Run("main machine", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = true + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("foreign public key should return false", func(t *testing.T) { + isManaged := holder.IsKeyManagedByCurrentNode(pkBytes1) + assert.False(t, isManaged) + }) + t.Run("managed key should return true", func(t *testing.T) { + isManaged := holder.IsKeyManagedByCurrentNode(pkBytes0) + assert.True(t, isManaged) + }) + }) + t.Run("secondary machine", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = false + args.MaxRoundsWithoutReceivedMessages = 2 + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("foreign public key should return false", func(t *testing.T) { + isManaged := holder.IsKeyManagedByCurrentNode(pkBytes1) + assert.False(t, isManaged) + }) + t.Run("managed key should return false while MaxRoundsWithoutReceivedMessages is not reached", func(t *testing.T) { + isManaged := holder.IsKeyManagedByCurrentNode(pkBytes0) + assert.False(t, isManaged) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + isManaged = holder.IsKeyManagedByCurrentNode(pkBytes0) + assert.False(t, isManaged) + + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + isManaged = holder.IsKeyManagedByCurrentNode(pkBytes0) + assert.True(t, isManaged) + + holder.ResetRoundsWithoutReceivedMessages(pkBytes0) + isManaged = holder.IsKeyManagedByCurrentNode(pkBytes0) + assert.False(t, isManaged) + }) + }) +} + +func TestManagedPeersHolder_IsKeyRegistered(t *testing.T) { + t.Parallel() + + t.Run("main machine", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = true + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("foreign public key should return false", func(t *testing.T) { + isManaged := holder.IsKeyRegistered(pkBytes1) + assert.False(t, isManaged) + }) + t.Run("registered key should return true", func(t *testing.T) { + isManaged := holder.IsKeyRegistered(pkBytes0) + assert.True(t, isManaged) + }) + }) + t.Run("secondary machine", func(t *testing.T) { + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = false + holder, _ := keysManagement.NewManagedPeersHolder(args) + _ = holder.AddManagedPeer(skBytes0) + + t.Run("foreign public key should return false", func(t *testing.T) { + isManaged := holder.IsKeyRegistered(pkBytes1) + assert.False(t, isManaged) + }) + t.Run("registered key should return true", func(t *testing.T) { + isManaged := holder.IsKeyRegistered(pkBytes0) + assert.True(t, isManaged) + }) + }) +} + +func TestManagedPeersHolder_IsPidManagedByCurrentNode(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + args.IsMainMachine = true + holder, _ := keysManagement.NewManagedPeersHolder(args) + + t.Run("empty holder should return false", func(t *testing.T) { + isManaged := holder.IsPidManagedByCurrentNode(pid) + assert.False(t, isManaged) + }) + + _ = holder.AddManagedPeer(skBytes0) + + t.Run("pid not managed by current should return false", func(t *testing.T) { + isManaged := holder.IsPidManagedByCurrentNode("other pid") + assert.False(t, isManaged) + }) + t.Run("pid managed by current should return true", func(t *testing.T) { + isManaged := holder.IsPidManagedByCurrentNode(pid) + assert.True(t, isManaged) + }) +} + +func TestManagedPeersHolder_IsKeyValidator(t *testing.T) { + t.Parallel() + + holder, _ := keysManagement.NewManagedPeersHolder(createMockArgsManagedPeersHolder()) + + t.Run("missing key should return false", func(t *testing.T) { + isValidator := holder.IsKeyValidator(pkBytes0) + assert.False(t, isValidator) + }) + + _ = holder.AddManagedPeer(skBytes0) + + t.Run("key found, but not validator should return false", func(t *testing.T) { + isValidator := holder.IsKeyValidator(pkBytes0) + assert.False(t, isValidator) + }) + t.Run("key found and validator should return true", func(t *testing.T) { + holder.SetValidatorState(pkBytes0, true) + isValidator := holder.IsKeyValidator(pkBytes0) + assert.True(t, isValidator) + }) +} + +func TestManagedPeersHolder_GetNextPeerAuthenticationTime(t *testing.T) { + t.Parallel() + + holder, _ := keysManagement.NewManagedPeersHolder(createMockArgsManagedPeersHolder()) + + t.Run("missing key should return error", func(t *testing.T) { + timeBefore := time.Now() + nextTime, err := holder.GetNextPeerAuthenticationTime(pkBytes0) + timeAfter := time.Now() + assert.NotNil(t, err) + assert.True(t, errors.Is(err, keysManagement.ErrMissingPublicKeyDefinition)) + assert.True(t, strings.Contains(err.Error(), hex.EncodeToString(pkBytes0))) + assert.LessOrEqual(t, nextTime, timeAfter) + assert.Greater(t, nextTime, timeBefore) + }) + + _ = holder.AddManagedPeer(skBytes0) + + t.Run("key found should work", func(t *testing.T) { + expectedNextTime := time.Now().Add(time.Hour) + holder.SetNextPeerAuthenticationTime(pkBytes0, expectedNextTime) + nextTime, err := holder.GetNextPeerAuthenticationTime(pkBytes0) + assert.Nil(t, err) + assert.Equal(t, expectedNextTime, nextTime) + }) +} + +func TestManagedPeersHolder_IsMultiKeyMode(t *testing.T) { + t.Parallel() + + args := createMockArgsManagedPeersHolder() + holder, _ := keysManagement.NewManagedPeersHolder(args) + assert.False(t, holder.IsMultiKeyMode()) + + _ = holder.AddManagedPeer(skBytes0) + assert.True(t, holder.IsMultiKeyMode()) +} + +func TestManagedPeersHolder_ParallelOperationsShouldNotPanic(t *testing.T) { + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked %v", r)) + } + }() + + args := createMockArgsManagedPeersHolder() + holder, _ := keysManagement.NewManagedPeersHolder(args) + + numOperations := 1500 + wg := sync.WaitGroup{} + wg.Add(numOperations) + for i := 0; i < numOperations; i++ { + go func(idOperation int) { + time.Sleep(time.Millisecond * 10) // increase the chance of concurrent operations + + switch idOperation { + case 0: + randomBytes := make([]byte, 32) + _, _ = rand.Read(randomBytes) + _ = holder.AddManagedPeer(randomBytes) + case 1: + _, _ = holder.GetMachineID(pkBytes1) + case 2: + _, _, _ = holder.GetP2PIdentity(pkBytes1) + case 3: + _, _ = holder.GetPrivateKey(pkBytes1) + case 4: + holder.IncrementRoundsWithoutReceivedMessages(pkBytes0) + case 5: + holder.ResetRoundsWithoutReceivedMessages(pkBytes0) + case 6: + _ = holder.GetManagedKeysByCurrentNode() + case 7: + _ = holder.IsKeyManagedByCurrentNode(pkBytes0) + case 8: + _ = holder.IsKeyRegistered(pkBytes0) + case 9: + _ = holder.IsPidManagedByCurrentNode("pid") + case 10: + _ = holder.IsKeyValidator(pkBytes0) + case 11: + holder.SetValidatorState(pkBytes0, true) + case 12: + _, _ = holder.GetNextPeerAuthenticationTime(pkBytes0) + case 13: + holder.SetNextPeerAuthenticationTime(pkBytes0, time.Now()) + } + + wg.Done() + }(i % 14) + } + + wg.Wait() +} diff --git a/keysManagement/peerInfo.go b/keysManagement/peerInfo.go new file mode 100644 index 00000000000..98bf994b55b --- /dev/null +++ b/keysManagement/peerInfo.go @@ -0,0 +1,70 @@ +package keysManagement + +import ( + "sync" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type peerInfo struct { + pid core.PeerID + p2pPrivateKeyBytes []byte + privateKey crypto.PrivateKey + machineID string + nodeName string + nodeIdentity string + + mutChangeableData sync.RWMutex + roundsWithoutReceivedMessages int + nextPeerAuthenticationTime time.Time + isValidator bool +} + +func (pInfo *peerInfo) incrementRoundsWithoutReceivedMessages() { + pInfo.mutChangeableData.Lock() + pInfo.roundsWithoutReceivedMessages++ + pInfo.mutChangeableData.Unlock() +} + +func (pInfo *peerInfo) resetRoundsWithoutReceivedMessages() { + pInfo.mutChangeableData.Lock() + pInfo.roundsWithoutReceivedMessages = 0 + pInfo.mutChangeableData.Unlock() +} + +func (pInfo *peerInfo) isNodeActiveOnMainMachine(maxRoundsWithoutReceivedMessages int) bool { + pInfo.mutChangeableData.RLock() + defer pInfo.mutChangeableData.RUnlock() + + return pInfo.roundsWithoutReceivedMessages < maxRoundsWithoutReceivedMessages +} + +func (pInfo *peerInfo) isNodeValidator() bool { + pInfo.mutChangeableData.RLock() + defer pInfo.mutChangeableData.RUnlock() + + return pInfo.isValidator +} + +func (pInfo *peerInfo) setNodeValidator(value bool) { + pInfo.mutChangeableData.Lock() + defer pInfo.mutChangeableData.Unlock() + + pInfo.isValidator = value +} + +func (pInfo *peerInfo) getNextPeerAuthenticationTime() time.Time { + pInfo.mutChangeableData.RLock() + defer pInfo.mutChangeableData.RUnlock() + + return pInfo.nextPeerAuthenticationTime +} + +func (pInfo *peerInfo) setNextPeerAuthenticationTime(value time.Time) { + pInfo.mutChangeableData.Lock() + defer pInfo.mutChangeableData.Unlock() + + pInfo.nextPeerAuthenticationTime = value +} diff --git a/node/external/transactionAPI/apiTransactionProcessor.go b/node/external/transactionAPI/apiTransactionProcessor.go index 48391904536..30f0ff35843 100644 --- a/node/external/transactionAPI/apiTransactionProcessor.go +++ b/node/external/transactionAPI/apiTransactionProcessor.go @@ -64,7 +64,7 @@ func NewAPITransactionProcessor(args *ArgAPITransactionProcessor) (*apiTransacti ) refundDetector := newRefundDetector() - gasUsedAndFeeProc := newGasUsedAndFeeProcessor(args.FeeComputer) + gasUsedAndFeeProc := newGasUsedAndFeeProcessor(args.FeeComputer, args.AddressPubKeyConverter) return &apiTransactionProcessor{ roundDuration: args.RoundDuration, @@ -102,6 +102,10 @@ func (atp *apiTransactionProcessor) GetTransaction(txHash string, withResults bo atp.PopulateComputedFields(tx) atp.gasUsedAndFeeProcessor.computeAndAttachGasUsedAndFee(tx) + if withResults { + atp.gasUsedAndFeeProcessor.computeAndAttachGasUsedAndFee(tx) + } + return tx, nil } @@ -333,7 +337,7 @@ func (atp *apiTransactionProcessor) extractRequestedTxInfo(wrappedTx *txcache.Wr tx.TxFields[dataField] = wrappedTx.Tx.GetData() } if requestedFieldsHandler.HasValue { - tx.TxFields[valueField] = wrappedTx.Tx.GetValue() + tx.TxFields[valueField] = getTxValue(wrappedTx) } return tx @@ -671,6 +675,14 @@ func (atp *apiTransactionProcessor) castObjToTransaction(txObj interface{}, txTy return &transaction.ApiTransactionResult{Type: string(transaction.TxTypeInvalid)} } +func getTxValue(wrappedTx *txcache.WrappedTransaction) string { + txValue := wrappedTx.Tx.GetValue() + if txValue != nil { + return txValue.String() + } + return "0" +} + // UnmarshalTransaction will try to unmarshal the transaction bytes based on the transaction type func (atp *apiTransactionProcessor) UnmarshalTransaction(txBytes []byte, txType transaction.TxType) (*transaction.ApiTransactionResult, error) { tx, err := atp.txUnmarshaller.unmarshalTransaction(txBytes, txType) diff --git a/node/external/transactionAPI/apiTransactionProcessor_test.go b/node/external/transactionAPI/apiTransactionProcessor_test.go index 0c81f96b489..a469cb7aef0 100644 --- a/node/external/transactionAPI/apiTransactionProcessor_test.go +++ b/node/external/transactionAPI/apiTransactionProcessor_test.go @@ -748,6 +748,7 @@ func createTx(hash []byte, sender string, nonce uint64) *txcache.WrappedTransact tx := &transaction.Transaction{ SndAddr: []byte(sender), Nonce: nonce, + Value: big.NewInt(100000 + int64(nonce)), } return &txcache.WrappedTransaction{ @@ -821,11 +822,13 @@ func TestApiTransactionProcessor_GetTransactionsPoolForSender(t *testing.T) { require.NoError(t, err) require.NotNil(t, atp) - res, err := atp.GetTransactionsPoolForSender(sender, "sender") + res, err := atp.GetTransactionsPoolForSender(sender, "sender,value") require.NoError(t, err) expectedHashes := []string{hex.EncodeToString(txHash0), hex.EncodeToString(txHash1), hex.EncodeToString(txHash2), hex.EncodeToString(txHash3), hex.EncodeToString(txHash4)} + expectedValues := []string{"100001", "100002", "100003", "100004", "100005"} for i, tx := range res.Transactions { require.Equal(t, expectedHashes[i], tx.TxFields[hashField]) + require.Equal(t, expectedValues[i], tx.TxFields[valueField]) require.Equal(t, sender, tx.TxFields["sender"]) } diff --git a/node/external/transactionAPI/gasUsedAndFeeProcessor.go b/node/external/transactionAPI/gasUsedAndFeeProcessor.go index f5addcedaed..42290a7db19 100644 --- a/node/external/transactionAPI/gasUsedAndFeeProcessor.go +++ b/node/external/transactionAPI/gasUsedAndFeeProcessor.go @@ -8,12 +8,14 @@ import ( ) type gasUsedAndFeeProcessor struct { - feeComputer feeComputer + feeComputer feeComputer + pubKeyConverter core.PubkeyConverter } -func newGasUsedAndFeeProcessor(txFeeCalculator feeComputer) *gasUsedAndFeeProcessor { +func newGasUsedAndFeeProcessor(txFeeCalculator feeComputer, pubKeyConverter core.PubkeyConverter) *gasUsedAndFeeProcessor { return &gasUsedAndFeeProcessor{ - feeComputer: txFeeCalculator, + feeComputer: txFeeCalculator, + pubKeyConverter: pubKeyConverter, } } @@ -24,7 +26,7 @@ func (gfp *gasUsedAndFeeProcessor) computeAndAttachGasUsedAndFee(tx *transaction tx.GasUsed = gasUsed tx.Fee = fee.String() - if tx.IsRelayed { + if tx.IsRelayed || gfp.isESDTOperationWithSCCall(tx) { tx.GasUsed = tx.GasLimit tx.Fee = tx.InitiallyPaidFee } @@ -74,3 +76,33 @@ func (gfp *gasUsedAndFeeProcessor) setGasUsedAndFeeBaseOnRefundValue(tx *transac tx.GasUsed = gasUsed tx.Fee = fee.String() } + +func (gfp *gasUsedAndFeeProcessor) isESDTOperationWithSCCall(tx *transaction.ApiTransactionResult) bool { + isESDTTransferOperation := tx.Operation == core.BuiltInFunctionESDTTransfer || + tx.Operation == core.BuiltInFunctionESDTNFTTransfer || tx.Operation == core.BuiltInFunctionMultiESDTNFTTransfer + + isReceiverSC := core.IsSmartContractAddress(tx.Tx.GetRcvAddr()) + hasFunction := tx.Function != "" + if !hasFunction { + return false + } + + if tx.Sender != tx.Receiver { + return isESDTTransferOperation && isReceiverSC && hasFunction + } + + if len(tx.Receivers) == 0 { + return false + } + + receiver := tx.Receivers[0] + decodedReceiver, err := gfp.pubKeyConverter.Decode(receiver) + if err != nil { + log.Warn("gasUsedAndFeeProcessor.isESDTOperationWithSCCall cannot decode receiver address", "error", err.Error()) + return false + } + + isReceiverSC = core.IsSmartContractAddress(decodedReceiver) + + return isESDTTransferOperation && isReceiverSC && hasFunction +} diff --git a/node/external/transactionAPI/gasUsedAndFeeProcessor_test.go b/node/external/transactionAPI/gasUsedAndFeeProcessor_test.go index f79ab3c2d4e..c6f73e074d8 100644 --- a/node/external/transactionAPI/gasUsedAndFeeProcessor_test.go +++ b/node/external/transactionAPI/gasUsedAndFeeProcessor_test.go @@ -24,7 +24,7 @@ func TestComputeTransactionGasUsedAndFeeMoveBalance(t *testing.T) { }) computer := fee.NewTestFeeComputer(feeComp) - gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer) + gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer, pubKeyConverter) sender := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" receiver := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" @@ -53,7 +53,7 @@ func TestComputeTransactionGasUsedAndFeeLogWithError(t *testing.T) { }) computer := fee.NewTestFeeComputer(feeComp) - gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer) + gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer, pubKeyConverter) sender := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" receiver := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" @@ -95,7 +95,7 @@ func TestComputeTransactionGasUsedAndFeeRelayedTxWithWriteLog(t *testing.T) { }) computer := fee.NewTestFeeComputer(feeComp) - gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer) + gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer, pubKeyConverter) sender := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" receiver := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" @@ -132,7 +132,7 @@ func TestComputeTransactionGasUsedAndFeeTransactionWithScrWithRefund(t *testing. }) computer := fee.NewTestFeeComputer(feeComp) - gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer) + gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer, pubKeyConverter) sender := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" receiver := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" @@ -166,3 +166,36 @@ func TestComputeTransactionGasUsedAndFeeTransactionWithScrWithRefund(t *testing. require.Equal(uint64(3_365_000), txWithSRefundSCR.GasUsed) require.Equal("98000000000000", txWithSRefundSCR.Fee) } + +func TestNFTTransferWithScCall(t *testing.T) { + require := require.New(t) + feeComp, _ := fee.NewFeeComputer(fee.ArgsNewFeeComputer{ + BuiltInFunctionsCostHandler: &testscommon.BuiltInCostHandlerStub{}, + EconomicsConfig: testscommon.GetEconomicsConfig(), + }) + computer := fee.NewTestFeeComputer(feeComp) + + gasUsedAndFeeProc := newGasUsedAndFeeProcessor(computer, pubKeyConverter) + + sender := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" + receiver := "erd1wc3uh22g2aved3qeehkz9kzgrjwxhg9mkkxp2ee7jj7ph34p2csq0n2y5x" + + tx := &transaction.ApiTransactionResult{ + Tx: &transaction.Transaction{ + GasLimit: 55_000_000, + GasPrice: 1000000000, + SndAddr: silentDecodeAddress(sender), + RcvAddr: silentDecodeAddress(receiver), + Data: []byte("ESDTNFTTransfer@434f572d636434363364@080c@01@00000000000000000500d3b28828d62052124f07dcd50ed31b0825f60eee1526@616363657074476c6f62616c4f66666572@c3e5q"), + }, + GasLimit: 55_000_000, + Receivers: []string{"erd1qqqqqqqqqqqqqpgq6wegs2xkypfpync8mn2sa5cmpqjlvrhwz5nqgepyg8"}, + Function: "acceptGlobalOffer", + Operation: "ESDTNFTTransfer", + } + tx.InitiallyPaidFee = feeComp.ComputeTransactionFee(tx).String() + + gasUsedAndFeeProc.computeAndAttachGasUsedAndFee(tx) + require.Equal(uint64(55_000_000), tx.GasUsed) + require.Equal("822250000000000", tx.Fee) +} diff --git a/node/mock/factory/cryptoComponentsStub.go b/node/mock/factory/cryptoComponentsStub.go index 573319def67..f7e61374bcf 100644 --- a/node/mock/factory/cryptoComponentsStub.go +++ b/node/mock/factory/cryptoComponentsStub.go @@ -5,6 +5,7 @@ import ( "sync" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/vm" @@ -12,24 +13,26 @@ import ( // CryptoComponentsMock - type CryptoComponentsMock struct { - PubKey crypto.PublicKey - PrivKey crypto.PrivateKey - P2pPubKey crypto.PublicKey - P2pPrivKey crypto.PrivateKey - P2pSig crypto.SingleSigner - PubKeyString string - PrivKeyBytes []byte - PubKeyBytes []byte - BlockSig crypto.SingleSigner - TxSig crypto.SingleSigner - MultiSigContainer cryptoCommon.MultiSignerContainer - PeerSignHandler crypto.PeerSignatureHandler - BlKeyGen crypto.KeyGenerator - TxKeyGen crypto.KeyGenerator - P2PKeyGen crypto.KeyGenerator - MsgSigVerifier vm.MessageSignVerifier - SigHandler consensus.SignatureHandler - mutMultiSig sync.RWMutex + PubKey crypto.PublicKey + PrivKey crypto.PrivateKey + P2pPubKey crypto.PublicKey + P2pPrivKey crypto.PrivateKey + P2pSig crypto.SingleSigner + PubKeyString string + PrivKeyBytes []byte + PubKeyBytes []byte + BlockSig crypto.SingleSigner + TxSig crypto.SingleSigner + MultiSigContainer cryptoCommon.MultiSignerContainer + PeerSignHandler crypto.PeerSignatureHandler + BlKeyGen crypto.KeyGenerator + TxKeyGen crypto.KeyGenerator + P2PKeyGen crypto.KeyGenerator + MsgSigVerifier vm.MessageSignVerifier + SigHandler consensus.SigningHandler + ManagedPeersHolderField common.ManagedPeersHolder + KeysHandlerField consensus.KeysHandler + mutMultiSig sync.RWMutex } // Create - @@ -154,30 +157,42 @@ func (ccm *CryptoComponentsMock) MessageSignVerifier() vm.MessageSignVerifier { return ccm.MsgSigVerifier } -// ConsensusSigHandler - -func (ccm *CryptoComponentsMock) ConsensusSigHandler() consensus.SignatureHandler { +// ConsensusSigningHandler - +func (ccm *CryptoComponentsMock) ConsensusSigningHandler() consensus.SigningHandler { return ccm.SigHandler } +// ManagedPeersHolder - +func (ccm *CryptoComponentsMock) ManagedPeersHolder() common.ManagedPeersHolder { + return ccm.ManagedPeersHolderField +} + +// KeysHandler - +func (ccm *CryptoComponentsMock) KeysHandler() consensus.KeysHandler { + return ccm.KeysHandlerField +} + // Clone - func (ccm *CryptoComponentsMock) Clone() interface{} { return &CryptoComponentsMock{ - PubKey: ccm.PubKey, - P2pPubKey: ccm.P2pPubKey, - PrivKey: ccm.PrivKey, - P2pPrivKey: ccm.P2pPrivKey, - PubKeyString: ccm.PubKeyString, - PrivKeyBytes: ccm.PrivKeyBytes, - PubKeyBytes: ccm.PubKeyBytes, - BlockSig: ccm.BlockSig, - TxSig: ccm.TxSig, - MultiSigContainer: ccm.MultiSigContainer, - PeerSignHandler: ccm.PeerSignHandler, - BlKeyGen: ccm.BlKeyGen, - TxKeyGen: ccm.TxKeyGen, - P2PKeyGen: ccm.P2PKeyGen, - MsgSigVerifier: ccm.MsgSigVerifier, - mutMultiSig: sync.RWMutex{}, + PubKey: ccm.PubKey, + P2pPubKey: ccm.P2pPubKey, + PrivKey: ccm.PrivKey, + P2pPrivKey: ccm.P2pPrivKey, + PubKeyString: ccm.PubKeyString, + PrivKeyBytes: ccm.PrivKeyBytes, + PubKeyBytes: ccm.PubKeyBytes, + BlockSig: ccm.BlockSig, + TxSig: ccm.TxSig, + MultiSigContainer: ccm.MultiSigContainer, + PeerSignHandler: ccm.PeerSignHandler, + BlKeyGen: ccm.BlKeyGen, + TxKeyGen: ccm.TxKeyGen, + P2PKeyGen: ccm.P2PKeyGen, + MsgSigVerifier: ccm.MsgSigVerifier, + KeysHandlerField: ccm.KeysHandlerField, + ManagedPeersHolderField: ccm.ManagedPeersHolderField, + mutMultiSig: sync.RWMutex{}, } } diff --git a/node/nodeDebugFactory/interceptedDebugHandler_test.go b/node/nodeDebugFactory/interceptedDebugHandler_test.go index e56382302cf..786d46e5d3a 100644 --- a/node/nodeDebugFactory/interceptedDebugHandler_test.go +++ b/node/nodeDebugFactory/interceptedDebugHandler_test.go @@ -22,7 +22,7 @@ func TestCreateInterceptedDebugHandler_NilNodeWrapperShouldErr(t *testing.T) { err := CreateInterceptedDebugHandler( nil, &testscommon.InterceptorsContainerStub{}, - &dataRetrieverMocks.ResolversContainerStub{}, + &dataRetrieverTests.ResolversContainerStub{}, &dataRetrieverTests.RequestersContainerStub{}, config.InterceptorResolverDebugConfig{}, ) @@ -36,7 +36,7 @@ func TestCreateInterceptedDebugHandler_NilInterceptorsShouldErr(t *testing.T) { err := CreateInterceptedDebugHandler( &mock.NodeWrapperStub{}, nil, - &dataRetrieverMocks.ResolversContainerStub{}, + &dataRetrieverTests.ResolversContainerStub{}, &dataRetrieverTests.RequestersFinderStub{}, config.InterceptorResolverDebugConfig{}, ) @@ -64,7 +64,7 @@ func TestCreateInterceptedDebugHandler_NilRequestersShouldErr(t *testing.T) { err := CreateInterceptedDebugHandler( &mock.NodeWrapperStub{}, &testscommon.InterceptorsContainerStub{}, - &dataRetrieverMocks.ResolversContainerStub{}, + &dataRetrieverTests.ResolversContainerStub{}, nil, config.InterceptorResolverDebugConfig{}, ) @@ -78,7 +78,7 @@ func TestCreateInterceptedDebugHandler_InvalidDebugConfigShouldErr(t *testing.T) err := CreateInterceptedDebugHandler( &mock.NodeWrapperStub{}, &testscommon.InterceptorsContainerStub{}, - &dataRetrieverMocks.ResolversContainerStub{}, + &dataRetrieverTests.ResolversContainerStub{}, &dataRetrieverTests.RequestersFinderStub{}, config.InterceptorResolverDebugConfig{ Enabled: true, @@ -115,7 +115,7 @@ func TestCreateInterceptedDebugHandler_SettingOnInterceptorsErrShouldErr(t *test interceptorsIterateCalled = true }, }, - &dataRetrieverMocks.ResolversContainerStub{ + &dataRetrieverTests.ResolversContainerStub{ IterateCalled: func(handler func(key string, resolver dataRetriever.Resolver) bool) { resolversIterateCalled = true }, @@ -158,7 +158,7 @@ func TestCreateInterceptedDebugHandler_SettingOnResolverErrShouldErr(t *testing. interceptorsIterateCalled = true }, }, - &dataRetrieverMocks.ResolversContainerStub{ + &dataRetrieverTests.ResolversContainerStub{ IterateCalled: func(handler func(key string, resolver dataRetriever.Resolver) bool) { handler("key", &dataRetrieverMocks.HeaderResolverStub{ SetDebugHandlerCalled: func(handler dataRetriever.DebugHandler) error { @@ -205,7 +205,7 @@ func TestCreateInterceptedDebugHandler_ShouldWork(t *testing.T) { interceptorsIterateCalled = true }, }, - &dataRetrieverMocks.ResolversContainerStub{ + &dataRetrieverTests.ResolversContainerStub{ IterateCalled: func(handler func(key string, resolver dataRetriever.Resolver) bool) { handler("key", &dataRetrieverMocks.HeaderResolverStub{}) resolversIterateCalled = true diff --git a/node/nodeRunner.go b/node/nodeRunner.go index ce87af025d8..de0a7e65bd4 100644 --- a/node/nodeRunner.go +++ b/node/nodeRunner.go @@ -314,7 +314,7 @@ func (nr *nodeRunner) executeOneComponentCreationCycle( nr.logInformation(managedCoreComponents, managedCryptoComponents, managedBootstrapComponents) log.Debug("creating data components") - managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents) + managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) if err != nil { return true, err } @@ -892,6 +892,7 @@ func (nr *nodeRunner) CreateManagedHeartbeatV2Components( heartbeatV2Args := heartbeatComp.ArgHeartbeatV2ComponentsFactory{ Config: *nr.configs.GeneralConfig, Prefs: *nr.configs.PreferencesConfig, + BaseVersion: nr.configs.FlagsConfig.BaseVersion, AppVersion: nr.configs.FlagsConfig.Version, BootstrapComponents: bootstrapComponents, CoreComponents: coreComponents, @@ -1217,6 +1218,7 @@ func (nr *nodeRunner) CreateManagedProcessComponents( ImportStartHandler: importStartHandler, WorkingDir: configs.FlagsConfig.WorkingDir, HistoryRepo: historyRepository, + SnapshotsEnabled: configs.FlagsConfig.SnapshotsEnabled, } processComponentsFactory, err := processComp.NewProcessComponentsFactory(processArgs) if err != nil { @@ -1241,6 +1243,7 @@ func (nr *nodeRunner) CreateManagedDataComponents( statusCoreComponents mainFactory.StatusCoreComponentsHolder, coreComponents mainFactory.CoreComponentsHolder, bootstrapComponents mainFactory.BootstrapComponentsHolder, + crypto mainFactory.CryptoComponentsHolder, ) (mainFactory.DataComponentsHandler, error) { configs := nr.configs storerEpoch := bootstrapComponents.EpochBootstrapParams().Epoch() @@ -1256,9 +1259,10 @@ func (nr *nodeRunner) CreateManagedDataComponents( ShardCoordinator: bootstrapComponents.ShardCoordinator(), Core: coreComponents, StatusCore: statusCoreComponents, - EpochStartNotifier: coreComponents.EpochStartNotifierWithConfirm(), + Crypto: crypto, CurrentEpoch: storerEpoch, CreateTrieEpochRootHashStorer: configs.ImportDbConfig.ImportDbSaveTrieEpochRootHash, + SnapshotsEnabled: configs.FlagsConfig.SnapshotsEnabled, } dataComponentsFactory, err := dataComp.NewDataComponentsFactory(dataArgs) @@ -1307,6 +1311,7 @@ func (nr *nodeRunner) CreateManagedStateComponents( StorageService: dataComponents.StorageService(), ProcessingMode: processingMode, ShouldSerializeSnapshots: nr.configs.FlagsConfig.SerializeSnapshots, + SnapshotsEnabled: nr.configs.FlagsConfig.SnapshotsEnabled, ChainHandler: dataComponents.Blockchain(), } @@ -1478,13 +1483,15 @@ func (nr *nodeRunner) CreateManagedCryptoComponents( ) (mainFactory.CryptoComponentsHandler, error) { configs := nr.configs validatorKeyPemFileName := configs.ConfigurationPathsHolder.ValidatorKey + allValidatorKeysPemFileName := configs.ConfigurationPathsHolder.AllValidatorKeys cryptoComponentsHandlerArgs := cryptoComp.CryptoComponentsFactoryArgs{ ValidatorKeyPemFileName: validatorKeyPemFileName, + AllValidatorKeysPemFileName: allValidatorKeysPemFileName, SkIndex: configs.FlagsConfig.ValidatorKeyIndex, Config: *configs.GeneralConfig, CoreComponentsHolder: coreComponents, ActivateBLSPubKeyMessageVerification: configs.SystemSCConfig.StakingSystemSCConfig.ActivateBLSPubKeyMessageVerification, - KeyLoader: &core.KeyLoader{}, + KeyLoader: core.NewKeyLoader(), ImportModeNoSigCheck: configs.ImportDbConfig.ImportDbNoSigCheckFlag, IsInImportMode: configs.ImportDbConfig.IsImportDBMode, EnableEpochs: configs.EpochConfig.EnableEpochs, diff --git a/node/nodeTesting_test.go b/node/nodeTesting_test.go index fc00e1474fa..89aa392bfc4 100644 --- a/node/nodeTesting_test.go +++ b/node/nodeTesting_test.go @@ -388,21 +388,23 @@ func TestGenerateAndSendBulkTransactions_ShouldWork(t *testing.T) { func getDefaultCryptoComponents() *factoryMock.CryptoComponentsMock { return &factoryMock.CryptoComponentsMock{ - PubKey: &mock.PublicKeyMock{}, - P2pPubKey: &mock.PublicKeyMock{}, - PrivKey: &mock.PrivateKeyStub{}, - P2pPrivKey: &mock.PrivateKeyStub{}, - PubKeyString: "pubKey", - PrivKeyBytes: []byte("privKey"), - PubKeyBytes: []byte("pubKey"), - BlockSig: &mock.SingleSignerMock{}, - TxSig: &mock.SingleSignerMock{}, - MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(cryptoMocks.NewMultiSigner()), - PeerSignHandler: &mock.PeerSignatureHandler{}, - BlKeyGen: &mock.KeyGenMock{}, - TxKeyGen: &mock.KeyGenMock{}, - P2PKeyGen: &mock.KeyGenMock{}, - MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, + PubKey: &mock.PublicKeyMock{}, + P2pPubKey: &mock.PublicKeyMock{}, + PrivKey: &mock.PrivateKeyStub{}, + P2pPrivKey: &mock.PrivateKeyStub{}, + PubKeyString: "pubKey", + PrivKeyBytes: []byte("privKey"), + PubKeyBytes: []byte("pubKey"), + BlockSig: &mock.SingleSignerMock{}, + TxSig: &mock.SingleSignerMock{}, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock( cryptoMocks.NewMultiSigner()), + PeerSignHandler: &mock.PeerSignatureHandler{}, + BlKeyGen: &mock.KeyGenMock{}, + TxKeyGen: &mock.KeyGenMock{}, + P2PKeyGen: &mock.KeyGenMock{}, + MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, + KeysHandlerField: &testscommon.KeysHandlerStub{}, + ManagedPeersHolderField: &testscommon.ManagedPeersHolderStub{}, } } diff --git a/outport/process/factory/check_test.go b/outport/process/factory/check_test.go index d82e4d42479..dcd5c3cbbdc 100644 --- a/outport/process/factory/check_test.go +++ b/outport/process/factory/check_test.go @@ -17,7 +17,7 @@ import ( func createArgOutportDataProviderFactory() ArgOutportDataProviderFactory { return ArgOutportDataProviderFactory{ HasDrivers: false, - AddressConverter: &testscommon.PubkeyConverterMock{}, + AddressConverter: testscommon.NewPubkeyConverterMock(32), AccountsDB: &state.AccountsStub{}, Marshaller: &testscommon.MarshalizerMock{}, EsdtDataStorageHandler: &testscommon.EsdtStorageHandlerStub{}, diff --git a/outport/process/factory/outportDataProviderFactory.go b/outport/process/factory/outportDataProviderFactory.go index 52cf264fa71..268972b5584 100644 --- a/outport/process/factory/outportDataProviderFactory.go +++ b/outport/process/factory/outportDataProviderFactory.go @@ -64,6 +64,7 @@ func CreateOutportDataProvider(arg ArgOutportDataProviderFactory) (outport.DataP TransactionsStorer: arg.TransactionsStorer, ShardCoordinator: arg.ShardCoordinator, TxFeeCalculator: arg.EconomicsData, + PubKeyConverter: arg.AddressConverter, }) if err != nil { return nil, err diff --git a/outport/process/transactionsfee/interface.go b/outport/process/transactionsfee/interface.go index a8e8dfbbc5f..fa09f18076a 100644 --- a/outport/process/transactionsfee/interface.go +++ b/outport/process/transactionsfee/interface.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/transaction" + datafield "github.com/multiversx/mx-chain-vm-common-go/parsers/dataField" ) // FeesProcessorHandler defines the interface for the transaction fees processor @@ -18,3 +19,7 @@ type FeesProcessorHandler interface { type transactionGetter interface { GetTxByHash(txHash []byte) (*transaction.Transaction, error) } + +type dataFieldParser interface { + Parse(dataField []byte, sender, receiver []byte, numOfShards uint32) *datafield.ResponseParseData +} diff --git a/outport/process/transactionsfee/transactionChecker.go b/outport/process/transactionsfee/transactionChecker.go index 5fc16189e77..593ab51a08c 100644 --- a/outport/process/transactionsfee/transactionChecker.go +++ b/outport/process/transactionsfee/transactionChecker.go @@ -12,6 +12,31 @@ import ( vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) +func (tep *transactionsFeeProcessor) isESDTOperationWithSCCall(tx data.TransactionHandlerWithGasUsedAndFee) bool { + res := tep.dataFieldParser.Parse(tx.GetData(), tx.GetSndAddr(), tx.GetRcvAddr(), tep.shardCoordinator.NumberOfShards()) + + isESDTTransferOperation := res.Operation == core.BuiltInFunctionESDTTransfer || + res.Operation == core.BuiltInFunctionESDTNFTTransfer || res.Operation == core.BuiltInFunctionMultiESDTNFTTransfer + + isReceiverSC := core.IsSmartContractAddress(tx.GetRcvAddr()) + hasFunction := res.Function != "" + if !hasFunction { + return false + } + + if !bytes.Equal(tx.GetSndAddr(), tx.GetRcvAddr()) { + return isESDTTransferOperation && isReceiverSC && hasFunction + } + + if len(res.Receivers) == 0 { + return false + } + + isReceiverSC = core.IsSmartContractAddress(res.Receivers[0]) + + return isESDTTransferOperation && isReceiverSC && hasFunction +} + func isSCRForSenderWithRefund(scr *smartContractResult.SmartContractResult, txHash []byte, tx data.TransactionHandlerWithGasUsedAndFee) bool { isForSender := bytes.Equal(scr.RcvAddr, tx.GetSndAddr()) isRightNonce := scr.Nonce == tx.GetNonce()+1 diff --git a/outport/process/transactionsfee/transactionsFeeProcessor.go b/outport/process/transactionsfee/transactionsFeeProcessor.go index 8952ecb72e1..64b0b6aaa7f 100644 --- a/outport/process/transactionsfee/transactionsFeeProcessor.go +++ b/outport/process/transactionsfee/transactionsFeeProcessor.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/storage" + datafield "github.com/multiversx/mx-chain-vm-common-go/parsers/dataField" ) // ArgTransactionsFeeProcessor holds the arguments needed for creating a new instance of transactionsFeeProcessor @@ -18,12 +19,14 @@ type ArgTransactionsFeeProcessor struct { TransactionsStorer storage.Storer ShardCoordinator sharding.Coordinator TxFeeCalculator FeesProcessorHandler + PubKeyConverter core.PubkeyConverter } type transactionsFeeProcessor struct { txGetter transactionGetter txFeeCalculator FeesProcessorHandler shardCoordinator sharding.Coordinator + dataFieldParser dataFieldParser } // NewTransactionsFeeProcessor will create a new instance of transactionsFeeProcessor @@ -33,10 +36,19 @@ func NewTransactionsFeeProcessor(arg ArgTransactionsFeeProcessor) (*transactions return nil, err } + parser, err := datafield.NewOperationDataFieldParser(&datafield.ArgsOperationDataFieldParser{ + AddressLength: arg.PubKeyConverter.Len(), + Marshalizer: arg.Marshaller, + }) + if err != nil { + return nil, err + } + return &transactionsFeeProcessor{ txFeeCalculator: arg.TxFeeCalculator, shardCoordinator: arg.ShardCoordinator, txGetter: newTxGetter(arg.TransactionsStorer, arg.Marshaller), + dataFieldParser: parser, }, nil } @@ -53,6 +65,9 @@ func checkArg(arg ArgTransactionsFeeProcessor) error { if check.IfNil(arg.Marshaller) { return ErrNilMarshaller } + if check.IfNil(arg.PubKeyConverter) { + return core.ErrNilPubkeyConverter + } return nil } @@ -86,7 +101,7 @@ func (tep *transactionsFeeProcessor) prepareNormalTxs(transactionsAndScrs *trans txWithResult.SetFee(fee) txWithResult.SetInitialPaidFee(initialPaidFee) - if isRelayedTx(txWithResult) { + if isRelayedTx(txWithResult) || tep.isESDTOperationWithSCCall(txWithResult) { txWithResult.SetGasUsed(txWithResult.GetGasLimit()) txWithResult.SetFee(initialPaidFee) } diff --git a/outport/process/transactionsfee/transactionsFeeProcessor_test.go b/outport/process/transactionsfee/transactionsFeeProcessor_test.go index 73d74ec012e..2744a968c72 100644 --- a/outport/process/transactionsfee/transactionsFeeProcessor_test.go +++ b/outport/process/transactionsfee/transactionsFeeProcessor_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/pubkeyConverter" coreData "github.com/multiversx/mx-chain-core-go/data" outportcore "github.com/multiversx/mx-chain-core-go/data/outport" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" @@ -15,12 +16,15 @@ import ( "github.com/stretchr/testify/require" ) +var pubKeyConverter, _ = pubkeyConverter.NewBech32PubkeyConverter(32, "erd") + func prepareMockArg() ArgTransactionsFeeProcessor { return ArgTransactionsFeeProcessor{ Marshaller: testscommon.MarshalizerMock{}, TransactionsStorer: genericMocks.NewStorerMock(), ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, TxFeeCalculator: &mock.EconomicsHandlerMock{}, + PubKeyConverter: pubKeyConverter, } } @@ -298,3 +302,40 @@ func TestPutFeeAndGasUsedWrongRelayedTx(t *testing.T) { require.Equal(t, uint64(550000000), initialTx.GetGasUsed()) require.Equal(t, "6103405000000000", initialTx.GetInitialPaidFee().String()) } + +func TestPutFeeAndGasUsedESDTWithScCall(t *testing.T) { + t.Parallel() + + txHash := []byte("tx") + tx := outportcore.NewTransactionHandlerWithGasAndFee(&transaction.Transaction{ + Nonce: 1011, + SndAddr: silentDecodeAddress("erd1dglncxk6sl9a3xumj78n6z2xux4ghp5c92cstv5zsn56tjgtdwpsk46qrs"), + RcvAddr: silentDecodeAddress("erd1dglncxk6sl9a3xumj78n6z2xux4ghp5c92cstv5zsn56tjgtdwpsk46qrs"), + GasLimit: 55_000_000, + GasPrice: 1000000000, + Data: []byte("ESDTNFTTransfer@434f572d636434363364@080c@01@00000000000000000500d3b28828d62052124f07dcd50ed31b0825f60eee1526@616363657074476c6f62616c4f66666572@c3e5"), + Value: big.NewInt(0), + }, 0, big.NewInt(0)) + + pool := &outportcore.Pool{ + Txs: map[string]coreData.TransactionHandlerWithGasUsedAndFee{ + string(txHash): tx, + }, + } + + arg := prepareMockArg() + txsFeeProc, err := NewTransactionsFeeProcessor(arg) + require.NotNil(t, txsFeeProc) + require.Nil(t, err) + + err = txsFeeProc.PutFeeAndGasUsed(pool) + require.Nil(t, err) + require.Equal(t, big.NewInt(820765000000000), tx.GetFee()) + require.Equal(t, uint64(55_000_000), tx.GetGasUsed()) + require.Equal(t, "820765000000000", tx.GetInitialPaidFee().String()) +} + +func silentDecodeAddress(address string) []byte { + decoded, _ := pubKeyConverter.Decode(address) + return decoded +} diff --git a/p2p/factory/factory.go b/p2p/factory/factory.go index 591059bbec0..40cdd00d8e3 100644 --- a/p2p/factory/factory.go +++ b/p2p/factory/factory.go @@ -1,8 +1,6 @@ package factory import ( - "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-p2p-go/libp2p" p2pCrypto "github.com/multiversx/mx-chain-p2p-go/libp2p/crypto" @@ -45,9 +43,9 @@ func NewPeersHolder(preferredConnectionAddresses []string) (p2p.PreferredPeersHo return peersHolder.NewPeersHolder(preferredConnectionAddresses) } -// ConvertPublicKeyToPeerID will convert a public key to core.PeerID -func ConvertPublicKeyToPeerID(pk crypto.PublicKey) (core.PeerID, error) { - return p2pCrypto.ConvertPublicKeyToPeerID(pk) +// NewP2PKeyConverter returns a new instance of p2pKeyConverter +func NewP2PKeyConverter() p2p.P2PKeyConverter { + return p2pCrypto.NewP2PKeyConverter() } // NewMessageVerifier will return a new instance of messages verifier diff --git a/p2p/interface.go b/p2p/interface.go index 7496daaa682..3799f1afeee 100644 --- a/p2p/interface.go +++ b/p2p/interface.go @@ -5,6 +5,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" p2p "github.com/multiversx/mx-chain-p2p-go" ) @@ -112,3 +113,10 @@ type IdentityGenerator interface { CreateRandomP2PIdentity() ([]byte, core.PeerID, error) IsInterfaceNil() bool } + +// P2PKeyConverter defines what a p2p key converter can do +type P2PKeyConverter interface { + ConvertPeerIDToPublicKey(keyGen crypto.KeyGenerator, pid core.PeerID) (crypto.PublicKey, error) + ConvertPublicKeyToPeerID(pk crypto.PublicKey) (core.PeerID, error) + IsInterfaceNil() bool +} diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index a9c47516a55..160f4cccfa7 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -1611,7 +1611,7 @@ func trimSliceBootstrapHeaderInfo(in []bootstrapStorage.BootstrapHeaderInfo) []b return ret } -func (bp *baseProcessor) restoreBlockBody(bodyHandler data.BodyHandler) { +func (bp *baseProcessor) restoreBlockBody(headerHandler data.HeaderHandler, bodyHandler data.BodyHandler) { if check.IfNil(bodyHandler) { log.Debug("restoreMiniblocks nil bodyHandler") return @@ -1623,12 +1623,12 @@ func (bp *baseProcessor) restoreBlockBody(bodyHandler data.BodyHandler) { return } - restoredTxNr, errNotCritical := bp.txCoordinator.RestoreBlockDataFromStorage(body) + _, errNotCritical := bp.txCoordinator.RestoreBlockDataFromStorage(body) if errNotCritical != nil { log.Debug("restoreBlockBody RestoreBlockDataFromStorage", "error", errNotCritical.Error()) } - go bp.txCounter.subtractRestoredTxs(restoredTxNr) + go bp.txCounter.headerReverted(headerHandler) } // RestoreBlockBodyIntoPools restores the block body into associated pools diff --git a/process/block/displayBlock.go b/process/block/displayBlock.go index c5b24a20e23..3b1ab7410cc 100644 --- a/process/block/displayBlock.go +++ b/process/block/displayBlock.go @@ -20,32 +20,44 @@ import ( ) type transactionCounter struct { - mutex sync.RWMutex - currentBlockTxs uint64 - totalTxs uint64 - hasher hashing.Hasher - marshalizer marshal.Marshalizer + mutex sync.RWMutex + currentBlockTxs uint64 + totalTxs uint64 + hasher hashing.Hasher + marshalizer marshal.Marshalizer + appStatusHandler core.AppStatusHandler + shardID uint32 +} + +// ArgsTransactionCounter represents the arguments needed to create a new transaction counter +type ArgsTransactionCounter struct { + AppStatusHandler core.AppStatusHandler + Hasher hashing.Hasher + Marshalizer marshal.Marshalizer + ShardID uint32 } // NewTransactionCounter returns a new object that keeps track of how many transactions // were executed in total, and in the current block -func NewTransactionCounter( - hasher hashing.Hasher, - marshalizer marshal.Marshalizer, -) (*transactionCounter, error) { - if check.IfNil(hasher) { +func NewTransactionCounter(args ArgsTransactionCounter) (*transactionCounter, error) { + if check.IfNil(args.AppStatusHandler) { + return nil, process.ErrNilAppStatusHandler + } + if check.IfNil(args.Hasher) { return nil, process.ErrNilHasher } - if check.IfNil(marshalizer) { + if check.IfNil(args.Marshalizer) { return nil, process.ErrNilMarshalizer } return &transactionCounter{ - mutex: sync.RWMutex{}, - currentBlockTxs: 0, - totalTxs: 0, - hasher: hasher, - marshalizer: marshalizer, + mutex: sync.RWMutex{}, + appStatusHandler: args.AppStatusHandler, + currentBlockTxs: 0, + totalTxs: 0, + hasher: args.Hasher, + marshalizer: args.Marshalizer, + shardID: args.ShardID, }, nil } @@ -56,16 +68,63 @@ func (txc *transactionCounter) getPoolCounts(poolsHolder dataRetriever.PoolsHold return } -// subtractRestoredTxs updated the total processed txs in case of restore -func (txc *transactionCounter) subtractRestoredTxs(txsNr int) { +// headerReverted updates the total processed txs in case of restore. It also sets the current block txs to 0 +func (txc *transactionCounter) headerReverted(hdr data.HeaderHandler) { + if check.IfNil(hdr) { + log.Warn("programming error: nil header in transactionCounter.headerReverted function") + return + } + + currentBlockTxs := txc.getProcessedTxCount(hdr) + txc.mutex.Lock() - defer txc.mutex.Unlock() - if txc.totalTxs < uint64(txsNr) { + txc.currentBlockTxs = 0 + txc.safeSubtractTotalTxs(uint64(currentBlockTxs)) + txc.appStatusHandler.SetUInt64Value(common.MetricNumProcessedTxs, txc.totalTxs) + txc.mutex.Unlock() +} + +func (txc *transactionCounter) safeSubtractTotalTxs(delta uint64) { + if txc.totalTxs < delta { txc.totalTxs = 0 return } - txc.totalTxs -= uint64(txsNr) + txc.totalTxs -= delta +} + +func (txc *transactionCounter) headerExecuted(hdr data.HeaderHandler) { + if check.IfNil(hdr) { + log.Warn("programming error: nil header in transactionCounter.headerExecuted function") + return + } + + currentBlockTxs := txc.getProcessedTxCount(hdr) + + txc.mutex.Lock() + txc.currentBlockTxs = uint64(currentBlockTxs) + txc.totalTxs += uint64(currentBlockTxs) + txc.appStatusHandler.SetUInt64Value(common.MetricNumProcessedTxs, txc.totalTxs) + txc.mutex.Unlock() +} + +func (txc *transactionCounter) getProcessedTxCount(hdr data.HeaderHandler) int32 { + currentBlockTxs := int32(0) + for _, miniBlockHeaderHandler := range hdr.GetMiniBlockHeaderHandlers() { + if miniBlockHeaderHandler.GetTypeInt32() == int32(block.PeerBlock) { + continue + } + + isMiniblockScheduledFromMe := miniBlockHeaderHandler.GetSenderShardID() == txc.shardID && + miniBlockHeaderHandler.GetProcessingType() == int32(block.Scheduled) + if isMiniblockScheduledFromMe { + continue + } + + currentBlockTxs += miniBlockHeaderHandler.GetIndexOfLastTxProcessed() - miniBlockHeaderHandler.GetIndexOfFirstTxProcessed() + 1 + } + + return currentBlockTxs } // displayLogInfo writes to the output information about the block and transactions @@ -76,15 +135,10 @@ func (txc *transactionCounter) displayLogInfo( numShards uint32, selfId uint32, _ dataRetriever.PoolsHolder, - appStatusHandler core.AppStatusHandler, blockTracker process.BlockTracker, ) { dispHeader, dispLines := txc.createDisplayableShardHeaderAndBlockBody(header, body) - txc.mutex.RLock() - appStatusHandler.SetUInt64Value(common.MetricNumProcessedTxs, txc.totalTxs) - txc.mutex.RUnlock() - tblString, err := display.CreateTableString(dispHeader, dispLines) if err != nil { log.Debug("CreateTableString", "error", err.Error()) @@ -187,8 +241,6 @@ func (txc *transactionCounter) displayTxBlockBody( header data.HeaderHandler, body *block.Body, ) []*display.LineData { - currentBlockTxs := 0 - miniBlockHeaders := header.GetMiniBlockHeaderHandlers() for i := 0; i < len(body.MiniBlocks); i++ { miniBlock := body.MiniBlocks[i] @@ -227,8 +279,6 @@ func (txc *transactionCounter) displayTxBlockBody( lines = append(lines, display.NewLineData(false, []string{"", "TxsProcessedRange", strProcessedRange})) } - currentBlockTxs += len(miniBlock.TxHashes) - for j := 0; j < len(miniBlock.TxHashes); j++ { if j == 0 || j >= len(miniBlock.TxHashes)-1 { lines = append(lines, display.NewLineData(false, []string{ @@ -251,11 +301,6 @@ func (txc *transactionCounter) displayTxBlockBody( lines[len(lines)-1].HorizontalRuleAfter = true } - txc.mutex.Lock() - txc.currentBlockTxs = uint64(currentBlockTxs) - txc.totalTxs += uint64(currentBlockTxs) - txc.mutex.Unlock() - return lines } @@ -310,3 +355,24 @@ func DisplayLastNotarized( "nonce", lastNotarizedHdrForShard.GetNonce(), "hash", lastNotarizedHdrHashForShard) } + +// CurrentBlockTxs returns the current block's number of processed transactions +func (txc *transactionCounter) CurrentBlockTxs() uint64 { + txc.mutex.RLock() + defer txc.mutex.RUnlock() + + return txc.currentBlockTxs +} + +// TotalTxs returns the total number of processed transactions +func (txc *transactionCounter) TotalTxs() uint64 { + txc.mutex.RLock() + defer txc.mutex.RUnlock() + + return txc.totalTxs +} + +// IsInterfaceNil returns true if there is no value under the interface +func (txc *transactionCounter) IsInterfaceNil() bool { + return txc == nil +} diff --git a/process/block/displayBlock_test.go b/process/block/displayBlock_test.go index 538e32f2b5d..ccc6eced1a0 100644 --- a/process/block/displayBlock_test.go +++ b/process/block/displayBlock_test.go @@ -1,13 +1,19 @@ package block import ( + "fmt" + "sync" "testing" + "time" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/display" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func createGenesisBlock(shardId uint32) *block.Header { @@ -26,10 +32,21 @@ func createGenesisBlock(shardId uint32) *block.Header { } } +func createMockArgsTransactionCounter() ArgsTransactionCounter { + return ArgsTransactionCounter{ + AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, + Hasher: &testscommon.HasherStub{}, + Marshalizer: &testscommon.MarshalizerMock{}, + ShardID: 0, + } +} + func TestDisplayBlock_NewTransactionCounterShouldErrWhenHasherIsNil(t *testing.T) { t.Parallel() - txCounter, err := NewTransactionCounter(nil, &testscommon.MarshalizerMock{}) + args := createMockArgsTransactionCounter() + args.Hasher = nil + txCounter, err := NewTransactionCounter(args) assert.Nil(t, txCounter) assert.Equal(t, process.ErrNilHasher, err) @@ -38,16 +55,30 @@ func TestDisplayBlock_NewTransactionCounterShouldErrWhenHasherIsNil(t *testing.T func TestDisplayBlock_NewTransactionCounterShouldErrWhenMarshalizerIsNil(t *testing.T) { t.Parallel() - txCounter, err := NewTransactionCounter(&testscommon.HasherStub{}, nil) + args := createMockArgsTransactionCounter() + args.Marshalizer = nil + txCounter, err := NewTransactionCounter(args) assert.Nil(t, txCounter) assert.Equal(t, process.ErrNilMarshalizer, err) } +func TestDisplayBlock_NewTransactionCounterShouldErrWhenAppStatusHandlerIsNil(t *testing.T) { + t.Parallel() + + args := createMockArgsTransactionCounter() + args.AppStatusHandler = nil + txCounter, err := NewTransactionCounter(args) + + assert.Nil(t, txCounter) + assert.Equal(t, process.ErrNilAppStatusHandler, err) +} + func TestDisplayBlock_NewTransactionCounterShouldWork(t *testing.T) { t.Parallel() - txCounter, err := NewTransactionCounter(&testscommon.HasherStub{}, &testscommon.MarshalizerMock{}) + args := createMockArgsTransactionCounter() + txCounter, err := NewTransactionCounter(args) assert.NotNil(t, txCounter) assert.Nil(t, err) @@ -58,7 +89,8 @@ func TestDisplayBlock_DisplayMetaHashesIncluded(t *testing.T) { shardLines := make([]*display.LineData, 0) header := createGenesisBlock(0) - txCounter, _ := NewTransactionCounter(&testscommon.HasherStub{}, &testscommon.MarshalizerMock{}) + args := createMockArgsTransactionCounter() + txCounter, _ := NewTransactionCounter(args) lines := txCounter.displayMetaHashesIncluded( shardLines, header, @@ -79,7 +111,8 @@ func TestDisplayBlock_DisplayTxBlockBody(t *testing.T) { TxHashes: [][]byte{[]byte("hash1"), []byte("hash2"), []byte("hash3")}, } body.MiniBlocks = append(body.MiniBlocks, &miniblock) - txCounter, _ := NewTransactionCounter(&testscommon.HasherStub{}, &testscommon.MarshalizerMock{}) + args := createMockArgsTransactionCounter() + txCounter, _ := NewTransactionCounter(args) lines := txCounter.displayTxBlockBody( shardLines, &block.Header{}, @@ -105,3 +138,176 @@ func TestDisplayBlock_GetConstructionStateAsString(t *testing.T) { str = getConstructionStateAsString(miniBlockHeader) assert.Equal(t, "", str) } + +func TestDisplayBlock_ConcurrencyTestForTotalTxs(t *testing.T) { + t.Parallel() + + numCalls := 100 + wg := sync.WaitGroup{} + wg.Add(numCalls) + + args := createMockArgsTransactionCounter() + txCounter, _ := NewTransactionCounter(args) + + mbh1 := block.MiniBlockHeader{} + _ = mbh1.SetIndexOfLastTxProcessed(0) + _ = mbh1.SetIndexOfLastTxProcessed(37) + header := &block.Header{ + MiniBlockHeaders: []block.MiniBlockHeader{mbh1}, + } + + for i := 0; i < numCalls; i++ { + go func(idx int) { + time.Sleep(time.Millisecond * 10) + defer wg.Done() + + switch idx % 4 { + case 0: + txCounter.headerReverted(header) + case 1: + txCounter.headerExecuted(header) + case 2: + _ = txCounter.TotalTxs() + case 3: + _ = txCounter.CurrentBlockTxs() + } + }(i) + } + + wg.Wait() +} + +func TestTransactionCounter_HeaderExecutedAndReverted(t *testing.T) { + t.Parallel() + + args := createMockArgsTransactionCounter() + + mbhPeer := block.MiniBlockHeader{} + _ = mbhPeer.SetTypeInt32(int32(block.PeerBlock)) + _ = mbhPeer.SetIndexOfFirstTxProcessed(0) + _ = mbhPeer.SetIndexOfLastTxProcessed(99) + + mbhRwd := block.MiniBlockHeader{} + _ = mbhRwd.SetTypeInt32(int32(block.RewardsBlock)) + _ = mbhRwd.SetIndexOfFirstTxProcessed(0) + _ = mbhRwd.SetIndexOfLastTxProcessed(199) + + mbhScheduledFromShard0 := block.MiniBlockHeader{} + _ = mbhScheduledFromShard0.SetTypeInt32(int32(block.TxBlock)) + _ = mbhScheduledFromShard0.SetProcessingType(int32(block.Scheduled)) + _ = mbhScheduledFromShard0.SetIndexOfFirstTxProcessed(0) + _ = mbhScheduledFromShard0.SetIndexOfLastTxProcessed(399) + + mbhScheduledFromShard1 := block.MiniBlockHeader{ + SenderShardID: 1, + } + _ = mbhScheduledFromShard1.SetTypeInt32(int32(block.TxBlock)) + _ = mbhScheduledFromShard1.SetProcessingType(int32(block.Scheduled)) + _ = mbhScheduledFromShard1.SetIndexOfFirstTxProcessed(0) + _ = mbhScheduledFromShard1.SetIndexOfLastTxProcessed(499) + + t.Run("headerExecuted", func(t *testing.T) { + t.Parallel() + + txCounter, _ := NewTransactionCounter(args) + require.False(t, check.IfNil(txCounter)) + t.Run("nil header should not panic", func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked: %v", r)) + } + }() + + txCounter.headerExecuted(nil) + }) + t.Run("empty header", func(t *testing.T) { + txCounter.totalTxs = 1000 // initial value + txCounter.headerExecuted(&block.Header{}) + assert.Equal(t, uint64(1000), txCounter.TotalTxs()) + assert.Equal(t, uint64(0), txCounter.CurrentBlockTxs()) + }) + t.Run("header with peer miniblocks & rewards miniblocks", func(t *testing.T) { + txCounter.totalTxs = 1000 // initial value + + blk := &block.Header{ + MiniBlockHeaders: []block.MiniBlockHeader{mbhPeer, mbhRwd}, + } + + txCounter.headerExecuted(blk) + assert.Equal(t, uint64(1200), txCounter.TotalTxs()) + assert.Equal(t, uint64(200), txCounter.CurrentBlockTxs()) + }) + t.Run("header with scheduled from self and shard 1", func(t *testing.T) { + txCounter.totalTxs = 1000 // initial value + + blk := &block.Header{ + MiniBlockHeaders: []block.MiniBlockHeader{mbhScheduledFromShard0, mbhScheduledFromShard1}, + } + + txCounter.headerExecuted(blk) + assert.Equal(t, uint64(1500), txCounter.TotalTxs()) + assert.Equal(t, uint64(500), txCounter.CurrentBlockTxs()) + }) + }) + t.Run("headerReverted", func(t *testing.T) { + t.Parallel() + + txCounter, _ := NewTransactionCounter(args) + require.False(t, check.IfNil(txCounter)) + t.Run("nil header should not panic", func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked: %v", r)) + } + }() + + txCounter.headerReverted(nil) + }) + t.Run("empty header", func(t *testing.T) { + txCounter.totalTxs = 1000 // initial value + txCounter.headerReverted(&block.Header{}) + assert.Equal(t, uint64(1000), txCounter.TotalTxs()) + assert.Equal(t, uint64(0), txCounter.CurrentBlockTxs()) + }) + t.Run("header with peer miniblocks & rewards miniblocks", func(t *testing.T) { + txCounter.totalTxs = 1000 // initial value + blk := &block.Header{ + MiniBlockHeaders: []block.MiniBlockHeader{mbhPeer, mbhRwd}, + } + + txCounter.headerReverted(blk) + assert.Equal(t, uint64(800), txCounter.TotalTxs()) // 1000 - 200 + assert.Equal(t, uint64(0), txCounter.CurrentBlockTxs()) // unable to revert to the last executed block, so hardcoded to 0 + }) + t.Run("header with scheduled from self and shard 1", func(t *testing.T) { + txCounter.totalTxs = 1000 // initial value + blk := &block.Header{ + MiniBlockHeaders: []block.MiniBlockHeader{mbhScheduledFromShard0, mbhScheduledFromShard1}, + } + + txCounter.headerReverted(blk) + assert.Equal(t, uint64(500), txCounter.TotalTxs()) // 1000 - 500 + assert.Equal(t, uint64(0), txCounter.CurrentBlockTxs()) // unable to revert to the last executed block, so hardcoded to 0 + }) + }) + t.Run("headerExecuted then headerReverted", func(t *testing.T) { + t.Parallel() + + txCounter, _ := NewTransactionCounter(args) + require.False(t, check.IfNil(txCounter)) + txCounter.totalTxs = 1000 // initial value + blk := &block.Header{ + MiniBlockHeaders: []block.MiniBlockHeader{mbhPeer, mbhRwd, mbhScheduledFromShard0, mbhScheduledFromShard1}, + } + + txCounter.headerExecuted(blk) + assert.Equal(t, uint64(1700), txCounter.TotalTxs()) + assert.Equal(t, uint64(700), txCounter.CurrentBlockTxs()) + + txCounter.headerReverted(blk) + assert.Equal(t, uint64(1000), txCounter.TotalTxs()) + assert.Equal(t, uint64(0), txCounter.CurrentBlockTxs()) + }) +} diff --git a/process/block/displayMetaBlock.go b/process/block/displayMetaBlock.go index e2755c6c8bf..2018b819925 100644 --- a/process/block/displayMetaBlock.go +++ b/process/block/displayMetaBlock.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/display" @@ -11,11 +12,16 @@ import ( "github.com/multiversx/mx-chain-logger-go" ) +type transactionCountersProvider interface { + CurrentBlockTxs() uint64 + TotalTxs() uint64 + IsInterfaceNil() bool +} + type headersCounter struct { shardMBHeaderCounterMutex sync.RWMutex shardMBHeadersCurrentBlockProcessed uint64 shardMBHeadersTotalProcessed uint64 - peakTPS uint64 } // NewHeaderCounter returns a new object that keeps track of how many headers @@ -25,7 +31,6 @@ func NewHeaderCounter() *headersCounter { shardMBHeaderCounterMutex: sync.RWMutex{}, shardMBHeadersCurrentBlockProcessed: 0, shardMBHeadersTotalProcessed: 0, - peakTPS: 0, } } @@ -59,13 +64,18 @@ func (hc *headersCounter) calculateNumOfShardMBHeaders(header *block.MetaBlock) } func (hc *headersCounter) displayLogInfo( + countersProvider transactionCountersProvider, header *block.MetaBlock, body *block.Body, headerHash []byte, numShardHeadersFromPool int, blockTracker process.BlockTracker, - roundDuration uint64, ) { + if check.IfNil(countersProvider) { + log.Warn("programming error in headersCounter.displayLogInfo - nil countersProvider") + return + } + hc.calculateNumOfShardMBHeaders(header) dispHeader, dispLines := hc.createDisplayableMetaHeader(header) @@ -88,19 +98,11 @@ func (hc *headersCounter) displayLogInfo( log.Debug(message, arguments...) - numTxs := getNumTxs(header, body) - tps := numTxs / roundDuration - if tps > hc.peakTPS { - hc.peakTPS = tps - } - - log.Debug("tps info", - "shard", header.GetShardID(), - "round", header.GetRound(), - "nonce", header.GetNonce(), - "num txs", numTxs, - "tps", tps, - "peak tps", hc.peakTPS) + log.Debug("metablock metrics info", + "total txs processed", countersProvider.TotalTxs(), + "block txs processed", countersProvider.CurrentBlockTxs(), + "hash", headerHash, + ) blockTracker.DisplayTrackedHeaders() } @@ -320,38 +322,3 @@ func displayEconomicsData(economics block.Economics) []*display.LineData { fmt.Sprintf("%d", economics.PrevEpochStartRound)}), } } - -func getNumTxs(metaBlock *block.MetaBlock, body *block.Body) uint64 { - shardInfo := metaBlock.ShardInfo - numTxs := uint64(0) - for i := 0; i < len(shardInfo); i++ { - shardMiniBlockHeaders := shardInfo[i].ShardMiniBlockHeaders - numTxsInShardHeader := uint64(0) - for j := 0; j < len(shardMiniBlockHeaders); j++ { - numTxsInShardHeader += uint64(shardMiniBlockHeaders[j].TxCount) - } - - log.Trace("txs info", - "shard", shardInfo[i].GetShardID(), - "round", shardInfo[i].GetRound(), - "nonce", shardInfo[i].GetNonce(), - "num txs", numTxsInShardHeader) - - numTxs += numTxsInShardHeader - } - - numTxsInMetaBlock := uint64(0) - for i := 0; i < len(body.MiniBlocks); i++ { - numTxsInMetaBlock += uint64(len(body.MiniBlocks[i].TxHashes)) - } - - log.Trace("txs info", - "shard", metaBlock.GetShardID(), - "round", metaBlock.GetRound(), - "nonce", metaBlock.GetNonce(), - "num txs", numTxsInMetaBlock) - - numTxs += numTxsInMetaBlock - - return numTxs -} diff --git a/process/block/export_test.go b/process/block/export_test.go index 3507ff0c02c..a47d9851500 100644 --- a/process/block/export_test.go +++ b/process/block/export_test.go @@ -344,10 +344,9 @@ func (sp *shardProcessor) DisplayLogInfo( numShards uint32, selfId uint32, dataPool dataRetriever.PoolsHolder, - statusHandler core.AppStatusHandler, blockTracker process.BlockTracker, ) { - sp.txCounter.displayLogInfo(header, body, headerHash, numShards, selfId, dataPool, statusHandler, blockTracker) + sp.txCounter.displayLogInfo(header, body, headerHash, numShards, selfId, dataPool, blockTracker) } func (sp *shardProcessor) GetHighestHdrForOwnShardFromMetachain(processedHdrs []data.HeaderHandler) ([]data.HeaderHandler, [][]byte, error) { diff --git a/process/block/metablock.go b/process/block/metablock.go index 310072acafd..17dd596bcf8 100644 --- a/process/block/metablock.go +++ b/process/block/metablock.go @@ -148,7 +148,13 @@ func NewMetaProcessor(arguments ArgMetaProcessor) (*metaProcessor, error) { epochSystemSCProcessor: arguments.EpochSystemSCProcessor, } - mp.txCounter, err = NewTransactionCounter(mp.hasher, mp.marshalizer) + argsTransactionCounter := ArgsTransactionCounter{ + AppStatusHandler: mp.appStatusHandler, + Hasher: mp.hasher, + Marshalizer: mp.marshalizer, + ShardID: core.MetachainShardId, + } + mp.txCounter, err = NewTransactionCounter(argsTransactionCounter) if err != nil { return nil, err } @@ -681,7 +687,7 @@ func (mp *metaProcessor) RestoreBlockIntoPools(headerHandler data.HeaderHandler, mp.headersCounter.subtractRestoredMBHeaders(len(shardHeader.GetMiniBlockHeaderHandlers())) } - mp.restoreBlockBody(bodyHandler) + mp.restoreBlockBody(headerHandler, bodyHandler) mp.blockTracker.RemoveLastNotarizedHeaders() @@ -1282,14 +1288,17 @@ func (mp *metaProcessor) CommitBlock( numShardHeadersFromPool += headersPool.GetNumHeaders(shardID) } - go mp.headersCounter.displayLogInfo( - header, - body, - headerHash, - numShardHeadersFromPool, - mp.blockTracker, - uint64(mp.roundHandler.TimeDuration().Seconds()), - ) + go func() { + mp.txCounter.headerExecuted(header) + mp.headersCounter.displayLogInfo( + mp.txCounter, + header, + body, + headerHash, + numShardHeadersFromPool, + mp.blockTracker, + ) + }() headerInfo := bootstrapStorage.BootstrapHeaderInfo{ ShardId: header.GetShardID(), diff --git a/process/block/shardblock.go b/process/block/shardblock.go index 16b88d5bf9d..42eed2856b5 100644 --- a/process/block/shardblock.go +++ b/process/block/shardblock.go @@ -124,7 +124,13 @@ func NewShardProcessor(arguments ArgShardProcessor) (*shardProcessor, error) { baseProcessor: base, } - sp.txCounter, err = NewTransactionCounter(sp.hasher, sp.marshalizer) + argsTransactionCounter := ArgsTransactionCounter{ + AppStatusHandler: sp.appStatusHandler, + Hasher: sp.hasher, + Marshalizer: sp.marshalizer, + ShardID: sp.shardCoordinator.SelfId(), + } + sp.txCounter, err = NewTransactionCounter(argsTransactionCounter) if err != nil { return nil, err } @@ -623,7 +629,7 @@ func (sp *shardProcessor) RestoreBlockIntoPools(headerHandler data.HeaderHandler return err } - sp.restoreBlockBody(bodyHandler) + sp.restoreBlockBody(headerHandler, bodyHandler) sp.blockTracker.RemoveLastNotarizedHeaders() @@ -807,6 +813,16 @@ func (sp *shardProcessor) CreateBlock( if err != nil { return nil, nil, err } + + epoch := sp.epochStartTrigger.MetaEpoch() + if initialHdr.GetEpoch() != epoch { + log.Debug("shardProcessor.CreateBlock: epoch from header is not the same as epoch from epoch start trigger, overwriting", + "epoch from header", initialHdr.GetEpoch(), "epoch from epoch start trigger", epoch) + err = shardHdr.SetEpoch(epoch) + if err != nil { + return nil, nil, err + } + } } sp.epochNotifier.CheckEpoch(shardHdr) @@ -1037,16 +1053,18 @@ func (sp *shardProcessor) CommitBlock( sp.prepareDataForBootStorer(args) // write data to log - go sp.txCounter.displayLogInfo( - header, - body, - headerHash, - sp.shardCoordinator.NumberOfShards(), - sp.shardCoordinator.SelfId(), - sp.dataPool, - sp.appStatusHandler, - sp.blockTracker, - ) + go func() { + sp.txCounter.headerExecuted(header) + sp.txCounter.displayLogInfo( + header, + body, + headerHash, + sp.shardCoordinator.NumberOfShards(), + sp.shardCoordinator.SelfId(), + sp.dataPool, + sp.blockTracker, + ) + }() sp.blockSizeThrottler.Succeed(header.GetRound()) diff --git a/process/block/shardblock_test.go b/process/block/shardblock_test.go index 07411d109fb..1dce393b2d2 100644 --- a/process/block/shardblock_test.go +++ b/process/block/shardblock_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" atomicCore "github.com/multiversx/mx-chain-core-go/core/atomic" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" outportcore "github.com/multiversx/mx-chain-core-go/data/outport" @@ -2425,10 +2426,6 @@ func TestShardProcessor_DisplayLogInfo(t *testing.T) { hasher := hashingMocks.HasherMock{} hdr, txBlock := createTestHdrTxBlockBody() shardCoordinator := mock.NewMultiShardsCoordinatorMock(3) - statusHandler := &statusHandlerMock.AppStatusHandlerStub{ - SetUInt64ValueHandler: func(key string, value uint64) { - }, - } coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() dataComponents.DataPool = tdp @@ -2437,7 +2434,7 @@ func TestShardProcessor_DisplayLogInfo(t *testing.T) { sp, _ := blproc.NewShardProcessor(arguments) assert.NotNil(t, sp) hdr.PrevHash = hasher.Compute("prev hash") - sp.DisplayLogInfo(hdr, txBlock, []byte("tx_hash1"), shardCoordinator.NumberOfShards(), shardCoordinator.SelfId(), tdp, statusHandler, &mock.BlockTrackerMock{}) + sp.DisplayLogInfo(hdr, txBlock, []byte("tx_hash1"), shardCoordinator.NumberOfShards(), shardCoordinator.SelfId(), tdp, &mock.BlockTrackerMock{}) } func TestBlockProcessor_ApplyBodyToHeaderNilBodyError(t *testing.T) { @@ -5241,3 +5238,147 @@ func TestShardProcessor_RollBackProcessedMiniBlocksInfo(t *testing.T) { assert.False(t, processedMbInfo.FullyProcessed) assert.Equal(t, indexOfFirstTxProcessed-1, processedMbInfo.IndexOfLastTxProcessed) } + +func TestShardProcessor_CreateBlock(t *testing.T) { + t.Parallel() + + arguments := CreateMockArguments(createComponentHolderMocks()) + processHandler := arguments.CoreComponents.ProcessStatusHandler() + mockProcessHandler := processHandler.(*testscommon.ProcessStatusHandlerStub) + busyIdleCalled := make([]string, 0) + mockProcessHandler.SetIdleCalled = func() { + busyIdleCalled = append(busyIdleCalled, idleIdentifier) + } + mockProcessHandler.SetBusyCalled = func(reason string) { + busyIdleCalled = append(busyIdleCalled, busyIdentifier) + } + + expectedBusyIdleSequencePerCall := []string{busyIdentifier, idleIdentifier} + sp, errConstructor := blproc.NewShardProcessor(arguments) + assert.Nil(t, errConstructor) + + doesHaveTime := func() bool { + return true + } + t.Run("nil block should error", func(t *testing.T) { + hdr, body, err := sp.CreateBlock(nil, doesHaveTime) + assert.True(t, check.IfNil(body)) + assert.True(t, check.IfNil(hdr)) + assert.Equal(t, process.ErrNilBlockHeader, err) + assert.Zero(t, len(busyIdleCalled)) + }) + t.Run("wrong block type should error", func(t *testing.T) { + meta := &block.MetaBlock{} + + hdr, body, err := sp.CreateBlock(meta, doesHaveTime) + assert.True(t, check.IfNil(body)) + assert.True(t, check.IfNil(hdr)) + assert.Equal(t, process.ErrWrongTypeAssertion, err) + assert.Zero(t, len(busyIdleCalled)) + }) + t.Run("should work with empty header v1", func(t *testing.T) { + header := &block.Header{ + Nonce: 37, + Round: 38, + Epoch: 1, + } + + expectedHeader := &block.Header{ + Nonce: 37, + Round: 38, + Epoch: 1, + ReceiptsHash: []byte("receiptHash"), + DeveloperFees: big.NewInt(0), + AccumulatedFees: big.NewInt(0), + } + + // reset the slice, do not call these tests in parallel + busyIdleCalled = make([]string, 0) + hdr, bodyHandler, err := sp.CreateBlock(header, doesHaveTime) + assert.False(t, check.IfNil(bodyHandler)) + body, ok := bodyHandler.(*block.Body) + assert.True(t, ok) + + assert.Zero(t, len(body.MiniBlocks)) + assert.False(t, check.IfNil(hdr)) + assert.Equal(t, expectedHeader, header) + assert.Nil(t, err) + assert.Equal(t, expectedBusyIdleSequencePerCall, busyIdleCalled) + }) + t.Run("should work with empty header v2", func(t *testing.T) { + header := &block.HeaderV2{ + Header: &block.Header{ + Nonce: 37, + Round: 38, + Epoch: 1, + }, + } + + expectedHeader := &block.HeaderV2{ + Header: &block.Header{ + Nonce: 37, + Round: 38, + Epoch: 1, + ReceiptsHash: []byte("receiptHash"), + DeveloperFees: big.NewInt(0), + AccumulatedFees: big.NewInt(0), + }, + } + + // reset the slice, do not call these tests in parallel + busyIdleCalled = make([]string, 0) + hdr, bodyHandler, err := sp.CreateBlock(header, doesHaveTime) + assert.False(t, check.IfNil(bodyHandler)) + body, ok := bodyHandler.(*block.Body) + assert.True(t, ok) + + assert.Zero(t, len(body.MiniBlocks)) + assert.False(t, check.IfNil(hdr)) + assert.Equal(t, expectedHeader, header) + assert.Nil(t, err) + assert.Equal(t, expectedBusyIdleSequencePerCall, busyIdleCalled) + }) + t.Run("should work with empty header v2 and epoch start rewriting the epoch value", func(t *testing.T) { + argumentsLocal := CreateMockArguments(createComponentHolderMocks()) + argumentsLocal.EpochStartTrigger = &mock.EpochStartTriggerStub{ + IsEpochStartCalled: func() bool { + return true + }, + MetaEpochCalled: func() uint32 { + return 2 + }, + } + + spLocal, err := blproc.NewShardProcessor(argumentsLocal) + assert.Nil(t, err) + + header := &block.HeaderV2{ + Header: &block.Header{ + Nonce: 37, + Round: 38, + Epoch: 1, + }, + } + + expectedHeader := &block.HeaderV2{ + Header: &block.Header{ + Nonce: 37, + Round: 38, + Epoch: 2, // epoch should be re-written + ReceiptsHash: []byte("receiptHash"), + DeveloperFees: big.NewInt(0), + AccumulatedFees: big.NewInt(0), + }, + } + + hdr, bodyHandler, err := spLocal.CreateBlock(header, doesHaveTime) + assert.False(t, check.IfNil(bodyHandler)) + body, ok := bodyHandler.(*block.Body) + assert.True(t, ok) + + assert.Zero(t, len(body.MiniBlocks)) + assert.False(t, check.IfNil(hdr)) + assert.Equal(t, expectedHeader, header) + assert.Nil(t, err) + }) +} diff --git a/process/interface.go b/process/interface.go index 7a8444b1719..8c324e2d7fe 100644 --- a/process/interface.go +++ b/process/interface.go @@ -1147,6 +1147,7 @@ type CryptoComponentsHolder interface { PublicKey() crypto.PublicKey PrivateKey() crypto.PrivateKey Clone() interface{} + ManagedPeersHolder() common.ManagedPeersHolder IsInterfaceNil() bool } diff --git a/process/mock/cryptoComponentsMock.go b/process/mock/cryptoComponentsMock.go index 606c919fdda..4b02784b69b 100644 --- a/process/mock/cryptoComponentsMock.go +++ b/process/mock/cryptoComponentsMock.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" ) @@ -18,6 +19,7 @@ type CryptoComponentsMock struct { TxKeyGen crypto.KeyGenerator PubKey crypto.PublicKey PrivKey crypto.PrivateKey + ManagedPeers common.ManagedPeersHolder mutMultiSig sync.RWMutex } @@ -88,6 +90,11 @@ func (ccm *CryptoComponentsMock) PrivateKey() crypto.PrivateKey { return ccm.PrivKey } +// ManagedPeersHolder - +func (ccm *CryptoComponentsMock) ManagedPeersHolder() common.ManagedPeersHolder { + return ccm.ManagedPeers +} + // Clone - func (ccm *CryptoComponentsMock) Clone() interface{} { return &CryptoComponentsMock{ diff --git a/scripts/testnet/include/config.sh b/scripts/testnet/include/config.sh index b2f7a40e3be..425339f99c6 100644 --- a/scripts/testnet/include/config.sh +++ b/scripts/testnet/include/config.sh @@ -1,16 +1,23 @@ generateConfig() { echo "Generating configuration using values from scripts/variables.sh..." + TMP_SHARD_OBSERVERCOUNT=$SHARD_OBSERVERCOUNT + TMP_META_OBSERVERCOUNT=$META_OBSERVERCOUNT + if [[ $MULTI_KEY_NODES -eq 1 ]]; then + TMP_SHARD_OBSERVERCOUNT=0 + TMP_META_OBSERVERCOUNT=0 + fi + pushd $TESTNETDIR/filegen ./filegen \ - -output-directory $CONFIGGENERATOROUTPUTDIR \ - -num-of-shards $SHARDCOUNT \ - -num-of-nodes-in-each-shard $SHARD_VALIDATORCOUNT \ - -num-of-observers-in-each-shard $SHARD_OBSERVERCOUNT \ - -consensus-group-size $SHARD_CONSENSUS_SIZE \ - -num-of-metachain-nodes $META_VALIDATORCOUNT \ - -num-of-observers-in-metachain $META_OBSERVERCOUNT \ - -metachain-consensus-group-size $META_CONSENSUS_SIZE \ + -output-directory $CONFIGGENERATOROUTPUTDIR \ + -num-of-shards $SHARDCOUNT \ + -num-of-nodes-in-each-shard $SHARD_VALIDATORCOUNT \ + -num-of-observers-in-each-shard $TMP_SHARD_OBSERVERCOUNT \ + -consensus-group-size $SHARD_CONSENSUS_SIZE \ + -num-of-metachain-nodes $META_VALIDATORCOUNT \ + -num-of-observers-in-metachain $TMP_META_OBSERVERCOUNT \ + -metachain-consensus-group-size $META_CONSENSUS_SIZE \ -stake-type $GENESIS_STAKE_TYPE \ -hysteresis $HYSTERESIS popd @@ -22,6 +29,9 @@ copyConfig() { cp ./filegen/"$CONFIGGENERATOROUTPUTDIR"/genesis.json ./node/config cp ./filegen/"$CONFIGGENERATOROUTPUTDIR"/nodesSetup.json ./node/config cp ./filegen/"$CONFIGGENERATOROUTPUTDIR"/*.pem ./node/config #there might be more .pem files there + if [[ $MULTI_KEY_NODES -eq 1 ]]; then + mv ./node/config/"$VALIDATOR_KEY_PEM_FILE" ./node/config/"$MULTI_KEY_PEM_FILE" + fi echo "Configuration files copied from the configuration generator to the working directories of the executables." popd } diff --git a/scripts/testnet/include/validators.sh b/scripts/testnet/include/validators.sh index dcea595d684..b19bad12525 100644 --- a/scripts/testnet/include/validators.sh +++ b/scripts/testnet/include/validators.sh @@ -4,7 +4,11 @@ startValidators() { setTerminalSession "multiversx-nodes" setTerminalLayout "tiled" setWorkdirForNextCommands "$TESTNETDIR/node" - iterateOverValidators startSingleValidator + if [[ $MULTI_KEY_NODES -eq 1 ]]; then + iterateOverValidatorsMultiKey startSingleValidator + else + iterateOverValidators startSingleValidator + fi } pauseValidators() { @@ -46,15 +50,44 @@ iterateOverValidators() { done } +iterateOverValidatorsMultiKey() { + local callback=$1 + local VALIDATOR_INDEX=0 + + # Iterate over shards and start validators + (( max_shard_id=$SHARDCOUNT - 1 )) + for SHARD in $(seq 0 1 $max_shard_id); do + if [ $VALIDATOR_INDEX -ne $SKIP_VALIDATOR_IDX ]; then + $callback $SHARD $VALIDATOR_INDEX + sleep 0.5 + fi + (( VALIDATOR_INDEX++ )) + done + + # Start Metachain Validator + SHARD="metachain" + if [ $VALIDATOR_INDEX -ne $SKIP_VALIDATOR_IDX ]; then + $callback $SHARD $VALIDATOR_INDEX + sleep 0.5 + fi + (( VALIDATOR_INDEX++ )) +} + startSingleValidator() { local SHARD=$1 local VALIDATOR_INDEX=$2 + + local DIR_NAME="validator" + if [[ $MULTI_KEY_NODES -eq 1 ]]; then + DIR_NAME="multikey" + fi + local startCommand="" if [ "$NODE_WATCHER" -eq 1 ]; then - setWorkdirForNextCommands "$TESTNETDIR/node_working_dirs/validator$VALIDATOR_INDEX" - startCommand="$(assembleCommand_startValidatorNodeWithWatcher $VALIDATOR_INDEX)" + setWorkdirForNextCommands "$TESTNETDIR/node_working_dirs/$DIR_NAME$VALIDATOR_INDEX" + startCommand="$(assembleCommand_startValidatorNodeWithWatcher $VALIDATOR_INDEX $DIR_NAME)" else - startCommand="$(assembleCommand_startValidatorNode $VALIDATOR_INDEX)" + startCommand="$(assembleCommand_startValidatorNode $VALIDATOR_INDEX $DIR_NAME)" fi runCommandInTerminal "$startCommand" } @@ -77,8 +110,13 @@ stopSingleValidator() { local SHARD=$1 local VALIDATOR_INDEX=$2 + local DIR_NAME="validator" + if [[ $MULTI_KEY_NODES -eq 1 ]]; then + DIR_NAME="multikey" + fi + if [ "$NODE_WATCHER" == "1" ]; then - WORKING_DIR=$TESTNETDIR/node_working_dirs/validator$VALIDATOR_INDEX + WORKING_DIR=$TESTNETDIR/node_working_dirs/$DIR_NAME$VALIDATOR_INDEX mkdir -p $WORKING_DIR touch $WORKING_DIR/norestart fi @@ -90,12 +128,13 @@ stopSingleValidator() { assembleCommand_startValidatorNodeWithWatcher() { VALIDATOR_INDEX=$1 + DIR_NAME=$2 (( PORT=$PORT_ORIGIN_VALIDATOR + $VALIDATOR_INDEX )) - WORKING_DIR=$TESTNETDIR/node_working_dirs/validator$VALIDATOR_INDEX + WORKING_DIR=$TESTNETDIR/node_working_dirs/$DIR_NAME$VALIDATOR_INDEX local source_command="source $MULTIVERSXTESTNETSCRIPTSDIR/include/watcher.sh" local watcher_command="node-start-with-watcher $VALIDATOR_INDEX $PORT &" - local node_command=$(assembleCommand_startValidatorNode $VALIDATOR_INDEX) + local node_command=$(assembleCommand_startValidatorNode $VALIDATOR_INDEX $DIR_NAME) mkdir -p $WORKING_DIR echo "$node_command" > $WORKING_DIR/node-command echo "$PORT" > $WORKING_DIR/node-port @@ -105,10 +144,11 @@ assembleCommand_startValidatorNodeWithWatcher() { assembleCommand_startValidatorNode() { VALIDATOR_INDEX=$1 + DIR_NAME=$2 (( PORT=$PORT_ORIGIN_VALIDATOR + $VALIDATOR_INDEX )) (( RESTAPIPORT=$PORT_ORIGIN_VALIDATOR_REST + $VALIDATOR_INDEX )) (( KEY_INDEX=$VALIDATOR_INDEX )) - WORKING_DIR=$TESTNETDIR/node_working_dirs/validator$VALIDATOR_INDEX + WORKING_DIR=$TESTNETDIR/node_working_dirs/$DIR_NAME$VALIDATOR_INDEX local node_command="./node \ -port $PORT --profile-mode -log-save -log-level $LOGLEVEL --log-logger-name --log-correlation --use-health-service -rest-api-interface localhost:$RESTAPIPORT \ diff --git a/scripts/testnet/variables.sh b/scripts/testnet/variables.sh index 94aac4c0e8c..14eff94e7e9 100644 --- a/scripts/testnet/variables.sh +++ b/scripts/testnet/variables.sh @@ -62,6 +62,9 @@ export META_VALIDATORCOUNT=3 export META_OBSERVERCOUNT=1 export META_CONSENSUS_SIZE=$META_VALIDATORCOUNT +# MULTI_KEY_NODES if set to 1, one observer will be generated on each shard that will handle all generated keys +export MULTI_KEY_NODES=0 + # ALWAYS_NEW_CHAINID will generate a fresh new chain ID each time start.sh/config.sh is called export ALWAYS_NEW_CHAINID=1 @@ -161,6 +164,16 @@ export TOTAL_OBSERVERCOUNT=$total_observer_count # to enable the full archive feature on the observers, please use the --full-archive flag export EXTRA_OBSERVERS_FLAGS="-operation-mode db-lookup-extension" +if [[ $MULTI_KEY_NODES -eq 1 ]]; then + EXTRA_OBSERVERS_FLAGS="--no-key" +fi + # Leave unchanged. let "total_node_count = $SHARD_VALIDATORCOUNT * $SHARDCOUNT + $META_VALIDATORCOUNT + $TOTAL_OBSERVERCOUNT" export TOTAL_NODECOUNT=$total_node_count + +# VALIDATOR_KEY_PEM_FILE is the pem file name when running single key mode, with all nodes' keys +export VALIDATOR_KEY_PEM_FILE="validatorKey.pem" + +# MULTI_KEY_PEM_FILE is the pem file name when running multi key mode, with all managed +export MULTI_KEY_PEM_FILE="allValidatorsKeys.pem" diff --git a/sharding/mock/enableEpochsHandlerMock.go b/sharding/mock/enableEpochsHandlerMock.go index a1c15518988..6173c091e32 100644 --- a/sharding/mock/enableEpochsHandlerMock.go +++ b/sharding/mock/enableEpochsHandlerMock.go @@ -546,6 +546,11 @@ func (mock *EnableEpochsHandlerMock) IsRuntimeMemStoreLimitEnabled() bool { return false } +// IsRuntimeCodeSizeFixEnabled - +func (mock *EnableEpochsHandlerMock) IsRuntimeCodeSizeFixEnabled() bool { + return false +} + // IsMaxBlockchainHookCountersFlagEnabled - func (mock *EnableEpochsHandlerMock) IsMaxBlockchainHookCountersFlagEnabled() bool { return false diff --git a/storage/clean/oldDataCleanerProvider.go b/storage/clean/oldDataCleanerProvider.go index 5c0285cc1d7..7d01c9005ac 100644 --- a/storage/clean/oldDataCleanerProvider.go +++ b/storage/clean/oldDataCleanerProvider.go @@ -7,27 +7,46 @@ import ( "github.com/multiversx/mx-chain-go/storage" ) +// ArgOldDataCleanerProvider is the argument used to create a new oldDataCleanerProvider instance +type ArgOldDataCleanerProvider struct { + NodeTypeProvider NodeTypeProviderHandler + PruningStorerConfig config.StoragePruningConfig + ManagedPeersHolder storage.ManagedPeersHolder +} + type oldDataCleanerProvider struct { nodeTypeProvider NodeTypeProviderHandler + managedPeersHolder storage.ManagedPeersHolder validatorCleanOldEpochsData bool observerCleanOldEpochsData bool } // NewOldDataCleanerProvider returns a new instance of oldDataCleanerProvider -func NewOldDataCleanerProvider( - nodeTypeProvider NodeTypeProviderHandler, - pruningStorerConfig config.StoragePruningConfig, -) (*oldDataCleanerProvider, error) { - if check.IfNil(nodeTypeProvider) { - return nil, storage.ErrNilNodeTypeProvider +func NewOldDataCleanerProvider(args ArgOldDataCleanerProvider) (*oldDataCleanerProvider, error) { + err := checkArgs(args) + if err != nil { + return nil, err } + return &oldDataCleanerProvider{ - nodeTypeProvider: nodeTypeProvider, - validatorCleanOldEpochsData: pruningStorerConfig.ValidatorCleanOldEpochsData, - observerCleanOldEpochsData: pruningStorerConfig.ObserverCleanOldEpochsData, + nodeTypeProvider: args.NodeTypeProvider, + validatorCleanOldEpochsData: args.PruningStorerConfig.ValidatorCleanOldEpochsData, + observerCleanOldEpochsData: args.PruningStorerConfig.ObserverCleanOldEpochsData, + managedPeersHolder: args.ManagedPeersHolder, }, nil } +func checkArgs(args ArgOldDataCleanerProvider) error { + if check.IfNil(args.NodeTypeProvider) { + return storage.ErrNilNodeTypeProvider + } + if check.IfNil(args.ManagedPeersHolder) { + return storage.ErrNilManagedPeersHolder + } + + return nil +} + // ShouldClean returns true if old data can be cleaned, based on current configuration, func (odcp *oldDataCleanerProvider) ShouldClean() bool { nodeType := odcp.nodeTypeProvider.GetType() @@ -41,7 +60,12 @@ func (odcp *oldDataCleanerProvider) ShouldClean() bool { shouldClean = odcp.observerCleanOldEpochsData } - log.Debug("oldDataCleanerProvider.ShouldClean", "node type", nodeType, "value", shouldClean) + isMultiKey := odcp.managedPeersHolder.IsMultiKeyMode() + if isMultiKey { + shouldClean = odcp.validatorCleanOldEpochsData + } + + log.Debug("oldDataCleanerProvider.ShouldClean", "node type", nodeType, "is multi key", isMultiKey, "value", shouldClean) return shouldClean } diff --git a/storage/clean/oldDataCleanerProvider_test.go b/storage/clean/oldDataCleanerProvider_test.go index 8bd4f873556..efdec71a9da 100644 --- a/storage/clean/oldDataCleanerProvider_test.go +++ b/storage/clean/oldDataCleanerProvider_test.go @@ -7,66 +7,132 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/stretchr/testify/require" ) -func TestNewOldDataCleanerProvider_NilNodeTypeProviderShouldErr(t *testing.T) { - t.Parallel() - - odcp, err := NewOldDataCleanerProvider(nil, config.StoragePruningConfig{}) - require.True(t, check.IfNil(odcp)) - require.Equal(t, storage.ErrNilNodeTypeProvider, err) -} - -func TestNewOldDataCleanerProvider_ShouldWork(t *testing.T) { - t.Parallel() - - odcp, err := NewOldDataCleanerProvider(&nodeTypeProviderMock.NodeTypeProviderStub{}, config.StoragePruningConfig{}) - require.NoError(t, err) - require.False(t, check.IfNil(odcp)) +func createMockArgOldDataCleanerProvider() ArgOldDataCleanerProvider { + return ArgOldDataCleanerProvider{ + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + PruningStorerConfig: config.StoragePruningConfig{}, + ManagedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + } } -func TestOldDataCleanerProvider_ShouldCleanShouldReturnObserverIfInvalidNodeType(t *testing.T) { +func TestNewOldDataCleanerProvider(t *testing.T) { t.Parallel() - ntp := &nodeTypeProviderMock.NodeTypeProviderStub{ - GetTypeCalled: func() core.NodeType { - return "invalid" - }, - } + t.Run("nil NodeTypeProvider should error", func(t *testing.T) { + t.Parallel() - odcp, _ := NewOldDataCleanerProvider(ntp, config.StoragePruningConfig{ - ObserverCleanOldEpochsData: true, - ValidatorCleanOldEpochsData: true, + args := createMockArgOldDataCleanerProvider() + args.NodeTypeProvider = nil + odcp, err := NewOldDataCleanerProvider(args) + require.True(t, check.IfNil(odcp)) + require.Equal(t, storage.ErrNilNodeTypeProvider, err) + }) + t.Run("nil ManagedPeersHolder should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgOldDataCleanerProvider() + args.ManagedPeersHolder = nil + odcp, err := NewOldDataCleanerProvider(args) + require.True(t, check.IfNil(odcp)) + require.Equal(t, storage.ErrNilManagedPeersHolder, err) }) + t.Run("should work", func(t *testing.T) { + t.Parallel() - require.False(t, odcp.ShouldClean()) + odcp, err := NewOldDataCleanerProvider(createMockArgOldDataCleanerProvider()) + require.NoError(t, err) + require.False(t, check.IfNil(odcp)) + }) } func TestOldDataCleanerProvider_ShouldClean(t *testing.T) { t.Parallel() - storagePruningConfig := config.StoragePruningConfig{ - ObserverCleanOldEpochsData: false, - ValidatorCleanOldEpochsData: true, - } + t.Run("invalid type should not clean", func(t *testing.T) { + t.Parallel() + + args := createMockArgOldDataCleanerProvider() + args.NodeTypeProvider = &nodeTypeProviderMock.NodeTypeProviderStub{ + GetTypeCalled: func() core.NodeType { + return "invalid" + }, + } + args.PruningStorerConfig = config.StoragePruningConfig{ + ObserverCleanOldEpochsData: true, + ValidatorCleanOldEpochsData: true, + } + odcp, _ := NewOldDataCleanerProvider(args) + + require.False(t, odcp.ShouldClean()) + }) + t.Run("observer should clean", func(t *testing.T) { + t.Parallel() - ntp := &nodeTypeProviderMock.NodeTypeProviderStub{ - GetTypeCalled: func() core.NodeType { - return core.NodeTypeValidator - }, - } + args := createMockArgOldDataCleanerProvider() + args.PruningStorerConfig = config.StoragePruningConfig{ + ObserverCleanOldEpochsData: true, + ValidatorCleanOldEpochsData: false, + } - odcp, _ := NewOldDataCleanerProvider(ntp, storagePruningConfig) - require.NotNil(t, odcp) + args.NodeTypeProvider = &nodeTypeProviderMock.NodeTypeProviderStub{ + GetTypeCalled: func() core.NodeType { + return core.NodeTypeObserver + }, + } - require.True(t, odcp.ShouldClean()) + odcp, _ := NewOldDataCleanerProvider(args) + require.NotNil(t, odcp) - odcp.nodeTypeProvider = &nodeTypeProviderMock.NodeTypeProviderStub{ - GetTypeCalled: func() core.NodeType { - return core.NodeTypeObserver - }, - } - require.False(t, odcp.ShouldClean()) + require.True(t, odcp.ShouldClean()) + }) + t.Run("validator should clean", func(t *testing.T) { + t.Parallel() + + args := createMockArgOldDataCleanerProvider() + args.PruningStorerConfig = config.StoragePruningConfig{ + ObserverCleanOldEpochsData: false, + ValidatorCleanOldEpochsData: true, + } + + args.NodeTypeProvider = &nodeTypeProviderMock.NodeTypeProviderStub{ + GetTypeCalled: func() core.NodeType { + return core.NodeTypeValidator + }, + } + + odcp, _ := NewOldDataCleanerProvider(args) + require.NotNil(t, odcp) + + require.True(t, odcp.ShouldClean()) + }) + t.Run("multi key observer should clean", func(t *testing.T) { + t.Parallel() + + args := createMockArgOldDataCleanerProvider() + args.ManagedPeersHolder = &testscommon.ManagedPeersHolderStub{ + IsMultiKeyModeCalled: func() bool { + return true + }, + } + args.PruningStorerConfig = config.StoragePruningConfig{ + ObserverCleanOldEpochsData: false, + ValidatorCleanOldEpochsData: true, + } + + args.NodeTypeProvider = &nodeTypeProviderMock.NodeTypeProviderStub{ + GetTypeCalled: func() core.NodeType { + return core.NodeTypeObserver + }, + } + + odcp, _ := NewOldDataCleanerProvider(args) + require.NotNil(t, odcp) + + require.True(t, odcp.ShouldClean()) + }) } diff --git a/storage/errors.go b/storage/errors.go index fdf1e7075d8..f33ace22458 100644 --- a/storage/errors.go +++ b/storage/errors.go @@ -88,6 +88,9 @@ var ErrEpochKeepIsLowerThanNumActive = errors.New("num epochs to keep is lower t // ErrNilPersistersTracker signals that a nil persisters tracker has been provided var ErrNilPersistersTracker = errors.New("nil persisters tracker provided") +// ErrNilManagedPeersHolder signals that a nil managed peers holder has been provided +var ErrNilManagedPeersHolder = errors.New("nil managed peers holder") + // IsNotFoundInStorageErr returns whether an error is a "not found in storage" error. // Currently, "item not found" storage errors are untyped (thus not distinguishable from others). E.g. see "pruningStorer.go". // As a workaround, we test the error message for a match. diff --git a/storage/factory/pruningStorerFactory.go b/storage/factory/pruningStorerFactory.go index 34293aaf9f3..d42f8db6ce6 100644 --- a/storage/factory/pruningStorerFactory.go +++ b/storage/factory/pruningStorerFactory.go @@ -48,6 +48,7 @@ type StorageServiceFactory struct { createTrieEpochRootHashStorer bool currentEpoch uint32 storageType StorageServiceType + snapshotsEnabled bool } // StorageServiceFactoryArgs holds the arguments needed for creating a new storage service factory @@ -59,8 +60,10 @@ type StorageServiceFactoryArgs struct { EpochStartNotifier epochStart.EpochStartNotifier NodeTypeProvider NodeTypeProviderHandler StorageType StorageServiceType + ManagedPeersHolder storage.ManagedPeersHolder CurrentEpoch uint32 CreateTrieEpochRootHashStorer bool + SnapshotsEnabled bool } // NewStorageServiceFactory will return a new instance of StorageServiceFactory @@ -70,10 +73,12 @@ func NewStorageServiceFactory(args StorageServiceFactoryArgs) (*StorageServiceFa return nil, err } - oldDataCleanProvider, err := clean.NewOldDataCleanerProvider( - args.NodeTypeProvider, - args.Config.StoragePruning, - ) + argsOldDataCleanerProvider := clean.ArgOldDataCleanerProvider{ + NodeTypeProvider: args.NodeTypeProvider, + PruningStorerConfig: args.Config.StoragePruning, + ManagedPeersHolder: args.ManagedPeersHolder, + } + oldDataCleanProvider, err := clean.NewOldDataCleanerProvider(argsOldDataCleanerProvider) if err != nil { return nil, err } @@ -91,6 +96,7 @@ func NewStorageServiceFactory(args StorageServiceFactoryArgs) (*StorageServiceFa createTrieEpochRootHashStorer: args.CreateTrieEpochRootHashStorer, oldDataCleanerProvider: oldDataCleanProvider, storageType: args.StorageType, + snapshotsEnabled: args.SnapshotsEnabled, }, nil } @@ -375,7 +381,7 @@ func (psf *StorageServiceFactory) createTrieUnit( storageConfig config.StorageConfig, pruningStorageArgs pruning.StorerArgs, ) (storage.Storer, error) { - if !psf.generalConfig.StateTriesConfig.SnapshotsEnabled { + if !psf.snapshotsEnabled { return psf.createTriePersister(storageConfig) } diff --git a/storage/interface.go b/storage/interface.go index a65d4b837d7..8f84cb400d6 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -201,3 +201,9 @@ type AdaptedSizedLRUCache interface { AddSizedAndReturnEvicted(key, value interface{}, sizeInBytes int64) map[interface{}]interface{} IsInterfaceNil() bool } + +// ManagedPeersHolder defines the operations of an entity that holds managed identities for a node +type ManagedPeersHolder interface { + IsMultiKeyMode() bool + IsInterfaceNil() bool +} diff --git a/testscommon/components/components.go b/testscommon/components/components.go index af691ea04d7..c451d18df21 100644 --- a/testscommon/components/components.go +++ b/testscommon/components/components.go @@ -215,9 +215,10 @@ func GetDataArgs(coreComponents factory.CoreComponentsHolder, shardCoordinator s ShardCoordinator: shardCoordinator, Core: coreComponents, StatusCore: GetStatusCoreComponents(), - EpochStartNotifier: &mock.EpochStartNotifierStub{}, + Crypto: GetCryptoComponents(coreComponents), CurrentEpoch: 0, CreateTrieEpochRootHashStorer: false, + SnapshotsEnabled: false, } } @@ -565,8 +566,9 @@ func GetProcessArgs( MaxServiceFee: 100, }, }, - Version: "v1.0.0", - HistoryRepo: &dblookupext.HistoryRepositoryStub{}, + Version: "v1.0.0", + HistoryRepo: &dblookupext.HistoryRepositoryStub{}, + SnapshotsEnabled: false, } } diff --git a/testscommon/components/default.go b/testscommon/components/default.go index e64653dd42d..bee3de0ac47 100644 --- a/testscommon/components/default.go +++ b/testscommon/components/default.go @@ -6,7 +6,6 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-go/common" consensusMocks "github.com/multiversx/mx-chain-go/consensus/mock" - dataRetrieverMock "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/factory/mock" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/testscommon" @@ -71,7 +70,7 @@ func GetDefaultCryptoComponents() *mock.CryptoComponentsMock { TxKeyGen: &mock.KeyGenMock{}, P2PKeyGen: &mock.KeyGenMock{}, MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, - SigHandler: &consensusMocks.SignatureHandlerStub{}, + SigHandler: &consensusMocks.SigningHandlerStub{}, } } @@ -115,7 +114,7 @@ func GetDefaultProcessComponents(shardCoordinator sharding.Coordinator) *mock.Pr NodesCoord: &shardingMocks.NodesCoordinatorMock{}, ShardCoord: shardCoordinator, IntContainer: &testscommon.InterceptorsContainerStub{}, - ResContainer: &dataRetrieverMock.ResolversContainerStub{}, + ResContainer: &dataRetrieverTests.ResolversContainerStub{}, ReqFinder: &dataRetrieverTests.RequestersFinderStub{}, RoundHandlerField: &testscommon.RoundHandlerMock{}, EpochTrigger: &testscommon.EpochStartTriggerStub{}, diff --git a/dataRetriever/mock/resolversContainerStub.go b/testscommon/dataRetriever/resolversContainerStub.go similarity index 98% rename from dataRetriever/mock/resolversContainerStub.go rename to testscommon/dataRetriever/resolversContainerStub.go index dbeec5f98e5..7f3ac79094e 100644 --- a/dataRetriever/mock/resolversContainerStub.go +++ b/testscommon/dataRetriever/resolversContainerStub.go @@ -1,4 +1,4 @@ -package mock +package dataRetriever import ( "github.com/multiversx/mx-chain-go/dataRetriever" diff --git a/testscommon/enableEpochsHandlerStub.go b/testscommon/enableEpochsHandlerStub.go index 5fa9c87a3a4..092131f8ebc 100644 --- a/testscommon/enableEpochsHandlerStub.go +++ b/testscommon/enableEpochsHandlerStub.go @@ -113,6 +113,7 @@ type EnableEpochsHandlerStub struct { IsFixAsyncCallBackArgsListFlagEnabledField bool IsFixOldTokenLiquidityEnabledField bool IsRuntimeMemStoreLimitEnabledField bool + IsRuntimeCodeSizeFixEnabledField bool IsMaxBlockchainHookCountersFlagEnabledField bool IsWipeSingleNFTLiquidityDecreaseEnabledField bool IsAlwaysSaveTokenMetaDataEnabledField bool @@ -981,6 +982,14 @@ func (stub *EnableEpochsHandlerStub) IsRuntimeMemStoreLimitEnabled() bool { return stub.IsRuntimeMemStoreLimitEnabledField } +// IsRuntimeCodeSizeFixEnabled - +func (stub *EnableEpochsHandlerStub) IsRuntimeCodeSizeFixEnabled() bool { + stub.RLock() + defer stub.RUnlock() + + return stub.IsRuntimeCodeSizeFixEnabledField +} + // IsMaxBlockchainHookCountersFlagEnabled - func (stub *EnableEpochsHandlerStub) IsMaxBlockchainHookCountersFlagEnabled() bool { stub.RLock() diff --git a/testscommon/generalConfig.go b/testscommon/generalConfig.go index 50d20dd1001..132effecc4e 100644 --- a/testscommon/generalConfig.go +++ b/testscommon/generalConfig.go @@ -124,7 +124,6 @@ func GetGeneralConfig() config.Config { }, StateTriesConfig: config.StateTriesConfig{ CheckpointRoundsModulus: 100, - SnapshotsEnabled: false, CheckpointsEnabled: false, AccountsStatePruningEnabled: false, PeerStatePruningEnabled: false, @@ -273,6 +272,7 @@ func GetGeneralConfig() config.Config { HideInactiveValidatorIntervalInSec: 60, HardforkTimeBetweenSendsInSec: 5, TimeBetweenConnectionsMetricsUpdateInSec: 10, + PeerAuthenticationTimeBetweenChecksInSec: 1, HeartbeatPool: getLRUCacheConfig(), }, StatusMetricsStorage: config.StorageConfig{ diff --git a/testscommon/keysHandlerSingleSignerMock.go b/testscommon/keysHandlerSingleSignerMock.go new file mode 100644 index 00000000000..c70bc381dc8 --- /dev/null +++ b/testscommon/keysHandlerSingleSignerMock.go @@ -0,0 +1,73 @@ +package testscommon + +import ( + "bytes" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-crypto-go" +) + +type keysHandlerSingleSignerMock struct { + privateKey crypto.PrivateKey + publicKey crypto.PublicKey + pkBytes []byte + pid core.PeerID +} + +// NewKeysHandlerSingleSignerMock - +func NewKeysHandlerSingleSignerMock( + privateKey crypto.PrivateKey, + pid core.PeerID, +) *keysHandlerSingleSignerMock { + pk := privateKey.GeneratePublic() + pkBytes, _ := pk.ToByteArray() + + return &keysHandlerSingleSignerMock{ + privateKey: privateKey, + publicKey: pk, + pkBytes: pkBytes, + pid: pid, + } +} + +// GetHandledPrivateKey - +func (mock *keysHandlerSingleSignerMock) GetHandledPrivateKey(_ []byte) crypto.PrivateKey { + return mock.privateKey +} + +// GetP2PIdentity - +func (mock *keysHandlerSingleSignerMock) GetP2PIdentity(_ []byte) ([]byte, core.PeerID, error) { + return make([]byte, 0), "", nil +} + +// IsKeyManagedByCurrentNode - +func (mock *keysHandlerSingleSignerMock) IsKeyManagedByCurrentNode(_ []byte) bool { + return false +} + +// IncrementRoundsWithoutReceivedMessages - +func (mock *keysHandlerSingleSignerMock) IncrementRoundsWithoutReceivedMessages(_ []byte) { +} + +// GetAssociatedPid - +func (mock *keysHandlerSingleSignerMock) GetAssociatedPid(pkBytes []byte) core.PeerID { + if bytes.Equal(mock.pkBytes, pkBytes) { + return mock.pid + } + + return "" +} + +// IsOriginalPublicKeyOfTheNode - +func (mock *keysHandlerSingleSignerMock) IsOriginalPublicKeyOfTheNode(pkBytes []byte) bool { + return bytes.Equal(mock.pkBytes, pkBytes) +} + +// UpdatePublicKeyLiveness - +func (mock *keysHandlerSingleSignerMock) UpdatePublicKeyLiveness(_ []byte, _ core.PeerID) { +} + +// IsInterfaceNil - +func (mock *keysHandlerSingleSignerMock) IsInterfaceNil() bool { + return mock == nil +} diff --git a/testscommon/keysHandlerStub.go b/testscommon/keysHandlerStub.go new file mode 100644 index 00000000000..616f6d3c3db --- /dev/null +++ b/testscommon/keysHandlerStub.go @@ -0,0 +1,82 @@ +package testscommon + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" +) + +// KeysHandlerStub - +type KeysHandlerStub struct { + GetHandledPrivateKeyCalled func(pkBytes []byte) crypto.PrivateKey + GetP2PIdentityCalled func(pkBytes []byte) ([]byte, core.PeerID, error) + IsKeyManagedByCurrentNodeCalled func(pkBytes []byte) bool + IncrementRoundsWithoutReceivedMessagesCalled func(pkBytes []byte) + GetAssociatedPidCalled func(pkBytes []byte) core.PeerID + IsOriginalPublicKeyOfTheNodeCalled func(pkBytes []byte) bool + UpdatePublicKeyLivenessCalled func(pkBytes []byte, pid core.PeerID) +} + +// GetHandledPrivateKey - +func (stub *KeysHandlerStub) GetHandledPrivateKey(pkBytes []byte) crypto.PrivateKey { + if stub.GetHandledPrivateKeyCalled != nil { + return stub.GetHandledPrivateKeyCalled(pkBytes) + } + + return &cryptoMocks.PrivateKeyStub{} +} + +// GetP2PIdentity - +func (stub *KeysHandlerStub) GetP2PIdentity(pkBytes []byte) ([]byte, core.PeerID, error) { + if stub.GetP2PIdentityCalled != nil { + return stub.GetP2PIdentityCalled(pkBytes) + } + + return make([]byte, 0), "", nil +} + +// IsKeyManagedByCurrentNode - +func (stub *KeysHandlerStub) IsKeyManagedByCurrentNode(pkBytes []byte) bool { + if stub.IsKeyManagedByCurrentNodeCalled != nil { + return stub.IsKeyManagedByCurrentNodeCalled(pkBytes) + } + + return false +} + +// IncrementRoundsWithoutReceivedMessages - +func (stub *KeysHandlerStub) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { + if stub.IncrementRoundsWithoutReceivedMessagesCalled != nil { + stub.IncrementRoundsWithoutReceivedMessagesCalled(pkBytes) + } +} + +// GetAssociatedPid - +func (stub *KeysHandlerStub) GetAssociatedPid(pkBytes []byte) core.PeerID { + if stub.GetAssociatedPidCalled != nil { + return stub.GetAssociatedPidCalled(pkBytes) + } + + return "" +} + +// IsOriginalPublicKeyOfTheNode - +func (stub *KeysHandlerStub) IsOriginalPublicKeyOfTheNode(pkBytes []byte) bool { + if stub.IsOriginalPublicKeyOfTheNodeCalled != nil { + return stub.IsOriginalPublicKeyOfTheNodeCalled(pkBytes) + } + + return true +} + +// UpdatePublicKeyLiveness - +func (stub *KeysHandlerStub) UpdatePublicKeyLiveness(pkBytes []byte, pid core.PeerID) { + if stub.UpdatePublicKeyLivenessCalled != nil { + stub.UpdatePublicKeyLivenessCalled(pkBytes, pid) + } +} + +// IsInterfaceNil - +func (stub *KeysHandlerStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/testscommon/managedPeersHolderStub.go b/testscommon/managedPeersHolderStub.go new file mode 100644 index 00000000000..ad7bf309c91 --- /dev/null +++ b/testscommon/managedPeersHolderStub.go @@ -0,0 +1,157 @@ +package testscommon + +import ( + "time" + + "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// ManagedPeersHolderStub - +type ManagedPeersHolderStub struct { + AddManagedPeerCalled func(privateKeyBytes []byte) error + GetPrivateKeyCalled func(pkBytes []byte) (crypto.PrivateKey, error) + GetP2PIdentityCalled func(pkBytes []byte) ([]byte, core.PeerID, error) + GetMachineIDCalled func(pkBytes []byte) (string, error) + GetNameAndIdentityCalled func(pkBytes []byte) (string, string, error) + IncrementRoundsWithoutReceivedMessagesCalled func(pkBytes []byte) + ResetRoundsWithoutReceivedMessagesCalled func(pkBytes []byte) + GetManagedKeysByCurrentNodeCalled func() map[string]crypto.PrivateKey + IsKeyManagedByCurrentNodeCalled func(pkBytes []byte) bool + IsKeyRegisteredCalled func(pkBytes []byte) bool + IsPidManagedByCurrentNodeCalled func(pid core.PeerID) bool + IsKeyValidatorCalled func(pkBytes []byte) bool + SetValidatorStateCalled func(pkBytes []byte, state bool) + GetNextPeerAuthenticationTimeCalled func(pkBytes []byte) (time.Time, error) + SetNextPeerAuthenticationTimeCalled func(pkBytes []byte, nextTime time.Time) + IsMultiKeyModeCalled func() bool +} + +// AddManagedPeer - +func (stub *ManagedPeersHolderStub) AddManagedPeer(privateKeyBytes []byte) error { + if stub.AddManagedPeerCalled != nil { + return stub.AddManagedPeerCalled(privateKeyBytes) + } + return nil +} + +// GetPrivateKey - +func (stub *ManagedPeersHolderStub) GetPrivateKey(pkBytes []byte) (crypto.PrivateKey, error) { + if stub.GetPrivateKeyCalled != nil { + return stub.GetPrivateKeyCalled(pkBytes) + } + return nil, nil +} + +// GetP2PIdentity - +func (stub *ManagedPeersHolderStub) GetP2PIdentity(pkBytes []byte) ([]byte, core.PeerID, error) { + if stub.GetP2PIdentityCalled != nil { + return stub.GetP2PIdentityCalled(pkBytes) + } + return nil, "", nil +} + +// GetMachineID - +func (stub *ManagedPeersHolderStub) GetMachineID(pkBytes []byte) (string, error) { + if stub.GetMachineIDCalled != nil { + return stub.GetMachineIDCalled(pkBytes) + } + return "", nil +} + +// GetNameAndIdentity - +func (stub *ManagedPeersHolderStub) GetNameAndIdentity(pkBytes []byte) (string, string, error) { + if stub.GetNameAndIdentityCalled != nil { + return stub.GetNameAndIdentityCalled(pkBytes) + } + return "", "", nil +} + +// IncrementRoundsWithoutReceivedMessages - +func (stub *ManagedPeersHolderStub) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { + if stub.IncrementRoundsWithoutReceivedMessagesCalled != nil { + stub.IncrementRoundsWithoutReceivedMessagesCalled(pkBytes) + } +} + +// ResetRoundsWithoutReceivedMessages - +func (stub *ManagedPeersHolderStub) ResetRoundsWithoutReceivedMessages(pkBytes []byte) { + if stub.ResetRoundsWithoutReceivedMessagesCalled != nil { + stub.ResetRoundsWithoutReceivedMessagesCalled(pkBytes) + } +} + +// GetManagedKeysByCurrentNode - +func (stub *ManagedPeersHolderStub) GetManagedKeysByCurrentNode() map[string]crypto.PrivateKey { + if stub.GetManagedKeysByCurrentNodeCalled != nil { + return stub.GetManagedKeysByCurrentNodeCalled() + } + return nil +} + +// IsKeyManagedByCurrentNode - +func (stub *ManagedPeersHolderStub) IsKeyManagedByCurrentNode(pkBytes []byte) bool { + if stub.IsKeyManagedByCurrentNodeCalled != nil { + return stub.IsKeyManagedByCurrentNodeCalled(pkBytes) + } + return false +} + +// IsKeyRegistered - +func (stub *ManagedPeersHolderStub) IsKeyRegistered(pkBytes []byte) bool { + if stub.IsKeyRegisteredCalled != nil { + return stub.IsKeyRegisteredCalled(pkBytes) + } + return false +} + +// IsPidManagedByCurrentNode - +func (stub *ManagedPeersHolderStub) IsPidManagedByCurrentNode(pid core.PeerID) bool { + if stub.IsPidManagedByCurrentNodeCalled != nil { + return stub.IsPidManagedByCurrentNodeCalled(pid) + } + return false +} + +// IsKeyValidator - +func (stub *ManagedPeersHolderStub) IsKeyValidator(pkBytes []byte) bool { + if stub.IsKeyValidatorCalled != nil { + return stub.IsKeyValidatorCalled(pkBytes) + } + return false +} + +// SetValidatorState - +func (stub *ManagedPeersHolderStub) SetValidatorState(pkBytes []byte, state bool) { + if stub.SetValidatorStateCalled != nil { + stub.SetValidatorStateCalled(pkBytes, state) + } +} + +// GetNextPeerAuthenticationTime - +func (stub *ManagedPeersHolderStub) GetNextPeerAuthenticationTime(pkBytes []byte) (time.Time, error) { + if stub.GetNextPeerAuthenticationTimeCalled != nil { + return stub.GetNextPeerAuthenticationTimeCalled(pkBytes) + } + return time.Time{}, nil +} + +// SetNextPeerAuthenticationTime - +func (stub *ManagedPeersHolderStub) SetNextPeerAuthenticationTime(pkBytes []byte, nextTime time.Time) { + if stub.SetNextPeerAuthenticationTimeCalled != nil { + stub.SetNextPeerAuthenticationTimeCalled(pkBytes, nextTime) + } +} + +// IsMultiKeyMode - +func (stub *ManagedPeersHolderStub) IsMultiKeyMode() bool { + if stub.IsMultiKeyModeCalled != nil { + return stub.IsMultiKeyModeCalled() + } + return false +} + +// IsInterfaceNil - +func (stub *ManagedPeersHolderStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/testscommon/p2pmocks/p2pKeyConverterStub.go b/testscommon/p2pmocks/p2pKeyConverterStub.go new file mode 100644 index 00000000000..aa56127b927 --- /dev/null +++ b/testscommon/p2pmocks/p2pKeyConverterStub.go @@ -0,0 +1,36 @@ +package p2pmocks + +import ( + "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" +) + +// P2PKeyConverterStub - +type P2PKeyConverterStub struct { + ConvertPeerIDToPublicKeyCalled func(keyGen crypto.KeyGenerator, pid core.PeerID) (crypto.PublicKey, error) + ConvertPublicKeyToPeerIDCalled func(pk crypto.PublicKey) (core.PeerID, error) +} + +// ConvertPeerIDToPublicKey - +func (stub *P2PKeyConverterStub) ConvertPeerIDToPublicKey(keyGen crypto.KeyGenerator, pid core.PeerID) (crypto.PublicKey, error) { + if stub.ConvertPeerIDToPublicKeyCalled != nil { + return stub.ConvertPeerIDToPublicKeyCalled(keyGen, pid) + } + + return &cryptoMocks.PublicKeyStub{}, nil +} + +// ConvertPublicKeyToPeerID - +func (stub *P2PKeyConverterStub) ConvertPublicKeyToPeerID(pk crypto.PublicKey) (core.PeerID, error) { + if stub.ConvertPublicKeyToPeerIDCalled != nil { + return stub.ConvertPublicKeyToPeerIDCalled(pk) + } + + return "", nil +} + +// IsInterfaceNil - +func (stub *P2PKeyConverterStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/testscommon/trie/trieStub.go b/testscommon/trie/trieStub.go index a82b695da41..b6707e2752e 100644 --- a/testscommon/trie/trieStub.go +++ b/testscommon/trie/trieStub.go @@ -27,7 +27,6 @@ type TrieStub struct { VerifyProofCalled func(rootHash []byte, key []byte, proof [][]byte) (bool, error) GetStorageManagerCalled func() common.StorageManager GetSerializedNodeCalled func(bytes []byte) ([]byte, error) - GetNumNodesCalled func() common.NumNodesDTO GetOldRootCalled func() []byte CloseCalled func() error } @@ -185,15 +184,6 @@ func (ts *TrieStub) GetSerializedNode(bytes []byte) ([]byte, error) { return nil, nil } -// GetNumNodes - -func (ts *TrieStub) GetNumNodes() common.NumNodesDTO { - if ts.GetNumNodesCalled != nil { - return ts.GetNumNodesCalled() - } - - return common.NumNodesDTO{} -} - // GetOldRoot - func (ts *TrieStub) GetOldRoot() []byte { if ts.GetOldRootCalled != nil { diff --git a/trie/branchNode.go b/trie/branchNode.go index 6adc4c180dd..3e6f26768b5 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -855,34 +855,6 @@ func (bn *branchNode) getAllHashes(db common.DBWriteCacher) ([][]byte, error) { return hashes, nil } -func (bn *branchNode) getNumNodes() common.NumNodesDTO { - if check.IfNil(bn) { - return common.NumNodesDTO{} - } - - currentNumNodes := common.NumNodesDTO{ - Branches: 1, - } - - for _, n := range bn.children { - if check.IfNil(n) { - continue - } - - childNumNodes := n.getNumNodes() - currentNumNodes.Branches += childNumNodes.Branches - currentNumNodes.Leaves += childNumNodes.Leaves - currentNumNodes.Extensions += childNumNodes.Extensions - if childNumNodes.MaxLevel > currentNumNodes.MaxLevel { - currentNumNodes.MaxLevel = childNumNodes.MaxLevel - } - } - - currentNumNodes.MaxLevel++ - - return currentNumNodes -} - func (bn *branchNode) getNextHashAndKey(key []byte) (bool, []byte, []byte) { if len(key) == 0 || check.IfNil(bn) { return false, nil, nil diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index f79544b403c..a121e8b21aa 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -1319,15 +1319,6 @@ func TestBranchNode_getNextHashAndKeyNilNode(t *testing.T) { assert.Nil(t, nextKey) } -func TestBranchNode_GetNumNodesNilSelfShouldErr(t *testing.T) { - t.Parallel() - - var bn *branchNode - numNodes := bn.getNumNodes() - - assert.Equal(t, common.NumNodesDTO{}, numNodes) -} - func TestBranchNode_SizeInBytes(t *testing.T) { t.Parallel() diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 72de749f475..8130a761233 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -601,18 +601,6 @@ func (en *extensionNode) getChildren(db common.DBWriteCacher) ([]node, error) { return nextNodes, nil } -func (en *extensionNode) getNumNodes() common.NumNodesDTO { - if check.IfNil(en) { - return common.NumNodesDTO{} - } - - childNumNodes := en.child.getNumNodes() - childNumNodes.Extensions++ - childNumNodes.MaxLevel++ - - return childNumNodes -} - func (en *extensionNode) isValid() bool { if len(en.EncodedChild) == 0 && en.child == nil { return false diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 32d36a33222..cc8dd806d2c 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -987,15 +987,6 @@ func TestExtensionNode_getNextHashAndKeyNilNode(t *testing.T) { assert.Nil(t, nextKey) } -func TestExtensionNode_GetNumNodesNilSelfShouldErr(t *testing.T) { - t.Parallel() - - var en *extensionNode - numNodes := en.getNumNodes() - - assert.Equal(t, common.NumNodesDTO{}, numNodes) -} - func TestExtensionNode_SizeInBytes(t *testing.T) { t.Parallel() diff --git a/trie/factory/trieCreator.go b/trie/factory/trieCreator.go index 1a66b9f8359..8d90ba47844 100644 --- a/trie/factory/trieCreator.go +++ b/trie/factory/trieCreator.go @@ -106,6 +106,7 @@ func (tc *trieCreator) IsInterfaceNil() bool { // CreateTriesComponentsForShardId creates the user and peer tries and trieStorageManagers func CreateTriesComponentsForShardId( + snapshotsEnabled bool, generalConfig config.Config, coreComponentsHolder coreComponentsHandler, storageService dataRetriever.StorageService, @@ -137,7 +138,7 @@ func CreateTriesComponentsForShardId( PruningEnabled: generalConfig.StateTriesConfig.AccountsStatePruningEnabled, CheckpointsEnabled: generalConfig.StateTriesConfig.CheckpointsEnabled, MaxTrieLevelInMem: generalConfig.StateTriesConfig.MaxStateTrieLevelInMemory, - SnapshotsEnabled: generalConfig.StateTriesConfig.SnapshotsEnabled, + SnapshotsEnabled: snapshotsEnabled, IdleProvider: coreComponentsHolder.ProcessStatusHandler(), } userStorageManager, userAccountTrie, err := trFactory.Create(args) @@ -167,7 +168,7 @@ func CreateTriesComponentsForShardId( PruningEnabled: generalConfig.StateTriesConfig.PeerStatePruningEnabled, CheckpointsEnabled: generalConfig.StateTriesConfig.CheckpointsEnabled, MaxTrieLevelInMem: generalConfig.StateTriesConfig.MaxPeerTrieLevelInMemory, - SnapshotsEnabled: generalConfig.StateTriesConfig.SnapshotsEnabled, + SnapshotsEnabled: snapshotsEnabled, IdleProvider: coreComponentsHolder.ProcessStatusHandler(), } peerStorageManager, peerAccountsTrie, err := trFactory.Create(args) diff --git a/trie/factory/trieCreator_test.go b/trie/factory/trieCreator_test.go index 36652eca941..375969eb070 100644 --- a/trie/factory/trieCreator_test.go +++ b/trie/factory/trieCreator_test.go @@ -179,6 +179,7 @@ func testWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *testing.T t.Parallel() holder, storageManager, err := factory.CreateTriesComponentsForShardId( + false, testscommon.GetGeneralConfig(), &mock.CoreComponentsStub{ InternalMarshalizerField: &testscommon.MarshalizerMock{}, diff --git a/trie/interface.go b/trie/interface.go index 72be5dea54b..4c6ffb45572 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -42,7 +42,6 @@ type node interface { getAllLeavesOnChannel(chan core.KeyValueHolder, common.KeyBuilder, common.DBWriteCacher, marshal.Marshalizer, chan struct{}, context.Context) error getAllHashes(db common.DBWriteCacher) ([][]byte, error) getNextHashAndKey([]byte) (bool, []byte, []byte) - getNumNodes() common.NumNodesDTO getValue() []byte commitDirty(level byte, maxTrieLevelInMemory uint, originDb common.DBWriteCacher, targetDb common.DBWriteCacher) error diff --git a/trie/leafNode.go b/trie/leafNode.go index a12d827f162..cb4c4bfdc76 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -413,13 +413,6 @@ func (ln *leafNode) getChildren(_ common.DBWriteCacher) ([]node, error) { return nil, nil } -func (ln *leafNode) getNumNodes() common.NumNodesDTO { - return common.NumNodesDTO{ - Leaves: 1, - MaxLevel: 1, - } -} - func (ln *leafNode) isValid() bool { return len(ln.Value) > 0 } diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index 300388e6cfc..a424b8a8bea 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -589,19 +589,6 @@ func (tr *patriciaMerkleTrie) VerifyProof(rootHash []byte, key []byte, proof [][ return false, nil } -// GetNumNodes will return the trie nodes statistics DTO -func (tr *patriciaMerkleTrie) GetNumNodes() common.NumNodesDTO { - tr.mutOperation.Lock() - defer tr.mutOperation.Unlock() - - n := tr.root - if check.IfNil(n) { - return common.NumNodesDTO{} - } - - return n.getNumNodes() -} - // GetStorageManager returns the storage manager for the trie func (tr *patriciaMerkleTrie) GetStorageManager() common.StorageManager { return tr.trieStorage diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 6847403164d..fc9a23a1843 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -890,15 +890,6 @@ func dumpTrieContents(tr common.Trie, values [][]byte) { } } -func TestPatriciaMerkleTrie_GetNumNodesNilRootShouldReturnEmpty(t *testing.T) { - t.Parallel() - - tr := emptyTrie() - - numNodes := tr.GetNumNodes() - assert.Equal(t, common.NumNodesDTO{}, numNodes) -} - func TestPatriciaMerkleTrie_GetTrieStats(t *testing.T) { t.Parallel() @@ -928,21 +919,6 @@ func TestPatriciaMerkleTrie_GetTrieStats(t *testing.T) { assert.Equal(t, uint32(3), stats.MaxTrieDepth) } -func TestPatriciaMerkleTrie_GetNumNodes(t *testing.T) { - t.Parallel() - - tr := emptyTrie() - _ = tr.Update([]byte("eod"), []byte("reindeer")) - _ = tr.Update([]byte("god"), []byte("puppy")) - _ = tr.Update([]byte("eggod"), []byte("cat")) - - numNodes := tr.GetNumNodes() - assert.Equal(t, 5, numNodes.MaxLevel) - assert.Equal(t, 3, numNodes.Leaves) - assert.Equal(t, 2, numNodes.Extensions) - assert.Equal(t, 2, numNodes.Branches) -} - func TestPatriciaMerkleTrie_GetOldRoot(t *testing.T) { t.Parallel() @@ -979,7 +955,7 @@ func TestPatriciaMerkleTrie_ConcurrentOperations(t *testing.T) { numOperations := 1000 wg := sync.WaitGroup{} wg.Add(numOperations) - numFunctions := 20 + numFunctions := 19 initialRootHash, _ := tr.RootHash() @@ -1051,14 +1027,11 @@ func TestPatriciaMerkleTrie_ConcurrentOperations(t *testing.T) { // extremely hard to compute an existing hash due to concurrent changes. _, _ = tr.VerifyProof([]byte("dog"), []byte("puppy"), [][]byte{[]byte("proof1")}) // this might error due to concurrent operations that change the roothash case 16: - numNodes := tr.GetNumNodes() - assert.Equal(t, 4, numNodes.MaxLevel) - case 17: sm := tr.GetStorageManager() assert.NotNil(t, sm) - case 18: + case 17: _ = tr.GetOldRoot() - case 19: + case 18: trieStatsHandler := tr.(common.TrieStats) _, err := trieStatsHandler.GetTrieStats("address", initialRootHash) assert.Nil(t, err)