diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2bad3c860d5..f7d1741b59c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,9 +1,9 @@ -## Description of the reasoning behind the pull request (what feature was missing / how the problem was manifesting itself / what was the motive behind the refactoring) +## Reasoning behind the pull request - - - -## Proposed Changes +## Proposed changes - - - diff --git a/.github/release.yml b/.github/release.yml index 5d8e225d1c2..c54b26c2739 100644 --- a/.github/release.yml +++ b/.github/release.yml @@ -1,7 +1,7 @@ changelog: exclude: labels: - - ignore-for-release + - ignore-for-release-notes categories: - title: Breaking Changes labels: diff --git a/api/shared/logging/logging.go b/api/shared/logging/logging.go index 829b74fb2dc..ede27e35f5f 100644 --- a/api/shared/logging/logging.go +++ b/api/shared/logging/logging.go @@ -8,6 +8,7 @@ import ( ) var log = logger.GetOrCreate("api/shared/logging") + const thresholdMinAPICallDurationToLog = 200 * time.Millisecond // LogAPIActionDurationIfNeeded will log the duration of an action triggered by an API call if it's above a threshold diff --git a/cmd/keygenerator/main.go b/cmd/keygenerator/main.go index ba480c29915..f608ceb7ecd 100644 --- a/cmd/keygenerator/main.go +++ b/cmd/keygenerator/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "crypto/rand" "encoding/hex" "encoding/pem" "fmt" @@ -19,6 +20,7 @@ import ( "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" logger "github.com/ElrondNetwork/elrond-go-logger" + libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" "github.com/urfave/cli" ) @@ -33,6 +35,7 @@ type cfg struct { const validatorType = "validator" const walletType = "wallet" +const p2pType = "p2p" const bothType = "both" const minedWalletPrefixKeys = "mined-wallet" const nopattern = "nopattern" @@ -76,9 +79,10 @@ VERSION: keyType = cli.StringFlag{ Name: "key-type", Usage: fmt.Sprintf( - "What kind of keys should generate. Available options: %s, %s, %s, %s", + "What kind of keys should generate. Available options: %s, %s, %s, %s, %s", validatorType, walletType, + p2pType, bothType, minedWalletPrefixKeys), Value: "validator", @@ -116,10 +120,12 @@ VERSION: walletKeyFilenameTemplate = "walletKey%s.pem" validatorKeyFilenameTemplate = "validatorKey%s.pem" + p2pKeyFilenameTemplate = "p2pKey%s.pem" log = logger.GetOrCreate("keygenerator") validatorPubKeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(blsPubkeyLen) + p2pPubKeyConverter = NewP2pConverter() walletPubKeyConverter, _ = pubkeyConverter.NewBech32PubkeyConverter(txSignPubkeyLen, log) ) @@ -157,21 +163,22 @@ func main() { } func process() error { - validatorKeys, walletKeys, err := generateKeys(argsConfig.keyType, argsConfig.numKeys, argsConfig.prefixPattern, argsConfig.shardIDByte) + validatorKeys, walletKeys, p2pKeys, err := generateKeys(argsConfig.keyType, argsConfig.numKeys, argsConfig.prefixPattern, argsConfig.shardIDByte) if err != nil { return err } - return outputKeys(validatorKeys, walletKeys, argsConfig.consoleOut, argsConfig.noSplit) + return outputKeys(validatorKeys, walletKeys, p2pKeys, argsConfig.consoleOut, argsConfig.noSplit) } -func generateKeys(typeKey string, numKeys int, prefix string, shardID int) ([]key, []key, error) { +func generateKeys(typeKey string, numKeys int, prefix string, shardID int) ([]key, []key, []key, error) { if numKeys < 1 { - return nil, nil, fmt.Errorf("number of keys should be a number greater or equal to 1") + return nil, nil, nil, fmt.Errorf("number of keys should be a number greater or equal to 1") } validatorKeys := make([]key, 0) walletKeys := make([]key, 0) + p2pKeys := make([]key, 0) var err error blockSigningGenerator := signing.NewKeyGenerator(mcl.NewSuiteBLS12()) @@ -182,35 +189,68 @@ func generateKeys(typeKey string, numKeys int, prefix string, shardID int) ([]ke case validatorType: validatorKeys, err = generateKey(blockSigningGenerator, validatorKeys) if err != nil { - return nil, nil, err + return nil, nil, nil, err } case walletType: walletKeys, err = generateKey(txSigningGenerator, walletKeys) if err != nil { - return nil, nil, err + return nil, nil, nil, err } + case p2pType: + p2pKeys, err = generateP2pKey(p2pKeys) + if err != nil { + return nil, nil, nil, err + } + // TODO: change this behaviour, maybe list of options instead of both type case bothType: validatorKeys, err = generateKey(blockSigningGenerator, validatorKeys) if err != nil { - return nil, nil, err + return nil, nil, nil, err } walletKeys, err = generateKey(txSigningGenerator, walletKeys) if err != nil { - return nil, nil, err + return nil, nil, nil, err } case minedWalletPrefixKeys: walletKeys, err = generateMinedWalletKeys(txSigningGenerator, walletKeys, prefix, shardID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } default: - return nil, nil, fmt.Errorf("unknown key type %s", argsConfig.keyType) + return nil, nil, nil, fmt.Errorf("unknown key type %s", argsConfig.keyType) } } - return validatorKeys, walletKeys, nil + return validatorKeys, walletKeys, p2pKeys, nil +} + +func generateP2pKey(list []key) ([]key, error) { + privateKey, publicKey, err := libp2pCrypto.GenerateSecp256k1Key(rand.Reader) + if err != nil { + return nil, err + } + + skBytes, err := privateKey.Raw() + if err != nil { + return nil, err + } + + pkBytes, err := publicKey.Raw() + if err != nil { + return nil, err + } + + list = append( + list, + key{ + skBytes: skBytes, + pkBytes: pkBytes, + }, + ) + + return list, nil } func generateKey(keyGen crypto.KeyGenerator, list []key) ([]key, error) { @@ -284,18 +324,19 @@ func keyInShard(keyBytes []byte, shardID byte) bool { func outputKeys( validatorKeys []key, walletKeys []key, + p2pKeys []key, consoleOut bool, noSplit bool, ) error { if consoleOut { - return printKeys(validatorKeys, walletKeys) + return printKeys(validatorKeys, walletKeys, p2pKeys) } - return saveKeys(validatorKeys, walletKeys, noSplit) + return saveKeys(validatorKeys, walletKeys, p2pKeys, noSplit) } -func printKeys(validatorKeys []key, walletKeys []key) error { - if len(validatorKeys)+len(walletKeys) == 0 { +func printKeys(validatorKeys, walletKeys, p2pKeys []key) error { + if len(validatorKeys)+len(walletKeys)+len(p2pKeys) == 0 { return fmt.Errorf("internal error: no keys to print") } @@ -312,6 +353,12 @@ func printKeys(validatorKeys []key, walletKeys []key) error { errFound = err } } + if len(p2pKeys) > 0 { + err := printSliceKeys("P2p keys:", p2pKeys, p2pPubKeyConverter) + if err != nil { + errFound = err + } + } return errFound } @@ -348,8 +395,8 @@ func writeKeyToStream(writer io.Writer, key key, pubkeyConverter core.PubkeyConv return pem.Encode(writer, &blk) } -func saveKeys(validatorKeys []key, walletKeys []key, noSplit bool) error { - if len(validatorKeys)+len(walletKeys) == 0 { +func saveKeys(validatorKeys, walletKeys, p2pKeys []key, noSplit bool) error { + if len(validatorKeys)+len(walletKeys)+len(p2pKeys) == 0 { return fmt.Errorf("internal error: no keys to save") } @@ -366,6 +413,12 @@ func saveKeys(validatorKeys []key, walletKeys []key, noSplit bool) error { errFound = err } } + if len(p2pKeys) > 0 { + err := saveSliceKeys(p2pKeyFilenameTemplate, p2pKeys, p2pPubKeyConverter, noSplit) + if err != nil { + errFound = err + } + } return errFound } diff --git a/cmd/keygenerator/p2pConverter.go b/cmd/keygenerator/p2pConverter.go new file mode 100644 index 00000000000..b771f5d5342 --- /dev/null +++ b/cmd/keygenerator/p2pConverter.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + + libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" +) + +type p2pConverter struct{} + +// NewP2pConverter creates a new instance of p2p converter +func NewP2pConverter() *p2pConverter { + return &p2pConverter{} +} + +// Len return zero +func (p *p2pConverter) Len() int { + return 0 +} + +// Decode does nothing +func (p *p2pConverter) Decode(humanReadable string) ([]byte, error) { + return nil, fmt.Errorf("not implemented") +} + +// Encode encodes a byte array representing public key as peer ID string +func (p *p2pConverter) Encode(pkBytes []byte) string { + pubKey, err := libp2pCrypto.UnmarshalSecp256k1PublicKey(pkBytes) + if err != nil { + return "" + } + + id, err := peer.IDFromPublicKey(pubKey) + if err != nil { + return "" + } + + return id.Pretty() +} + +// IsInterfaceNil returns true if there is no value under the interface +func (p *p2pConverter) IsInterfaceNil() bool { + return p == nil +} diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 4782d37742b..070c387984f 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -815,6 +815,11 @@ [Debug.EpochStart] GoRoutineAnalyserEnabled = true ProcessDataTrieOnCommitEpoch = true + [Debug.Process] + Enabled = true + DebuggingLogLevel = "*:DEBUG,p2p:TRACE,debug:DEBUG,process:TRACE,intercept:TRACE" + GoRoutinesDump = true + PollingTimeInSeconds = 240 # 4 minutes [Health] IntervalVerifyMemoryInSeconds = 30 diff --git a/cmd/node/config/enableEpochs.toml b/cmd/node/config/enableEpochs.toml index 48ab61d8ee9..fdf1e0a8771 100644 --- a/cmd/node/config/enableEpochs.toml +++ b/cmd/node/config/enableEpochs.toml @@ -210,13 +210,16 @@ CheckExecuteOnReadOnlyEnableEpoch = 1 # ESDTMetadataContinuousCleanupEnableEpoch represents the epoch when esdt metadata is automatically deleted according to inshard liquidity - ESDTMetadataContinuousCleanupEnableEpoch = 4 + ESDTMetadataContinuousCleanupEnableEpoch = 1 # HeartbeatDisableEpoch represents the epoch when heartbeat v1 messages stop being sent and processed - HeartbeatDisableEpoch = 2 + HeartbeatDisableEpoch = 1 # MiniBlockPartialExecutionEnableEpoch represents the epoch when mini block partial execution will be enabled - MiniBlockPartialExecutionEnableEpoch = 3 + MiniBlockPartialExecutionEnableEpoch = 1 + + # FixAsyncCallBackArgsListEnableEpoch represents the epoch when the async callback arguments lists fix will be enabled + FixAsyncCallBackArgsListEnableEpoch = 1 # SetSenderInEeiOutputTransferEnableEpoch represents the epoch when setting the sender in eei output transfers will be enabled SetSenderInEeiOutputTransferEnableEpoch = 4 @@ -224,6 +227,12 @@ # RefactorPeersMiniBlocksEnableEpoch represents the epoch when refactor of the peers mini blocks will be enabled RefactorPeersMiniBlocksEnableEpoch = 5 + # BLSMultiSignerEnableEpoch represents the activation epoch for different types of BLS multi-signers + BLSMultiSignerEnableEpoch = [ + { EnableEpoch = 0, Type = "no-KOSK"}, + { EnableEpoch = 3, Type = "KOSK"} + ] + # MaxNodesChangeEnableEpoch holds configuration for changing the maximum number of nodes and the enabling epoch MaxNodesChangeEnableEpoch = [ { EpochEnable = 0, MaxNumNodes = 36, NodesToShufflePerShard = 4 }, diff --git a/cmd/node/config/external.toml b/cmd/node/config/external.toml index 897219be8b0..1fbf249fe71 100644 --- a/cmd/node/config/external.toml +++ b/cmd/node/config/external.toml @@ -34,6 +34,9 @@ # Password is used to authorize an observer to push event data Password = "" + # RequestTimeoutSec defines the timeout in seconds for the http client + RequestTimeoutSec = 60 + # CovalentConnector defines settings related to covalent indexer [CovalentConnector] # This flag shall only be used for observer nodes diff --git a/cmd/node/config/p2p.toml b/cmd/node/config/p2p.toml index 890c887a03a..2ce99da3ba0 100644 --- a/cmd/node/config/p2p.toml +++ b/cmd/node/config/p2p.toml @@ -6,20 +6,6 @@ #If the port = 0, the node will search for a free port on the machine and use it Port = "37373-38383" - #Seed represents the seed string generator for p2p identity (used during authentication and message passing). - #An empty Seed value will mean that the identity will be generated randomly in a secure cryptographically manner. - #The seed provided in this string can be of any length. - ######################################################################################### - # WARNING! FOR SECURITY REASONS, ONE MIGHT USE A GENERATED STRING AS LONG AS POSSIBLE! # - # IT IS RECOMMENDED THAT THE SEED FIELD SHOULD REMAIN EMPTY (NO CHARACTERS BETWEEN "") # - # THIS SEED WILL BE USED FOR P2P'S PRIVATE KEY GENERATION. SAME SEED USED WILL LEAD TO # - # THE GENERATION OF THE SAME P2P IDENTITY. # - # SPECIFY A SEED VALUE IF YOU KNOW WHAT YOU ARE DOING! # - ######################################################################################### - #The seed provided will be hashed using SHA256 and the resulting 32 byte length byte array will be used in - #p2p identity generation - Seed = "" - #ThresholdMinConnectedPeers represents the minimum number of connections a node should have before it can start #the sync and consensus mechanisms ThresholdMinConnectedPeers = 3 diff --git a/cmd/node/flags.go b/cmd/node/flags.go index eed68edd002..e276d39ac33 100644 --- a/cmd/node/flags.go +++ b/cmd/node/flags.go @@ -344,6 +344,13 @@ var ( Usage: "Boolean flag for enabling the node to generate a signing key when it starts (if the validatorKey.pem" + " file is present, setting this flag to true will overwrite the BLS key used by the node)", } + + // p2pKeyPemFile defines the flag for the path to the key pem file used for p2p signing + p2pKeyPemFile = cli.StringFlag{ + Name: "p2p-key-pem-file", + Usage: "The `filepath` for the PEM file which contains the secret keys for the p2p key. If this is not specified a new key will be generated (internally) by default.", + Value: "./config/p2pKey.pem", + } ) func getFlags() []cli.Flag { @@ -397,6 +404,7 @@ func getFlags() []cli.Flag { disableConsensusWatchdog, serializeSnapshots, noKey, + p2pKeyPemFile, } } @@ -434,6 +442,7 @@ func applyFlags(ctx *cli.Context, cfgs *config.Configs, flagsConfig *config.Cont cfgs.ConfigurationPathsHolder.GasScheduleDirectoryName = ctx.GlobalString(gasScheduleConfigurationDirectory.Name) cfgs.ConfigurationPathsHolder.SmartContracts = ctx.GlobalString(smartContractsFile.Name) cfgs.ConfigurationPathsHolder.ValidatorKey = ctx.GlobalString(validatorKeyPemFile.Name) + cfgs.ConfigurationPathsHolder.P2pKey = ctx.GlobalString(p2pKeyPemFile.Name) if ctx.IsSet(startInEpoch.Name) { log.Debug("start in epoch is enabled") diff --git a/cmd/seednode/config/p2p.toml b/cmd/seednode/config/p2p.toml index b5e599479c2..5e13f92574f 100644 --- a/cmd/seednode/config/p2p.toml +++ b/cmd/seednode/config/p2p.toml @@ -6,20 +6,6 @@ #Can use single values such as 0, 10230, 15670 or a range such as 5000-10000 Port = "10000" - #Seed represents the seed string generator for p2p identity (used during authentication and message passing). - #An empty Seed value will mean that the identity will be generated randomly in a secure cryptographically manner. - #The seed provided in this string can be of any length. - ######################################################################################### - # WARNING! FOR SECURITY REASONS, ONE MIGHT USE A GENERATED STRING AS LONG AS POSSIBLE! # - # IT IS RECOMMENDED THAT THE SEED FIELD SHOULD REMAIN EMPTY (NO CHARACTERS BETWEEN "") # - # THIS SEED WILL BE USED FOR P2P'S PRIVATE KEY GENERATION. SAME SEED USED WILL LEAD TO # - # THE GENERATION OF THE SAME P2P IDENTITY. # - # SPECIFY A SEED VALUE IF YOU KNOW WHAT YOU ARE DOING! # - ######################################################################################### - #The seed provided will be hashed using SHA256 and the resulting 32 byte length byte array will be used in - #p2p identity generation - Seed = "seed" - # The maximum peers that will connect to this node MaximumExpectedPeerCount = 1024 diff --git a/cmd/seednode/main.go b/cmd/seednode/main.go index ad66dc6ee4a..ad777f11c42 100644 --- a/cmd/seednode/main.go +++ b/cmd/seednode/main.go @@ -23,7 +23,7 @@ import ( "github.com/ElrondNetwork/elrond-go/epochStart/bootstrap/disabled" "github.com/ElrondNetwork/elrond-go/facade" "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/urfave/cli" ) @@ -63,12 +63,6 @@ VERSION: "To bind to all available interfaces, set this flag to :8080. If set to `off` then the API won't be available", Value: facade.DefaultRestInterface, } - // p2pSeed defines a flag to be used as a seed when generating P2P credentials. Useful for seed nodes. - p2pSeed = cli.StringFlag{ - Name: "p2p-seed", - Usage: "P2P seed will be used when generating credentials for p2p component. Can be any string.", - Value: "seed", - } // logLevel defines the logger level logLevel = cli.StringFlag{ Name: "log-level", @@ -90,6 +84,13 @@ VERSION: "configurations such as the marshalizer type", Value: "./config/config.toml", } + // p2pKeyPemFile defines the flag for the path to the key pem file used for p2p signing + p2pKeyPemFile = cli.StringFlag{ + Name: "p2p-key-pem-file", + Usage: "The `filepath` for the PEM file which contains the secret keys for the p2p key. If this is not specified a new key will be generated (internally) by default.", + Value: "./config/p2pKey.pem", + } + p2pConfigurationFile = "./config/p2p.toml" ) @@ -103,10 +104,10 @@ func main() { app.Flags = []cli.Flag{ port, restApiInterfaceFlag, - p2pSeed, logLevel, logSaveFile, configurationFile, + p2pKeyPemFile, } app.Version = "v0.0.1" app.Authors = []cli.Author{ @@ -176,7 +177,7 @@ func startNode(ctx *cli.Context) error { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - p2pConfig, err := common.LoadP2PConfig(p2pConfigurationFile) + p2pCfg, err := common.LoadP2PConfig(p2pConfigurationFile) if err != nil { return err } @@ -184,18 +185,21 @@ func startNode(ctx *cli.Context) error { "filename", p2pConfigurationFile, ) if ctx.IsSet(port.Name) { - p2pConfig.Node.Port = ctx.GlobalString(port.Name) + p2pCfg.Node.Port = ctx.GlobalString(port.Name) } - if ctx.IsSet(p2pSeed.Name) { - p2pConfig.Node.Seed = ctx.GlobalString(p2pSeed.Name) + + err = checkExpectedPeerCount(*p2pCfg) + if err != nil { + return err } - err = checkExpectedPeerCount(*p2pConfig) + p2pKeyPemFileName := ctx.GlobalString(p2pKeyPemFile.Name) + p2pKeyBytes, err := common.GetSkBytesFromP2pKey(p2pKeyPemFileName) if err != nil { return err } - messenger, err := createNode(*p2pConfig, internalMarshalizer) + messenger, err := createNode(*p2pCfg, internalMarshalizer, p2pKeyBytes) if err != nil { return err } @@ -240,19 +244,20 @@ func loadMainConfig(filepath string) (*config.Config, error) { return cfg, nil } -func createNode(p2pConfig config.P2PConfig, marshalizer marshal.Marshalizer) (p2p.Messenger, error) { - arg := libp2p.ArgsNetworkMessenger{ +func createNode(p2pConfig p2pConfig.P2PConfig, marshalizer marshal.Marshalizer, p2pKeyBytes []byte) (p2p.Messenger, error) { + arg := p2p.ArgsNetworkMessenger{ Marshalizer: marshalizer, - ListenAddress: libp2p.ListenAddrWithIp4AndTcp, + ListenAddress: p2p.ListenAddrWithIp4AndTcp, P2pConfig: p2pConfig, - SyncTimer: &libp2p.LocalSyncTimer{}, + SyncTimer: &p2p.LocalSyncTimer{}, PreferredPeersHolder: disabled.NewPreferredPeersHolder(), NodeOperationMode: p2p.NormalOperation, PeersRatingHandler: disabled.NewDisabledPeersRatingHandler(), ConnectionWatcherType: "disabled", + P2pPrivateKeyBytes: p2pKeyBytes, } - return libp2p.NewNetworkMessenger(arg) + return p2p.NewNetworkMessenger(arg) } func displayMessengerInfo(messenger p2p.Messenger) { @@ -295,7 +300,7 @@ func getWorkingDir(log logger.Logger) string { return workingDir } -func checkExpectedPeerCount(p2pConfig config.P2PConfig) error { +func checkExpectedPeerCount(p2pConfig p2pConfig.P2PConfig) error { maxExpectedPeerCount := p2pConfig.Node.MaximumExpectedPeerCount var rLimit syscall.Rlimit diff --git a/common/configParser.go b/common/configParser.go index 76dd74a99b7..d9444ef714d 100644 --- a/common/configParser.go +++ b/common/configParser.go @@ -1,15 +1,18 @@ package common import ( + "encoding/hex" "fmt" + "os" "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go/config" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" ) // LoadP2PConfig returns a P2PConfig by reading the config file provided -func LoadP2PConfig(filepath string) (*config.P2PConfig, error) { - cfg := &config.P2PConfig{} +func LoadP2PConfig(filepath string) (*p2pConfig.P2PConfig, error) { + cfg := &p2pConfig.P2PConfig{} err := core.LoadTomlFile(cfg, filepath) if err != nil { return nil, err @@ -134,3 +137,29 @@ func LoadRoundConfig(filePath string) (*config.RoundConfig, error) { return cfg, nil } + +// GetSkBytesFromP2pKey will read key file based on provided path. If no valid filename +// it will return an empty byte array, otherwise it will try to fetch the private key and +// return the decoded byte array. +func GetSkBytesFromP2pKey(p2pKeyFilename string) ([]byte, error) { + if len(p2pKeyFilename) == 0 { + return []byte{}, nil + } + + skIndex := 0 + encodedSk, _, err := core.LoadSkPkFromPemFile(p2pKeyFilename, skIndex) + if err != nil { + if os.IsNotExist(err) { + return []byte{}, nil + } + + return nil, err + } + + skBytes, err := hex.DecodeString(string(encodedSk)) + if err != nil { + return nil, fmt.Errorf("%w for encoded secret key", err) + } + + return skBytes, nil +} diff --git a/common/constants.go b/common/constants.go index 13573db3b80..5881cd94676 100644 --- a/common/constants.go +++ b/common/constants.go @@ -75,15 +75,6 @@ const ConnectionTopic = "connection" // ValidatorInfoTopic is the topic used for validatorInfo signaling const ValidatorInfoTopic = "validatorInfo" -// PathShardPlaceholder represents the placeholder for the shard ID in paths -const PathShardPlaceholder = "[S]" - -// PathEpochPlaceholder represents the placeholder for the epoch number in paths -const PathEpochPlaceholder = "[E]" - -// PathIdentifierPlaceholder represents the placeholder for the identifier in paths -const PathIdentifierPlaceholder = "[I]" - // MetricCurrentRound is the metric for monitoring the current round of a node const MetricCurrentRound = "erd_current_round" @@ -301,7 +292,7 @@ const MetricCreatedProposedBlock = "erd_consensus_created_proposed_block" // MetricRedundancyLevel is the metric that specifies the redundancy level of the current node const MetricRedundancyLevel = "erd_redundancy_level" -// MetricRedundancyIsMainActive is the metrics that specifies data about the redundancy main machine +// MetricRedundancyIsMainActive is the metric that specifies data about the redundancy main machine const MetricRedundancyIsMainActive = "erd_redundancy_is_main_active" // MetricValueNA represents the value to be used when a metric is not available/applicable @@ -683,27 +674,12 @@ const WrongConfiguration = "wrongConfiguration" // ImportComplete signals that a node restart will be done because the import did complete const ImportComplete = "importComplete" -// MaxRetriesToCreateDB represents the maximum number of times to try to create DB if it failed -const MaxRetriesToCreateDB = 10 - -// SleepTimeBetweenCreateDBRetries represents the number of seconds to sleep between DB creates -const SleepTimeBetweenCreateDBRetries = 5 * time.Second - // DefaultStatsPath is the default path where the node stats are logged const DefaultStatsPath = "stats" // DefaultDBPath is the default path for nodes databases const DefaultDBPath = "db" -// DefaultEpochString is the default folder root name for node per epoch databases -const DefaultEpochString = "Epoch" - -// DefaultStaticDbString is the default name for the static databases (not changing with epoch) -const DefaultStaticDbString = "Static" - -// DefaultShardString is the default folder root name for per shard databases -const DefaultShardString = "Shard" - // MetachainShardName is the string identifier of the metachain shard const MetachainShardName = "metachain" diff --git a/common/crypto/interface.go b/common/crypto/interface.go new file mode 100644 index 00000000000..3e7f1bbeb6f --- /dev/null +++ b/common/crypto/interface.go @@ -0,0 +1,9 @@ +package crypto + +import crypto "github.com/ElrondNetwork/elrond-go-crypto" + +// MultiSignerContainer defines the container for different versioned multiSigner instances +type MultiSignerContainer interface { + GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) + IsInterfaceNil() bool +} diff --git a/common/enablers/enableEpochsHandler.go b/common/enablers/enableEpochsHandler.go index 4b4877f102c..ec7ee27d8a6 100644 --- a/common/enablers/enableEpochsHandler.go +++ b/common/enablers/enableEpochsHandler.go @@ -112,6 +112,7 @@ func (handler *enableEpochsHandler) EpochConfirmed(epoch uint32, _ uint64) { handler.setFlagValue(epoch >= handler.enableEpochsConfig.SetSenderInEeiOutputTransferEnableEpoch, handler.setSenderInEeiOutputTransferFlag, "setSenderInEeiOutputTransferFlag") handler.setFlagValue(epoch >= handler.enableEpochsConfig.ESDTMetadataContinuousCleanupEnableEpoch, handler.changeDelegationOwnerFlag, "changeDelegationOwnerFlag") handler.setFlagValue(epoch >= handler.enableEpochsConfig.RefactorPeersMiniBlocksEnableEpoch, handler.refactorPeersMiniBlocksFlag, "refactorPeersMiniBlocksFlag") + handler.setFlagValue(epoch >= handler.enableEpochsConfig.FixAsyncCallBackArgsListEnableEpoch, handler.fixAsyncCallBackArgsList, "fixAsyncCallBackArgsList") } func (handler *enableEpochsHandler) setFlagValue(value bool, flag *atomic.Flag, flagName string) { diff --git a/common/enablers/enableEpochsHandler_test.go b/common/enablers/enableEpochsHandler_test.go index 86f050aee30..e4fa7a3f930 100644 --- a/common/enablers/enableEpochsHandler_test.go +++ b/common/enablers/enableEpochsHandler_test.go @@ -85,6 +85,7 @@ func createEnableEpochsConfig() config.EnableEpochs { RefactorContextEnableEpoch: 69, CheckFunctionArgumentEnableEpoch: 70, CheckExecuteOnReadOnlyEnableEpoch: 71, + FixAsyncCallBackArgsListEnableEpoch: 72, } } @@ -123,7 +124,7 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { handler, _ := NewEnableEpochsHandler(cfg, &epochNotifier.EpochNotifierStub{}) require.False(t, check.IfNil(handler)) - handler.EpochConfirmed(73, 0) + handler.EpochConfirmed(75, 0) assert.Equal(t, cfg.BlockGasAndFeesReCheckEnableEpoch, handler.BlockGasAndFeesReCheckEnableEpoch()) assert.True(t, handler.IsSCDeployFlagEnabled()) @@ -297,6 +298,7 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.True(t, handler.IsCheckFunctionArgumentFlagEnabled()) assert.True(t, handler.IsCheckExecuteOnReadOnlyFlagEnabled()) assert.True(t, handler.IsChangeDelegationOwnerFlagEnabled()) + assert.True(t, handler.IsFixAsyncCallBackArgsListFlagEnabled()) }) t.Run("flags with < should be set", func(t *testing.T) { t.Parallel() @@ -386,5 +388,6 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.False(t, handler.IsCheckFunctionArgumentFlagEnabled()) assert.False(t, handler.IsCheckExecuteOnReadOnlyFlagEnabled()) assert.False(t, handler.IsChangeDelegationOwnerFlagEnabled()) + assert.False(t, handler.IsFixAsyncCallBackArgsListFlagEnabled()) }) } diff --git a/common/enablers/epochFlags.go b/common/enablers/epochFlags.go index 3960d990ea3..65550b8d89b 100644 --- a/common/enablers/epochFlags.go +++ b/common/enablers/epochFlags.go @@ -82,6 +82,7 @@ type epochFlagsHolder struct { setSenderInEeiOutputTransferFlag *atomic.Flag changeDelegationOwnerFlag *atomic.Flag refactorPeersMiniBlocksFlag *atomic.Flag + fixAsyncCallBackArgsList *atomic.Flag } func newEpochFlagsHolder() *epochFlagsHolder { @@ -165,6 +166,7 @@ func newEpochFlagsHolder() *epochFlagsHolder { setSenderInEeiOutputTransferFlag: &atomic.Flag{}, changeDelegationOwnerFlag: &atomic.Flag{}, refactorPeersMiniBlocksFlag: &atomic.Flag{}, + fixAsyncCallBackArgsList: &atomic.Flag{}, } } @@ -615,3 +617,8 @@ func (holder *epochFlagsHolder) IsChangeDelegationOwnerFlagEnabled() bool { func (holder *epochFlagsHolder) IsRefactorPeersMiniBlocksFlagEnabled() bool { return holder.refactorPeersMiniBlocksFlag.IsSet() } + +// IsFixAsyncCallBackArgsListFlagEnabled returns true if fixAsyncCallBackArgsList is enabled +func (holder *epochFlagsHolder) IsFixAsyncCallBackArgsListFlagEnabled() bool { + return holder.fixAsyncCallBackArgsList.IsSet() +} diff --git a/common/interface.go b/common/interface.go index 78b57d4a71e..e002db2e372 100644 --- a/common/interface.go +++ b/common/interface.go @@ -34,16 +34,29 @@ type Trie interface { GetSerializedNodes([]byte, uint64) ([][]byte, uint64, error) GetSerializedNode([]byte) ([]byte, error) GetNumNodes() NumNodesDTO - GetAllLeavesOnChannel(leavesChannel chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error + GetAllLeavesOnChannel(leavesChannel chan core.KeyValueHolder, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder) error GetAllHashes() ([][]byte, error) GetProof(key []byte) ([][]byte, []byte, error) VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bool, error) GetStorageManager() StorageManager - MarkStorerAsSyncedAndActive() Close() error IsInterfaceNil() bool } +// KeyBuilder is used for building trie keys as you traverse the trie +type KeyBuilder interface { + BuildKey(keyPart []byte) + GetKey() ([]byte, error) + Clone() KeyBuilder +} + +// DataTrieHandler is an interface that declares the methods used for dataTries +type DataTrieHandler interface { + RootHash() ([]byte, error) + GetAllLeavesOnChannel(leavesChannel chan core.KeyValueHolder, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder) error + IsInterfaceNil() bool +} + // StorageManager manages all trie storage operations type StorageManager interface { Get(key []byte) ([]byte, error) @@ -51,8 +64,8 @@ type StorageManager interface { Put(key []byte, val []byte) error PutInEpoch(key []byte, val []byte, epoch uint32) error PutInEpochWithoutCache(key []byte, val []byte, epoch uint32) error - TakeSnapshot(rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, errChan chan error, stats SnapshotStatisticsHandler, epoch uint32) - SetCheckpoint(rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, errChan chan error, stats SnapshotStatisticsHandler) + TakeSnapshot(rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, errChan chan error, stats SnapshotStatisticsHandler, epoch uint32) + SetCheckpoint(rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, errChan chan error, stats SnapshotStatisticsHandler) GetLatestStorageEpoch() (uint32, error) IsPruningEnabled() bool IsPruningBlocked() bool @@ -291,6 +304,7 @@ type EnableEpochsHandler interface { IsSetSenderInEeiOutputTransferFlagEnabled() bool IsChangeDelegationOwnerFlagEnabled() bool IsRefactorPeersMiniBlocksFlagEnabled() bool + IsFixAsyncCallBackArgsListFlagEnabled() bool IsInterfaceNil() bool } diff --git a/config/config.go b/config/config.go index 90ee12ed86e..05c2b21b60c 100644 --- a/config/config.go +++ b/config/config.go @@ -1,5 +1,7 @@ package config +import p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" + // CacheConfig will map the cache configuration type CacheConfig struct { Name string @@ -266,6 +268,12 @@ type MaxNodesChangeConfig struct { NodesToShufflePerShard uint32 } +// MultiSignerConfig defines a config tuple for a BLS multi-signer that activates in a certain epoch +type MultiSignerConfig struct { + EnableEpoch uint32 + Type string +} + // GeneralSettingsConfig will hold the general settings for a node type GeneralSettingsConfig struct { StatusPollingIntervalSec int @@ -466,6 +474,7 @@ type DebugConfig struct { Antiflood AntifloodDebugConfig ShuffleOut ShuffleOutDebugConfig EpochStart EpochStartDebugConfig + Process ProcessDebugConfig } // HealthServiceConfig will hold health service (monitoring) configuration @@ -509,6 +518,14 @@ type EpochStartDebugConfig struct { ProcessDataTrieOnCommitEpoch bool } +// ProcessDebugConfig will hold the process debug configuration +type ProcessDebugConfig struct { + Enabled bool + GoRoutinesDump bool + DebuggingLogLevel string + PollingTimeInSeconds int +} + // ApiRoutesConfig holds the configuration related to Rest API routes type ApiRoutesConfig struct { Logging ApiLoggingConfig @@ -554,7 +571,7 @@ type Configs struct { RatingsConfig *RatingsConfig PreferencesConfig *Preferences ExternalConfig *ExternalConfig - P2pConfig *P2PConfig + P2pConfig *p2pConfig.P2PConfig FlagsConfig *ContextFlagsConfig ImportDbConfig *ImportDbConfig ConfigurationPathsHolder *ConfigurationPathsHolder @@ -579,6 +596,7 @@ type ConfigurationPathsHolder struct { ValidatorKey string Epoch string RoundActivation string + P2pKey string } // TrieSyncConfig represents the trie synchronization configuration area diff --git a/config/epochConfig.go b/config/epochConfig.go index 2496a3250d3..d2fd5f7f44c 100644 --- a/config/epochConfig.go +++ b/config/epochConfig.go @@ -85,8 +85,10 @@ type EnableEpochs struct { HeartbeatDisableEpoch uint32 MiniBlockPartialExecutionEnableEpoch uint32 ESDTMetadataContinuousCleanupEnableEpoch uint32 + FixAsyncCallBackArgsListEnableEpoch uint32 SetSenderInEeiOutputTransferEnableEpoch uint32 RefactorPeersMiniBlocksEnableEpoch uint32 + BLSMultiSignerEnableEpoch []MultiSignerConfig } // GasScheduleByEpochs represents a gas schedule toml entry that will be applied from the provided epoch diff --git a/config/externalConfig.go b/config/externalConfig.go index d4a869bdf4c..47f7b8d270c 100644 --- a/config/externalConfig.go +++ b/config/externalConfig.go @@ -21,11 +21,12 @@ type ElasticSearchConfig struct { // EventNotifierConfig will hold the configuration for the events notifier driver type EventNotifierConfig struct { - Enabled bool - UseAuthorization bool - ProxyUrl string - Username string - Password string + Enabled bool + UseAuthorization bool + ProxyUrl string + Username string + Password string + RequestTimeoutSec int } // CovalentConfig will hold the configurations for covalent indexer diff --git a/config/p2pConfig.go b/config/p2pConfig.go deleted file mode 100644 index adfb0976b68..00000000000 --- a/config/p2pConfig.go +++ /dev/null @@ -1,46 +0,0 @@ -package config - -// P2PConfig will hold all the P2P settings -type P2PConfig struct { - Node NodeConfig - KadDhtPeerDiscovery KadDhtPeerDiscoveryConfig - Sharding ShardingConfig -} - -// NodeConfig will hold basic p2p settings -type NodeConfig struct { - Port string - Seed string - MaximumExpectedPeerCount uint64 - ThresholdMinConnectedPeers uint32 - MinNumPeersToWaitForOnBootstrap uint32 -} - -// KadDhtPeerDiscoveryConfig will hold the kad-dht discovery config settings -type KadDhtPeerDiscoveryConfig struct { - Enabled bool - Type string - RefreshIntervalInSec uint32 - ProtocolID string - InitialPeerList []string - BucketSize uint32 - RoutingTableRefreshIntervalInSec uint32 -} - -// ShardingConfig will hold the network sharding config settings -type ShardingConfig struct { - TargetPeerCount uint32 - MaxIntraShardValidators uint32 - MaxCrossShardValidators uint32 - MaxIntraShardObservers uint32 - MaxCrossShardObservers uint32 - MaxSeeders uint32 - Type string - AdditionalConnections AdditionalConnectionsConfig -} - -// AdditionalConnectionsConfig will hold the additional connections that will be open when certain conditions are met -// All these values should be added to the maximum target peer count value -type AdditionalConnectionsConfig struct { - MaxFullHistoryObservers uint32 -} diff --git a/config/tomlConfig_test.go b/config/tomlConfig_test.go index 58dbe73a39b..ae2952592cd 100644 --- a/config/tomlConfig_test.go +++ b/config/tomlConfig_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/pelletier/go-toml" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -434,14 +435,12 @@ func TestP2pConfig(t *testing.T) { initialPeersList := "/ip4/127.0.0.1/tcp/9999/p2p/16Uiu2HAkw5SNNtSvH1zJiQ6Gc3WoGNSxiyNueRKe6fuAuh57G3Bk" protocolID := "test protocol id" shardingType := "ListSharder" - seed := "test seed" port := "37373-38383" testString := ` #P2P config file [Node] Port = "` + port + `" - Seed = "` + seed + `" ThresholdMinConnectedPeers = 0 [KadDhtPeerDiscovery] @@ -469,20 +468,19 @@ func TestP2pConfig(t *testing.T) { [AdditionalConnections] MaxFullHistoryObservers = 0` - expectedCfg := P2PConfig{ - Node: NodeConfig{ + expectedCfg := p2pConfig.P2PConfig{ + Node: p2pConfig.NodeConfig{ Port: port, - Seed: seed, }, - KadDhtPeerDiscovery: KadDhtPeerDiscoveryConfig{ + KadDhtPeerDiscovery: p2pConfig.KadDhtPeerDiscoveryConfig{ ProtocolID: protocolID, InitialPeerList: []string{initialPeersList}, }, - Sharding: ShardingConfig{ + Sharding: p2pConfig.ShardingConfig{ Type: shardingType, }, } - cfg := P2PConfig{} + cfg := p2pConfig.P2PConfig{} err := toml.Unmarshal([]byte(testString), &cfg) @@ -665,8 +663,11 @@ func TestEnableEpochConfig(t *testing.T) { # ESDTMetadataContinuousCleanupEnableEpoch represents the epoch when esdt metadata is automatically deleted according to inshard liquidity ESDTMetadataContinuousCleanupEnableEpoch = 56 + # FixAsyncCallBackArgsListEnableEpoch represents the epoch when the async callback arguments lists fix will be enabled + FixAsyncCallBackArgsListEnableEpoch = 57 + # SetSenderInEeiOutputTransferEnableEpoch represents the epoch when setting the sender in eei output transfers will be enabled - SetSenderInEeiOutputTransferEnableEpoch = 57 + SetSenderInEeiOutputTransferEnableEpoch = 58 # MaxNodesChangeEnableEpoch holds configuration for changing the maximum number of nodes and the enabling epoch MaxNodesChangeEnableEpoch = [ @@ -674,6 +675,11 @@ func TestEnableEpochConfig(t *testing.T) { { EpochEnable = 45, MaxNumNodes = 3200, NodesToShufflePerShard = 80 } ] + BLSMultiSignerEnableEpoch = [ + {EnableEpoch = 0, Type = "no-KOSK"}, + {EnableEpoch = 3, Type = "KOSK"} + ] + [GasSchedule] GasScheduleByEpochs = [ { StartEpoch = 46, FileName = "gasScheduleV1.toml" }, @@ -751,8 +757,20 @@ func TestEnableEpochConfig(t *testing.T) { ManagedCryptoAPIsEnableEpoch: 54, HeartbeatDisableEpoch: 55, ESDTMetadataContinuousCleanupEnableEpoch: 56, - SetSenderInEeiOutputTransferEnableEpoch: 57, + FixAsyncCallBackArgsListEnableEpoch: 57, + SetSenderInEeiOutputTransferEnableEpoch: 58, + BLSMultiSignerEnableEpoch: []MultiSignerConfig{ + { + EnableEpoch: 0, + Type: "no-KOSK", + }, + { + EnableEpoch: 3, + Type: "KOSK", + }, + }, }, + GasSchedule: GasScheduleConfig{ GasScheduleByEpochs: []GasScheduleByEpochs{ { diff --git a/consensus/broadcast/delayedBroadcast.go b/consensus/broadcast/delayedBroadcast.go index cdff7bbff31..4a70f42ea99 100644 --- a/consensus/broadcast/delayedBroadcast.go +++ b/consensus/broadcast/delayedBroadcast.go @@ -18,7 +18,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process/factory" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" ) const prefixHeaderAlarm = "header_" @@ -100,7 +100,7 @@ func NewDelayedBlockBroadcaster(args *ArgsDelayedBlockBroadcaster) (*delayedBloc return nil, spos.ErrNilAlarmScheduler } - cacheHeaders, err := lrucache.NewCache(sizeHeadersCache) + cacheHeaders, err := cache.NewLRUCache(sizeHeadersCache) if err != nil { return nil, err } diff --git a/consensus/interface.go b/consensus/interface.go index 5b666c5e1fa..0ef9c0cda88 100644 --- a/consensus/interface.go +++ b/consensus/interface.go @@ -6,7 +6,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data" - "github.com/ElrondNetwork/elrond-go-crypto" + crypto "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go/p2p" ) @@ -149,6 +149,20 @@ type ScheduledProcessor interface { IsInterfaceNil() bool } +// SignatureHandler defines the behaviour of a component that handles signatures in consensus +type SignatureHandler interface { + Reset(pubKeys []string) error + CreateSignatureShare(msg []byte, index uint16, epoch uint32) ([]byte, error) + CreateSignatureShareWithPrivateKey(message []byte, index uint16, epoch uint32, privateKeyBytes []byte) ([]byte, error) + StoreSignatureShare(index uint16, sig []byte) error + SignatureShare(index uint16) ([]byte, error) + VerifySignatureShare(index uint16, sig []byte, msg []byte, epoch uint32) error + AggregateSigs(bitmap []byte, epoch uint32) ([]byte, error) + SetAggregatedSig([]byte) error + Verify(msg []byte, bitmap []byte, epoch uint32) error + IsInterfaceNil() bool +} + // KeysHandler defines the operations implemented by a component that will manage all keys, // including the single signer keys or the set of multi-keys type KeysHandler interface { diff --git a/consensus/mock/consensusDataContainerMock.go b/consensus/mock/consensusDataContainerMock.go index adbeaaf2c86..eb93be5e2eb 100644 --- a/consensus/mock/consensusDataContainerMock.go +++ b/consensus/mock/consensusDataContainerMock.go @@ -5,6 +5,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/hashing" "github.com/ElrondNetwork/elrond-go-core/marshal" crypto "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/ntp" @@ -25,7 +26,7 @@ type ConsensusCoreMock struct { marshalizer marshal.Marshalizer blsPrivateKey crypto.PrivateKey blsSingleSigner crypto.SingleSigner - multiSigner crypto.MultiSigner + multiSignerContainer cryptoCommon.MultiSignerContainer roundHandler consensus.RoundHandler shardCoordinator sharding.Coordinator syncTimer ntp.SyncTimer @@ -37,6 +38,7 @@ type ConsensusCoreMock struct { fallbackHeaderValidator consensus.FallbackHeaderValidator nodeRedundancyHandler consensus.NodeRedundancyHandler scheduledProcessor consensus.ScheduledProcessor + signatureHandler consensus.SignatureHandler } // GetAntiFloodHandler - @@ -84,9 +86,9 @@ func (ccm *ConsensusCoreMock) Marshalizer() marshal.Marshalizer { return ccm.marshalizer } -// MultiSigner - -func (ccm *ConsensusCoreMock) MultiSigner() crypto.MultiSigner { - return ccm.multiSigner +// MultiSignerContainer - +func (ccm *ConsensusCoreMock) MultiSignerContainer() cryptoCommon.MultiSignerContainer { + return ccm.multiSignerContainer } // RoundHandler - @@ -154,9 +156,9 @@ func (ccm *ConsensusCoreMock) SetMarshalizer(marshalizer marshal.Marshalizer) { ccm.marshalizer = marshalizer } -// SetMultiSigner - -func (ccm *ConsensusCoreMock) SetMultiSigner(multiSigner crypto.MultiSigner) { - ccm.multiSigner = multiSigner +// SetMultiSignerContainer - +func (ccm *ConsensusCoreMock) SetMultiSignerContainer(multiSignerContainer cryptoCommon.MultiSignerContainer) { + ccm.multiSignerContainer = multiSignerContainer } // SetRoundHandler - @@ -229,6 +231,16 @@ func (ccm *ConsensusCoreMock) SetNodeRedundancyHandler(nodeRedundancyHandler con ccm.nodeRedundancyHandler = nodeRedundancyHandler } +// SignatureHandler - +func (ccm *ConsensusCoreMock) SignatureHandler() consensus.SignatureHandler { + return ccm.signatureHandler +} + +// SetSignatureHandler - +func (ccm *ConsensusCoreMock) SetSignatureHandler(signatureHandler consensus.SignatureHandler) { + ccm.signatureHandler = signatureHandler +} + // IsInterfaceNil returns true if there is no value under the interface func (ccm *ConsensusCoreMock) IsInterfaceNil() bool { return ccm == nil diff --git a/consensus/mock/mockTestInitializer.go b/consensus/mock/mockTestInitializer.go index 18a3c006025..efc416d059f 100644 --- a/consensus/mock/mockTestInitializer.go +++ b/consensus/mock/mockTestInitializer.go @@ -98,17 +98,17 @@ func InitBlockProcessorHeaderV2Mock() *BlockProcessorMock { // InitMultiSignerMock - func InitMultiSignerMock() *cryptoMocks.MultisignerMock { - multiSigner := cryptoMocks.NewMultiSigner(21) - multiSigner.VerifySignatureShareCalled = func(index uint16, sig []byte, msg []byte, bitmap []byte) error { + multiSigner := cryptoMocks.NewMultiSigner() + multiSigner.VerifySignatureShareCalled = func(publicKey []byte, message []byte, sig []byte) error { return nil } - multiSigner.VerifyCalled = func(msg []byte, bitmap []byte) error { + multiSigner.VerifyAggregatedSigCalled = func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { return nil } - multiSigner.AggregateSigsCalled = func(bitmap []byte) ([]byte, error) { + multiSigner.AggregateSigsCalled = func(pubKeysSigners [][]byte, signatures [][]byte) ([]byte, error) { return []byte("aggregatedSig"), nil } - multiSigner.CreateSignatureShareCalled = func(msg []byte, bitmap []byte) ([]byte, error) { + multiSigner.CreateSignatureShareCalled = func(privateKeyBytes []byte, message []byte) ([]byte, error) { return []byte("partialSign"), nil } return multiSigner @@ -204,6 +204,8 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{} nodeRedundancyHandler := &NodeRedundancyHandlerStub{} scheduledProcessor := &consensusMocks.ScheduledProcessorStub{} + multiSignerContainer := cryptoMocks.NewMultiSignerContainerMock(multiSigner) + signatureHandler := &SignatureHandlerStub{} container := &ConsensusCoreMock{ blockChain: blockChain, @@ -216,7 +218,7 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus marshalizer: marshalizerMock, blsPrivateKey: blsPrivateKeyMock, blsSingleSigner: blsSingleSignerMock, - multiSigner: multiSigner, + multiSignerContainer: multiSignerContainer, roundHandler: roundHandlerMock, shardCoordinator: shardCoordinatorMock, syncTimer: syncTimerMock, @@ -228,6 +230,7 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus fallbackHeaderValidator: fallbackHeaderValidator, nodeRedundancyHandler: nodeRedundancyHandler, scheduledProcessor: scheduledProcessor, + signatureHandler: signatureHandler, } return container diff --git a/consensus/mock/signatureHandlerStub.go b/consensus/mock/signatureHandlerStub.go new file mode 100644 index 00000000000..c1d25eb0195 --- /dev/null +++ b/consensus/mock/signatureHandlerStub.go @@ -0,0 +1,100 @@ +package mock + +// SignatureHandlerStub implements SignatureHandler interface +type SignatureHandlerStub struct { + ResetCalled func(pubKeys []string) error + CreateSignatureShareCalled func(msg []byte, index uint16, epoch uint32) ([]byte, error) + CreateSignatureShareWithPrivateKeyCalled func(message []byte, index uint16, epoch uint32, privateKeyBytes []byte) ([]byte, error) + StoreSignatureShareCalled func(index uint16, sig []byte) error + SignatureShareCalled func(index uint16) ([]byte, error) + VerifySignatureShareCalled func(index uint16, sig []byte, msg []byte, epoch uint32) error + AggregateSigsCalled func(bitmap []byte, epoch uint32) ([]byte, error) + SetAggregatedSigCalled func(_ []byte) error + VerifyCalled func(msg []byte, bitmap []byte, epoch uint32) error +} + +// Reset - +func (stub *SignatureHandlerStub) Reset(pubKeys []string) error { + if stub.ResetCalled != nil { + return stub.ResetCalled(pubKeys) + } + + return nil +} + +// CreateSignatureShare - +func (stub *SignatureHandlerStub) CreateSignatureShare(msg []byte, index uint16, epoch uint32) ([]byte, error) { + if stub.CreateSignatureShareCalled != nil { + return stub.CreateSignatureShareCalled(msg, index, epoch) + } + + return []byte("sigShare"), nil +} + +// CreateSignatureShareWithPrivateKey - +func (stub *SignatureHandlerStub) CreateSignatureShareWithPrivateKey(message []byte, index uint16, epoch uint32, privateKeyBytes []byte) ([]byte, error) { + if stub.CreateSignatureShareWithPrivateKeyCalled != nil { + return stub.CreateSignatureShareWithPrivateKeyCalled(message, index, epoch, privateKeyBytes) + } + + return make([]byte, 0), nil +} + +// StoreSignatureShare - +func (stub *SignatureHandlerStub) StoreSignatureShare(index uint16, sig []byte) error { + if stub.StoreSignatureShareCalled != nil { + return stub.StoreSignatureShareCalled(index, sig) + } + + return nil +} + +// SignatureShare - +func (stub *SignatureHandlerStub) SignatureShare(index uint16) ([]byte, error) { + if stub.SignatureShareCalled != nil { + return stub.SignatureShareCalled(index) + } + + return []byte("sigShare"), nil +} + +// VerifySignatureShare - +func (stub *SignatureHandlerStub) VerifySignatureShare(index uint16, sig []byte, msg []byte, epoch uint32) error { + if stub.VerifySignatureShareCalled != nil { + return stub.VerifySignatureShareCalled(index, sig, msg, epoch) + } + + return nil +} + +// AggregateSigs - +func (stub *SignatureHandlerStub) AggregateSigs(bitmap []byte, epoch uint32) ([]byte, error) { + if stub.AggregateSigsCalled != nil { + return stub.AggregateSigsCalled(bitmap, epoch) + } + + return []byte("aggSigs"), nil +} + +// SetAggregatedSig - +func (stub *SignatureHandlerStub) SetAggregatedSig(sig []byte) error { + if stub.SetAggregatedSigCalled != nil { + return stub.SetAggregatedSigCalled(sig) + } + + return nil +} + +// Verify - +func (stub *SignatureHandlerStub) Verify(msg []byte, bitmap []byte, epoch uint32) error { + if stub.VerifyCalled != nil { + return stub.VerifyCalled(msg, bitmap, epoch) + } + + return nil +} + +// IsInterfaceNil - +func (stub *SignatureHandlerStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/consensus/signing/errors.go b/consensus/signing/errors.go new file mode 100644 index 00000000000..01af6e1773b --- /dev/null +++ b/consensus/signing/errors.go @@ -0,0 +1,42 @@ +package signing + +import "errors" + +// ErrInvalidSignature is raised for an invalid signature +var ErrInvalidSignature = errors.New("invalid signature was provided") + +// ErrNilElement is raised when searching for a specific element but found nil +var ErrNilElement = errors.New("element is nil") + +// ErrIndexNotSelected is raised when a not selected index is used for multi-signing +var ErrIndexNotSelected = errors.New("index is not selected") + +// ErrNilBitmap is raised when a nil bitmap is used +var ErrNilBitmap = errors.New("bitmap is nil") + +// ErrNoPrivateKeySet is raised when no private key was set +var ErrNoPrivateKeySet = errors.New("no private key was set") + +// ErrNoPublicKeySet is raised when no public key was set for a multisignature +var ErrNoPublicKeySet = errors.New("no public key was set") + +// ErrNilKeyGenerator is raised when a valid key generator is expected but nil used +var ErrNilKeyGenerator = errors.New("key generator is nil") + +// ErrNilPublicKeys is raised when public keys are expected but received nil +var ErrNilPublicKeys = errors.New("public keys are nil") + +// ErrNilMultiSignerContainer is raised when a nil multi signer container has been provided +var ErrNilMultiSignerContainer = errors.New("multi signer container is nil") + +// ErrIndexOutOfBounds is raised when an out of bound index is used +var ErrIndexOutOfBounds = errors.New("index is out of bounds") + +// ErrEmptyPubKeyString is raised when an empty public key string is used +var ErrEmptyPubKeyString = errors.New("public key string is empty") + +// ErrNilMessage is raised when trying to verify a nil signed message or trying to sign a nil message +var ErrNilMessage = errors.New("message to be signed or to be verified is nil") + +// ErrBitmapMismatch is raised when an invalid bitmap is passed to the multisigner +var ErrBitmapMismatch = errors.New("multi signer reported a mismatch in used bitmap") diff --git a/consensus/signing/signing.go b/consensus/signing/signing.go new file mode 100644 index 00000000000..3a84c841754 --- /dev/null +++ b/consensus/signing/signing.go @@ -0,0 +1,322 @@ +package signing + +import ( + "sync" + + "github.com/ElrondNetwork/elrond-go-core/core/check" + crypto "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" +) + +// ArgsSignatureHolder defines the arguments needed to create a new signature holder component +type ArgsSignatureHolder struct { + PubKeys []string + PrivKeyBytes []byte + MultiSignerContainer cryptoCommon.MultiSignerContainer + KeyGenerator crypto.KeyGenerator +} + +type signatureHolderData struct { + pubKeys [][]byte + privKey []byte + sigShares [][]byte + aggSig []byte +} + +type signatureHolder struct { + data *signatureHolderData + mutSigningData sync.RWMutex + multiSignerContainer cryptoCommon.MultiSignerContainer + keyGen crypto.KeyGenerator +} + +// NewSignatureHolder will create a new signature holder component +func NewSignatureHolder(args ArgsSignatureHolder) (*signatureHolder, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + sigSharesSize := uint16(len(args.PubKeys)) + sigShares := make([][]byte, sigSharesSize) + + pubKeysBytes, err := convertStringsToPubKeysBytes(args.PubKeys) + if err != nil { + return nil, err + } + + data := &signatureHolderData{ + pubKeys: pubKeysBytes, + privKey: args.PrivKeyBytes, + sigShares: sigShares, + } + + return &signatureHolder{ + data: data, + mutSigningData: sync.RWMutex{}, + multiSignerContainer: args.MultiSignerContainer, + keyGen: args.KeyGenerator, + }, nil +} + +func checkArgs(args ArgsSignatureHolder) error { + if check.IfNil(args.MultiSignerContainer) { + return ErrNilMultiSignerContainer + } + if len(args.PrivKeyBytes) == 0 { + return ErrNoPrivateKeySet + } + if check.IfNil(args.KeyGenerator) { + return ErrNilKeyGenerator + } + if len(args.PubKeys) == 0 { + return ErrNoPublicKeySet + } + + return nil +} + +// Create generates a signature holder component and initializes corresponding fields +func (sh *signatureHolder) Create(pubKeys []string) (*signatureHolder, error) { + sh.mutSigningData.RLock() + privKey := sh.data.privKey + sh.mutSigningData.RUnlock() + + args := ArgsSignatureHolder{ + PubKeys: pubKeys, + PrivKeyBytes: privKey, + MultiSignerContainer: sh.multiSignerContainer, + KeyGenerator: sh.keyGen, + } + return NewSignatureHolder(args) +} + +// Reset resets the data inside the signature holder component +func (sh *signatureHolder) Reset(pubKeys []string) error { + if pubKeys == nil { + return ErrNilPublicKeys + } + + sigSharesSize := uint16(len(pubKeys)) + sigShares := make([][]byte, sigSharesSize) + pubKeysBytes, err := convertStringsToPubKeysBytes(pubKeys) + if err != nil { + return err + } + + sh.mutSigningData.Lock() + defer sh.mutSigningData.Unlock() + + privKey := sh.data.privKey + + data := &signatureHolderData{ + pubKeys: pubKeysBytes, + privKey: privKey, + sigShares: sigShares, + } + + sh.data = data + + return nil +} + +// CreateSignatureShare returns a signature over a message +func (sh *signatureHolder) CreateSignatureShare(message []byte, selfIndex uint16, epoch uint32) ([]byte, error) { + sh.mutSigningData.RLock() + privateKeyBytes := sh.data.privKey + sh.mutSigningData.RUnlock() + + return sh.CreateSignatureShareWithPrivateKey(message, selfIndex, epoch, privateKeyBytes) +} + +// CreateSignatureShareWithPrivateKey returns a signature over a message providing the private key bytes +func (sh *signatureHolder) CreateSignatureShareWithPrivateKey(message []byte, index uint16, epoch uint32, privateKeyBytes []byte) ([]byte, error) { + if message == nil { + return nil, ErrNilMessage + } + + sh.mutSigningData.Lock() + defer sh.mutSigningData.Unlock() + + multiSigner, err := sh.multiSignerContainer.GetMultiSigner(epoch) + if err != nil { + return nil, err + } + + sigShareBytes, err := multiSigner.CreateSignatureShare(privateKeyBytes, message) + if err != nil { + return nil, err + } + + sh.data.sigShares[index] = sigShareBytes + + return sigShareBytes, nil +} + +// VerifySignatureShare will verify the signature share based on the specified index +func (sh *signatureHolder) VerifySignatureShare(index uint16, sig []byte, message []byte, epoch uint32) error { + if len(sig) == 0 { + return ErrInvalidSignature + } + + sh.mutSigningData.RLock() + defer sh.mutSigningData.RUnlock() + + indexOutOfBounds := index >= uint16(len(sh.data.pubKeys)) + if indexOutOfBounds { + return ErrIndexOutOfBounds + } + + pubKey := sh.data.pubKeys[index] + + multiSigner, err := sh.multiSignerContainer.GetMultiSigner(epoch) + if err != nil { + return err + } + + return multiSigner.VerifySignatureShare(pubKey, message, sig) +} + +// StoreSignatureShare stores the partial signature of the signer with specified position +func (sh *signatureHolder) StoreSignatureShare(index uint16, sig []byte) error { + if len(sig) == 0 { + return ErrInvalidSignature + } + + sh.mutSigningData.Lock() + defer sh.mutSigningData.Unlock() + + if int(index) >= len(sh.data.sigShares) { + return ErrIndexOutOfBounds + } + + sh.data.sigShares[index] = sig + + return nil +} + +// SignatureShare returns the partial signature set for given index +func (sh *signatureHolder) SignatureShare(index uint16) ([]byte, error) { + sh.mutSigningData.RLock() + defer sh.mutSigningData.RUnlock() + + if int(index) >= len(sh.data.sigShares) { + return nil, ErrIndexOutOfBounds + } + + if sh.data.sigShares[index] == nil { + return nil, ErrNilElement + } + + return sh.data.sigShares[index], nil +} + +// not concurrent safe, should be used under RLock mutex +func (sh *signatureHolder) isIndexInBitmap(index uint16, bitmap []byte) bool { + indexOutOfBounds := index >= uint16(len(sh.data.pubKeys)) + if indexOutOfBounds { + return false + } + + indexNotInBitmap := bitmap[index/8]&(1< 0 + if isFirstCommit { + log.Debug("processor debugger: first committed block", "round", debugger.lastCommittedBlockRound) + return false + } + + isNodeRunning := debugger.lastCheckedBlockRound < debugger.lastCommittedBlockRound + if isNodeRunning { + log.Debug("processor debugger: node is running, nothing to do", "round", debugger.lastCommittedBlockRound) + return false + } + + return true +} + +func (debugger *processDebugger) trigger() { + debugger.mut.RLock() + lastCommittedBlockRound := debugger.lastCommittedBlockRound + debugger.mut.RUnlock() + + log.Warn("processor debugger: node is stuck", + "last committed round", lastCommittedBlockRound) + + debugger.logChangeHandler() + + if debugger.dumpGoRoutines { + debugger.goRoutinesDumpHandler() + } +} + +// SetLastCommittedBlockRound sets the last committed block's round +func (debugger *processDebugger) SetLastCommittedBlockRound(round uint64) { + debugger.mut.Lock() + defer debugger.mut.Unlock() + + log.Debug("processor debugger: updated last committed block round", "round", round) + debugger.lastCommittedBlockRound = int64(round) +} + +// Close stops any started go routines +func (debugger *processDebugger) Close() error { + debugger.cancel() + + return nil +} + +func dumpGoRoutines() { + buff := make([]byte, buffSize) + numBytes := runtime.Stack(buff, true) + log.Debug(string(buff[:numBytes])) +} + +func (debugger *processDebugger) changeLog() { + errSetLogLevel := logger.SetLogLevel(debugger.debuggingLogLevel) + if errSetLogLevel != nil { + log.Error("debugger.changeLog: cannot change log level", "error", errSetLogLevel) + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (debugger *processDebugger) IsInterfaceNil() bool { + return debugger == nil +} diff --git a/debug/process/debugger_test.go b/debug/process/debugger_test.go new file mode 100644 index 00000000000..97d09811598 --- /dev/null +++ b/debug/process/debugger_test.go @@ -0,0 +1,223 @@ +package process + +import ( + "context" + "errors" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/ElrondNetwork/elrond-go-core/core/check" + "github.com/ElrondNetwork/elrond-go/config" + "github.com/stretchr/testify/assert" +) + +func createMockProcessDebugConfig() config.ProcessDebugConfig { + return config.ProcessDebugConfig{ + Enabled: true, + GoRoutinesDump: true, + DebuggingLogLevel: "*:INFO", + PollingTimeInSeconds: minAcceptedValue, + } +} + +func TestNewProcessDebugger(t *testing.T) { + t.Parallel() + + t.Run("invalid PollingTimeInSeconds", func(t *testing.T) { + t.Parallel() + + configs := createMockProcessDebugConfig() + configs.PollingTimeInSeconds = minAcceptedValue - 1 + + debuggerInstance, err := NewProcessDebugger(configs) + + assert.True(t, check.IfNil(debuggerInstance)) + assert.True(t, errors.Is(err, errInvalidValue)) + assert.True(t, strings.Contains(err.Error(), "PollingTimeInSeconds")) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + configs := createMockProcessDebugConfig() + debuggerInstance, err := NewProcessDebugger(configs) + + assert.False(t, check.IfNil(debuggerInstance)) + assert.Nil(t, err) + + _ = debuggerInstance.Close() + }) +} + +func TestDebugger_ProcessLoopAndClose(t *testing.T) { + t.Parallel() + + t.Run("node is starting, go routines dump active, should not trigger", func(t *testing.T) { + t.Parallel() + + configs := createMockProcessDebugConfig() + + numGoRoutinesDumpHandlerCalls := int32(0) + numLogChangeHandlerCalls := int32(0) + + debuggerInstance, _ := NewProcessDebugger(configs) + debuggerInstance.goRoutinesDumpHandler = func() { + atomic.AddInt32(&numGoRoutinesDumpHandlerCalls, 1) + } + debuggerInstance.logChangeHandler = func() { + atomic.AddInt32(&numLogChangeHandlerCalls, 1) + } + + time.Sleep(time.Second*3 + time.Millisecond*500) + + assert.Zero(t, atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Zero(t, atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + + time.Sleep(time.Second * 3) + + assert.Zero(t, atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Zero(t, atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + + err := debuggerInstance.Close() + assert.Nil(t, err) + + time.Sleep(time.Second * 3) + + assert.Zero(t, atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Zero(t, atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + }) + t.Run("node is syncing, go routines dump active, should not trigger", func(t *testing.T) { + t.Parallel() + + configs := createMockProcessDebugConfig() + + numGoRoutinesDumpHandlerCalls := int32(0) + numLogChangeHandlerCalls := int32(0) + + debuggerInstance, _ := NewProcessDebugger(configs) + debuggerInstance.goRoutinesDumpHandler = func() { + atomic.AddInt32(&numGoRoutinesDumpHandlerCalls, 1) + } + debuggerInstance.logChangeHandler = func() { + atomic.AddInt32(&numLogChangeHandlerCalls, 1) + } + debuggerInstance.SetLastCommittedBlockRound(223) + + time.Sleep(time.Second*1 + time.Millisecond*500) + + assert.Zero(t, atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Zero(t, atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + + err := debuggerInstance.Close() + assert.Nil(t, err) + + time.Sleep(time.Second * 3) + + assert.Zero(t, atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Zero(t, atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + }) + t.Run("node is running, go routines dump active, should not trigger", func(t *testing.T) { + t.Parallel() + + configs := createMockProcessDebugConfig() + + numGoRoutinesDumpHandlerCalls := int32(0) + numLogChangeHandlerCalls := int32(0) + + debuggerInstance, _ := NewProcessDebugger(configs) + debuggerInstance.goRoutinesDumpHandler = func() { + atomic.AddInt32(&numGoRoutinesDumpHandlerCalls, 1) + } + debuggerInstance.logChangeHandler = func() { + atomic.AddInt32(&numLogChangeHandlerCalls, 1) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for i := uint64(0); ; i++ { + select { + case <-ctx.Done(): + return + case <-time.After(time.Millisecond * 100): + debuggerInstance.SetLastCommittedBlockRound(i) + } + } + }() + + time.Sleep(time.Second*3 + time.Millisecond*500) + + assert.Equal(t, int32(0), atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + + err := debuggerInstance.Close() + assert.Nil(t, err) + + time.Sleep(time.Second * 3) + + assert.Equal(t, int32(0), atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + }) + t.Run("node is stuck, go routines dump active, should trigger", func(t *testing.T) { + t.Parallel() + + configs := createMockProcessDebugConfig() + + numGoRoutinesDumpHandlerCalls := int32(0) + numLogChangeHandlerCalls := int32(0) + + debuggerInstance, _ := NewProcessDebugger(configs) + debuggerInstance.goRoutinesDumpHandler = func() { + atomic.AddInt32(&numGoRoutinesDumpHandlerCalls, 1) + } + debuggerInstance.logChangeHandler = func() { + atomic.AddInt32(&numLogChangeHandlerCalls, 1) + } + debuggerInstance.SetLastCommittedBlockRound(223) + + time.Sleep(time.Second*3 + time.Millisecond*500) + + assert.Equal(t, int32(2), atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Equal(t, int32(2), atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + + err := debuggerInstance.Close() + assert.Nil(t, err) + + time.Sleep(time.Second * 3) + + assert.Equal(t, int32(2), atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Equal(t, int32(2), atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + }) + t.Run("node is stuck, go routines dump inactive, should trigger", func(t *testing.T) { + t.Parallel() + + configs := createMockProcessDebugConfig() + configs.GoRoutinesDump = false + + numGoRoutinesDumpHandlerCalls := int32(0) + numLogChangeHandlerCalls := int32(0) + + debuggerInstance, _ := NewProcessDebugger(configs) + debuggerInstance.goRoutinesDumpHandler = func() { + atomic.AddInt32(&numGoRoutinesDumpHandlerCalls, 1) + } + debuggerInstance.logChangeHandler = func() { + atomic.AddInt32(&numLogChangeHandlerCalls, 1) + } + debuggerInstance.SetLastCommittedBlockRound(223) + + time.Sleep(time.Second*3 + time.Millisecond*500) + + assert.Equal(t, int32(2), atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + + err := debuggerInstance.Close() + assert.Nil(t, err) + + time.Sleep(time.Second * 3) + + assert.Equal(t, int32(2), atomic.LoadInt32(&numLogChangeHandlerCalls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&numGoRoutinesDumpHandlerCalls)) + }) +} diff --git a/debug/process/disabledDebugger.go b/debug/process/disabledDebugger.go new file mode 100644 index 00000000000..980bb87c6f0 --- /dev/null +++ b/debug/process/disabledDebugger.go @@ -0,0 +1,23 @@ +package process + +type disabledDebugger struct { +} + +// NewDisabledDebugger creates a disabled process debugger instance +func NewDisabledDebugger() *disabledDebugger { + return &disabledDebugger{} +} + +// SetLastCommittedBlockRound does nothing +func (debugger *disabledDebugger) SetLastCommittedBlockRound(_ uint64) { +} + +// Close does nothing and returns nil +func (debugger *disabledDebugger) Close() error { + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (debugger *disabledDebugger) IsInterfaceNil() bool { + return debugger == nil +} diff --git a/debug/process/disabledDebugger_test.go b/debug/process/disabledDebugger_test.go new file mode 100644 index 00000000000..de4118e5c72 --- /dev/null +++ b/debug/process/disabledDebugger_test.go @@ -0,0 +1,35 @@ +package process + +import ( + "fmt" + "testing" + + "github.com/ElrondNetwork/elrond-go-core/core/check" + "github.com/stretchr/testify/assert" +) + +func TestNewDisabledDebugger(t *testing.T) { + t.Parallel() + + debugger := NewDisabledDebugger() + assert.False(t, check.IfNil(debugger)) +} + +func TestDisabledDebugger_MethodsShouldNotPanic(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not failed %v", r)) + } + }() + + debugger := NewDisabledDebugger() + debugger.SetLastCommittedBlockRound(0) + debugger.SetLastCommittedBlockRound(1) + err := debugger.Close() + assert.Nil(t, err) + + debugger.SetLastCommittedBlockRound(1) +} diff --git a/debug/process/errors.go b/debug/process/errors.go new file mode 100644 index 00000000000..158541c3308 --- /dev/null +++ b/debug/process/errors.go @@ -0,0 +1,5 @@ +package process + +import "errors" + +var errInvalidValue = errors.New("invalid value") diff --git a/debug/resolver/interceptorResolver.go b/debug/resolver/interceptorResolver.go index 903b4de41b4..e6f659e041d 100644 --- a/debug/resolver/interceptorResolver.go +++ b/debug/resolver/interceptorResolver.go @@ -13,7 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/debug" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" ) const requestEvent = "request" @@ -91,13 +91,13 @@ type interceptorResolver struct { // NewInterceptorResolver creates a new interceptorResolver able to hold requested-intercepted information func NewInterceptorResolver(config config.InterceptorResolverDebugConfig) (*interceptorResolver, error) { - cache, err := lrucache.NewCache(config.CacheSize) + lruCache, err := cache.NewLRUCache(config.CacheSize) if err != nil { return nil, fmt.Errorf("%w when creating NewInterceptorResolver", err) } ir := &interceptorResolver{ - cache: cache, + cache: lruCache, timestampHandler: getCurrentTimeStamp, } diff --git a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go index da86e186d00..965b72812ab 100644 --- a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go +++ b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go @@ -17,6 +17,16 @@ func NewAccountsAdapter() *accountsAdapter { return &accountsAdapter{} } +// SetSyncer - +func (a *accountsAdapter) SetSyncer(_ state.AccountsDBSyncer) error { + return nil +} + +// StartSnapshotIfNeeded - +func (a *accountsAdapter) StartSnapshotIfNeeded() error { + return nil +} + // GetTrie - func (a *accountsAdapter) GetTrie(_ []byte) (common.Trie, error) { return nil, nil diff --git a/epochStart/bootstrap/disabled/disabledMultiSigner.go b/epochStart/bootstrap/disabled/disabledMultiSigner.go index 46c5cff28ff..dd7c5e4d7b6 100644 --- a/epochStart/bootstrap/disabled/disabledMultiSigner.go +++ b/epochStart/bootstrap/disabled/disabledMultiSigner.go @@ -1,9 +1,5 @@ package disabled -import ( - "github.com/ElrondNetwork/elrond-go-crypto" -) - type multiSigner struct { } @@ -12,54 +8,24 @@ func NewMultiSigner() *multiSigner { return &multiSigner{} } -// Create returns a nil instance and a nil error -func (m *multiSigner) Create(_ []string, _ uint16) (crypto.MultiSigner, error) { - return nil, nil -} - -// SetAggregatedSig returns nil -func (m *multiSigner) SetAggregatedSig([]byte) error { - return nil -} - -// Verify returns nil -func (m *multiSigner) Verify(_ []byte, _ []byte) error { - return nil -} - -// Reset returns nil and does nothing -func (m *multiSigner) Reset(_ []string, _ uint16) error { - return nil -} - // CreateSignatureShare returns nil byte slice and nil error func (m *multiSigner) CreateSignatureShare(_ []byte, _ []byte) ([]byte, error) { return nil, nil } -// StoreSignatureShare returns nil -func (m *multiSigner) StoreSignatureShare(_ uint16, _ []byte) error { - return nil -} - -// SignatureShare returns nil byte slice and a nil error -func (m *multiSigner) SignatureShare(_ uint16) ([]byte, error) { - return nil, nil -} - // VerifySignatureShare returns nil -func (m *multiSigner) VerifySignatureShare(_ uint16, _ []byte, _ []byte, _ []byte) error { +func (m *multiSigner) VerifySignatureShare(_ []byte, _ []byte, _ []byte) error { return nil } // AggregateSigs returns nil byte slice and nil error -func (m *multiSigner) AggregateSigs(_ []byte) ([]byte, error) { +func (m *multiSigner) AggregateSigs(_ [][]byte, _ [][]byte) ([]byte, error) { return nil, nil } -// CreateAndAddSignatureShareForKey will return an empty slice and a nil error -func (m *multiSigner) CreateAndAddSignatureShareForKey(_ []byte, _ crypto.PrivateKey, _ []byte) ([]byte, error) { - return make([]byte, 0), nil +// VerifyAggregatedSig returns nil +func (m *multiSigner) VerifyAggregatedSig(_ [][]byte, _ []byte, _ []byte) error { + return nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/epochStart/bootstrap/disabled/disabledMultiSignerContainer.go b/epochStart/bootstrap/disabled/disabledMultiSignerContainer.go new file mode 100644 index 00000000000..cb5a735fe13 --- /dev/null +++ b/epochStart/bootstrap/disabled/disabledMultiSignerContainer.go @@ -0,0 +1,24 @@ +package disabled + +import crypto "github.com/ElrondNetwork/elrond-go-crypto" + +type disabledMultiSignerContainer struct { + multiSigner crypto.MultiSigner +} + +// NewMultiSignerContainer creates a disabled multi signer container +func NewMultiSignerContainer() *disabledMultiSignerContainer { + return &disabledMultiSignerContainer{ + multiSigner: NewMultiSigner(), + } +} + +// GetMultiSigner returns a disabled multi signer as this is a disabled component +func (dmsc *disabledMultiSignerContainer) GetMultiSigner(_ uint32) (crypto.MultiSigner, error) { + return dmsc.multiSigner, nil +} + +// IsInterfaceNil returns true if the underlying object is nil +func (dmsc *disabledMultiSignerContainer) IsInterfaceNil() bool { + return dmsc == nil +} diff --git a/epochStart/bootstrap/disabled/disabledStorer.go b/epochStart/bootstrap/disabled/disabledStorer.go index 0e2350bf814..f0864b313cc 100644 --- a/epochStart/bootstrap/disabled/disabledStorer.go +++ b/epochStart/bootstrap/disabled/disabledStorer.go @@ -2,8 +2,8 @@ package disabled import ( "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) const defaultCapacity = 10000 @@ -12,12 +12,12 @@ const zeroSize = 0 // CreateMemUnit creates an in-memory storer unit using maps func CreateMemUnit() storage.Storer { - cache, err := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: defaultCapacity, Shards: defaultNumShards, SizeInBytes: zeroSize}) + cache, err := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: defaultCapacity, Shards: defaultNumShards, SizeInBytes: zeroSize}) if err != nil { return nil } - unit, err := storageUnit.NewStorageUnit(cache, memorydb.New()) + unit, err := storageunit.NewStorageUnit(cache, database.NewMemDB()) if err != nil { return nil } diff --git a/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go b/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go index deb1204c3d5..d6a41588f8f 100644 --- a/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go +++ b/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go @@ -16,7 +16,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/factory/interceptorscontainer" "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" "github.com/ElrondNetwork/elrond-go/update" ) @@ -56,7 +56,7 @@ func NewEpochStartInterceptorsContainer(args ArgsEpochStartInterceptorContainer) } cryptoComponents := args.CryptoComponents.Clone().(process.CryptoComponentsHolder) - err := cryptoComponents.SetMultiSigner(disabled.NewMultiSigner()) + err := cryptoComponents.SetMultiSignerContainer(disabled.NewMultiSignerContainer()) if err != nil { return nil, err } @@ -65,7 +65,7 @@ func NewEpochStartInterceptorsContainer(args ArgsEpochStartInterceptorContainer) storer := disabled.NewChainStorer() antiFloodHandler := disabled.NewAntiFloodHandler() accountsAdapter := disabled.NewAccountsAdapter() - blackListHandler := timecache.NewTimeCache(timeSpanForBadHeaders) + blackListHandler := cache.NewTimeCache(timeSpanForBadHeaders) feeHandler := &disabledGenesis.FeeHandler{} headerSigVerifier := disabled.NewHeaderSigVerifier() sizeCheckDelta := 0 diff --git a/epochStart/bootstrap/metaStorageHandler_test.go b/epochStart/bootstrap/metaStorageHandler_test.go index 07b27f7a392..14db0aef369 100644 --- a/epochStart/bootstrap/metaStorageHandler_test.go +++ b/epochStart/bootstrap/metaStorageHandler_test.go @@ -223,6 +223,7 @@ func testMetaWithMissingStorer(missingUnit dataRetriever.UnitType, atCallNumber } err := mtStrHandler.SaveDataToStorage(components) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), storage.ErrKeyNotFound.Error())) require.True(t, strings.Contains(err.Error(), missingUnit.String())) } diff --git a/epochStart/bootstrap/process.go b/epochStart/bootstrap/process.go index 54e363f912f..ce60feacd2a 100644 --- a/epochStart/bootstrap/process.go +++ b/epochStart/bootstrap/process.go @@ -37,10 +37,11 @@ import ( "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/state/syncer" "github.com/ElrondNetwork/elrond-go/storage" + "github.com/ElrondNetwork/elrond-go/storage/cache" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/trie/factory" + "github.com/ElrondNetwork/elrond-go/trie/storageMarker" "github.com/ElrondNetwork/elrond-go/update" updateSync "github.com/ElrondNetwork/elrond-go/update/sync" ) @@ -213,7 +214,7 @@ func NewEpochStartBootstrap(args ArgsEpochStartBootstrap) (*epochStartBootstrap, shardCoordinator: args.GenesisShardCoordinator, } - whiteListCache, err := storageUnit.NewCache(storageFactory.GetCacherFromConfig(epochStartProvider.generalConfig.WhiteListPool)) + whiteListCache, err := storageunit.NewCache(storageFactory.GetCacherFromConfig(epochStartProvider.generalConfig.WhiteListPool)) if err != nil { return nil, err } @@ -1044,6 +1045,7 @@ func (e *epochStartBootstrap) syncUserAccountsState(rootHash []byte) error { MaxHardCapForMissingNodes: e.maxHardCapForMissingNodes, TrieSyncerVersion: e.trieSyncerVersion, CheckNodesOnDisk: e.checkNodesOnDisk, + StorageMarker: storageMarker.NewTrieStorageMarker(), }, ShardId: e.shardCoordinator.SelfId(), Throttler: thr, @@ -1065,7 +1067,7 @@ func (e *epochStartBootstrap) syncUserAccountsState(rootHash []byte) error { func (e *epochStartBootstrap) createStorageService( shardCoordinator sharding.Coordinator, pathManager storage.PathManagerHandler, - epochStartNotifier storage.EpochStartNotifier, + epochStartNotifier epochStart.EpochStartNotifier, startEpoch uint32, createTrieEpochRootHashStorer bool, targetShardId uint32, @@ -1109,6 +1111,7 @@ func (e *epochStartBootstrap) syncValidatorAccountsState(rootHash []byte) error MaxHardCapForMissingNodes: e.maxHardCapForMissingNodes, TrieSyncerVersion: e.trieSyncerVersion, CheckNodesOnDisk: e.checkNodesOnDisk, + StorageMarker: storageMarker.NewTrieStorageMarker(), }, } accountsDBSyncer, err := syncer.NewValidatorAccountsSyncer(argsValidatorAccountsSyncer) @@ -1179,7 +1182,7 @@ func (e *epochStartBootstrap) createRequestHandler() error { return err } - requestedItemsHandler := timecache.NewTimeCache(timeBetweenRequests) + requestedItemsHandler := cache.NewTimeCache(timeBetweenRequests) e.requestHandler, err = requestHandlers.NewResolverRequestHandler( finder, requestedItemsHandler, diff --git a/epochStart/bootstrap/shardStorageHandler_test.go b/epochStart/bootstrap/shardStorageHandler_test.go index db5096f8e4c..fc8b741914b 100644 --- a/epochStart/bootstrap/shardStorageHandler_test.go +++ b/epochStart/bootstrap/shardStorageHandler_test.go @@ -140,6 +140,7 @@ func testShardWithMissingStorer(missingUnit dataRetriever.UnitType, atCallNumber } err := shardStorage.SaveDataToStorage(components, components.ShardHeader, false) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), storage.ErrKeyNotFound.Error())) require.True(t, strings.Contains(err.Error(), missingUnit.String())) } @@ -1144,7 +1145,7 @@ func createPendingAndProcessedMiniBlocksScenario() scenarioData { expectedPendingMbsWithScheduled := []bootstrapStorage.PendingMiniBlocksInfo{ {ShardID: 0, MiniBlocksHashes: [][]byte{crossMbHeaders[1].Hash, crossMbHeaders[2].Hash, crossMbHeaders[3].Hash, crossMbHeaders[4].Hash, crossMbHeaders[0].Hash}}, } - expectedProcessedMbsWithScheduled := []bootstrapStorage.MiniBlocksInMeta{} + expectedProcessedMbsWithScheduled := make([]bootstrapStorage.MiniBlocksInMeta, 0) headers := map[string]data.HeaderHandler{ lastFinishedMetaBlockHash: &block.MetaBlock{ diff --git a/epochStart/bootstrap/storageProcess.go b/epochStart/bootstrap/storageProcess.go index 249fb3dc71d..9bb14106976 100644 --- a/epochStart/bootstrap/storageProcess.go +++ b/epochStart/bootstrap/storageProcess.go @@ -22,8 +22,8 @@ import ( "github.com/ElrondNetwork/elrond-go/epochStart/bootstrap/disabled" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" "github.com/ElrondNetwork/elrond-go/sharding" + "github.com/ElrondNetwork/elrond-go/storage/cache" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/timecache" "github.com/ElrondNetwork/elrond-go/trie/factory" ) @@ -206,7 +206,7 @@ func (sesb *storageEpochStartBootstrap) createStorageRequestHandler() error { return err } - requestedItemsHandler := timecache.NewTimeCache(timeBetweenRequests) + requestedItemsHandler := cache.NewTimeCache(timeBetweenRequests) sesb.requestHandler, err = requestHandlers.NewResolverRequestHandler( finder, requestedItemsHandler, diff --git a/epochStart/bootstrap/syncValidatorStatus.go b/epochStart/bootstrap/syncValidatorStatus.go index 23ac78a3841..6bd71e4c142 100644 --- a/epochStart/bootstrap/syncValidatorStatus.go +++ b/epochStart/bootstrap/syncValidatorStatus.go @@ -17,7 +17,7 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" "github.com/ElrondNetwork/elrond-go/update" "github.com/ElrondNetwork/elrond-go/update/sync" ) @@ -103,7 +103,7 @@ func NewSyncValidatorStatus(args ArgsNewSyncValidatorStatus) (*syncValidatorStat return nil, err } - consensusGroupCache, err := lrucache.NewCache(consensusGroupCacheSize) + consensusGroupCache, err := cache.NewLRUCache(consensusGroupCacheSize) if err != nil { return nil, err } diff --git a/epochStart/interface.go b/epochStart/interface.go index f02f0b39bca..b7662e960ff 100644 --- a/epochStart/interface.go +++ b/epochStart/interface.go @@ -214,3 +214,9 @@ type EpochNotifier interface { CheckEpoch(epoch uint32) IsInterfaceNil() bool } + +// EpochStartNotifier defines which actions should be done for handling new epoch's events +type EpochStartNotifier interface { + RegisterHandler(handler ActionHandler) + IsInterfaceNil() bool +} diff --git a/epochStart/metachain/baseRewards.go b/epochStart/metachain/baseRewards.go index 8b8a96ddcf2..7c1a18795bb 100644 --- a/epochStart/metachain/baseRewards.go +++ b/epochStart/metachain/baseRewards.go @@ -331,7 +331,7 @@ func (brc *baseRewardsCreator) isSystemDelegationSC(address []byte) bool { return false } - val, err := userAcc.DataTrieTracker().RetrieveValue([]byte(core.DelegationSystemSCKey)) + val, err := userAcc.RetrieveValue([]byte(core.DelegationSystemSCKey)) if err != nil { return false } diff --git a/epochStart/metachain/baseRewards_test.go b/epochStart/metachain/baseRewards_test.go index 3787a9af4c1..f3bcaf61e11 100644 --- a/epochStart/metachain/baseRewards_test.go +++ b/epochStart/metachain/baseRewards_test.go @@ -855,15 +855,12 @@ func TestBaseRewardsCreator_isSystemDelegationSCTrue(t *testing.T) { args.UserAccountsDB = &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { return &stateMock.UserAccountStub{ - DataTrieTrackerCalled: func() state.DataTrieTracker { - return &mock.DataTrieTrackerStub{ - RetrieveValueCalled: func(key []byte) ([]byte, error) { - if bytes.Equal(key, []byte("delegation")) { - return []byte("value"), nil - } - return nil, fmt.Errorf("error") - }, + RetrieveValueCalled: func(key []byte) ([]byte, error) { + if bytes.Equal(key, []byte("delegation")) { + return []byte("value"), nil } + + return nil, fmt.Errorf("error") }, }, nil }, diff --git a/epochStart/metachain/epochStartData_test.go b/epochStart/metachain/epochStartData_test.go index d0bcaa48f62..eb1fa61ebfa 100644 --- a/epochStart/metachain/epochStartData_test.go +++ b/epochStart/metachain/epochStartData_test.go @@ -14,8 +14,8 @@ import ( "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" @@ -84,9 +84,9 @@ func createMemUnit() storage.Storer { capacity := uint32(10) shards := uint32(1) sizeInBytes := uint64(0) - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) - persist, _ := memorydb.NewlruDB(100000) - unit, _ := storageUnit.NewStorageUnit(cache, persist) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) + persist, _ := database.NewlruDB(100000) + unit, _ := storageunit.NewStorageUnit(cache, persist) return unit } diff --git a/epochStart/metachain/rewards_test.go b/epochStart/metachain/rewards_test.go index 517ccc7eb03..5b71e950e4e 100644 --- a/epochStart/metachain/rewards_test.go +++ b/epochStart/metachain/rewards_test.go @@ -767,7 +767,7 @@ func TestRewardsCreator_ValidatorInfoWithMetaAddressAddedToProtocolSustainabilit acc, _ := args.UserAccountsDB.LoadAccount(vm.FirstDelegationSCAddress) userAcc, _ := acc.(state.UserAccountHandler) - _ = userAcc.DataTrieTracker().SaveKeyValue([]byte(core.DelegationSystemSCKey), []byte(core.DelegationSystemSCKey)) + _ = userAcc.SaveKeyValue([]byte(core.DelegationSystemSCKey), []byte(core.DelegationSystemSCKey)) _ = args.UserAccountsDB.SaveAccount(userAcc) miniBlocks, err := rwdc.CreateRewardsMiniBlocks(metaBlk, valInfo, &metaBlk.EpochStart.Economics) diff --git a/epochStart/metachain/stakingDataProvider_test.go b/epochStart/metachain/stakingDataProvider_test.go index a8c6099d5ba..2dacc431d2c 100644 --- a/epochStart/metachain/stakingDataProvider_test.go +++ b/epochStart/metachain/stakingDataProvider_test.go @@ -450,7 +450,7 @@ func saveOutputAccounts(t *testing.T, accountsDB state.AccountsAdapter, vmOutput userAccount, _ := account.(state.UserAccountHandler) for _, storeUpdate := range outputAccount.StorageUpdates { - _ = userAccount.DataTrieTracker().SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) + _ = userAccount.SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) } err := accountsDB.SaveAccount(account) diff --git a/epochStart/metachain/systemSCs.go b/epochStart/metachain/systemSCs.go index 20e82f7899f..26592bb70a0 100644 --- a/epochStart/metachain/systemSCs.go +++ b/epochStart/metachain/systemSCs.go @@ -23,6 +23,7 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/vm" "github.com/ElrondNetwork/elrond-go/vm/systemSmartContracts" vmcommon "github.com/ElrondNetwork/elrond-vm-common" @@ -890,7 +891,7 @@ func (s *systemSCProcessor) processSCOutputAccounts( storageUpdates := process.GetSortedStorageUpdates(outAcc) for _, storeUpdate := range storageUpdates { - err = acc.DataTrieTracker().SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) + err = acc.SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) if err != nil { return err } @@ -1100,7 +1101,7 @@ func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAc } chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = userValidatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, context.Background(), rootHash) + err = userValidatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, err } diff --git a/epochStart/metachain/systemSCs_test.go b/epochStart/metachain/systemSCs_test.go index ddcf47f874f..057b0c04a43 100644 --- a/epochStart/metachain/systemSCs_test.go +++ b/epochStart/metachain/systemSCs_test.go @@ -41,7 +41,7 @@ import ( "github.com/ElrondNetwork/elrond-go/state/storagePruningManager" "github.com/ElrondNetwork/elrond-go/state/storagePruningManager/evictionWaitingList" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" @@ -66,7 +66,7 @@ type testKeyPair struct { } func createPhysicalUnit(t *testing.T) (storage.Storer, string) { - cacheConfig := storageUnit.CacheConfig{ + cacheConfig := storageunit.CacheConfig{ Name: "test", Type: "SizeLRU", SizeInBytes: 314572800, @@ -76,7 +76,7 @@ func createPhysicalUnit(t *testing.T) (storage.Storer, string) { Shards: 0, } dir := t.TempDir() - persisterConfig := storageUnit.ArgDB{ + persisterConfig := storageunit.ArgDB{ Path: dir, DBType: "LvlDBSerial", BatchDelaySeconds: 2, @@ -84,9 +84,9 @@ func createPhysicalUnit(t *testing.T) (storage.Storer, string) { MaxOpenFiles: 10, } - cache, _ := storageUnit.NewCache(cacheConfig) - persist, _ := storageUnit.NewDB(persisterConfig) - unit, _ := storageUnit.NewStorageUnit(cache, persist) + cache, _ := storageunit.NewCache(cacheConfig) + persist, _ := storageunit.NewDB(persisterConfig) + unit, _ := storageunit.NewStorageUnit(cache, persist) return unit, dir } @@ -265,7 +265,7 @@ func checkNodesStatusInSystemSCDataTrie(t *testing.T, nodes []*state.ValidatorIn systemScAccount, ok := account.(state.UserAccountHandler) require.True(t, ok) for _, nodeInfo := range nodes { - buff, err = systemScAccount.DataTrieTracker().RetrieveValue(nodeInfo.PublicKey) + buff, err = systemScAccount.RetrieveValue(nodeInfo.PublicKey) require.Nil(t, err) require.True(t, len(buff) > 0) @@ -547,7 +547,7 @@ func createEligibleNodes(numNodes int, stakingSCAcc state.UserAccountHandler, ma StakeValue: big.NewInt(100), } marshaledData, _ := marshalizer.Marshal(stakedData) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue([]byte(fmt.Sprintf("waiting_%d", i)), marshaledData) + _ = stakingSCAcc.SaveKeyValue([]byte(fmt.Sprintf("waiting_%d", i)), marshaledData) } } @@ -562,7 +562,7 @@ func createJailedNodes(numNodes int, stakingSCAcc state.UserAccountHandler, user OwnerAddress: []byte("ownerForAll"), } marshaledData, _ := marshalizer.Marshal(stakedData) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue([]byte(fmt.Sprintf("jailed__%d", i)), marshaledData) + _ = stakingSCAcc.SaveKeyValue([]byte(fmt.Sprintf("jailed__%d", i)), marshaledData) _ = userAccounts.SaveAccount(stakingSCAcc) @@ -629,11 +629,11 @@ func addValidatorDataWithUnStakedKey( OwnerAddress: ownerKey, } marshaledData, _ := marshalizer.Marshal(stakingData) - _ = stakingAccount.DataTrieTracker().SaveKeyValue(bls, marshaledData) + _ = stakingAccount.SaveKeyValue(bls, marshaledData) } marshaledData, _ := marshalizer.Marshal(validatorData) - _ = validatorAccount.DataTrieTracker().SaveKeyValue(ownerKey, marshaledData) + _ = validatorAccount.SaveKeyValue(ownerKey, marshaledData) _ = accountsDB.SaveAccount(validatorAccount) _ = accountsDB.SaveAccount(stakingAccount) @@ -650,7 +650,7 @@ func createWaitingNodes(numNodes int, stakingSCAcc state.UserAccountHandler, use StakeValue: big.NewInt(100), } marshaledData, _ := marshalizer.Marshal(stakedData) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue([]byte(fmt.Sprintf("waiting_%d", i)), marshaledData) + _ = stakingSCAcc.SaveKeyValue([]byte(fmt.Sprintf("waiting_%d", i)), marshaledData) previousKey := string(waitingKeyInList) waitingKeyInList = []byte("w_" + fmt.Sprintf("waiting_%d", i)) waitingListHead := &systemSmartContracts.WaitingList{ @@ -659,7 +659,7 @@ func createWaitingNodes(numNodes int, stakingSCAcc state.UserAccountHandler, use Length: uint32(numNodes), } marshaledData, _ = marshalizer.Marshal(waitingListHead) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue([]byte("waitingList"), marshaledData) + _ = stakingSCAcc.SaveKeyValue([]byte("waitingList"), marshaledData) waitingListElement := &systemSmartContracts.ElementInList{ BLSPublicKey: []byte(fmt.Sprintf("waiting_%d", i)), @@ -674,7 +674,7 @@ func createWaitingNodes(numNodes int, stakingSCAcc state.UserAccountHandler, use } marshaledData, _ = marshalizer.Marshal(waitingListElement) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(waitingKeyInList, marshaledData) + _ = stakingSCAcc.SaveKeyValue(waitingKeyInList, marshaledData) vInfo := &state.ValidatorInfo{ PublicKey: []byte(fmt.Sprintf("waiting_%d", i)), @@ -713,7 +713,7 @@ func addValidatorData( } marshaledData, _ := marshalizer.Marshal(validatorData) - _ = validatorSC.DataTrieTracker().SaveKeyValue(ownerKey, marshaledData) + _ = validatorSC.SaveKeyValue(ownerKey, marshaledData) _ = accountsDB.SaveAccount(validatorSC) } @@ -732,7 +732,7 @@ func addStakedData( StakeValue: big.NewInt(0), } marshaledData, _ := marshalizer.Marshal(stakedData) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(stakedKey, marshaledData) + _ = stakingSCAcc.SaveKeyValue(stakedKey, marshaledData) _ = accountsDB.SaveAccount(stakingSCAcc) } @@ -754,7 +754,7 @@ func prepareStakingContractWithData( StakeValue: big.NewInt(100), } marshaledData, _ := marshalizer.Marshal(stakedData) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(stakedKey, marshaledData) + _ = stakingSCAcc.SaveKeyValue(stakedKey, marshaledData) _ = accountsDB.SaveAccount(stakingSCAcc) saveOneKeyToWaitingList(accountsDB, waitingKey, marshalizer, rewardAddress, ownerAddress) @@ -772,7 +772,7 @@ func prepareStakingContractWithData( } marshaledData, _ = marshalizer.Marshal(validatorData) - _ = validatorSC.DataTrieTracker().SaveKeyValue(rewardAddress, marshaledData) + _ = validatorSC.SaveKeyValue(rewardAddress, marshaledData) _ = accountsDB.SaveAccount(validatorSC) _, err := accountsDB.Commit() @@ -794,7 +794,7 @@ func saveOneKeyToWaitingList( StakeValue: big.NewInt(100), } marshaledData, _ := marshalizer.Marshal(stakedData) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(waitingKey, marshaledData) + _ = stakingSCAcc.SaveKeyValue(waitingKey, marshaledData) waitingKeyInList := []byte("w_" + string(waitingKey)) waitingListHead := &systemSmartContracts.WaitingList{ @@ -803,7 +803,7 @@ func saveOneKeyToWaitingList( Length: 1, } marshaledData, _ = marshalizer.Marshal(waitingListHead) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue([]byte("waitingList"), marshaledData) + _ = stakingSCAcc.SaveKeyValue([]byte("waitingList"), marshaledData) waitingListElement := &systemSmartContracts.ElementInList{ BLSPublicKey: waitingKey, @@ -811,7 +811,7 @@ func saveOneKeyToWaitingList( NextKey: make([]byte, 0), } marshaledData, _ = marshalizer.Marshal(waitingListElement) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(waitingKeyInList, marshaledData) + _ = stakingSCAcc.SaveKeyValue(waitingKeyInList, marshaledData) _ = accountsDB.SaveAccount(stakingSCAcc) } @@ -833,10 +833,10 @@ func addKeysToWaitingList( StakeValue: big.NewInt(100), } marshaledData, _ := marshalizer.Marshal(stakedData) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(waitingKey, marshaledData) + _ = stakingSCAcc.SaveKeyValue(waitingKey, marshaledData) } - marshaledData, _ := stakingSCAcc.DataTrieTracker().RetrieveValue([]byte("waitingList")) + marshaledData, _ := stakingSCAcc.RetrieveValue([]byte("waitingList")) waitingListHead := &systemSmartContracts.WaitingList{} _ = marshalizer.Unmarshal(waitingListHead, marshaledData) waitingListHead.Length += uint32(len(waitingKeys)) @@ -844,7 +844,7 @@ func addKeysToWaitingList( waitingListHead.LastKey = lastKeyInList marshaledData, _ = marshalizer.Marshal(waitingListHead) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue([]byte("waitingList"), marshaledData) + _ = stakingSCAcc.SaveKeyValue([]byte("waitingList"), marshaledData) numWaitingKeys := len(waitingKeys) previousKey := waitingListHead.FirstKey @@ -863,17 +863,17 @@ func addKeysToWaitingList( } marshaledData, _ = marshalizer.Marshal(waitingListElement) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(waitingKeyInList, marshaledData) + _ = stakingSCAcc.SaveKeyValue(waitingKeyInList, marshaledData) previousKey = waitingKeyInList } - marshaledData, _ = stakingSCAcc.DataTrieTracker().RetrieveValue(waitingListHead.FirstKey) + marshaledData, _ = stakingSCAcc.RetrieveValue(waitingListHead.FirstKey) waitingListElement := &systemSmartContracts.ElementInList{} _ = marshalizer.Unmarshal(waitingListElement, marshaledData) waitingListElement.NextKey = []byte("w_" + string(waitingKeys[0])) marshaledData, _ = marshalizer.Marshal(waitingListElement) - _ = stakingSCAcc.DataTrieTracker().SaveKeyValue(waitingListHead.FirstKey, marshaledData) + _ = stakingSCAcc.SaveKeyValue(waitingListHead.FirstKey, marshaledData) _ = accountsDB.SaveAccount(stakingSCAcc) } @@ -1484,7 +1484,7 @@ func addDelegationData( } marshaledData, _ := marshalizer.Marshal(dStatus) - _ = delegatorSC.DataTrieTracker().SaveKeyValue([]byte("delegationStatus"), marshaledData) + _ = delegatorSC.SaveKeyValue([]byte("delegationStatus"), marshaledData) _ = accountsDB.SaveAccount(delegatorSC) } @@ -1563,7 +1563,7 @@ func TestSystemSCProcessor_ProcessSystemSmartContractUnStakeFromDelegationContra assert.Equal(t, 4, len(validatorInfos[0])) delegationSC := loadSCAccount(args.UserAccountsDB, delegationAddr) - marshalledData, err := delegationSC.DataTrie().Get([]byte("delegationStatus")) + marshalledData, err := delegationSC.DataTrie().(common.Trie).Get([]byte("delegationStatus")) assert.Nil(t, err) dStatus := &systemSmartContracts.DelegationContractStatus{ StakedKeys: make([]*systemSmartContracts.NodesData, 0), @@ -1652,7 +1652,7 @@ func TestSystemSCProcessor_ProcessSystemSmartContractShouldUnStakeFromAdditional } delegationSC := loadSCAccount(args.UserAccountsDB, delegationAddr) - marshalledData, err := delegationSC.DataTrie().Get([]byte("delegationStatus")) + marshalledData, err := delegationSC.DataTrie().(common.Trie).Get([]byte("delegationStatus")) assert.Nil(t, err) dStatus := &systemSmartContracts.DelegationContractStatus{ StakedKeys: make([]*systemSmartContracts.NodesData, 0), @@ -1742,7 +1742,7 @@ func TestSystemSCProcessor_ProcessSystemSmartContractUnStakeFromAdditionalQueue( assert.Nil(t, err) delegationSC := loadSCAccount(args.UserAccountsDB, delegationAddr2) - marshalledData, err := delegationSC.DataTrie().Get([]byte("delegationStatus")) + marshalledData, err := delegationSC.DataTrie().(common.Trie).Get([]byte("delegationStatus")) assert.Nil(t, err) dStatus := &systemSmartContracts.DelegationContractStatus{ StakedKeys: make([]*systemSmartContracts.NodesData, 0), @@ -1758,7 +1758,7 @@ func TestSystemSCProcessor_ProcessSystemSmartContractUnStakeFromAdditionalQueue( assert.Equal(t, []byte("waitingPubKe3"), dStatus.UnStakedKeys[1].BLSKey) stakingSCAcc := loadSCAccount(args.UserAccountsDB, vm.StakingSCAddress) - marshaledData, _ := stakingSCAcc.DataTrieTracker().RetrieveValue([]byte("waitingList")) + marshaledData, _ := stakingSCAcc.RetrieveValue([]byte("waitingList")) waitingListHead := &systemSmartContracts.WaitingList{} _ = args.Marshalizer.Unmarshal(waitingListHead, marshaledData) assert.Equal(t, uint32(3), waitingListHead.Length) @@ -1827,14 +1827,14 @@ func TestSystemSCProcessor_TogglePauseUnPause(t *testing.T) { assert.Nil(t, err) validatorSC := loadSCAccount(s.userAccountsDB, vm.ValidatorSCAddress) - value, _ := validatorSC.DataTrie().Get([]byte("unStakeUnBondPause")) + value, _ := validatorSC.DataTrie().(common.Trie).Get([]byte("unStakeUnBondPause")) assert.True(t, value[0] == 1) err = s.ToggleUnStakeUnBond(false) assert.Nil(t, err) validatorSC = loadSCAccount(s.userAccountsDB, vm.ValidatorSCAddress) - value, _ = validatorSC.DataTrie().Get([]byte("unStakeUnBondPause")) + value, _ = validatorSC.DataTrie().(common.Trie).Get([]byte("unStakeUnBondPause")) assert.True(t, value[0] == 0) } diff --git a/epochStart/mock/cryptoComponentsMock.go b/epochStart/mock/cryptoComponentsMock.go index afbcb00a382..1d65646728c 100644 --- a/epochStart/mock/cryptoComponentsMock.go +++ b/epochStart/mock/cryptoComponentsMock.go @@ -1,21 +1,23 @@ package mock import ( + "errors" "sync" "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" ) // CryptoComponentsMock - type CryptoComponentsMock struct { - PubKey crypto.PublicKey - BlockSig crypto.SingleSigner - TxSig crypto.SingleSigner - MultiSig crypto.MultiSigner - PeerSignHandler crypto.PeerSignatureHandler - BlKeyGen crypto.KeyGenerator - TxKeyGen crypto.KeyGenerator - mutCrypto sync.RWMutex + PubKey crypto.PublicKey + BlockSig crypto.SingleSigner + TxSig crypto.SingleSigner + MultiSigContainer cryptoCommon.MultiSignerContainer + PeerSignHandler crypto.PeerSignatureHandler + BlKeyGen crypto.KeyGenerator + TxKeyGen crypto.KeyGenerator + mutCrypto sync.RWMutex } // PublicKey - @@ -33,20 +35,32 @@ func (ccm *CryptoComponentsMock) TxSingleSigner() crypto.SingleSigner { return ccm.TxSig } -// MultiSigner - -func (ccm *CryptoComponentsMock) MultiSigner() crypto.MultiSigner { +// GetMultiSigner - +func (ccm *CryptoComponentsMock) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { ccm.mutCrypto.RLock() defer ccm.mutCrypto.RUnlock() - return ccm.MultiSig + if ccm.MultiSigContainer == nil { + return nil, errors.New("multisigner container is nil") + } + + return ccm.MultiSigContainer.GetMultiSigner(epoch) +} + +// MultiSignerContainer - +func (ccm *CryptoComponentsMock) MultiSignerContainer() cryptoCommon.MultiSignerContainer { + ccm.mutCrypto.RLock() + defer ccm.mutCrypto.RUnlock() + + return ccm.MultiSigContainer } -// SetMultiSigner - -func (ccm *CryptoComponentsMock) SetMultiSigner(m crypto.MultiSigner) error { +// SetMultiSignerContainer - +func (ccm *CryptoComponentsMock) SetMultiSignerContainer(msc cryptoCommon.MultiSignerContainer) error { ccm.mutCrypto.Lock() - ccm.MultiSig = m - ccm.mutCrypto.Unlock() + defer ccm.mutCrypto.Unlock() + ccm.MultiSigContainer = msc return nil } @@ -68,14 +82,14 @@ func (ccm *CryptoComponentsMock) TxSignKeyGen() crypto.KeyGenerator { // Clone - func (ccm *CryptoComponentsMock) Clone() interface{} { return &CryptoComponentsMock{ - PubKey: ccm.PubKey, - BlockSig: ccm.BlockSig, - TxSig: ccm.TxSig, - MultiSig: ccm.MultiSig, - PeerSignHandler: ccm.PeerSignHandler, - BlKeyGen: ccm.BlKeyGen, - TxKeyGen: ccm.TxKeyGen, - mutCrypto: sync.RWMutex{}, + PubKey: ccm.PubKey, + BlockSig: ccm.BlockSig, + TxSig: ccm.TxSig, + MultiSigContainer: ccm.MultiSigContainer, + PeerSignHandler: ccm.PeerSignHandler, + BlKeyGen: ccm.BlKeyGen, + TxKeyGen: ccm.TxKeyGen, + mutCrypto: sync.RWMutex{}, } } diff --git a/epochStart/mock/dataTrieTrackerStub.go b/epochStart/mock/dataTrieTrackerStub.go deleted file mode 100644 index 2616d46468e..00000000000 --- a/epochStart/mock/dataTrieTrackerStub.go +++ /dev/null @@ -1,66 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go/common" -) - -// DataTrieTrackerStub - -type DataTrieTrackerStub struct { - ClearDataCachesCalled func() - DirtyDataCalled func() map[string][]byte - RetrieveValueCalled func(key []byte) ([]byte, error) - SaveKeyValueCalled func(key []byte, value []byte) error - SetDataTrieCalled func(tr common.Trie) - DataTrieCalled func() common.Trie -} - -// ClearDataCaches - -func (dtts *DataTrieTrackerStub) ClearDataCaches() { - if dtts.ClearDataCachesCalled != nil { - dtts.ClearDataCachesCalled() - } -} - -// DirtyData - -func (dtts *DataTrieTrackerStub) DirtyData() map[string][]byte { - if dtts.DirtyDataCalled != nil { - return dtts.DirtyDataCalled() - } - return nil -} - -// RetrieveValue - -func (dtts *DataTrieTrackerStub) RetrieveValue(key []byte) ([]byte, error) { - if dtts.RetrieveValueCalled != nil { - return dtts.RetrieveValueCalled(key) - } - return nil, nil -} - -// SaveKeyValue - -func (dtts *DataTrieTrackerStub) SaveKeyValue(key []byte, value []byte) error { - if dtts.SaveKeyValueCalled != nil { - return dtts.SaveKeyValueCalled(key, value) - } - return nil -} - -// SetDataTrie - -func (dtts *DataTrieTrackerStub) SetDataTrie(tr common.Trie) { - if dtts.SetDataTrieCalled != nil { - dtts.SetDataTrieCalled(tr) - } -} - -// DataTrie - -func (dtts *DataTrieTrackerStub) DataTrie() common.Trie { - if dtts.DataTrieCalled != nil { - return dtts.DataTrieCalled() - } - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (dtts *DataTrieTrackerStub) IsInterfaceNil() bool { - return dtts == nil -} diff --git a/epochStart/shardchain/trigger_test.go b/epochStart/shardchain/trigger_test.go index 04f761071a4..3ba04c0953a 100644 --- a/epochStart/shardchain/trigger_test.go +++ b/epochStart/shardchain/trigger_test.go @@ -189,6 +189,7 @@ func testWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *testing.T } epochStartTrigger, err := NewEpochStartTrigger(args) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), storage.ErrKeyNotFound.Error())) require.True(t, strings.Contains(err.Error(), missingUnit.String())) require.True(t, check.IfNil(epochStartTrigger)) diff --git a/errors/closingError.go b/errors/closingError.go index 8360e28e202..2afc8f1bf47 100644 --- a/errors/closingError.go +++ b/errors/closingError.go @@ -2,6 +2,8 @@ package errors import ( "strings" + + "github.com/ElrondNetwork/elrond-go/storage" ) // IsClosingError returns true if the provided error is used whenever the node is in the closing process @@ -10,6 +12,6 @@ func IsClosingError(err error) bool { return false } - return strings.Contains(err.Error(), ErrDBIsClosed.Error()) || + return strings.Contains(err.Error(), storage.ErrDBIsClosed.Error()) || strings.Contains(err.Error(), ErrContextClosing.Error()) } diff --git a/errors/closingError_test.go b/errors/closingError_test.go index ee8c70ea8be..95c6b5f1927 100644 --- a/errors/closingError_test.go +++ b/errors/closingError_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/storage" "github.com/stretchr/testify/assert" ) @@ -24,7 +25,7 @@ func TestIsClosingError(t *testing.T) { t.Run("DB closed error should return true", func(t *testing.T) { t.Parallel() - assert.True(t, errors.IsClosingError(fmt.Errorf("%w random string", errors.ErrDBIsClosed))) + assert.True(t, errors.IsClosingError(fmt.Errorf("%w random string", storage.ErrDBIsClosed))) }) t.Run("contains 'DB is closed' should return true", func(t *testing.T) { t.Parallel() diff --git a/errors/errors.go b/errors/errors.go index 6fe1d8bf0e7..af6889d578f 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -77,6 +77,9 @@ var ErrNilBlockTracker = errors.New("trying to set nil block tracker") // ErrNilBootStorer signals that the provided boot storer is nil var ErrNilBootStorer = errors.New("nil boot storer") +// ErrNilBootstrapComponents signals that the provided instance of bootstrap components is nil +var ErrNilBootstrapComponents = errors.New("nil bootstrap components") + // ErrNilBootstrapComponentsHolder signals that the provided bootstrap components holder is nil var ErrNilBootstrapComponentsHolder = errors.New("nil bootstrap components holder") @@ -476,11 +479,44 @@ var ErrNilTxsSender = errors.New("nil transactions sender has been provided") // ErrNilProcessStatusHandler signals that a nil process status handler was provided var ErrNilProcessStatusHandler = errors.New("nil process status handler") -// ErrDBIsClosed is raised when the DB is closed -var ErrDBIsClosed = errors.New("DB is closed") - // ErrNilEnableEpochsHandler signals that a nil enable epochs handler was provided var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") +// ErrSignerNotSupported signals that a not supported signer was provided +var ErrSignerNotSupported = errors.New("signer not supported") + +// ErrMissingMultiSignerConfig signals that the multisigner config is missing +var ErrMissingMultiSignerConfig = errors.New("multisigner configuration missing") + +// ErrMissingMultiSigner signals that there is no multisigner instance available +var ErrMissingMultiSigner = errors.New("multisigner instance missing") + +// ErrMissingEpochZeroMultiSignerConfig signals that the multisigner config for epoch zero is missing +var ErrMissingEpochZeroMultiSignerConfig = errors.New("multisigner configuration missing for epoch zero") + +// ErrNilMultiSignerContainer signals that the multisigner container is nil +var ErrNilMultiSignerContainer = errors.New("multisigner container is nil") + +// ErrNilCacher signals that a nil cacher has been provided +var ErrNilCacher = errors.New("nil cacher") + +// ErrNilSingleSigner is raised when a valid singleSigner is expected but nil used +var ErrNilSingleSigner = errors.New("singleSigner is nil") + +// ErrPIDMismatch signals that the pid from the message is different from the cached pid associated to a certain pk +var ErrPIDMismatch = errors.New("pid mismatch") + +// ErrSignatureMismatch signals that the signature from the message is different from the cached signature associated to a certain pk +var ErrSignatureMismatch = errors.New("signature mismatch") + +// ErrInvalidPID signals that given PID is invalid +var ErrInvalidPID = errors.New("invalid PID") + +// ErrInvalidSignature signals that the given signature is invalid +var ErrInvalidSignature = errors.New("invalid signature") + +// ErrInvalidHeartbeatV2Config signals that an invalid heartbeat v2 configuration has been provided +var ErrInvalidHeartbeatV2Config = errors.New("invalid heartbeat v2 configuration") + // ErrNilManagedPeersHolder signals that a nil managed peers holder has been provided var ErrNilManagedPeersHolder = errors.New("nil managed peers holder") diff --git a/factory/addressDecoder.go b/factory/addressDecoder.go index 83508413a4d..b8ed6050a50 100644 --- a/factory/addressDecoder.go +++ b/factory/addressDecoder.go @@ -6,7 +6,8 @@ import ( "github.com/ElrondNetwork/elrond-go/errors" ) -func decodeAddresses(pkConverter core.PubkeyConverter, stringAddresses []string) ([][]byte, error) { +// DecodeAddresses will decode the provided string addresses +func DecodeAddresses(pkConverter core.PubkeyConverter, stringAddresses []string) ([][]byte, error) { if check.IfNil(pkConverter) { return nil, errors.ErrNilPubKeyConverter } diff --git a/factory/apiResolverFactory.go b/factory/api/apiResolverFactory.go similarity index 93% rename from factory/apiResolverFactory.go rename to factory/api/apiResolverFactory.go index 2d4d8b5f6c6..d388ef8e9ed 100644 --- a/factory/apiResolverFactory.go +++ b/factory/api/apiResolverFactory.go @@ -1,4 +1,4 @@ -package factory +package api import ( "errors" @@ -8,9 +8,11 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/marshal" + logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/facade" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/node/external" "github.com/ElrondNetwork/elrond-go/node/external/blockAPI" "github.com/ElrondNetwork/elrond-go/node/external/logs" @@ -31,22 +33,24 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/state" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/vm" vmcommon "github.com/ElrondNetwork/elrond-vm-common" "github.com/ElrondNetwork/elrond-vm-common/parsers" datafield "github.com/ElrondNetwork/elrond-vm-common/parsers/dataField" ) +var log = logger.GetOrCreate("factory") + // ApiResolverArgs holds the argument needed to create an API resolver type ApiResolverArgs struct { Configs *config.Configs - CoreComponents CoreComponentsHolder - DataComponents DataComponentsHolder - StateComponents StateComponentsHolder - BootstrapComponents BootstrapComponentsHolder - CryptoComponents CryptoComponentsHolder - ProcessComponents ProcessComponentsHolder + CoreComponents factory.CoreComponentsHolder + DataComponents factory.DataComponentsHolder + StateComponents factory.StateComponentsHolder + BootstrapComponents factory.BootstrapComponentsHolder + CryptoComponents factory.CryptoComponentsHolder + ProcessComponents factory.ProcessComponentsHolder GasScheduleNotifier common.GasScheduleNotifierAPI Bootstrapper process.Bootstrapper AllowVMQueriesChan chan struct{} @@ -55,10 +59,10 @@ type ApiResolverArgs struct { type scQueryServiceArgs struct { generalConfig *config.Config epochConfig *config.EpochConfig - coreComponents CoreComponentsHolder - stateComponents StateComponentsHolder - dataComponents DataComponentsHolder - processComponents ProcessComponentsHolder + coreComponents factory.CoreComponentsHolder + stateComponents factory.StateComponentsHolder + dataComponents factory.DataComponentsHolder + processComponents factory.ProcessComponentsHolder gasScheduleNotifier core.GasScheduleNotifier messageSigVerifier vm.MessageSignVerifier systemSCConfig *config.SystemSmartContractsConfig @@ -70,10 +74,10 @@ type scQueryServiceArgs struct { type scQueryElementArgs struct { generalConfig *config.Config epochConfig *config.EpochConfig - coreComponents CoreComponentsHolder - stateComponents StateComponentsHolder - dataComponents DataComponentsHolder - processComponents ProcessComponentsHolder + coreComponents factory.CoreComponentsHolder + stateComponents factory.StateComponentsHolder + dataComponents factory.DataComponentsHolder + processComponents factory.ProcessComponentsHolder gasScheduleNotifier core.GasScheduleNotifier messageSigVerifier vm.MessageSignVerifier systemSCConfig *config.SystemSmartContractsConfig @@ -109,7 +113,7 @@ func CreateApiResolver(args *ApiResolverArgs) (facade.ApiResolver, error) { pkConverter := args.CoreComponents.AddressPubKeyConverter() automaticCrawlerAddressesStrings := args.Configs.GeneralConfig.BuiltInFunctions.AutomaticCrawlerAddresses - convertedAddresses, errDecode := decodeAddresses(pkConverter, automaticCrawlerAddressesStrings) + convertedAddresses, errDecode := factory.DecodeAddresses(pkConverter, automaticCrawlerAddressesStrings) if errDecode != nil { return nil, errDecode } @@ -319,7 +323,7 @@ func createScQueryElement( pkConverter := args.coreComponents.AddressPubKeyConverter() automaticCrawlerAddressesStrings := args.generalConfig.BuiltInFunctions.AutomaticCrawlerAddresses - convertedAddresses, errDecode := decodeAddresses(pkConverter, automaticCrawlerAddressesStrings) + convertedAddresses, errDecode := factory.DecodeAddresses(pkConverter, automaticCrawlerAddressesStrings) if errDecode != nil { return nil, errDecode } @@ -339,7 +343,7 @@ func createScQueryElement( } cacherCfg := storageFactory.GetCacherFromConfig(args.generalConfig.SmartContractDataPool) - smartContractsCache, err := storageUnit.NewCache(cacherCfg) + smartContractsCache, err := storageunit.NewCache(cacherCfg) if err != nil { return nil, err } @@ -528,7 +532,7 @@ func createAPIBlockProcessorArgs(args *ApiResolverArgs, apiTransactionHandler ex return blockApiArgs, nil } -func createLogsFacade(args *ApiResolverArgs) (LogsFacade, error) { +func createLogsFacade(args *ApiResolverArgs) (factory.LogsFacade, error) { return logs.NewLogsFacade(logs.ArgsNewLogsFacade{ StorageService: args.DataComponents.StorageService(), Marshaller: args.CoreComponents.InternalMarshalizer(), diff --git a/factory/apiResolverFactory_test.go b/factory/api/apiResolverFactory_test.go similarity index 54% rename from factory/apiResolverFactory_test.go rename to factory/api/apiResolverFactory_test.go index 567ccda0e04..f39e5a68561 100644 --- a/factory/apiResolverFactory_test.go +++ b/factory/api/apiResolverFactory_test.go @@ -1,14 +1,16 @@ -package factory_test +package api_test import ( "testing" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/factory" + "github.com/ElrondNetwork/elrond-go/factory/api" + "github.com/ElrondNetwork/elrond-go/factory/bootstrap" "github.com/ElrondNetwork/elrond-go/factory/mock" "github.com/ElrondNetwork/elrond-go/process/sync/disabled" "github.com/ElrondNetwork/elrond-go/testscommon" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -18,25 +20,25 @@ func TestCreateApiResolver(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(1) - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() coreComponents.StatusHandlerUtils().Metrics() - networkComponents := getNetworkComponents() - dataComponents := getDataComponents(coreComponents, shardCoordinator) - cryptoComponents := getCryptoComponents(coreComponents) - stateComponents := getStateComponents(coreComponents, shardCoordinator) - processComponents := getProcessComponents(shardCoordinator, coreComponents, networkComponents, dataComponents, cryptoComponents, stateComponents) - argsB := getBootStrapArgs() + networkComponents := componentsMock.GetNetworkComponents() + dataComponents := componentsMock.GetDataComponents(coreComponents, shardCoordinator) + cryptoComponents := componentsMock.GetCryptoComponents(coreComponents) + stateComponents := componentsMock.GetStateComponents(coreComponents, shardCoordinator) + processComponents := componentsMock.GetProcessComponents(shardCoordinator, coreComponents, networkComponents, dataComponents, cryptoComponents, stateComponents) + argsB := componentsMock.GetBootStrapFactoryArgs() - bcf, _ := factory.NewBootstrapComponentsFactory(argsB) - mbc, err := factory.NewManagedBootstrapComponents(bcf) + bcf, _ := bootstrap.NewBootstrapComponentsFactory(argsB) + mbc, err := bootstrap.NewManagedBootstrapComponents(bcf) require.Nil(t, err) err = mbc.Create() require.Nil(t, err) - gasSchedule, _ := common.LoadGasScheduleConfig("../cmd/node/config/gasSchedules/gasScheduleV1.toml") + gasSchedule, _ := common.LoadGasScheduleConfig("../../cmd/node/config/gasSchedules/gasScheduleV1.toml") economicsConfig := testscommon.GetEconomicsConfig() - cfg := getGeneralConfig() - args := &factory.ApiResolverArgs{ + cfg := componentsMock.GetGeneralConfig() + args := &api.ApiResolverArgs{ Configs: &config.Configs{ FlagsConfig: &config.ContextFlagsConfig{ WorkingDir: "", @@ -58,7 +60,7 @@ func TestCreateApiResolver(t *testing.T) { AllowVMQueriesChan: common.GetClosedUnbufferedChannel(), } - apiResolver, err := factory.CreateApiResolver(args) + apiResolver, err := api.CreateApiResolver(args) require.Nil(t, err) require.NotNil(t, apiResolver) } diff --git a/factory/bootstrapComponents.go b/factory/bootstrap/bootstrapComponents.go similarity index 85% rename from factory/bootstrapComponents.go rename to factory/bootstrap/bootstrapComponents.go index 34a4d459364..2995e46af77 100644 --- a/factory/bootstrapComponents.go +++ b/factory/bootstrap/bootstrapComponents.go @@ -1,4 +1,4 @@ -package factory +package bootstrap import ( "fmt" @@ -6,22 +6,26 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/cmd/node/factory" + logger "github.com/ElrondNetwork/elrond-go-logger" + nodeFactory "github.com/ElrondNetwork/elrond-go/cmd/node/factory" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/epochStart/bootstrap" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/factory/block" "github.com/ElrondNetwork/elrond-go/process/headerCheck" "github.com/ElrondNetwork/elrond-go/process/smartContract" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/storage" + "github.com/ElrondNetwork/elrond-go/storage/directoryhandler" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/factory/directoryhandler" "github.com/ElrondNetwork/elrond-go/storage/latestData" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) +var log = logger.GetOrCreate("factory") + // BootstrapComponentsFactoryArgs holds the arguments needed to create a botstrap components factory type BootstrapComponentsFactoryArgs struct { Config config.Config @@ -30,9 +34,9 @@ type BootstrapComponentsFactoryArgs struct { ImportDbConfig config.ImportDbConfig FlagsConfig config.ContextFlagsConfig WorkingDir string - CoreComponents CoreComponentsHolder - CryptoComponents CryptoComponentsHolder - NetworkComponents NetworkComponentsHolder + CoreComponents factory.CoreComponentsHolder + CryptoComponents factory.CryptoComponentsHolder + NetworkComponents factory.NetworkComponentsHolder } type bootstrapComponentsFactory struct { @@ -41,19 +45,19 @@ type bootstrapComponentsFactory struct { importDbConfig config.ImportDbConfig flagsConfig config.ContextFlagsConfig workingDir string - coreComponents CoreComponentsHolder - cryptoComponents CryptoComponentsHolder - networkComponents NetworkComponentsHolder + coreComponents factory.CoreComponentsHolder + cryptoComponents factory.CryptoComponentsHolder + networkComponents factory.NetworkComponentsHolder } type bootstrapComponents struct { - epochStartBootstrapper EpochStartBootstrapper - bootstrapParamsHolder BootstrapParamsHolder + epochStartBootstrapper factory.EpochStartBootstrapper + bootstrapParamsHolder factory.BootstrapParamsHolder nodeType core.NodeType shardCoordinator sharding.Coordinator - headerVersionHandler factory.HeaderVersionHandler - versionedHeaderFactory factory.VersionedHeaderFactory - headerIntegrityVerifier factory.HeaderIntegrityVerifierHandler + headerVersionHandler nodeFactory.HeaderVersionHandler + versionedHeaderFactory nodeFactory.VersionedHeaderFactory + headerIntegrityVerifier nodeFactory.HeaderIntegrityVerifierHandler } // NewBootstrapComponentsFactory creates an instance of bootstrapComponentsFactory @@ -90,7 +94,7 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { return nil, err } - versionsCache, err := storageUnit.NewCache(storageFactory.GetCacherFromConfig(bcf.config.Versions.Cache)) + versionsCache, err := storageunit.NewCache(storageFactory.GetCacherFromConfig(bcf.config.Versions.Cache)) if err != nil { return nil, err } @@ -136,8 +140,8 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { bootstrapDataProvider, bcf.config, parentDir, - common.DefaultEpochString, - common.DefaultShardString, + storage.DefaultEpochString, + storage.DefaultShardString, ) if err != nil { return nil, err @@ -146,8 +150,8 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { unitOpener, err := createUnitOpener( bootstrapDataProvider, latestStorageDataProvider, - common.DefaultEpochString, - common.DefaultShardString, + storage.DefaultEpochString, + storage.DefaultShardString, ) if err != nil { return nil, err @@ -178,7 +182,7 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { ScheduledSCRsStorer: nil, // will be updated after sync from network } - var epochStartBootstrapper EpochStartBootstrapper + var epochStartBootstrapper factory.EpochStartBootstrapper if bcf.importDbConfig.IsImportDBMode { storageArg := bootstrap.ArgsStorageEpochStartBootstrap{ ArgsEpochStartBootstrap: epochStartBootstrapArgs, @@ -234,7 +238,7 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { }, nil } -func (bcf *bootstrapComponentsFactory) createHeaderFactory(handler factory.HeaderVersionHandler, shardID uint32) (factory.VersionedHeaderFactory, error) { +func (bcf *bootstrapComponentsFactory) createHeaderFactory(handler nodeFactory.HeaderVersionHandler, shardID uint32) (nodeFactory.VersionedHeaderFactory, error) { if shardID == core.MetachainShardId { return block.NewMetaHeaderFactory(handler) } @@ -262,17 +266,17 @@ func (bc *bootstrapComponents) ShardCoordinator() sharding.Coordinator { } // HeaderVersionHandler returns the header version handler -func (bc *bootstrapComponents) HeaderVersionHandler() factory.HeaderVersionHandler { +func (bc *bootstrapComponents) HeaderVersionHandler() nodeFactory.HeaderVersionHandler { return bc.headerVersionHandler } // VersionedHeaderFactory returns the versioned header factory -func (bc *bootstrapComponents) VersionedHeaderFactory() factory.VersionedHeaderFactory { +func (bc *bootstrapComponents) VersionedHeaderFactory() nodeFactory.VersionedHeaderFactory { return bc.versionedHeaderFactory } // HeaderIntegrityVerifier returns the header integrity verifier -func (bc *bootstrapComponents) HeaderIntegrityVerifier() factory.HeaderIntegrityVerifierHandler { +func (bc *bootstrapComponents) HeaderIntegrityVerifier() nodeFactory.HeaderIntegrityVerifierHandler { return bc.headerIntegrityVerifier } diff --git a/factory/bootstrapComponentsHandler.go b/factory/bootstrap/bootstrapComponentsHandler.go similarity index 86% rename from factory/bootstrapComponentsHandler.go rename to factory/bootstrap/bootstrapComponentsHandler.go index cecf4da142f..6f81af99834 100644 --- a/factory/bootstrapComponentsHandler.go +++ b/factory/bootstrap/bootstrapComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package bootstrap import ( "fmt" @@ -6,11 +6,12 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" ) -var _ ComponentHandler = (*managedBootstrapComponents)(nil) -var _ BootstrapComponentsHolder = (*managedBootstrapComponents)(nil) -var _ BootstrapComponentsHandler = (*managedBootstrapComponents)(nil) +var _ factory.ComponentHandler = (*managedBootstrapComponents)(nil) +var _ factory.BootstrapComponentsHolder = (*managedBootstrapComponents)(nil) +var _ factory.BootstrapComponentsHandler = (*managedBootstrapComponents)(nil) type managedBootstrapComponents struct { *bootstrapComponents @@ -81,7 +82,7 @@ func (mbf *managedBootstrapComponents) CheckSubcomponents() error { } // EpochStartBootstrapper returns the epoch start bootstrapper -func (mbf *managedBootstrapComponents) EpochStartBootstrapper() EpochStartBootstrapper { +func (mbf *managedBootstrapComponents) EpochStartBootstrapper() factory.EpochStartBootstrapper { mbf.mutBootstrapComponents.RLock() defer mbf.mutBootstrapComponents.RUnlock() @@ -93,7 +94,7 @@ func (mbf *managedBootstrapComponents) EpochStartBootstrapper() EpochStartBootst } // EpochBootstrapParams returns the epoch start bootstrap parameters handler -func (mbf *managedBootstrapComponents) EpochBootstrapParams() BootstrapParamsHolder { +func (mbf *managedBootstrapComponents) EpochBootstrapParams() factory.BootstrapParamsHolder { mbf.mutBootstrapComponents.RLock() defer mbf.mutBootstrapComponents.RUnlock() @@ -111,5 +112,5 @@ func (mbf *managedBootstrapComponents) IsInterfaceNil() bool { // String returns the name of the component func (mbf *managedBootstrapComponents) String() string { - return bootstrapComponentsName + return factory.BootstrapComponentsName } diff --git a/factory/bootstrapComponentsHandler_test.go b/factory/bootstrap/bootstrapComponentsHandler_test.go similarity index 61% rename from factory/bootstrapComponentsHandler_test.go rename to factory/bootstrap/bootstrapComponentsHandler_test.go index f60d645ec17..a7a071dc765 100644 --- a/factory/bootstrapComponentsHandler_test.go +++ b/factory/bootstrap/bootstrapComponentsHandler_test.go @@ -1,11 +1,12 @@ -package factory_test +package bootstrap_test import ( "errors" "testing" errorsErd "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" + "github.com/ElrondNetwork/elrond-go/factory/bootstrap" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -16,9 +17,9 @@ func TestNewManagedBootstrapComponents(t *testing.T) { t.Skip("this is not a short test") } - args := getBootStrapArgs() - bcf, _ := factory.NewBootstrapComponentsFactory(args) - mbc, err := factory.NewManagedBootstrapComponents(bcf) + args := componentsMock.GetBootStrapFactoryArgs() + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + mbc, err := bootstrap.NewManagedBootstrapComponents(bcf) require.NotNil(t, mbc) require.Nil(t, err) @@ -30,7 +31,7 @@ func TestNewBootstrapComponentsFactory_NilFactory(t *testing.T) { t.Skip("this is not a short test") } - mbc, err := factory.NewManagedBootstrapComponents(nil) + mbc, err := bootstrap.NewManagedBootstrapComponents(nil) require.Nil(t, mbc) require.Equal(t, errorsErd.ErrNilBootstrapComponentsFactory, err) @@ -42,9 +43,9 @@ func TestManagedBootstrapComponents_CheckSubcomponentsNoCreate(t *testing.T) { t.Skip("this is not a short test") } - args := getBootStrapArgs() - bcf, _ := factory.NewBootstrapComponentsFactory(args) - mbc, _ := factory.NewManagedBootstrapComponents(bcf) + args := componentsMock.GetBootStrapFactoryArgs() + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + mbc, _ := bootstrap.NewManagedBootstrapComponents(bcf) err := mbc.CheckSubcomponents() require.Equal(t, errorsErd.ErrNilBootstrapComponentsHolder, err) @@ -56,9 +57,9 @@ func TestManagedBootstrapComponents_Create(t *testing.T) { t.Skip("this is not a short test") } - args := getBootStrapArgs() - bcf, _ := factory.NewBootstrapComponentsFactory(args) - mbc, _ := factory.NewManagedBootstrapComponents(bcf) + args := componentsMock.GetBootStrapFactoryArgs() + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + mbc, _ := bootstrap.NewManagedBootstrapComponents(bcf) err := mbc.Create() require.Nil(t, err) @@ -73,11 +74,11 @@ func TestManagedBootstrapComponents_CreateNilInternalMarshalizer(t *testing.T) { t.Skip("this is not a short test") } - args := getBootStrapArgs() - coreComponents := getDefaultCoreComponents() + args := componentsMock.GetBootStrapFactoryArgs() + coreComponents := componentsMock.GetDefaultCoreComponents() args.CoreComponents = coreComponents - bcf, _ := factory.NewBootstrapComponentsFactory(args) - mbc, _ := factory.NewManagedBootstrapComponents(bcf) + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + mbc, _ := bootstrap.NewManagedBootstrapComponents(bcf) coreComponents.IntMarsh = nil err := mbc.Create() @@ -90,10 +91,10 @@ func TestManagedBootstrapComponents_Close(t *testing.T) { t.Skip("this is not a short test") } - args := getBootStrapArgs() + args := componentsMock.GetBootStrapFactoryArgs() - bcf, _ := factory.NewBootstrapComponentsFactory(args) - mbc, _ := factory.NewManagedBootstrapComponents(bcf) + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + mbc, _ := bootstrap.NewManagedBootstrapComponents(bcf) _ = mbc.Create() require.NotNil(t, mbc.EpochBootstrapParams()) diff --git a/factory/bootstrap/bootstrapComponents_test.go b/factory/bootstrap/bootstrapComponents_test.go new file mode 100644 index 00000000000..52795453f40 --- /dev/null +++ b/factory/bootstrap/bootstrapComponents_test.go @@ -0,0 +1,140 @@ +package bootstrap_test + +import ( + "errors" + "testing" + + errorsErd "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory/bootstrap" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/stretchr/testify/require" +) + +// ------------ Test BootstrapComponentsFactory -------------------- +func TestNewBootstrapComponentsFactory_OkValuesShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + + bcf, err := bootstrap.NewBootstrapComponentsFactory(args) + + require.NotNil(t, bcf) + require.Nil(t, err) +} + +func TestNewBootstrapComponentsFactory_NilCoreComponents(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + args.CoreComponents = nil + + bcf, err := bootstrap.NewBootstrapComponentsFactory(args) + + require.Nil(t, bcf) + require.Equal(t, errorsErd.ErrNilCoreComponentsHolder, err) +} + +func TestNewBootstrapComponentsFactory_NilCryptoComponents(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + args.CryptoComponents = nil + + bcf, err := bootstrap.NewBootstrapComponentsFactory(args) + + require.Nil(t, bcf) + require.Equal(t, errorsErd.ErrNilCryptoComponentsHolder, err) +} + +func TestNewBootstrapComponentsFactory_NilNetworkComponents(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + args.NetworkComponents = nil + + bcf, err := bootstrap.NewBootstrapComponentsFactory(args) + + require.Nil(t, bcf) + require.Equal(t, errorsErd.ErrNilNetworkComponentsHolder, err) +} + +func TestNewBootstrapComponentsFactory_NilWorkingDir(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + args.WorkingDir = "" + + bcf, err := bootstrap.NewBootstrapComponentsFactory(args) + + require.Nil(t, bcf) + require.Equal(t, errorsErd.ErrInvalidWorkingDir, err) +} + +func TestBootstrapComponentsFactory_CreateShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + + bc, err := bcf.Create() + + require.Nil(t, err) + require.NotNil(t, bc) +} + +func TestBootstrapComponentsFactory_CreateBootstrapDataProviderCreationFail(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + coreComponents := componentsMock.GetDefaultCoreComponents() + args.CoreComponents = coreComponents + + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + + coreComponents.IntMarsh = nil + bc, err := bcf.Create() + + require.Nil(t, bc) + require.True(t, errors.Is(err, errorsErd.ErrNewBootstrapDataProvider)) +} + +func TestBootstrapComponentsFactory_CreateEpochStartBootstrapCreationFail(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetBootStrapFactoryArgs() + coreComponents := componentsMock.GetDefaultCoreComponents() + args.CoreComponents = coreComponents + + bcf, _ := bootstrap.NewBootstrapComponentsFactory(args) + + coreComponents.RatingHandler = nil + bc, err := bcf.Create() + + require.Nil(t, bc) + require.True(t, errors.Is(err, errorsErd.ErrNewEpochStartBootstrap)) +} diff --git a/factory/bootstrapParameters.go b/factory/bootstrap/bootstrapParameters.go similarity index 98% rename from factory/bootstrapParameters.go rename to factory/bootstrap/bootstrapParameters.go index 9ce700a08a0..881013b5c02 100644 --- a/factory/bootstrapParameters.go +++ b/factory/bootstrap/bootstrapParameters.go @@ -1,4 +1,4 @@ -package factory +package bootstrap import ( "github.com/ElrondNetwork/elrond-go/epochStart/bootstrap" diff --git a/factory/shardingFactory.go b/factory/bootstrap/shardingFactory.go similarity index 96% rename from factory/shardingFactory.go rename to factory/bootstrap/shardingFactory.go index 9dc44483fd6..562e8bf0eeb 100644 --- a/factory/shardingFactory.go +++ b/factory/bootstrap/shardingFactory.go @@ -1,4 +1,4 @@ -package factory +package bootstrap import ( "errors" @@ -15,10 +15,11 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/epochStart" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" ) // CreateShardCoordinator is the shard coordinator factory @@ -89,7 +90,7 @@ func getShardIdFromNodePubKey(pubKey crypto.PublicKey, nodesConfig sharding.Gene // CreateNodesCoordinator is the nodes coordinator factory func CreateNodesCoordinator( - nodeShufflerOut ShuffleOutCloser, + nodeShufflerOut factory.ShuffleOutCloser, nodesConfig sharding.GenesisNodesSetupHandler, prefsConfig config.PreferencesConfig, epochStartNotifier epochStart.RegistrationHandler, @@ -100,7 +101,7 @@ func CreateNodesCoordinator( bootStorer storage.Storer, nodeShuffler nodesCoordinator.NodesShuffler, currentShardID uint32, - bootstrapParameters BootstrapParamsHolder, + bootstrapParameters factory.BootstrapParamsHolder, startEpoch uint32, chanNodeStop chan endProcess.ArgEndProcess, nodeTypeProvider core.NodeTypeProviderHandler, @@ -164,7 +165,7 @@ func CreateNodesCoordinator( return nil, err } - consensusGroupCache, err := lrucache.NewCache(25000) + consensusGroupCache, err := cache.NewLRUCache(25000) if err != nil { return nil, err } @@ -216,7 +217,7 @@ func CreateNodesShuffleOut( nodesConfig sharding.GenesisNodesSetupHandler, epochConfig config.EpochStartConfig, chanStopNodeProcess chan endProcess.ArgEndProcess, -) (ShuffleOutCloser, error) { +) (factory.ShuffleOutCloser, error) { maxThresholdEpochDuration := epochConfig.MaxShuffledOutRestartThreshold if !(maxThresholdEpochDuration >= 0.0 && maxThresholdEpochDuration <= 1.0) { diff --git a/factory/bootstrap/testBootstrapComponentsHandler.go b/factory/bootstrap/testBootstrapComponentsHandler.go new file mode 100644 index 00000000000..e9a6f7fe244 --- /dev/null +++ b/factory/bootstrap/testBootstrapComponentsHandler.go @@ -0,0 +1,34 @@ +package bootstrap + +import ( + "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/sharding" +) + +type testManagedBootstrapComponents struct { + *managedBootstrapComponents +} + +// NewTestManagedBootstrapComponents creates an instance of a test managed bootstrap components +func NewTestManagedBootstrapComponents(bootstrapComponentsFactory *bootstrapComponentsFactory) (*testManagedBootstrapComponents, error) { + bc, err := NewManagedBootstrapComponents(bootstrapComponentsFactory) + if err != nil { + return nil, err + } + return &testManagedBootstrapComponents{ + managedBootstrapComponents: bc, + }, nil +} + +// SetShardCoordinator sets the shard coordinator +func (mbf *testManagedBootstrapComponents) SetShardCoordinator(shardCoordinator sharding.Coordinator) error { + mbf.mutBootstrapComponents.RLock() + defer mbf.mutBootstrapComponents.RUnlock() + + if mbf.bootstrapComponents == nil { + return errors.ErrNilBootstrapComponents + } + + mbf.bootstrapComponents.shardCoordinator = shardCoordinator + return nil +} diff --git a/factory/bootstrapComponents_test.go b/factory/bootstrapComponents_test.go deleted file mode 100644 index fc793fb24e5..00000000000 --- a/factory/bootstrapComponents_test.go +++ /dev/null @@ -1,201 +0,0 @@ -package factory_test - -import ( - "errors" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/config" - errorsErd "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/economicsmocks" - "github.com/ElrondNetwork/elrond-go/testscommon/nodeTypeProviderMock" - "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" - "github.com/stretchr/testify/require" -) - -// ------------ Test BootstrapComponentsFactory -------------------- -func TestNewBootstrapComponentsFactory_OkValuesShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - - bcf, err := factory.NewBootstrapComponentsFactory(args) - - require.NotNil(t, bcf) - require.Nil(t, err) -} - -func TestNewBootstrapComponentsFactory_NilCoreComponents(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - args.CoreComponents = nil - - bcf, err := factory.NewBootstrapComponentsFactory(args) - - require.Nil(t, bcf) - require.Equal(t, errorsErd.ErrNilCoreComponentsHolder, err) -} - -func TestNewBootstrapComponentsFactory_NilCryptoComponents(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - args.CryptoComponents = nil - - bcf, err := factory.NewBootstrapComponentsFactory(args) - - require.Nil(t, bcf) - require.Equal(t, errorsErd.ErrNilCryptoComponentsHolder, err) -} - -func TestNewBootstrapComponentsFactory_NilNetworkComponents(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - args.NetworkComponents = nil - - bcf, err := factory.NewBootstrapComponentsFactory(args) - - require.Nil(t, bcf) - require.Equal(t, errorsErd.ErrNilNetworkComponentsHolder, err) -} - -func TestNewBootstrapComponentsFactory_NilWorkingDir(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - args.WorkingDir = "" - - bcf, err := factory.NewBootstrapComponentsFactory(args) - - require.Nil(t, bcf) - require.Equal(t, errorsErd.ErrInvalidWorkingDir, err) -} - -func TestBootstrapComponentsFactory_CreateShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - - bcf, _ := factory.NewBootstrapComponentsFactory(args) - - bc, err := bcf.Create() - - require.Nil(t, err) - require.NotNil(t, bc) -} - -func TestBootstrapComponentsFactory_CreateBootstrapDataProviderCreationFail(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - coreComponents := getDefaultCoreComponents() - args.CoreComponents = coreComponents - - bcf, _ := factory.NewBootstrapComponentsFactory(args) - - coreComponents.IntMarsh = nil - bc, err := bcf.Create() - - require.Nil(t, bc) - require.True(t, errors.Is(err, errorsErd.ErrNewBootstrapDataProvider)) -} - -func TestBootstrapComponentsFactory_CreateEpochStartBootstrapCreationFail(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getBootStrapArgs() - coreComponents := getDefaultCoreComponents() - args.CoreComponents = coreComponents - - bcf, _ := factory.NewBootstrapComponentsFactory(args) - - coreComponents.RatingHandler = nil - bc, err := bcf.Create() - - require.Nil(t, bc) - require.True(t, errors.Is(err, errorsErd.ErrNewEpochStartBootstrap)) -} - -func getBootStrapArgs() factory.BootstrapComponentsFactoryArgs { - coreComponents := getCoreComponents() - networkComponents := getNetworkComponents() - cryptoComponents := getCryptoComponents(coreComponents) - return factory.BootstrapComponentsFactoryArgs{ - Config: testscommon.GetGeneralConfig(), - WorkingDir: "home", - CoreComponents: coreComponents, - CryptoComponents: cryptoComponents, - NetworkComponents: networkComponents, - PrefConfig: config.Preferences{ - Preferences: config.PreferencesConfig{ - DestinationShardAsObserver: "0", - }, - }, - ImportDbConfig: config.ImportDbConfig{ - IsImportDBMode: false, - }, - RoundConfig: config.RoundConfig{}, - FlagsConfig: config.ContextFlagsConfig{ - ForceStartFromNetwork: false, - }, - } -} - -func getDefaultCoreComponents() *mock.CoreComponentsMock { - return &mock.CoreComponentsMock{ - IntMarsh: &testscommon.MarshalizerMock{}, - TxMarsh: &testscommon.MarshalizerMock{}, - VmMarsh: &testscommon.MarshalizerMock{}, - Hash: &testscommon.HasherStub{}, - UInt64ByteSliceConv: testscommon.NewNonceHashConverterMock(), - AddrPubKeyConv: testscommon.NewPubkeyConverterMock(32), - ValPubKeyConv: testscommon.NewPubkeyConverterMock(32), - PathHdl: &testscommon.PathManagerStub{}, - ChainIdCalled: func() string { - return "chainID" - }, - MinTransactionVersionCalled: func() uint32 { - return 1 - }, - AppStatusHdl: &statusHandler.AppStatusHandlerStub{}, - WatchdogTimer: &testscommon.WatchdogMock{}, - AlarmSch: &testscommon.AlarmSchedulerStub{}, - NtpSyncTimer: &testscommon.SyncTimerStub{}, - RoundHandlerField: &testscommon.RoundHandlerMock{}, - EconomicsHandler: &economicsmocks.EconomicsHandlerStub{}, - RatingsConfig: &testscommon.RatingsInfoMock{}, - RatingHandler: &testscommon.RaterMock{}, - NodesConfig: &testscommon.NodesSetupStub{}, - StartTime: time.Time{}, - NodeTypeProviderField: &nodeTypeProviderMock.NodeTypeProviderStub{}, - } -} diff --git a/factory/consensusComponents.go b/factory/consensus/consensusComponents.go similarity index 92% rename from factory/consensusComponents.go rename to factory/consensus/consensusComponents.go index 8236e986c69..a4bb57a657e 100644 --- a/factory/consensusComponents.go +++ b/factory/consensus/consensusComponents.go @@ -1,4 +1,4 @@ -package factory +package consensus import ( "time" @@ -8,33 +8,39 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/throttler" "github.com/ElrondNetwork/elrond-go-core/core/watchdog" "github.com/ElrondNetwork/elrond-go-core/marshal" + logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/consensus/chronology" + "github.com/ElrondNetwork/elrond-go/consensus/signing" "github.com/ElrondNetwork/elrond-go/consensus/spos" "github.com/ElrondNetwork/elrond-go/consensus/spos/sposFactory" "github.com/ElrondNetwork/elrond-go/errors" + factory "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/sync" "github.com/ElrondNetwork/elrond-go/process/sync/storageBootstrap" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/state/syncer" - "github.com/ElrondNetwork/elrond-go/trie/factory" + trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" + "github.com/ElrondNetwork/elrond-go/trie/storageMarker" "github.com/ElrondNetwork/elrond-go/update" ) +var log = logger.GetOrCreate("factory") + // ConsensusComponentsFactoryArgs holds the arguments needed to create a consensus components factory type ConsensusComponentsFactoryArgs struct { Config config.Config BootstrapRoundIndex uint64 - CoreComponents CoreComponentsHolder - NetworkComponents NetworkComponentsHolder - CryptoComponents CryptoComponentsHolder - DataComponents DataComponentsHolder - ProcessComponents ProcessComponentsHolder - StateComponents StateComponentsHolder - StatusComponents StatusComponentsHolder + CoreComponents factory.CoreComponentsHolder + NetworkComponents factory.NetworkComponentsHolder + CryptoComponents factory.CryptoComponentsHolder + DataComponents factory.DataComponentsHolder + ProcessComponents factory.ProcessComponentsHolder + StateComponents factory.StateComponentsHolder + StatusComponents factory.StatusComponentsHolder ScheduledProcessor consensus.ScheduledProcessor IsInImportMode bool ShouldDisableWatchdog bool @@ -43,13 +49,13 @@ type ConsensusComponentsFactoryArgs struct { type consensusComponentsFactory struct { config config.Config bootstrapRoundIndex uint64 - coreComponents CoreComponentsHolder - networkComponents NetworkComponentsHolder - cryptoComponents CryptoComponentsHolder - dataComponents DataComponentsHolder - processComponents ProcessComponentsHolder - stateComponents StateComponentsHolder - statusComponents StatusComponentsHolder + coreComponents factory.CoreComponentsHolder + networkComponents factory.NetworkComponentsHolder + cryptoComponents factory.CryptoComponentsHolder + dataComponents factory.DataComponentsHolder + processComponents factory.ProcessComponentsHolder + stateComponents factory.StateComponentsHolder + statusComponents factory.StatusComponentsHolder scheduledProcessor consensus.ScheduledProcessor isInImportMode bool shouldDisableWatchdog bool @@ -59,7 +65,7 @@ type consensusComponents struct { chronology consensus.ChronologyHandler bootstrapper process.Bootstrapper broadcastMessenger consensus.BroadcastMessenger - worker ConsensusWorker + worker factory.ConsensusWorker consensusTopic string consensusGroupSize int } @@ -220,6 +226,11 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { return nil, err } + signatureHandler, err := ccf.createBlsSignatureHandler() + if err != nil { + return nil, err + } + consensusArgs := &spos.ConsensusCoreArgs{ BlockChain: ccf.dataComponents.Blockchain(), BlockProcessor: ccf.processComponents.BlockProcessor(), @@ -230,7 +241,7 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { Marshalizer: ccf.coreComponents.InternalMarshalizer(), BlsPrivateKey: ccf.cryptoComponents.PrivateKey(), BlsSingleSigner: ccf.cryptoComponents.BlockSigner(), - MultiSigner: ccf.cryptoComponents.MultiSigner(), + MultiSignerContainer: ccf.cryptoComponents.MultiSignerContainer(), RoundHandler: ccf.processComponents.RoundHandler(), ShardCoordinator: ccf.processComponents.ShardCoordinator(), NodesCoordinator: ccf.processComponents.NodesCoordinator(), @@ -242,6 +253,7 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { FallbackHeaderValidator: ccf.processComponents.FallbackHeaderValidator(), NodeRedundancyHandler: ccf.processComponents.NodeRedundancyHandler(), ScheduledProcessor: ccf.scheduledProcessor, + SignatureHandler: signatureHandler, } consensusDataContainer, err := spos.NewConsensusCore( @@ -494,11 +506,12 @@ func (ccf *consensusComponentsFactory) createArgsBaseAccountsSyncer(trieStorageM MaxHardCapForMissingNodes: ccf.config.TrieSync.MaxHardCapForMissingNodes, TrieSyncerVersion: ccf.config.TrieSync.TrieSyncerVersion, CheckNodesOnDisk: ccf.config.TrieSync.CheckNodesOnDisk, + StorageMarker: storageMarker.NewTrieStorageMarker(), } } func (ccf *consensusComponentsFactory) createValidatorAccountsSyncer() (process.AccountsDBSyncer, error) { - trieStorageManager, ok := ccf.stateComponents.TrieStorageManagers()[factory.PeerAccountTrie] + trieStorageManager, ok := ccf.stateComponents.TrieStorageManagers()[trieFactory.PeerAccountTrie] if !ok { return nil, errors.ErrNilTrieStorageManager } @@ -510,7 +523,7 @@ func (ccf *consensusComponentsFactory) createValidatorAccountsSyncer() (process. } func (ccf *consensusComponentsFactory) createUserAccountsSyncer() (process.AccountsDBSyncer, error) { - trieStorageManager, ok := ccf.stateComponents.TrieStorageManagers()[factory.UserAccountTrie] + trieStorageManager, ok := ccf.stateComponents.TrieStorageManagers()[trieFactory.UserAccountTrie] if !ok { return nil, errors.ErrNilTrieStorageManager } @@ -638,6 +651,22 @@ func (ccf *consensusComponentsFactory) createConsensusTopic(cc *consensusCompone return ccf.networkComponents.NetworkMessenger().RegisterMessageProcessor(cc.consensusTopic, common.DefaultInterceptorsIdentifier, cc.worker) } +func (ccf *consensusComponentsFactory) createBlsSignatureHandler() (consensus.SignatureHandler, error) { + privKeyBytes, err := ccf.cryptoComponents.PrivateKey().ToByteArray() + if err != nil { + return nil, err + } + + signatureHolderArgs := signing.ArgsSignatureHolder{ + PubKeys: []string{ccf.cryptoComponents.PublicKeyString()}, + PrivKeyBytes: privKeyBytes, + MultiSignerContainer: ccf.cryptoComponents.MultiSignerContainer(), + KeyGenerator: ccf.cryptoComponents.BlockSignKeyGen(), + } + + return signing.NewSignatureHolder(signatureHolderArgs) +} + func (ccf *consensusComponentsFactory) addCloserInstances(closers ...update.Closer) error { hardforkTrigger := ccf.processComponents.HardforkTrigger() for _, c := range closers { diff --git a/factory/consensusComponentsHandler.go b/factory/consensus/consensusComponentsHandler.go similarity index 90% rename from factory/consensusComponentsHandler.go rename to factory/consensus/consensusComponentsHandler.go index 7bbc649719e..68a99fb2b63 100644 --- a/factory/consensusComponentsHandler.go +++ b/factory/consensus/consensusComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package consensus import ( "fmt" @@ -7,12 +7,13 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/process" ) -var _ ComponentHandler = (*managedConsensusComponents)(nil) -var _ ConsensusComponentsHolder = (*managedConsensusComponents)(nil) -var _ ConsensusComponentsHandler = (*managedConsensusComponents)(nil) +var _ factory.ComponentHandler = (*managedConsensusComponents)(nil) +var _ factory.ConsensusComponentsHolder = (*managedConsensusComponents)(nil) +var _ factory.ConsensusComponentsHandler = (*managedConsensusComponents)(nil) type managedConsensusComponents struct { *consensusComponents @@ -77,7 +78,7 @@ func (mcc *managedConsensusComponents) Chronology() consensus.ChronologyHandler } // ConsensusWorker returns the consensus worker -func (mcc *managedConsensusComponents) ConsensusWorker() ConsensusWorker { +func (mcc *managedConsensusComponents) ConsensusWorker() factory.ConsensusWorker { mcc.mutConsensusComponents.RLock() defer mcc.mutConsensusComponents.RUnlock() @@ -152,5 +153,5 @@ func (mcc *managedConsensusComponents) IsInterfaceNil() bool { // String returns the name of the component func (mcc *managedConsensusComponents) String() string { - return consensusComponentsName + return factory.ConsensusComponentsName } diff --git a/factory/consensusComponentsHandler_test.go b/factory/consensus/consensusComponentsHandler_test.go similarity index 65% rename from factory/consensusComponentsHandler_test.go rename to factory/consensus/consensusComponentsHandler_test.go index 4c110531f82..3cfaf73b2f8 100644 --- a/factory/consensusComponentsHandler_test.go +++ b/factory/consensus/consensusComponentsHandler_test.go @@ -1,10 +1,11 @@ -package factory_test +package consensus_test import ( "testing" - "github.com/ElrondNetwork/elrond-go/factory" + consensusComp "github.com/ElrondNetwork/elrond-go/factory/consensus" "github.com/ElrondNetwork/elrond-go/factory/mock" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -16,11 +17,11 @@ func TestManagedConsensusComponents_CreateWithInvalidArgsShouldErr(t *testing.T) } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) - coreComponents := getDefaultCoreComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + coreComponents := componentsMock.GetDefaultCoreComponents() args.CoreComponents = coreComponents - consensusComponentsFactory, _ := factory.NewConsensusComponentsFactory(args) - managedConsensusComponents, err := factory.NewManagedConsensusComponents(consensusComponentsFactory) + consensusComponentsFactory, _ := consensusComp.NewConsensusComponentsFactory(args) + managedConsensusComponents, err := consensusComp.NewManagedConsensusComponents(consensusComponentsFactory) require.NoError(t, err) coreComponents.AppStatusHdl = nil @@ -36,10 +37,10 @@ func TestManagedConsensusComponents_CreateShouldWork(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) - consensusComponentsFactory, _ := factory.NewConsensusComponentsFactory(args) - managedConsensusComponents, err := factory.NewManagedConsensusComponents(consensusComponentsFactory) + consensusComponentsFactory, _ := consensusComp.NewConsensusComponentsFactory(args) + managedConsensusComponents, err := consensusComp.NewManagedConsensusComponents(consensusComponentsFactory) require.NoError(t, err) require.Nil(t, managedConsensusComponents.BroadcastMessenger()) @@ -62,9 +63,9 @@ func TestManagedConsensusComponents_Close(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - consensusArgs := getConsensusArgs(shardCoordinator) - consensusComponentsFactory, _ := factory.NewConsensusComponentsFactory(consensusArgs) - managedConsensusComponents, _ := factory.NewManagedConsensusComponents(consensusComponentsFactory) + consensusArgs := componentsMock.GetConsensusArgs(shardCoordinator) + consensusComponentsFactory, _ := consensusComp.NewConsensusComponentsFactory(consensusArgs) + managedConsensusComponents, _ := consensusComp.NewManagedConsensusComponents(consensusComponentsFactory) err := managedConsensusComponents.Create() require.NoError(t, err) diff --git a/factory/consensusComponents_test.go b/factory/consensus/consensusComponents_test.go similarity index 53% rename from factory/consensusComponents_test.go rename to factory/consensus/consensusComponents_test.go index 572862fd1ae..c9ba95baae0 100644 --- a/factory/consensusComponents_test.go +++ b/factory/consensus/consensusComponents_test.go @@ -1,4 +1,4 @@ -package factory_test +package consensus_test import ( "errors" @@ -7,24 +7,17 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data" - crypto "github.com/ElrondNetwork/elrond-go-crypto" - "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/consensus/chronology" - "github.com/ElrondNetwork/elrond-go/consensus/spos" "github.com/ElrondNetwork/elrond-go/consensus/spos/sposFactory" errorsErd "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/factory" + consensusComp "github.com/ElrondNetwork/elrond-go/factory/consensus" "github.com/ElrondNetwork/elrond-go/factory/mock" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" - dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" - stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" - storageStubs "github.com/ElrondNetwork/elrond-go/testscommon/storage" - trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" "github.com/stretchr/testify/require" ) @@ -36,9 +29,9 @@ func TestNewConsensusComponentsFactory_OkValuesShouldWork(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.NotNil(t, bcf) require.Nil(t, err) @@ -51,10 +44,10 @@ func TestNewConsensusComponentsFactory_NilCoreComponents(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.CoreComponents = nil - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, bcf) require.Equal(t, errorsErd.ErrNilCoreComponentsHolder, err) @@ -67,10 +60,10 @@ func TestNewConsensusComponentsFactory_NilDataComponents(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.DataComponents = nil - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, bcf) require.Equal(t, errorsErd.ErrNilDataComponentsHolder, err) @@ -83,10 +76,10 @@ func TestNewConsensusComponentsFactory_NilCryptoComponents(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.CryptoComponents = nil - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, bcf) require.Equal(t, errorsErd.ErrNilCryptoComponentsHolder, err) @@ -99,10 +92,10 @@ func TestNewConsensusComponentsFactory_NilNetworkComponents(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.NetworkComponents = nil - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, bcf) require.Equal(t, errorsErd.ErrNilNetworkComponentsHolder, err) @@ -115,10 +108,10 @@ func TestNewConsensusComponentsFactory_NilProcessComponents(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.ProcessComponents = nil - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, bcf) require.Equal(t, errorsErd.ErrNilProcessComponentsHolder, err) @@ -131,10 +124,10 @@ func TestNewConsensusComponentsFactory_NilStateComponents(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.StateComponents = nil - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, bcf) require.Equal(t, errorsErd.ErrNilStateComponentsHolder, err) @@ -148,9 +141,9 @@ func TestConsensusComponentsFactory_CreateGenesisBlockNotInitializedShouldErr(t } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - consensusArgs := getConsensusArgs(shardCoordinator) - consensusComponentsFactory, _ := factory.NewConsensusComponentsFactory(consensusArgs) - managedConsensusComponents, _ := factory.NewManagedConsensusComponents(consensusComponentsFactory) + consensusArgs := componentsMock.GetConsensusArgs(shardCoordinator) + consensusComponentsFactory, _ := consensusComp.NewConsensusComponentsFactory(consensusArgs) + managedConsensusComponents, _ := consensusComp.NewManagedConsensusComponents(consensusComponentsFactory) dataComponents := consensusArgs.DataComponents @@ -175,8 +168,8 @@ func TestConsensusComponentsFactory_CreateForShard(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) - ccf, _ := factory.NewConsensusComponentsFactory(args) + args := componentsMock.GetConsensusArgs(shardCoordinator) + ccf, _ := consensusComp.NewConsensusComponentsFactory(args) require.NotNil(t, ccf) cc, err := ccf.Create() @@ -204,12 +197,12 @@ func TestConsensusComponentsFactory_CreateForMeta(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.ProcessComponents = &wrappedProcessComponents{ ProcessComponentsHolder: args.ProcessComponents, } - ccf, _ := factory.NewConsensusComponentsFactory(args) + ccf, _ := consensusComp.NewConsensusComponentsFactory(args) require.NotNil(t, ccf) cc, err := ccf.Create() @@ -224,10 +217,10 @@ func TestConsensusComponentsFactory_CreateNilShardCoordinator(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - consensusArgs := getConsensusArgs(shardCoordinator) + consensusArgs := componentsMock.GetConsensusArgs(shardCoordinator) processComponents := &mock.ProcessComponentsMock{} consensusArgs.ProcessComponents = processComponents - consensusComponentsFactory, _ := factory.NewConsensusComponentsFactory(consensusArgs) + consensusComponentsFactory, _ := consensusComp.NewConsensusComponentsFactory(consensusArgs) cc, err := consensusComponentsFactory.Create() @@ -243,8 +236,8 @@ func TestConsensusComponentsFactory_CreateConsensusTopicCreateTopicError(t *test localError := errors.New("error") shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) - networkComponents := getDefaultNetworkComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + networkComponents := componentsMock.GetDefaultNetworkComponents() networkComponents.Messenger = &p2pmocks.MessengerStub{ HasTopicValidatorCalled: func(name string) bool { return false @@ -258,7 +251,7 @@ func TestConsensusComponentsFactory_CreateConsensusTopicCreateTopicError(t *test } args.NetworkComponents = networkComponents - bcf, _ := factory.NewConsensusComponentsFactory(args) + bcf, _ := consensusComp.NewConsensusComponentsFactory(args) cc, err := bcf.Create() require.Nil(t, cc) @@ -272,12 +265,12 @@ func TestConsensusComponentsFactory_CreateConsensusTopicNilMessageProcessor(t *t } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) - networkComponents := getDefaultNetworkComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + networkComponents := componentsMock.GetDefaultNetworkComponents() networkComponents.Messenger = nil args.NetworkComponents = networkComponents - bcf, _ := factory.NewConsensusComponentsFactory(args) + bcf, _ := consensusComp.NewConsensusComponentsFactory(args) cc, err := bcf.Create() require.Nil(t, cc) @@ -291,11 +284,11 @@ func TestConsensusComponentsFactory_CreateNilSyncTimer(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) - coreComponents := getDefaultCoreComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + coreComponents := componentsMock.GetDefaultCoreComponents() coreComponents.NtpSyncTimer = nil args.CoreComponents = coreComponents - bcf, _ := factory.NewConsensusComponentsFactory(args) + bcf, _ := consensusComp.NewConsensusComponentsFactory(args) cc, err := bcf.Create() require.Nil(t, cc) @@ -309,11 +302,11 @@ func TestStartConsensus_ShardBootstrapperNilAccounts(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) - stateComponents := getDefaultStateComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + stateComponents := componentsMock.GetDefaultStateComponents() stateComponents.Accounts = nil args.StateComponents = stateComponents - bcf, _ := factory.NewConsensusComponentsFactory(args) + bcf, _ := consensusComp.NewConsensusComponentsFactory(args) cc, err := bcf.Create() require.Nil(t, cc) @@ -328,13 +321,13 @@ func TestStartConsensus_ShardBootstrapperNilPoolHolder(t *testing.T) { shardCoordinator := mock.NewMultiShardsCoordinatorMock(1) shardCoordinator.CurrentShard = 0 - args := getConsensusArgs(shardCoordinator) - dataComponents := getDefaultDataComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + dataComponents := componentsMock.GetDefaultDataComponents() dataComponents.DataPool = nil args.DataComponents = dataComponents - processComponents := getDefaultProcessComponents(shardCoordinator) + processComponents := componentsMock.GetDefaultProcessComponents(shardCoordinator) args.ProcessComponents = processComponents - bcf, _ := factory.NewConsensusComponentsFactory(args) + bcf, _ := consensusComp.NewConsensusComponentsFactory(args) cc, err := bcf.Create() require.Nil(t, cc) @@ -356,12 +349,12 @@ func TestStartConsensus_MetaBootstrapperNilPoolHolder(t *testing.T) { return 0 } - args := getConsensusArgs(shardCoordinator) - dataComponents := getDefaultDataComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + dataComponents := componentsMock.GetDefaultDataComponents() dataComponents.DataPool = nil args.DataComponents = dataComponents - args.ProcessComponents = getDefaultProcessComponents(shardCoordinator) - bcf, err := factory.NewConsensusComponentsFactory(args) + args.ProcessComponents = componentsMock.GetDefaultProcessComponents(shardCoordinator) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, err) require.NotNil(t, bcf) cc, err := bcf.Create() @@ -377,10 +370,10 @@ func TestStartConsensus_MetaBootstrapperWrongNumberShards(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(1) - args := getConsensusArgs(shardCoordinator) - processComponents := getDefaultProcessComponents(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) + processComponents := componentsMock.GetDefaultProcessComponents(shardCoordinator) args.ProcessComponents = processComponents - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, err) shardCoordinator.CurrentShard = 2 cc, err := bcf.Create() @@ -397,15 +390,15 @@ func TestStartConsensus_ShardBootstrapperPubKeyToByteArrayError(t *testing.T) { localErr := errors.New("err") shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) - cryptoParams := getDefaultCryptoComponents() + args := componentsMock.GetConsensusArgs(shardCoordinator) + cryptoParams := componentsMock.GetDefaultCryptoComponents() cryptoParams.PubKey = &mock.PublicKeyMock{ ToByteArrayHandler: func() (i []byte, err error) { return []byte("nil"), localErr }, } args.CryptoComponents = cryptoParams - bcf, _ := factory.NewConsensusComponentsFactory(args) + bcf, _ := consensusComp.NewConsensusComponentsFactory(args) cc, err := bcf.Create() require.Nil(t, cc) require.Equal(t, localErr, err) @@ -418,142 +411,11 @@ func TestStartConsensus_ShardBootstrapperInvalidConsensusType(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getConsensusArgs(shardCoordinator) + args := componentsMock.GetConsensusArgs(shardCoordinator) args.Config.Consensus.Type = "invalid" - bcf, err := factory.NewConsensusComponentsFactory(args) + bcf, err := consensusComp.NewConsensusComponentsFactory(args) require.Nil(t, err) cc, err := bcf.Create() require.Nil(t, cc) require.Equal(t, sposFactory.ErrInvalidConsensusType, err) } - -func getConsensusArgs(shardCoordinator sharding.Coordinator) factory.ConsensusComponentsFactoryArgs { - coreComponents := getCoreComponents() - networkComponents := getNetworkComponents() - stateComponents := getStateComponents(coreComponents, shardCoordinator) - cryptoComponents := getCryptoComponents(coreComponents) - dataComponents := getDataComponents(coreComponents, shardCoordinator) - processComponents := getProcessComponents( - shardCoordinator, - coreComponents, - networkComponents, - dataComponents, - cryptoComponents, - stateComponents, - ) - statusComponents := getStatusComponents( - coreComponents, - networkComponents, - dataComponents, - stateComponents, - shardCoordinator, - processComponents.NodesCoordinator(), - ) - - args := spos.ScheduledProcessorWrapperArgs{ - SyncTimer: coreComponents.SyncTimer(), - Processor: processComponents.BlockProcessor(), - RoundTimeDurationHandler: coreComponents.RoundHandler(), - } - scheduledProcessor, _ := spos.NewScheduledProcessorWrapper(args) - - return factory.ConsensusComponentsFactoryArgs{ - Config: testscommon.GetGeneralConfig(), - BootstrapRoundIndex: 0, - CoreComponents: coreComponents, - NetworkComponents: networkComponents, - CryptoComponents: cryptoComponents, - DataComponents: dataComponents, - ProcessComponents: processComponents, - StateComponents: stateComponents, - StatusComponents: statusComponents, - ScheduledProcessor: scheduledProcessor, - } -} - -func getDefaultNetworkComponents() *mock.NetworkComponentsMock { - return &mock.NetworkComponentsMock{ - Messenger: &p2pmocks.MessengerStub{}, - InputAntiFlood: &mock.P2PAntifloodHandlerStub{}, - OutputAntiFlood: &mock.P2PAntifloodHandlerStub{}, - PeerBlackList: &mock.PeerBlackListHandlerStub{}, - } -} - -func getDefaultStateComponents() *testscommon.StateComponentsMock { - return &testscommon.StateComponentsMock{ - PeersAcc: &stateMock.AccountsStub{}, - Accounts: &stateMock.AccountsStub{}, - Tries: &mock.TriesHolderStub{}, - StorageManagers: map[string]common.StorageManager{ - "0": &testscommon.StorageManagerStub{}, - trieFactory.UserAccountTrie: &testscommon.StorageManagerStub{}, - trieFactory.PeerAccountTrie: &testscommon.StorageManagerStub{}, - }, - } -} - -func getDefaultDataComponents() *mock.DataComponentsMock { - return &mock.DataComponentsMock{ - Blkc: &testscommon.ChainHandlerStub{}, - Storage: &storageStubs.ChainStorerStub{}, - DataPool: &dataRetrieverMock.PoolsHolderMock{}, - MiniBlockProvider: &mock.MiniBlocksProviderStub{}, - } -} - -func getDefaultProcessComponents(shardCoordinator sharding.Coordinator) *mock.ProcessComponentsMock { - return &mock.ProcessComponentsMock{ - NodesCoord: &shardingMocks.NodesCoordinatorMock{}, - ShardCoord: shardCoordinator, - IntContainer: &testscommon.InterceptorsContainerStub{}, - ResFinder: &mock.ResolversFinderStub{}, - RoundHandlerField: &testscommon.RoundHandlerMock{}, - EpochTrigger: &testscommon.EpochStartTriggerStub{}, - EpochNotifier: &mock.EpochStartNotifierStub{}, - ForkDetect: &mock.ForkDetectorMock{}, - BlockProcess: &mock.BlockProcessorStub{}, - BlackListHdl: &testscommon.TimeCacheStub{}, - BootSore: &mock.BootstrapStorerMock{}, - HeaderSigVerif: &mock.HeaderSigVerifierStub{}, - HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, - ValidatorStatistics: &mock.ValidatorStatisticsProcessorStub{}, - ValidatorProvider: &mock.ValidatorsProviderStub{}, - BlockTrack: &mock.BlockTrackerStub{}, - PendingMiniBlocksHdl: &mock.PendingMiniBlocksHandlerStub{}, - ReqHandler: &testscommon.RequestHandlerStub{}, - TxLogsProcess: &mock.TxLogProcessorMock{}, - HeaderConstructValidator: &mock.HeaderValidatorStub{}, - PeerMapper: &p2pmocks.NetworkShardingCollectorStub{}, - FallbackHdrValidator: &testscommon.FallBackHeaderValidatorStub{}, - NodeRedundancyHandlerInternal: &mock.RedundancyHandlerStub{ - IsRedundancyNodeCalled: func() bool { - return false - }, - IsMainMachineActiveCalled: func() bool { - return false - }, - ObserverPrivateKeyCalled: func() crypto.PrivateKey { - return &mock.PrivateKeyStub{} - }, - }, - HardforkTriggerField: &testscommon.HardforkTriggerStub{}, - } -} - -func getDefaultCryptoComponents() *mock.CryptoComponentsMock { - return &mock.CryptoComponentsMock{ - PubKey: &mock.PublicKeyMock{}, - PrivKey: &mock.PrivateKeyStub{}, - PubKeyString: "pubKey", - PrivKeyBytes: []byte("privKey"), - PubKeyBytes: []byte("pubKey"), - BlockSig: &mock.SinglesignMock{}, - TxSig: &mock.SinglesignMock{}, - MultiSig: &cryptoMocks.MultisignerStub{}, - PeerSignHandler: &mock.PeerSignatureHandler{}, - BlKeyGen: &mock.KeyGenMock{}, - TxKeyGen: &mock.KeyGenMock{}, - MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, - } -} diff --git a/factory/constants.go b/factory/constants.go index 95d2eb61b30..347727cfbde 100644 --- a/factory/constants.go +++ b/factory/constants.go @@ -1,15 +1,26 @@ package factory const ( - bootstrapComponentsName = "managedBootstrapComponents" - consensusComponentsName = "managedConsensusComponents" - coreComponentsName = "managedCoreComponents" - cryptoComponentsName = "managedCryptoComponents" - dataComponentsName = "managedDataComponents" - heartbeatComponentsName = "managedHeartbeatComponents" - heartbeatV2ComponentsName = "managedHeartbeatV2Components" - networkComponentsName = "managedNetworkComponents" - processComponentsName = "managedProcessComponents" - stateComponentsName = "managedStateComponents" - statusComponentsName = "managedStatusComponents" + // BootstrapComponentsName is the bootstrap components identifier + BootstrapComponentsName = "managedBootstrapComponents" + // ConsensusComponentsName is the consensus components identifier + ConsensusComponentsName = "managedConsensusComponents" + // CoreComponentsName is the core components identifier + CoreComponentsName = "managedCoreComponents" + // CryptoComponentsName is the crypto components identifier + CryptoComponentsName = "managedCryptoComponents" + // DataComponentsName is the data components identifier + DataComponentsName = "managedDataComponents" + // HeartbeatComponentsName is the heartbeat components identifier + HeartbeatComponentsName = "managedHeartbeatComponents" + // HeartbeatV2ComponentsName is the heartbeat V2 components identifier + HeartbeatV2ComponentsName = "managedHeartbeatV2Components" + // NetworkComponentsName is the network components identifier + NetworkComponentsName = "managedNetworkComponents" + // ProcessComponentsName is the process components identifier + ProcessComponentsName = "managedProcessComponents" + // StateComponentsName is the state components identifier + StateComponentsName = "managedStateComponents" + // StatusComponentsName is the status components identifier + StatusComponentsName = "managedStatusComponents" ) diff --git a/factory/coreComponents.go b/factory/core/coreComponents.go similarity index 96% rename from factory/coreComponents.go rename to factory/core/coreComponents.go index 85bdbe2a15c..0d59bef8077 100644 --- a/factory/coreComponents.go +++ b/factory/core/coreComponents.go @@ -1,4 +1,4 @@ -package factory +package core import ( "bytes" @@ -19,7 +19,8 @@ import ( hasherFactory "github.com/ElrondNetwork/elrond-go-core/hashing/factory" "github.com/ElrondNetwork/elrond-go-core/marshal" marshalizerFactory "github.com/ElrondNetwork/elrond-go-core/marshal/factory" - "github.com/ElrondNetwork/elrond-go/cmd/node/factory" + logger "github.com/ElrondNetwork/elrond-go-logger" + nodeFactory "github.com/ElrondNetwork/elrond-go/cmd/node/factory" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/common/enablers" commonFactory "github.com/ElrondNetwork/elrond-go/common/factory" @@ -29,6 +30,7 @@ import ( "github.com/ElrondNetwork/elrond-go/consensus/round" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" "github.com/ElrondNetwork/elrond-go/errors" + factory "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/node/metrics" "github.com/ElrondNetwork/elrond-go/ntp" "github.com/ElrondNetwork/elrond-go/process" @@ -42,6 +44,8 @@ import ( storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" ) +var log = logger.GetOrCreate("factory") + // CoreComponentsFactoryArgs holds the arguments needed for creating a core components factory type CoreComponentsFactoryArgs struct { Config config.Config @@ -54,7 +58,7 @@ type CoreComponentsFactoryArgs struct { NodesFilename string WorkingDirectory string ChanStopNodeProcess chan endProcess.ArgEndProcess - StatusHandlersFactory factory.StatusHandlerUtilsFactory + StatusHandlersFactory nodeFactory.StatusHandlerUtilsFactory } // coreComponentsFactory is responsible for creating the core components @@ -69,7 +73,7 @@ type coreComponentsFactory struct { nodesFilename string workingDir string chanStopNodeProcess chan endProcess.ArgEndProcess - statusHandlersFactory factory.StatusHandlerUtilsFactory + statusHandlersFactory nodeFactory.StatusHandlerUtilsFactory } // coreComponents is the DTO used for core components @@ -82,7 +86,7 @@ type coreComponents struct { uint64ByteSliceConverter typeConverters.Uint64ByteSliceConverter addressPubKeyConverter core.PubkeyConverter validatorPubKeyConverter core.PubkeyConverter - statusHandlersUtils factory.StatusHandlersUtils + statusHandlersUtils nodeFactory.StatusHandlersUtils pathHandler storage.PathManagerHandler syncTimer ntp.SyncTimer roundHandler consensus.RoundHandler @@ -100,7 +104,7 @@ type coreComponents struct { minTransactionVersion uint32 epochNotifier process.EpochNotifier enableRoundsHandler process.EnableRoundsHandler - epochStartNotifierWithConfirm EpochStartNotifierWithConfirm + epochStartNotifierWithConfirm factory.EpochStartNotifierWithConfirm chanStopNodeProcess chan endProcess.ArgEndProcess nodeTypeProvider core.NodeTypeProviderHandler encodedAddressLen uint32 diff --git a/factory/coreComponentsHandler.go b/factory/core/coreComponentsHandler.go similarity index 96% rename from factory/coreComponentsHandler.go rename to factory/core/coreComponentsHandler.go index a4ce3134818..b52deeb4735 100644 --- a/factory/coreComponentsHandler.go +++ b/factory/core/coreComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package core import ( "fmt" @@ -11,10 +11,11 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/typeConverters" "github.com/ElrondNetwork/elrond-go-core/hashing" "github.com/ElrondNetwork/elrond-go-core/marshal" - "github.com/ElrondNetwork/elrond-go/cmd/node/factory" + nodeFactory "github.com/ElrondNetwork/elrond-go/cmd/node/factory" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/ntp" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/sharding" @@ -22,9 +23,9 @@ import ( "github.com/ElrondNetwork/elrond-go/storage" ) -var _ ComponentHandler = (*managedCoreComponents)(nil) -var _ CoreComponentsHolder = (*managedCoreComponents)(nil) -var _ CoreComponentsHandler = (*managedCoreComponents)(nil) +var _ factory.ComponentHandler = (*managedCoreComponents)(nil) +var _ factory.CoreComponentsHolder = (*managedCoreComponents)(nil) +var _ factory.CoreComponentsHandler = (*managedCoreComponents)(nil) // managedCoreComponents is an implementation of core components handler that can create, close and access the core components type managedCoreComponents struct { @@ -270,7 +271,7 @@ func (mcc *managedCoreComponents) ValidatorPubKeyConverter() core.PubkeyConverte } // StatusHandlerUtils returns the core components status handler utils -func (mcc *managedCoreComponents) StatusHandlerUtils() factory.StatusHandlersUtils { +func (mcc *managedCoreComponents) StatusHandlerUtils() nodeFactory.StatusHandlersUtils { mcc.mutCoreComponents.RLock() defer mcc.mutCoreComponents.RUnlock() @@ -510,7 +511,7 @@ func (mcc *managedCoreComponents) EnableRoundsHandler() process.EnableRoundsHand } // EpochStartNotifierWithConfirm returns the epoch notifier with confirm -func (mcc *managedCoreComponents) EpochStartNotifierWithConfirm() EpochStartNotifierWithConfirm { +func (mcc *managedCoreComponents) EpochStartNotifierWithConfirm() factory.EpochStartNotifierWithConfirm { mcc.mutCoreComponents.RLock() defer mcc.mutCoreComponents.RUnlock() @@ -600,5 +601,5 @@ func (mcc *managedCoreComponents) IsInterfaceNil() bool { // String returns the name of the component func (mcc *managedCoreComponents) String() string { - return coreComponentsName + return factory.CoreComponentsName } diff --git a/factory/coreComponentsHandler_test.go b/factory/core/coreComponentsHandler_test.go similarity index 79% rename from factory/coreComponentsHandler_test.go rename to factory/core/coreComponentsHandler_test.go index 52f6ac58a51..daa1ddc8dc0 100644 --- a/factory/coreComponentsHandler_test.go +++ b/factory/core/coreComponentsHandler_test.go @@ -1,10 +1,11 @@ -package factory_test +package core_test import ( "testing" "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/factory" + coreComp "github.com/ElrondNetwork/elrond-go/factory/core" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -15,13 +16,13 @@ func TestManagedCoreComponents_CreateWithInvalidArgsShouldErr(t *testing.T) { t.Skip("this is not a short test") } - coreArgs := getCoreArgs() + coreArgs := componentsMock.GetCoreArgs() coreArgs.Config.Marshalizer = config.MarshalizerConfig{ Type: "invalid_marshalizer_type", SizeCheckDelta: 0, } - coreComponentsFactory, _ := factory.NewCoreComponentsFactory(coreArgs) - managedCoreComponents, err := factory.NewManagedCoreComponents(coreComponentsFactory) + coreComponentsFactory, _ := coreComp.NewCoreComponentsFactory(coreArgs) + managedCoreComponents, err := coreComp.NewManagedCoreComponents(coreComponentsFactory) require.NoError(t, err) err = managedCoreComponents.Create() require.Error(t, err) @@ -34,9 +35,9 @@ func TestManagedCoreComponents_CreateShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreArgs := getCoreArgs() - coreComponentsFactory, _ := factory.NewCoreComponentsFactory(coreArgs) - managedCoreComponents, err := factory.NewManagedCoreComponents(coreComponentsFactory) + coreArgs := componentsMock.GetCoreArgs() + coreComponentsFactory, _ := coreComp.NewCoreComponentsFactory(coreArgs) + managedCoreComponents, err := coreComp.NewManagedCoreComponents(coreComponentsFactory) require.NoError(t, err) require.Nil(t, managedCoreComponents.Hasher()) require.Nil(t, managedCoreComponents.InternalMarshalizer()) @@ -70,7 +71,7 @@ func TestManagedCoreComponents_CreateShouldWork(t *testing.T) { require.NotNil(t, managedCoreComponents.EnableRoundsHandler()) require.NotNil(t, managedCoreComponents.ArwenChangeLocker()) require.NotNil(t, managedCoreComponents.ProcessStatusHandler()) - expectedBytes, _ := managedCoreComponents.ValidatorPubKeyConverter().Decode(dummyPk) + expectedBytes, _ := managedCoreComponents.ValidatorPubKeyConverter().Decode(componentsMock.DummyPk) require.Equal(t, expectedBytes, managedCoreComponents.HardforkTriggerPubKey()) } @@ -80,9 +81,9 @@ func TestManagedCoreComponents_Close(t *testing.T) { t.Skip("this is not a short test") } - coreArgs := getCoreArgs() - coreComponentsFactory, _ := factory.NewCoreComponentsFactory(coreArgs) - managedCoreComponents, _ := factory.NewManagedCoreComponents(coreComponentsFactory) + coreArgs := componentsMock.GetCoreArgs() + coreComponentsFactory, _ := coreComp.NewCoreComponentsFactory(coreArgs) + managedCoreComponents, _ := coreComp.NewManagedCoreComponents(coreComponentsFactory) err := managedCoreComponents.Create() require.NoError(t, err) diff --git a/factory/core/coreComponents_test.go b/factory/core/coreComponents_test.go new file mode 100644 index 00000000000..bac99fb8034 --- /dev/null +++ b/factory/core/coreComponents_test.go @@ -0,0 +1,273 @@ +package core_test + +import ( + "errors" + "testing" + + "github.com/ElrondNetwork/elrond-go/config" + errorsErd "github.com/ElrondNetwork/elrond-go/errors" + coreComp "github.com/ElrondNetwork/elrond-go/factory/core" + "github.com/ElrondNetwork/elrond-go/state" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/stretchr/testify/require" +) + +func TestNewCoreComponentsFactory_OkValuesShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + require.NotNil(t, ccf) +} + +func TestCoreComponentsFactory_CreateCoreComponentsNoHasherConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Marshalizer: config.MarshalizerConfig{ + Type: componentsMock.TestMarshalizer, + SizeCheckDelta: 0, + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrHasherCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsInvalidHasherConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Marshalizer: config.MarshalizerConfig{ + Type: componentsMock.TestMarshalizer, + SizeCheckDelta: 0, + }, + Hasher: config.TypeConfig{ + Type: "invalid_type", + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrHasherCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsNoInternalMarshallerConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Hasher: config.TypeConfig{ + Type: componentsMock.TestHasher, + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsInvalidInternalMarshallerConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Marshalizer: config.MarshalizerConfig{ + Type: "invalid_marshalizer_type", + SizeCheckDelta: 0, + }, + Hasher: config.TypeConfig{ + Type: componentsMock.TestHasher, + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsNoVmMarshallerConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Hasher: config.TypeConfig{ + Type: componentsMock.TestHasher, + }, + Marshalizer: config.MarshalizerConfig{ + Type: componentsMock.TestMarshalizer, + SizeCheckDelta: 0, + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsInvalidVmMarshallerConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Marshalizer: config.MarshalizerConfig{ + Type: componentsMock.TestMarshalizer, + SizeCheckDelta: 0, + }, + Hasher: config.TypeConfig{ + Type: componentsMock.TestHasher, + }, + VmMarshalizer: config.TypeConfig{ + Type: "invalid", + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsNoTxSignMarshallerConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Hasher: config.TypeConfig{ + Type: componentsMock.TestHasher, + }, + Marshalizer: config.MarshalizerConfig{ + Type: componentsMock.TestMarshalizer, + SizeCheckDelta: 0, + }, + VmMarshalizer: config.TypeConfig{ + Type: componentsMock.TestMarshalizer, + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsInvalidTxSignMarshallerConfigShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config = config.Config{ + Marshalizer: config.MarshalizerConfig{ + Type: componentsMock.TestMarshalizer, + SizeCheckDelta: 0, + }, + Hasher: config.TypeConfig{ + Type: componentsMock.TestHasher, + }, + VmMarshalizer: config.TypeConfig{ + Type: componentsMock.TestMarshalizer, + }, + TxSignMarshalizer: config.TypeConfig{ + Type: "invalid", + }, + } + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsInvalidValPubKeyConverterShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config.ValidatorPubkeyConverter.Type = "invalid" + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, state.ErrInvalidPubkeyConverterType)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsInvalidAddrPubKeyConverterShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + args.Config.AddressPubkeyConverter.Type = "invalid" + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.Nil(t, cc) + require.True(t, errors.Is(err, state.ErrInvalidPubkeyConverterType)) +} + +func TestCoreComponentsFactory_CreateCoreComponentsShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + ccf, _ := coreComp.NewCoreComponentsFactory(args) + + cc, err := ccf.Create() + require.NoError(t, err) + require.NotNil(t, cc) +} + +// ------------ Test CoreComponents -------------------- +func TestCoreComponents_CloseShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetCoreArgs() + ccf, _ := coreComp.NewCoreComponentsFactory(args) + cc, _ := ccf.Create() + err := cc.Close() + + require.NoError(t, err) +} diff --git a/factory/coreComponents_test.go b/factory/coreComponents_test.go deleted file mode 100644 index 97faadee26d..00000000000 --- a/factory/coreComponents_test.go +++ /dev/null @@ -1,483 +0,0 @@ -package factory_test - -import ( - "errors" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/data/endProcess" - "github.com/ElrondNetwork/elrond-go/config" - errorsErd "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/state" - "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" - "github.com/stretchr/testify/require" -) - -const testHasher = "blake2b" -const testMarshalizer = "json" -const signedBlocksThreshold = 0.025 -const consecutiveMissedBlocksPenalty = 1.1 - -func TestNewCoreComponentsFactory_OkValuesShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - ccf, _ := factory.NewCoreComponentsFactory(args) - - require.NotNil(t, ccf) -} - -func TestCoreComponentsFactory_CreateCoreComponentsNoHasherConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Marshalizer: config.MarshalizerConfig{ - Type: testMarshalizer, - SizeCheckDelta: 0, - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrHasherCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsInvalidHasherConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Marshalizer: config.MarshalizerConfig{ - Type: testMarshalizer, - SizeCheckDelta: 0, - }, - Hasher: config.TypeConfig{ - Type: "invalid_type", - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrHasherCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsNoInternalMarshallerConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Hasher: config.TypeConfig{ - Type: testHasher, - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsInvalidInternalMarshallerConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Marshalizer: config.MarshalizerConfig{ - Type: "invalid_marshalizer_type", - SizeCheckDelta: 0, - }, - Hasher: config.TypeConfig{ - Type: testHasher, - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsNoVmMarshallerConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Hasher: config.TypeConfig{ - Type: testHasher, - }, - Marshalizer: config.MarshalizerConfig{ - Type: testMarshalizer, - SizeCheckDelta: 0, - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsInvalidVmMarshallerConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Marshalizer: config.MarshalizerConfig{ - Type: testMarshalizer, - SizeCheckDelta: 0, - }, - Hasher: config.TypeConfig{ - Type: testHasher, - }, - VmMarshalizer: config.TypeConfig{ - Type: "invalid", - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsNoTxSignMarshallerConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Hasher: config.TypeConfig{ - Type: testHasher, - }, - Marshalizer: config.MarshalizerConfig{ - Type: testMarshalizer, - SizeCheckDelta: 0, - }, - VmMarshalizer: config.TypeConfig{ - Type: testMarshalizer, - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsInvalidTxSignMarshallerConfigShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config = config.Config{ - Marshalizer: config.MarshalizerConfig{ - Type: testMarshalizer, - SizeCheckDelta: 0, - }, - Hasher: config.TypeConfig{ - Type: testHasher, - }, - VmMarshalizer: config.TypeConfig{ - Type: testMarshalizer, - }, - TxSignMarshalizer: config.TypeConfig{ - Type: "invalid", - }, - } - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, errorsErd.ErrMarshalizerCreation)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsInvalidValPubKeyConverterShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config.ValidatorPubkeyConverter.Type = "invalid" - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, state.ErrInvalidPubkeyConverterType)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsInvalidAddrPubKeyConverterShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - args.Config.AddressPubkeyConverter.Type = "invalid" - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.Nil(t, cc) - require.True(t, errors.Is(err, state.ErrInvalidPubkeyConverterType)) -} - -func TestCoreComponentsFactory_CreateCoreComponentsShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - ccf, _ := factory.NewCoreComponentsFactory(args) - - cc, err := ccf.Create() - require.NoError(t, err) - require.NotNil(t, cc) -} - -// ------------ Test CoreComponents -------------------- -func TestCoreComponents_CloseShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getCoreArgs() - ccf, _ := factory.NewCoreComponentsFactory(args) - cc, _ := ccf.Create() - err := cc.Close() - - require.NoError(t, err) -} - -func getEpochStartConfig() config.EpochStartConfig { - return config.EpochStartConfig{ - MinRoundsBetweenEpochs: 20, - RoundsPerEpoch: 20, - MaxShuffledOutRestartThreshold: 0.2, - MinShuffledOutRestartThreshold: 0.1, - MinNumConnectedPeersToStart: 2, - MinNumOfPeersToConsiderBlockValid: 2, - } -} - -func getCoreArgs() factory.CoreComponentsFactoryArgs { - return factory.CoreComponentsFactoryArgs{ - Config: config.Config{ - EpochStartConfig: getEpochStartConfig(), - PublicKeyPeerId: config.CacheConfig{ - Type: "LRU", - Capacity: 5000, - Shards: 16, - }, - PublicKeyShardId: config.CacheConfig{ - Type: "LRU", - Capacity: 5000, - Shards: 16, - }, - PeerIdShardId: config.CacheConfig{ - Type: "LRU", - Capacity: 5000, - Shards: 16, - }, - PeerHonesty: config.CacheConfig{ - Type: "LRU", - Capacity: 5000, - Shards: 16, - }, - GeneralSettings: config.GeneralSettingsConfig{ - ChainID: "undefined", - MinTransactionVersion: 1, - GenesisMaxNumberOfShards: 3, - }, - Marshalizer: config.MarshalizerConfig{ - Type: testMarshalizer, - SizeCheckDelta: 0, - }, - Hasher: config.TypeConfig{ - Type: testHasher, - }, - VmMarshalizer: config.TypeConfig{ - Type: testMarshalizer, - }, - TxSignMarshalizer: config.TypeConfig{ - Type: testMarshalizer, - }, - TxSignHasher: config.TypeConfig{ - Type: testHasher, - }, - AddressPubkeyConverter: config.PubkeyConfig{ - Length: 32, - Type: "bech32", - SignatureLength: 0, - }, - ValidatorPubkeyConverter: config.PubkeyConfig{ - Length: 96, - Type: "hex", - SignatureLength: 48, - }, - Consensus: config.ConsensusConfig{ - Type: "bls", - }, - ValidatorStatistics: config.ValidatorStatisticsConfig{ - CacheRefreshIntervalInSec: uint32(100), - }, - SoftwareVersionConfig: config.SoftwareVersionConfig{ - PollingIntervalInMinutes: 30, - }, - Versions: config.VersionsConfig{ - DefaultVersion: "1", - VersionsByEpochs: nil, - Cache: config.CacheConfig{ - Type: "LRU", - Capacity: 1000, - Shards: 1, - }, - }, - PeersRatingConfig: config.PeersRatingConfig{ - TopRatedCacheCapacity: 1000, - BadRatedCacheCapacity: 1000, - }, - PoolsCleanersConfig: config.PoolsCleanersConfig{ - MaxRoundsToKeepUnprocessedMiniBlocks: 50, - MaxRoundsToKeepUnprocessedTransactions: 50, - }, - Hardfork: config.HardforkConfig{ - PublicKeyToListenFrom: dummyPk, - }, - HeartbeatV2: config.HeartbeatV2Config{ - HeartbeatExpiryTimespanInSec: 10, - }, - }, - ConfigPathsHolder: config.ConfigurationPathsHolder{ - GasScheduleDirectoryName: "../cmd/node/config/gasSchedules", - }, - RatingsConfig: createDummyRatingsConfig(), - EconomicsConfig: createDummyEconomicsConfig(), - NodesFilename: "mock/testdata/nodesSetupMock.json", - WorkingDirectory: "home", - ChanStopNodeProcess: make(chan endProcess.ArgEndProcess), - StatusHandlersFactory: &statusHandler.StatusHandlersFactoryMock{}, - EpochConfig: config.EpochConfig{ - GasSchedule: config.GasScheduleConfig{ - GasScheduleByEpochs: []config.GasScheduleByEpochs{ - { - StartEpoch: 0, - FileName: "gasScheduleV1.toml", - }, - }, - }, - }, - RoundConfig: config.RoundConfig{ - RoundActivations: map[string]config.ActivationRoundByName{ - "Example": { - Round: "18446744073709551615", - }, - }, - }, - } -} - -func createDummyEconomicsConfig() config.EconomicsConfig { - return config.EconomicsConfig{ - GlobalSettings: config.GlobalSettings{ - GenesisTotalSupply: "20000000000000000000000000", - MinimumInflation: 0, - YearSettings: []*config.YearSetting{ - { - Year: 0, - MaximumInflation: 0.01, - }, - }, - }, - RewardsSettings: config.RewardsSettings{ - RewardsConfigByEpoch: []config.EpochRewardSettings{ - { - LeaderPercentage: 0.1, - ProtocolSustainabilityPercentage: 0.1, - ProtocolSustainabilityAddress: "erd1932eft30w753xyvme8d49qejgkjc09n5e49w4mwdjtm0neld797su0dlxp", - TopUpFactor: 0.25, - TopUpGradientPoint: "3000000000000000000000000", - }, - }, - }, - FeeSettings: config.FeeSettings{ - GasLimitSettings: []config.GasLimitSetting{ - { - MaxGasLimitPerBlock: "1500000000", - MaxGasLimitPerMiniBlock: "1500000000", - MaxGasLimitPerMetaBlock: "15000000000", - MaxGasLimitPerMetaMiniBlock: "15000000000", - MaxGasLimitPerTx: "1500000000", - MinGasLimit: "50000", - }, - }, - MinGasPrice: "1000000000", - GasPerDataByte: "1500", - GasPriceModifier: 1, - }, - } -} - -func createDummyRatingsConfig() config.RatingsConfig { - return config.RatingsConfig{ - General: config.General{ - StartRating: 5000001, - MaxRating: 10000000, - MinRating: 1, - SignedBlocksThreshold: signedBlocksThreshold, - SelectionChances: []*config.SelectionChance{ - {MaxThreshold: 0, ChancePercent: 5}, - {MaxThreshold: 2500000, ChancePercent: 19}, - {MaxThreshold: 7500000, ChancePercent: 20}, - {MaxThreshold: 10000000, ChancePercent: 21}, - }, - }, - ShardChain: config.ShardChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: consecutiveMissedBlocksPenalty, - }, - }, - MetaChain: config.MetaChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: consecutiveMissedBlocksPenalty, - }, - }, - } -} diff --git a/factory/cryptoComponents.go b/factory/crypto/cryptoComponents.go similarity index 74% rename from factory/cryptoComponents.go rename to factory/crypto/cryptoComponents.go index 5f272fa7b3f..7d5d2c7a893 100644 --- a/factory/cryptoComponents.go +++ b/factory/crypto/cryptoComponents.go @@ -1,4 +1,4 @@ -package factory +package crypto import ( "bytes" @@ -6,30 +6,27 @@ import ( "fmt" "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go-core/hashing" - "github.com/ElrondNetwork/elrond-go-core/hashing/blake2b" - "github.com/ElrondNetwork/elrond-go-core/hashing/sha256" "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" disabledCrypto "github.com/ElrondNetwork/elrond-go-crypto/signing/disabled" - disabledMultiSig "github.com/ElrondNetwork/elrond-go-crypto/signing/disabled/multisig" disabledSig "github.com/ElrondNetwork/elrond-go-crypto/signing/disabled/singlesig" "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519" "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" - mclMultiSig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/multisig" mclSig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/singlesig" - "github.com/ElrondNetwork/elrond-go-crypto/signing/multisig" + logger "github.com/ElrondNetwork/elrond-go-logger" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/factory/peerSignatureHandler" "github.com/ElrondNetwork/elrond-go/genesis/process/disabled" "github.com/ElrondNetwork/elrond-go/heartbeat" "github.com/ElrondNetwork/elrond-go/keysManagement" - p2pCrypto "github.com/ElrondNetwork/elrond-go/p2p/crypto" + "github.com/ElrondNetwork/elrond-go/p2p" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/vm" systemVM "github.com/ElrondNetwork/elrond-go/vm/process" ) @@ -44,9 +41,10 @@ type CryptoComponentsFactoryArgs struct { ValidatorKeyPemFileName string SkIndex int Config config.Config + EnableEpochs config.EnableEpochs PrefsConfig config.Preferences - CoreComponentsHolder CoreComponentsHolder - KeyLoader KeyLoaderHandler + CoreComponentsHolder factory.CoreComponentsHolder + KeyLoader factory.KeyLoaderHandler ActivateBLSPubKeyMessageVerification bool IsInImportMode bool ImportModeNoSigCheck bool @@ -58,10 +56,11 @@ type cryptoComponentsFactory struct { validatorKeyPemFileName string skIndex int config config.Config + enableEpochs config.EnableEpochs prefsConfig config.Preferences - coreComponentsHolder CoreComponentsHolder + coreComponentsHolder factory.CoreComponentsHolder activateBLSPubKeyMessageVerification bool - keyLoader KeyLoaderHandler + keyLoader factory.KeyLoaderHandler isInImportMode bool importModeNoSigCheck bool noKeyProvided bool @@ -78,18 +77,20 @@ type cryptoParams struct { // cryptoComponents struct holds the crypto components type cryptoComponents struct { - txSingleSigner crypto.SingleSigner - blockSingleSigner crypto.SingleSigner - multiSigner crypto.MultiSigner - peerSignHandler crypto.PeerSignatureHandler - blockSignKeyGen crypto.KeyGenerator - txSignKeyGen crypto.KeyGenerator - messageSignVerifier vm.MessageSignVerifier - managedPeersHolder heartbeat.ManagedPeersHolder - keysHandler consensus.KeysHandler + txSingleSigner crypto.SingleSigner + blockSingleSigner crypto.SingleSigner + multiSignerContainer cryptoCommon.MultiSignerContainer + peerSignHandler crypto.PeerSignatureHandler + blockSignKeyGen crypto.KeyGenerator + txSignKeyGen crypto.KeyGenerator + messageSignVerifier vm.MessageSignVerifier + managedPeersHolder heartbeat.ManagedPeersHolder + keysHandler consensus.KeysHandler cryptoParams } +var log = logger.GetOrCreate("factory") + // NewCryptoComponentsFactory returns a new crypto components factory func NewCryptoComponentsFactory(args CryptoComponentsFactoryArgs) (*cryptoComponentsFactory, error) { if check.IfNil(args.CoreComponentsHolder) { @@ -113,6 +114,7 @@ func NewCryptoComponentsFactory(args CryptoComponentsFactoryArgs) (*cryptoCompon keyLoader: args.KeyLoader, isInImportMode: args.IsInImportMode, importModeNoSigCheck: args.ImportModeNoSigCheck, + enableEpochs: args.EnableEpochs, noKeyProvided: args.NoKeyProvided, } @@ -144,12 +146,7 @@ func (ccf *cryptoComponentsFactory) Create() (*cryptoComponents, error) { return nil, err } - multisigHasher, err := ccf.getMultiSigHasherFromConfig() - if err != nil { - return nil, err - } - - multiSigner, err := ccf.createMultiSigner(multisigHasher, cp, blockSignKeyGen, ccf.importModeNoSigCheck) + multiSigner, err := ccf.createMultiSignerContainer(blockSignKeyGen, ccf.importModeNoSigCheck) if err != nil { return nil, err } @@ -168,7 +165,7 @@ func (ccf *cryptoComponentsFactory) Create() (*cryptoComponents, error) { } cacheConfig := ccf.config.PublicKeyPIDSignature - cachePkPIDSignature, err := storageUnit.NewCache(storageFactory.GetCacherFromConfig(cacheConfig)) + cachePkPIDSignature, err := storageunit.NewCache(storageFactory.GetCacherFromConfig(cacheConfig)) if err != nil { return nil, err } @@ -183,7 +180,7 @@ func (ccf *cryptoComponentsFactory) Create() (*cryptoComponents, error) { isMainMachine := redundancyLevel == mainMachineRedundancyLevel argsManagedPeersHolder := keysManagement.ArgsManagedPeersHolder{ KeyGenerator: blockSignKeyGen, - P2PIdentityGenerator: p2pCrypto.NewIdentityGenerator(), + P2PIdentityGenerator: p2p.NewRandomP2PIdentityGenerator(), IsMainMachine: isMainMachine, MaxRoundsWithoutReceivedMessages: redundancyLevel, PrefsConfig: ccf.prefsConfig, @@ -196,16 +193,16 @@ func (ccf *cryptoComponentsFactory) Create() (*cryptoComponents, error) { log.Debug("block sign pubkey", "value", cp.publicKeyString) return &cryptoComponents{ - txSingleSigner: txSingleSigner, - blockSingleSigner: interceptSingleSigner, - multiSigner: multiSigner, - peerSignHandler: peerSigHandler, - blockSignKeyGen: blockSignKeyGen, - txSignKeyGen: txSignKeyGen, - messageSignVerifier: messageSignVerifier, - managedPeersHolder: managedPeersHolder, - keysHandler: keysManagement.NewKeysHandler(), - cryptoParams: *cp, + txSingleSigner: txSingleSigner, + blockSingleSigner: interceptSingleSigner, + multiSignerContainer: multiSigner, + peerSignHandler: peerSigHandler, + blockSignKeyGen: blockSignKeyGen, + txSignKeyGen: txSignKeyGen, + messageSignVerifier: messageSignVerifier, + managedPeersHolder: managedPeersHolder, + keysHandler: keysManagement.NewKeysHandler(), + cryptoParams: *cp, }, nil } @@ -226,45 +223,18 @@ func (ccf *cryptoComponentsFactory) createSingleSigner(importModeNoSigCheck bool } } -func (ccf *cryptoComponentsFactory) getMultiSigHasherFromConfig() (hashing.Hasher, error) { - if ccf.consensusType == consensus.BlsConsensusType && ccf.config.MultisigHasher.Type != "blake2b" { - return nil, errors.ErrMultiSigHasherMissmatch - } - - switch ccf.config.MultisigHasher.Type { - case "sha256": - return sha256.NewSha256(), nil - case "blake2b": - if ccf.consensusType == consensus.BlsConsensusType { - return blake2b.NewBlake2bWithSize(multisig.BlsHashSize) - } - return blake2b.NewBlake2b(), nil - } - - return nil, errors.ErrMissingMultiHasherConfig -} - -func (ccf *cryptoComponentsFactory) createMultiSigner( - hasher hashing.Hasher, - cp *cryptoParams, +func (ccf *cryptoComponentsFactory) createMultiSignerContainer( blSignKeyGen crypto.KeyGenerator, importModeNoSigCheck bool, -) (crypto.MultiSigner, error) { - if importModeNoSigCheck { - log.Warn("using disabled multi signer because the node is running in import-db 'turbo mode'") - return &disabledMultiSig.DisabledMultiSig{}, nil - } +) (cryptoCommon.MultiSignerContainer, error) { - switch ccf.consensusType { - case consensus.BlsConsensusType: - blsSigner := &mclMultiSig.BlsMultiSigner{Hasher: hasher} - return multisig.NewBLSMultisig(blsSigner, []string{string(cp.publicKeyBytes)}, cp.privateKey, blSignKeyGen, uint16(0)) - case disabledSigChecking: - log.Warn("using disabled multi signer") - return &disabledMultiSig.DisabledMultiSig{}, nil - default: - return nil, errors.ErrInvalidConsensusConfig + args := MultiSigArgs{ + MultiSigHasherType: ccf.config.MultisigHasher.Type, + BlSignKeyGen: blSignKeyGen, + ConsensusType: ccf.consensusType, + ImportModeNoSigCheck: importModeNoSigCheck, } + return NewMultiSignerContainer(args, ccf.enableEpochs.BLSMultiSignerEnableEpoch) } func (ccf *cryptoComponentsFactory) getSuite() (crypto.Suite, error) { diff --git a/factory/cryptoComponentsHandler.go b/factory/crypto/cryptoComponentsHandler.go similarity index 79% rename from factory/cryptoComponentsHandler.go rename to factory/crypto/cryptoComponentsHandler.go index df0a0590323..de44d33ee5a 100644 --- a/factory/cryptoComponentsHandler.go +++ b/factory/crypto/cryptoComponentsHandler.go @@ -1,21 +1,23 @@ -package factory +package crypto import ( "fmt" "sync" "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go-crypto" + crypto "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/heartbeat" "github.com/ElrondNetwork/elrond-go/vm" ) -var _ ComponentHandler = (*managedCryptoComponents)(nil) -var _ CryptoParamsHolder = (*managedCryptoComponents)(nil) -var _ CryptoComponentsHolder = (*managedCryptoComponents)(nil) -var _ CryptoComponentsHandler = (*managedCryptoComponents)(nil) +var _ factory.ComponentHandler = (*managedCryptoComponents)(nil) +var _ factory.CryptoParamsHolder = (*managedCryptoComponents)(nil) +var _ factory.CryptoComponentsHolder = (*managedCryptoComponents)(nil) +var _ factory.CryptoComponentsHandler = (*managedCryptoComponents)(nil) // CryptoComponentsHandlerArgs holds the arguments required to create a crypto components handler type CryptoComponentsHandlerArgs CryptoComponentsFactoryArgs @@ -91,7 +93,7 @@ func (mcc *managedCryptoComponents) CheckSubcomponents() error { if check.IfNil(mcc.cryptoComponents.blockSingleSigner) { return errors.ErrNilBlockSigner } - if check.IfNil(mcc.cryptoComponents.multiSigner) { + if check.IfNil(mcc.cryptoComponents.multiSignerContainer) { return errors.ErrNilMultiSigner } if check.IfNil(mcc.cryptoComponents.peerSignHandler) { @@ -197,41 +199,52 @@ func (mcc *managedCryptoComponents) BlockSigner() crypto.SingleSigner { return mcc.cryptoComponents.blockSingleSigner } -// MultiSigner returns the block multi-signer -func (mcc *managedCryptoComponents) MultiSigner() crypto.MultiSigner { +// MultiSignerContainer returns the multiSigner container holding the multiSigner versions +func (mcc *managedCryptoComponents) MultiSignerContainer() cryptoCommon.MultiSignerContainer { mcc.mutCryptoComponents.RLock() defer mcc.mutCryptoComponents.RUnlock() - if mcc.cryptoComponents == nil { return nil } - return mcc.cryptoComponents.multiSigner + return mcc.cryptoComponents.multiSignerContainer } -// PeerSignatureHandler returns the peer signature handler -func (mcc *managedCryptoComponents) PeerSignatureHandler() crypto.PeerSignatureHandler { +// SetMultiSignerContainer sets the multiSigner container in the crypto components +func (mcc *managedCryptoComponents) SetMultiSignerContainer(ms cryptoCommon.MultiSignerContainer) error { + mcc.mutCryptoComponents.Lock() + mcc.multiSignerContainer = ms + mcc.mutCryptoComponents.Unlock() + + return nil +} + +// GetMultiSigner returns the multiSigner configured in the multiSigner container for the given epoch +func (mcc *managedCryptoComponents) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { mcc.mutCryptoComponents.RLock() defer mcc.mutCryptoComponents.RUnlock() if mcc.cryptoComponents == nil { - return nil + return nil, errors.ErrNilCryptoComponentsHolder } - return mcc.cryptoComponents.peerSignHandler + if mcc.multiSignerContainer == nil { + return nil, errors.ErrNilMultiSignerContainer + } + + return mcc.MultiSignerContainer().GetMultiSigner(epoch) } -// SetMultiSigner sets the block multi-signer -func (mcc *managedCryptoComponents) SetMultiSigner(ms crypto.MultiSigner) error { - mcc.mutCryptoComponents.Lock() - defer mcc.mutCryptoComponents.Unlock() +// PeerSignatureHandler returns the peer signature handler +func (mcc *managedCryptoComponents) PeerSignatureHandler() crypto.PeerSignatureHandler { + mcc.mutCryptoComponents.RLock() + defer mcc.mutCryptoComponents.RUnlock() if mcc.cryptoComponents == nil { - return errors.ErrNilCryptoComponents + return nil } - mcc.cryptoComponents.multiSigner = ms - return nil + return mcc.cryptoComponents.peerSignHandler } // BlockSignKeyGen returns the block signer key generator @@ -299,16 +312,16 @@ func (mcc *managedCryptoComponents) Clone() interface{} { cryptoComp := (*cryptoComponents)(nil) if mcc.cryptoComponents != nil { cryptoComp = &cryptoComponents{ - txSingleSigner: mcc.TxSingleSigner(), - blockSingleSigner: mcc.BlockSigner(), - multiSigner: mcc.MultiSigner(), - peerSignHandler: mcc.PeerSignatureHandler(), - blockSignKeyGen: mcc.BlockSignKeyGen(), - txSignKeyGen: mcc.TxSignKeyGen(), - messageSignVerifier: mcc.MessageSignVerifier(), - managedPeersHolder: mcc.ManagedPeersHolder(), - keysHandler: mcc.KeysHandler(), - cryptoParams: mcc.cryptoParams, + txSingleSigner: mcc.TxSingleSigner(), + blockSingleSigner: mcc.BlockSigner(), + multiSignerContainer: mcc.MultiSignerContainer(), + peerSignHandler: mcc.PeerSignatureHandler(), + blockSignKeyGen: mcc.BlockSignKeyGen(), + txSignKeyGen: mcc.TxSignKeyGen(), + messageSignVerifier: mcc.MessageSignVerifier(), + managedPeersHolder: mcc.ManagedPeersHolder(), + keysHandler: mcc.KeysHandler(), + cryptoParams: mcc.cryptoParams, } } @@ -326,5 +339,5 @@ func (mcc *managedCryptoComponents) IsInterfaceNil() bool { // String returns the name of the component func (mcc *managedCryptoComponents) String() string { - return cryptoComponentsName + return factory.CryptoComponentsName } diff --git a/factory/cryptoComponentsHandler_test.go b/factory/crypto/cryptoComponentsHandler_test.go similarity index 62% rename from factory/cryptoComponentsHandler_test.go rename to factory/crypto/cryptoComponentsHandler_test.go index b7a3047e58e..1dc869b164b 100644 --- a/factory/cryptoComponentsHandler_test.go +++ b/factory/crypto/cryptoComponentsHandler_test.go @@ -1,10 +1,12 @@ -package factory_test +package crypto_test import ( "testing" + "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" + cryptoComp "github.com/ElrondNetwork/elrond-go/factory/crypto" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -15,11 +17,11 @@ func TestManagedCryptoComponents_CreateWithInvalidArgsShouldErr(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config.Consensus.Type = "invalid" - cryptoComponentsFactory, _ := factory.NewCryptoComponentsFactory(args) - managedCryptoComponents, err := factory.NewManagedCryptoComponents(cryptoComponentsFactory) + cryptoComponentsFactory, _ := cryptoComp.NewCryptoComponentsFactory(args) + managedCryptoComponents, err := cryptoComp.NewManagedCryptoComponents(cryptoComponentsFactory) require.NoError(t, err) err = managedCryptoComponents.Create() require.Error(t, err) @@ -32,14 +34,14 @@ func TestManagedCryptoComponents_CreateShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - cryptoComponentsFactory, _ := factory.NewCryptoComponentsFactory(args) - managedCryptoComponents, err := factory.NewManagedCryptoComponents(cryptoComponentsFactory) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + cryptoComponentsFactory, _ := cryptoComp.NewCryptoComponentsFactory(args) + managedCryptoComponents, err := cryptoComp.NewManagedCryptoComponents(cryptoComponentsFactory) require.NoError(t, err) require.Nil(t, managedCryptoComponents.TxSingleSigner()) require.Nil(t, managedCryptoComponents.BlockSigner()) - require.Nil(t, managedCryptoComponents.MultiSigner()) + require.Nil(t, managedCryptoComponents.MultiSignerContainer()) require.Nil(t, managedCryptoComponents.BlockSignKeyGen()) require.Nil(t, managedCryptoComponents.TxSignKeyGen()) require.Nil(t, managedCryptoComponents.MessageSignVerifier()) @@ -49,7 +51,10 @@ func TestManagedCryptoComponents_CreateShouldWork(t *testing.T) { require.NoError(t, err) require.NotNil(t, managedCryptoComponents.TxSingleSigner()) require.NotNil(t, managedCryptoComponents.BlockSigner()) - require.NotNil(t, managedCryptoComponents.MultiSigner()) + require.NotNil(t, managedCryptoComponents.MultiSignerContainer()) + multiSigner, errGet := managedCryptoComponents.MultiSignerContainer().GetMultiSigner(0) + require.NotNil(t, multiSigner) + require.Nil(t, errGet) require.NotNil(t, managedCryptoComponents.BlockSignKeyGen()) require.NotNil(t, managedCryptoComponents.TxSignKeyGen()) require.NotNil(t, managedCryptoComponents.MessageSignVerifier()) @@ -68,21 +73,6 @@ func TestManagedCryptoComponents_CheckSubcomponents(t *testing.T) { require.NoError(t, err) } -func TestManagedCryptoComponents_SetMultiSigner(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - managedCryptoComponents := getManagedCryptoComponents(t) - - ms := &cryptoMocks.MultisignerMock{} - err := managedCryptoComponents.SetMultiSigner(ms) - require.NoError(t, err) - - require.Equal(t, managedCryptoComponents.MultiSigner(), ms) -} - func TestManagedCryptoComponents_Close(t *testing.T) { t.Parallel() if testing.Short() { @@ -93,15 +83,17 @@ func TestManagedCryptoComponents_Close(t *testing.T) { err := managedCryptoComponents.Close() require.NoError(t, err) - require.Nil(t, managedCryptoComponents.MultiSigner()) + multiSigner, errGet := managedCryptoComponents.GetMultiSigner(0) + require.Nil(t, multiSigner) + require.Equal(t, errors.ErrNilCryptoComponentsHolder, errGet) } func getManagedCryptoComponents(t *testing.T) factory.CryptoComponentsHandler { - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - cryptoComponentsFactory, _ := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + cryptoComponentsFactory, _ := cryptoComp.NewCryptoComponentsFactory(args) require.NotNil(t, cryptoComponentsFactory) - managedCryptoComponents, _ := factory.NewManagedCryptoComponents(cryptoComponentsFactory) + managedCryptoComponents, _ := cryptoComp.NewManagedCryptoComponents(cryptoComponentsFactory) require.NotNil(t, managedCryptoComponents) err := managedCryptoComponents.Create() require.NoError(t, err) @@ -115,10 +107,10 @@ func TestManagedCryptoComponents_Clone(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - cryptoComponentsFactory, _ := factory.NewCryptoComponentsFactory(args) - managedCryptoComponents, _ := factory.NewManagedCryptoComponents(cryptoComponentsFactory) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + cryptoComponentsFactory, _ := cryptoComp.NewCryptoComponentsFactory(args) + managedCryptoComponents, _ := cryptoComp.NewManagedCryptoComponents(cryptoComponentsFactory) err := managedCryptoComponents.Create() require.NoError(t, err) diff --git a/factory/cryptoComponents_test.go b/factory/crypto/cryptoComponents_test.go similarity index 50% rename from factory/cryptoComponents_test.go rename to factory/crypto/cryptoComponents_test.go index 58fbeb9236d..313ddc6f871 100644 --- a/factory/cryptoComponents_test.go +++ b/factory/crypto/cryptoComponents_test.go @@ -1,4 +1,4 @@ -package factory_test +package crypto_test import ( "encoding/hex" @@ -9,26 +9,20 @@ import ( "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go/config" errErd "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" + cryptoComp "github.com/ElrondNetwork/elrond-go/factory/crypto" "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) -const dummyPk = "629e1245577afb7717ccb46b6ff3649bdd6a1311514ad4a7695da13f801cc277ee24e730a7fa8aa6c612159b4328db17" + - "35692d0bded3a2264ba621d6bda47a981d60e17dd306d608e0875a0ba19639fb0844661f519472a175ca9ed2f33fbe16" -const dummySk = "cea01c0bf060187d90394802ff223078e47527dc8aa33a922744fb1d06029c4b" - -type LoadKeysFunc func(string, int) ([]byte, string, error) - func TestNewCryptoComponentsFactory_NiCoreComponentsHandlerShouldErr(t *testing.T) { t.Parallel() if testing.Short() { t.Skip("this is not a short test") } - args := getCryptoArgs(nil) - ccf, err := factory.NewCryptoComponentsFactory(args) + args := componentsMock.GetCryptoArgs(nil) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.Nil(t, ccf) require.Equal(t, errErd.ErrNilCoreComponents, err) } @@ -39,10 +33,10 @@ func TestNewCryptoComponentsFactory_NilPemFileShouldErr(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.ValidatorKeyPemFileName = "" - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.Nil(t, ccf) require.Equal(t, errErd.ErrNilPath, err) } @@ -53,10 +47,10 @@ func TestCryptoComponentsFactory_CreateCryptoParamsNilKeyLoaderShouldErr(t *test t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.KeyLoader = nil - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.Nil(t, ccf) require.Equal(t, errErd.ErrNilKeyLoader, err) @@ -68,9 +62,9 @@ func TestNewCryptoComponentsFactory_OkValsShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - ccf, err := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NoError(t, err) require.NotNil(t, ccf) } @@ -81,10 +75,10 @@ func TestNewCryptoComponentsFactory_DisabledSigShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.ImportModeNoSigCheck = true - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NoError(t, err) require.NotNil(t, ccf) } @@ -95,10 +89,10 @@ func TestNewCryptoComponentsFactory_CreateInvalidConsensusTypeShouldErr(t *testi t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config.Consensus.Type = "invalid" - ccf, _ := factory.NewCryptoComponentsFactory(args) + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) cc, err := ccf.Create() require.Nil(t, cc) @@ -111,8 +105,8 @@ func TestCryptoComponentsFactory_CreateShouldErrDueToMissingConfig(t *testing.T) t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config = config.Config{ ValidatorPubkeyConverter: config.PubkeyConfig{ Length: 8, @@ -120,7 +114,7 @@ func TestCryptoComponentsFactory_CreateShouldErrDueToMissingConfig(t *testing.T) SignatureLength: 48, }} - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.Nil(t, err) cc, err := ccf.Create() @@ -134,10 +128,10 @@ func TestCryptoComponentsFactory_CreateInvalidMultiSigHasherShouldErr(t *testing t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config.MultisigHasher.Type = "invalid" - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.Nil(t, err) cspf, err := ccf.Create() @@ -151,9 +145,9 @@ func TestCryptoComponentsFactory_CreateOK(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - ccf, _ := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) cc, err := ccf.Create() require.NoError(t, err) @@ -166,10 +160,10 @@ func TestCryptoComponentsFactory_CreateWithDisabledSig(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.IsInImportMode = true - ccf, _ := factory.NewCryptoComponentsFactory(args) + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) cc, err := ccf.Create() require.NoError(t, err) @@ -182,10 +176,10 @@ func TestCryptoComponentsFactory_CreateWithAutoGenerateKey(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.NoKeyProvided = true - ccf, _ := factory.NewCryptoComponentsFactory(args) + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) cc, err := ccf.Create() require.NoError(t, err) @@ -198,10 +192,10 @@ func TestCryptoComponentsFactory_CreateSingleSignerInvalidConsensusTypeShouldErr t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config.Consensus.Type = "invalid" - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NotNil(t, ccf) require.Nil(t, err) @@ -216,9 +210,9 @@ func TestCryptoComponentsFactory_CreateSingleSignerOK(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - ccf, err := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NotNil(t, ccf) require.Nil(t, err) @@ -227,77 +221,20 @@ func TestCryptoComponentsFactory_CreateSingleSignerOK(t *testing.T) { require.NotNil(t, singleSigner) } -func TestCryptoComponentsFactory_GetMultiSigHasherFromConfigInvalidHasherShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - args.Config.Consensus.Type = "" - args.Config.MultisigHasher.Type = "" - ccf, err := factory.NewCryptoComponentsFactory(args) - require.NotNil(t, ccf) - require.Nil(t, err) - - multiSigHasher, err := ccf.GetMultiSigHasherFromConfig() - require.Nil(t, multiSigHasher) - require.Equal(t, errErd.ErrMissingMultiHasherConfig, err) -} - -func TestCryptoComponentsFactory_GetMultiSigHasherFromConfigMismatchConsensusTypeMultiSigHasher(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - args.Config.MultisigHasher.Type = "sha256" - ccf, err := factory.NewCryptoComponentsFactory(args) - require.NotNil(t, ccf) - require.Nil(t, err) - - multiSigHasher, err := ccf.GetMultiSigHasherFromConfig() - require.Nil(t, multiSigHasher) - require.Equal(t, errErd.ErrMultiSigHasherMissmatch, err) -} - -func TestCryptoComponentsFactory_GetMultiSigHasherFromConfigOK(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - args.Config.Consensus.Type = "bls" - args.Config.MultisigHasher.Type = "blake2b" - ccf, err := factory.NewCryptoComponentsFactory(args) - require.NotNil(t, ccf) - require.Nil(t, err) - - multiSigHasher, err := ccf.GetMultiSigHasherFromConfig() - require.Nil(t, err) - require.NotNil(t, multiSigHasher) -} - func TestCryptoComponentsFactory_CreateMultiSignerInvalidConsensusTypeShouldErr(t *testing.T) { t.Parallel() if testing.Short() { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config.Consensus.Type = "other" - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NotNil(t, ccf) require.Nil(t, err) - cp := ccf.CreateDummyCryptoParams() - multiSigner, err := ccf.CreateMultiSigner(&hashingMocks.HasherMock{}, cp, &mock.KeyGenMock{}, false) + multiSigner, err := ccf.CreateMultiSignerContainer(&mock.KeyGenMock{}, false) require.Nil(t, multiSigner) require.Equal(t, errErd.ErrInvalidConsensusConfig, err) } @@ -308,18 +245,16 @@ func TestCryptoComponentsFactory_CreateMultiSignerOK(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - ccf, err := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NotNil(t, ccf) require.Nil(t, err) suite, _ := ccf.GetSuite() blockSignKeyGen := signing.NewKeyGenerator(suite) - cp, _ := ccf.CreateCryptoParams(blockSignKeyGen) - multisigHasher, _ := ccf.GetMultiSigHasherFromConfig() - multiSigner, err := ccf.CreateMultiSigner(multisigHasher, cp, blockSignKeyGen, false) + multiSigner, err := ccf.CreateMultiSignerContainer(blockSignKeyGen, false) require.Nil(t, err) require.NotNil(t, multiSigner) } @@ -330,10 +265,10 @@ func TestCryptoComponentsFactory_GetSuiteInvalidConsensusTypeShouldErr(t *testin t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config.Consensus.Type = "" - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NotNil(t, ccf) require.Nil(t, err) @@ -348,10 +283,10 @@ func TestCryptoComponentsFactory_GetSuiteOK(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) args.Config.Consensus.Type = "bls" - ccf, err := factory.NewCryptoComponentsFactory(args) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.NotNil(t, ccf) require.Nil(t, err) @@ -366,10 +301,10 @@ func TestCryptoComponentsFactory_CreateCryptoParamsInvalidPrivateKeyByteArraySho t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: dummyLoadSkPkFromPemFile([]byte{}, dummyPk, nil)} - ccf, _ := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: componentsMock.DummyLoadSkPkFromPemFile([]byte{}, componentsMock.DummyPk, nil)} + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) suite, _ := ccf.GetSuite() blockSignKeyGen := signing.NewKeyGenerator(suite) @@ -387,10 +322,10 @@ func TestCryptoComponentsFactory_CreateCryptoParamsLoadKeysFailShouldErr(t *test expectedError := errors.New("expected error") - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: dummyLoadSkPkFromPemFile([]byte{}, "", expectedError)} - ccf, _ := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: componentsMock.DummyLoadSkPkFromPemFile([]byte{}, "", expectedError)} + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) suite, _ := ccf.GetSuite() blockSignKeyGen := signing.NewKeyGenerator(suite) @@ -406,9 +341,9 @@ func TestCryptoComponentsFactory_CreateCryptoParamsOK(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - ccf, _ := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) suite, _ := ccf.GetSuite() blockSignKeyGen := signing.NewKeyGenerator(suite) @@ -425,11 +360,11 @@ func TestCryptoComponentsFactory_GetSkPkInvalidSkBytesShouldErr(t *testing.T) { } setSk := []byte("zxwY") - setPk := []byte(dummyPk) - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: dummyLoadSkPkFromPemFile(setSk, string(setPk), nil)} - ccf, _ := factory.NewCryptoComponentsFactory(args) + setPk := []byte(componentsMock.DummyPk) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: componentsMock.DummyLoadSkPkFromPemFile(setSk, string(setPk), nil)} + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) sk, pk, err := ccf.GetSkPk() require.NotNil(t, err) @@ -443,13 +378,13 @@ func TestCryptoComponentsFactory_GetSkPkInvalidPkBytesShouldErr(t *testing.T) { t.Skip("this is not a short test") } - setSk := []byte(dummySk) + setSk := []byte(componentsMock.DummySk) setPk := "0" - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: dummyLoadSkPkFromPemFile(setSk, setPk, nil)} - ccf, _ := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + args.KeyLoader = &mock.KeyLoaderStub{LoadKeyCalled: componentsMock.DummyLoadSkPkFromPemFile(setSk, setPk, nil)} + ccf, _ := cryptoComp.NewCryptoComponentsFactory(args) sk, pk, err := ccf.GetSkPk() require.NotNil(t, err) @@ -463,57 +398,16 @@ func TestCryptoComponentsFactory_GetSkPkOK(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() - args := getCryptoArgs(coreComponents) - ccf, err := factory.NewCryptoComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetCryptoArgs(coreComponents) + ccf, err := cryptoComp.NewCryptoComponentsFactory(args) require.Nil(t, err) - expectedSk, _ := hex.DecodeString(dummySk) - expectedPk, _ := hex.DecodeString(dummyPk) + expectedSk, _ := hex.DecodeString(componentsMock.DummySk) + expectedPk, _ := hex.DecodeString(componentsMock.DummyPk) sk, pk, err := ccf.GetSkPk() require.Nil(t, err) require.Equal(t, expectedSk, sk) require.Equal(t, expectedPk, pk) } - -func getCryptoArgs(coreComponents factory.CoreComponentsHolder) factory.CryptoComponentsFactoryArgs { - args := factory.CryptoComponentsFactoryArgs{ - Config: config.Config{ - GeneralSettings: config.GeneralSettingsConfig{ChainID: "undefined"}, - Consensus: config.ConsensusConfig{ - Type: "bls", - }, - MultisigHasher: config.TypeConfig{Type: "blake2b"}, - PublicKeyPIDSignature: config.CacheConfig{ - Capacity: 1000, - Type: "LRU", - }, - Hasher: config.TypeConfig{Type: "blake2b"}, - }, - PrefsConfig: config.Preferences{ - Preferences: config.PreferencesConfig{ - DestinationShardAsObserver: "", - NodeDisplayName: "node name", - Identity: "identity", - RedundancyLevel: 1, - }, - NamedIdentity: nil, - }, - SkIndex: 0, - ValidatorKeyPemFileName: "validatorKey.pem", - CoreComponentsHolder: coreComponents, - ActivateBLSPubKeyMessageVerification: false, - KeyLoader: &mock.KeyLoaderStub{ - LoadKeyCalled: dummyLoadSkPkFromPemFile([]byte(dummySk), dummyPk, nil), - }, - } - - return args -} - -func dummyLoadSkPkFromPemFile(sk []byte, pk string, err error) LoadKeysFunc { - return func(_ string, _ int) ([]byte, string, error) { - return sk, pk, err - } -} diff --git a/factory/crypto/export_test.go b/factory/crypto/export_test.go new file mode 100644 index 00000000000..9e006c825f3 --- /dev/null +++ b/factory/crypto/export_test.go @@ -0,0 +1,39 @@ +package crypto + +import ( + crypto "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" +) + +// GetSkPk - +func (ccf *cryptoComponentsFactory) GetSkPk() ([]byte, []byte, error) { + return ccf.getSkPk() +} + +// CreateSingleSigner - +func (ccf *cryptoComponentsFactory) CreateSingleSigner(importModeNoSigCheck bool) (crypto.SingleSigner, error) { + return ccf.createSingleSigner(importModeNoSigCheck) +} + +// CreateDummyCryptoParams +func (ccf *cryptoComponentsFactory) CreateDummyCryptoParams() *cryptoParams { + return &cryptoParams{} +} + +// CreateCryptoParams - +func (ccf *cryptoComponentsFactory) CreateCryptoParams(blockSignKeyGen crypto.KeyGenerator) (*cryptoParams, error) { + return ccf.createCryptoParams(blockSignKeyGen) +} + +// CreateMultiSignerContainer - +func (ccf *cryptoComponentsFactory) CreateMultiSignerContainer( + blSignKeyGen crypto.KeyGenerator, + importModeNoSigCheck bool, +) (cryptoCommon.MultiSignerContainer, error) { + return ccf.createMultiSignerContainer(blSignKeyGen, importModeNoSigCheck) +} + +// GetSuite - +func (ccf *cryptoComponentsFactory) GetSuite() (crypto.Suite, error) { + return ccf.getSuite() +} diff --git a/factory/crypto/multiSignerContainer.go b/factory/crypto/multiSignerContainer.go new file mode 100644 index 00000000000..b886bfa0ad4 --- /dev/null +++ b/factory/crypto/multiSignerContainer.go @@ -0,0 +1,156 @@ +package crypto + +import ( + "sort" + "sync" + + "github.com/ElrondNetwork/elrond-go-core/core/check" + "github.com/ElrondNetwork/elrond-go-core/hashing" + "github.com/ElrondNetwork/elrond-go-core/hashing/blake2b" + "github.com/ElrondNetwork/elrond-go-core/hashing/sha256" + crypto "github.com/ElrondNetwork/elrond-go-crypto" + disabledMultiSig "github.com/ElrondNetwork/elrond-go-crypto/signing/disabled/multisig" + mclMultiSig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/multisig" + "github.com/ElrondNetwork/elrond-go-crypto/signing/multisig" + "github.com/ElrondNetwork/elrond-go/config" + "github.com/ElrondNetwork/elrond-go/consensus" + "github.com/ElrondNetwork/elrond-go/errors" +) + +const ( + blsNoKOSK = "no-KOSK" + blsKOSK = "KOSK" +) + +type epochMultiSigner struct { + epoch uint32 + multiSigner crypto.MultiSigner +} + +type container struct { + multiSigners []*epochMultiSigner + mutSigners sync.RWMutex +} + +// MultiSigArgs holds the arguments for creating the multiSignerContainer container +type MultiSigArgs struct { + MultiSigHasherType string + BlSignKeyGen crypto.KeyGenerator + ConsensusType string + ImportModeNoSigCheck bool +} + +// NewMultiSignerContainer creates the multiSignerContainer container +func NewMultiSignerContainer(args MultiSigArgs, multiSignerConfig []config.MultiSignerConfig) (*container, error) { + if len(multiSignerConfig) == 0 { + return nil, errors.ErrMissingMultiSignerConfig + } + + c := &container{ + multiSigners: make([]*epochMultiSigner, len(multiSignerConfig)), + } + + sortedMultiSignerConfig := sortMultiSignerConfig(multiSignerConfig) + if sortedMultiSignerConfig[0].EnableEpoch != 0 { + return nil, errors.ErrMissingEpochZeroMultiSignerConfig + } + + for i, mConfig := range sortedMultiSignerConfig { + multiSigner, err := createMultiSigner(mConfig.Type, args) + if err != nil { + return nil, err + } + + c.multiSigners[i] = &epochMultiSigner{ + multiSigner: multiSigner, + epoch: mConfig.EnableEpoch, + } + } + + return c, nil +} + +// GetMultiSigner returns the multiSigner configured for the given epoch +func (c *container) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { + c.mutSigners.RLock() + defer c.mutSigners.RUnlock() + + for i := len(c.multiSigners) - 1; i >= 0; i-- { + if epoch >= c.multiSigners[i].epoch { + return c.multiSigners[i].multiSigner, nil + } + } + return nil, errors.ErrMissingMultiSigner +} + +// IsInterfaceNil returns true if the underlying object is nil +func (c *container) IsInterfaceNil() bool { + return c == nil +} + +func createMultiSigner(multiSigType string, args MultiSigArgs) (crypto.MultiSigner, error) { + if args.ImportModeNoSigCheck { + log.Warn("using disabled multi signer because the node is running in import-db 'turbo mode'") + return &disabledMultiSig.DisabledMultiSig{}, nil + } + + switch args.ConsensusType { + case consensus.BlsConsensusType: + hasher, err := getMultiSigHasherFromConfig(args) + if err != nil { + return nil, err + } + blsSigner, err := createLowLevelSigner(multiSigType, hasher) + if err != nil { + return nil, err + } + return multisig.NewBLSMultisig(blsSigner, args.BlSignKeyGen) + case disabledSigChecking: + log.Warn("using disabled multi signer") + return &disabledMultiSig.DisabledMultiSig{}, nil + default: + return nil, errors.ErrInvalidConsensusConfig + } +} + +func createLowLevelSigner(multiSigType string, hasher hashing.Hasher) (crypto.LowLevelSignerBLS, error) { + if check.IfNil(hasher) { + return nil, errors.ErrNilHasher + } + + switch multiSigType { + case blsNoKOSK: + return &mclMultiSig.BlsMultiSigner{Hasher: hasher}, nil + case blsKOSK: + return &mclMultiSig.BlsMultiSignerKOSK{}, nil + default: + return nil, errors.ErrSignerNotSupported + } +} + +func getMultiSigHasherFromConfig(args MultiSigArgs) (hashing.Hasher, error) { + if args.ConsensusType == consensus.BlsConsensusType && args.MultiSigHasherType != "blake2b" { + return nil, errors.ErrMultiSigHasherMissmatch + } + + switch args.MultiSigHasherType { + case "sha256": + return sha256.NewSha256(), nil + case "blake2b": + if args.ConsensusType == consensus.BlsConsensusType { + return blake2b.NewBlake2bWithSize(mclMultiSig.HasherOutputSize) + } + return blake2b.NewBlake2b(), nil + } + + return nil, errors.ErrMissingMultiHasherConfig +} + +func sortMultiSignerConfig(multiSignerConfig []config.MultiSignerConfig) []config.MultiSignerConfig { + sortedMultiSignerConfig := append([]config.MultiSignerConfig{}, multiSignerConfig...) + sort.Slice(sortedMultiSignerConfig, func(i, j int) bool { + return sortedMultiSignerConfig[i].EnableEpoch < sortedMultiSignerConfig[j].EnableEpoch + }) + + return sortedMultiSignerConfig +} diff --git a/factory/crypto/multiSignerContainer_test.go b/factory/crypto/multiSignerContainer_test.go new file mode 100644 index 00000000000..1d3b52d0941 --- /dev/null +++ b/factory/crypto/multiSignerContainer_test.go @@ -0,0 +1,334 @@ +package crypto + +import ( + "math/rand" + "testing" + + "github.com/ElrondNetwork/elrond-go-core/core/check" + disabledMultiSig "github.com/ElrondNetwork/elrond-go-crypto/signing/disabled/multisig" + mclMultiSig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/multisig" + "github.com/ElrondNetwork/elrond-go/config" + "github.com/ElrondNetwork/elrond-go/consensus" + "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" + "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewMultiSignerContainer(t *testing.T) { + t.Parallel() + + args := createDefaultMultiSignerArgs() + multiSigConfig := createDefaultMultiSignerConfig() + + t.Run("nil multiSigner config should err", func(t *testing.T) { + multiSigContainer, err := NewMultiSignerContainer(args, nil) + + require.Nil(t, multiSigContainer) + require.Equal(t, errors.ErrMissingMultiSignerConfig, err) + }) + t.Run("missing epoch 0 config should err", func(t *testing.T) { + multiSigConfigClone := append([]config.MultiSignerConfig{}, multiSigConfig...) + multiSigConfigClone[0].EnableEpoch = 1 + multiSigContainer, err := NewMultiSignerContainer(args, multiSigConfigClone) + + require.Nil(t, multiSigContainer) + require.Equal(t, errors.ErrMissingEpochZeroMultiSignerConfig, err) + }) + t.Run("invalid multiSigner type should err", func(t *testing.T) { + multiSigConfigClone := append([]config.MultiSignerConfig{}, multiSigConfig...) + multiSigConfigClone[1].Type = "invalid type" + multiSigContainer, err := NewMultiSignerContainer(args, multiSigConfigClone) + + require.Nil(t, multiSigContainer) + require.Equal(t, errors.ErrSignerNotSupported, err) + }) + t.Run("valid params", func(t *testing.T) { + multiSigContainer, err := NewMultiSignerContainer(args, multiSigConfig) + + require.Nil(t, err) + require.NotNil(t, multiSigContainer) + }) +} + +func TestContainer_GetMultiSigner(t *testing.T) { + t.Parallel() + + args := createDefaultMultiSignerArgs() + multiSigConfig := createDefaultMultiSignerConfig() + + t.Run("missing epoch config should err (can only happen if epoch 0 is missing)", func(t *testing.T) { + multiSigContainer, _ := NewMultiSignerContainer(args, multiSigConfig) + multiSigContainer.multiSigners[0].epoch = 1 + + multiSigner, err := multiSigContainer.GetMultiSigner(0) + require.Nil(t, multiSigner) + require.Equal(t, errors.ErrMissingMultiSigner, err) + }) + t.Run("get multi signer OK", func(t *testing.T) { + multiSigContainer, _ := NewMultiSignerContainer(args, multiSigConfig) + + for i := uint32(0); i < 10; i++ { + multiSigner, err := multiSigContainer.GetMultiSigner(i) + require.Nil(t, err) + require.Equal(t, multiSigContainer.multiSigners[0].multiSigner, multiSigner) + } + for i := uint32(10); i < 30; i++ { + multiSigner, err := multiSigContainer.GetMultiSigner(i) + require.Nil(t, err) + require.Equal(t, multiSigContainer.multiSigners[1].multiSigner, multiSigner) + } + }) +} + +func TestContainer_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var msc *container + assert.True(t, check.IfNil(msc)) + + args := createDefaultMultiSignerArgs() + multiSigConfig := createDefaultMultiSignerConfig() + + msc, _ = NewMultiSignerContainer(args, multiSigConfig) + assert.False(t, check.IfNil(msc)) +} + +func TestContainer_createMultiSigner(t *testing.T) { + t.Parallel() + + t.Run("create disabled multi signer", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ImportModeNoSigCheck = true + multiSigType := "KOSK" + multiSigner, err := createMultiSigner(multiSigType, args) + require.Nil(t, err) + _, ok := multiSigner.(*disabledMultiSig.DisabledMultiSig) + require.True(t, ok) + }) + t.Run("invalid consensus config", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = "invalid" + multiSigType := "KOSK" + multiSigner, err := createMultiSigner(multiSigType, args) + require.Nil(t, multiSigner) + require.Equal(t, errors.ErrInvalidConsensusConfig, err) + }) + t.Run("bls consensus type invalid hasher config", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = consensus.BlsConsensusType + args.MultiSigHasherType = "sha256" + multiSigType := "KOSK" + multiSigner, err := createMultiSigner(multiSigType, args) + require.Nil(t, multiSigner) + require.Equal(t, errors.ErrMultiSigHasherMissmatch, err) + }) + t.Run("bls consensus type signer not supported", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = consensus.BlsConsensusType + args.MultiSigHasherType = "blake2b" + multiSigType := "not supported" + multiSigner, err := createMultiSigner(multiSigType, args) + require.Nil(t, multiSigner) + require.Equal(t, errors.ErrSignerNotSupported, err) + }) + t.Run("bls consensus type KOSK OK", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = consensus.BlsConsensusType + args.MultiSigHasherType = "blake2b" + multiSigType := blsKOSK + multiSigner, err := createMultiSigner(multiSigType, args) + require.Nil(t, err) + require.NotNil(t, multiSigner) + }) + t.Run("bls consensus type no-KOSK OK", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = consensus.BlsConsensusType + args.MultiSigHasherType = "blake2b" + multiSigType := blsNoKOSK + multiSigner, err := createMultiSigner(multiSigType, args) + require.Nil(t, err) + require.NotNil(t, multiSigner) + }) + t.Run("disabledSigChecking", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = disabledSigChecking + multiSigType := blsNoKOSK + multiSigner, err := createMultiSigner(multiSigType, args) + require.Nil(t, err) + require.NotNil(t, multiSigner) + + _, ok := multiSigner.(*disabledMultiSig.DisabledMultiSig) + require.True(t, ok) + }) +} + +func TestContainer_createLowLevelSigner(t *testing.T) { + t.Parallel() + + hasher := &hashingMocks.HasherMock{} + t.Run("nil hasher should err", func(t *testing.T) { + llSig, err := createLowLevelSigner(blsKOSK, nil) + require.Nil(t, llSig) + require.Equal(t, errors.ErrNilHasher, err) + }) + t.Run("not supported multiSig type should err", func(t *testing.T) { + llSig, err := createLowLevelSigner("not supported", hasher) + require.Nil(t, llSig) + require.Equal(t, errors.ErrSignerNotSupported, err) + }) + t.Run("multiSig of type no KOSK", func(t *testing.T) { + llSig, err := createLowLevelSigner(blsNoKOSK, hasher) + require.Nil(t, err) + _, ok := llSig.(*mclMultiSig.BlsMultiSigner) + require.True(t, ok) + }) + t.Run("multiSig of type KOSK", func(t *testing.T) { + llSig, err := createLowLevelSigner(blsKOSK, hasher) + require.Nil(t, err) + _, ok := llSig.(*mclMultiSig.BlsMultiSignerKOSK) + require.True(t, ok) + }) +} + +func TestContainer_getMultiSigHasherFromConfig(t *testing.T) { + t.Parallel() + + t.Run("mismatch config consensus type and hasher type", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = consensus.BlsConsensusType + args.MultiSigHasherType = "sha256" + hasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, hasher) + require.Equal(t, errors.ErrMultiSigHasherMissmatch, err) + }) + t.Run("sha256 config", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = "dummy config" + args.MultiSigHasherType = "sha256" + hasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, err) + require.NotNil(t, hasher) + }) + t.Run("invalid hasher config", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = "dummy config" + args.MultiSigHasherType = "unknown" + hasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, hasher) + require.Equal(t, errors.ErrMissingMultiHasherConfig, err) + }) + t.Run("blake2b config and bls consensus", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = consensus.BlsConsensusType + args.MultiSigHasherType = "blake2b" + hasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, err) + require.NotNil(t, hasher) + }) + t.Run("blake2b config and non-bls consensus", func(t *testing.T) { + args := createDefaultMultiSignerArgs() + args.ConsensusType = "dummy config" + args.MultiSigHasherType = "blake2b" + hasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, err) + require.NotNil(t, hasher) + }) +} + +func TestContainer_sortMultiSignerConfig(t *testing.T) { + multiSignersOrderedConfig := []config.MultiSignerConfig{ + { + EnableEpoch: 2, + Type: "KOSK", + }, + { + EnableEpoch: 10, + Type: "no-KOSK", + }, + { + EnableEpoch: 100, + Type: "BN", + }, + { + EnableEpoch: 200, + Type: "DUMMY", + }, + } + + for i := 0; i < 20; i++ { + shuffledConfig := append([]config.MultiSignerConfig{}, multiSignersOrderedConfig...) + rand.Shuffle(len(shuffledConfig), func(i, j int) { + shuffledConfig[i], shuffledConfig[j] = shuffledConfig[j], shuffledConfig[i] + }) + sortedConfig := sortMultiSignerConfig(shuffledConfig) + require.Equal(t, multiSignersOrderedConfig, sortedConfig) + } +} + +func Test_getMultiSigHasherFromConfigInvalidHasherShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := createDefaultMultiSignerArgs() + args.ConsensusType = "" + args.MultiSigHasherType = "" + + multiSigHasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, multiSigHasher) + require.Equal(t, errors.ErrMissingMultiHasherConfig, err) +} + +func Test_getMultiSigHasherFromConfigMismatchConsensusTypeMultiSigHasher(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := createDefaultMultiSignerArgs() + args.MultiSigHasherType = "sha256" + + multiSigHasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, multiSigHasher) + require.Equal(t, errors.ErrMultiSigHasherMissmatch, err) +} + +func Test_getMultiSigHasherFromConfigOK(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := createDefaultMultiSignerArgs() + args.ConsensusType = "bls" + args.MultiSigHasherType = "blake2b" + + multiSigHasher, err := getMultiSigHasherFromConfig(args) + require.Nil(t, err) + require.NotNil(t, multiSigHasher) +} + +func createDefaultMultiSignerArgs() MultiSigArgs { + return MultiSigArgs{ + MultiSigHasherType: "blake2b", + BlSignKeyGen: &cryptoMocks.KeyGenStub{}, + ConsensusType: "bls", + ImportModeNoSigCheck: false, + } +} + +func createDefaultMultiSignerConfig() []config.MultiSignerConfig { + return []config.MultiSignerConfig{ + { + EnableEpoch: 0, + Type: "no-KOSK", + }, + { + EnableEpoch: 10, + Type: "KOSK", + }, + } +} diff --git a/factory/dataComponents.go b/factory/data/dataComponents.go similarity index 90% rename from factory/dataComponents.go rename to factory/data/dataComponents.go index ff1df68fb5c..b2056f10da6 100644 --- a/factory/dataComponents.go +++ b/factory/data/dataComponents.go @@ -1,4 +1,4 @@ -package factory +package data import ( "fmt" @@ -6,14 +6,16 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/data" + logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/dataRetriever" "github.com/ElrondNetwork/elrond-go/dataRetriever/blockchain" dataRetrieverFactory "github.com/ElrondNetwork/elrond-go/dataRetriever/factory" "github.com/ElrondNetwork/elrond-go/dataRetriever/provider" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/storage/factory" + storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" ) // DataComponentsFactoryArgs holds the arguments needed for creating a data components factory @@ -21,8 +23,8 @@ type DataComponentsFactoryArgs struct { Config config.Config PrefsConfig config.PreferencesConfig ShardCoordinator sharding.Coordinator - Core CoreComponentsHolder - EpochStartNotifier EpochStartNotifier + Core factory.CoreComponentsHolder + EpochStartNotifier factory.EpochStartNotifier CurrentEpoch uint32 CreateTrieEpochRootHashStorer bool } @@ -31,8 +33,8 @@ type dataComponentsFactory struct { config config.Config prefsConfig config.PreferencesConfig shardCoordinator sharding.Coordinator - core CoreComponentsHolder - epochStartNotifier EpochStartNotifier + core factory.CoreComponentsHolder + epochStartNotifier factory.EpochStartNotifier currentEpoch uint32 createTrieEpochRootHashStorer bool } @@ -42,9 +44,11 @@ type dataComponents struct { blkc data.ChainHandler store dataRetriever.StorageService datapool dataRetriever.PoolsHolder - miniBlocksProvider MiniBlockProvider + miniBlocksProvider factory.MiniBlockProvider } +var log = logger.GetOrCreate("factory") + // NewDataComponentsFactory will return a new instance of dataComponentsFactory func NewDataComponentsFactory(args DataComponentsFactoryArgs) (*dataComponentsFactory, error) { if check.IfNil(args.ShardCoordinator) { @@ -148,7 +152,7 @@ func (dcf *dataComponentsFactory) createBlockChainFromConfig() (data.ChainHandle } func (dcf *dataComponentsFactory) createDataStoreFromConfig() (dataRetriever.StorageService, error) { - storageServiceFactory, err := factory.NewStorageServiceFactory( + storageServiceFactory, err := storageFactory.NewStorageServiceFactory( &dcf.config, &dcf.prefsConfig, dcf.shardCoordinator, @@ -157,7 +161,7 @@ func (dcf *dataComponentsFactory) createDataStoreFromConfig() (dataRetriever.Sto dcf.core.NodeTypeProvider(), dcf.currentEpoch, dcf.createTrieEpochRootHashStorer, - factory.ProcessStorageService, + storageFactory.ProcessStorageService, ) if err != nil { return nil, err diff --git a/factory/dataComponentsHandler.go b/factory/data/dataComponentsHandler.go similarity index 91% rename from factory/dataComponentsHandler.go rename to factory/data/dataComponentsHandler.go index 7bc4acf0b00..189eaea1b75 100644 --- a/factory/dataComponentsHandler.go +++ b/factory/data/dataComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package data import ( "fmt" @@ -8,11 +8,12 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data" "github.com/ElrondNetwork/elrond-go/dataRetriever" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" ) -var _ ComponentHandler = (*managedDataComponents)(nil) -var _ DataComponentsHolder = (*managedDataComponents)(nil) -var _ DataComponentsHandler = (*managedDataComponents)(nil) +var _ factory.ComponentHandler = (*managedDataComponents)(nil) +var _ factory.DataComponentsHolder = (*managedDataComponents)(nil) +var _ factory.DataComponentsHandler = (*managedDataComponents)(nil) // managedDataComponents creates the data components handler that can create, close and access the data components type managedDataComponents struct { @@ -133,7 +134,7 @@ func (mdc *managedDataComponents) Datapool() dataRetriever.PoolsHolder { } // MiniBlocksProvider returns the MiniBlockProvider -func (mdc *managedDataComponents) MiniBlocksProvider() MiniBlockProvider { +func (mdc *managedDataComponents) MiniBlocksProvider() factory.MiniBlockProvider { mdc.mutDataComponents.RLock() defer mdc.mutDataComponents.RUnlock() @@ -170,5 +171,5 @@ func (mdc *managedDataComponents) IsInterfaceNil() bool { // String returns the name of the component func (mdc *managedDataComponents) String() string { - return dataComponentsName + return factory.DataComponentsName } diff --git a/factory/dataComponentsHandler_test.go b/factory/data/dataComponentsHandler_test.go similarity index 63% rename from factory/dataComponentsHandler_test.go rename to factory/data/dataComponentsHandler_test.go index 9c4bfa7f2de..51e8b522426 100644 --- a/factory/dataComponentsHandler_test.go +++ b/factory/data/dataComponentsHandler_test.go @@ -1,11 +1,12 @@ -package factory_test +package data_test import ( "testing" "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/factory" + dataComp "github.com/ElrondNetwork/elrond-go/factory/data" "github.com/ElrondNetwork/elrond-go/factory/mock" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -16,12 +17,12 @@ func TestManagedDataComponents_CreateWithInvalidArgsShouldErr(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getDataArgs(coreComponents, shardCoordinator) + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) args.Config.ShardHdrNonceHashStorage = config.StorageConfig{} - dataComponentsFactory, _ := factory.NewDataComponentsFactory(args) - managedDataComponents, err := factory.NewManagedDataComponents(dataComponentsFactory) + dataComponentsFactory, _ := dataComp.NewDataComponentsFactory(args) + managedDataComponents, err := dataComp.NewManagedDataComponents(dataComponentsFactory) require.NoError(t, err) err = managedDataComponents.Create() require.Error(t, err) @@ -34,11 +35,11 @@ func TestManagedDataComponents_CreateShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getDataArgs(coreComponents, shardCoordinator) - dataComponentsFactory, _ := factory.NewDataComponentsFactory(args) - managedDataComponents, err := factory.NewManagedDataComponents(dataComponentsFactory) + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) + dataComponentsFactory, _ := dataComp.NewDataComponentsFactory(args) + managedDataComponents, err := dataComp.NewManagedDataComponents(dataComponentsFactory) require.NoError(t, err) require.Nil(t, managedDataComponents.Blockchain()) require.Nil(t, managedDataComponents.StorageService()) @@ -57,11 +58,11 @@ func TestManagedDataComponents_Close(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getDataArgs(coreComponents, shardCoordinator) - dataComponentsFactory, _ := factory.NewDataComponentsFactory(args) - managedDataComponents, _ := factory.NewManagedDataComponents(dataComponentsFactory) + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) + dataComponentsFactory, _ := dataComp.NewDataComponentsFactory(args) + managedDataComponents, _ := dataComp.NewManagedDataComponents(dataComponentsFactory) err := managedDataComponents.Create() require.NoError(t, err) @@ -76,11 +77,11 @@ func TestManagedDataComponents_Clone(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getDataArgs(coreComponents, shardCoordinator) - dataComponentsFactory, _ := factory.NewDataComponentsFactory(args) - managedDataComponents, _ := factory.NewManagedDataComponents(dataComponentsFactory) + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) + dataComponentsFactory, _ := dataComp.NewDataComponentsFactory(args) + managedDataComponents, _ := dataComp.NewManagedDataComponents(dataComponentsFactory) clonedBeforeCreate := managedDataComponents.Clone() require.Equal(t, managedDataComponents, clonedBeforeCreate) diff --git a/factory/dataComponents_test.go b/factory/data/dataComponents_test.go similarity index 60% rename from factory/dataComponents_test.go rename to factory/data/dataComponents_test.go index 5717f55402e..8c9440339b2 100644 --- a/factory/dataComponents_test.go +++ b/factory/data/dataComponents_test.go @@ -1,4 +1,4 @@ -package factory_test +package data_test import ( "testing" @@ -6,10 +6,9 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" + dataComp "github.com/ElrondNetwork/elrond-go/factory/data" "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/testscommon" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -20,11 +19,11 @@ func TestNewDataComponentsFactory_NilShardCoordinatorShouldErr(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - coreComponents := getCoreComponents() - args := getDataArgs(coreComponents, shardCoordinator) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) args.ShardCoordinator = nil - dcf, err := factory.NewDataComponentsFactory(args) + dcf, err := dataComp.NewDataComponentsFactory(args) require.Nil(t, dcf) require.Equal(t, errors.ErrNilShardCoordinator, err) } @@ -36,10 +35,10 @@ func TestNewDataComponentsFactory_NilCoreComponentsShouldErr(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getDataArgs(nil, shardCoordinator) + args := componentsMock.GetDataArgs(nil, shardCoordinator) args.Core = nil - dcf, err := factory.NewDataComponentsFactory(args) + dcf, err := dataComp.NewDataComponentsFactory(args) require.Nil(t, dcf) require.Equal(t, errors.ErrNilCoreComponents, err) } @@ -51,11 +50,11 @@ func TestNewDataComponentsFactory_NilEpochStartNotifierShouldErr(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - coreComponents := getCoreComponents() - args := getDataArgs(coreComponents, shardCoordinator) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) args.EpochStartNotifier = nil - dcf, err := factory.NewDataComponentsFactory(args) + dcf, err := dataComp.NewDataComponentsFactory(args) require.Nil(t, dcf) require.Equal(t, errors.ErrNilEpochStartNotifier, err) } @@ -67,9 +66,9 @@ func TestNewDataComponentsFactory_OkValsShouldWork(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - coreComponents := getCoreComponents() - args := getDataArgs(coreComponents, shardCoordinator) - dcf, err := factory.NewDataComponentsFactory(args) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) + dcf, err := dataComp.NewDataComponentsFactory(args) require.NoError(t, err) require.NotNil(t, dcf) } @@ -81,10 +80,10 @@ func TestDataComponentsFactory_CreateShouldErrDueBadConfig(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - coreComponents := getCoreComponents() - args := getDataArgs(coreComponents, shardCoordinator) + coreComponents := componentsMock.GetCoreComponents() + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) args.Config.ShardHdrNonceHashStorage = config.StorageConfig{} - dcf, err := factory.NewDataComponentsFactory(args) + dcf, err := dataComp.NewDataComponentsFactory(args) require.NoError(t, err) dc, err := dcf.Create() @@ -98,10 +97,10 @@ func TestDataComponentsFactory_CreateForShardShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getDataArgs(coreComponents, shardCoordinator) - dcf, err := factory.NewDataComponentsFactory(args) + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) + dcf, err := dataComp.NewDataComponentsFactory(args) require.NoError(t, err) dc, err := dcf.Create() @@ -115,12 +114,12 @@ func TestDataComponentsFactory_CreateForMetaShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) shardCoordinator.CurrentShard = core.MetachainShardId - args := getDataArgs(coreComponents, shardCoordinator) + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) - dcf, err := factory.NewDataComponentsFactory(args) + dcf, err := dataComp.NewDataComponentsFactory(args) require.NoError(t, err) dc, err := dcf.Create() require.NoError(t, err) @@ -134,27 +133,13 @@ func TestManagedDataComponents_CloseShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getDataArgs(coreComponents, shardCoordinator) - dcf, _ := factory.NewDataComponentsFactory(args) + args := componentsMock.GetDataArgs(coreComponents, shardCoordinator) + dcf, _ := dataComp.NewDataComponentsFactory(args) dc, _ := dcf.Create() err := dc.Close() require.NoError(t, err) } - -func getDataArgs(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) factory.DataComponentsFactoryArgs { - return factory.DataComponentsFactoryArgs{ - Config: testscommon.GetGeneralConfig(), - PrefsConfig: config.PreferencesConfig{ - FullArchive: false, - }, - ShardCoordinator: shardCoordinator, - Core: coreComponents, - EpochStartNotifier: &mock.EpochStartNotifierStub{}, - CurrentEpoch: 0, - CreateTrieEpochRootHashStorer: false, - } -} diff --git a/factory/heartbeatComponents.go b/factory/heartbeat/heartbeatComponents.go similarity index 92% rename from factory/heartbeatComponents.go rename to factory/heartbeat/heartbeatComponents.go index f08078c3128..63935321ad9 100644 --- a/factory/heartbeatComponents.go +++ b/factory/heartbeat/heartbeatComponents.go @@ -1,4 +1,4 @@ -package factory +package heartbeat import ( "context" @@ -9,10 +9,12 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/marshal" + logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/dataRetriever" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/heartbeat" heartbeatProcess "github.com/ElrondNetwork/elrond-go/heartbeat/process" heartbeatStorage "github.com/ElrondNetwork/elrond-go/heartbeat/storage" @@ -27,11 +29,11 @@ type HeartbeatComponentsFactoryArgs struct { AppVersion string GenesisTime time.Time RedundancyHandler heartbeat.NodeRedundancyHandler - CoreComponents CoreComponentsHolder - DataComponents DataComponentsHolder - NetworkComponents NetworkComponentsHolder - CryptoComponents CryptoComponentsHolder - ProcessComponents ProcessComponentsHolder + CoreComponents factory.CoreComponentsHolder + DataComponents factory.DataComponentsHolder + NetworkComponents factory.NetworkComponentsHolder + CryptoComponents factory.CryptoComponentsHolder + ProcessComponents factory.ProcessComponentsHolder } type heartbeatComponentsFactory struct { @@ -40,21 +42,23 @@ type heartbeatComponentsFactory struct { version string GenesisTime time.Time redundancyHandler heartbeat.NodeRedundancyHandler - coreComponents CoreComponentsHolder - dataComponents DataComponentsHolder - networkComponents NetworkComponentsHolder - cryptoComponents CryptoComponentsHolder - processComponents ProcessComponentsHolder + coreComponents factory.CoreComponentsHolder + dataComponents factory.DataComponentsHolder + networkComponents factory.NetworkComponentsHolder + cryptoComponents factory.CryptoComponentsHolder + processComponents factory.ProcessComponentsHolder } type heartbeatComponents struct { messageHandler heartbeat.MessageHandler - monitor HeartbeatMonitor - sender HeartbeatSender - storer HeartbeatStorer + monitor factory.HeartbeatMonitor + sender factory.HeartbeatSender + storer factory.HeartbeatStorer cancelFunc context.CancelFunc } +var log = logger.GetOrCreate("factory") + // NewHeartbeatComponentsFactory creates the heartbeat components factory func NewHeartbeatComponentsFactory(args HeartbeatComponentsFactoryArgs) (*heartbeatComponentsFactory, error) { @@ -253,7 +257,7 @@ func (hcf *heartbeatComponentsFactory) IsInterfaceNil() bool { return hcf == nil } -func (hcf *heartbeatComponentsFactory) startSendingHeartbeats(ctx context.Context, sender HeartbeatSender, monitor HeartbeatMonitor) { +func (hcf *heartbeatComponentsFactory) startSendingHeartbeats(ctx context.Context, sender factory.HeartbeatSender, monitor factory.HeartbeatMonitor) { r := rand.New(rand.NewSource(time.Now().Unix())) cfg := hcf.config.Heartbeat diff --git a/factory/heartbeatComponentsHandler.go b/factory/heartbeat/heartbeatComponentsHandler.go similarity index 85% rename from factory/heartbeatComponentsHandler.go rename to factory/heartbeat/heartbeatComponentsHandler.go index 4edd75cb2a6..13e49cb9622 100644 --- a/factory/heartbeatComponentsHandler.go +++ b/factory/heartbeat/heartbeatComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package heartbeat import ( "fmt" @@ -6,12 +6,13 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/heartbeat" ) -var _ ComponentHandler = (*managedHeartbeatComponents)(nil) -var _ HeartbeatComponentsHolder = (*managedHeartbeatComponents)(nil) -var _ HeartbeatComponentsHandler = (*managedHeartbeatComponents)(nil) +var _ factory.ComponentHandler = (*managedHeartbeatComponents)(nil) +var _ factory.HeartbeatComponentsHolder = (*managedHeartbeatComponents)(nil) +var _ factory.HeartbeatComponentsHandler = (*managedHeartbeatComponents)(nil) type managedHeartbeatComponents struct { *heartbeatComponents @@ -100,7 +101,7 @@ func (mhc *managedHeartbeatComponents) MessageHandler() heartbeat.MessageHandler } // Monitor returns the heartbeat monitor -func (mhc *managedHeartbeatComponents) Monitor() HeartbeatMonitor { +func (mhc *managedHeartbeatComponents) Monitor() factory.HeartbeatMonitor { mhc.mutHeartbeatComponents.RLock() defer mhc.mutHeartbeatComponents.RUnlock() @@ -112,7 +113,7 @@ func (mhc *managedHeartbeatComponents) Monitor() HeartbeatMonitor { } // Sender returns the heartbeat sender -func (mhc *managedHeartbeatComponents) Sender() HeartbeatSender { +func (mhc *managedHeartbeatComponents) Sender() factory.HeartbeatSender { mhc.mutHeartbeatComponents.RLock() defer mhc.mutHeartbeatComponents.RUnlock() @@ -124,7 +125,7 @@ func (mhc *managedHeartbeatComponents) Sender() HeartbeatSender { } // Storer returns the heartbeat storer -func (mhc *managedHeartbeatComponents) Storer() HeartbeatStorer { +func (mhc *managedHeartbeatComponents) Storer() factory.HeartbeatStorer { mhc.mutHeartbeatComponents.RLock() defer mhc.mutHeartbeatComponents.RUnlock() @@ -142,5 +143,5 @@ func (mhc *managedHeartbeatComponents) IsInterfaceNil() bool { // String returns the name of the component func (mhc *managedHeartbeatComponents) String() string { - return heartbeatComponentsName + return factory.HeartbeatComponentsName } diff --git a/factory/heartbeatComponentsHandler_test.go b/factory/heartbeat/heartbeatComponentsHandler_test.go similarity index 63% rename from factory/heartbeatComponentsHandler_test.go rename to factory/heartbeat/heartbeatComponentsHandler_test.go index 1b422076e70..336b30e96f7 100644 --- a/factory/heartbeatComponentsHandler_test.go +++ b/factory/heartbeat/heartbeatComponentsHandler_test.go @@ -1,10 +1,11 @@ -package factory_test +package heartbeat_test import ( "testing" - "github.com/ElrondNetwork/elrond-go/factory" + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" "github.com/ElrondNetwork/elrond-go/factory/mock" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -16,10 +17,10 @@ func TestManagedHeartbeatComponents_CreateWithInvalidArgsShouldErr(t *testing.T) } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - heartbeatArgs := getDefaultHeartbeatComponents(shardCoordinator) + heartbeatArgs := componentsMock.GetHeartbeatFactoryArgs(shardCoordinator) heartbeatArgs.Config.Heartbeat.MaxTimeToWaitBetweenBroadcastsInSec = 0 - heartbeatComponentsFactory, _ := factory.NewHeartbeatComponentsFactory(heartbeatArgs) - managedHeartbeatComponents, err := factory.NewManagedHeartbeatComponents(heartbeatComponentsFactory) + heartbeatComponentsFactory, _ := heartbeatComp.NewHeartbeatComponentsFactory(heartbeatArgs) + managedHeartbeatComponents, err := heartbeatComp.NewManagedHeartbeatComponents(heartbeatComponentsFactory) require.NoError(t, err) err = managedHeartbeatComponents.Create() require.Error(t, err) @@ -33,9 +34,9 @@ func TestManagedHeartbeatComponents_CreateShouldWork(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - heartbeatArgs := getDefaultHeartbeatComponents(shardCoordinator) - heartbeatComponentsFactory, _ := factory.NewHeartbeatComponentsFactory(heartbeatArgs) - managedHeartbeatComponents, err := factory.NewManagedHeartbeatComponents(heartbeatComponentsFactory) + heartbeatArgs := componentsMock.GetHeartbeatFactoryArgs(shardCoordinator) + heartbeatComponentsFactory, _ := heartbeatComp.NewHeartbeatComponentsFactory(heartbeatArgs) + managedHeartbeatComponents, err := heartbeatComp.NewManagedHeartbeatComponents(heartbeatComponentsFactory) require.NoError(t, err) require.Nil(t, managedHeartbeatComponents.Monitor()) require.Nil(t, managedHeartbeatComponents.MessageHandler()) @@ -57,9 +58,9 @@ func TestManagedHeartbeatComponents_Close(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - heartbeatArgs := getDefaultHeartbeatComponents(shardCoordinator) - heartbeatComponentsFactory, _ := factory.NewHeartbeatComponentsFactory(heartbeatArgs) - managedHeartbeatComponents, _ := factory.NewManagedHeartbeatComponents(heartbeatComponentsFactory) + heartbeatArgs := componentsMock.GetHeartbeatFactoryArgs(shardCoordinator) + heartbeatComponentsFactory, _ := heartbeatComp.NewHeartbeatComponentsFactory(heartbeatArgs) + managedHeartbeatComponents, _ := heartbeatComp.NewManagedHeartbeatComponents(heartbeatComponentsFactory) err := managedHeartbeatComponents.Create() require.NoError(t, err) diff --git a/factory/heartbeat/heartbeatComponents_test.go b/factory/heartbeat/heartbeatComponents_test.go new file mode 100644 index 00000000000..aaa6c90e8fa --- /dev/null +++ b/factory/heartbeat/heartbeatComponents_test.go @@ -0,0 +1,28 @@ +package heartbeat_test + +import ( + "testing" + + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" + "github.com/ElrondNetwork/elrond-go/factory/mock" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/stretchr/testify/require" +) + +// ------------ Test HeartbeatComponents -------------------- +func TestHeartbeatComponents_CloseShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + heartbeatArgs := componentsMock.GetHeartbeatFactoryArgs(shardCoordinator) + hcf, err := heartbeatComp.NewHeartbeatComponentsFactory(heartbeatArgs) + require.Nil(t, err) + cc, err := hcf.Create() + require.Nil(t, err) + + err = cc.Close() + require.NoError(t, err) +} diff --git a/factory/heartbeatV2Components.go b/factory/heartbeat/heartbeatV2Components.go similarity index 91% rename from factory/heartbeatV2Components.go rename to factory/heartbeat/heartbeatV2Components.go index 8eff99f2a88..6bf2b632564 100644 --- a/factory/heartbeatV2Components.go +++ b/factory/heartbeat/heartbeatV2Components.go @@ -1,6 +1,7 @@ -package factory +package heartbeat import ( + "fmt" "time" "github.com/ElrondNetwork/elrond-go-core/core" @@ -9,6 +10,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/heartbeat/monitor" "github.com/ElrondNetwork/elrond-go/heartbeat/processor" "github.com/ElrondNetwork/elrond-go/heartbeat/sender" @@ -23,12 +25,12 @@ type ArgHeartbeatV2ComponentsFactory struct { Prefs config.Preferences BaseVersion string AppVersion string - BoostrapComponents BootstrapComponentsHolder - CoreComponents CoreComponentsHolder - DataComponents DataComponentsHolder - NetworkComponents NetworkComponentsHolder - CryptoComponents CryptoComponentsHolder - ProcessComponents ProcessComponentsHolder + BoostrapComponents factory.BootstrapComponentsHolder + CoreComponents factory.CoreComponentsHolder + DataComponents factory.DataComponentsHolder + NetworkComponents factory.NetworkComponentsHolder + CryptoComponents factory.CryptoComponentsHolder + ProcessComponents factory.ProcessComponentsHolder } type heartbeatV2ComponentsFactory struct { @@ -36,19 +38,19 @@ type heartbeatV2ComponentsFactory struct { prefs config.Preferences baseVersion string version string - boostrapComponents BootstrapComponentsHolder - coreComponents CoreComponentsHolder - dataComponents DataComponentsHolder - networkComponents NetworkComponentsHolder - cryptoComponents CryptoComponentsHolder - processComponents ProcessComponentsHolder + boostrapComponents factory.BootstrapComponentsHolder + coreComponents factory.CoreComponentsHolder + dataComponents factory.DataComponentsHolder + networkComponents factory.NetworkComponentsHolder + cryptoComponents factory.CryptoComponentsHolder + processComponents factory.ProcessComponentsHolder } type heartbeatV2Components struct { sender update.Closer peerAuthRequestsProcessor update.Closer directConnectionsProcessor update.Closer - monitor HeartbeatV2Monitor + monitor factory.HeartbeatV2Monitor statusHandler update.Closer } @@ -115,6 +117,11 @@ func (hcf *heartbeatV2ComponentsFactory) Create() (*heartbeatV2Components, error } } + cfg := hcf.config.HeartbeatV2 + if cfg.HeartbeatExpiryTimespanInSec <= cfg.PeerAuthenticationTimeBetweenSendsInSec { + return nil, fmt.Errorf("%w, HeartbeatExpiryTimespanInSec must be greater than PeerAuthenticationTimeBetweenSendsInSec", errors.ErrInvalidHeartbeatV2Config) + } + peerSubType := core.RegularPeer if hcf.prefs.Preferences.FullArchive { peerSubType = core.FullHistoryObserver @@ -123,8 +130,6 @@ func (hcf *heartbeatV2ComponentsFactory) Create() (*heartbeatV2Components, error shardC := hcf.boostrapComponents.ShardCoordinator() heartbeatTopic := common.HeartbeatV2Topic + shardC.CommunicationIdentifier(shardC.SelfId()) - cfg := hcf.config.HeartbeatV2 - argPeerTypeProvider := peer.ArgPeerTypeProvider{ NodesCoordinator: hcf.processComponents.NodesCoordinator(), StartEpoch: hcf.processComponents.EpochStartTrigger().MetaEpoch(), diff --git a/factory/heartbeatV2ComponentsHandler.go b/factory/heartbeat/heartbeatV2ComponentsHandler.go similarity index 92% rename from factory/heartbeatV2ComponentsHandler.go rename to factory/heartbeat/heartbeatV2ComponentsHandler.go index 2841f7cff05..6c2130b4e4f 100644 --- a/factory/heartbeatV2ComponentsHandler.go +++ b/factory/heartbeat/heartbeatV2ComponentsHandler.go @@ -1,10 +1,11 @@ -package factory +package heartbeat import ( "sync" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" ) type managedHeartbeatV2Components struct { @@ -56,11 +57,11 @@ func (mhc *managedHeartbeatV2Components) CheckSubcomponents() error { // String returns the name of the component func (mhc *managedHeartbeatV2Components) String() string { - return heartbeatV2ComponentsName + return factory.HeartbeatV2ComponentsName } // Monitor returns the heartbeatV2 monitor -func (mhc *managedHeartbeatV2Components) Monitor() HeartbeatV2Monitor { +func (mhc *managedHeartbeatV2Components) Monitor() factory.HeartbeatV2Monitor { mhc.mutHeartbeatV2Components.Lock() defer mhc.mutHeartbeatV2Components.Unlock() diff --git a/factory/heartbeatV2ComponentsHandler_test.go b/factory/heartbeat/heartbeatV2ComponentsHandler_test.go similarity index 72% rename from factory/heartbeatV2ComponentsHandler_test.go rename to factory/heartbeat/heartbeatV2ComponentsHandler_test.go index 816421ad120..3a094594d62 100644 --- a/factory/heartbeatV2ComponentsHandler_test.go +++ b/factory/heartbeat/heartbeatV2ComponentsHandler_test.go @@ -1,11 +1,11 @@ -package factory_test +package heartbeat_test import ( "testing" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" "github.com/stretchr/testify/assert" ) @@ -19,13 +19,13 @@ func TestManagedHeartbeatV2Components(t *testing.T) { } }() - mhc, err := factory.NewManagedHeartbeatV2Components(nil) + mhc, err := heartbeatComp.NewManagedHeartbeatV2Components(nil) assert.True(t, check.IfNil(mhc)) assert.Equal(t, errors.ErrNilHeartbeatV2ComponentsFactory, err) args := createMockHeartbeatV2ComponentsFactoryArgs() - hcf, _ := factory.NewHeartbeatV2ComponentsFactory(args) - mhc, err = factory.NewManagedHeartbeatV2Components(hcf) + hcf, _ := heartbeatComp.NewHeartbeatV2ComponentsFactory(args) + mhc, err = heartbeatComp.NewManagedHeartbeatV2Components(hcf) assert.False(t, check.IfNil(mhc)) assert.Nil(t, err) diff --git a/factory/heartbeat/heartbeatV2Components_test.go b/factory/heartbeat/heartbeatV2Components_test.go new file mode 100644 index 00000000000..cccfd279954 --- /dev/null +++ b/factory/heartbeat/heartbeatV2Components_test.go @@ -0,0 +1,117 @@ +package heartbeat_test + +import ( + "errors" + "testing" + + "github.com/ElrondNetwork/elrond-go-core/core/check" + "github.com/ElrondNetwork/elrond-go/config" + errErd "github.com/ElrondNetwork/elrond-go/errors" + bootstrapComp "github.com/ElrondNetwork/elrond-go/factory/bootstrap" + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" + "github.com/ElrondNetwork/elrond-go/factory/mock" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/stretchr/testify/assert" +) + +func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2ComponentsFactory { + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + bootStrapArgs := componentsMock.GetBootStrapFactoryArgs() + bootstrapComponentsFactory, _ := bootstrapComp.NewBootstrapComponentsFactory(bootStrapArgs) + bootstrapC, _ := bootstrapComp.NewTestManagedBootstrapComponents(bootstrapComponentsFactory) + _ = bootstrapC.Create() + + _ = bootstrapC.SetShardCoordinator(shardCoordinator) + + coreC := componentsMock.GetCoreComponents() + networkC := componentsMock.GetNetworkComponents() + dataC := componentsMock.GetDataComponents(coreC, shardCoordinator) + cryptoC := componentsMock.GetCryptoComponents(coreC) + stateC := componentsMock.GetStateComponents(coreC, shardCoordinator) + processC := componentsMock.GetProcessComponents(shardCoordinator, coreC, networkC, dataC, cryptoC, stateC) + return heartbeatComp.ArgHeartbeatV2ComponentsFactory{ + Config: config.Config{ + HeartbeatV2: config.HeartbeatV2Config{ + PeerAuthenticationTimeBetweenSendsInSec: 1, + PeerAuthenticationTimeBetweenSendsWhenErrorInSec: 1, + PeerAuthenticationThresholdBetweenSends: 0.1, + HeartbeatTimeBetweenSendsInSec: 1, + HeartbeatTimeBetweenSendsWhenErrorInSec: 1, + HeartbeatThresholdBetweenSends: 0.1, + HeartbeatExpiryTimespanInSec: 30, + MinPeersThreshold: 0.8, + DelayBetweenRequestsInSec: 10, + MaxTimeoutInSec: 60, + DelayBetweenConnectionNotificationsInSec: 5, + MaxMissingKeysInRequest: 100, + MaxDurationPeerUnresponsiveInSec: 10, + HideInactiveValidatorIntervalInSec: 60, + HardforkTimeBetweenSendsInSec: 5, + TimeBetweenConnectionsMetricsUpdateInSec: 10, + PeerAuthenticationTimeBetweenChecksInSec: 6, + HeartbeatPool: config.CacheConfig{ + Type: "LRU", + Capacity: 1000, + Shards: 1, + }, + }, + Hardfork: config.HardforkConfig{ + PublicKeyToListenFrom: componentsMock.DummyPk, + }, + }, + Prefs: config.Preferences{ + Preferences: config.PreferencesConfig{ + NodeDisplayName: "node", + Identity: "identity", + }, + }, + BaseVersion: "test-base", + AppVersion: "test", + BoostrapComponents: bootstrapC, + CoreComponents: coreC, + DataComponents: dataC, + NetworkComponents: networkC, + CryptoComponents: cryptoC, + ProcessComponents: processC, + } +} + +func Test_heartbeatV2Components_Create(t *testing.T) { + t.Parallel() + + t.Run("invalid config should error", func(t *testing.T) { + t.Parallel() + + args := createMockHeartbeatV2ComponentsFactoryArgs() + args.Config.HeartbeatV2.HeartbeatExpiryTimespanInSec = args.Config.HeartbeatV2.PeerAuthenticationTimeBetweenSendsInSec + hcf, err := heartbeatComp.NewHeartbeatV2ComponentsFactory(args) + assert.False(t, check.IfNil(hcf)) + assert.Nil(t, err) + + hc, err := hcf.Create() + assert.Nil(t, hc) + assert.True(t, errors.Is(err, errErd.ErrInvalidHeartbeatV2Config)) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + assert.Fail(t, "should not panic") + } + }() + + args := createMockHeartbeatV2ComponentsFactoryArgs() + hcf, err := heartbeatComp.NewHeartbeatV2ComponentsFactory(args) + assert.False(t, check.IfNil(hcf)) + assert.Nil(t, err) + + hc, err := hcf.Create() + assert.NotNil(t, hc) + assert.Nil(t, err) + + err = hc.Close() + assert.Nil(t, err) + }) +} diff --git a/factory/heartbeatComponents_test.go b/factory/heartbeatComponents_test.go deleted file mode 100644 index 3bff2a064f7..00000000000 --- a/factory/heartbeatComponents_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package factory_test - -import ( - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-crypto" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/stretchr/testify/require" -) - -// ------------ Test HeartbeatComponents -------------------- -func TestHeartbeatComponents_CloseShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - heartbeatArgs := getDefaultHeartbeatComponents(shardCoordinator) - hcf, err := factory.NewHeartbeatComponentsFactory(heartbeatArgs) - require.Nil(t, err) - cc, err := hcf.Create() - require.Nil(t, err) - - err = cc.Close() - require.NoError(t, err) -} - -func getDefaultHeartbeatComponents(shardCoordinator sharding.Coordinator) factory.HeartbeatComponentsFactoryArgs { - coreComponents := getCoreComponents() - networkComponents := getNetworkComponents() - dataComponents := getDataComponents(coreComponents, shardCoordinator) - cryptoComponents := getCryptoComponents(coreComponents) - stateComponents := getStateComponents(coreComponents, shardCoordinator) - processComponents := getProcessComponents( - shardCoordinator, - coreComponents, - networkComponents, - dataComponents, - cryptoComponents, - stateComponents, - ) - - return factory.HeartbeatComponentsFactoryArgs{ - Config: config.Config{ - Heartbeat: config.HeartbeatConfig{ - MinTimeToWaitBetweenBroadcastsInSec: 20, - MaxTimeToWaitBetweenBroadcastsInSec: 25, - HeartbeatRefreshIntervalInSec: 60, - HideInactiveValidatorIntervalInSec: 3600, - DurationToConsiderUnresponsiveInSec: 60, - HeartbeatStorage: config.StorageConfig{ - Cache: config.CacheConfig{ - Capacity: 10000, - Type: "LRU", - Shards: 1, - }, - DB: config.DBConfig{ - FilePath: "HeartbeatStorage", - Type: "MemoryDB", - BatchDelaySeconds: 30, - MaxBatchSize: 6, - MaxOpenFiles: 10, - }, - }, - }, - ValidatorStatistics: config.ValidatorStatisticsConfig{ - CacheRefreshIntervalInSec: uint32(100), - }, - }, - Prefs: config.Preferences{}, - AppVersion: "test", - GenesisTime: time.Time{}, - RedundancyHandler: &mock.RedundancyHandlerStub{ - ObserverPrivateKeyCalled: func() crypto.PrivateKey { - return &mock.PrivateKeyStub{ - GeneratePublicHandler: func() crypto.PublicKey { - return &mock.PublicKeyMock{} - }, - } - }, - }, - CoreComponents: coreComponents, - DataComponents: dataComponents, - NetworkComponents: networkComponents, - CryptoComponents: cryptoComponents, - ProcessComponents: processComponents, - } -} diff --git a/factory/heartbeatV2Components_test.go b/factory/heartbeatV2Components_test.go deleted file mode 100644 index 5e216406af7..00000000000 --- a/factory/heartbeatV2Components_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package factory_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/stretchr/testify/assert" -) - -func createMockHeartbeatV2ComponentsFactoryArgs() factory.ArgHeartbeatV2ComponentsFactory { - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - bootStrapArgs := getBootStrapArgs() - bootstrapComponentsFactory, _ := factory.NewBootstrapComponentsFactory(bootStrapArgs) - bootstrapC, _ := factory.NewManagedBootstrapComponents(bootstrapComponentsFactory) - _ = bootstrapC.Create() - factory.SetShardCoordinator(shardCoordinator, bootstrapC) - - coreC := getCoreComponents() - networkC := getNetworkComponents() - dataC := getDataComponents(coreC, shardCoordinator) - cryptoC := getCryptoComponents(coreC) - stateC := getStateComponents(coreC, shardCoordinator) - processC := getProcessComponents(shardCoordinator, coreC, networkC, dataC, cryptoC, stateC) - return factory.ArgHeartbeatV2ComponentsFactory{ - Config: config.Config{ - HeartbeatV2: config.HeartbeatV2Config{ - PeerAuthenticationTimeBetweenSendsInSec: 1, - PeerAuthenticationTimeBetweenSendsWhenErrorInSec: 1, - PeerAuthenticationThresholdBetweenSends: 0.1, - HeartbeatTimeBetweenSendsInSec: 1, - HeartbeatTimeBetweenSendsWhenErrorInSec: 1, - HeartbeatThresholdBetweenSends: 0.1, - HeartbeatExpiryTimespanInSec: 30, - MinPeersThreshold: 0.8, - DelayBetweenRequestsInSec: 10, - MaxTimeoutInSec: 60, - DelayBetweenConnectionNotificationsInSec: 5, - MaxMissingKeysInRequest: 100, - MaxDurationPeerUnresponsiveInSec: 10, - HideInactiveValidatorIntervalInSec: 60, - HardforkTimeBetweenSendsInSec: 5, - TimeBetweenConnectionsMetricsUpdateInSec: 10, - PeerAuthenticationTimeBetweenChecksInSec: 1, - HeartbeatPool: config.CacheConfig{ - Type: "LRU", - Capacity: 1000, - Shards: 1, - }, - }, - Hardfork: config.HardforkConfig{ - PublicKeyToListenFrom: dummyPk, - }, - }, - Prefs: config.Preferences{ - Preferences: config.PreferencesConfig{ - NodeDisplayName: "node", - Identity: "identity", - }, - }, - BaseVersion: "test-base", - AppVersion: "test", - BoostrapComponents: bootstrapC, - CoreComponents: coreC, - DataComponents: dataC, - NetworkComponents: networkC, - CryptoComponents: cryptoC, - ProcessComponents: processC, - } -} - -func Test_heartbeatV2Components_Create_ShouldWork(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r != nil { - assert.Fail(t, "should not panic") - } - }() - - args := createMockHeartbeatV2ComponentsFactoryArgs() - hcf, err := factory.NewHeartbeatV2ComponentsFactory(args) - assert.False(t, check.IfNil(hcf)) - assert.Nil(t, err) - - hc, err := hcf.Create() - assert.NotNil(t, hc) - assert.Nil(t, err) - - err = hc.Close() - assert.Nil(t, err) -} diff --git a/factory/interface.go b/factory/interface.go index 0361b32b798..d3380c1e20c 100644 --- a/factory/interface.go +++ b/factory/interface.go @@ -16,6 +16,7 @@ import ( crypto "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go/cmd/node/factory" "github.com/ElrondNetwork/elrond-go/common" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/common/statistics" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/dataRetriever" @@ -156,9 +157,10 @@ type CryptoComponentsHolder interface { CryptoParamsHolder TxSingleSigner() crypto.SingleSigner BlockSigner() crypto.SingleSigner - MultiSigner() crypto.MultiSigner + SetMultiSignerContainer(container cryptoCommon.MultiSignerContainer) error + MultiSignerContainer() cryptoCommon.MultiSignerContainer + GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) PeerSignatureHandler() crypto.PeerSignatureHandler - SetMultiSigner(ms crypto.MultiSigner) error BlockSignKeyGen() crypto.KeyGenerator TxSignKeyGen() crypto.KeyGenerator MessageSignVerifier() vm.MessageSignVerifier @@ -507,3 +509,8 @@ type ReceiptsRepository interface { LoadReceipts(header data.HeaderHandler, headerHash []byte) (common.ReceiptsHolder, error) IsInterfaceNil() bool } + +// ProcessDebuggerSetter allows setting a debugger on the process component +type ProcessDebuggerSetter interface { + SetProcessDebugger(debugger process.Debugger) error +} diff --git a/factory/mock/cryptoComponentsMock.go b/factory/mock/cryptoComponentsMock.go index d97fa241a22..8941f1bac30 100644 --- a/factory/mock/cryptoComponentsMock.go +++ b/factory/mock/cryptoComponentsMock.go @@ -1,9 +1,11 @@ package mock import ( + "errors" "sync" "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/heartbeat" "github.com/ElrondNetwork/elrond-go/vm" @@ -18,7 +20,7 @@ type CryptoComponentsMock struct { PubKeyBytes []byte BlockSig crypto.SingleSigner TxSig crypto.SingleSigner - MultiSig crypto.MultiSigner + MultiSigContainer cryptoCommon.MultiSignerContainer PeerSignHandler crypto.PeerSignatureHandler BlKeyGen crypto.KeyGenerator TxKeyGen crypto.KeyGenerator @@ -63,29 +65,41 @@ func (ccm *CryptoComponentsMock) TxSingleSigner() crypto.SingleSigner { return ccm.TxSig } -// MultiSigner - -func (ccm *CryptoComponentsMock) MultiSigner() crypto.MultiSigner { +// MultiSignerContainer - +func (ccm *CryptoComponentsMock) MultiSignerContainer() cryptoCommon.MultiSignerContainer { ccm.mutMultiSig.RLock() defer ccm.mutMultiSig.RUnlock() - return ccm.MultiSig + return ccm.MultiSigContainer } -// PeerSignatureHandler - -func (ccm *CryptoComponentsMock) PeerSignatureHandler() crypto.PeerSignatureHandler { +// SetMultiSignerContainer - +func (ccm *CryptoComponentsMock) SetMultiSignerContainer(ms cryptoCommon.MultiSignerContainer) error { + ccm.mutMultiSig.Lock() + ccm.MultiSigContainer = ms + ccm.mutMultiSig.Unlock() + + return nil +} + +// GetMultiSigner - +func (ccm *CryptoComponentsMock) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { ccm.mutMultiSig.RLock() defer ccm.mutMultiSig.RUnlock() - return ccm.PeerSignHandler + if ccm.MultiSigContainer == nil { + return nil, errors.New("nil multi sig container") + } + + return ccm.MultiSigContainer.GetMultiSigner(epoch) } -// SetMultiSigner - -func (ccm *CryptoComponentsMock) SetMultiSigner(ms crypto.MultiSigner) error { - ccm.mutMultiSig.Lock() - ccm.MultiSig = ms - ccm.mutMultiSig.Unlock() +// PeerSignatureHandler - +func (ccm *CryptoComponentsMock) PeerSignatureHandler() crypto.PeerSignatureHandler { + ccm.mutMultiSig.RLock() + defer ccm.mutMultiSig.RUnlock() - return nil + return ccm.PeerSignHandler } // BlockSignKeyGen - @@ -123,7 +137,7 @@ func (ccm *CryptoComponentsMock) Clone() interface{} { PubKeyBytes: ccm.PubKeyBytes, BlockSig: ccm.BlockSig, TxSig: ccm.TxSig, - MultiSig: ccm.MultiSig, + MultiSigContainer: ccm.MultiSigContainer, PeerSignHandler: ccm.PeerSignHandler, BlKeyGen: ccm.BlKeyGen, TxKeyGen: ccm.TxKeyGen, diff --git a/factory/mock/triesHolderStub.go b/factory/mock/triesHolderStub.go deleted file mode 100644 index 42ba4079a6c..00000000000 --- a/factory/mock/triesHolderStub.go +++ /dev/null @@ -1,57 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go/common" - trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" -) - -// TriesHolderStub - -type TriesHolderStub struct { - PutCalled func([]byte, common.Trie) - RemoveCalled func([]byte, common.Trie) - GetCalled func([]byte) common.Trie - GetAllCalled func() []common.Trie - ResetCalled func() -} - -// Put - -func (ths *TriesHolderStub) Put(key []byte, trie common.Trie) { - if ths.PutCalled != nil { - ths.PutCalled(key, trie) - } -} - -// Replace - -func (ths *TriesHolderStub) Replace(key []byte, trie common.Trie) { - if ths.RemoveCalled != nil { - ths.RemoveCalled(key, trie) - } -} - -// Get - -func (ths *TriesHolderStub) Get(key []byte) common.Trie { - if ths.GetCalled != nil { - return ths.GetCalled(key) - } - return &trieMock.TrieStub{} -} - -// GetAll - -func (ths *TriesHolderStub) GetAll() []common.Trie { - if ths.GetAllCalled != nil { - return ths.GetAllCalled() - } - return nil -} - -// Reset - -func (ths *TriesHolderStub) Reset() { - if ths.ResetCalled != nil { - ths.ResetCalled() - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ths *TriesHolderStub) IsInterfaceNil() bool { - return ths == nil -} diff --git a/factory/network/export_test.go b/factory/network/export_test.go new file mode 100644 index 00000000000..da06513cb6f --- /dev/null +++ b/factory/network/export_test.go @@ -0,0 +1,6 @@ +package network + +// SetListenAddress - +func (ncf *networkComponentsFactory) SetListenAddress(address string) { + ncf.listenAddress = address +} diff --git a/factory/networkComponents.go b/factory/network/networkComponents.go similarity index 82% rename from factory/networkComponents.go rename to factory/network/networkComponents.go index 40ed370d7e3..e57ec119095 100644 --- a/factory/networkComponents.go +++ b/factory/network/networkComponents.go @@ -1,4 +1,4 @@ -package factory +package network import ( "context" @@ -8,25 +8,26 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/marshal" + logger "github.com/ElrondNetwork/elrond-go-logger" + "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/debug/antiflood" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - peersHolder "github.com/ElrondNetwork/elrond-go/p2p/peersHolder" - "github.com/ElrondNetwork/elrond-go/p2p/rating" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/rating/peerHonesty" antifloodFactory "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/factory" + "github.com/ElrondNetwork/elrond-go/storage/cache" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // NetworkComponentsFactoryArgs holds the arguments to create a network component handler instance type NetworkComponentsFactoryArgs struct { - P2pConfig config.P2PConfig + P2pConfig p2pConfig.P2PConfig MainConfig config.Config RatingsConfig config.RatingsConfig StatusHandler core.AppStatusHandler @@ -36,10 +37,11 @@ type NetworkComponentsFactoryArgs struct { BootstrapWaitTime time.Duration NodeOperationMode p2p.NodeOperation ConnectionWatcherType string + P2pKeyPemFileName string } type networkComponentsFactory struct { - p2pConfig config.P2PConfig + p2pConfig p2pConfig.P2PConfig mainConfig config.Config ratingsConfig config.RatingsConfig statusHandler core.AppStatusHandler @@ -50,24 +52,27 @@ type networkComponentsFactory struct { bootstrapWaitTime time.Duration nodeOperationMode p2p.NodeOperation connectionWatcherType string + p2pKeyPemFileName string } // networkComponents struct holds the network components type networkComponents struct { netMessenger p2p.Messenger - inputAntifloodHandler P2PAntifloodHandler - outputAntifloodHandler P2PAntifloodHandler + inputAntifloodHandler factory.P2PAntifloodHandler + outputAntifloodHandler factory.P2PAntifloodHandler pubKeyTimeCacher process.TimeCacher topicFloodPreventer process.TopicFloodPreventer floodPreventers []process.FloodPreventer peerBlackListHandler process.PeerBlackListCacher antifloodConfig config.AntifloodConfig peerHonestyHandler consensus.PeerHonestyHandler - peersHolder PreferredPeersHolderHandler + peersHolder factory.PreferredPeersHolderHandler peersRatingHandler p2p.PeersRatingHandler closeFunc context.CancelFunc } +var log = logger.GetOrCreate("factory") + // NewNetworkComponentsFactory returns a new instance of a network components factory func NewNetworkComponentsFactory( args NetworkComponentsFactoryArgs, @@ -88,40 +93,46 @@ func NewNetworkComponentsFactory( marshalizer: args.Marshalizer, mainConfig: args.MainConfig, statusHandler: args.StatusHandler, - listenAddress: libp2p.ListenAddrWithIp4AndTcp, + listenAddress: p2p.ListenAddrWithIp4AndTcp, syncer: args.Syncer, bootstrapWaitTime: args.BootstrapWaitTime, preferredPeersSlices: args.PreferredPeersSlices, nodeOperationMode: args.NodeOperationMode, connectionWatcherType: args.ConnectionWatcherType, + p2pKeyPemFileName: args.P2pKeyPemFileName, }, nil } // Create creates and returns the network components func (ncf *networkComponentsFactory) Create() (*networkComponents, error) { - ph, err := peersHolder.NewPeersHolder(ncf.preferredPeersSlices) + ph, err := p2p.NewPeersHolder(ncf.preferredPeersSlices) if err != nil { return nil, err } - topRatedCache, err := lrucache.NewCache(ncf.mainConfig.PeersRatingConfig.TopRatedCacheCapacity) + topRatedCache, err := cache.NewLRUCache(ncf.mainConfig.PeersRatingConfig.TopRatedCacheCapacity) if err != nil { return nil, err } - badRatedCache, err := lrucache.NewCache(ncf.mainConfig.PeersRatingConfig.BadRatedCacheCapacity) + badRatedCache, err := cache.NewLRUCache(ncf.mainConfig.PeersRatingConfig.BadRatedCacheCapacity) if err != nil { return nil, err } - argsPeersRatingHandler := rating.ArgPeersRatingHandler{ + argsPeersRatingHandler := p2p.ArgPeersRatingHandler{ TopRatedCache: topRatedCache, BadRatedCache: badRatedCache, } - peersRatingHandler, err := rating.NewPeersRatingHandler(argsPeersRatingHandler) + peersRatingHandler, err := p2p.NewPeersRatingHandler(argsPeersRatingHandler) + if err != nil { + return nil, err + } + + p2pPrivateKeyBytes, err := common.GetSkBytesFromP2pKey(ncf.p2pKeyPemFileName) if err != nil { return nil, err } - arg := libp2p.ArgsNetworkMessenger{ + arg := p2p.ArgsNetworkMessenger{ Marshalizer: ncf.marshalizer, ListenAddress: ncf.listenAddress, P2pConfig: ncf.p2pConfig, @@ -130,8 +141,9 @@ func (ncf *networkComponentsFactory) Create() (*networkComponents, error) { NodeOperationMode: ncf.nodeOperationMode, PeersRatingHandler: peersRatingHandler, ConnectionWatcherType: ncf.connectionWatcherType, + P2pPrivateKeyBytes: p2pPrivateKeyBytes, } - netMessenger, err := libp2p.NewNetworkMessenger(arg) + netMessenger, err := p2p.NewNetworkMessenger(arg) if err != nil { return nil, err } @@ -163,7 +175,7 @@ func (ncf *networkComponentsFactory) Create() (*networkComponents, error) { } } - inputAntifloodHandler, ok := antiFloodComponents.AntiFloodHandler.(P2PAntifloodHandler) + inputAntifloodHandler, ok := antiFloodComponents.AntiFloodHandler.(factory.P2PAntifloodHandler) if !ok { err = errors.ErrWrongTypeAssertion return nil, fmt.Errorf("%w when casting input antiflood handler to P2PAntifloodHandler", err) @@ -175,7 +187,7 @@ func (ncf *networkComponentsFactory) Create() (*networkComponents, error) { return nil, err } - outputAntifloodHandler, ok := outAntifloodHandler.(P2PAntifloodHandler) + outputAntifloodHandler, ok := outAntifloodHandler.(factory.P2PAntifloodHandler) if !ok { err = errors.ErrWrongTypeAssertion return nil, fmt.Errorf("%w when casting output antiflood handler to P2PAntifloodHandler", err) @@ -220,12 +232,12 @@ func (ncf *networkComponentsFactory) createPeerHonestyHandler( pkTimeCache process.TimeCacher, ) (consensus.PeerHonestyHandler, error) { - cache, err := storageUnit.NewCache(storageFactory.GetCacherFromConfig(config.PeerHonesty)) + suCache, err := storageunit.NewCache(storageFactory.GetCacherFromConfig(config.PeerHonesty)) if err != nil { return nil, err } - return peerHonesty.NewP2pPeerHonesty(ratingConfig.PeerHonesty, pkTimeCache, cache) + return peerHonesty.NewP2pPeerHonesty(ratingConfig.PeerHonesty, pkTimeCache, suCache) } // Close closes all underlying components that need closing diff --git a/factory/networkComponentsHandler.go b/factory/network/networkComponentsHandler.go similarity index 88% rename from factory/networkComponentsHandler.go rename to factory/network/networkComponentsHandler.go index 987a04d1a06..3328f9388ab 100644 --- a/factory/networkComponentsHandler.go +++ b/factory/network/networkComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package network import ( "fmt" @@ -6,13 +6,14 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" ) -var _ ComponentHandler = (*managedNetworkComponents)(nil) -var _ NetworkComponentsHolder = (*managedNetworkComponents)(nil) -var _ NetworkComponentsHandler = (*managedNetworkComponents)(nil) +var _ factory.ComponentHandler = (*managedNetworkComponents)(nil) +var _ factory.NetworkComponentsHolder = (*managedNetworkComponents)(nil) +var _ factory.NetworkComponentsHandler = (*managedNetworkComponents)(nil) // managedNetworkComponents creates the data components handler that can create, close and access the data components type managedNetworkComponents struct { @@ -105,7 +106,7 @@ func (mnc *managedNetworkComponents) NetworkMessenger() p2p.Messenger { } // InputAntiFloodHandler returns the input p2p anti-flood handler -func (mnc *managedNetworkComponents) InputAntiFloodHandler() P2PAntifloodHandler { +func (mnc *managedNetworkComponents) InputAntiFloodHandler() factory.P2PAntifloodHandler { mnc.mutNetworkComponents.RLock() defer mnc.mutNetworkComponents.RUnlock() @@ -117,7 +118,7 @@ func (mnc *managedNetworkComponents) InputAntiFloodHandler() P2PAntifloodHandler } // OutputAntiFloodHandler returns the output p2p anti-flood handler -func (mnc *managedNetworkComponents) OutputAntiFloodHandler() P2PAntifloodHandler { +func (mnc *managedNetworkComponents) OutputAntiFloodHandler() factory.P2PAntifloodHandler { mnc.mutNetworkComponents.RLock() defer mnc.mutNetworkComponents.RUnlock() @@ -153,7 +154,7 @@ func (mnc *managedNetworkComponents) PeerBlackListHandler() process.PeerBlackLis } // PeerHonestyHandler returns the blacklist handler -func (mnc *managedNetworkComponents) PeerHonestyHandler() PeerHonestyHandler { +func (mnc *managedNetworkComponents) PeerHonestyHandler() factory.PeerHonestyHandler { mnc.mutNetworkComponents.RLock() defer mnc.mutNetworkComponents.RUnlock() @@ -165,7 +166,7 @@ func (mnc *managedNetworkComponents) PeerHonestyHandler() PeerHonestyHandler { } // PreferredPeersHolderHandler returns the preferred peers holder -func (mnc *managedNetworkComponents) PreferredPeersHolderHandler() PreferredPeersHolderHandler { +func (mnc *managedNetworkComponents) PreferredPeersHolderHandler() factory.PreferredPeersHolderHandler { mnc.mutNetworkComponents.RLock() defer mnc.mutNetworkComponents.RUnlock() @@ -195,5 +196,5 @@ func (mnc *managedNetworkComponents) IsInterfaceNil() bool { // String returns the name of the component func (mnc *managedNetworkComponents) String() string { - return networkComponentsName + return factory.NetworkComponentsName } diff --git a/factory/networkComponentsHandler_test.go b/factory/network/networkComponentsHandler_test.go similarity index 68% rename from factory/networkComponentsHandler_test.go rename to factory/network/networkComponentsHandler_test.go index 173d19615d2..ea6f9b91cb6 100644 --- a/factory/networkComponentsHandler_test.go +++ b/factory/network/networkComponentsHandler_test.go @@ -1,10 +1,11 @@ -package factory_test +package network_test import ( "testing" "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/factory" + networkComp "github.com/ElrondNetwork/elrond-go/factory/network" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -15,10 +16,10 @@ func TestManagedNetworkComponents_CreateWithInvalidArgsShouldErr(t *testing.T) { t.Skip("this is not a short test") } - networkArgs := getNetworkArgs() + networkArgs := componentsMock.GetNetworkFactoryArgs() networkArgs.P2pConfig.Node.Port = "invalid" - networkComponentsFactory, _ := factory.NewNetworkComponentsFactory(networkArgs) - managedNetworkComponents, err := factory.NewManagedNetworkComponents(networkComponentsFactory) + networkComponentsFactory, _ := networkComp.NewNetworkComponentsFactory(networkArgs) + managedNetworkComponents, err := networkComp.NewManagedNetworkComponents(networkComponentsFactory) require.NoError(t, err) err = managedNetworkComponents.Create() require.Error(t, err) @@ -31,9 +32,9 @@ func TestManagedNetworkComponents_CreateShouldWork(t *testing.T) { t.Skip("this is not a short test") } - networkArgs := getNetworkArgs() - networkComponentsFactory, _ := factory.NewNetworkComponentsFactory(networkArgs) - managedNetworkComponents, err := factory.NewManagedNetworkComponents(networkComponentsFactory) + networkArgs := componentsMock.GetNetworkFactoryArgs() + networkComponentsFactory, _ := networkComp.NewNetworkComponentsFactory(networkArgs) + managedNetworkComponents, err := networkComp.NewManagedNetworkComponents(networkComponentsFactory) require.NoError(t, err) require.False(t, check.IfNil(managedNetworkComponents)) require.Nil(t, managedNetworkComponents.NetworkMessenger()) @@ -61,9 +62,9 @@ func TestManagedNetworkComponents_CheckSubcomponents(t *testing.T) { t.Skip("this is not a short test") } - networkArgs := getNetworkArgs() - networkComponentsFactory, _ := factory.NewNetworkComponentsFactory(networkArgs) - managedNetworkComponents, err := factory.NewManagedNetworkComponents(networkComponentsFactory) + networkArgs := componentsMock.GetNetworkFactoryArgs() + networkComponentsFactory, _ := networkComp.NewNetworkComponentsFactory(networkArgs) + managedNetworkComponents, err := networkComp.NewManagedNetworkComponents(networkComponentsFactory) require.NoError(t, err) require.Error(t, managedNetworkComponents.CheckSubcomponents()) @@ -79,9 +80,9 @@ func TestManagedNetworkComponents_Close(t *testing.T) { t.Skip("this is not a short test") } - networkArgs := getNetworkArgs() - networkComponentsFactory, _ := factory.NewNetworkComponentsFactory(networkArgs) - managedNetworkComponents, _ := factory.NewManagedNetworkComponents(networkComponentsFactory) + networkArgs := componentsMock.GetNetworkFactoryArgs() + networkComponentsFactory, _ := networkComp.NewNetworkComponentsFactory(networkArgs) + managedNetworkComponents, _ := networkComp.NewManagedNetworkComponents(networkComponentsFactory) err := managedNetworkComponents.Create() require.NoError(t, err) diff --git a/factory/network/networkComponents_test.go b/factory/network/networkComponents_test.go new file mode 100644 index 00000000000..88999442a8a --- /dev/null +++ b/factory/network/networkComponents_test.go @@ -0,0 +1,100 @@ +package network_test + +import ( + "errors" + "testing" + + "github.com/ElrondNetwork/elrond-go/config" + errErd "github.com/ElrondNetwork/elrond-go/errors" + networkComp "github.com/ElrondNetwork/elrond-go/factory/network" + "github.com/ElrondNetwork/elrond-go/p2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/stretchr/testify/require" +) + +func TestNewNetworkComponentsFactory_NilStatusHandlerShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetNetworkFactoryArgs() + args.StatusHandler = nil + ncf, err := networkComp.NewNetworkComponentsFactory(args) + require.Nil(t, ncf) + require.Equal(t, errErd.ErrNilStatusHandler, err) +} + +func TestNewNetworkComponentsFactory_NilMarshalizerShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetNetworkFactoryArgs() + args.Marshalizer = nil + ncf, err := networkComp.NewNetworkComponentsFactory(args) + require.Nil(t, ncf) + require.True(t, errors.Is(err, errErd.ErrNilMarshalizer)) +} + +func TestNewNetworkComponentsFactory_OkValsShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetNetworkFactoryArgs() + ncf, err := networkComp.NewNetworkComponentsFactory(args) + require.NoError(t, err) + require.NotNil(t, ncf) +} + +func TestNetworkComponentsFactory_CreateShouldErrDueToBadConfig(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetNetworkFactoryArgs() + args.MainConfig = config.Config{} + args.P2pConfig = p2pConfig.P2PConfig{} + + ncf, _ := networkComp.NewNetworkComponentsFactory(args) + + nc, err := ncf.Create() + require.Error(t, err) + require.Nil(t, nc) +} + +func TestNetworkComponentsFactory_CreateShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetNetworkFactoryArgs() + ncf, _ := networkComp.NewNetworkComponentsFactory(args) + ncf.SetListenAddress(p2p.ListenLocalhostAddrWithIp4AndTcp) + + nc, err := ncf.Create() + require.NoError(t, err) + require.NotNil(t, nc) +} + +// ------------ Test NetworkComponents -------------------- +func TestNetworkComponents_CloseShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + args := componentsMock.GetNetworkFactoryArgs() + ncf, _ := networkComp.NewNetworkComponentsFactory(args) + + nc, _ := ncf.Create() + + err := nc.Close() + require.NoError(t, err) +} diff --git a/factory/networkComponents_test.go b/factory/networkComponents_test.go deleted file mode 100644 index db2fba1669f..00000000000 --- a/factory/networkComponents_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package factory_test - -import ( - "errors" - "testing" - - "github.com/ElrondNetwork/elrond-go/config" - errErd "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - statusHandlerMock "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" - "github.com/stretchr/testify/require" -) - -func TestNewNetworkComponentsFactory_NilStatusHandlerShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getNetworkArgs() - args.StatusHandler = nil - ncf, err := factory.NewNetworkComponentsFactory(args) - require.Nil(t, ncf) - require.Equal(t, errErd.ErrNilStatusHandler, err) -} - -func TestNewNetworkComponentsFactory_NilMarshalizerShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getNetworkArgs() - args.Marshalizer = nil - ncf, err := factory.NewNetworkComponentsFactory(args) - require.Nil(t, ncf) - require.True(t, errors.Is(err, errErd.ErrNilMarshalizer)) -} - -func TestNewNetworkComponentsFactory_OkValsShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getNetworkArgs() - ncf, err := factory.NewNetworkComponentsFactory(args) - require.NoError(t, err) - require.NotNil(t, ncf) -} - -func TestNetworkComponentsFactory_CreateShouldErrDueToBadConfig(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getNetworkArgs() - args.MainConfig = config.Config{} - args.P2pConfig = config.P2PConfig{} - - ncf, _ := factory.NewNetworkComponentsFactory(args) - - nc, err := ncf.Create() - require.Error(t, err) - require.Nil(t, nc) -} - -func TestNetworkComponentsFactory_CreateShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getNetworkArgs() - ncf, _ := factory.NewNetworkComponentsFactory(args) - ncf.SetListenAddress(libp2p.ListenLocalhostAddrWithIp4AndTcp) - - nc, err := ncf.Create() - require.NoError(t, err) - require.NotNil(t, nc) -} - -// ------------ Test NetworkComponents -------------------- -func TestNetworkComponents_CloseShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - args := getNetworkArgs() - ncf, _ := factory.NewNetworkComponentsFactory(args) - - nc, _ := ncf.Create() - - err := nc.Close() - require.NoError(t, err) -} - -func getNetworkArgs() factory.NetworkComponentsFactoryArgs { - p2pConfig := config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - Seed: "seed", - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - Type: "optimized", - RefreshIntervalInSec: 10, - ProtocolID: "erd/kad/1.0.0", - InitialPeerList: []string{"peer0", "peer1"}, - BucketSize: 10, - RoutingTableRefreshIntervalInSec: 5, - }, - Sharding: config.ShardingConfig{ - TargetPeerCount: 10, - MaxIntraShardValidators: 10, - MaxCrossShardValidators: 10, - MaxIntraShardObservers: 10, - MaxCrossShardObservers: 10, - MaxSeeders: 2, - Type: "NilListSharder", - AdditionalConnections: config.AdditionalConnectionsConfig{ - MaxFullHistoryObservers: 10, - }, - }, - } - - mainConfig := config.Config{ - PeerHonesty: config.CacheConfig{ - Type: "LRU", - Capacity: 5000, - Shards: 16, - }, - Debug: config.DebugConfig{ - Antiflood: config.AntifloodDebugConfig{ - Enabled: true, - CacheSize: 100, - IntervalAutoPrintInSeconds: 1, - }, - }, - PeersRatingConfig: config.PeersRatingConfig{ - TopRatedCacheCapacity: 1000, - BadRatedCacheCapacity: 1000, - }, - PoolsCleanersConfig: config.PoolsCleanersConfig{ - MaxRoundsToKeepUnprocessedMiniBlocks: 50, - MaxRoundsToKeepUnprocessedTransactions: 50, - }, - } - - appStatusHandler := statusHandlerMock.NewAppStatusHandlerMock() - - return factory.NetworkComponentsFactoryArgs{ - P2pConfig: p2pConfig, - MainConfig: mainConfig, - StatusHandler: appStatusHandler, - Marshalizer: &mock.MarshalizerMock{}, - RatingsConfig: config.RatingsConfig{ - General: config.General{}, - ShardChain: config.ShardChain{}, - MetaChain: config.MetaChain{}, - PeerHonesty: config.PeerHonestyConfig{ - DecayCoefficient: 0.9779, - DecayUpdateIntervalInSeconds: 10, - MaxScore: 100, - MinScore: -100, - BadPeerThreshold: -80, - UnitValue: 1.0, - }, - }, - Syncer: &libp2p.LocalSyncTimer{}, - NodeOperationMode: p2p.NormalOperation, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } -} diff --git a/factory/peerSignatureHandler/export_test.go b/factory/peerSignatureHandler/export_test.go index 89ecb48ff94..6132b794ed1 100644 --- a/factory/peerSignatureHandler/export_test.go +++ b/factory/peerSignatureHandler/export_test.go @@ -2,13 +2,13 @@ package peerSignatureHandler import ( "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-crypto" + "github.com/ElrondNetwork/elrond-go/errors" ) func (psh *peerSignatureHandler) GetPIDAndSig(entry interface{}) (core.PeerID, []byte, error) { pidSig, ok := entry.(*pidSignature) if !ok { - return "", nil, crypto.ErrWrongTypeAssertion + return "", nil, errors.ErrWrongTypeAssertion } return pidSig.pid, pidSig.signature, nil diff --git a/factory/peerSignatureHandler/peerSignatureHandler.go b/factory/peerSignatureHandler/peerSignatureHandler.go index b169d140720..6fa1f622915 100644 --- a/factory/peerSignatureHandler/peerSignatureHandler.go +++ b/factory/peerSignatureHandler/peerSignatureHandler.go @@ -6,6 +6,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-crypto" + "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/storage" ) @@ -29,10 +30,10 @@ func NewPeerSignatureHandler( keygen crypto.KeyGenerator, ) (*peerSignatureHandler, error) { if check.IfNil(pkPIDSignature) { - return nil, crypto.ErrNilCacher + return nil, errors.ErrNilCacher } if check.IfNil(singleSigner) { - return nil, crypto.ErrNilSingleSigner + return nil, errors.ErrNilSingleSigner } if check.IfNil(keygen) { return nil, crypto.ErrNilKeyGenerator @@ -52,10 +53,10 @@ func (psh *peerSignatureHandler) VerifyPeerSignature(pk []byte, pid core.PeerID, return crypto.ErrInvalidPublicKey } if len(pid) == 0 { - return crypto.ErrInvalidPID + return errors.ErrInvalidPID } if len(signature) == 0 { - return crypto.ErrInvalidSignature + return errors.ErrInvalidSignature } senderPubKey, err := psh.keygen.PublicKeyFromByteArray(pk) @@ -77,11 +78,11 @@ func (psh *peerSignatureHandler) VerifyPeerSignature(pk []byte, pid core.PeerID, } if retrievedPID != pid { - return crypto.ErrPIDMismatch + return errors.ErrPIDMismatch } if !bytes.Equal(retrievedSig, signature) { - return crypto.ErrSignatureMismatch + return errors.ErrSignatureMismatch } return nil diff --git a/factory/peerSignatureHandler/peerSignatureHandler_test.go b/factory/peerSignatureHandler/peerSignatureHandler_test.go index b4f47b0dd3f..3ab7717221b 100644 --- a/factory/peerSignatureHandler/peerSignatureHandler_test.go +++ b/factory/peerSignatureHandler/peerSignatureHandler_test.go @@ -7,6 +7,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-crypto" + errorsErd "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/factory/peerSignatureHandler" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" @@ -23,7 +24,7 @@ func TestNewPeerSignatureHandler_NilCacherShouldErr(t *testing.T) { ) assert.True(t, check.IfNil(peerSigHandler)) - assert.Equal(t, crypto.ErrNilCacher, err) + assert.Equal(t, errorsErd.ErrNilCacher, err) } func TestNewPeerSignatureHandler_NilSingleSignerShouldErr(t *testing.T) { @@ -36,7 +37,7 @@ func TestNewPeerSignatureHandler_NilSingleSignerShouldErr(t *testing.T) { ) assert.True(t, check.IfNil(peerSigHandler)) - assert.Equal(t, crypto.ErrNilSingleSigner, err) + assert.Equal(t, errorsErd.ErrNilSingleSigner, err) } func TestNewPeerSignatureHandler_NilKeyGenShouldErr(t *testing.T) { @@ -88,7 +89,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidPID(t *testing.T) { ) err := peerSigHandler.VerifyPeerSignature([]byte("public key"), "", []byte("signature")) - assert.Equal(t, crypto.ErrInvalidPID, err) + assert.Equal(t, errorsErd.ErrInvalidPID, err) } func TestPeerSignatureHandler_VerifyPeerSignatureInvalidSignature(t *testing.T) { @@ -101,7 +102,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidSignature(t *testing.T) ) err := peerSigHandler.VerifyPeerSignature([]byte("public key"), "dummy peer", nil) - assert.Equal(t, crypto.ErrInvalidSignature, err) + assert.Equal(t, errorsErd.ErrInvalidSignature, err) } func TestPeerSignatureHandler_VerifyPeerSignatureCantGetPubKeyBytes(t *testing.T) { @@ -303,7 +304,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureDifferentPid(t *testing.T) { cache.Put(pk, cacheEntry, len(pid)+len(sig)) err := peerSigHandler.VerifyPeerSignature(pk, newPid, sig) - assert.Equal(t, crypto.ErrPIDMismatch, err) + assert.Equal(t, errorsErd.ErrPIDMismatch, err) assert.False(t, verifyCalled) } @@ -343,7 +344,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureDifferentSig(t *testing.T) { cache.Put(pk, cacheEntry, len(pid)+len(sig)) err := peerSigHandler.VerifyPeerSignature(pk, pid, newSig) - assert.Equal(t, crypto.ErrSignatureMismatch, err) + assert.Equal(t, errorsErd.ErrSignatureMismatch, err) assert.False(t, verifyCalled) } diff --git a/factory/processComponents_test.go b/factory/processComponents_test.go deleted file mode 100644 index 43361d1fb7e..00000000000 --- a/factory/processComponents_test.go +++ /dev/null @@ -1,327 +0,0 @@ -package factory_test - -import ( - "math/big" - "strings" - "sync" - "testing" - - arwenConfig "github.com/ElrondNetwork/arwen-wasm-vm/v1_4/config" - coreData "github.com/ElrondNetwork/elrond-go-core/data" - "github.com/ElrondNetwork/elrond-go-core/data/block" - dataBlock "github.com/ElrondNetwork/elrond-go-core/data/block" - "github.com/ElrondNetwork/elrond-go-core/data/indexer" - "github.com/ElrondNetwork/elrond-go/common" - commonFactory "github.com/ElrondNetwork/elrond-go/common/factory" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/genesis" - "github.com/ElrondNetwork/elrond-go/genesis/data" - "github.com/ElrondNetwork/elrond-go/process" - "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/dblookupext" - "github.com/ElrondNetwork/elrond-go/testscommon/mainFactoryMocks" - "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" - storageStubs "github.com/ElrondNetwork/elrond-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// ------------ Test TestProcessComponents -------------------- -func TestProcessComponents_CloseShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - processArgs := getProcessComponentsArgs(shardCoordinator) - pcf, err := factory.NewProcessComponentsFactory(processArgs) - require.Nil(t, err) - - pc, err := pcf.Create() - require.Nil(t, err) - - err = pc.Close() - require.NoError(t, err) -} - -func TestProcessComponentsFactory_CreateWithInvalidTxAccumulatorTimeExpectError(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - processArgs := getProcessComponentsArgs(shardCoordinator) - processArgs.Config.Antiflood.TxAccumulator.MaxAllowedTimeInMilliseconds = 0 - pcf, err := factory.NewProcessComponentsFactory(processArgs) - require.Nil(t, err) - - instance, err := pcf.Create() - require.Nil(t, instance) - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), process.ErrInvalidValue.Error())) -} - -func getProcessComponentsArgs(shardCoordinator sharding.Coordinator) factory.ProcessComponentsFactoryArgs { - coreComponents := getCoreComponents() - networkComponents := getNetworkComponents() - dataComponents := getDataComponents(coreComponents, shardCoordinator) - cryptoComponents := getCryptoComponents(coreComponents) - stateComponents := getStateComponents(coreComponents, shardCoordinator) - processArgs := getProcessArgs( - shardCoordinator, - coreComponents, - dataComponents, - cryptoComponents, - stateComponents, - networkComponents, - ) - return processArgs -} - -func getProcessArgs( - shardCoordinator sharding.Coordinator, - coreComponents factory.CoreComponentsHolder, - dataComponents factory.DataComponentsHolder, - cryptoComponents factory.CryptoComponentsHolder, - stateComponents factory.StateComponentsHolder, - networkComponents factory.NetworkComponentsHolder, -) factory.ProcessComponentsFactoryArgs { - - gasSchedule := arwenConfig.MakeGasMapForTests() - // TODO: check if these could be initialized by MakeGasMapForTests() - gasSchedule["BuiltInCost"]["SaveUserName"] = 1 - gasSchedule["BuiltInCost"]["SaveKeyValue"] = 1 - gasSchedule["BuiltInCost"]["ESDTTransfer"] = 1 - gasSchedule["BuiltInCost"]["ESDTBurn"] = 1 - gasSchedule[common.MetaChainSystemSCsCost] = FillGasMapMetaChainSystemSCsCosts(1) - - gasScheduleNotifier := &testscommon.GasScheduleNotifierMock{ - GasSchedule: gasSchedule, - } - - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{} - statusComponents := getStatusComponents( - coreComponents, - networkComponents, - dataComponents, - stateComponents, - shardCoordinator, - nodesCoordinator, - ) - - bootstrapComponentsFactoryArgs := getBootStrapArgs() - - bootstrapComponentsFactory, _ := factory.NewBootstrapComponentsFactory(bootstrapComponentsFactoryArgs) - bootstrapComponents, _ := factory.NewManagedBootstrapComponents(bootstrapComponentsFactory) - _ = bootstrapComponents.Create() - factory.SetShardCoordinator(shardCoordinator, bootstrapComponents) - - return factory.ProcessComponentsFactoryArgs{ - Config: testscommon.GetGeneralConfig(), - AccountsParser: &mock.AccountsParserStub{ - InitialAccountsCalled: func() []genesis.InitialAccountHandler { - addrConverter, _ := commonFactory.NewPubkeyConverter(config.PubkeyConfig{ - Length: 32, - Type: "bech32", - SignatureLength: 0, - }) - balance := big.NewInt(0) - acc1 := data.InitialAccount{ - Address: "erd1ulhw20j7jvgfgak5p05kv667k5k9f320sgef5ayxkt9784ql0zssrzyhjp", - Supply: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), - Balance: balance, - StakingValue: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), - Delegation: &data.DelegationData{ - Address: "", - Value: big.NewInt(0), - }, - } - acc2 := data.InitialAccount{ - Address: "erd17c4fs6mz2aa2hcvva2jfxdsrdknu4220496jmswer9njznt22eds0rxlr4", - Supply: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), - Balance: balance, - StakingValue: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), - Delegation: &data.DelegationData{ - Address: "", - Value: big.NewInt(0), - }, - } - acc3 := data.InitialAccount{ - Address: "erd10d2gufxesrp8g409tzxljlaefhs0rsgjle3l7nq38de59txxt8csj54cd3", - Supply: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), - Balance: balance, - StakingValue: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), - Delegation: &data.DelegationData{ - Address: "", - Value: big.NewInt(0), - }, - } - - acc1Bytes, _ := addrConverter.Decode(acc1.Address) - acc1.SetAddressBytes(acc1Bytes) - acc2Bytes, _ := addrConverter.Decode(acc2.Address) - acc2.SetAddressBytes(acc2Bytes) - acc3Bytes, _ := addrConverter.Decode(acc3.Address) - acc3.SetAddressBytes(acc3Bytes) - initialAccounts := []genesis.InitialAccountHandler{&acc1, &acc2, &acc3} - - return initialAccounts - }, - GenerateInitialTransactionsCalled: func(shardCoordinator sharding.Coordinator, initialIndexingData map[uint32]*genesis.IndexingData) ([]*block.MiniBlock, map[uint32]*indexer.Pool, error) { - txsPool := make(map[uint32]*indexer.Pool) - for i := uint32(0); i < shardCoordinator.NumberOfShards(); i++ { - txsPool[i] = &indexer.Pool{} - } - - return make([]*block.MiniBlock, 4), txsPool, nil - }, - }, - SmartContractParser: &mock.SmartContractParserStub{}, - GasSchedule: gasScheduleNotifier, - NodesCoordinator: nodesCoordinator, - Data: dataComponents, - CoreData: coreComponents, - Crypto: cryptoComponents, - State: stateComponents, - Network: networkComponents, - StatusComponents: statusComponents, - BootstrapComponents: bootstrapComponents, - RequestedItemsHandler: &testscommon.RequestedItemsHandlerStub{}, - WhiteListHandler: &testscommon.WhiteListHandlerStub{}, - WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, - MaxRating: 100, - ImportStartHandler: &testscommon.ImportStartHandlerStub{}, - SystemSCConfig: &config.SystemSmartContractsConfig{ - ESDTSystemSCConfig: config.ESDTSystemSCConfig{ - BaseIssuingCost: "1000", - OwnerAddress: "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", - }, - GovernanceSystemSCConfig: config.GovernanceSystemSCConfig{ - V1: config.GovernanceSystemSCConfigV1{ - ProposalCost: "500", - NumNodes: 100, - MinQuorum: 50, - MinPassThreshold: 50, - MinVetoThreshold: 50, - }, - Active: config.GovernanceSystemSCConfigActive{ - ProposalCost: "500", - MinQuorum: "50", - MinPassThreshold: "50", - MinVetoThreshold: "50", - }, - FirstWhitelistedAddress: "erd1vxy22x0fj4zv6hktmydg8vpfh6euv02cz4yg0aaws6rrad5a5awqgqky80", - }, - StakingSystemSCConfig: config.StakingSystemSCConfig{ - GenesisNodePrice: "2500000000000000000000", - MinStakeValue: "1", - UnJailValue: "1", - MinStepValue: "1", - UnBondPeriod: 0, - NumRoundsWithoutBleed: 0, - MaximumPercentageToBleed: 0, - BleedPercentagePerRound: 0, - MaxNumberOfNodesForStake: 10, - ActivateBLSPubKeyMessageVerification: false, - MinUnstakeTokensValue: "1", - }, - DelegationManagerSystemSCConfig: config.DelegationManagerSystemSCConfig{ - MinCreationDeposit: "100", - MinStakeAmount: "100", - ConfigChangeAddress: "erd1vxy22x0fj4zv6hktmydg8vpfh6euv02cz4yg0aaws6rrad5a5awqgqky80", - }, - DelegationSystemSCConfig: config.DelegationSystemSCConfig{ - MinServiceFee: 0, - MaxServiceFee: 100, - }, - }, - Version: "v1.0.0", - HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - } -} - -// FillGasMapMetaChainSystemSCsCosts - -func FillGasMapMetaChainSystemSCsCosts(value uint64) map[string]uint64 { - gasMap := make(map[string]uint64) - gasMap["Stake"] = value - gasMap["UnStake"] = value - gasMap["UnBond"] = value - gasMap["Claim"] = value - gasMap["Get"] = value - gasMap["ChangeRewardAddress"] = value - gasMap["ChangeValidatorKeys"] = value - gasMap["UnJail"] = value - gasMap["ESDTIssue"] = value - gasMap["ESDTOperations"] = value - gasMap["Proposal"] = value - gasMap["Vote"] = value - gasMap["DelegateVote"] = value - gasMap["RevokeVote"] = value - gasMap["CloseProposal"] = value - gasMap["DelegationOps"] = value - gasMap["UnStakeTokens"] = value - gasMap["UnBondTokens"] = value - gasMap["DelegationMgrOps"] = value - gasMap["GetAllNodeStates"] = value - gasMap["ValidatorToDelegation"] = value - gasMap["FixWaitingListSize"] = value - - return gasMap -} - -func TestProcessComponents_IndexGenesisBlocks(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(1) - processArgs := getProcessComponentsArgs(shardCoordinator) - processArgs.Data = &mock.DataComponentsMock{ - Storage: &storageStubs.ChainStorerStub{}, - } - - saveBlockCalledMutex := sync.Mutex{} - - outportHandler := &testscommon.OutportStub{ - HasDriversCalled: func() bool { - return true - }, - SaveBlockCalled: func(args *indexer.ArgsSaveBlockData) { - saveBlockCalledMutex.Lock() - require.NotNil(t, args) - - bodyRequired := &dataBlock.Body{ - MiniBlocks: make([]*block.MiniBlock, 4), - } - - txsPoolRequired := &indexer.Pool{} - - assert.Equal(t, txsPoolRequired, args.TransactionsPool) - assert.Equal(t, bodyRequired, args.Body) - saveBlockCalledMutex.Unlock() - }, - } - - processArgs.StatusComponents = &mainFactoryMocks.StatusComponentsStub{ - Outport: outportHandler, - } - - pcf, err := factory.NewProcessComponentsFactory(processArgs) - require.Nil(t, err) - - genesisBlocks := make(map[uint32]coreData.HeaderHandler) - indexingData := make(map[uint32]*genesis.IndexingData) - - for i := uint32(0); i < shardCoordinator.NumberOfShards(); i++ { - genesisBlocks[i] = &block.Header{} - } - - err = pcf.IndexGenesisBlocks(genesisBlocks, indexingData) - require.Nil(t, err) -} diff --git a/factory/blockProcessorCreator.go b/factory/processing/blockProcessorCreator.go similarity index 97% rename from factory/blockProcessorCreator.go rename to factory/processing/blockProcessorCreator.go index d9fcfd9b7ca..d81ca45acdd 100644 --- a/factory/blockProcessorCreator.go +++ b/factory/processing/blockProcessorCreator.go @@ -1,4 +1,4 @@ -package factory +package processing import ( "errors" @@ -6,12 +6,15 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" dataBlock "github.com/ElrondNetwork/elrond-go-core/data/block" + logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/dataRetriever" + debugFactory "github.com/ElrondNetwork/elrond-go/debug/factory" "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/epochStart/bootstrap/disabled" metachainEpochStart "github.com/ElrondNetwork/elrond-go/epochStart/metachain" + mainFactory "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/genesis" processDisabled "github.com/ElrondNetwork/elrond-go/genesis/process/disabled" "github.com/ElrondNetwork/elrond-go/process" @@ -56,7 +59,7 @@ func (pcf *processComponentsFactory) newBlockProcessor( arwenChangeLocker common.Locker, scheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler, processedMiniBlocksTracker process.ProcessedMiniBlocksTracker, - receiptsRepository ReceiptsRepository, + receiptsRepository mainFactory.ReceiptsRepository, ) (*blockProcessorAndVmFactories, error) { if pcf.bootstrapComponents.ShardCoordinator().SelfId() < pcf.bootstrapComponents.ShardCoordinator().NumberOfShards() { return pcf.newShardBlockProcessor( @@ -95,6 +98,8 @@ func (pcf *processComponentsFactory) newBlockProcessor( return nil, errors.New("could not create block processor") } +var log = logger.GetOrCreate("factory") + func (pcf *processComponentsFactory) newShardBlockProcessor( requestHandler process.RequestHandler, forkDetector process.ForkDetector, @@ -107,7 +112,7 @@ func (pcf *processComponentsFactory) newShardBlockProcessor( arwenChangeLocker common.Locker, scheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler, processedMiniBlocksTracker process.ProcessedMiniBlocksTracker, - receiptsRepository ReceiptsRepository, + receiptsRepository mainFactory.ReceiptsRepository, ) (*blockProcessorAndVmFactories, error) { argsParser := smartContract.NewArgumentParser() @@ -412,6 +417,11 @@ func (pcf *processComponentsFactory) newShardBlockProcessor( return nil, errors.New("could not create block statisticsProcessor: " + err.Error()) } + err = pcf.attachProcessDebugger(blockProcessor, pcf.config.Debug.Process) + if err != nil { + return nil, err + } + blockProcessorComponents := &blockProcessorAndVmFactories{ blockProcessor: blockProcessor, vmFactoryForTxSimulate: vmFactoryTxSimulator, @@ -434,7 +444,7 @@ func (pcf *processComponentsFactory) newMetaBlockProcessor( arwenChangeLocker common.Locker, scheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler, processedMiniBlocksTracker process.ProcessedMiniBlocksTracker, - receiptsRepository ReceiptsRepository, + receiptsRepository mainFactory.ReceiptsRepository, ) (*blockProcessorAndVmFactories, error) { builtInFuncFactory, err := pcf.createBuiltInFunctionContainer(pcf.state.AccountsAdapter(), make(map[string]struct{})) if err != nil { @@ -854,6 +864,11 @@ func (pcf *processComponentsFactory) newMetaBlockProcessor( return nil, errors.New("could not create block processor: " + err.Error()) } + err = pcf.attachProcessDebugger(metaProcessor, pcf.config.Debug.Process) + if err != nil { + return nil, err + } + blockProcessorComponents := &blockProcessorAndVmFactories{ blockProcessor: metaProcessor, vmFactoryForTxSimulate: vmFactoryTxSimulator, @@ -863,6 +878,18 @@ func (pcf *processComponentsFactory) newMetaBlockProcessor( return blockProcessorComponents, nil } +func (pcf *processComponentsFactory) attachProcessDebugger( + processor mainFactory.ProcessDebuggerSetter, + configs config.ProcessDebugConfig, +) error { + processDebugger, err := debugFactory.CreateProcessDebugger(configs) + if err != nil { + return err + } + + return processor.SetProcessDebugger(processDebugger) +} + func (pcf *processComponentsFactory) createShardTxSimulatorProcessor( txSimulatorProcessorArgs *txsimulator.ArgsTxSimulator, scProcArgs smartContract.ArgsNewSmartContractProcessor, @@ -1162,7 +1189,7 @@ func (pcf *processComponentsFactory) createBuiltInFunctionContainer( accounts state.AccountsAdapter, mapDNSAddresses map[string]struct{}, ) (vmcommon.BuiltInFunctionFactory, error) { - convertedAddresses, err := decodeAddresses( + convertedAddresses, err := mainFactory.DecodeAddresses( pcf.coreData.AddressPubKeyConverter(), pcf.config.BuiltInFunctions.AutomaticCrawlerAddresses, ) @@ -1178,7 +1205,7 @@ func (pcf *processComponentsFactory) createBuiltInFunctionContainer( ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), EpochNotifier: pcf.coreData.EpochNotifier(), EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), - AutomaticCrawlerAddresses: convertedAddresses, + AutomaticCrawlerAddresses: convertedAddresses, MaxNumNodesInTransferRole: pcf.config.BuiltInFunctions.MaxNumAddressesInTransferRole, } diff --git a/factory/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go similarity index 83% rename from factory/blockProcessorCreator_test.go rename to factory/processing/blockProcessorCreator_test.go index d9e503fbfc8..a2a507a74f8 100644 --- a/factory/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -1,4 +1,4 @@ -package factory_test +package processing_test import ( "sync" @@ -8,17 +8,20 @@ import ( "github.com/ElrondNetwork/elrond-go-core/hashing" "github.com/ElrondNetwork/elrond-go-core/marshal" "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/factory" + dataComp "github.com/ElrondNetwork/elrond-go/factory/data" "github.com/ElrondNetwork/elrond-go/factory/mock" + processComp "github.com/ElrondNetwork/elrond-go/factory/processing" "github.com/ElrondNetwork/elrond-go/process/txsimulator" "github.com/ElrondNetwork/elrond-go/state" factoryState "github.com/ElrondNetwork/elrond-go/state/factory" "github.com/ElrondNetwork/elrond-go/state/storagePruningManager/disabled" "github.com/ElrondNetwork/elrond-go/storage/txcache" "github.com/ElrondNetwork/elrond-go/testscommon" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" storageManager "github.com/ElrondNetwork/elrond-go/testscommon/storage" + trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" "github.com/ElrondNetwork/elrond-go/trie" trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" vmcommon "github.com/ElrondNetwork/elrond-vm-common" @@ -32,7 +35,7 @@ func Test_newBlockProcessorCreatorForShard(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - pcf, _ := factory.NewProcessComponentsFactory(getProcessComponentsArgs(shardCoordinator)) + pcf, _ := processComp.NewProcessComponentsFactory(componentsMock.GetProcessComponentsFactoryArgs(shardCoordinator)) require.NotNil(t, pcf) _, err := pcf.Create() @@ -67,7 +70,7 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardC := mock.NewMultiShardsCoordinatorMock(1) shardC.SelfIDCalled = func() uint32 { return core.MetachainShardId @@ -81,13 +84,13 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { } shardC.CurrentShard = core.MetachainShardId - dataArgs := getDataArgs(coreComponents, shardC) - dataComponentsFactory, _ := factory.NewDataComponentsFactory(dataArgs) - dataComponents, _ := factory.NewManagedDataComponents(dataComponentsFactory) + dataArgs := componentsMock.GetDataArgs(coreComponents, shardC) + dataComponentsFactory, _ := dataComp.NewDataComponentsFactory(dataArgs) + dataComponents, _ := dataComp.NewManagedDataComponents(dataComponentsFactory) _ = dataComponents.Create() - networkComponents := getNetworkComponents() - cryptoComponents := getCryptoComponents(coreComponents) + networkComponents := componentsMock.GetNetworkComponents() + cryptoComponents := componentsMock.GetCryptoComponents(coreComponents) storageManagerArgs, options := storageManager.GetStorageManagerArgsAndOptions() storageManagerArgs.Marshalizer = coreComponents.InternalMarshalizer() @@ -134,13 +137,17 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { return accounts }, TriesContainerCalled: func() common.TriesHolder { - return &mock.TriesHolderStub{} + return &trieMock.TriesHolderStub{ + GetCalled: func(bytes []byte) common.Trie { + return &trieMock.TrieStub{} + }, + } }, TrieStorageManagersCalled: func() map[string]common.StorageManager { return trieStorageManagers }, } - args := getProcessArgs( + args := componentsMock.GetProcessArgs( shardC, coreComponents, dataComponents, @@ -149,9 +156,9 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { networkComponents, ) - factory.SetShardCoordinator(shardC, args.BootstrapComponents) + componentsMock.SetShardCoordinator(t, args.BootstrapComponents, shardC) - pcf, _ := factory.NewProcessComponentsFactory(args) + pcf, _ := processComp.NewProcessComponentsFactory(args) require.NotNil(t, pcf) _, err = pcf.Create() diff --git a/factory/export_test.go b/factory/processing/export_test.go similarity index 51% rename from factory/export_test.go rename to factory/processing/export_test.go index 676fc2b115e..f2a57413fde 100644 --- a/factory/export_test.go +++ b/factory/processing/export_test.go @@ -1,60 +1,16 @@ -package factory +package processing import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data" - "github.com/ElrondNetwork/elrond-go-core/hashing" - crypto "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/epochStart" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/genesis" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/txsimulator" - "github.com/ElrondNetwork/elrond-go/sharding" ) -// GetSkPk - -func (ccf *cryptoComponentsFactory) GetSkPk() ([]byte, []byte, error) { - return ccf.getSkPk() -} - -// CreateSingleSigner - -func (ccf *cryptoComponentsFactory) CreateSingleSigner(importModeNoSigCheck bool) (crypto.SingleSigner, error) { - return ccf.createSingleSigner(importModeNoSigCheck) -} - -// GetMultiSigHasherFromConfig - -func (ccf *cryptoComponentsFactory) GetMultiSigHasherFromConfig() (hashing.Hasher, error) { - return ccf.getMultiSigHasherFromConfig() -} - -// CreateDummyCryptoParams -func (ccf *cryptoComponentsFactory) CreateDummyCryptoParams() *cryptoParams { - return &cryptoParams{} -} - -// CreateCryptoParams - -func (ccf *cryptoComponentsFactory) CreateCryptoParams(blockSignKeyGen crypto.KeyGenerator) (*cryptoParams, error) { - return ccf.createCryptoParams(blockSignKeyGen) -} - -// CreateMultiSigner - -func (ccf *cryptoComponentsFactory) CreateMultiSigner( - h hashing.Hasher, cp *cryptoParams, blSignKeyGen crypto.KeyGenerator, importModeNoSigCheck bool, -) (crypto.MultiSigner, error) { - return ccf.createMultiSigner(h, cp, blSignKeyGen, importModeNoSigCheck) -} - -// GetSuite - -func (ccf *cryptoComponentsFactory) GetSuite() (crypto.Suite, error) { - return ccf.getSuite() -} - -// SetListenAddress - -func (ncf *networkComponentsFactory) SetListenAddress(address string) { - ncf.listenAddress = address -} - // NewBlockProcessor calls the unexported method with the same name in order to use it in tests func (pcf *processComponentsFactory) NewBlockProcessor( requestHandler process.RequestHandler, @@ -69,7 +25,7 @@ func (pcf *processComponentsFactory) NewBlockProcessor( arwenChangeLocker common.Locker, scheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler, processedMiniBlocksTracker process.ProcessedMiniBlocksTracker, - receiptsRepository ReceiptsRepository, + receiptsRepository factory.ReceiptsRepository, ) (process.BlockProcessor, process.VirtualMachinesContainerFactory, error) { blockProcessorComponents, err := pcf.newBlockProcessor( requestHandler, @@ -93,16 +49,6 @@ func (pcf *processComponentsFactory) NewBlockProcessor( return blockProcessorComponents.blockProcessor, blockProcessorComponents.vmFactoryForTxSimulate, nil } -// SetShardCoordinator - -func SetShardCoordinator(shardCoordinator sharding.Coordinator, holder BootstrapComponentsHolder) { - mbf := holder.(*managedBootstrapComponents) - - mbf.mutBootstrapComponents.Lock() - defer mbf.mutBootstrapComponents.Unlock() - - mbf.bootstrapComponents.shardCoordinator = shardCoordinator -} - // IndexGenesisBlocks - func (pcf *processComponentsFactory) IndexGenesisBlocks(genesisBlocks map[uint32]data.HeaderHandler, indexingData map[uint32]*genesis.IndexingData) error { return pcf.indexGenesisBlocks(genesisBlocks, indexingData) @@ -110,5 +56,5 @@ func (pcf *processComponentsFactory) IndexGenesisBlocks(genesisBlocks map[uint32 // DecodeAddresses - func DecodeAddresses(pkConverter core.PubkeyConverter, automaticCrawlerAddressesStrings []string) ([][]byte, error) { - return decodeAddresses(pkConverter, automaticCrawlerAddressesStrings) + return factory.DecodeAddresses(pkConverter, automaticCrawlerAddressesStrings) } diff --git a/factory/processComponents.go b/factory/processing/processComponents.go similarity index 96% rename from factory/processComponents.go rename to factory/processing/processComponents.go index 55279eec21d..ee9303614fb 100644 --- a/factory/processComponents.go +++ b/factory/processing/processComponents.go @@ -1,4 +1,4 @@ -package factory +package processing import ( "context" @@ -14,8 +14,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data" dataBlock "github.com/ElrondNetwork/elrond-go-core/data/block" "github.com/ElrondNetwork/elrond-go-core/data/indexer" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/cmd/node/factory" + nodeFactory "github.com/ElrondNetwork/elrond-go/cmd/node/factory" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/consensus" @@ -31,6 +30,8 @@ import ( "github.com/ElrondNetwork/elrond-go/epochStart/notifier" "github.com/ElrondNetwork/elrond-go/epochStart/shardchain" errErd "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" + mainFactory "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/factory/disabled" "github.com/ElrondNetwork/elrond-go/fallback" "github.com/ElrondNetwork/elrond-go/genesis" @@ -60,17 +61,15 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage" + "github.com/ElrondNetwork/elrond-go/storage/cache" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/update" updateDisabled "github.com/ElrondNetwork/elrond-go/update/disabled" updateFactory "github.com/ElrondNetwork/elrond-go/update/factory" "github.com/ElrondNetwork/elrond-go/update/trigger" ) -var log = logger.GetOrCreate("factory") - // timeSpanForBadHeaders is the expiry time for an added block header hash var timeSpanForBadHeaders = time.Minute * 2 @@ -82,13 +81,13 @@ type processComponents struct { resolversFinder dataRetriever.ResolversFinder roundHandler consensus.RoundHandler epochStartTrigger epochStart.TriggerHandler - epochStartNotifier EpochStartNotifier + epochStartNotifier factory.EpochStartNotifier forkDetector process.ForkDetector blockProcessor process.BlockProcessor blackListHandler process.TimeCacher bootStorer process.BootStorer headerSigVerifier process.InterceptedHeaderSigVerifier - headerIntegrityVerifier factory.HeaderIntegrityVerifierHandler + headerIntegrityVerifier nodeFactory.HeaderIntegrityVerifierHandler validatorsStatistics process.ValidatorStatisticsProcessor validatorsProvider process.ValidatorsProvider blockTracker process.BlockTracker @@ -97,7 +96,7 @@ type processComponents struct { txLogsProcessor process.TransactionLogProcessorDatabase headerConstructionValidator process.HeaderConstructionValidator peerShardMapper process.NetworkShardingCollector - txSimulatorProcessor TransactionSimulatorProcessor + txSimulatorProcessor factory.TransactionSimulatorProcessor miniBlocksPoolCleaner process.PoolsCleaner txsPoolCleaner process.PoolsCleaner fallbackHeaderValidator process.FallbackHeaderValidator @@ -113,10 +112,10 @@ type processComponents struct { vmFactoryForProcessing process.VirtualMachinesContainerFactory scheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler txsSender process.TxsSenderHandler - hardforkTrigger HardforkTrigger + hardforkTrigger factory.HardforkTrigger processedMiniBlocksTracker process.ProcessedMiniBlocksTracker accountsParser genesis.AccountsParser - receiptsRepository ReceiptsRepository + receiptsRepository mainFactory.ReceiptsRepository } // ProcessComponentsFactoryArgs holds the arguments needed to create a process components factory @@ -139,13 +138,13 @@ type ProcessComponentsFactoryArgs struct { WorkingDir string HistoryRepo dblookupext.HistoryRepository - Data DataComponentsHolder - CoreData CoreComponentsHolder - Crypto CryptoComponentsHolder - State StateComponentsHolder - Network NetworkComponentsHolder - BootstrapComponents BootstrapComponentsHolder - StatusComponents StatusComponentsHolder + Data factory.DataComponentsHolder + CoreData factory.CoreComponentsHolder + Crypto factory.CryptoComponentsHolder + State factory.StateComponentsHolder + Network factory.NetworkComponentsHolder + BootstrapComponents factory.BootstrapComponentsHolder + StatusComponents factory.StatusComponentsHolder } type processComponentsFactory struct { @@ -170,13 +169,13 @@ type processComponentsFactory struct { epochNotifier process.EpochNotifier importHandler update.ImportHandler - data DataComponentsHolder - coreData CoreComponentsHolder - crypto CryptoComponentsHolder - state StateComponentsHolder - network NetworkComponentsHolder - bootstrapComponents BootstrapComponentsHolder - statusComponents StatusComponentsHolder + data factory.DataComponentsHolder + coreData factory.CoreComponentsHolder + crypto factory.CryptoComponentsHolder + state factory.StateComponentsHolder + network factory.NetworkComponentsHolder + bootstrapComponents factory.BootstrapComponentsHolder + statusComponents factory.StatusComponentsHolder } // NewProcessComponentsFactory will return a new instance of processComponentsFactory @@ -244,7 +243,7 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { Marshalizer: pcf.coreData.InternalMarshalizer(), Hasher: pcf.coreData.Hasher(), NodesCoordinator: pcf.nodesCoordinator, - MultiSigVerifier: pcf.crypto.MultiSigner(), + MultiSigContainer: pcf.crypto.MultiSignerContainer(), SingleSigVerifier: pcf.crypto.BlockSigner(), KeyGen: pcf.crypto.BlockSignKeyGen(), FallbackHeaderValidator: fallbackHeaderValidator, @@ -517,7 +516,7 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { } vmOutputCacherConfig := storageFactory.GetCacherFromConfig(pcf.config.VMOutputCacher) - vmOutputCacher, err := storageUnit.NewCache(vmOutputCacherConfig) + vmOutputCacher, err := storageunit.NewCache(vmOutputCacherConfig) if err != nil { return nil, err } @@ -1169,12 +1168,12 @@ func (pcf *processComponentsFactory) newMetaResolverContainerFactory( func (pcf *processComponentsFactory) newInterceptorContainerFactory( headerSigVerifier process.InterceptedHeaderSigVerifier, - headerIntegrityVerifier factory.HeaderIntegrityVerifierHandler, + headerIntegrityVerifier nodeFactory.HeaderIntegrityVerifierHandler, validityAttester process.ValidityAttester, epochStartTrigger process.EpochStartTriggerHandler, requestHandler process.RequestHandler, peerShardMapper *networksharding.PeerShardMapper, - hardforkTrigger HardforkTrigger, + hardforkTrigger factory.HardforkTrigger, ) (process.InterceptorsContainerFactory, process.TimeCacher, error) { if pcf.bootstrapComponents.ShardCoordinator().SelfId() < pcf.bootstrapComponents.ShardCoordinator().NumberOfShards() { return pcf.newShardInterceptorContainerFactory( @@ -1325,14 +1324,14 @@ func (pcf *processComponentsFactory) createStorageResolversForShard( func (pcf *processComponentsFactory) newShardInterceptorContainerFactory( headerSigVerifier process.InterceptedHeaderSigVerifier, - headerIntegrityVerifier factory.HeaderIntegrityVerifierHandler, + headerIntegrityVerifier nodeFactory.HeaderIntegrityVerifierHandler, validityAttester process.ValidityAttester, epochStartTrigger process.EpochStartTriggerHandler, requestHandler process.RequestHandler, peerShardMapper *networksharding.PeerShardMapper, - hardforkTrigger HardforkTrigger, + hardforkTrigger factory.HardforkTrigger, ) (process.InterceptorsContainerFactory, process.TimeCacher, error) { - headerBlackList := timecache.NewTimeCache(timeSpanForBadHeaders) + headerBlackList := cache.NewTimeCache(timeSpanForBadHeaders) shardInterceptorsContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ CoreComponents: pcf.coreData, CryptoComponents: pcf.crypto, @@ -1373,14 +1372,14 @@ func (pcf *processComponentsFactory) newShardInterceptorContainerFactory( func (pcf *processComponentsFactory) newMetaInterceptorContainerFactory( headerSigVerifier process.InterceptedHeaderSigVerifier, - headerIntegrityVerifier factory.HeaderIntegrityVerifierHandler, + headerIntegrityVerifier nodeFactory.HeaderIntegrityVerifierHandler, validityAttester process.ValidityAttester, epochStartTrigger process.EpochStartTriggerHandler, requestHandler process.RequestHandler, peerShardMapper *networksharding.PeerShardMapper, - hardforkTrigger HardforkTrigger, + hardforkTrigger factory.HardforkTrigger, ) (process.InterceptorsContainerFactory, process.TimeCacher, error) { - headerBlackList := timecache.NewTimeCache(timeSpanForBadHeaders) + headerBlackList := cache.NewTimeCache(timeSpanForBadHeaders) metaInterceptorsContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ CoreComponents: pcf.coreData, CryptoComponents: pcf.crypto, @@ -1509,7 +1508,7 @@ func (pcf *processComponentsFactory) createExportFactoryHandler( return updateFactory.NewExportHandlerFactory(argsExporter) } -func (pcf *processComponentsFactory) createHardforkTrigger(epochStartTrigger update.EpochHandler) (HardforkTrigger, error) { +func (pcf *processComponentsFactory) createHardforkTrigger(epochStartTrigger update.EpochHandler) (factory.HardforkTrigger, error) { hardforkConfig := pcf.config.Hardfork selfPubKeyBytes := pcf.crypto.PublicKeyBytes() triggerPubKeyBytes, err := pcf.coreData.ValidatorPubKeyConverter().Decode(hardforkConfig.PublicKeyToListenFrom) @@ -1538,7 +1537,7 @@ func (pcf *processComponentsFactory) createHardforkTrigger(epochStartTrigger upd func createNetworkShardingCollector( config *config.Config, nodesCoordinator nodesCoordinator.NodesCoordinator, - preferredPeersHolder PreferredPeersHolderHandler, + preferredPeersHolder factory.PreferredPeersHolderHandler, ) (*networksharding.PeerShardMapper, error) { cacheConfig := config.PublicKeyPeerId @@ -1575,7 +1574,7 @@ func createNetworkShardingCollector( } func createCache(cacheConfig config.CacheConfig) (storage.Cacher, error) { - return storageUnit.NewCache(storageFactory.GetCacherFromConfig(cacheConfig)) + return storageunit.NewCache(storageFactory.GetCacherFromConfig(cacheConfig)) } func checkProcessComponentsArgs(args ProcessComponentsFactoryArgs) error { diff --git a/factory/processComponentsHandler.go b/factory/processing/processComponentsHandler.go similarity index 96% rename from factory/processComponentsHandler.go rename to factory/processing/processComponentsHandler.go index f31fda36a68..2b9bb8d6080 100644 --- a/factory/processComponentsHandler.go +++ b/factory/processing/processComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package processing import ( "sync" @@ -9,6 +9,7 @@ import ( "github.com/ElrondNetwork/elrond-go/dblookupext" "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/genesis" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/sharding" @@ -16,9 +17,9 @@ import ( "github.com/ElrondNetwork/elrond-go/update" ) -var _ ComponentHandler = (*managedProcessComponents)(nil) -var _ ProcessComponentsHolder = (*managedProcessComponents)(nil) -var _ ProcessComponentsHandler = (*managedProcessComponents)(nil) +var _ factory.ComponentHandler = (*managedProcessComponents)(nil) +var _ factory.ProcessComponentsHolder = (*managedProcessComponents)(nil) +var _ factory.ProcessComponentsHandler = (*managedProcessComponents)(nil) type managedProcessComponents struct { *processComponents @@ -235,7 +236,7 @@ func (m *managedProcessComponents) EpochStartTrigger() epochStart.TriggerHandler } // EpochStartNotifier returns the epoch start notifier -func (m *managedProcessComponents) EpochStartNotifier() EpochStartNotifier { +func (m *managedProcessComponents) EpochStartNotifier() factory.EpochStartNotifier { m.mutProcessComponents.RLock() defer m.mutProcessComponents.RUnlock() @@ -427,7 +428,7 @@ func (m *managedProcessComponents) FallbackHeaderValidator() process.FallbackHea } // TransactionSimulatorProcessor returns the transaction simulator processor -func (m *managedProcessComponents) TransactionSimulatorProcessor() TransactionSimulatorProcessor { +func (m *managedProcessComponents) TransactionSimulatorProcessor() factory.TransactionSimulatorProcessor { m.mutProcessComponents.RLock() defer m.mutProcessComponents.RUnlock() @@ -559,7 +560,7 @@ func (m *managedProcessComponents) TxsSenderHandler() process.TxsSenderHandler { } // HardforkTrigger returns the hardfork trigger -func (m *managedProcessComponents) HardforkTrigger() HardforkTrigger { +func (m *managedProcessComponents) HardforkTrigger() factory.HardforkTrigger { m.mutProcessComponents.RLock() defer m.mutProcessComponents.RUnlock() @@ -583,7 +584,7 @@ func (m *managedProcessComponents) ProcessedMiniBlocksTracker() process.Processe } // ReceiptsRepository returns the receipts repository -func (m *managedProcessComponents) ReceiptsRepository() ReceiptsRepository { +func (m *managedProcessComponents) ReceiptsRepository() factory.ReceiptsRepository { m.mutProcessComponents.RLock() defer m.mutProcessComponents.RUnlock() @@ -601,5 +602,5 @@ func (m *managedProcessComponents) IsInterfaceNil() bool { // String returns the name of the component func (m *managedProcessComponents) String() string { - return processComponentsName + return factory.ProcessComponentsName } diff --git a/factory/processComponentsHandler_test.go b/factory/processing/processComponentsHandler_test.go similarity index 83% rename from factory/processComponentsHandler_test.go rename to factory/processing/processComponentsHandler_test.go index 1f06e0ca35d..603dff6ac7b 100644 --- a/factory/processComponentsHandler_test.go +++ b/factory/processing/processComponentsHandler_test.go @@ -1,12 +1,13 @@ -package factory_test +package processing_test import ( "testing" "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/factory/mock" + processComp "github.com/ElrondNetwork/elrond-go/factory/processing" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -18,10 +19,10 @@ func TestManagedProcessComponents_CreateWithInvalidArgsShouldErr(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - processArgs := getProcessComponentsArgs(shardCoordinator) + processArgs := componentsMock.GetProcessComponentsFactoryArgs(shardCoordinator) _ = processArgs.CoreData.SetInternalMarshalizer(nil) - processComponentsFactory, _ := factory.NewProcessComponentsFactory(processArgs) - managedProcessComponents, err := factory.NewManagedProcessComponents(processComponentsFactory) + processComponentsFactory, _ := processComp.NewProcessComponentsFactory(processArgs) + managedProcessComponents, err := processComp.NewManagedProcessComponents(processComponentsFactory) require.NoError(t, err) err = managedProcessComponents.Create() require.Error(t, err) @@ -34,7 +35,7 @@ func TestManagedProcessComponents_CreateShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(1) shardCoordinator.SelfIDCalled = func() uint32 { return core.MetachainShardId @@ -48,11 +49,11 @@ func TestManagedProcessComponents_CreateShouldWork(t *testing.T) { } shardCoordinator.CurrentShard = core.MetachainShardId - dataComponents := getDataComponents(coreComponents, shardCoordinator) - networkComponents := getNetworkComponents() - cryptoComponents := getCryptoComponents(coreComponents) - stateComponents := getStateComponents(coreComponents, shardCoordinator) - processArgs := getProcessArgs( + dataComponents := componentsMock.GetDataComponents(coreComponents, shardCoordinator) + networkComponents := componentsMock.GetNetworkComponents() + cryptoComponents := componentsMock.GetCryptoComponents(coreComponents) + stateComponents := componentsMock.GetStateComponents(coreComponents, shardCoordinator) + processArgs := componentsMock.GetProcessArgs( shardCoordinator, coreComponents, dataComponents, @@ -61,11 +62,11 @@ func TestManagedProcessComponents_CreateShouldWork(t *testing.T) { networkComponents, ) - factory.SetShardCoordinator(shardCoordinator, processArgs.BootstrapComponents) + componentsMock.SetShardCoordinator(t, processArgs.BootstrapComponents, shardCoordinator) - processComponentsFactory, err := factory.NewProcessComponentsFactory(processArgs) + processComponentsFactory, err := processComp.NewProcessComponentsFactory(processArgs) require.Nil(t, err) - managedProcessComponents, err := factory.NewManagedProcessComponents(processComponentsFactory) + managedProcessComponents, err := processComp.NewManagedProcessComponents(processComponentsFactory) require.NoError(t, err) require.True(t, check.IfNil(managedProcessComponents.NodesCoordinator())) require.True(t, check.IfNil(managedProcessComponents.InterceptorsContainer())) @@ -151,9 +152,9 @@ func TestManagedProcessComponents_Close(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - processArgs := getProcessComponentsArgs(shardCoordinator) - processComponentsFactory, _ := factory.NewProcessComponentsFactory(processArgs) - managedProcessComponents, _ := factory.NewManagedProcessComponents(processComponentsFactory) + processArgs := componentsMock.GetProcessComponentsFactoryArgs(shardCoordinator) + processComponentsFactory, _ := processComp.NewProcessComponentsFactory(processArgs) + managedProcessComponents, _ := processComp.NewManagedProcessComponents(processComponentsFactory) err := managedProcessComponents.Create() require.NoError(t, err) diff --git a/factory/processing/processComponents_test.go b/factory/processing/processComponents_test.go new file mode 100644 index 00000000000..ee3334f7554 --- /dev/null +++ b/factory/processing/processComponents_test.go @@ -0,0 +1,111 @@ +package processing_test + +import ( + "strings" + "sync" + "testing" + + coreData "github.com/ElrondNetwork/elrond-go-core/data" + "github.com/ElrondNetwork/elrond-go-core/data/block" + dataBlock "github.com/ElrondNetwork/elrond-go-core/data/block" + "github.com/ElrondNetwork/elrond-go-core/data/indexer" + "github.com/ElrondNetwork/elrond-go/factory/mock" + processComp "github.com/ElrondNetwork/elrond-go/factory/processing" + "github.com/ElrondNetwork/elrond-go/genesis" + "github.com/ElrondNetwork/elrond-go/process" + "github.com/ElrondNetwork/elrond-go/testscommon" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/ElrondNetwork/elrond-go/testscommon/mainFactoryMocks" + storageStubs "github.com/ElrondNetwork/elrond-go/testscommon/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ------------ Test TestProcessComponents -------------------- +func TestProcessComponents_CloseShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + processArgs := componentsMock.GetProcessComponentsFactoryArgs(shardCoordinator) + pcf, err := processComp.NewProcessComponentsFactory(processArgs) + require.Nil(t, err) + + pc, err := pcf.Create() + require.Nil(t, err) + + err = pc.Close() + require.NoError(t, err) +} + +func TestProcessComponentsFactory_CreateWithInvalidTxAccumulatorTimeExpectError(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + processArgs := componentsMock.GetProcessComponentsFactoryArgs(shardCoordinator) + processArgs.Config.Antiflood.TxAccumulator.MaxAllowedTimeInMilliseconds = 0 + pcf, err := processComp.NewProcessComponentsFactory(processArgs) + require.Nil(t, err) + + instance, err := pcf.Create() + require.Nil(t, instance) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), process.ErrInvalidValue.Error())) +} + +func TestProcessComponents_IndexGenesisBlocks(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(1) + processArgs := componentsMock.GetProcessComponentsFactoryArgs(shardCoordinator) + processArgs.Data = &mock.DataComponentsMock{ + Storage: &storageStubs.ChainStorerStub{}, + } + + saveBlockCalledMutex := sync.Mutex{} + + outportHandler := &testscommon.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveBlockCalled: func(args *indexer.ArgsSaveBlockData) { + saveBlockCalledMutex.Lock() + require.NotNil(t, args) + + bodyRequired := &dataBlock.Body{ + MiniBlocks: make([]*block.MiniBlock, 4), + } + + txsPoolRequired := &indexer.Pool{} + + assert.Equal(t, txsPoolRequired, args.TransactionsPool) + assert.Equal(t, bodyRequired, args.Body) + saveBlockCalledMutex.Unlock() + }, + } + + processArgs.StatusComponents = &mainFactoryMocks.StatusComponentsStub{ + Outport: outportHandler, + } + + pcf, err := processComp.NewProcessComponentsFactory(processArgs) + require.Nil(t, err) + + genesisBlocks := make(map[uint32]coreData.HeaderHandler) + indexingData := make(map[uint32]*genesis.IndexingData) + + for i := uint32(0); i < shardCoordinator.NumberOfShards(); i++ { + genesisBlocks[i] = &block.Header{} + } + + err = pcf.IndexGenesisBlocks(genesisBlocks, indexingData) + require.Nil(t, err) +} diff --git a/factory/stateComponents.go b/factory/state/stateComponents.go similarity index 98% rename from factory/stateComponents.go rename to factory/state/stateComponents.go index 7798dd0b1fb..09606e5d6b3 100644 --- a/factory/stateComponents.go +++ b/factory/state/stateComponents.go @@ -1,4 +1,4 @@ -package factory +package state import ( "fmt" @@ -9,6 +9,7 @@ import ( "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/dataRetriever" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/state" factoryState "github.com/ElrondNetwork/elrond-go/state/factory" @@ -23,7 +24,7 @@ import ( type StateComponentsFactoryArgs struct { Config config.Config ShardCoordinator sharding.Coordinator - Core CoreComponentsHolder + Core factory.CoreComponentsHolder StorageService dataRetriever.StorageService ProcessingMode common.NodeProcessingMode ShouldSerializeSnapshots bool @@ -33,7 +34,7 @@ type StateComponentsFactoryArgs struct { type stateComponentsFactory struct { config config.Config shardCoordinator sharding.Coordinator - core CoreComponentsHolder + core factory.CoreComponentsHolder storageService dataRetriever.StorageService processingMode common.NodeProcessingMode shouldSerializeSnapshots bool diff --git a/factory/stateComponentsHandler.go b/factory/state/stateComponentsHandler.go similarity index 94% rename from factory/stateComponentsHandler.go rename to factory/state/stateComponentsHandler.go index 00cbc6f9378..1550322e0a5 100644 --- a/factory/stateComponentsHandler.go +++ b/factory/state/stateComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package state import ( "fmt" @@ -7,12 +7,13 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/state" ) -var _ ComponentHandler = (*managedStateComponents)(nil) -var _ StateComponentsHolder = (*managedStateComponents)(nil) -var _ StateComponentsHandler = (*managedStateComponents)(nil) +var _ factory.ComponentHandler = (*managedStateComponents)(nil) +var _ factory.StateComponentsHolder = (*managedStateComponents)(nil) +var _ factory.StateComponentsHandler = (*managedStateComponents)(nil) type managedStateComponents struct { *stateComponents @@ -205,5 +206,5 @@ func (msc *managedStateComponents) IsInterfaceNil() bool { // String returns the name of the component func (msc *managedStateComponents) String() string { - return stateComponentsName + return factory.StateComponentsName } diff --git a/factory/stateComponentsHandler_test.go b/factory/state/stateComponentsHandler_test.go similarity index 63% rename from factory/stateComponentsHandler_test.go rename to factory/state/stateComponentsHandler_test.go index 32e91b0f978..37d0206d6e2 100644 --- a/factory/stateComponentsHandler_test.go +++ b/factory/state/stateComponentsHandler_test.go @@ -1,12 +1,14 @@ -package factory_test +package state_test import ( "testing" "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/factory/mock" + stateComp "github.com/ElrondNetwork/elrond-go/factory/state" "github.com/ElrondNetwork/elrond-go/testscommon" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" "github.com/stretchr/testify/require" ) @@ -17,11 +19,11 @@ func TestManagedStateComponents_CreateWithInvalidArgsShouldErr(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - stateComponentsFactory, _ := factory.NewStateComponentsFactory(args) - managedStateComponents, err := factory.NewManagedStateComponents(stateComponentsFactory) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + stateComponentsFactory, _ := stateComp.NewStateComponentsFactory(args) + managedStateComponents, err := stateComp.NewManagedStateComponents(stateComponentsFactory) require.NoError(t, err) _ = args.Core.SetInternalMarshalizer(nil) err = managedStateComponents.Create() @@ -35,11 +37,11 @@ func TestManagedStateComponents_CreateShouldWork(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - stateComponentsFactory, _ := factory.NewStateComponentsFactory(args) - managedStateComponents, err := factory.NewManagedStateComponents(stateComponentsFactory) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + stateComponentsFactory, _ := stateComp.NewStateComponentsFactory(args) + managedStateComponents, err := stateComp.NewManagedStateComponents(stateComponentsFactory) require.NoError(t, err) require.Nil(t, managedStateComponents.AccountsAdapter()) require.Nil(t, managedStateComponents.PeerAccounts()) @@ -60,11 +62,11 @@ func TestManagedStateComponents_Close(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - stateComponentsFactory, _ := factory.NewStateComponentsFactory(args) - managedStateComponents, _ := factory.NewManagedStateComponents(stateComponentsFactory) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + stateComponentsFactory, _ := stateComp.NewStateComponentsFactory(args) + managedStateComponents, _ := stateComp.NewManagedStateComponents(stateComponentsFactory) err := managedStateComponents.Create() require.NoError(t, err) @@ -79,11 +81,11 @@ func TestManagedStateComponents_CheckSubcomponents(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - stateComponentsFactory, _ := factory.NewStateComponentsFactory(args) - managedStateComponents, _ := factory.NewManagedStateComponents(stateComponentsFactory) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + stateComponentsFactory, _ := stateComp.NewStateComponentsFactory(args) + managedStateComponents, _ := stateComp.NewManagedStateComponents(stateComponentsFactory) err := managedStateComponents.Create() require.NoError(t, err) @@ -97,15 +99,15 @@ func TestManagedStateComponents_Setters(t *testing.T) { t.Skip("this is not a short test") } - coreComponents := getCoreComponents() + coreComponents := componentsMock.GetCoreComponents() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - stateComponentsFactory, _ := factory.NewStateComponentsFactory(args) - managedStateComponents, _ := factory.NewManagedStateComponents(stateComponentsFactory) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + stateComponentsFactory, _ := stateComp.NewStateComponentsFactory(args) + managedStateComponents, _ := stateComp.NewManagedStateComponents(stateComponentsFactory) err := managedStateComponents.Create() require.NoError(t, err) - triesContainer := &mock.TriesHolderStub{} + triesContainer := &trieMock.TriesHolderStub{} triesStorageManagers := map[string]common.StorageManager{"a": &testscommon.StorageManagerStub{}} err = managedStateComponents.SetTriesContainer(triesContainer) diff --git a/factory/state/stateComponents_test.go b/factory/state/stateComponents_test.go new file mode 100644 index 00000000000..1579456acb8 --- /dev/null +++ b/factory/state/stateComponents_test.go @@ -0,0 +1,93 @@ +package state_test + +import ( + "testing" + + "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory/mock" + stateComp "github.com/ElrondNetwork/elrond-go/factory/state" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/stretchr/testify/require" +) + +func TestNewStateComponentsFactory_NilShardCoordinatorShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + coreComponents := componentsMock.GetCoreComponents() + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + args.ShardCoordinator = nil + + scf, err := stateComp.NewStateComponentsFactory(args) + require.Nil(t, scf) + require.Equal(t, errors.ErrNilShardCoordinator, err) +} + +func TestNewStateComponentsFactory_NilCoreComponents(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + coreComponents := componentsMock.GetCoreComponents() + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + args.Core = nil + + scf, err := stateComp.NewStateComponentsFactory(args) + require.Nil(t, scf) + require.Equal(t, errors.ErrNilCoreComponents, err) +} + +func TestNewStateComponentsFactory_ShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + coreComponents := componentsMock.GetCoreComponents() + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + + scf, err := stateComp.NewStateComponentsFactory(args) + require.NoError(t, err) + require.NotNil(t, scf) +} + +func TestStateComponentsFactory_CreateShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + coreComponents := componentsMock.GetCoreComponents() + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + + scf, _ := stateComp.NewStateComponentsFactory(args) + + res, err := scf.Create() + require.NoError(t, err) + require.NotNil(t, res) +} + +// ------------ Test StateComponents -------------------- +func TestStateComponents_CloseShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + coreComponents := componentsMock.GetCoreComponents() + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args := componentsMock.GetStateFactoryArgs(coreComponents, shardCoordinator) + scf, _ := stateComp.NewStateComponentsFactory(args) + + sc, _ := scf.Create() + + err := sc.Close() + require.NoError(t, err) +} diff --git a/factory/stateComponents_test.go b/factory/stateComponents_test.go deleted file mode 100644 index 8a19e9ddb5b..00000000000 --- a/factory/stateComponents_test.go +++ /dev/null @@ -1,303 +0,0 @@ -package factory_test - -import ( - "fmt" - "testing" - - "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/epochStart/bootstrap/disabled" - "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/state" - "github.com/ElrondNetwork/elrond-go/testscommon" - stateMock "github.com/ElrondNetwork/elrond-go/testscommon/storage" - "github.com/ElrondNetwork/elrond-go/trie" - trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" - "github.com/stretchr/testify/require" -) - -func TestNewStateComponentsFactory_NilShardCoordinatorShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - args.ShardCoordinator = nil - - scf, err := factory.NewStateComponentsFactory(args) - require.Nil(t, scf) - require.Equal(t, errors.ErrNilShardCoordinator, err) -} - -func TestNewStateComponentsFactory_NilCoreComponents(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - args.Core = nil - - scf, err := factory.NewStateComponentsFactory(args) - require.Nil(t, scf) - require.Equal(t, errors.ErrNilCoreComponents, err) -} - -func TestNewStateComponentsFactory_ShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - - scf, err := factory.NewStateComponentsFactory(args) - require.NoError(t, err) - require.NotNil(t, scf) -} - -func TestStateComponentsFactory_CreateShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - - scf, _ := factory.NewStateComponentsFactory(args) - - res, err := scf.Create() - require.NoError(t, err) - require.NotNil(t, res) -} - -// ------------ Test StateComponents -------------------- -func TestStateComponents_CloseShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - coreComponents := getCoreComponents() - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args := getStateArgs(coreComponents, shardCoordinator) - scf, _ := factory.NewStateComponentsFactory(args) - - sc, _ := scf.Create() - - err := sc.Close() - require.NoError(t, err) -} - -func getStateArgs(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) factory.StateComponentsFactoryArgs { - memDBUsers := mock.NewMemDbMock() - memdbPeers := mock.NewMemDbMock() - generalConfig := getGeneralConfig() - - storageManagerArgs, options := stateMock.GetStorageManagerArgsAndOptions() - storageManagerArgs.Marshalizer = coreComponents.InternalMarshalizer() - storageManagerArgs.Hasher = coreComponents.Hasher() - storageManagerArgs.MainStorer = memDBUsers - storageManagerArgs.CheckpointsStorer = memDBUsers - storageManagerArgs.GeneralConfig = generalConfig.TrieStorageManagerConfig - options.PruningEnabled = generalConfig.StateTriesConfig.AccountsStatePruningEnabled - options.SnapshotsEnabled = generalConfig.StateTriesConfig.SnapshotsEnabled - options.CheckpointsEnabled = generalConfig.StateTriesConfig.CheckpointsEnabled - - storageManagerUser, _ := trie.CreateTrieStorageManager(storageManagerArgs, options) - - storageManagerArgs.MainStorer = memdbPeers - storageManagerArgs.CheckpointsStorer = memdbPeers - options.PruningEnabled = generalConfig.StateTriesConfig.PeerStatePruningEnabled - storageManagerPeer, _ := trie.CreateTrieStorageManager(storageManagerArgs, options) - - trieStorageManagers := make(map[string]common.StorageManager) - trieStorageManagers[trieFactory.UserAccountTrie] = storageManagerUser - trieStorageManagers[trieFactory.PeerAccountTrie] = storageManagerPeer - - triesHolder := state.NewDataTriesHolder() - trieUsers, _ := trie.NewTrie(storageManagerUser, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), 5) - triePeers, _ := trie.NewTrie(storageManagerPeer, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), 5) - triesHolder.Put([]byte(trieFactory.UserAccountTrie), trieUsers) - triesHolder.Put([]byte(trieFactory.PeerAccountTrie), triePeers) - - stateComponentsFactoryArgs := factory.StateComponentsFactoryArgs{ - Config: generalConfig, - ShardCoordinator: shardCoordinator, - Core: coreComponents, - StorageService: disabled.NewChainStorer(), - ProcessingMode: common.Normal, - ChainHandler: &testscommon.ChainHandlerStub{}, - } - - return stateComponentsFactoryArgs -} - -func getGeneralConfig() config.Config { - return config.Config{ - AddressPubkeyConverter: config.PubkeyConfig{ - Length: 32, - Type: "hex", - SignatureLength: 0, - }, - ValidatorPubkeyConverter: config.PubkeyConfig{ - Length: 96, - Type: "hex", - SignatureLength: 0, - }, - StateTriesConfig: config.StateTriesConfig{ - CheckpointRoundsModulus: 5, - SnapshotsEnabled: true, - AccountsStatePruningEnabled: true, - PeerStatePruningEnabled: true, - MaxStateTrieLevelInMemory: 5, - MaxPeerTrieLevelInMemory: 5, - }, - EvictionWaitingList: config.EvictionWaitingListConfig{ - HashesSize: 100, - RootHashesSize: 100, - DB: config.DBConfig{ - FilePath: "EvictionWaitingList", - Type: "MemoryDB", - BatchDelaySeconds: 30, - MaxBatchSize: 6, - MaxOpenFiles: 10, - }, - }, - AccountsTrieStorage: config.StorageConfig{ - Cache: config.CacheConfig{ - Capacity: 10000, - Type: "LRU", - Shards: 1, - }, - DB: config.DBConfig{ - FilePath: "AccountsTrie/MainDB", - Type: "MemoryDB", - BatchDelaySeconds: 30, - MaxBatchSize: 6, - MaxOpenFiles: 10, - }, - }, - AccountsTrieCheckpointsStorage: config.StorageConfig{ - Cache: config.CacheConfig{ - Capacity: 10000, - Type: "LRU", - Shards: 1, - }, - DB: config.DBConfig{ - FilePath: "AccountsTrieCheckpoints", - Type: "MemoryDB", - BatchDelaySeconds: 30, - MaxBatchSize: 6, - MaxOpenFiles: 10, - }, - }, - PeerAccountsTrieStorage: config.StorageConfig{ - Cache: config.CacheConfig{ - Capacity: 10000, - Type: "LRU", - Shards: 1, - }, - DB: config.DBConfig{ - FilePath: "PeerAccountsTrie/MainDB", - Type: "MemoryDB", - BatchDelaySeconds: 30, - MaxBatchSize: 6, - MaxOpenFiles: 10, - }, - }, - PeerAccountsTrieCheckpointsStorage: config.StorageConfig{ - Cache: config.CacheConfig{ - Capacity: 10000, - Type: "LRU", - Shards: 1, - }, - DB: config.DBConfig{ - FilePath: "PeerAccountsTrieCheckpoints", - Type: "MemoryDB", - BatchDelaySeconds: 30, - MaxBatchSize: 6, - MaxOpenFiles: 10, - }, - }, - TrieStorageManagerConfig: config.TrieStorageManagerConfig{ - PruningBufferLen: 1000, - SnapshotsBufferLen: 10, - SnapshotsGoroutineNum: 1, - }, - VirtualMachine: config.VirtualMachineServicesConfig{ - Querying: config.QueryVirtualMachineConfig{ - NumConcurrentVMs: 1, - VirtualMachineConfig: config.VirtualMachineConfig{ - ArwenVersions: []config.ArwenVersionByEpoch{ - {StartEpoch: 0, Version: "v0.3"}, - }, - }, - }, - Execution: config.VirtualMachineConfig{ - ArwenVersions: []config.ArwenVersionByEpoch{ - {StartEpoch: 0, Version: "v0.3"}, - }, - }, - GasConfig: config.VirtualMachineGasConfig{ - ShardMaxGasPerVmQuery: 1_500_000_000, - MetaMaxGasPerVmQuery: 0, - }, - }, - SmartContractsStorageForSCQuery: config.StorageConfig{ - Cache: config.CacheConfig{ - Capacity: 10000, - Type: "LRU", - Shards: 1, - }, - }, - SmartContractDataPool: config.CacheConfig{ - Capacity: 10000, - Type: "LRU", - Shards: 1, - }, - PeersRatingConfig: config.PeersRatingConfig{ - TopRatedCacheCapacity: 1000, - BadRatedCacheCapacity: 1000, - }, - PoolsCleanersConfig: config.PoolsCleanersConfig{ - MaxRoundsToKeepUnprocessedMiniBlocks: 50, - MaxRoundsToKeepUnprocessedTransactions: 50, - }, - BuiltInFunctions: config.BuiltInFunctionsConfig{ - AutomaticCrawlerAddresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 - }, - MaxNumAddressesInTransferRole: 100, - }, - } -} - -func getCoreComponents() factory.CoreComponentsHolder { - coreArgs := getCoreArgs() - coreComponentsFactory, _ := factory.NewCoreComponentsFactory(coreArgs) - coreComponents, err := factory.NewManagedCoreComponents(coreComponentsFactory) - if err != nil { - fmt.Println("getCoreComponents NewManagedCoreComponents", "error", err.Error()) - return nil - } - err = coreComponents.Create() - if err != nil { - fmt.Println("getCoreComponents Create", "error", err.Error()) - } - return coreComponents -} diff --git a/factory/statusComponents.go b/factory/status/statusComponents.go similarity index 87% rename from factory/statusComponents.go rename to factory/status/statusComponents.go index 60408c6e801..a782e393e4e 100644 --- a/factory/statusComponents.go +++ b/factory/status/statusComponents.go @@ -1,4 +1,4 @@ -package factory +package status import ( "context" @@ -9,13 +9,15 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" nodeData "github.com/ElrondNetwork/elrond-go-core/data" + logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/common/statistics" - "github.com/ElrondNetwork/elrond-go/common/statistics/softwareVersion/factory" + swVersionFactory "github.com/ElrondNetwork/elrond-go/common/statistics/softwareVersion/factory" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/outport" outportDriverFactory "github.com/ElrondNetwork/elrond-go/outport/factory" "github.com/ElrondNetwork/elrond-go/process" @@ -42,11 +44,11 @@ type StatusComponentsFactoryArgs struct { EconomicsConfig config.EconomicsConfig ShardCoordinator sharding.Coordinator NodesCoordinator nodesCoordinator.NodesCoordinator - EpochStartNotifier EpochStartNotifier - CoreComponents CoreComponentsHolder - DataComponents DataComponentsHolder - NetworkComponents NetworkComponentsHolder - StateComponents StateComponentsHolder + EpochStartNotifier factory.EpochStartNotifier + CoreComponents factory.CoreComponentsHolder + DataComponents factory.DataComponentsHolder + NetworkComponents factory.NetworkComponentsHolder + StateComponents factory.StateComponentsHolder IsInImportMode bool } @@ -56,15 +58,17 @@ type statusComponentsFactory struct { economicsConfig config.EconomicsConfig shardCoordinator sharding.Coordinator nodesCoordinator nodesCoordinator.NodesCoordinator - epochStartNotifier EpochStartNotifier + epochStartNotifier factory.EpochStartNotifier forkDetector process.ForkDetector - coreComponents CoreComponentsHolder - dataComponents DataComponentsHolder - networkComponents NetworkComponentsHolder - stateComponents StateComponentsHolder + coreComponents factory.CoreComponentsHolder + dataComponents factory.DataComponentsHolder + networkComponents factory.NetworkComponentsHolder + stateComponents factory.StateComponentsHolder isInImportMode bool } +var log = logger.GetOrCreate("factory") + // NewStatusComponentsFactory will return a status components factory func NewStatusComponentsFactory(args StatusComponentsFactoryArgs) (*statusComponentsFactory, error) { if check.IfNil(args.CoreComponents) { @@ -123,7 +127,7 @@ func (scf *statusComponentsFactory) Create() (*statusComponents, error) { } log.Trace("creating software checker structure") - softwareVersionCheckerFactory, err := factory.NewSoftwareVersionFactory( + softwareVersionCheckerFactory, err := swVersionFactory.NewSoftwareVersionFactory( scf.coreComponents.StatusHandler(), scf.config.SoftwareVersionConfig, ) @@ -243,14 +247,15 @@ func (scf *statusComponentsFactory) makeElasticIndexerArgs() *indexerFactory.Arg func (scf *statusComponentsFactory) makeEventNotifierArgs() *outportDriverFactory.EventNotifierFactoryArgs { eventNotifierConfig := scf.externalConfig.EventNotifierConnector return &outportDriverFactory.EventNotifierFactoryArgs{ - Enabled: eventNotifierConfig.Enabled, - UseAuthorization: eventNotifierConfig.UseAuthorization, - ProxyUrl: eventNotifierConfig.ProxyUrl, - Username: eventNotifierConfig.Username, - Password: eventNotifierConfig.Password, - Marshaller: scf.coreComponents.InternalMarshalizer(), - Hasher: scf.coreComponents.Hasher(), - PubKeyConverter: scf.coreComponents.AddressPubKeyConverter(), + Enabled: eventNotifierConfig.Enabled, + UseAuthorization: eventNotifierConfig.UseAuthorization, + ProxyUrl: eventNotifierConfig.ProxyUrl, + Username: eventNotifierConfig.Username, + Password: eventNotifierConfig.Password, + RequestTimeoutSec: eventNotifierConfig.RequestTimeoutSec, + Marshaller: scf.coreComponents.InternalMarshalizer(), + Hasher: scf.coreComponents.Hasher(), + PubKeyConverter: scf.coreComponents.AddressPubKeyConverter(), } } diff --git a/factory/statusComponentsHandler.go b/factory/status/statusComponentsHandler.go similarity index 97% rename from factory/statusComponentsHandler.go rename to factory/status/statusComponentsHandler.go index ccf6ae773ba..ea7919a1ff3 100644 --- a/factory/statusComponentsHandler.go +++ b/factory/status/statusComponentsHandler.go @@ -1,4 +1,4 @@ -package factory +package status import ( "context" @@ -18,6 +18,7 @@ import ( "github.com/ElrondNetwork/elrond-go/debug/goroutine" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/outport" "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" @@ -25,9 +26,9 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" ) -var _ ComponentHandler = (*managedStatusComponents)(nil) -var _ StatusComponentsHolder = (*managedStatusComponents)(nil) -var _ StatusComponentsHandler = (*managedStatusComponents)(nil) +var _ factory.ComponentHandler = (*managedStatusComponents)(nil) +var _ factory.StatusComponentsHolder = (*managedStatusComponents)(nil) +var _ factory.StatusComponentsHandler = (*managedStatusComponents)(nil) type managedStatusComponents struct { *statusComponents @@ -197,7 +198,7 @@ func (msc *managedStatusComponents) startStatusPolling(ctx context.Context) erro func registerPollConnectedPeers( appStatusPollingHandler *appStatusPolling.AppStatusPolling, - networkComponents NetworkComponentsHolder, + networkComponents factory.NetworkComponentsHolder, ) error { p2pMetricsHandlerFunc := func(appStatusHandler core.AppStatusHandler) { @@ -420,7 +421,7 @@ func registerCpuStatistics(ctx context.Context, appStatusPollingHandler *appStat // String returns the name of the component func (msc *managedStatusComponents) String() string { - return statusComponentsName + return factory.StatusComponentsName } func (msc *managedStatusComponents) attachEpochGoRoutineAnalyser() { diff --git a/factory/statusComponentsHandler_test.go b/factory/status/statusComponentsHandler_test.go similarity index 57% rename from factory/statusComponentsHandler_test.go rename to factory/status/statusComponentsHandler_test.go index d144ca9ceb5..844dcce236d 100644 --- a/factory/statusComponentsHandler_test.go +++ b/factory/status/statusComponentsHandler_test.go @@ -1,10 +1,11 @@ -package factory_test +package status_test import ( "testing" - "github.com/ElrondNetwork/elrond-go/factory" "github.com/ElrondNetwork/elrond-go/factory/mock" + statusComp "github.com/ElrondNetwork/elrond-go/factory/status" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" "github.com/stretchr/testify/require" ) @@ -16,12 +17,12 @@ func TestManagedStatusComponents_CreateWithInvalidArgsShouldErr(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - statusArgs, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - coreComponents := getDefaultCoreComponents() + statusArgs, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + coreComponents := componentsMock.GetDefaultCoreComponents() statusArgs.CoreComponents = coreComponents - statusComponentsFactory, _ := factory.NewStatusComponentsFactory(statusArgs) - managedStatusComponents, err := factory.NewManagedStatusComponents(statusComponentsFactory) + statusComponentsFactory, _ := statusComp.NewStatusComponentsFactory(statusArgs) + managedStatusComponents, err := statusComp.NewManagedStatusComponents(statusComponentsFactory) require.NoError(t, err) coreComponents.AppStatusHdl = nil @@ -36,9 +37,9 @@ func TestManagedStatusComponents_CreateShouldWork(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - statusArgs, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - statusComponentsFactory, _ := factory.NewStatusComponentsFactory(statusArgs) - managedStatusComponents, err := factory.NewManagedStatusComponents(statusComponentsFactory) + statusArgs, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + statusComponentsFactory, _ := statusComp.NewStatusComponentsFactory(statusArgs) + managedStatusComponents, err := statusComp.NewManagedStatusComponents(statusComponentsFactory) require.NoError(t, err) require.Nil(t, managedStatusComponents.OutportHandler()) require.Nil(t, managedStatusComponents.SoftwareVersionChecker()) @@ -56,9 +57,9 @@ func TestManagedStatusComponents_Close(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - statusArgs, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - statusComponentsFactory, _ := factory.NewStatusComponentsFactory(statusArgs) - managedStatusComponents, _ := factory.NewManagedStatusComponents(statusComponentsFactory) + statusArgs, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + statusComponentsFactory, _ := statusComp.NewStatusComponentsFactory(statusArgs) + managedStatusComponents, _ := statusComp.NewManagedStatusComponents(statusComponentsFactory) err := managedStatusComponents.Create() require.NoError(t, err) @@ -73,9 +74,9 @@ func TestManagedStatusComponents_CheckSubcomponents(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - statusArgs, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - statusComponentsFactory, _ := factory.NewStatusComponentsFactory(statusArgs) - managedStatusComponents, _ := factory.NewManagedStatusComponents(statusComponentsFactory) + statusArgs, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + statusComponentsFactory, _ := statusComp.NewStatusComponentsFactory(statusArgs) + managedStatusComponents, _ := statusComp.NewManagedStatusComponents(statusComponentsFactory) err := managedStatusComponents.Create() require.NoError(t, err) diff --git a/factory/status/statusComponents_test.go b/factory/status/statusComponents_test.go new file mode 100644 index 00000000000..f0dbca1c0bc --- /dev/null +++ b/factory/status/statusComponents_test.go @@ -0,0 +1,174 @@ +package status_test + +import ( + "testing" + + "github.com/ElrondNetwork/elrond-go-core/core/check" + "github.com/ElrondNetwork/elrond-go/config" + "github.com/ElrondNetwork/elrond-go/errors" + coreComp "github.com/ElrondNetwork/elrond-go/factory/core" + "github.com/ElrondNetwork/elrond-go/factory/mock" + statusComp "github.com/ElrondNetwork/elrond-go/factory/status" + "github.com/ElrondNetwork/elrond-go/testscommon" + componentsMock "github.com/ElrondNetwork/elrond-go/testscommon/components" + "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewStatusComponentsFactory_NilCoreComponentsShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + args.CoreComponents = nil + scf, err := statusComp.NewStatusComponentsFactory(args) + assert.True(t, check.IfNil(scf)) + assert.Equal(t, errors.ErrNilCoreComponentsHolder, err) +} + +func TestNewStatusComponentsFactory_NilNodesCoordinatorShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + args.NodesCoordinator = nil + scf, err := statusComp.NewStatusComponentsFactory(args) + assert.True(t, check.IfNil(scf)) + assert.Equal(t, errors.ErrNilNodesCoordinator, err) +} + +func TestNewStatusComponentsFactory_NilEpochStartNotifierShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + args.EpochStartNotifier = nil + scf, err := statusComp.NewStatusComponentsFactory(args) + assert.True(t, check.IfNil(scf)) + assert.Equal(t, errors.ErrNilEpochStartNotifier, err) +} + +func TestNewStatusComponentsFactory_NilNetworkComponentsShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + args.NetworkComponents = nil + scf, err := statusComp.NewStatusComponentsFactory(args) + assert.True(t, check.IfNil(scf)) + assert.Equal(t, errors.ErrNilNetworkComponentsHolder, err) +} + +func TestNewStatusComponentsFactory_NilShardCoordinatorShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + args.ShardCoordinator = nil + scf, err := statusComp.NewStatusComponentsFactory(args) + assert.True(t, check.IfNil(scf)) + assert.Equal(t, errors.ErrNilShardCoordinator, err) +} + +func TestNewStatusComponents_InvalidRoundDurationShouldErr(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + coreArgs := componentsMock.GetCoreArgs() + coreArgs.NodesFilename = "../mock/testdata/nodesSetupMockInvalidRound.json" + coreComponentsFactory, _ := coreComp.NewCoreComponentsFactory(coreArgs) + coreComponents, err := coreComp.NewManagedCoreComponents(coreComponentsFactory) + require.Nil(t, err) + require.NotNil(t, coreComponents) + err = coreComponents.Create() + require.Nil(t, err) + networkComponents := componentsMock.GetNetworkComponents() + dataComponents := componentsMock.GetDataComponents(coreComponents, shardCoordinator) + stateComponents := componentsMock.GetStateComponents(coreComponents, shardCoordinator) + + statusArgs := statusComp.StatusComponentsFactoryArgs{ + Config: testscommon.GetGeneralConfig(), + ExternalConfig: config.ExternalConfig{}, + ShardCoordinator: shardCoordinator, + NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, + EpochStartNotifier: &mock.EpochStartNotifierStub{}, + CoreComponents: coreComponents, + DataComponents: dataComponents, + NetworkComponents: networkComponents, + StateComponents: stateComponents, + IsInImportMode: false, + EconomicsConfig: config.EconomicsConfig{}, + } + scf, err := statusComp.NewStatusComponentsFactory(statusArgs) + assert.Nil(t, err) + assert.NotNil(t, scf) + + statusComponents, err := scf.Create() + assert.Nil(t, statusComponents) + assert.Equal(t, errors.ErrInvalidRoundDuration, err) +} + +func TestNewStatusComponentsFactory_ShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + scf, err := statusComp.NewStatusComponentsFactory(args) + require.NoError(t, err) + require.False(t, check.IfNil(scf)) +} + +func TestStatusComponentsFactory_Create(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + args, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + scf, err := statusComp.NewStatusComponentsFactory(args) + require.Nil(t, err) + + res, err := scf.Create() + require.NoError(t, err) + require.NotNil(t, res) +} + +// ------------ Test StatusComponents -------------------- +func TestStatusComponents_CloseShouldWork(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("this is not a short test") + } + + shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) + statusArgs, _ := componentsMock.GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) + scf, _ := statusComp.NewStatusComponentsFactory(statusArgs) + cc, err := scf.Create() + require.Nil(t, err) + + err = cc.Close() + require.NoError(t, err) +} diff --git a/factory/statusComponents_test.go b/factory/statusComponents_test.go deleted file mode 100644 index c3676b2e5e2..00000000000 --- a/factory/statusComponents_test.go +++ /dev/null @@ -1,351 +0,0 @@ -package factory_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/mock" - "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var log = logger.GetOrCreate("factory/factory_test") - -func TestNewStatusComponentsFactory_NilCoreComponentsShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - args.CoreComponents = nil - scf, err := factory.NewStatusComponentsFactory(args) - assert.True(t, check.IfNil(scf)) - assert.Equal(t, errors.ErrNilCoreComponentsHolder, err) -} - -func TestNewStatusComponentsFactory_NilNodesCoordinatorShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - args.NodesCoordinator = nil - scf, err := factory.NewStatusComponentsFactory(args) - assert.True(t, check.IfNil(scf)) - assert.Equal(t, errors.ErrNilNodesCoordinator, err) -} - -func TestNewStatusComponentsFactory_NilEpochStartNotifierShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - args.EpochStartNotifier = nil - scf, err := factory.NewStatusComponentsFactory(args) - assert.True(t, check.IfNil(scf)) - assert.Equal(t, errors.ErrNilEpochStartNotifier, err) -} - -func TestNewStatusComponentsFactory_NilNetworkComponentsShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - args.NetworkComponents = nil - scf, err := factory.NewStatusComponentsFactory(args) - assert.True(t, check.IfNil(scf)) - assert.Equal(t, errors.ErrNilNetworkComponentsHolder, err) -} - -func TestNewStatusComponentsFactory_NilShardCoordinatorShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - args.ShardCoordinator = nil - scf, err := factory.NewStatusComponentsFactory(args) - assert.True(t, check.IfNil(scf)) - assert.Equal(t, errors.ErrNilShardCoordinator, err) -} - -func TestNewStatusComponents_InvalidRoundDurationShouldErr(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - coreArgs := getCoreArgs() - coreArgs.NodesFilename = "mock/testdata/nodesSetupMockInvalidRound.json" - coreComponentsFactory, _ := factory.NewCoreComponentsFactory(coreArgs) - coreComponents, err := factory.NewManagedCoreComponents(coreComponentsFactory) - require.Nil(t, err) - require.NotNil(t, coreComponents) - err = coreComponents.Create() - require.Nil(t, err) - networkComponents := getNetworkComponents() - dataComponents := getDataComponents(coreComponents, shardCoordinator) - stateComponents := getStateComponents(coreComponents, shardCoordinator) - - statusArgs := factory.StatusComponentsFactoryArgs{ - Config: testscommon.GetGeneralConfig(), - ExternalConfig: config.ExternalConfig{}, - ShardCoordinator: shardCoordinator, - NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, - EpochStartNotifier: &mock.EpochStartNotifierStub{}, - CoreComponents: coreComponents, - DataComponents: dataComponents, - NetworkComponents: networkComponents, - StateComponents: stateComponents, - IsInImportMode: false, - EconomicsConfig: config.EconomicsConfig{}, - } - scf, err := factory.NewStatusComponentsFactory(statusArgs) - assert.Nil(t, err) - assert.NotNil(t, scf) - - statusComponents, err := scf.Create() - assert.Nil(t, statusComponents) - assert.Equal(t, errors.ErrInvalidRoundDuration, err) -} - -func TestNewStatusComponentsFactory_ShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - scf, err := factory.NewStatusComponentsFactory(args) - require.NoError(t, err) - require.False(t, check.IfNil(scf)) -} - -func TestStatusComponentsFactory_Create(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - args, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - scf, err := factory.NewStatusComponentsFactory(args) - require.Nil(t, err) - - res, err := scf.Create() - require.NoError(t, err) - require.NotNil(t, res) -} - -// ------------ Test StatusComponents -------------------- -func TestStatusComponents_CloseShouldWork(t *testing.T) { - t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") - } - - shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) - statusArgs, _ := getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator) - scf, _ := factory.NewStatusComponentsFactory(statusArgs) - cc, err := scf.Create() - require.Nil(t, err) - - err = cc.Close() - require.NoError(t, err) -} - -func getStatusComponents( - coreComponents factory.CoreComponentsHolder, - networkComponents factory.NetworkComponentsHolder, - dataComponents factory.DataComponentsHolder, - stateComponents factory.StateComponentsHolder, - shardCoordinator sharding.Coordinator, - nodesCoordinator nodesCoordinator.NodesCoordinator, -) factory.StatusComponentsHandler { - indexerURL := "url" - elasticUsername := "user" - elasticPassword := "pass" - statusArgs := factory.StatusComponentsFactoryArgs{ - Config: testscommon.GetGeneralConfig(), - ExternalConfig: config.ExternalConfig{ - ElasticSearchConnector: config.ElasticSearchConfig{ - Enabled: false, - URL: indexerURL, - Username: elasticUsername, - Password: elasticPassword, - EnabledIndexes: []string{"transactions", "blocks"}, - }, - }, - EconomicsConfig: config.EconomicsConfig{}, - ShardCoordinator: shardCoordinator, - NodesCoordinator: nodesCoordinator, - EpochStartNotifier: coreComponents.EpochStartNotifierWithConfirm(), - CoreComponents: coreComponents, - DataComponents: dataComponents, - NetworkComponents: networkComponents, - StateComponents: stateComponents, - IsInImportMode: false, - } - - statusComponentsFactory, _ := factory.NewStatusComponentsFactory(statusArgs) - managedStatusComponents, err := factory.NewManagedStatusComponents(statusComponentsFactory) - if err != nil { - log.Error("getStatusComponents NewManagedStatusComponents", "error", err.Error()) - return nil - } - err = managedStatusComponents.Create() - if err != nil { - log.Error("getStatusComponents Create", "error", err.Error()) - return nil - } - return managedStatusComponents -} - -func getStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator sharding.Coordinator) (factory.StatusComponentsFactoryArgs, factory.ProcessComponentsHolder) { - coreComponents := getCoreComponents() - networkComponents := getNetworkComponents() - dataComponents := getDataComponents(coreComponents, shardCoordinator) - cryptoComponents := getCryptoComponents(coreComponents) - stateComponents := getStateComponents(coreComponents, shardCoordinator) - processComponents := getProcessComponents( - shardCoordinator, - coreComponents, - networkComponents, - dataComponents, - cryptoComponents, - stateComponents, - ) - - indexerURL := "url" - elasticUsername := "user" - elasticPassword := "pass" - return factory.StatusComponentsFactoryArgs{ - Config: testscommon.GetGeneralConfig(), - ExternalConfig: config.ExternalConfig{ - ElasticSearchConnector: config.ElasticSearchConfig{ - Enabled: false, - URL: indexerURL, - Username: elasticUsername, - Password: elasticPassword, - EnabledIndexes: []string{"transactions", "blocks"}, - }, - }, - EconomicsConfig: config.EconomicsConfig{}, - ShardCoordinator: mock.NewMultiShardsCoordinatorMock(2), - NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, - EpochStartNotifier: &mock.EpochStartNotifierStub{}, - CoreComponents: coreComponents, - DataComponents: dataComponents, - NetworkComponents: networkComponents, - StateComponents: stateComponents, - IsInImportMode: false, - }, processComponents -} - -func getNetworkComponents() factory.NetworkComponentsHolder { - networkArgs := getNetworkArgs() - networkComponentsFactory, _ := factory.NewNetworkComponentsFactory(networkArgs) - networkComponents, _ := factory.NewManagedNetworkComponents(networkComponentsFactory) - - _ = networkComponents.Create() - - return networkComponents -} - -func getDataComponents(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) factory.DataComponentsHolder { - dataArgs := getDataArgs(coreComponents, shardCoordinator) - dataComponentsFactory, _ := factory.NewDataComponentsFactory(dataArgs) - dataComponents, _ := factory.NewManagedDataComponents(dataComponentsFactory) - _ = dataComponents.Create() - return dataComponents -} - -func getCryptoComponents(coreComponents factory.CoreComponentsHolder) factory.CryptoComponentsHolder { - cryptoArgs := getCryptoArgs(coreComponents) - cryptoComponentsFactory, _ := factory.NewCryptoComponentsFactory(cryptoArgs) - cryptoComponents, err := factory.NewManagedCryptoComponents(cryptoComponentsFactory) - if err != nil { - log.Error("getCryptoComponents NewManagedCryptoComponents", "error", err.Error()) - return nil - } - - err = cryptoComponents.Create() - if err != nil { - log.Error("getCryptoComponents Create", "error", err.Error()) - return nil - } - return cryptoComponents -} - -func getStateComponents(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) factory.StateComponentsHolder { - stateArgs := getStateArgs(coreComponents, shardCoordinator) - stateComponentsFactory, err := factory.NewStateComponentsFactory(stateArgs) - if err != nil { - log.Error("getStateComponents NewStateComponentsFactory", "error", err.Error()) - return nil - } - - stateComponents, err := factory.NewManagedStateComponents(stateComponentsFactory) - if err != nil { - log.Error("getStateComponents NewManagedStateComponents", "error", err.Error()) - return nil - } - err = stateComponents.Create() - if err != nil { - log.Error("getStateComponents Create", "error", err.Error()) - return nil - } - return stateComponents -} - -func getProcessComponents( - shardCoordinator sharding.Coordinator, - coreComponents factory.CoreComponentsHolder, - networkComponents factory.NetworkComponentsHolder, - dataComponents factory.DataComponentsHolder, - cryptoComponents factory.CryptoComponentsHolder, - stateComponents factory.StateComponentsHolder, -) factory.ProcessComponentsHolder { - processArgs := getProcessArgs( - shardCoordinator, - coreComponents, - dataComponents, - cryptoComponents, - stateComponents, - networkComponents, - ) - processComponentsFactory, _ := factory.NewProcessComponentsFactory(processArgs) - managedProcessComponents, err := factory.NewManagedProcessComponents(processComponentsFactory) - if err != nil { - log.Error("getProcessComponents NewManagedProcessComponents", "error", err.Error()) - return nil - } - err = managedProcessComponents.Create() - if err != nil { - log.Error("getProcessComponents Create", "error", err.Error()) - return nil - } - return managedProcessComponents -} diff --git a/genesis/mock/userAccountMock.go b/genesis/mock/userAccountMock.go index 90d1aa77a33..f2060c98e00 100644 --- a/genesis/mock/userAccountMock.go +++ b/genesis/mock/userAccountMock.go @@ -5,7 +5,6 @@ import ( "math/big" "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/state" ) // ErrNegativeValue - @@ -47,7 +46,7 @@ func (uam *UserAccountMock) SetCodeHash(bytes []byte) { } // GetCodeHash - -func (uam UserAccountMock) GetCodeHash() []byte { +func (uam *UserAccountMock) GetCodeHash() []byte { return uam.codeHash } @@ -66,17 +65,17 @@ func (uam *UserAccountMock) SetDataTrie(_ common.Trie) { } // DataTrie - -func (uam *UserAccountMock) DataTrie() common.Trie { +func (uam *UserAccountMock) DataTrie() common.DataTrieHandler { return nil } -// RetrieveValueFromDataTrieTracker - -func (uam *UserAccountMock) RetrieveValueFromDataTrieTracker(_ []byte) ([]byte, error) { +// RetrieveValue - +func (uam *UserAccountMock) RetrieveValue(_ []byte) ([]byte, error) { return nil, nil } -// DataTrieTracker - -func (uam *UserAccountMock) DataTrieTracker() state.DataTrieTracker { +// SaveKeyValue - +func (uam *UserAccountMock) SaveKeyValue(_ []byte, _ []byte) error { return nil } @@ -139,3 +138,8 @@ func (uam *UserAccountMock) SetUserName(_ []byte) { func (uam *UserAccountMock) GetUserName() []byte { return nil } + +// SaveDirtyData - +func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) (map[string][]byte, error) { + return nil, nil +} diff --git a/genesis/process/genesisBlockCreator.go b/genesis/process/genesisBlockCreator.go index 2d596760244..629177b235b 100644 --- a/genesis/process/genesisBlockCreator.go +++ b/genesis/process/genesisBlockCreator.go @@ -26,14 +26,13 @@ import ( "github.com/ElrondNetwork/elrond-go/statusHandler" "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" triesFactory "github.com/ElrondNetwork/elrond-go/trie/factory" "github.com/ElrondNetwork/elrond-go/update" + hardfork "github.com/ElrondNetwork/elrond-go/update/genesis" hardForkProcess "github.com/ElrondNetwork/elrond-go/update/process" "github.com/ElrondNetwork/elrond-go/update/storing" vmcommonBuiltInFunctions "github.com/ElrondNetwork/elrond-vm-common/builtInFunctions" - - hardfork "github.com/ElrondNetwork/elrond-go/update/genesis" ) const accountStartNonce = uint64(0) @@ -128,7 +127,7 @@ func (gbc *genesisBlockCreator) createHardForkImportHandler() error { func createStorer(storageConfig config.StorageConfig, folder string) (storage.Storer, error) { dbConfig := factory.GetDBFromConfig(storageConfig.DB) dbConfig.FilePath = path.Join(folder, storageConfig.DB.FilePath) - store, err := storageUnit.NewStorageUnitFromConf( + store, err := storageunit.NewStorageUnitFromConf( factory.GetCacherFromConfig(storageConfig.Cache), dbConfig, ) diff --git a/go.mod b/go.mod index 81e62c96fab..54a0387c0cd 100644 --- a/go.mod +++ b/go.mod @@ -6,41 +6,31 @@ require ( github.com/ElrondNetwork/arwen-wasm-vm/v1_2 v1.2.42-0.20220729115258-b9f2fb2f6568 github.com/ElrondNetwork/arwen-wasm-vm/v1_3 v1.3.42-0.20220729115131-85ecca868e90 github.com/ElrondNetwork/arwen-wasm-vm/v1_4 v1.4.59-0.20220729115431-a6c93119bdda - github.com/ElrondNetwork/concurrent-map v0.1.3 github.com/ElrondNetwork/covalent-indexer-go v1.0.6 - github.com/ElrondNetwork/elastic-indexer-go v1.2.39 - github.com/ElrondNetwork/elrond-go-core v1.1.19 - github.com/ElrondNetwork/elrond-go-crypto v1.0.1 + github.com/ElrondNetwork/elastic-indexer-go v1.2.42 + github.com/ElrondNetwork/elrond-go-core v1.1.20 + github.com/ElrondNetwork/elrond-go-crypto v1.2.1 github.com/ElrondNetwork/elrond-go-logger v1.0.7 - github.com/ElrondNetwork/elrond-vm-common v1.3.16 - github.com/ElrondNetwork/go-libp2p-pubsub v0.6.1-rc1 + github.com/ElrondNetwork/elrond-go-p2p v1.0.1 + github.com/ElrondNetwork/elrond-go-storage v1.0.1 + github.com/ElrondNetwork/elrond-vm-common v1.3.17 github.com/beevik/ntp v0.3.0 - github.com/btcsuite/btcd v0.22.0-beta github.com/davecgh/go-spew v1.1.1 github.com/elastic/go-elasticsearch/v7 v7.12.0 github.com/gin-contrib/cors v0.0.0-20190301062745-f9e10995c85a - github.com/gin-contrib/pprof v1.3.0 - github.com/gin-gonic/gin v1.8.0 + github.com/gin-contrib/pprof v1.4.0 + github.com/gin-gonic/gin v1.8.1 github.com/gizak/termui/v3 v3.1.0 github.com/gogo/protobuf v1.3.2 github.com/google/gops v0.3.18 github.com/gorilla/websocket v1.5.0 - github.com/hashicorp/golang-lru v0.5.4 - github.com/ipfs/go-log v1.0.5 - github.com/jbenet/goprocess v0.1.4 - github.com/libp2p/go-libp2p v0.19.3 github.com/libp2p/go-libp2p-core v0.15.1 - github.com/libp2p/go-libp2p-kad-dht v0.15.0 - github.com/libp2p/go-libp2p-kbucket v0.4.7 github.com/mitchellh/mapstructure v1.5.0 - github.com/multiformats/go-multiaddr v0.5.0 github.com/pelletier/go-toml v1.9.3 github.com/pkg/errors v0.9.1 github.com/shirou/gopsutil v3.21.11+incompatible github.com/stretchr/testify v1.7.1 - github.com/syndtr/goleveldb v1.0.1-0.20190318030020-c3a204f8e965 github.com/urfave/cli v1.22.10 - github.com/whyrusleeping/timecache v0.0.0-20160911033111-cfcb2f1abfee github.com/yusufpapurcu/wmi v1.2.2 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 diff --git a/go.sum b/go.sum index 5debb7db97f..ab847f64344 100644 --- a/go.sum +++ b/go.sum @@ -55,27 +55,32 @@ github.com/ElrondNetwork/concurrent-map v0.1.3 h1:j2LtPrNJuerannC1cQDE79STvi/P04 github.com/ElrondNetwork/concurrent-map v0.1.3/go.mod h1:3XwSwn4JHI0lrKxWLZvtp53Emr8BXYTmNQGwcukHJEE= github.com/ElrondNetwork/covalent-indexer-go v1.0.6 h1:+LNKItUc+Pb7WuTbil3VuiLMmdQ1AY7lBJM476PtVNE= github.com/ElrondNetwork/covalent-indexer-go v1.0.6/go.mod h1:j3h2g96vqhJAuj3aEX2PWhomae2/o7YfXGEfweNXEeQ= -github.com/ElrondNetwork/elastic-indexer-go v1.2.39 h1:NnhTF6yVnzAQNC7JibeGvR3anUSiA1I5UbWU9sn/U5E= -github.com/ElrondNetwork/elastic-indexer-go v1.2.39/go.mod h1:w+J48ssy1kxOawG2lwiOUR4JYPA092g8Zjk88kRVDNA= +github.com/ElrondNetwork/elastic-indexer-go v1.2.42 h1:PZi4XupvTrHWqaDiNpRrCuH0h9SHz8BiBb9KDFtGvi4= +github.com/ElrondNetwork/elastic-indexer-go v1.2.42/go.mod h1:q0SJzaMI5kTjDoi+sSqBrLmu8XUbyTOt5k865EyczPo= github.com/ElrondNetwork/elrond-go-core v1.0.0/go.mod h1:FQMem7fFF4+8pQ6lVsBZq6yO+smD0nV23P4bJpmPjTo= github.com/ElrondNetwork/elrond-go-core v1.1.7/go.mod h1:O9FkkTT2H9kxCzfn40TbhoCDXzGmUrRVusMomhK/Y3g= github.com/ElrondNetwork/elrond-go-core v1.1.13/go.mod h1:Yz8JK5sGBctw7+gU8j2mZHbzQ09Ek4XHJ4Uinq1N6nM= -github.com/ElrondNetwork/elrond-go-core v1.1.16-0.20220414130405-e3cc29bc7711/go.mod h1:Yz8JK5sGBctw7+gU8j2mZHbzQ09Ek4XHJ4Uinq1N6nM= -github.com/ElrondNetwork/elrond-go-core v1.1.19 h1:1wRYaG/eb7vtPIYqULwhj9ANPfmPM6yX40OYgI5h2nk= +github.com/ElrondNetwork/elrond-go-core v1.1.18/go.mod h1:Yz8JK5sGBctw7+gU8j2mZHbzQ09Ek4XHJ4Uinq1N6nM= github.com/ElrondNetwork/elrond-go-core v1.1.19/go.mod h1:Yz8JK5sGBctw7+gU8j2mZHbzQ09Ek4XHJ4Uinq1N6nM= +github.com/ElrondNetwork/elrond-go-core v1.1.20 h1:2JbXK7BrgbMm+q9xgyW80H8Ljn+HcGfbdXZp8YuFPXc= +github.com/ElrondNetwork/elrond-go-core v1.1.20/go.mod h1:Yz8JK5sGBctw7+gU8j2mZHbzQ09Ek4XHJ4Uinq1N6nM= github.com/ElrondNetwork/elrond-go-crypto v1.0.0/go.mod h1:DGiR7/j1xv729Xg8SsjYaUzWXL5svMd44REXjWS/gAc= -github.com/ElrondNetwork/elrond-go-crypto v1.0.1 h1:xJUUshIZQ7h+rG7Art/9QHVyaPRV1wEjrxXYBdpmRlM= -github.com/ElrondNetwork/elrond-go-crypto v1.0.1/go.mod h1:uunsvweBrrhVojL8uiQSaTPsl3YIQ9iBqtYGM6xs4s0= +github.com/ElrondNetwork/elrond-go-crypto v1.2.1 h1:5wWCBEZp5SMFO2+Nal8UaJNJcG9G9J4PHNNZvQpEeUE= +github.com/ElrondNetwork/elrond-go-crypto v1.2.1/go.mod h1:UNmpDaJjLTKxfzUcwua2R7Mh9bicw/L3ICJt5V7zqMo= github.com/ElrondNetwork/elrond-go-logger v1.0.4/go.mod h1:e5D+c97lKUfFdAzFX7rrI2Igl/z4Y0RkKYKWyzprTGk= github.com/ElrondNetwork/elrond-go-logger v1.0.5/go.mod h1:cBfgx0ST/CJx8jrxJSC5aiSrvkGzcnF7sK06RD8mFxQ= github.com/ElrondNetwork/elrond-go-logger v1.0.7 h1:Ldl1rVS0RGKc1IsW8jIaGCb6Zwei04gsMvyjL05X6mE= github.com/ElrondNetwork/elrond-go-logger v1.0.7/go.mod h1:cBfgx0ST/CJx8jrxJSC5aiSrvkGzcnF7sK06RD8mFxQ= +github.com/ElrondNetwork/elrond-go-p2p v1.0.1 h1:1ZwkIL3LVBUt1oPDvl1VKNA3f7muW2D1Wh3AW4YokwY= +github.com/ElrondNetwork/elrond-go-p2p v1.0.1/go.mod h1:cJWOF4Ek2hBq7LOLS9zMoybuOXblBnWPcsV6dBjsTyc= +github.com/ElrondNetwork/elrond-go-storage v1.0.1 h1:T5pmTAu97aFNbUPpqxJprBEOs+uWsTaJSbCwY9xWPRA= +github.com/ElrondNetwork/elrond-go-storage v1.0.1/go.mod h1:Dht8Vt0BJvyUrr+mDSjYo2pu2fIMZfmUa0yznPG9zGw= github.com/ElrondNetwork/elrond-vm-common v1.1.0/go.mod h1:w3i6f8uiuRkE68Ie/gebRcLgTuHqvruJSYrFyZWuLrE= github.com/ElrondNetwork/elrond-vm-common v1.2.9/go.mod h1:B/Y8WiqHyDd7xsjNYsaYbVMp1jQgQ+z4jTJkFvj/EWI= github.com/ElrondNetwork/elrond-vm-common v1.3.7/go.mod h1:seROQuR7RJCoCS7mgRXVAlvjztltY1c+UroAgWr/USE= github.com/ElrondNetwork/elrond-vm-common v1.3.15-0.20220729115029-e70fd191b2f0/go.mod h1:seROQuR7RJCoCS7mgRXVAlvjztltY1c+UroAgWr/USE= -github.com/ElrondNetwork/elrond-vm-common v1.3.16 h1:/pLt3ckAhi5vE6Lde6tog7VNUg5BBm5sTDXnJBSvj7E= -github.com/ElrondNetwork/elrond-vm-common v1.3.16/go.mod h1:seROQuR7RJCoCS7mgRXVAlvjztltY1c+UroAgWr/USE= +github.com/ElrondNetwork/elrond-vm-common v1.3.17 h1:oeZ8AuVETpBv2mmaQg7MT9m3eAFF9ro50WGjQrQFGUI= +github.com/ElrondNetwork/elrond-vm-common v1.3.17/go.mod h1:seROQuR7RJCoCS7mgRXVAlvjztltY1c+UroAgWr/USE= github.com/ElrondNetwork/go-libp2p-pubsub v0.6.1-rc1 h1:Nu/uwYQg/QbfoQ0uD6GahYTwgtAkAwtzsB0HVfSP58I= github.com/ElrondNetwork/go-libp2p-pubsub v0.6.1-rc1/go.mod h1:pJfaShe+i5aWZx8NhSkQjvOYQYLoqPztmFUlKjToOzM= github.com/ElrondNetwork/protobuf v1.3.2 h1:qoCSYiO+8GtXBEZWEjw0WPcZfM3g7QuuJrwpN+y6Mvg= @@ -245,17 +250,16 @@ github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5 github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/cors v0.0.0-20190301062745-f9e10995c85a h1:zBycVvXa03SIX+jdMv8wGu9TMDMWdN8EhaR1FoeKHNo= github.com/gin-contrib/cors v0.0.0-20190301062745-f9e10995c85a/go.mod h1:pL2kNE+DgDU+eQ+dary5bX0Z6LPP8nR6Mqs1iejILw4= -github.com/gin-contrib/pprof v1.3.0 h1:G9eK6HnbkSqDZBYbzG4wrjCsA4e+cvYAHUZw6W+W9K0= -github.com/gin-contrib/pprof v1.3.0/go.mod h1:waMjT1H9b179t3CxuG1cV3DHpga6ybizwfBaM5OXaB0= +github.com/gin-contrib/pprof v1.4.0 h1:XxiBSf5jWZ5i16lNOPbMTVdgHBdhfGRD5PZ1LWazzvg= +github.com/gin-contrib/pprof v1.4.0/go.mod h1:RrehPJasUVBPK6yTUwOl8/NP6i0vbUgmxtis+Z5KE90= github.com/gin-contrib/sse v0.0.0-20170109093832-22d885f9ecc7/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.3.0/go.mod h1:7cKuhb5qV2ggCFctp2fJQ+ErvciLZrIeoOSOm6mUr7Y= -github.com/gin-gonic/gin v1.6.2/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gin-gonic/gin v1.7.1/go.mod h1:jD2toBW3GZUr5UMcdrwQA10I7RuaFOl/SGeDjXkfUtY= github.com/gin-gonic/gin v1.7.6/go.mod h1:jD2toBW3GZUr5UMcdrwQA10I7RuaFOl/SGeDjXkfUtY= -github.com/gin-gonic/gin v1.8.0 h1:4WFH5yycBMA3za5Hnl425yd9ymdw1XPm4666oab+hv4= -github.com/gin-gonic/gin v1.8.0/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= +github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= +github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= github.com/gizak/termui/v3 v3.1.0 h1:ZZmVDgwHl7gR7elfKf1xc4IudXZ5qqfDh4wExk4Iajc= github.com/gizak/termui/v3 v3.1.0/go.mod h1:bXQEBkJpzxUAKf0+xq9MSWAvWZlE7c+aidmyFlkYTrY= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= @@ -285,7 +289,6 @@ github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= -github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= github.com/go-playground/validator/v10 v10.10.0 h1:I7mrTYv78z8k8VXa/qJlOlEXn/nBh+BF8dHX5nt/dr0= github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= diff --git a/heartbeat/process/monitor.go b/heartbeat/process/monitor.go index 109a9c37021..cef4de65273 100644 --- a/heartbeat/process/monitor.go +++ b/heartbeat/process/monitor.go @@ -18,7 +18,7 @@ import ( "github.com/ElrondNetwork/elrond-go/heartbeat/data" "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" ) var log = logger.GetOrCreate("heartbeat/process") @@ -596,7 +596,7 @@ func (m *Monitor) addDoubleSignerPeers(hb *data.Heartbeat) { pubKeyStr := string(hb.Pubkey) tc, ok := m.doubleSignerPeers[pubKeyStr] if !ok { - tc = timecache.NewTimeCache(m.maxDurationPeerUnresponsive) + tc = cache.NewTimeCache(m.maxDurationPeerUnresponsive) err := tc.Add(string(hb.Pid)) if err != nil { log.Warn("cannot add heartbeat in cache", "peer id", hb.Pid, "error", err) diff --git a/heartbeat/processor/directConnectionsProcessor.go b/heartbeat/processor/directConnectionsProcessor.go index 13460d06416..a8b0ed77cad 100644 --- a/heartbeat/processor/directConnectionsProcessor.go +++ b/heartbeat/processor/directConnectionsProcessor.go @@ -11,7 +11,6 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/heartbeat" "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/message" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/sharding" ) @@ -130,7 +129,7 @@ func (dcp *directConnectionsProcessor) computeNewPeers(connectedPeers []core.Pee } func (dcp *directConnectionsProcessor) notifyNewPeers(newPeers []core.PeerID) { - shardValidatorInfo := &message.DirectConnectionInfo{ + shardValidatorInfo := &p2p.DirectConnectionInfo{ ShardId: fmt.Sprintf("%d", dcp.shardCoordinator.SelfId()), } diff --git a/heartbeat/processor/directConnectionsProcessor_test.go b/heartbeat/processor/directConnectionsProcessor_test.go index b228248b635..9301e42d0b2 100644 --- a/heartbeat/processor/directConnectionsProcessor_test.go +++ b/heartbeat/processor/directConnectionsProcessor_test.go @@ -14,7 +14,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/marshal" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/heartbeat" - "github.com/ElrondNetwork/elrond-go/p2p/message" + "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/testscommon" @@ -111,7 +111,7 @@ func TestNewDirectConnectionsProcessor(t *testing.T) { mutNotifiedPeers.Lock() defer mutNotifiedPeers.Unlock() - shardValidatorInfo := &message.DirectConnectionInfo{} + shardValidatorInfo := &p2p.DirectConnectionInfo{} err := args.Marshaller.Unmarshal(shardValidatorInfo, buff) assert.Nil(t, err) assert.Equal(t, expectedShard, shardValidatorInfo.ShardId) @@ -261,7 +261,7 @@ func Test_directConnectionsProcessor_notifyNewPeers(t *testing.T) { expectedShard := fmt.Sprintf("%d", args.ShardCoordinator.SelfId()) args.Messenger = &p2pmocks.MessengerStub{ SendToConnectedPeerCalled: func(topic string, buff []byte, peerID core.PeerID) error { - shardValidatorInfo := &message.DirectConnectionInfo{} + shardValidatorInfo := &p2p.DirectConnectionInfo{} err := args.Marshaller.Unmarshal(shardValidatorInfo, buff) assert.Nil(t, err) assert.Equal(t, expectedShard, shardValidatorInfo.ShardId) diff --git a/heartbeat/processor/peerAuthenticationRequestsProcessor.go b/heartbeat/processor/peerAuthenticationRequestsProcessor.go index d27769e4348..f8119b9c536 100644 --- a/heartbeat/processor/peerAuthenticationRequestsProcessor.go +++ b/heartbeat/processor/peerAuthenticationRequestsProcessor.go @@ -48,7 +48,6 @@ type peerAuthenticationRequestsProcessor struct { epoch uint32 minPeersThreshold float32 delayBetweenRequests time.Duration - maxTimeout time.Duration maxMissingKeysInRequest uint32 randomizer dataRetriever.IntRandomizer cancel func() @@ -69,7 +68,6 @@ func NewPeerAuthenticationRequestsProcessor(args ArgPeerAuthenticationRequestsPr epoch: args.Epoch, minPeersThreshold: args.MinPeersThreshold, delayBetweenRequests: args.DelayBetweenRequests, - maxTimeout: args.MaxTimeout, maxMissingKeysInRequest: args.MaxMissingKeysInRequest, randomizer: args.Randomizer, } diff --git a/integrationTests/consensus/consensus_test.go b/integrationTests/consensus/consensus_test.go index a9a1ee94ad6..a95acfb73fd 100644 --- a/integrationTests/consensus/consensus_test.go +++ b/integrationTests/consensus/consensus_test.go @@ -7,17 +7,26 @@ import ( "testing" "time" + "github.com/ElrondNetwork/elrond-go-core/core/pubkeyConverter" "github.com/ElrondNetwork/elrond-go-core/data" "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/factory" + consensusComp "github.com/ElrondNetwork/elrond-go/factory/consensus" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/ElrondNetwork/elrond-go/process" consensusMocks "github.com/ElrondNetwork/elrond-go/testscommon/consensus" "github.com/stretchr/testify/assert" ) -const consensusTimeBetweenRounds = time.Second +const ( + consensusTimeBetweenRounds = time.Second + blsConsensusType = "bls" +) + +var ( + p2pBootstrapDelay = time.Second * 5 + testPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(32) +) func encodeAddress(address []byte) string { return hex.EncodeToString(address) @@ -38,13 +47,12 @@ func initNodesAndTest( numInvalid uint32, roundTime uint64, consensusType string, -) ([]*testNode, *sync.Map) { +) []*integrationTests.TestConsensusNode { fmt.Println("Step 1. Setup nodes...") - concMap := &sync.Map{} - - nodes := createNodes( + nodes := integrationTests.CreateNodesWithTestConsensusNode( + 1, int(numNodes), int(consensusSize), roundTime, @@ -60,7 +68,7 @@ func initNodesAndTest( if numInvalid < numNodes { for i := uint32(0); i < numInvalid; i++ { iCopy := i - nodes[0][i].blkProcessor.ProcessBlockCalled = func( + nodes[0][i].BlockProcessor.ProcessBlockCalled = func( header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration, @@ -69,11 +77,11 @@ func initNodesAndTest( "process block invalid ", header.GetRound(), header.GetNonce(), - getPkEncoded(nodes[0][iCopy].pk), + getPkEncoded(nodes[0][iCopy].NodeKeys.Pk), ) return process.ErrBlockHashDoesNotMatch } - nodes[0][i].blkProcessor.CreateBlockCalled = func( + nodes[0][i].BlockProcessor.CreateBlockCalled = func( header data.HeaderHandler, haveTime func() bool, ) (data.HeaderHandler, data.BodyHandler, error) { @@ -82,15 +90,15 @@ func initNodesAndTest( } } - return nodes[0], concMap + return nodes[0] } -func startNodesWithCommitBlock(nodes []*testNode, mutex *sync.Mutex, nonceForRoundMap map[uint64]uint64, totalCalled *int) error { +func startNodesWithCommitBlock(nodes []*integrationTests.TestConsensusNode, mutex *sync.Mutex, nonceForRoundMap map[uint64]uint64, totalCalled *int) error { for _, n := range nodes { nCopy := n - n.blkProcessor.CommitBlockCalled = func(header data.HeaderHandler, body data.BodyHandler) error { - nCopy.blkProcessor.NrCommitBlockCalled++ - _ = nCopy.blkc.SetCurrentBlockHeaderAndRootHash(header, header.GetRootHash()) + n.BlockProcessor.CommitBlockCalled = func(header data.HeaderHandler, body data.BodyHandler) error { + nCopy.BlockProcessor.NrCommitBlockCalled++ + _ = nCopy.ChainHandler.SetCurrentBlockHeaderAndRootHash(header, header.GetRootHash()) mutex.Lock() nonceForRoundMap[header.GetRound()] = header.GetNonce() @@ -102,7 +110,7 @@ func startNodesWithCommitBlock(nodes []*testNode, mutex *sync.Mutex, nonceForRou statusComponents := integrationTests.GetDefaultStatusComponents() - consensusArgs := factory.ConsensusComponentsFactoryArgs{ + consensusArgs := consensusComp.ConsensusComponentsFactoryArgs{ Config: config.Config{ Consensus: config.ConsensusConfig{ Type: blsConsensusType, @@ -123,23 +131,23 @@ func startNodesWithCommitBlock(nodes []*testNode, mutex *sync.Mutex, nonceForRou }, }, BootstrapRoundIndex: 0, - CoreComponents: n.node.GetCoreComponents(), - NetworkComponents: n.node.GetNetworkComponents(), - CryptoComponents: n.node.GetCryptoComponents(), - DataComponents: n.node.GetDataComponents(), - ProcessComponents: n.node.GetProcessComponents(), - StateComponents: n.node.GetStateComponents(), + CoreComponents: n.Node.GetCoreComponents(), + NetworkComponents: n.Node.GetNetworkComponents(), + CryptoComponents: n.Node.GetCryptoComponents(), + DataComponents: n.Node.GetDataComponents(), + ProcessComponents: n.Node.GetProcessComponents(), + StateComponents: n.Node.GetStateComponents(), StatusComponents: statusComponents, ScheduledProcessor: &consensusMocks.ScheduledProcessorStub{}, - IsInImportMode: n.node.IsInImportMode(), + IsInImportMode: n.Node.IsInImportMode(), } - consensusFactory, err := factory.NewConsensusComponentsFactory(consensusArgs) + consensusFactory, err := consensusComp.NewConsensusComponentsFactory(consensusArgs) if err != nil { return fmt.Errorf("NewConsensusComponentsFactory failed: %w", err) } - managedConsensusComponents, err := factory.NewManagedConsensusComponents(consensusFactory) + managedConsensusComponents, err := consensusComp.NewManagedConsensusComponents(consensusFactory) if err != nil { return err } @@ -196,12 +204,12 @@ func runFullConsensusTest(t *testing.T, consensusType string) { roundTime := uint64(1000) numCommBlock := uint64(8) - nodes, _ := initNodesAndTest(numNodes, consensusSize, numInvalid, roundTime, consensusType) + nodes := initNodesAndTest(numNodes, consensusSize, numInvalid, roundTime, consensusType) mutex := &sync.Mutex{} defer func() { for _, n := range nodes { - _ = n.messenger.Close() + _ = n.Messenger.Close() } }() @@ -243,12 +251,12 @@ func runConsensusWithNotEnoughValidators(t *testing.T, consensusType string) { consensusSize := uint32(4) numInvalid := uint32(2) roundTime := uint64(1000) - nodes, _ := initNodesAndTest(numNodes, consensusSize, numInvalid, roundTime, consensusType) + nodes := initNodesAndTest(numNodes, consensusSize, numInvalid, roundTime, consensusType) mutex := &sync.Mutex{} defer func() { for _, n := range nodes { - _ = n.messenger.Close() + _ = n.Messenger.Close() } }() @@ -277,3 +285,16 @@ func TestConsensusBLSNotEnoughValidators(t *testing.T) { runConsensusWithNotEnoughValidators(t, blsConsensusType) } + +func displayAndStartNodes(nodes []*integrationTests.TestConsensusNode) { + for _, n := range nodes { + skBuff, _ := n.NodeKeys.Sk.ToByteArray() + pkBuff, _ := n.NodeKeys.Pk.ToByteArray() + + fmt.Printf("Shard ID: %v, sk: %s, pk: %s\n", + n.ShardCoordinator.SelfId(), + hex.EncodeToString(skBuff), + testPubkeyConverter.Encode(pkBuff), + ) + } +} diff --git a/integrationTests/consensus/messengerWrapper.go b/integrationTests/consensus/messengerWrapper.go deleted file mode 100644 index 1c1d72c3f5e..00000000000 --- a/integrationTests/consensus/messengerWrapper.go +++ /dev/null @@ -1,32 +0,0 @@ -package consensus - -import ( - "fmt" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/integrationTests" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -//TODO refactor consensus node so this wrapper will not be required -type messengerWrapper struct { - p2p.Messenger -} - -// ConnectTo will try to initiate a connection to the provided parameter -func (mw *messengerWrapper) ConnectTo(connectable integrationTests.Connectable) error { - if check.IfNil(connectable) { - return fmt.Errorf("trying to connect to a nil Connectable parameter") - } - - return mw.ConnectToPeer(connectable.GetConnectableAddress()) -} - -// GetConnectableAddress returns a non circuit, non windows default connectable p2p address -func (mw *messengerWrapper) GetConnectableAddress() string { - if mw == nil { - return "nil" - } - - return integrationTests.GetConnectableAddress(mw) -} diff --git a/integrationTests/consensus/testInitializer.go b/integrationTests/consensus/testInitializer.go deleted file mode 100644 index dc6242ebff8..00000000000 --- a/integrationTests/consensus/testInitializer.go +++ /dev/null @@ -1,576 +0,0 @@ -package consensus - -import ( - "encoding/hex" - "fmt" - "strconv" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/pubkeyConverter" - "github.com/ElrondNetwork/elrond-go-core/data" - dataBlock "github.com/ElrondNetwork/elrond-go-core/data/block" - "github.com/ElrondNetwork/elrond-go-core/data/endProcess" - "github.com/ElrondNetwork/elrond-go-core/hashing" - "github.com/ElrondNetwork/elrond-go-core/hashing/blake2b" - "github.com/ElrondNetwork/elrond-go-core/hashing/sha256" - "github.com/ElrondNetwork/elrond-go-core/marshal" - crypto "github.com/ElrondNetwork/elrond-go-crypto" - "github.com/ElrondNetwork/elrond-go-crypto/signing" - ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" - "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" - mclsinglesig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/singlesig" - "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/consensus/round" - "github.com/ElrondNetwork/elrond-go/dataRetriever" - "github.com/ElrondNetwork/elrond-go/dataRetriever/blockchain" - "github.com/ElrondNetwork/elrond-go/epochStart/metachain" - "github.com/ElrondNetwork/elrond-go/epochStart/notifier" - mainFactory "github.com/ElrondNetwork/elrond-go/factory" - "github.com/ElrondNetwork/elrond-go/factory/peerSignatureHandler" - "github.com/ElrondNetwork/elrond-go/integrationTests" - "github.com/ElrondNetwork/elrond-go/integrationTests/mock" - "github.com/ElrondNetwork/elrond-go/node" - "github.com/ElrondNetwork/elrond-go/ntp" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/process/factory" - syncFork "github.com/ElrondNetwork/elrond-go/process/sync" - "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" - "github.com/ElrondNetwork/elrond-go/state" - "github.com/ElrondNetwork/elrond-go/state/storagePruningManager" - "github.com/ElrondNetwork/elrond-go/state/storagePruningManager/evictionWaitingList" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" - dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" - "github.com/ElrondNetwork/elrond-go/testscommon/nodeTypeProviderMock" - "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" - statusHandlerMock "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" - vic "github.com/ElrondNetwork/elrond-go/testscommon/validatorInfoCacher" - "github.com/ElrondNetwork/elrond-go/trie" - "github.com/ElrondNetwork/elrond-go/trie/hashesHolder" - vmcommon "github.com/ElrondNetwork/elrond-vm-common" -) - -const blsConsensusType = "bls" -const signatureSize = 48 -const publicKeySize = 96 - -var p2pBootstrapDelay = time.Second * 5 -var testPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(32) - -type testNode struct { - node *node.Node - messenger p2p.Messenger - blkc data.ChainHandler - blkProcessor *mock.BlockProcessorMock - sk crypto.PrivateKey - pk crypto.PublicKey - shardId uint32 -} - -type keyPair struct { - sk crypto.PrivateKey - pk crypto.PublicKey -} - -type cryptoParams struct { - keyGen crypto.KeyGenerator - keys map[uint32][]*keyPair - txSingleSigner crypto.SingleSigner - singleSigner crypto.SingleSigner -} - -func genValidatorsFromPubKeys(pubKeysMap map[uint32][]string) map[uint32][]nodesCoordinator.Validator { - validatorsMap := make(map[uint32][]nodesCoordinator.Validator) - - for shardId, shardNodesPks := range pubKeysMap { - shardValidators := make([]nodesCoordinator.Validator, 0) - for i := 0; i < len(shardNodesPks); i++ { - v, _ := nodesCoordinator.NewValidator([]byte(shardNodesPks[i]), 1, uint32(i)) - shardValidators = append(shardValidators, v) - } - validatorsMap[shardId] = shardValidators - } - - return validatorsMap -} - -func pubKeysMapFromKeysMap(keyPairMap map[uint32][]*keyPair) map[uint32][]string { - keysMap := make(map[uint32][]string) - - for shardId, pairList := range keyPairMap { - shardKeys := make([]string, len(pairList)) - for i, pair := range pairList { - b, _ := pair.pk.ToByteArray() - shardKeys[i] = string(b) - } - keysMap[shardId] = shardKeys - } - - return keysMap -} - -func displayAndStartNodes(nodes []*testNode) { - for _, n := range nodes { - skBuff, _ := n.sk.ToByteArray() - pkBuff, _ := n.pk.ToByteArray() - - fmt.Printf("Shard ID: %v, sk: %s, pk: %s\n", - n.shardId, - hex.EncodeToString(skBuff), - testPubkeyConverter.Encode(pkBuff), - ) - } -} - -func createTestBlockChain() data.ChainHandler { - blockChain, _ := blockchain.NewBlockChain(&statusHandlerMock.AppStatusHandlerStub{}) - _ = blockChain.SetGenesisHeader(&dataBlock.Header{}) - - return blockChain -} - -func createMemUnit() storage.Storer { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 10, Shards: 1, SizeInBytes: 0}) - - unit, _ := storageUnit.NewStorageUnit(cache, memorydb.New()) - return unit -} - -func createTestStore() dataRetriever.StorageService { - store := dataRetriever.NewChainStorer() - store.AddStorer(dataRetriever.TransactionUnit, createMemUnit()) - store.AddStorer(dataRetriever.MiniBlockUnit, createMemUnit()) - store.AddStorer(dataRetriever.RewardTransactionUnit, createMemUnit()) - store.AddStorer(dataRetriever.MetaBlockUnit, createMemUnit()) - store.AddStorer(dataRetriever.PeerChangesUnit, createMemUnit()) - store.AddStorer(dataRetriever.BlockHeaderUnit, createMemUnit()) - store.AddStorer(dataRetriever.BootstrapUnit, createMemUnit()) - store.AddStorer(dataRetriever.ReceiptsUnit, createMemUnit()) - store.AddStorer(dataRetriever.ScheduledSCRsUnit, createMemUnit()) - store.AddStorer(dataRetriever.ShardHdrNonceHashDataUnit, createMemUnit()) - - return store -} - -func createAccountsDB(marshaller marshal.Marshalizer) state.AccountsAdapter { - marsh := &marshal.GogoProtoMarshalizer{} - hasher := sha256.NewSha256() - evictionWaitListSize := uint(100) - ewl, _ := evictionWaitingList.NewEvictionWaitingList(evictionWaitListSize, memorydb.New(), marsh) - - // TODO change this implementation with a factory - generalCfg := config.TrieStorageManagerConfig{ - PruningBufferLen: 1000, - SnapshotsBufferLen: 10, - SnapshotsGoroutineNum: 1, - } - args := trie.NewTrieStorageManagerArgs{ - MainStorer: createMemUnit(), - CheckpointsStorer: createMemUnit(), - Marshalizer: marshaller, - Hasher: hasher, - GeneralConfig: generalCfg, - CheckpointHashesHolder: hashesHolder.NewCheckpointHashesHolder(10000000, uint64(hasher.Size())), - IdleProvider: &testscommon.ProcessStatusHandlerStub{}, - } - trieStorage, _ := trie.NewTrieStorageManager(args) - - maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, marsh, hasher, maxTrieLevelInMemory) - storagePruning, _ := storagePruningManager.NewStoragePruningManager( - ewl, - generalCfg.PruningBufferLen, - ) - - argsAccountsDB := state.ArgsAccountsDB{ - Trie: tr, - Hasher: sha256.NewSha256(), - Marshaller: marshaller, - AccountFactory: &mock.AccountsFactoryStub{ - CreateAccountCalled: func(address []byte) (wrapper vmcommon.AccountHandler, e error) { - return state.NewUserAccount(address) - }, - }, - StoragePruningManager: storagePruning, - ProcessingMode: common.Normal, - ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, - } - - adb, _ := state.NewAccountsDB(argsAccountsDB) - return adb -} - -func createCryptoParams(nodesPerShard int, nbMetaNodes int, nbShards int) *cryptoParams { - suite := mcl.NewSuiteBLS12() - txSingleSigner := &ed25519SingleSig.Ed25519Signer{} - singleSigner := &mclsinglesig.BlsSingleSigner{} - keyGen := signing.NewKeyGenerator(suite) - - keysMap := make(map[uint32][]*keyPair) - keyPairs := make([]*keyPair, nodesPerShard) - for shardId := 0; shardId < nbShards; shardId++ { - for n := 0; n < nodesPerShard; n++ { - kp := &keyPair{} - kp.sk, kp.pk = keyGen.GeneratePair() - keyPairs[n] = kp - } - keysMap[uint32(shardId)] = keyPairs - } - - keyPairs = make([]*keyPair, nbMetaNodes) - for n := 0; n < nbMetaNodes; n++ { - kp := &keyPair{} - kp.sk, kp.pk = keyGen.GeneratePair() - keyPairs[n] = kp - } - keysMap[core.MetachainShardId] = keyPairs - - params := &cryptoParams{ - keys: keysMap, - keyGen: keyGen, - txSingleSigner: txSingleSigner, - singleSigner: singleSigner, - } - - return params -} - -func createHasher(consensusType string) hashing.Hasher { - if consensusType == blsConsensusType { - hasher, _ := blake2b.NewBlake2bWithSize(32) - return hasher - } - return blake2b.NewBlake2b() -} - -func createConsensusOnlyNode( - shardCoordinator sharding.Coordinator, - nodesCoordinator nodesCoordinator.NodesCoordinator, - shardId uint32, - selfId uint32, - consensusSize uint32, - roundTime uint64, - privKey crypto.PrivateKey, - pubKeys []crypto.PublicKey, - testKeyGen crypto.KeyGenerator, - consensusType string, - epochStartRegistrationHandler mainFactory.EpochStartNotifier, -) ( - *node.Node, - p2p.Messenger, - *mock.BlockProcessorMock, - data.ChainHandler) { - - testHasher := createHasher(consensusType) - testMarshalizer := &marshal.GogoProtoMarshalizer{} - - messenger := integrationTests.CreateMessengerWithNoDiscovery() - rootHash := []byte("roothash") - - blockChain := createTestBlockChain() - blockProcessor := &mock.BlockProcessorMock{ - ProcessBlockCalled: func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { - _ = blockChain.SetCurrentBlockHeaderAndRootHash(header, header.GetRootHash()) - return nil - }, - RevertCurrentBlockCalled: func() { - }, - CreateBlockCalled: func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { - return header, &dataBlock.Body{}, nil - }, - MarshalizedDataToBroadcastCalled: func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { - mrsData := make(map[uint32][]byte) - mrsTxs := make(map[string][][]byte) - return mrsData, mrsTxs, nil - }, - CreateNewHeaderCalled: func(round uint64, nonce uint64) (data.HeaderHandler, error) { - return &dataBlock.Header{ - Round: round, - Nonce: nonce, - SoftwareVersion: []byte("version"), - }, nil - }, - } - - blockProcessor.CommitBlockCalled = func(header data.HeaderHandler, body data.BodyHandler) error { - blockProcessor.NrCommitBlockCalled++ - _ = blockChain.SetCurrentBlockHeaderAndRootHash(header, header.GetRootHash()) - return nil - } - blockProcessor.Marshalizer = testMarshalizer - - header := &dataBlock.Header{ - Nonce: 0, - ShardID: shardId, - BlockBodyType: dataBlock.StateBlock, - Signature: rootHash, - RootHash: rootHash, - PrevRandSeed: rootHash, - RandSeed: rootHash, - } - - _ = blockChain.SetGenesisHeader(header) - hdrMarshalized, _ := testMarshalizer.Marshal(header) - blockChain.SetGenesisHeaderHash(testHasher.Compute(string(hdrMarshalized))) - - startTime := time.Now().Unix() - - singlesigner := &ed25519SingleSig.Ed25519Signer{} - singleBlsSigner := &mclsinglesig.BlsSingleSigner{} - - syncer := ntp.NewSyncTime(ntp.NewNTPGoogleConfig(), nil) - syncer.StartSyncingTime() - - roundHandler, _ := round.NewRound( - time.Unix(startTime, 0), - syncer.CurrentTime(), - time.Millisecond*time.Duration(roundTime), - syncer, - 0) - - dataPool := dataRetrieverMock.CreatePoolsHolder(1, 0) - - argsNewMetaEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ - GenesisTime: time.Unix(startTime, 0), - EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), - Settings: &config.EpochStartConfig{ - MinRoundsBetweenEpochs: 1, - RoundsPerEpoch: 3, - }, - Epoch: 0, - Storage: createTestStore(), - Marshalizer: testMarshalizer, - Hasher: testHasher, - AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, - DataPool: dataPool, - } - epochStartTrigger, _ := metachain.NewEpochStartTrigger(argsNewMetaEpochStart) - - forkDetector, _ := syncFork.NewShardForkDetector( - roundHandler, - timecache.NewTimeCache(time.Second), - &mock.BlockTrackerStub{}, - 0, - ) - - hdrResolver := &mock.HeaderResolverStub{} - mbResolver := &mock.MiniBlocksResolverStub{} - resolverFinder := &mock.ResolversFinderStub{ - IntraShardResolverCalled: func(baseTopic string) (resolver dataRetriever.Resolver, e error) { - if baseTopic == factory.MiniBlocksTopic { - return mbResolver, nil - } - return nil, nil - }, - CrossShardResolverCalled: func(baseTopic string, crossShard uint32) (resolver dataRetriever.Resolver, err error) { - if baseTopic == factory.ShardBlocksTopic { - return hdrResolver, nil - } - return nil, nil - }, - } - - inPubKeys := make(map[uint32][]string) - for _, val := range pubKeys { - sPubKey, _ := val.ToByteArray() - inPubKeys[shardId] = append(inPubKeys[shardId], string(sPubKey)) - } - - testMultiSig := cryptoMocks.NewMultiSigner(consensusSize) - _ = testMultiSig.Reset(inPubKeys[shardId], uint16(selfId)) - - peerSigCache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) - peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler(peerSigCache, singleBlsSigner, testKeyGen) - accntAdapter := createAccountsDB(testMarshalizer) - networkShardingCollector := mock.NewNetworkShardingCollectorMock() - - coreComponents := integrationTests.GetDefaultCoreComponents() - coreComponents.SyncTimerField = syncer - coreComponents.RoundHandlerField = roundHandler - coreComponents.InternalMarshalizerField = testMarshalizer - coreComponents.VmMarshalizerField = &marshal.JsonMarshalizer{} - coreComponents.TxMarshalizerField = &marshal.JsonMarshalizer{} - coreComponents.HasherField = testHasher - coreComponents.AddressPubKeyConverterField = testPubkeyConverter - coreComponents.ChainIdCalled = func() string { - return string(integrationTests.ChainID) - } - coreComponents.Uint64ByteSliceConverterField = &mock.Uint64ByteSliceConverterMock{} - coreComponents.WatchdogField = &mock.WatchdogMock{} - coreComponents.GenesisTimeField = time.Unix(startTime, 0) - coreComponents.GenesisNodesSetupField = &testscommon.NodesSetupStub{ - GetShardConsensusGroupSizeCalled: func() uint32 { - return consensusSize - }, - GetMetaConsensusGroupSizeCalled: func() uint32 { - return consensusSize - }, - } - - networkComponents := integrationTests.GetDefaultNetworkComponents() - networkComponents.Messenger = messenger - networkComponents.InputAntiFlood = &mock.NilAntifloodHandler{} - networkComponents.PeerHonesty = &mock.PeerHonestyHandlerStub{} - - cryptoComponents := integrationTests.GetDefaultCryptoComponents() - cryptoComponents.PrivKey = privKey - cryptoComponents.PubKey = privKey.GeneratePublic() - cryptoComponents.BlockSig = singleBlsSigner - cryptoComponents.TxSig = singlesigner - cryptoComponents.MultiSig = testMultiSig - cryptoComponents.BlKeyGen = testKeyGen - cryptoComponents.PeerSignHandler = peerSigHandler - cryptoComponents.KeysHandlerField = testscommon.NewKeysHandlerSingleSignerMock( - cryptoComponents.PrivKey, - networkComponents.Messenger.ID(), - ) - - processComponents := integrationTests.GetDefaultProcessComponents() - processComponents.ForkDetect = forkDetector - processComponents.ShardCoord = shardCoordinator - processComponents.NodesCoord = nodesCoordinator - processComponents.BlockProcess = blockProcessor - processComponents.BlockTrack = &mock.BlockTrackerStub{} - processComponents.IntContainer = &testscommon.InterceptorsContainerStub{} - processComponents.ResFinder = resolverFinder - processComponents.EpochTrigger = epochStartTrigger - processComponents.EpochNotifier = epochStartRegistrationHandler - processComponents.BlackListHdl = &testscommon.TimeCacheStub{} - processComponents.BootSore = &mock.BoostrapStorerMock{} - processComponents.HeaderSigVerif = &mock.HeaderSigVerifierStub{} - processComponents.HeaderIntegrVerif = &mock.HeaderIntegrityVerifierStub{} - processComponents.ReqHandler = &testscommon.RequestHandlerStub{} - processComponents.PeerMapper = networkShardingCollector - processComponents.RoundHandlerField = roundHandler - processComponents.ScheduledTxsExecutionHandlerInternal = &testscommon.ScheduledTxsExecutionStub{} - processComponents.ProcessedMiniBlocksTrackerInternal = &testscommon.ProcessedMiniBlocksTrackerStub{} - - dataComponents := integrationTests.GetDefaultDataComponents() - dataComponents.BlockChain = blockChain - dataComponents.DataPool = dataPool - dataComponents.Store = createTestStore() - - stateComponents := integrationTests.GetDefaultStateComponents() - stateComponents.Accounts = accntAdapter - stateComponents.AccountsAPI = accntAdapter - - n, err := node.NewNode( - node.WithCoreComponents(coreComponents), - node.WithCryptoComponents(cryptoComponents), - node.WithProcessComponents(processComponents), - node.WithDataComponents(dataComponents), - node.WithStateComponents(stateComponents), - node.WithNetworkComponents(networkComponents), - node.WithInitialNodesPubKeys(inPubKeys), - node.WithRoundDuration(roundTime), - node.WithConsensusGroupSize(int(consensusSize)), - node.WithConsensusType(consensusType), - node.WithGenesisTime(time.Unix(startTime, 0)), - node.WithPeerDenialEvaluator(&mock.PeerDenialEvaluatorStub{}), - node.WithRequestedItemsHandler(&mock.RequestedItemsHandlerStub{}), - node.WithValidatorSignatureSize(signatureSize), - node.WithPublicKeySize(publicKeySize), - ) - - if err != nil { - fmt.Println(err.Error()) - } - - return n, messenger, blockProcessor, blockChain -} - -func createNodes( - nodesPerShard int, - consensusSize int, - roundTime uint64, - consensusType string, -) map[uint32][]*testNode { - - nodes := make(map[uint32][]*testNode) - cp := createCryptoParams(nodesPerShard, 1, 1) - keysMap := pubKeysMapFromKeysMap(cp.keys) - eligibleMap := genValidatorsFromPubKeys(keysMap) - waitingMap := make(map[uint32][]nodesCoordinator.Validator) - nodesList := make([]*testNode, nodesPerShard) - connectableNodes := make([]integrationTests.Connectable, 0) - - nodeShuffler := &shardingMocks.NodeShufflerMock{} - - pubKeys := make([]crypto.PublicKey, len(cp.keys[0])) - for idx, keyPairShard := range cp.keys[0] { - pubKeys[idx] = keyPairShard.pk - } - - for i := 0; i < nodesPerShard; i++ { - testNodeObject := &testNode{ - shardId: uint32(0), - } - - kp := cp.keys[0][i] - shardCoordinator, _ := sharding.NewMultiShardCoordinator(uint32(1), uint32(0)) - epochStartRegistrationHandler := notifier.NewEpochStartSubscriptionHandler() - bootStorer := integrationTests.CreateMemUnit() - consensusCache, _ := lrucache.NewCache(10000) - - argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusSize, - MetaConsensusGroupSize: 1, - Marshalizer: integrationTests.TestMarshalizer, - Hasher: createHasher(consensusType), - Shuffler: nodeShuffler, - EpochStartNotifier: epochStartRegistrationHandler, - BootStorer: bootStorer, - NbShards: 1, - EligibleNodes: eligibleMap, - WaitingNodes: waitingMap, - SelfPublicKey: []byte(strconv.Itoa(i)), - ConsensusGroupCache: consensusCache, - ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, - ChanStopNode: endProcess.GetDummyEndProcessChannel(), - NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, - IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ - IsWaitingListFixFlagEnabledField: true, - }, - ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, - } - nodesCoord, _ := nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) - - n, mes, blkProcessor, blkc := createConsensusOnlyNode( - shardCoordinator, - nodesCoord, - testNodeObject.shardId, - uint32(i), - uint32(consensusSize), - roundTime, - kp.sk, - pubKeys, - cp.keyGen, - consensusType, - epochStartRegistrationHandler, - ) - - testNodeObject.node = n - testNodeObject.sk = kp.sk - testNodeObject.messenger = mes - testNodeObject.pk = kp.pk - testNodeObject.blkProcessor = blkProcessor - testNodeObject.blkc = blkc - - nodesList[i] = testNodeObject - connectableNodes = append(connectableNodes, &messengerWrapper{mes}) - } - nodes[0] = nodesList - - integrationTests.ConnectNodes(connectableNodes) - - return nodes -} diff --git a/integrationTests/factory/componentsHelper.go b/integrationTests/factory/componentsHelper.go index e9a348e425a..6c1138d7614 100644 --- a/integrationTests/factory/componentsHelper.go +++ b/integrationTests/factory/componentsHelper.go @@ -98,5 +98,6 @@ func createConfigurationsPathsHolder() *config.ConfigurationPathsHolder { SmartContracts: GenesisSmartContracts, ValidatorKey: ValidatorKeyPemPath, ApiRoutes: "", + P2pKey: P2pKeyPath, } } diff --git a/integrationTests/factory/consensusComponents/consensusComponents_test.go b/integrationTests/factory/consensusComponents/consensusComponents_test.go index 826947140f6..d34f2d085d9 100644 --- a/integrationTests/factory/consensusComponents/consensusComponents_test.go +++ b/integrationTests/factory/consensusComponents/consensusComponents_test.go @@ -8,7 +8,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/endProcess" "github.com/ElrondNetwork/elrond-go/common/forking" "github.com/ElrondNetwork/elrond-go/dataRetriever" - mainFactory "github.com/ElrondNetwork/elrond-go/factory" + bootstrapComp "github.com/ElrondNetwork/elrond-go/factory/bootstrap" "github.com/ElrondNetwork/elrond-go/integrationTests/factory" "github.com/ElrondNetwork/elrond-go/node" "github.com/ElrondNetwork/elrond-go/testscommon/goroutines" @@ -46,11 +46,11 @@ func TestConsensusComponents_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents) require.Nil(t, err) - nodesShufflerOut, err := mainFactory.CreateNodesShuffleOut(managedCoreComponents.GenesisNodesSetup(), configs.GeneralConfig.EpochStartConfig, managedCoreComponents.ChanStopNodeProcess()) + nodesShufflerOut, err := bootstrapComp.CreateNodesShuffleOut(managedCoreComponents.GenesisNodesSetup(), configs.GeneralConfig.EpochStartConfig, managedCoreComponents.ChanStopNodeProcess()) require.Nil(t, err) storer, err := managedDataComponents.StorageService().GetStorer(dataRetriever.BootstrapUnit) require.Nil(t, err) - nodesCoordinator, err := mainFactory.CreateNodesCoordinator( + nodesCoordinator, err := bootstrapComp.CreateNodesCoordinator( nodesShufflerOut, managedCoreComponents.GenesisNodesSetup(), configs.PreferencesConfig.Preferences, diff --git a/integrationTests/factory/constants.go b/integrationTests/factory/constants.go index b392bf32e05..846445f0c46 100644 --- a/integrationTests/factory/constants.go +++ b/integrationTests/factory/constants.go @@ -19,4 +19,5 @@ const ( Version = "v1.1.6.1-0-gbae61225f/go1.14.2/linux-amd64/a72b5f2eff" WorkingDir = "workingDir" RoundActivationPath = "enableRounds.toml" + P2pKeyPath = "../p2pKey.pem" ) diff --git a/integrationTests/factory/processComponents/processComponents_test.go b/integrationTests/factory/processComponents/processComponents_test.go index 6581c618da2..6338e02bf8c 100644 --- a/integrationTests/factory/processComponents/processComponents_test.go +++ b/integrationTests/factory/processComponents/processComponents_test.go @@ -8,7 +8,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/endProcess" "github.com/ElrondNetwork/elrond-go/common/forking" "github.com/ElrondNetwork/elrond-go/dataRetriever" - mainFactory "github.com/ElrondNetwork/elrond-go/factory" + bootstrapComp "github.com/ElrondNetwork/elrond-go/factory/bootstrap" "github.com/ElrondNetwork/elrond-go/integrationTests/factory" "github.com/ElrondNetwork/elrond-go/node" "github.com/ElrondNetwork/elrond-go/testscommon/goroutines" @@ -47,11 +47,11 @@ func TestProcessComponents_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents) require.Nil(t, err) - nodesShufflerOut, err := mainFactory.CreateNodesShuffleOut(managedCoreComponents.GenesisNodesSetup(), configs.GeneralConfig.EpochStartConfig, managedCoreComponents.ChanStopNodeProcess()) + nodesShufflerOut, err := bootstrapComp.CreateNodesShuffleOut(managedCoreComponents.GenesisNodesSetup(), configs.GeneralConfig.EpochStartConfig, managedCoreComponents.ChanStopNodeProcess()) require.Nil(t, err) storer, err := managedDataComponents.StorageService().GetStorer(dataRetriever.BootstrapUnit) require.Nil(t, err) - nodesCoordinator, err := mainFactory.CreateNodesCoordinator( + nodesCoordinator, err := bootstrapComp.CreateNodesCoordinator( nodesShufflerOut, managedCoreComponents.GenesisNodesSetup(), configs.PreferencesConfig.Preferences, diff --git a/integrationTests/factory/statusComponents/statusComponents_test.go b/integrationTests/factory/statusComponents/statusComponents_test.go index 5f167c8291d..d37ef738141 100644 --- a/integrationTests/factory/statusComponents/statusComponents_test.go +++ b/integrationTests/factory/statusComponents/statusComponents_test.go @@ -8,7 +8,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/endProcess" "github.com/ElrondNetwork/elrond-go/common/forking" "github.com/ElrondNetwork/elrond-go/dataRetriever" - mainFactory "github.com/ElrondNetwork/elrond-go/factory" + bootstrapComp "github.com/ElrondNetwork/elrond-go/factory/bootstrap" "github.com/ElrondNetwork/elrond-go/integrationTests/factory" "github.com/ElrondNetwork/elrond-go/node" "github.com/ElrondNetwork/elrond-go/testscommon/goroutines" @@ -47,11 +47,11 @@ func TestStatusComponents_Create_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedStateComponents, err := nr.CreateManagedStateComponents(managedCoreComponents, managedBootstrapComponents, managedDataComponents) require.Nil(t, err) - nodesShufflerOut, err := mainFactory.CreateNodesShuffleOut(managedCoreComponents.GenesisNodesSetup(), configs.GeneralConfig.EpochStartConfig, managedCoreComponents.ChanStopNodeProcess()) + nodesShufflerOut, err := bootstrapComp.CreateNodesShuffleOut(managedCoreComponents.GenesisNodesSetup(), configs.GeneralConfig.EpochStartConfig, managedCoreComponents.ChanStopNodeProcess()) require.Nil(t, err) storer, err := managedDataComponents.StorageService().GetStorer(dataRetriever.BootstrapUnit) require.Nil(t, err) - nodesCoordinator, err := mainFactory.CreateNodesCoordinator( + nodesCoordinator, err := bootstrapComp.CreateNodesCoordinator( nodesShufflerOut, managedCoreComponents.GenesisNodesSetup(), configs.PreferencesConfig.Preferences, diff --git a/integrationTests/frontend/staking/staking_test.go b/integrationTests/frontend/staking/staking_test.go index e88f1cf2a5b..ac7320f3100 100644 --- a/integrationTests/frontend/staking/staking_test.go +++ b/integrationTests/frontend/staking/staking_test.go @@ -8,8 +8,6 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data/block" - ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" - "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/singlesig" logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/ElrondNetwork/elrond-go/process" @@ -38,8 +36,8 @@ func TestSignatureOnStaking(t *testing.T) { require.Nil(t, err) stakingWalletAccount := &integrationTests.TestWalletAccount{ - SingleSigner: &ed25519SingleSig.Ed25519Signer{}, - BlockSingleSigner: &singlesig.BlsSingleSigner{}, + SingleSigner: integrationTests.TestSingleSigner, + BlockSingleSigner: integrationTests.TestSingleBlsSigner, SkTxSign: skStaking, PkTxSign: pkStaking, PkTxSignBytes: pkBuff, diff --git a/integrationTests/frontend/wallet/dataField_test.go b/integrationTests/frontend/wallet/dataField_test.go index 51b2c0b251c..cd7beaf075f 100644 --- a/integrationTests/frontend/wallet/dataField_test.go +++ b/integrationTests/frontend/wallet/dataField_test.go @@ -14,14 +14,14 @@ import ( "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519" - "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" + "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/stretchr/testify/assert" ) func TestTxDataFieldContainingUTF8Characters(t *testing.T) { suite := ed25519.NewEd25519() keyGen := signing.NewKeyGenerator(suite) - singleSigner := &singlesig.Ed25519Signer{} + singleSigner := integrationTests.TestSingleSigner sk, pk := keyGen.GeneratePair() pkBytes, _ := pk.ToByteArray() diff --git a/integrationTests/longTests/storage/storagePutRemove_test.go b/integrationTests/longTests/storage/storagePutRemove_test.go index 041a7d17f84..0d714ad0ab7 100644 --- a/integrationTests/longTests/storage/storagePutRemove_test.go +++ b/integrationTests/longTests/storage/storagePutRemove_test.go @@ -7,8 +7,8 @@ import ( logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/leveldb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/stretchr/testify/assert" ) @@ -17,17 +17,17 @@ var log = logger.GetOrCreate("integrationTests/longTests/storage") func TestPutRemove(t *testing.T) { t.Skip("this is a long test") - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 5000, Shards: 16, SizeInBytes: 0}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 5000, Shards: 16, SizeInBytes: 0}) dir := t.TempDir() log.Info("opened in", "directory", dir) - lvdb1, err := leveldb.NewDB(dir, 2, 1000, 10) + lvdb1, err := database.NewLevelDB(dir, 2, 1000, 10) assert.NoError(t, err) defer func() { _ = lvdb1.Close() }() - store, err := storageUnit.NewStorageUnit(cache, lvdb1) + store, err := storageunit.NewStorageUnit(cache, lvdb1) log.LogIfError(err) numPuts := 800 diff --git a/integrationTests/mock/blockProcessorMock.go b/integrationTests/mock/blockProcessorMock.go index a85851ba42a..f9c14d7bbf2 100644 --- a/integrationTests/mock/blockProcessorMock.go +++ b/integrationTests/mock/blockProcessorMock.go @@ -39,37 +39,63 @@ func (bpm *BlockProcessorMock) RestoreLastNotarizedHrdsToGenesis() { // ProcessBlock mocks processing a block func (bpm *BlockProcessorMock) ProcessBlock(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { - return bpm.ProcessBlockCalled(header, body, haveTime) + if bpm.ProcessBlockCalled != nil { + return bpm.ProcessBlockCalled(header, body, haveTime) + } + + return nil } // ProcessScheduledBlock mocks processing a scheduled block func (bpm *BlockProcessorMock) ProcessScheduledBlock(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { - return bpm.ProcessScheduledBlockCalled(header, body, haveTime) + if bpm.ProcessScheduledBlockCalled != nil { + return bpm.ProcessScheduledBlockCalled(header, body, haveTime) + } + + return nil } // CommitBlock mocks the commit of a block func (bpm *BlockProcessorMock) CommitBlock(header data.HeaderHandler, body data.BodyHandler) error { - return bpm.CommitBlockCalled(header, body) + if bpm.CommitBlockCalled != nil { + return bpm.CommitBlockCalled(header, body) + } + + return nil } // RevertCurrentBlock mocks revert of the current block func (bpm *BlockProcessorMock) RevertCurrentBlock() { - bpm.RevertCurrentBlockCalled() + if bpm.RevertCurrentBlockCalled != nil { + bpm.RevertCurrentBlockCalled() + } } // CreateNewHeader - func (bpm *BlockProcessorMock) CreateNewHeader(round uint64, nonce uint64) (data.HeaderHandler, error) { - return bpm.CreateNewHeaderCalled(round, nonce) + if bpm.CreateNewHeaderCalled != nil { + return bpm.CreateNewHeaderCalled(round, nonce) + } + + return nil, nil } // CreateBlock - func (bpm *BlockProcessorMock) CreateBlock(initialHdrData data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { - return bpm.CreateBlockCalled(initialHdrData, haveTime) + if bpm.CreateBlockCalled != nil { + return bpm.CreateBlockCalled(initialHdrData, haveTime) + } + + return nil, nil, nil } // RestoreBlockIntoPools - func (bpm *BlockProcessorMock) RestoreBlockIntoPools(header data.HeaderHandler, body data.BodyHandler) error { - return bpm.RestoreBlockIntoPoolsCalled(header, body) + if bpm.RestoreBlockIntoPoolsCalled != nil { + return bpm.RestoreBlockIntoPoolsCalled(header, body) + } + + return nil } // RestoreBlockBodyIntoPools - diff --git a/integrationTests/mock/countingDB.go b/integrationTests/mock/countingDB.go index 62bcd87dc03..a7228998e32 100644 --- a/integrationTests/mock/countingDB.go +++ b/integrationTests/mock/countingDB.go @@ -2,19 +2,19 @@ package mock import ( "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" + "github.com/ElrondNetwork/elrond-go/storage/database" ) var _ storage.Persister = (*countingDB)(nil) type countingDB struct { - db *memorydb.DB + db *database.MemDB nrOfPut int } // NewCountingDB returns a new instance of countingDB func NewCountingDB() *countingDB { - return &countingDB{memorydb.New(), 0} + return &countingDB{database.NewMemDB(), 0} } // Put will add the given key-value pair in the db diff --git a/integrationTests/mock/cryptoComponentsStub.go b/integrationTests/mock/cryptoComponentsStub.go index 77f0619721f..b30fce753fc 100644 --- a/integrationTests/mock/cryptoComponentsStub.go +++ b/integrationTests/mock/cryptoComponentsStub.go @@ -1,9 +1,11 @@ package mock import ( + "errors" "sync" "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/heartbeat" "github.com/ElrondNetwork/elrond-go/vm" @@ -18,7 +20,7 @@ type CryptoComponentsStub struct { PubKeyBytes []byte BlockSig crypto.SingleSigner TxSig crypto.SingleSigner - MultiSig crypto.MultiSigner + MultiSigContainer cryptoCommon.MultiSignerContainer PeerSignHandler crypto.PeerSignatureHandler BlKeyGen crypto.KeyGenerator TxKeyGen crypto.KeyGenerator @@ -78,31 +80,43 @@ func (ccs *CryptoComponentsStub) TxSingleSigner() crypto.SingleSigner { return ccs.TxSig } -// MultiSigner - -func (ccs *CryptoComponentsStub) MultiSigner() crypto.MultiSigner { +// PeerSignatureHandler - +func (ccs *CryptoComponentsStub) PeerSignatureHandler() crypto.PeerSignatureHandler { ccs.mutMultiSig.RLock() defer ccs.mutMultiSig.RUnlock() - return ccs.MultiSig + return ccs.PeerSignHandler } -// PeerSignatureHandler - -func (ccs *CryptoComponentsStub) PeerSignatureHandler() crypto.PeerSignatureHandler { +// MultiSignerContainer - +func (ccs *CryptoComponentsStub) MultiSignerContainer() cryptoCommon.MultiSignerContainer { ccs.mutMultiSig.RLock() defer ccs.mutMultiSig.RUnlock() - return ccs.PeerSignHandler + return ccs.MultiSigContainer } -// SetMultiSigner - -func (ccs *CryptoComponentsStub) SetMultiSigner(ms crypto.MultiSigner) error { +// SetMultiSignerContainer - +func (ccs *CryptoComponentsStub) SetMultiSignerContainer(ms cryptoCommon.MultiSignerContainer) error { ccs.mutMultiSig.Lock() - ccs.MultiSig = ms + ccs.MultiSigContainer = ms ccs.mutMultiSig.Unlock() return nil } +// GetMultiSigner - +func (ccs *CryptoComponentsStub) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { + ccs.mutMultiSig.RLock() + defer ccs.mutMultiSig.RUnlock() + + if ccs.MultiSigContainer == nil { + return nil, errors.New("nil multi sig container") + } + + return ccs.MultiSigContainer.GetMultiSigner(epoch) +} + // BlockSignKeyGen - func (ccs *CryptoComponentsStub) BlockSignKeyGen() crypto.KeyGenerator { return ccs.BlKeyGen @@ -138,7 +152,7 @@ func (ccs *CryptoComponentsStub) Clone() interface{} { PubKeyBytes: ccs.PubKeyBytes, BlockSig: ccs.BlockSig, TxSig: ccs.TxSig, - MultiSig: ccs.MultiSig, + MultiSigContainer: ccs.MultiSigContainer, PeerSignHandler: ccs.PeerSignHandler, BlKeyGen: ccs.BlKeyGen, TxKeyGen: ccs.TxKeyGen, diff --git a/integrationTests/mock/triesHolderStub.go b/integrationTests/mock/triesHolderStub.go deleted file mode 100644 index 085a027cea3..00000000000 --- a/integrationTests/mock/triesHolderStub.go +++ /dev/null @@ -1,56 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go/common" -) - -// TriesHolderStub - -type TriesHolderStub struct { - PutCalled func([]byte, common.Trie) - RemoveCalled func([]byte, common.Trie) - GetCalled func([]byte) common.Trie - GetAllCalled func() []common.Trie - ResetCalled func() -} - -// Put - -func (ths *TriesHolderStub) Put(key []byte, trie common.Trie) { - if ths.PutCalled != nil { - ths.PutCalled(key, trie) - } -} - -// Replace - -func (ths *TriesHolderStub) Replace(key []byte, trie common.Trie) { - if ths.RemoveCalled != nil { - ths.RemoveCalled(key, trie) - } -} - -// Get - -func (ths *TriesHolderStub) Get(key []byte) common.Trie { - if ths.GetCalled != nil { - return ths.GetCalled(key) - } - return nil -} - -// GetAll - -func (ths *TriesHolderStub) GetAll() []common.Trie { - if ths.GetAllCalled != nil { - return ths.GetAllCalled() - } - return nil -} - -// Reset - -func (ths *TriesHolderStub) Reset() { - if ths.ResetCalled != nil { - ths.ResetCalled() - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ths *TriesHolderStub) IsInterfaceNil() bool { - return ths == nil -} diff --git a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go index 83cade83727..9d472cc8dba 100644 --- a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go +++ b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go @@ -11,7 +11,6 @@ import ( "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" - mclsinglesig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/singlesig" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/stretchr/testify/assert" ) @@ -27,7 +26,7 @@ func TestInterceptedShardBlockHeaderVerifiedWithCorrectConsensusGroup(t *testing nbMetaNodes := 4 nbShards := 1 consensusGroupSize := 3 - singleSigner := &mclsinglesig.BlsSingleSigner{} + singleSigner := integrationTests.TestSingleBlsSigner // create map of shard - testNodeProcessors for metachain and shard chain nodesMap := integrationTests.CreateNodesWithNodesCoordinator( @@ -166,7 +165,7 @@ func TestInterceptedShardBlockHeaderWithLeaderSignatureAndRandSeedChecks(t *test nbShards := 1 consensusGroupSize := 3 - singleSigner := &mclsinglesig.BlsSingleSigner{} + singleSigner := integrationTests.TestSingleBlsSigner keyGen := signing.NewKeyGenerator(mcl.NewSuiteBLS12()) // create map of shard - testNodeProcessors for metachain and shard chain nodesMap := integrationTests.CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( @@ -239,7 +238,7 @@ func TestInterceptedShardHeaderBlockWithWrongPreviousRandSeedShouldNotBeAccepted nbShards := 1 consensusGroupSize := 3 - singleSigner := &mclsinglesig.BlsSingleSigner{} + singleSigner := integrationTests.TestSingleBlsSigner keyGen := signing.NewKeyGenerator(mcl.NewSuiteBLS12()) // create map of shard - testNodeProcessors for metachain and shard chain nodesMap := integrationTests.CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( diff --git a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go index 281f073524d..aa01a12389e 100644 --- a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go +++ b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go @@ -27,7 +27,7 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" epochNotifierMock "github.com/ElrondNetwork/elrond-go/testscommon/epochNotifier" "github.com/ElrondNetwork/elrond-go/testscommon/genericMocks" @@ -352,14 +352,14 @@ func getBootstrapper(shardID uint32, baseArgs storageBootstrap.ArgsBaseStorageBo func getGeneralConfig() config.Config { generalConfig := testscommon.GetGeneralConfig() - generalConfig.MiniBlocksStorage.DB.Type = string(storageUnit.LvlDBSerial) - generalConfig.ShardHdrNonceHashStorage.DB.Type = string(storageUnit.LvlDBSerial) - generalConfig.MetaBlockStorage.DB.Type = string(storageUnit.LvlDBSerial) - generalConfig.MetaHdrNonceHashStorage.DB.Type = string(storageUnit.LvlDBSerial) - generalConfig.BlockHeaderStorage.DB.Type = string(storageUnit.LvlDBSerial) - generalConfig.BootstrapStorage.DB.Type = string(storageUnit.LvlDBSerial) - generalConfig.ReceiptsStorage.DB.Type = string(storageUnit.LvlDBSerial) - generalConfig.ScheduledSCRsStorage.DB.Type = string(storageUnit.LvlDBSerial) + generalConfig.MiniBlocksStorage.DB.Type = string(storageunit.LvlDBSerial) + generalConfig.ShardHdrNonceHashStorage.DB.Type = string(storageunit.LvlDBSerial) + generalConfig.MetaBlockStorage.DB.Type = string(storageunit.LvlDBSerial) + generalConfig.MetaHdrNonceHashStorage.DB.Type = string(storageunit.LvlDBSerial) + generalConfig.BlockHeaderStorage.DB.Type = string(storageunit.LvlDBSerial) + generalConfig.BootstrapStorage.DB.Type = string(storageunit.LvlDBSerial) + generalConfig.ReceiptsStorage.DB.Type = string(storageunit.LvlDBSerial) + generalConfig.ScheduledSCRsStorage.DB.Type = string(storageunit.LvlDBSerial) return generalConfig } diff --git a/integrationTests/multiShard/hardFork/hardFork_test.go b/integrationTests/multiShard/hardFork/hardFork_test.go index 68237e1f7df..3f708fd1f92 100644 --- a/integrationTests/multiShard/hardFork/hardFork_test.go +++ b/integrationTests/multiShard/hardFork/hardFork_test.go @@ -22,6 +22,7 @@ import ( "github.com/ElrondNetwork/elrond-go/integrationTests/vm/arwen" vmFactory "github.com/ElrondNetwork/elrond-go/process/factory" "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" "github.com/ElrondNetwork/elrond-go/testscommon/genesisMocks" "github.com/ElrondNetwork/elrond-go/update/factory" "github.com/ElrondNetwork/elrond-go/vm/systemSmartContracts/defaults" @@ -564,7 +565,7 @@ func createHardForkExporter( cryptoComponents := integrationTests.GetDefaultCryptoComponents() cryptoComponents.BlockSig = node.OwnAccount.BlockSingleSigner cryptoComponents.TxSig = node.OwnAccount.SingleSigner - cryptoComponents.MultiSig = node.MultiSigner + cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(node.MultiSigner) cryptoComponents.BlKeyGen = node.OwnAccount.KeygenBlockSign cryptoComponents.TxKeyGen = node.OwnAccount.KeygenTxSign diff --git a/integrationTests/multiShard/relayedTx/relayedTx_test.go b/integrationTests/multiShard/relayedTx/relayedTx_test.go index 8166c128505..4af579039a0 100644 --- a/integrationTests/multiShard/relayedTx/relayedTx_test.go +++ b/integrationTests/multiShard/relayedTx/relayedTx_test.go @@ -408,7 +408,7 @@ func CheckAddressHasTokens( func getESDTDataFromKey(userAcnt state.UserAccountHandler, key []byte) (*esdt.ESDigitalToken, error) { esdtData := &esdt.ESDigitalToken{Value: big.NewInt(0)} - marshaledData, err := userAcnt.DataTrieTracker().RetrieveValue(key) + marshaledData, err := userAcnt.RetrieveValue(key) if err != nil { return esdtData, nil } diff --git a/integrationTests/multiShard/smartContract/dns/dns_test.go b/integrationTests/multiShard/smartContract/dns/dns_test.go index 70896e105d9..78435ea7441 100644 --- a/integrationTests/multiShard/smartContract/dns/dns_test.go +++ b/integrationTests/multiShard/smartContract/dns/dns_test.go @@ -12,6 +12,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/api" "github.com/ElrondNetwork/elrond-go-core/hashing/keccak" + "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/genesis" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/ElrondNetwork/elrond-go/integrationTests/multiShard/relayedTx" @@ -299,7 +300,7 @@ func checkUserNamesAreDeleted( dnsAcc, _ := acnt.(state.UserAccountHandler) keyFromTrie := "value_state" + string(keccak.NewKeccak().Compute(userName)) - value, err := dnsAcc.DataTrie().Get([]byte(keyFromTrie)) + value, err := dnsAcc.DataTrie().(common.Trie).Get([]byte(keyFromTrie)) assert.Nil(t, err) assert.Nil(t, value) } diff --git a/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go b/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go index 39253c0efd9..12c65306fdf 100644 --- a/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go +++ b/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go @@ -28,8 +28,9 @@ func TestBridgeSetupAndBurn(t *testing.T) { numMetachainNodes := 1 enableEpochs := config.EnableEpochs{ - GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, - BuiltInFunctionOnMetaEnableEpoch: integrationTests.UnreachableEpoch, + GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, + BuiltInFunctionOnMetaEnableEpoch: integrationTests.UnreachableEpoch, + FixAsyncCallBackArgsListEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -141,7 +142,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { func checkBurnedOnESDTContract(t *testing.T, nodes []*integrationTests.TestProcessorNode, tokenIdentifier []byte, burntValue *big.Int) { esdtSCAcc := getUserAccountWithAddress(t, vm.ESDTSCAddress, nodes) - retrievedData, _ := esdtSCAcc.DataTrieTracker().RetrieveValue(tokenIdentifier) + retrievedData, _ := esdtSCAcc.RetrieveValue(tokenIdentifier) tokenInSystemSC := &systemSmartContracts.ESDTDataV2{} _ = integrationTests.TestMarshalizer.Unmarshal(tokenInSystemSC, retrievedData) diff --git a/integrationTests/multiShard/transaction/txRouting/txRouting_test.go b/integrationTests/multiShard/transaction/txRouting/txRouting_test.go index fdbe87b4ffa..c2afef7aa0b 100644 --- a/integrationTests/multiShard/transaction/txRouting/txRouting_test.go +++ b/integrationTests/multiShard/transaction/txRouting/txRouting_test.go @@ -10,7 +10,6 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data/transaction" "github.com/ElrondNetwork/elrond-go-crypto" - ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -149,7 +148,7 @@ func generateTx(sender crypto.PrivateKey, receiver crypto.PublicKey, nonce uint6 Version: integrationTests.MinTransactionVersion, } marshalizedTxBeforeSigning, _ := tx.GetDataForSigning(integrationTests.TestAddressPubkeyConverter, integrationTests.TestTxSignMarshalizer) - signer := ed25519SingleSig.Ed25519Signer{} + signer := integrationTests.TestSingleSigner signature, _ := signer.Sign(sender, marshalizedTxBeforeSigning) tx.Signature = signature diff --git a/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go b/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go index 4529260a14a..1cf42464377 100644 --- a/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go +++ b/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go @@ -227,12 +227,12 @@ func jailNodes(nodes []*integrationTests.TestProcessorNode, blsKeys [][]byte) { stakingAcc := acc.(state.UserAccountHandler) for _, blsKey := range blsKeys { - marshaledData, _ := stakingAcc.DataTrieTracker().RetrieveValue(blsKey) + marshaledData, _ := stakingAcc.RetrieveValue(blsKey) stakingData := &systemSmartContracts.StakedDataV2_0{} _ = integrationTests.TestMarshalizer.Unmarshal(stakingData, marshaledData) stakingData.Jailed = true marshaledData, _ = integrationTests.TestMarshalizer.Marshal(stakingData) - _ = stakingAcc.DataTrieTracker().SaveKeyValue(blsKey, marshaledData) + _ = stakingAcc.SaveKeyValue(blsKey, marshaledData) } _ = node.AccntState.SaveAccount(stakingAcc) @@ -367,7 +367,7 @@ func testBLSKeyOwnerIsAddress(t *testing.T, nodes []*integrationTests.TestProces acnt, _ := n.AccntState.GetExistingAccount(vm.StakingSCAddress) userAcc, _ := acnt.(state.UserAccountHandler) - marshaledData, _ := userAcc.DataTrieTracker().RetrieveValue(blsKey) + marshaledData, _ := userAcc.RetrieveValue(blsKey) stakingData := &systemSmartContracts.StakedDataV2_0{} _ = integrationTests.TestMarshalizer.Unmarshal(stakingData, marshaledData) assert.Equal(t, stakingData.OwnerAddress, address) diff --git a/integrationTests/multisig/blsMultisig_test.go b/integrationTests/multisig/blsMultisig_test.go index ed7a001ddb8..31cef69fa84 100644 --- a/integrationTests/multisig/blsMultisig_test.go +++ b/integrationTests/multisig/blsMultisig_test.go @@ -1,8 +1,6 @@ package multisig import ( - "bytes" - "errors" "fmt" "testing" @@ -16,123 +14,37 @@ import ( "github.com/stretchr/testify/assert" ) -func createMultiSignersBls( - numOfSigners uint16, +func createKeysAndMultiSignerBls( grSize uint16, hasher hashing.Hasher, suite crypto.Suite, -) ([]string, []crypto.MultiSigner) { +) ([][]byte, [][]byte, crypto.MultiSigner) { kg := signing.NewKeyGenerator(suite) - - var pubKeyBytes []byte - - privKeys := make([]crypto.PrivateKey, grSize) - pubKeys := make([]crypto.PublicKey, grSize) - pubKeysStr := make([]string, grSize) + privKeys := make([][]byte, grSize) + pubKeys := make([][]byte, grSize) for i := uint16(0); i < grSize; i++ { sk, pk := kg.GeneratePair() - privKeys[i] = sk - pubKeys[i] = pk - - pubKeyBytes, _ = pk.ToByteArray() - pubKeysStr[i] = string(pubKeyBytes) + privKeys[i], _ = sk.ToByteArray() + pubKeys[i], _ = pk.ToByteArray() } - - multiSigners := make([]crypto.MultiSigner, numOfSigners) llSigner := &llsig.BlsMultiSigner{Hasher: hasher} - for i := uint16(0); i < numOfSigners; i++ { - multiSigners[i], _ = multisig.NewBLSMultisig(llSigner, pubKeysStr, privKeys[i], kg, i) - } + multiSigner, _ := multisig.NewBLSMultisig(llSigner, kg) - return pubKeysStr, multiSigners + return privKeys, pubKeys, multiSigner } -func createSignaturesShares(numOfSigners uint16, multiSigners []crypto.MultiSigner, message []byte) [][]byte { - sigShares := make([][]byte, numOfSigners) - for i := uint16(0); i < numOfSigners; i++ { - sigShares[i], _ = multiSigners[i].CreateSignatureShare(message, []byte("")) +func createSignaturesShares(privKeys [][]byte, multiSigner crypto.MultiSigner, message []byte) [][]byte { + sigShares := make([][]byte, len(privKeys)) + for i := uint16(0); i < uint16(len(privKeys)); i++ { + sigShares[i], _ = multiSigner.CreateSignatureShare(privKeys[i], message) } return sigShares } -func setSignatureSharesAllSignersBls(multiSigners []crypto.MultiSigner, sigs [][]byte) error { - grSize := uint16(len(multiSigners)) - var err error - - for i := uint16(0); i < grSize; i++ { - for j := uint16(0); j < grSize; j++ { - err = multiSigners[j].StoreSignatureShare(i, sigs[i]) - if err != nil { - return err - } - } - } - - return nil -} - -func verifySigAllSignersBls( - multiSigners []crypto.MultiSigner, - message []byte, - signature []byte, - pubKeys []string, - bitmap []byte, - grSize uint16, -) error { - - var err error - var muSig crypto.MultiSigner - - for i := uint16(0); i < grSize; i++ { - muSig, err = multiSigners[i].Create(pubKeys, i) - if err != nil { - return err - } - - multiSigners[i] = muSig - err = multiSigners[i].SetAggregatedSig(signature) - if err != nil { - return err - } - - err = multiSigners[i].Verify(message, bitmap) - if err != nil { - return err - } - } - - return nil -} - -func aggregateSignatureSharesAllSignersBls(multiSigners []crypto.MultiSigner, bitmap []byte, grSize uint16) ( - signature []byte, - err error, -) { - aggSig, err := multiSigners[0].AggregateSigs(bitmap) - - if err != nil { - return nil, err - } - - for i := uint16(1); i < grSize; i++ { - aggSig2, err1 := multiSigners[i].AggregateSigs(bitmap) - - if err1 != nil { - return nil, err1 - } - - if !bytes.Equal(aggSig, aggSig2) { - return nil, errors.New("aggregated signatures not equal") - } - } - - return aggSig, nil -} - func TestMultiSig_Bls(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") @@ -140,37 +52,24 @@ func TestMultiSig_Bls(t *testing.T) { t.Parallel() - consensusGroupSize := uint16(6) - numOfSigners := uint16(6) + numSigners := uint16(6) message := "message" - - bitmapSize := consensusGroupSize/8 + 1 - // set bitmap to select all members - bitmap := make([]byte, bitmapSize) - byteMask := 0xFF - for i := uint16(0); i < bitmapSize; i++ { - bitmap[i] = byte((((1 << consensusGroupSize) - 1) >> i) & byteMask) - } - hashSize := 16 hasher, _ := blake2b.NewBlake2bWithSize(hashSize) suite := mcl.NewSuiteBLS12() - pubKeysStr, multiSigners := createMultiSignersBls(numOfSigners, consensusGroupSize, hasher, suite) + privKeys, pubKeys, multiSigner := createKeysAndMultiSignerBls(numSigners, hasher, suite) numOfTimesToRepeatTests := 100 for currentIdx := 0; currentIdx < numOfTimesToRepeatTests; currentIdx++ { message = fmt.Sprintf("%s%d", message, currentIdx) - signatures := createSignaturesShares(numOfSigners, multiSigners, []byte(message)) - - err := setSignatureSharesAllSignersBls(multiSigners, signatures) - assert.Nil(t, err) + signatures := createSignaturesShares(privKeys, multiSigner, []byte(message)) - aggSig, err := aggregateSignatureSharesAllSignersBls(multiSigners, bitmap, consensusGroupSize) + aggSig, err := multiSigner.AggregateSigs(pubKeys, signatures) assert.Nil(t, err) assert.NotNil(t, aggSig) - err = verifySigAllSignersBls(multiSigners, []byte(message), aggSig, pubKeysStr, bitmap, consensusGroupSize) + err = multiSigner.VerifyAggregatedSig(pubKeys, []byte(message), aggSig) assert.Nil(t, err) } } diff --git a/integrationTests/node/heartbeat/heartbeat_test.go b/integrationTests/node/heartbeat/heartbeat_test.go index 38bb9aa3f7a..a736088f60b 100644 --- a/integrationTests/node/heartbeat/heartbeat_test.go +++ b/integrationTests/node/heartbeat/heartbeat_test.go @@ -11,7 +11,6 @@ import ( "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" - mclsig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/singlesig" logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/common/enablers" @@ -325,7 +324,7 @@ func isMessageCorrectLen(heartbeats []data.PubKeyHeartbeat, pk crypto.PublicKey, func createSenderWithName(messenger p2p.Messenger, topic string, nodeName string, enableEpochsHandler common.EnableEpochsHandler) (*process.Sender, crypto.PublicKey) { suite := mcl.NewSuiteBLS12() - signer := &mclsig.BlsSingleSigner{} + signer := integrationTests.TestSingleBlsSigner keyGen := signing.NewKeyGenerator(suite) sk, pk := keyGen.GeneratePair() version := "v01" @@ -353,7 +352,7 @@ func createSenderWithName(messenger p2p.Messenger, topic string, nodeName string func createMonitor(maxDurationPeerUnresponsive time.Duration, enableEpochsHandler common.EnableEpochsHandler) *process.Monitor { suite := mcl.NewSuiteBLS12() - singlesigner := &mclsig.BlsSingleSigner{} + singlesigner := integrationTests.TestSingleBlsSigner keyGen := signing.NewKeyGenerator(suite) marshalizer := &marshal.GogoProtoMarshalizer{} diff --git a/integrationTests/p2p/antiflood/blacklist/blacklist_test.go b/integrationTests/p2p/antiflood/blacklist/blacklist_test.go index ebc715f4e24..9b0ee005689 100644 --- a/integrationTests/p2p/antiflood/blacklist/blacklist_test.go +++ b/integrationTests/p2p/antiflood/blacklist/blacklist_test.go @@ -14,8 +14,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/blackList" "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/floodPreventers" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/stretchr/testify/assert" ) @@ -165,8 +164,8 @@ func createBlacklistHandlersAndProcessors( blacklistProcessors := make([]floodPreventers.QuotaStatusHandler, len(peers)) blacklistCachers := make([]process.PeerBlackListCacher, len(peers)) for i := range peers { - blacklistCache, _ := lrucache.NewCache(5000) - blacklistCachers[i], _ = timecache.NewPeerTimeCache(timecache.NewTimeCache(time.Minute * 5)) + blacklistCache, _ := cache.NewLRUCache(5000) + blacklistCachers[i], _ = cache.NewPeerTimeCache(cache.NewTimeCache(time.Minute * 5)) blacklistProcessors[i], err = blackList.NewP2PBlackListProcessor( blacklistCache, diff --git a/integrationTests/p2p/antiflood/common.go b/integrationTests/p2p/antiflood/common.go index 4fa90aaa6af..4aa395a6507 100644 --- a/integrationTests/p2p/antiflood/common.go +++ b/integrationTests/p2p/antiflood/common.go @@ -7,7 +7,7 @@ import ( "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/floodPreventers" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // DurationBootstrapingTime - @@ -41,8 +41,8 @@ func CreateTopicsAndMockInterceptors( return nil, fmt.Errorf("%w, pid: %s", err, p.ID()) } - cacherCfg := storageUnit.CacheConfig{Capacity: 100, Type: storageUnit.LRUCache, Shards: 1} - antifloodPool, _ := storageUnit.NewCache(cacherCfg) + cacherCfg := storageunit.CacheConfig{Capacity: 100, Type: storageunit.LRUCache, Shards: 1} + antifloodPool, _ := storageunit.NewCache(cacherCfg) interceptors[idx] = newMessageProcessor() statusHandlers := []floodPreventers.QuotaStatusHandler{&nilQuotaStatusHandler{}} diff --git a/integrationTests/p2p/networkSharding-hbv2/networkSharding_test.go b/integrationTests/p2p/networkSharding-hbv2/networkSharding_test.go index 35ca11d4ba8..7cad6459b54 100644 --- a/integrationTests/p2p/networkSharding-hbv2/networkSharding_test.go +++ b/integrationTests/p2p/networkSharding-hbv2/networkSharding_test.go @@ -5,20 +5,20 @@ import ( "testing" "time" - "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/ElrondNetwork/elrond-go/p2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/stretchr/testify/assert" ) var p2pBootstrapStepDelay = 2 * time.Second -func createDefaultConfig() config.P2PConfig { - return config.P2PConfig{ - Node: config.NodeConfig{ +func createDefaultConfig() p2pConfig.P2PConfig { + return p2pConfig.P2PConfig{ + Node: p2pConfig.NodeConfig{ Port: "0", }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ + KadDhtPeerDiscovery: p2pConfig.KadDhtPeerDiscoveryConfig{ Enabled: true, Type: "optimized", RefreshIntervalInSec: 1, @@ -31,8 +31,8 @@ func createDefaultConfig() config.P2PConfig { } func TestConnectionsInNetworkShardingWithShardingWithLists(t *testing.T) { - p2pConfig := createDefaultConfig() - p2pConfig.Sharding = config.ShardingConfig{ + p2pCfg := createDefaultConfig() + p2pCfg.Sharding = p2pConfig.ShardingConfig{ TargetPeerCount: 12, MaxIntraShardValidators: 6, MaxCrossShardValidators: 1, @@ -40,15 +40,15 @@ func TestConnectionsInNetworkShardingWithShardingWithLists(t *testing.T) { MaxCrossShardObservers: 1, MaxSeeders: 1, Type: p2p.ListsSharder, - AdditionalConnections: config.AdditionalConnectionsConfig{ + AdditionalConnections: p2pConfig.AdditionalConnectionsConfig{ MaxFullHistoryObservers: 1, }, } - testConnectionsInNetworkSharding(t, p2pConfig) + testConnectionsInNetworkSharding(t, p2pCfg) } -func testConnectionsInNetworkSharding(t *testing.T, p2pConfig config.P2PConfig) { +func testConnectionsInNetworkSharding(t *testing.T, p2pConfig p2pConfig.P2PConfig) { if testing.Short() { t.Skip("this is not a short test") } diff --git a/integrationTests/p2p/networkSharding/networkSharding_test.go b/integrationTests/p2p/networkSharding/networkSharding_test.go index a11a649e248..79e7a7aca07 100644 --- a/integrationTests/p2p/networkSharding/networkSharding_test.go +++ b/integrationTests/p2p/networkSharding/networkSharding_test.go @@ -5,20 +5,20 @@ import ( "testing" "time" - "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/ElrondNetwork/elrond-go/p2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/stretchr/testify/assert" ) var p2pBootstrapStepDelay = 2 * time.Second -func createDefaultConfig() config.P2PConfig { - return config.P2PConfig{ - Node: config.NodeConfig{ +func createDefaultConfig() p2pConfig.P2PConfig { + return p2pConfig.P2PConfig{ + Node: p2pConfig.NodeConfig{ Port: "0", }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ + KadDhtPeerDiscovery: p2pConfig.KadDhtPeerDiscoveryConfig{ Enabled: true, Type: "optimized", RefreshIntervalInSec: 1, @@ -31,8 +31,8 @@ func createDefaultConfig() config.P2PConfig { } func TestConnectionsInNetworkShardingWithShardingWithLists(t *testing.T) { - p2pConfig := createDefaultConfig() - p2pConfig.Sharding = config.ShardingConfig{ + p2pCfg := createDefaultConfig() + p2pCfg.Sharding = p2pConfig.ShardingConfig{ TargetPeerCount: 12, MaxIntraShardValidators: 6, MaxCrossShardValidators: 1, @@ -40,15 +40,15 @@ func TestConnectionsInNetworkShardingWithShardingWithLists(t *testing.T) { MaxCrossShardObservers: 1, MaxSeeders: 1, Type: p2p.ListsSharder, - AdditionalConnections: config.AdditionalConnectionsConfig{ + AdditionalConnections: p2pConfig.AdditionalConnectionsConfig{ MaxFullHistoryObservers: 1, }, } - testConnectionsInNetworkSharding(t, p2pConfig) + testConnectionsInNetworkSharding(t, p2pCfg) } -func testConnectionsInNetworkSharding(t *testing.T, p2pConfig config.P2PConfig) { +func testConnectionsInNetworkSharding(t *testing.T, p2pConfig p2pConfig.P2PConfig) { if testing.Short() { t.Skip("this is not a short test") } diff --git a/integrationTests/p2p/peerDisconnecting/peerDisconnecting_test.go b/integrationTests/p2p/peerDisconnecting/peerDisconnecting_test.go deleted file mode 100644 index 03519f2a813..00000000000 --- a/integrationTests/p2p/peerDisconnecting/peerDisconnecting_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package peerDisconnecting - -import ( - "fmt" - "testing" - - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/integrationTests" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - "github.com/libp2p/go-libp2p-core/peer" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func createDefaultConfig() config.P2PConfig { - return config.P2PConfig{ - Node: config.NodeConfig{}, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - Type: "optimized", - RefreshIntervalInSec: 1, - RoutingTableRefreshIntervalInSec: 1, - ProtocolID: "/erd/kad/1.0.0", - InitialPeerList: nil, - BucketSize: 100, - }, - } -} - -func TestPeerDisconnectionWithOneAdvertiserWithShardingWithLists(t *testing.T) { - p2pConfig := createDefaultConfig() - p2pConfig.Sharding = config.ShardingConfig{ - TargetPeerCount: 100, - MaxIntraShardValidators: 40, - MaxCrossShardValidators: 40, - MaxIntraShardObservers: 1, - MaxCrossShardObservers: 1, - MaxSeeders: 1, - Type: p2p.ListsSharder, - AdditionalConnections: config.AdditionalConnectionsConfig{ - MaxFullHistoryObservers: 1, - }, - } - p2pConfig.Node.ThresholdMinConnectedPeers = 3 - - testPeerDisconnectionWithOneAdvertiser(t, p2pConfig) -} - -func testPeerDisconnectionWithOneAdvertiser(t *testing.T, p2pConfig config.P2PConfig) { - if testing.Short() { - t.Skip("this is not a short test") - } - - numOfPeers := 20 - netw := mocknet.New() - - p2pConfigSeeder := p2pConfig - argSeeder := libp2p.ArgsNetworkMessenger{ - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: p2pConfigSeeder, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - NodeOperationMode: p2p.NormalOperation, - Marshalizer: &testscommon.MarshalizerMock{}, - SyncTimer: &testscommon.SyncTimerStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - // Step 1. Create advertiser - advertiser, err := libp2p.NewMockMessenger(argSeeder, netw) - require.Nil(t, err) - p2pConfig.KadDhtPeerDiscovery.InitialPeerList = []string{integrationTests.GetConnectableAddress(advertiser)} - - // Step 2. Create noOfPeers instances of messenger type and call bootstrap - peers := make([]p2p.Messenger, numOfPeers) - for i := 0; i < numOfPeers; i++ { - arg := libp2p.ArgsNetworkMessenger{ - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: p2pConfig, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - NodeOperationMode: p2p.NormalOperation, - Marshalizer: &testscommon.MarshalizerMock{}, - SyncTimer: &testscommon.SyncTimerStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - node, errCreate := libp2p.NewMockMessenger(arg, netw) - require.Nil(t, errCreate) - peers[i] = node - } - - // cleanup function that closes all messengers - defer func() { - for i := 0; i < numOfPeers; i++ { - if peers[i] != nil { - _ = peers[i].Close() - } - } - - if advertiser != nil { - _ = advertiser.Close() - } - }() - - // link all peers so they can connect to each other - _ = netw.LinkAll() - - // Step 3. Call bootstrap on all peers - _ = advertiser.Bootstrap() - for _, p := range peers { - _ = p.Bootstrap() - } - integrationTests.WaitForBootstrapAndShowConnected(peers, integrationTests.P2pBootstrapDelay) - - // Step 4. Disconnect one peer - disconnectedPeer := peers[5] - fmt.Printf("--- Diconnecting peer: %v ---\n", disconnectedPeer.ID().Pretty()) - _ = netw.UnlinkPeers(getPeerId(advertiser), getPeerId(disconnectedPeer)) - _ = netw.DisconnectPeers(getPeerId(advertiser), getPeerId(disconnectedPeer)) - _ = netw.DisconnectPeers(getPeerId(disconnectedPeer), getPeerId(advertiser)) - for _, p := range peers { - if p != disconnectedPeer { - _ = netw.UnlinkPeers(getPeerId(p), getPeerId(disconnectedPeer)) - _ = netw.DisconnectPeers(getPeerId(p), getPeerId(disconnectedPeer)) - _ = netw.DisconnectPeers(getPeerId(disconnectedPeer), getPeerId(p)) - } - } - for i := 0; i < 5; i++ { - integrationTests.WaitForBootstrapAndShowConnected(peers, integrationTests.P2pBootstrapDelay) - } - - // Step 4.1. Test that the peer is disconnected - for _, p := range peers { - if p != disconnectedPeer { - assert.Equal(t, numOfPeers-1, len(p.ConnectedPeers())) - } else { - assert.Equal(t, 0, len(p.ConnectedPeers())) - } - } - - // Step 5. Re-link and test connections - fmt.Println("--- Re-linking ---") - _ = netw.LinkAll() - for i := 0; i < 5; i++ { - integrationTests.WaitForBootstrapAndShowConnected(peers, integrationTests.P2pBootstrapDelay) - } - - // Step 5.1. Test that the peer is reconnected - for _, p := range peers { - assert.Equal(t, numOfPeers, len(p.ConnectedPeers())) - } -} - -func getPeerId(netMessenger p2p.Messenger) peer.ID { - return peer.ID(netMessenger.ID().Bytes()) -} diff --git a/integrationTests/p2p/peerDisconnecting/seedersDisconnecting_test.go b/integrationTests/p2p/peerDisconnecting/seedersDisconnecting_test.go deleted file mode 100644 index 74846937994..00000000000 --- a/integrationTests/p2p/peerDisconnecting/seedersDisconnecting_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package peerDisconnecting - -import ( - "testing" - - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/integrationTests" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var log = logger.GetOrCreate("integrationtests/p2p/peerdisconnecting") - -func TestSeedersDisconnectionWith2AdvertiserAnd3Peers(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - netw := mocknet.New() - p2pConfig := createDefaultConfig() - p2pConfig.KadDhtPeerDiscovery.RefreshIntervalInSec = 1 - - p2pConfig.Sharding = config.ShardingConfig{ - TargetPeerCount: 100, - MaxIntraShardValidators: 40, - MaxCrossShardValidators: 40, - MaxIntraShardObservers: 1, - MaxCrossShardObservers: 1, - MaxSeeders: 3, - Type: p2p.ListsSharder, - AdditionalConnections: config.AdditionalConnectionsConfig{ - MaxFullHistoryObservers: 0, - }, - } - p2pConfig.Node.ThresholdMinConnectedPeers = 3 - - numOfPeers := 3 - seeders, seedersList := createBootstrappedSeeders(p2pConfig, 2, netw) - - integrationTests.WaitForBootstrapAndShowConnected(seeders, integrationTests.P2pBootstrapDelay) - - // Step 2. Create noOfPeers instances of messenger type and call bootstrap - p2pConfig.KadDhtPeerDiscovery.InitialPeerList = seedersList - peers := make([]p2p.Messenger, numOfPeers) - for i := 0; i < numOfPeers; i++ { - arg := libp2p.ArgsNetworkMessenger{ - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: p2pConfig, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - NodeOperationMode: p2p.NormalOperation, - Marshalizer: &testscommon.MarshalizerMock{}, - SyncTimer: &testscommon.SyncTimerStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - node, err := libp2p.NewMockMessenger(arg, netw) - require.Nil(t, err) - peers[i] = node - } - - // cleanup function that closes all messengers - defer func() { - for i := 0; i < numOfPeers; i++ { - if peers[i] != nil { - _ = peers[i].Close() - } - } - - for i := 0; i < len(seeders); i++ { - if seeders[i] != nil { - _ = seeders[i].Close() - } - } - }() - - // link all peers so they can connect to each other - _ = netw.LinkAll() - - // Step 3. Call bootstrap on all peers - for _, p := range peers { - _ = p.Bootstrap() - } - integrationTests.WaitForBootstrapAndShowConnected(append(seeders, peers...), integrationTests.P2pBootstrapDelay) - - // Step 4. Disconnect the seeders - log.Info("--- Disconnecting seeders: %v ---\n", seeders) - disconnectSeedersFromPeers(seeders, peers, netw) - - for i := 0; i < 2; i++ { - integrationTests.WaitForBootstrapAndShowConnected(append(seeders, peers...), integrationTests.P2pBootstrapDelay) - } - - // Step 4.1. Test that the peers are disconnected - for _, p := range peers { - assert.Equal(t, numOfPeers-1, len(p.ConnectedPeers())) - } - - for _, s := range seeders { - assert.Equal(t, len(seeders)-1, len(s.ConnectedPeers())) - } - - // Step 5. Re-link and test connections - log.Info("--- Re-linking ---") - _ = netw.LinkAll() - for i := 0; i < 2; i++ { - integrationTests.WaitForBootstrapAndShowConnected(append(seeders, peers...), integrationTests.P2pBootstrapDelay) - } - - // Step 5.1. Test that the peers got reconnected - for _, p := range append(peers, seeders...) { - assert.Equal(t, numOfPeers+len(seeders)-1, len(p.ConnectedPeers())) - } -} - -func createBootstrappedSeeders(baseP2PConfig config.P2PConfig, numSeeders int, netw mocknet.Mocknet) ([]p2p.Messenger, []string) { - seeders := make([]p2p.Messenger, numSeeders) - seedersAddresses := make([]string, numSeeders) - - p2pConfigSeeder := baseP2PConfig - argSeeder := libp2p.ArgsNetworkMessenger{ - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: p2pConfigSeeder, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - NodeOperationMode: p2p.NormalOperation, - Marshalizer: &testscommon.MarshalizerMock{}, - SyncTimer: &testscommon.SyncTimerStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - seeders[0], _ = libp2p.NewMockMessenger(argSeeder, netw) - _ = seeders[0].Bootstrap() - seedersAddresses[0] = integrationTests.GetConnectableAddress(seeders[0]) - - for i := 1; i < numSeeders; i++ { - p2pConfigSeeder = baseP2PConfig - p2pConfigSeeder.KadDhtPeerDiscovery.InitialPeerList = []string{integrationTests.GetConnectableAddress(seeders[0])} - argSeeder = libp2p.ArgsNetworkMessenger{ - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: p2pConfigSeeder, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - NodeOperationMode: p2p.NormalOperation, - Marshalizer: &testscommon.MarshalizerMock{}, - SyncTimer: &testscommon.SyncTimerStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - seeders[i], _ = libp2p.NewMockMessenger(argSeeder, netw) - _ = netw.LinkAll() - _ = seeders[i].Bootstrap() - seedersAddresses[i] = integrationTests.GetConnectableAddress(seeders[i]) - } - - return seeders, seedersAddresses -} - -func disconnectSeedersFromPeers(seeders []p2p.Messenger, peers []p2p.Messenger, netw mocknet.Mocknet) { - for _, p := range peers { - for _, s := range seeders { - disconnectPeers(p, s, netw) - } - } -} - -func disconnectPeers(peer1 p2p.Messenger, peer2 p2p.Messenger, netw mocknet.Mocknet) { - _ = netw.UnlinkPeers(getPeerId(peer1), getPeerId(peer2)) - _ = netw.DisconnectPeers(getPeerId(peer1), getPeerId(peer2)) - _ = netw.DisconnectPeers(getPeerId(peer2), getPeerId(peer1)) -} diff --git a/integrationTests/p2p/peerDiscovery/kadDht/peerDiscovery_test.go b/integrationTests/p2p/peerDiscovery/kadDht/peerDiscovery_test.go deleted file mode 100644 index f5926aae166..00000000000 --- a/integrationTests/p2p/peerDiscovery/kadDht/peerDiscovery_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package kadDht - -import ( - "fmt" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/integrationTests" - "github.com/ElrondNetwork/elrond-go/integrationTests/p2p/peerDiscovery" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/stretchr/testify/assert" -) - -var durationTopicAnnounceTime = 2 * time.Second - -func TestPeerDiscoveryAndMessageSendingWithOneAdvertiser(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - numOfPeers := 20 - - //Step 1. Create advertiser - advertiser := integrationTests.CreateMessengerWithKadDht("") - _ = advertiser.Bootstrap() - - //Step 2. Create numOfPeers instances of messenger type and call bootstrap - peers := make([]p2p.Messenger, numOfPeers) - - for i := 0; i < numOfPeers; i++ { - peers[i] = integrationTests.CreateMessengerWithKadDht(integrationTests.GetConnectableAddress(advertiser)) - - _ = peers[i].Bootstrap() - } - - //cleanup function that closes all messengers - defer func() { - for i := 0; i < numOfPeers; i++ { - if peers[i] != nil { - _ = peers[i].Close() - } - } - - if advertiser != nil { - _ = advertiser.Close() - } - }() - - integrationTests.WaitForBootstrapAndShowConnected(peers, integrationTests.P2pBootstrapDelay) - - //Step 3. Create a test topic, add receiving handlers - createTestTopicAndWaitForAnnouncements(t, peers) - - //Step 4. run the test for a couple of times as peer discovering and topic announcing - // are not deterministic nor instant processes - - numOfTests := 5 - for i := 0; i < numOfTests; i++ { - testResult := peerDiscovery.RunTest(peers, i, "test topic") - - if testResult { - return - } - } - - assert.Fail(t, "test failed. Discovery/message passing are not validated") -} - -func TestPeerDiscoveryAndMessageSendingWithThreeAdvertisers(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - numOfPeers := 20 - numOfAdvertisers := 3 - - //Step 1. Create 3 advertisers and connect them together - advertisers := make([]p2p.Messenger, numOfAdvertisers) - advertisers[0] = integrationTests.CreateMessengerWithKadDht("") - _ = advertisers[0].Bootstrap() - - for idx := 1; idx < numOfAdvertisers; idx++ { - advertisers[idx] = integrationTests.CreateMessengerWithKadDht(integrationTests.GetConnectableAddress(advertisers[0])) - _ = advertisers[idx].Bootstrap() - } - - //Step 2. Create numOfPeers instances of messenger type and call bootstrap - peers := make([]p2p.Messenger, numOfPeers) - - for i := 0; i < numOfPeers; i++ { - peers[i] = integrationTests.CreateMessengerWithKadDht(integrationTests.GetConnectableAddress(advertisers[i%numOfAdvertisers])) - _ = peers[i].Bootstrap() - } - - //cleanup function that closes all messengers - defer func() { - for i := 0; i < numOfPeers; i++ { - if peers[i] != nil { - _ = peers[i].Close() - } - } - - for i := 0; i < numOfAdvertisers; i++ { - if advertisers[i] != nil { - _ = advertisers[i].Close() - } - } - }() - - integrationTests.WaitForBootstrapAndShowConnected(peers, integrationTests.P2pBootstrapDelay) - - //Step 3. Create a test topic, add receiving handlers - createTestTopicAndWaitForAnnouncements(t, peers) - - //Step 4. run the test for a couple of times as peer discovering and topic announcing - // are not deterministic nor instant processes - - noOfTests := 5 - for i := 0; i < noOfTests; i++ { - testResult := peerDiscovery.RunTest(peers, i, "test topic") - - if testResult { - return - } - } - - assert.Fail(t, "test failed. Discovery/message passing are not validated") -} - -func TestPeerDiscoveryAndMessageSendingWithOneAdvertiserAndProtocolID(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - advertiser := integrationTests.CreateMessengerWithKadDht("") - _ = advertiser.Bootstrap() - - protocolID1 := "/erd/kad/1.0.0" - protocolID2 := "/amony/kad/0.0.0" - - peer1 := integrationTests.CreateMessengerWithKadDhtAndProtocolID( - integrationTests.GetConnectableAddress(advertiser), - protocolID1, - ) - peer2 := integrationTests.CreateMessengerWithKadDhtAndProtocolID( - integrationTests.GetConnectableAddress(advertiser), - protocolID1, - ) - peer3 := integrationTests.CreateMessengerWithKadDhtAndProtocolID( - integrationTests.GetConnectableAddress(advertiser), - protocolID2, - ) - - peers := []p2p.Messenger{peer1, peer2, peer3} - - for _, peer := range peers { - _ = peer.Bootstrap() - } - - //cleanup function that closes all messengers - defer func() { - for i := 0; i < len(peers); i++ { - if peers[i] != nil { - _ = peers[i].Close() - } - } - - if advertiser != nil { - _ = advertiser.Close() - } - }() - - integrationTests.WaitForBootstrapAndShowConnected(peers, integrationTests.P2pBootstrapDelay) - - createTestTopicAndWaitForAnnouncements(t, peers) - - topic := "test topic" - message := []byte("message") - messageProcessors := assignProcessors(peers, topic) - - peer1.Broadcast(topic, message) - time.Sleep(time.Second * 2) - - assert.Equal(t, message, messageProcessors[0].GetLastMessage()) - assert.Equal(t, message, messageProcessors[1].GetLastMessage()) - assert.Nil(t, messageProcessors[2].GetLastMessage()) - - assert.Equal(t, 2, len(peer1.ConnectedPeers())) - assert.Equal(t, 2, len(peer2.ConnectedPeers())) - assert.Equal(t, 1, len(peer3.ConnectedPeers())) -} - -func assignProcessors(peers []p2p.Messenger, topic string) []*peerDiscovery.SimpleMessageProcessor { - processors := make([]*peerDiscovery.SimpleMessageProcessor, 0, len(peers)) - for _, peer := range peers { - proc := &peerDiscovery.SimpleMessageProcessor{} - processors = append(processors, proc) - - err := peer.RegisterMessageProcessor(topic, "test", proc) - if err != nil { - fmt.Println(err.Error()) - } - } - - return processors -} - -func createTestTopicAndWaitForAnnouncements(t *testing.T, peers []p2p.Messenger) { - for _, peer := range peers { - err := peer.CreateTopic("test topic", true) - if err != nil { - assert.Fail(t, "test fail while creating topic") - } - } - - fmt.Printf("Waiting %v for topic announcement...\n", durationTopicAnnounceTime) - time.Sleep(durationTopicAnnounceTime) -} diff --git a/integrationTests/p2p/peerDiscovery/messageProcessor.go b/integrationTests/p2p/peerDiscovery/messageProcessor.go deleted file mode 100644 index a9e8f342b04..00000000000 --- a/integrationTests/p2p/peerDiscovery/messageProcessor.go +++ /dev/null @@ -1,51 +0,0 @@ -package peerDiscovery - -import ( - "bytes" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -// MessageProcesssor - -type MessageProcesssor struct { - RequiredValue []byte - chanDone chan struct{} - mutDataReceived sync.Mutex - wasDataReceived bool -} - -// NewMessageProcessor - -func NewMessageProcessor(chanDone chan struct{}, requiredVal []byte) *MessageProcesssor { - return &MessageProcesssor{ - RequiredValue: requiredVal, - chanDone: chanDone, - } -} - -// ProcessReceivedMessage - -func (mp *MessageProcesssor) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID) error { - if bytes.Equal(mp.RequiredValue, message.Data()) { - mp.mutDataReceived.Lock() - mp.wasDataReceived = true - mp.mutDataReceived.Unlock() - - mp.chanDone <- struct{}{} - } - - return nil -} - -// WasDataReceived - -func (mp *MessageProcesssor) WasDataReceived() bool { - mp.mutDataReceived.Lock() - defer mp.mutDataReceived.Unlock() - - return mp.wasDataReceived -} - -// IsInterfaceNil returns true if there is no value under the interface -func (mp *MessageProcesssor) IsInterfaceNil() bool { - return mp == nil -} diff --git a/integrationTests/p2p/peerDiscovery/simpleMessageProcessor.go b/integrationTests/p2p/peerDiscovery/simpleMessageProcessor.go deleted file mode 100644 index fe54584e675..00000000000 --- a/integrationTests/p2p/peerDiscovery/simpleMessageProcessor.go +++ /dev/null @@ -1,36 +0,0 @@ -package peerDiscovery - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -// SimpleMessageProcessor records the last received message -type SimpleMessageProcessor struct { - mutMessage sync.RWMutex - message []byte -} - -// ProcessReceivedMessage records the message -func (smp *SimpleMessageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID) error { - smp.mutMessage.Lock() - smp.message = message.Data() - smp.mutMessage.Unlock() - - return nil -} - -// GetLastMessage returns the last message received -func (smp *SimpleMessageProcessor) GetLastMessage() []byte { - smp.mutMessage.RLock() - defer smp.mutMessage.RUnlock() - - return smp.message -} - -// IsInterfaceNil returns true if there is no value under the interface -func (smp *SimpleMessageProcessor) IsInterfaceNil() bool { - return smp == nil -} diff --git a/integrationTests/p2p/peerDiscovery/testRunnner.go b/integrationTests/p2p/peerDiscovery/testRunnner.go deleted file mode 100644 index a75f2b4311d..00000000000 --- a/integrationTests/p2p/peerDiscovery/testRunnner.go +++ /dev/null @@ -1,78 +0,0 @@ -package peerDiscovery - -import ( - "fmt" - "strconv" - "sync/atomic" - "time" - - "github.com/ElrondNetwork/elrond-go/p2p" -) - -var durationMsgReceived = 2 * time.Second - -// RunTest will test if all the peers receive a message -func RunTest(peers []p2p.Messenger, testIndex int, topic string) bool { - fmt.Printf("Running test %v\n", testIndex) - - testMessage := "test " + strconv.Itoa(testIndex) - messageProcessors := make([]*MessageProcesssor, len(peers)) - - chanDone := make(chan struct{}) - chanMessageProcessor := make(chan struct{}, len(peers)) - - //add a new message processor for each messenger - for i, peer := range peers { - mp := NewMessageProcessor(chanMessageProcessor, []byte(testMessage)) - - messageProcessors[i] = mp - err := peer.RegisterMessageProcessor(topic, "test", mp) - if err != nil { - fmt.Println(err.Error()) - return false - } - } - - var msgReceived int32 = 0 - - go func() { - - for { - <-chanMessageProcessor - - completelyRecv := true - - atomic.StoreInt32(&msgReceived, 0) - - //to be 100% all peers received the messages, iterate all message processors and check received flag - for _, mp := range messageProcessors { - if !mp.WasDataReceived() { - completelyRecv = false - continue - } - - atomic.AddInt32(&msgReceived, 1) - } - - if !completelyRecv { - continue - } - - //all messengers got the message - chanDone <- struct{}{} - return - } - }() - - //write the message on topic - peers[0].Broadcast(topic, []byte(testMessage)) - - select { - case <-chanDone: - return true - case <-time.After(durationMsgReceived): - fmt.Printf("timeout fetching all messages. Got %d from %d\n", - atomic.LoadInt32(&msgReceived), len(peers)) - return false - } -} diff --git a/integrationTests/p2p/pubsub/messageProcessor.go b/integrationTests/p2p/pubsub/messageProcessor.go deleted file mode 100644 index b27bd690ad3..00000000000 --- a/integrationTests/p2p/pubsub/messageProcessor.go +++ /dev/null @@ -1,56 +0,0 @@ -package peerDisconnecting - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -type messageProcessor struct { - mutMessages sync.Mutex - messages map[core.PeerID][]p2p.MessageP2P -} - -func newMessageProcessor() *messageProcessor { - return &messageProcessor{ - messages: make(map[core.PeerID][]p2p.MessageP2P), - } -} - -// ProcessReceivedMessage - -func (mp *messageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - mp.mutMessages.Lock() - defer mp.mutMessages.Unlock() - - mp.messages[fromConnectedPeer] = append(mp.messages[fromConnectedPeer], message) - - return nil -} - -// Messages - -func (mp *messageProcessor) Messages(pid core.PeerID) []p2p.MessageP2P { - mp.mutMessages.Lock() - defer mp.mutMessages.Unlock() - - return mp.messages[pid] -} - -// AllMessages - -func (mp *messageProcessor) AllMessages() []p2p.MessageP2P { - result := make([]p2p.MessageP2P, 0) - - mp.mutMessages.Lock() - defer mp.mutMessages.Unlock() - - for _, messages := range mp.messages { - result = append(result, messages...) - } - - return result -} - -// IsInterfaceNil - -func (mp *messageProcessor) IsInterfaceNil() bool { - return mp == nil -} diff --git a/integrationTests/p2p/pubsub/peerReceivingMessages_test.go b/integrationTests/p2p/pubsub/peerReceivingMessages_test.go deleted file mode 100644 index c93ea3dce11..00000000000 --- a/integrationTests/p2p/pubsub/peerReceivingMessages_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package peerDisconnecting - -import ( - "encoding/hex" - "fmt" - "sync" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/integrationTests" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/stretchr/testify/assert" -) - -var durationTest = 30 * time.Second - -type messageProcessorStub struct { - ProcessReceivedMessageCalled func(message p2p.MessageP2P) error -} - -// ProcessReceivedMessage - -func (mps *messageProcessorStub) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID) error { - return mps.ProcessReceivedMessageCalled(message) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (mps *messageProcessorStub) IsInterfaceNil() bool { - return mps == nil -} - -func TestPeerReceivesTheSameMessageMultipleTimesShouldNotHappen(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - numOfPeers := 20 - - //Step 1. Create advertiser - advertiser := integrationTests.CreateMessengerWithKadDht("") - - //Step 2. Create numOfPeers instances of messenger type and call bootstrap - peers := make([]p2p.Messenger, numOfPeers) - for i := 0; i < numOfPeers; i++ { - node := integrationTests.CreateMessengerWithKadDht(integrationTests.GetConnectableAddress(advertiser)) - peers[i] = node - } - - //cleanup function that closes all messengers - defer func() { - for i := 0; i < numOfPeers; i++ { - if peers[i] != nil { - _ = peers[i].Close() - } - } - - if advertiser != nil { - _ = advertiser.Close() - } - }() - - chanStop := make(chan struct{}) - - //Step 3. Register pubsub validators - mutMapMessages := sync.Mutex{} - mapMessages := make(map[int]map[string]struct{}) - testTopic := "test" - - for i := 0; i < numOfPeers; i++ { - idx := i - mapMessages[idx] = make(map[string]struct{}) - err := peers[idx].CreateTopic(testTopic, true) - if err != nil { - fmt.Println("CreateTopic failed:", err.Error()) - continue - } - - err = peers[idx].RegisterMessageProcessor(testTopic, "test", &messageProcessorStub{ - ProcessReceivedMessageCalled: func(message p2p.MessageP2P) error { - time.Sleep(time.Second) - - mutMapMessages.Lock() - defer mutMapMessages.Unlock() - - msgId := "peer: " + message.Peer().Pretty() + " - seqNo: 0x" + hex.EncodeToString(message.SeqNo()) - _, ok := mapMessages[idx][msgId] - if ok { - assert.Fail(t, "message %s received twice", msgId) - chanStop <- struct{}{} - } - - mapMessages[idx][msgId] = struct{}{} - return nil - }, - }) - if err != nil { - fmt.Println("RegisterMessageProcessor:", err.Error()) - } - } - - //Step 4. Call bootstrap on all peers - err := advertiser.Bootstrap() - if err != nil { - fmt.Println("Bootstrap failed:", err.Error()) - } - for _, p := range peers { - err = p.Bootstrap() - if err != nil { - fmt.Printf("Bootstrap() for peer id %s failed:%s\n", p.ID(), err.Error()) - } - } - integrationTests.WaitForBootstrapAndShowConnected(peers, integrationTests.P2pBootstrapDelay) - - //Step 5. Continuously send messages from one peer - for timeStart := time.Now(); timeStart.Add(durationTest).Unix() > time.Now().Unix(); { - peers[0].Broadcast(testTopic, []byte("test buff")) - select { - case <-chanStop: - return - default: - } - time.Sleep(time.Millisecond) - } -} - -// TestBroadcastMessageComesFormTheConnectedPeers tests what happens in a network when a message comes through pubsub -// The receiving peer should get the message only from one of the connected peers -func TestBroadcastMessageComesFormTheConnectedPeers(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - topic := "test_topic" - broadcastMessageDuration := time.Second * 2 - peers, err := integrationTests.CreateFixedNetworkOf8Peers() - assert.Nil(t, err) - - defer func() { - integrationTests.ClosePeers(peers) - }() - - //node 0 is connected only to 1 and 3 (check integrationTests.CreateFixedNetworkOf7Peers function) - //a broadcast message from 6 should be received on node 0 only through peers 1 and 3 - - interceptors, err := createTopicsAndMockInterceptors(peers, topic) - assert.Nil(t, err) - - fmt.Println("bootstrapping nodes") - time.Sleep(integrationTests.P2pBootstrapDelay) - - broadcastIdx := 6 - receiverIdx := 0 - shouldReceiveFrom := []int{1, 3} - - broadcastPeer := peers[broadcastIdx] - fmt.Printf("broadcasting message from pid %s\n", broadcastPeer.ID().Pretty()) - broadcastPeer.Broadcast(topic, []byte("dummy")) - time.Sleep(broadcastMessageDuration) - - countReceivedMessages := 0 - receiverInterceptor := interceptors[receiverIdx] - for _, idx := range shouldReceiveFrom { - connectedPid := peers[idx].ID() - countReceivedMessages += len(receiverInterceptor.Messages(connectedPid)) - } - - assert.Equal(t, 1, countReceivedMessages) -} - -func createTopicsAndMockInterceptors(peers []p2p.Messenger, topic string) ([]*messageProcessor, error) { - interceptors := make([]*messageProcessor, len(peers)) - - for idx, p := range peers { - err := p.CreateTopic(topic, true) - if err != nil { - return nil, fmt.Errorf("%w, pid: %s", err, p.ID()) - } - - interceptors[idx] = newMessageProcessor() - err = p.RegisterMessageProcessor(topic, "test", interceptors[idx]) - if err != nil { - return nil, fmt.Errorf("%w, pid: %s", err, p.ID()) - } - } - - return interceptors, nil -} diff --git a/integrationTests/p2p/pubsub/unjoin_test.go b/integrationTests/p2p/pubsub/unjoin_test.go deleted file mode 100644 index a9ca66ed5aa..00000000000 --- a/integrationTests/p2p/pubsub/unjoin_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package peerDisconnecting - -import ( - "fmt" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/integrationTests" - "github.com/stretchr/testify/assert" -) - -const durationBootstrapping = time.Second * 2 -const durationTraverseNetwork = time.Second * 2 -const durationUnjoin = time.Second * 2 - -func TestPubsubUnjoinShouldWork(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - peers, _ := integrationTests.CreateFixedNetworkOf8Peers() - defer func() { - integrationTests.ClosePeers(peers) - }() - - topic := "test_topic" - processors := make([]*messageProcessor, 0, len(peers)) - for idx, p := range peers { - _ = p.CreateTopic(topic, true) - processors = append(processors, newMessageProcessor()) - _ = p.RegisterMessageProcessor(topic, "test", processors[idx]) - } - - fmt.Println("bootstrapping nodes") - time.Sleep(durationBootstrapping) - - //a message should traverse the network - fmt.Println("sending the message that should traverse the whole network") - sender := peers[4] - sender.Broadcast(topic, []byte("message 1")) - - time.Sleep(durationTraverseNetwork) - - for _, mp := range processors { - assert.Equal(t, 1, len(mp.AllMessages())) - } - - blockedIdxs := []int{3, 6, 2, 5} - //node 3 unjoins the topic, which should prevent the propagation of the messages on peers 3, 6, 2 and 5 - err := peers[3].UnregisterAllMessageProcessors() - assert.Nil(t, err) - - err = peers[3].UnjoinAllTopics() - assert.Nil(t, err) - - time.Sleep(durationUnjoin) - - fmt.Println("sending the message that should traverse half the network") - sender.Broadcast(topic, []byte("message 2")) - - time.Sleep(durationTraverseNetwork) - - for idx, mp := range processors { - if integrationTests.IsIntInSlice(idx, blockedIdxs) { - assert.Equal(t, 1, len(mp.AllMessages())) - continue - } - - assert.Equal(t, 2, len(mp.AllMessages())) - } -} diff --git a/integrationTests/singleShard/transaction/interceptedResolvedTx/interceptedResolvedTx_test.go b/integrationTests/singleShard/transaction/interceptedResolvedTx/interceptedResolvedTx_test.go index 24ad7c1f70b..d261cef2c0f 100644 --- a/integrationTests/singleShard/transaction/interceptedResolvedTx/interceptedResolvedTx_test.go +++ b/integrationTests/singleShard/transaction/interceptedResolvedTx/interceptedResolvedTx_test.go @@ -10,7 +10,6 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data/rewardTx" "github.com/ElrondNetwork/elrond-go-core/data/transaction" - ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" "github.com/ElrondNetwork/elrond-go/integrationTests" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/factory" @@ -73,7 +72,7 @@ func TestNode_RequestInterceptTransactionWithMessengerAndWhitelist(t *testing.T) } txBuff, _ := tx.GetDataForSigning(integrationTests.TestAddressPubkeyConverter, integrationTests.TestTxSignMarshalizer) - signer := &ed25519SingleSig.Ed25519Signer{} + signer := integrationTests.TestSingleSigner tx.Signature, _ = signer.Sign(nRequester.OwnAccount.SkTxSign, txBuff) signedTxBuff, _ := integrationTests.TestMarshalizer.Marshal(&tx) diff --git a/integrationTests/state/stateTrie/stateTrie_test.go b/integrationTests/state/stateTrie/stateTrie_test.go index 8ae566dd5b5..88175afd2fc 100644 --- a/integrationTests/state/stateTrie/stateTrie_test.go +++ b/integrationTests/state/stateTrie/stateTrie_test.go @@ -33,8 +33,8 @@ import ( "github.com/ElrondNetwork/elrond-go/state/storagePruningManager" "github.com/ElrondNetwork/elrond-go/state/storagePruningManager/evictionWaitingList" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" "github.com/ElrondNetwork/elrond-go/trie" @@ -70,8 +70,8 @@ func TestAccountsDB_RetrieveDataWithSomeValuesShouldWork(t *testing.T) { val2 := []byte("456") _, account, adb := integrationTests.GenerateAddressJournalAccountAccountsDB() - _ = account.DataTrieTracker().SaveKeyValue(key1, val1) - _ = account.DataTrieTracker().SaveKeyValue(key2, val2) + _ = account.SaveKeyValue(key1, val1) + _ = account.SaveKeyValue(key2, val2) err := adb.SaveAccount(account) require.Nil(t, err) @@ -84,11 +84,11 @@ func TestAccountsDB_RetrieveDataWithSomeValuesShouldWork(t *testing.T) { recoveredAccount := acc.(state.UserAccountHandler) // verify data - dataRecovered, err := recoveredAccount.DataTrieTracker().RetrieveValue(key1) + dataRecovered, err := recoveredAccount.RetrieveValue(key1) require.Nil(t, err) require.Equal(t, val1, dataRecovered) - dataRecovered, err = recoveredAccount.DataTrieTracker().RetrieveValue(key2) + dataRecovered, err = recoveredAccount.RetrieveValue(key2) require.Nil(t, err) require.Equal(t, val2, dataRecovered) } @@ -230,7 +230,7 @@ func TestAccountsDB_CommitTwoOkAccountsShouldWork(t *testing.T) { key := []byte("ABC") val := []byte("123") - _ = stateMock.DataTrieTracker().SaveKeyValue(key, val) + _ = stateMock.SaveKeyValue(key, val) _ = adb.SaveAccount(state1) _ = adb.SaveAccount(stateMock) @@ -261,7 +261,7 @@ func TestAccountsDB_CommitTwoOkAccountsShouldWork(t *testing.T) { require.Nil(t, err) require.Equal(t, balance2, newState2.(state.UserAccountHandler).GetBalance()) require.NotNil(t, newState2.(state.UserAccountHandler).GetRootHash()) - valRecovered, err := newState2.(state.UserAccountHandler).DataTrieTracker().RetrieveValue(key) + valRecovered, err := newState2.(state.UserAccountHandler).RetrieveValue(key) require.Nil(t, err) require.Equal(t, val, valRecovered) } @@ -319,7 +319,7 @@ func TestAccountsDB_CommitTwoOkAccountsWithRecreationFromStorageShouldWork(t *te key := []byte("ABC") val := []byte("123") - _ = stateMock.DataTrieTracker().SaveKeyValue(key, val) + _ = stateMock.SaveKeyValue(key, val) _ = adb.SaveAccount(state1) _ = adb.SaveAccount(stateMock) @@ -349,7 +349,7 @@ func TestAccountsDB_CommitTwoOkAccountsWithRecreationFromStorageShouldWork(t *te newState2 := acc2.(state.UserAccountHandler) require.Equal(t, balance2, newState2.GetBalance()) require.NotNil(t, newState2.GetRootHash()) - valRecovered, err := newState2.DataTrieTracker().RetrieveValue(key) + valRecovered, err := newState2.RetrieveValue(key) require.Nil(t, err) require.Equal(t, val, valRecovered) } @@ -674,7 +674,7 @@ func TestAccountsDB_RevertDataStepByStepAccountDataShouldWork(t *testing.T) { // Step 2. create 2 new accounts state1, err := adb.LoadAccount(adr1) require.Nil(t, err) - _ = state1.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(key, val) + _ = state1.(state.UserAccountHandler).SaveKeyValue(key, val) err = adb.SaveAccount(state1) require.Nil(t, err) snapshotCreated1 := adb.JournalLen() @@ -690,7 +690,7 @@ func TestAccountsDB_RevertDataStepByStepAccountDataShouldWork(t *testing.T) { stateMock, err := adb.LoadAccount(adr2) require.Nil(t, err) - _ = stateMock.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(key, val) + _ = stateMock.(state.UserAccountHandler).SaveKeyValue(key, val) err = adb.SaveAccount(stateMock) require.Nil(t, err) snapshotCreated2 := adb.JournalLen() @@ -753,7 +753,7 @@ func TestAccountsDB_RevertDataStepByStepWithCommitsAccountDataShouldWork(t *test // Step 2. create 2 new accounts state1, err := adb.LoadAccount(adr1) require.Nil(t, err) - _ = state1.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(key, val) + _ = state1.(state.UserAccountHandler).SaveKeyValue(key, val) err = adb.SaveAccount(state1) require.Nil(t, err) snapshotCreated1 := adb.JournalLen() @@ -769,7 +769,7 @@ func TestAccountsDB_RevertDataStepByStepWithCommitsAccountDataShouldWork(t *test stateMock, err := adb.LoadAccount(adr2) require.Nil(t, err) - _ = stateMock.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(key, val) + _ = stateMock.(state.UserAccountHandler).SaveKeyValue(key, val) err = adb.SaveAccount(stateMock) require.Nil(t, err) snapshotCreated2 := adb.JournalLen() @@ -800,7 +800,7 @@ func TestAccountsDB_RevertDataStepByStepWithCommitsAccountDataShouldWork(t *test stateMock, err = adb.LoadAccount(adr2) require.Nil(t, err) - _ = stateMock.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(key, newVal) + _ = stateMock.(state.UserAccountHandler).SaveKeyValue(key, newVal) err = adb.SaveAccount(stateMock) require.Nil(t, err) rootHash, err = adb.RootHash() @@ -1045,11 +1045,11 @@ func createAccounts( balance int, persist storage.Persister, ) (*state.AccountsDB, [][]byte, common.Trie) { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 10, Shards: 1, SizeInBytes: 0}) - store, _ := storageUnit.NewStorageUnit(cache, persist) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 10, Shards: 1, SizeInBytes: 0}) + store, _ := storageunit.NewStorageUnit(cache, persist) evictionWaitListSize := uint(100) - ewl, _ := evictionWaitingList.NewEvictionWaitingList(evictionWaitListSize, memorydb.New(), integrationTests.TestMarshalizer) + ewl, _ := evictionWaitingList.NewEvictionWaitingList(evictionWaitListSize, database.NewMemDB(), integrationTests.TestMarshalizer) args := getNewTrieStorageManagerArgs() args.MainStorer = store trieStorage, _ := trie.NewTrieStorageManager(args) @@ -1181,22 +1181,22 @@ func TestAccountsDB_RecreateTrieInvalidatesDataTriesCache(t *testing.T) { acc1, _ := adb.LoadAccount(address1) state1 := acc1.(state.UserAccountHandler) - _ = state1.DataTrieTracker().SaveKeyValue(key1, value1) - _ = state1.DataTrieTracker().SaveKeyValue(key2, value1) + _ = state1.SaveKeyValue(key1, value1) + _ = state1.SaveKeyValue(key2, value1) _ = adb.SaveAccount(state1) rootHash, err := adb.Commit() require.Nil(t, err) acc1, _ = adb.LoadAccount(address1) state1 = acc1.(state.UserAccountHandler) - _ = state1.DataTrieTracker().SaveKeyValue(key1, value2) + _ = state1.SaveKeyValue(key1, value2) _ = adb.SaveAccount(state1) _, err = adb.Commit() require.Nil(t, err) acc1, _ = adb.LoadAccount(address1) state1 = acc1.(state.UserAccountHandler) - _ = state1.DataTrieTracker().SaveKeyValue(key2, value2) + _ = state1.SaveKeyValue(key2, value2) _ = adb.SaveAccount(state1) err = adb.RevertToSnapshot(0) require.Nil(t, err) @@ -1206,7 +1206,7 @@ func TestAccountsDB_RecreateTrieInvalidatesDataTriesCache(t *testing.T) { acc1, _ = adb.LoadAccount(address1) state1 = acc1.(state.UserAccountHandler) - retrievedVal, _ := state1.DataTrieTracker().RetrieveValue(key1) + retrievedVal, _ := state1.RetrieveValue(key1) require.Equal(t, value1, retrievedVal) } @@ -1226,21 +1226,21 @@ func TestTrieDbPruning_GetDataTrieTrackerAfterPruning(t *testing.T) { acc1, _ := adb.LoadAccount(address1) state1 := acc1.(state.UserAccountHandler) - _ = state1.DataTrieTracker().SaveKeyValue(key1, value1) - _ = state1.DataTrieTracker().SaveKeyValue(key2, value1) + _ = state1.SaveKeyValue(key1, value1) + _ = state1.SaveKeyValue(key2, value1) _ = adb.SaveAccount(state1) acc2, _ := adb.LoadAccount(address2) stateMock := acc2.(state.UserAccountHandler) - _ = stateMock.DataTrieTracker().SaveKeyValue(key1, value1) - _ = stateMock.DataTrieTracker().SaveKeyValue(key2, value1) + _ = stateMock.SaveKeyValue(key1, value1) + _ = stateMock.SaveKeyValue(key2, value1) _ = adb.SaveAccount(stateMock) oldRootHash, _ := adb.Commit() acc2, _ = adb.LoadAccount(address2) stateMock = acc2.(state.UserAccountHandler) - _ = stateMock.DataTrieTracker().SaveKeyValue(key1, value2) + _ = stateMock.SaveKeyValue(key1, value2) _ = adb.SaveAccount(stateMock) newRootHash, _ := adb.Commit() @@ -1255,22 +1255,22 @@ func TestTrieDbPruning_GetDataTrieTrackerAfterPruning(t *testing.T) { collapseTrie(state1, t) collapseTrie(stateMock, t) - val, err := state1.DataTrieTracker().RetrieveValue(key1) + val, err := state1.RetrieveValue(key1) require.Nil(t, err) require.Equal(t, value1, val) - val, err = stateMock.DataTrieTracker().RetrieveValue(key2) + val, err = stateMock.RetrieveValue(key2) require.Nil(t, err) require.Equal(t, value1, val) } func collapseTrie(state state.UserAccountHandler, t *testing.T) { stateRootHash := state.GetRootHash() - stateTrie := state.DataTrieTracker().DataTrie() + stateTrie := state.DataTrie().(common.Trie) stateNewTrie, _ := stateTrie.Recreate(stateRootHash) require.NotNil(t, stateNewTrie) - state.DataTrieTracker().SetDataTrie(stateNewTrie) + state.SetDataTrie(stateNewTrie) } func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) { @@ -2063,7 +2063,7 @@ func generateAccounts( codeMap[string(code)]++ for j := 0; j < dataTrieSize; j++ { - _ = account.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(getDataTrieEntry()) + _ = account.(state.UserAccountHandler).SaveKeyValue(getDataTrieEntry()) } _ = shardNode.AccntState.SaveAccount(account) @@ -2131,7 +2131,7 @@ func TestProofAndVerifyProofDataTrie(t *testing.T) { key := []byte("key" + index) value := []byte("value" + index) - err := account.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(key, value) + err := account.(state.UserAccountHandler).SaveKeyValue(key, value) assert.Nil(t, err) } @@ -2329,7 +2329,7 @@ func addValuesInAccountDataTrie(index uint32, numKeys uint32, adb *state.Account accState := acc.(state.UserAccountHandler) for i := 0; i < int(numKeys); i++ { k, v := createDummyKeyValue(i) - _ = accState.DataTrieTracker().SaveKeyValue(k, v) + _ = accState.SaveKeyValue(k, v) } _ = adb.SaveAccount(accState) } @@ -2339,7 +2339,7 @@ func removeValuesFromAccountDataTrie(index uint32, numKeys uint32, adb *state.Ac accState := acc.(state.UserAccountHandler) for i := 0; i < int(numKeys); i++ { k, _ := createDummyKeyValue(i) - _ = accState.DataTrieTracker().SaveKeyValue(k, nil) + _ = accState.SaveKeyValue(k, nil) } _ = adb.SaveAccount(accState) } @@ -2358,7 +2358,7 @@ func checkAccountsDataTrie(t *testing.T, index uint32, startingKey uint32, adb * accState := acc.(state.UserAccountHandler) for i := int(startingKey); i < numKeys; i++ { k, v := createDummyKeyValue(i) - actualValue, errKey := accState.RetrieveValueFromDataTrieTracker(k) + actualValue, errKey := accState.RetrieveValue(k) require.Nil(t, errKey) require.Equal(t, v, actualValue) } @@ -2392,7 +2392,7 @@ func createAccountsDBTestSetup() *state.AccountsDB { SnapshotsGoroutineNum: 1, } evictionWaitListSize := uint(100) - ewl, _ := evictionWaitingList.NewEvictionWaitingList(evictionWaitListSize, memorydb.New(), integrationTests.TestMarshalizer) + ewl, _ := evictionWaitingList.NewEvictionWaitingList(evictionWaitListSize, database.NewMemDB(), integrationTests.TestMarshalizer) args := getNewTrieStorageManagerArgs() args.GeneralConfig = generalCfg trieStorage, _ := trie.NewTrieStorageManager(args) diff --git a/integrationTests/state/stateTrieClose/stateTrieClose_test.go b/integrationTests/state/stateTrieClose/stateTrieClose_test.go index 2e307398df9..cb92d47fefc 100644 --- a/integrationTests/state/stateTrieClose/stateTrieClose_test.go +++ b/integrationTests/state/stateTrieClose/stateTrieClose_test.go @@ -16,6 +16,7 @@ import ( "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" "github.com/ElrondNetwork/elrond-go/trie" "github.com/ElrondNetwork/elrond-go/trie/hashesHolder" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/stretchr/testify/assert" ) @@ -34,14 +35,14 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { idxInitial, _ := gc.Snapshot() rootHash, _ := tr.RootHash() leavesChannel1 := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash) + _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) time.Sleep(time.Second) // allow the go routine to start idx, _ := gc.Snapshot() diff := gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 1) // can be 0 on a fast running host leavesChannel1 = make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash) + _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 2) @@ -51,7 +52,7 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { rootHash, _ = tr.RootHash() leavesChannel1 = make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash) + _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.Equal(t, 3, len(diff), fmt.Sprintf("%v", diff)) @@ -61,7 +62,7 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { rootHash, _ = tr.RootHash() leavesChannel2 := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = tr.GetAllLeavesOnChannel(leavesChannel2, context.Background(), rootHash) + _ = tr.GetAllLeavesOnChannel(leavesChannel2, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) time.Sleep(time.Second) // allow the go routine to start idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index 607554cc7aa..33a6c2eab3e 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -23,7 +23,9 @@ import ( testStorage "github.com/ElrondNetwork/elrond-go/testscommon/state" "github.com/ElrondNetwork/elrond-go/trie" trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/trie/statistics" + "github.com/ElrondNetwork/elrond-go/trie/storageMarker" "github.com/ElrondNetwork/elrond-go/vm/systemSmartContracts/defaults" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -102,12 +104,7 @@ func TestNode_RequestInterceptTrieNodesWithMessenger(t *testing.T) { _ = resolverTrie.Commit() rootHash, _ := resolverTrie.RootHash() - leavesChannel := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = resolverTrie.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash) - numLeaves := 0 - for range leavesChannel { - numLeaves++ - } + numLeaves := getNumLeaves(t, resolverTrie, rootHash) assert.Equal(t, numTrieLeaves, numLeaves) requesterTrie := nRequester.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)) @@ -143,12 +140,7 @@ func TestNode_RequestInterceptTrieNodesWithMessenger(t *testing.T) { assert.NotEqual(t, nilRootHash, newRootHash) assert.Equal(t, rootHash, newRootHash) - leavesChannel = make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = requesterTrie.GetAllLeavesOnChannel(leavesChannel, context.Background(), newRootHash) - numLeaves = 0 - for range leavesChannel { - numLeaves++ - } + numLeaves = getNumLeaves(t, requesterTrie, rootHash) assert.Equal(t, numTrieLeaves, numLeaves) } @@ -233,12 +225,7 @@ func TestNode_RequestInterceptTrieNodesWithMessengerNotSyncingShouldErr(t *testi _ = resolverTrie.Commit() rootHash, _ := resolverTrie.RootHash() - leavesChannel := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = resolverTrie.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash) - numLeaves := 0 - for range leavesChannel { - numLeaves++ - } + numLeaves := getNumLeaves(t, resolverTrie, rootHash) assert.Equal(t, numTrieLeaves, numLeaves) requesterTrie := nRequester.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)) @@ -327,17 +314,7 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves time.Sleep(integrationTests.SyncDelay) accState := nResolver.AccntState - dataTrieRootHashes := make([][]byte, numAccounts) - - for i := 0; i < numAccounts; i++ { - address := integrationTests.CreateAccount(accState, 1, big.NewInt(100)) - account, _ := accState.LoadAccount(address) - userAcc, ok := account.(state.UserAccountHandler) - assert.True(t, ok) - - rootHash := addValuesToDataTrie(t, accState, userAcc, numDataTrieLeaves, valSize) - dataTrieRootHashes[i] = rootHash - } + dataTrieRootHashes := addAccountsToState(t, numAccounts, numDataTrieLeaves, accState, valSize) rootHash, _ := accState.RootHash() leavesChannel := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) @@ -349,24 +326,7 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves requesterTrie := nRequester.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)) nilRootHash, _ := requesterTrie.RootHash() - thr, _ := throttler.NewNumGoRoutinesThrottler(50) - syncerArgs := syncer.ArgsNewUserAccountsSyncer{ - ArgsNewBaseAccountsSyncer: syncer.ArgsNewBaseAccountsSyncer{ - Hasher: integrationTests.TestHasher, - Marshalizer: integrationTests.TestMarshalizer, - TrieStorageManager: nRequester.TrieStorageManagers[trieFactory.UserAccountTrie], - RequestHandler: nRequester.RequestHandler, - Timeout: common.TimeoutGettingTrieNodes, - Cacher: nRequester.DataPool.TrieNodes(), - MaxTrieLevelInMemory: 200, - MaxHardCapForMissingNodes: 5000, - TrieSyncerVersion: 2, - CheckNodesOnDisk: false, - }, - ShardId: shardID, - Throttler: thr, - AddressPubKeyConverter: integrationTests.TestAddressPubkeyConverter, - } + syncerArgs := getUserAccountSyncerArgs(nRequester) userAccSyncer, err := syncer.NewUserAccountsSyncer(syncerArgs) assert.Nil(t, err) @@ -388,18 +348,12 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves numLeaves++ } assert.Equal(t, numAccounts, numLeaves) - checkAllDataTriesAreSynced(t, numDataTrieLeaves, nRequester.AccntState, dataTrieRootHashes) + checkAllDataTriesAreSynced(t, numDataTrieLeaves, requesterTrie, dataTrieRootHashes) } -func checkAllDataTriesAreSynced(t *testing.T, numDataTrieLeaves int, adb state.AccountsAdapter, dataTriesRootHashes [][]byte) { +func checkAllDataTriesAreSynced(t *testing.T, numDataTrieLeaves int, tr common.Trie, dataTriesRootHashes [][]byte) { for i := range dataTriesRootHashes { - dataTrieLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err := adb.GetAllLeaves(dataTrieLeaves, context.Background(), dataTriesRootHashes[i]) - assert.Nil(t, err) - numLeaves := 0 - for range dataTrieLeaves { - numLeaves++ - } + numLeaves := getNumLeaves(t, tr, dataTriesRootHashes[i]) assert.Equal(t, numDataTrieLeaves, numLeaves) } } @@ -408,7 +362,7 @@ func addValuesToDataTrie(t *testing.T, adb state.AccountsAdapter, acc state.User for i := 0; i < numVals; i++ { keyRandBytes := integrationTests.CreateRandomBytes(32) valRandBytes := integrationTests.CreateRandomBytes(valSize) - _ = acc.DataTrieTracker().SaveKeyValue(keyRandBytes, valRandBytes) + _ = acc.SaveKeyValue(keyRandBytes, valRandBytes) } err := adb.SaveAccount(acc) @@ -419,3 +373,188 @@ func addValuesToDataTrie(t *testing.T, adb state.AccountsAdapter, acc state.User return acc.GetRootHash() } + +func TestSyncMissingSnapshotNodes(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + numSystemAccounts := 1 + numAccounts := 1000 + numDataTrieLeaves := 50 + valSize := 32 + roundsPerEpoch := uint64(5) + numOfShards := 1 + nodesPerShard := 2 + numMetachainNodes := 1 + + enableEpochsConfig := integrationTests.GetDefaultEnableEpochsConfig() + enableEpochsConfig.StakingV2EnableEpoch = integrationTests.UnreachableEpoch + nodes := integrationTests.CreateNodesWithEnableEpochsConfig( + numOfShards, + nodesPerShard, + numMetachainNodes, + enableEpochsConfig, + ) + + for _, node := range nodes { + node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) + } + + idxProposers := make([]int, numOfShards+1) + for i := 0; i < numOfShards; i++ { + idxProposers[i] = i * nodesPerShard + } + idxProposers[numOfShards] = numOfShards * nodesPerShard + + integrationTests.DisplayAndStartNodes(nodes) + + defer func() { + for _, n := range nodes { + n.Close() + } + }() + + nRequester := nodes[0] + nResolver := nodes[1] + + err := nRequester.ConnectTo(nResolver) + require.Nil(t, err) + time.Sleep(integrationTests.SyncDelay) + + round := uint64(0) + nonce := uint64(0) + round = integrationTests.IncrementAndPrintRound(round) + nonce++ + numDelayRounds := uint32(10) + for i := uint64(0); i < uint64(numDelayRounds); i++ { + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + time.Sleep(integrationTests.StepDelay) + } + + resolverTrie := nResolver.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)) + accState := nResolver.AccntState + dataTrieRootHashes := addAccountsToState(t, numAccounts, numDataTrieLeaves, accState, valSize) + rootHash, _ := accState.RootHash() + numLeaves := getNumLeaves(t, resolverTrie, rootHash) + require.Equal(t, numAccounts+numSystemAccounts, numLeaves) + + requesterTrie := nRequester.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)) + nilRootHash, _ := requesterTrie.RootHash() + + copyPartialState(t, nResolver, nRequester, dataTrieRootHashes) + + syncerArgs := getUserAccountSyncerArgs(nRequester) + userAccSyncer, err := syncer.NewUserAccountsSyncer(syncerArgs) + assert.Nil(t, err) + + err = nRequester.AccntState.SetSyncer(userAccSyncer) + assert.Nil(t, err) + err = nRequester.AccntState.StartSnapshotIfNeeded() + assert.Nil(t, err) + + tsm := nRequester.TrieStorageManagers[trieFactory.UserAccountTrie] + _ = tsm.PutInEpoch([]byte(common.ActiveDBKey), []byte(common.ActiveDBVal), 0) + nRequester.AccntState.SnapshotState(rootHash) + for tsm.IsPruningBlocked() { + time.Sleep(time.Millisecond * 100) + } + _ = nRequester.AccntState.RecreateTrie(rootHash) + + newRootHash, _ := nRequester.AccntState.RootHash() + assert.NotEqual(t, nilRootHash, newRootHash) + assert.Equal(t, rootHash, newRootHash) + + numLeaves = getNumLeaves(t, requesterTrie, rootHash) + assert.Equal(t, numAccounts+numSystemAccounts, numLeaves) + checkAllDataTriesAreSynced(t, numDataTrieLeaves, requesterTrie, dataTrieRootHashes) +} + +func copyPartialState(t *testing.T, sourceNode, destinationNode *integrationTests.TestProcessorNode, dataTriesRootHashes [][]byte) { + resolverTrie := sourceNode.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)) + hashes, _ := resolverTrie.GetAllHashes() + assert.NotEqual(t, 0, len(hashes)) + + hashes = append(hashes, getDataTriesHashes(t, resolverTrie, dataTriesRootHashes)...) + destStorage := destinationNode.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)).GetStorageManager() + + for i, hash := range hashes { + if i%1000 == 0 { + continue + } + + val, err := resolverTrie.GetStorageManager().Get(hash) + assert.Nil(t, err) + + err = destStorage.Put(hash, val) + assert.Nil(t, err) + } + +} + +func getDataTriesHashes(t *testing.T, tr common.Trie, dataTriesRootHashes [][]byte) [][]byte { + hashes := make([][]byte, 0) + for _, rh := range dataTriesRootHashes { + dt, err := tr.Recreate(rh) + assert.Nil(t, err) + + dtHashes, err := dt.GetAllHashes() + assert.Nil(t, err) + + hashes = append(hashes, dtHashes...) + } + + return hashes +} + +func addAccountsToState(t *testing.T, numAccounts int, numDataTrieLeaves int, accState state.AccountsAdapter, valSize int) [][]byte { + dataTrieRootHashes := make([][]byte, numAccounts) + + for i := 0; i < numAccounts; i++ { + address := integrationTests.CreateAccount(accState, 1, big.NewInt(100)) + account, _ := accState.LoadAccount(address) + userAcc, ok := account.(state.UserAccountHandler) + assert.True(t, ok) + + rootHash := addValuesToDataTrie(t, accState, userAcc, numDataTrieLeaves, valSize) + dataTrieRootHashes[i] = rootHash + } + + return dataTrieRootHashes +} + +func getNumLeaves(t *testing.T, tr common.Trie, rootHash []byte) int { + leavesChannel := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) + err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + require.Nil(t, err) + + numLeaves := 0 + for range leavesChannel { + numLeaves++ + } + + return numLeaves +} + +func getUserAccountSyncerArgs(node *integrationTests.TestProcessorNode) syncer.ArgsNewUserAccountsSyncer { + thr, _ := throttler.NewNumGoRoutinesThrottler(50) + syncerArgs := syncer.ArgsNewUserAccountsSyncer{ + ArgsNewBaseAccountsSyncer: syncer.ArgsNewBaseAccountsSyncer{ + Hasher: integrationTests.TestHasher, + Marshalizer: integrationTests.TestMarshalizer, + TrieStorageManager: node.TrieStorageManagers[trieFactory.UserAccountTrie], + RequestHandler: node.RequestHandler, + Timeout: common.TimeoutGettingTrieNodes, + Cacher: node.DataPool.TrieNodes(), + MaxTrieLevelInMemory: 200, + MaxHardCapForMissingNodes: 5000, + TrieSyncerVersion: 2, + StorageMarker: storageMarker.NewTrieStorageMarker(), + }, + ShardId: 0, + Throttler: thr, + AddressPubKeyConverter: integrationTests.TestAddressPubkeyConverter, + } + + return syncerArgs +} diff --git a/integrationTests/testConsensusNode.go b/integrationTests/testConsensusNode.go new file mode 100644 index 00000000000..00f9e60e21c --- /dev/null +++ b/integrationTests/testConsensusNode.go @@ -0,0 +1,442 @@ +package integrationTests + +import ( + "fmt" + "time" + + "github.com/ElrondNetwork/elrond-go-core/core" + "github.com/ElrondNetwork/elrond-go-core/core/check" + "github.com/ElrondNetwork/elrond-go-core/core/pubkeyConverter" + "github.com/ElrondNetwork/elrond-go-core/data" + dataBlock "github.com/ElrondNetwork/elrond-go-core/data/block" + "github.com/ElrondNetwork/elrond-go-core/data/endProcess" + "github.com/ElrondNetwork/elrond-go-core/hashing" + "github.com/ElrondNetwork/elrond-go-core/hashing/blake2b" + crypto "github.com/ElrondNetwork/elrond-go-crypto" + "github.com/ElrondNetwork/elrond-go/config" + "github.com/ElrondNetwork/elrond-go/consensus/round" + "github.com/ElrondNetwork/elrond-go/dataRetriever" + "github.com/ElrondNetwork/elrond-go/epochStart/metachain" + "github.com/ElrondNetwork/elrond-go/epochStart/notifier" + "github.com/ElrondNetwork/elrond-go/factory/peerSignatureHandler" + "github.com/ElrondNetwork/elrond-go/integrationTests/mock" + "github.com/ElrondNetwork/elrond-go/node" + "github.com/ElrondNetwork/elrond-go/ntp" + "github.com/ElrondNetwork/elrond-go/p2p" + "github.com/ElrondNetwork/elrond-go/process/factory" + syncFork "github.com/ElrondNetwork/elrond-go/process/sync" + "github.com/ElrondNetwork/elrond-go/sharding" + elrondShardingMocks "github.com/ElrondNetwork/elrond-go/sharding/mock" + "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" + "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/storage" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" + "github.com/ElrondNetwork/elrond-go/testscommon" + "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" + dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" + "github.com/ElrondNetwork/elrond-go/testscommon/nodeTypeProviderMock" + "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" + stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" + statusHandlerMock "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" + vic "github.com/ElrondNetwork/elrond-go/testscommon/validatorInfoCacher" +) + +const ( + blsConsensusType = "bls" + signatureSize = 48 + publicKeySize = 96 + maxShards = 1 + nodeShardId = 0 +) + +var testPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(32) + +// TestConsensusNode represents a structure used in integration tests used for consensus tests +type TestConsensusNode struct { + Node *node.Node + Messenger p2p.Messenger + NodesCoordinator nodesCoordinator.NodesCoordinator + ShardCoordinator sharding.Coordinator + ChainHandler data.ChainHandler + BlockProcessor *mock.BlockProcessorMock + ResolverFinder dataRetriever.ResolversFinder + AccountsDB *state.AccountsDB + NodeKeys TestKeyPair +} + +// NewTestConsensusNode returns a new TestConsensusNode +func NewTestConsensusNode( + consensusSize int, + roundTime uint64, + consensusType string, + nodeKeys TestKeyPair, + eligibleMap map[uint32][]nodesCoordinator.Validator, + waitingMap map[uint32][]nodesCoordinator.Validator, + keyGen crypto.KeyGenerator, +) *TestConsensusNode { + + shardCoordinator, _ := sharding.NewMultiShardCoordinator(maxShards, nodeShardId) + + tcn := &TestConsensusNode{ + NodeKeys: nodeKeys, + ShardCoordinator: shardCoordinator, + } + tcn.initNode(consensusSize, roundTime, consensusType, eligibleMap, waitingMap, keyGen) + + return tcn +} + +// CreateNodesWithTestConsensusNode returns a map with nodes per shard each using TestConsensusNode +func CreateNodesWithTestConsensusNode( + numMetaNodes int, + nodesPerShard int, + consensusSize int, + roundTime uint64, + consensusType string, +) map[uint32][]*TestConsensusNode { + + nodes := make(map[uint32][]*TestConsensusNode, nodesPerShard) + cp := CreateCryptoParams(nodesPerShard, numMetaNodes, maxShards) + keysMap := PubKeysMapFromKeysMap(cp.Keys) + validatorsMap := GenValidatorsFromPubKeys(keysMap, maxShards) + eligibleMap, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) + waitingMap := make(map[uint32][]nodesCoordinator.Validator) + connectableNodes := make([]Connectable, 0) + + for _, keysPair := range cp.Keys[0] { + tcn := NewTestConsensusNode( + consensusSize, + roundTime, + consensusType, + *keysPair, + eligibleMap, + waitingMap, + cp.KeyGen) + nodes[nodeShardId] = append(nodes[nodeShardId], tcn) + connectableNodes = append(connectableNodes, tcn) + } + + ConnectNodes(connectableNodes) + + return nodes +} + +func (tcn *TestConsensusNode) initNode( + consensusSize int, + roundTime uint64, + consensusType string, + eligibleMap map[uint32][]nodesCoordinator.Validator, + waitingMap map[uint32][]nodesCoordinator.Validator, + keyGen crypto.KeyGenerator, +) { + + testHasher := createHasher(consensusType) + epochStartRegistrationHandler := notifier.NewEpochStartSubscriptionHandler() + consensusCache, _ := cache.NewLRUCache(10000) + pkBytes, _ := tcn.NodeKeys.Pk.ToByteArray() + + tcn.initNodesCoordinator(consensusSize, testHasher, epochStartRegistrationHandler, eligibleMap, waitingMap, pkBytes, consensusCache) + tcn.Messenger = CreateMessengerWithNoDiscovery() + tcn.initBlockChain(testHasher) + tcn.initBlockProcessor() + + startTime := time.Now().Unix() + + syncer := ntp.NewSyncTime(ntp.NewNTPGoogleConfig(), nil) + syncer.StartSyncingTime() + + roundHandler, _ := round.NewRound( + time.Unix(startTime, 0), + syncer.CurrentTime(), + time.Millisecond*time.Duration(roundTime), + syncer, + 0) + + dataPool := dataRetrieverMock.CreatePoolsHolder(1, 0) + + argsNewMetaEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ + GenesisTime: time.Unix(startTime, 0), + EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), + Settings: &config.EpochStartConfig{ + MinRoundsBetweenEpochs: 1, + RoundsPerEpoch: 3, + }, + Epoch: 0, + Storage: createTestStore(), + Marshalizer: TestMarshalizer, + Hasher: testHasher, + AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, + DataPool: dataPool, + } + epochStartTrigger, _ := metachain.NewEpochStartTrigger(argsNewMetaEpochStart) + + forkDetector, _ := syncFork.NewShardForkDetector( + roundHandler, + cache.NewTimeCache(time.Second), + &mock.BlockTrackerStub{}, + startTime, + ) + + tcn.initResolverFinder() + + testMultiSig := cryptoMocks.NewMultiSigner() + + peerSigCache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler(peerSigCache, TestSingleBlsSigner, keyGen) + + tcn.initAccountsDB() + + coreComponents := GetDefaultCoreComponents() + coreComponents.SyncTimerField = syncer + coreComponents.RoundHandlerField = roundHandler + coreComponents.InternalMarshalizerField = TestMarshalizer + coreComponents.HasherField = testHasher + coreComponents.AddressPubKeyConverterField = testPubkeyConverter + coreComponents.ChainIdCalled = func() string { + return string(ChainID) + } + coreComponents.GenesisTimeField = time.Unix(startTime, 0) + coreComponents.GenesisNodesSetupField = &testscommon.NodesSetupStub{ + GetShardConsensusGroupSizeCalled: func() uint32 { + return uint32(consensusSize) + }, + GetMetaConsensusGroupSizeCalled: func() uint32 { + return uint32(consensusSize) + }, + } + + networkComponents := GetDefaultNetworkComponents() + networkComponents.Messenger = tcn.Messenger + networkComponents.InputAntiFlood = &mock.NilAntifloodHandler{} + networkComponents.PeerHonesty = &mock.PeerHonestyHandlerStub{} + + cryptoComponents := GetDefaultCryptoComponents() + cryptoComponents.PrivKey = tcn.NodeKeys.Sk + cryptoComponents.PubKey = tcn.NodeKeys.Sk.GeneratePublic() + cryptoComponents.BlockSig = TestSingleBlsSigner + cryptoComponents.TxSig = TestSingleSigner + cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(testMultiSig) + cryptoComponents.BlKeyGen = keyGen + cryptoComponents.PeerSignHandler = peerSigHandler + cryptoComponents.KeysHandlerField = &testscommon.KeysHandlerStub{ + GetHandledPrivateKeyCalled: func(pkBytes []byte) crypto.PrivateKey { + return tcn.NodeKeys.Sk + }, + GetAssociatedPidCalled: func(pkBytes []byte) core.PeerID { + return tcn.Messenger.ID() + }, + } + + processComponents := GetDefaultProcessComponents() + processComponents.ForkDetect = forkDetector + processComponents.ShardCoord = tcn.ShardCoordinator + processComponents.NodesCoord = tcn.NodesCoordinator + processComponents.BlockProcess = tcn.BlockProcessor + processComponents.ResFinder = tcn.ResolverFinder + processComponents.EpochTrigger = epochStartTrigger + processComponents.EpochNotifier = epochStartRegistrationHandler + processComponents.BlackListHdl = &testscommon.TimeCacheStub{} + processComponents.BootSore = &mock.BoostrapStorerMock{} + processComponents.HeaderSigVerif = &mock.HeaderSigVerifierStub{} + processComponents.HeaderIntegrVerif = &mock.HeaderIntegrityVerifierStub{} + processComponents.ReqHandler = &testscommon.RequestHandlerStub{} + processComponents.PeerMapper = mock.NewNetworkShardingCollectorMock() + processComponents.RoundHandlerField = roundHandler + processComponents.ScheduledTxsExecutionHandlerInternal = &testscommon.ScheduledTxsExecutionStub{} + processComponents.ProcessedMiniBlocksTrackerInternal = &testscommon.ProcessedMiniBlocksTrackerStub{} + + dataComponents := GetDefaultDataComponents() + dataComponents.BlockChain = tcn.ChainHandler + dataComponents.DataPool = dataPool + dataComponents.Store = createTestStore() + + stateComponents := GetDefaultStateComponents() + stateComponents.Accounts = tcn.AccountsDB + stateComponents.AccountsAPI = tcn.AccountsDB + + var err error + tcn.Node, err = node.NewNode( + node.WithCoreComponents(coreComponents), + node.WithCryptoComponents(cryptoComponents), + node.WithProcessComponents(processComponents), + node.WithDataComponents(dataComponents), + node.WithStateComponents(stateComponents), + node.WithNetworkComponents(networkComponents), + node.WithRoundDuration(roundTime), + node.WithConsensusGroupSize(consensusSize), + node.WithConsensusType(consensusType), + node.WithGenesisTime(time.Unix(startTime, 0)), + node.WithValidatorSignatureSize(signatureSize), + node.WithPublicKeySize(publicKeySize), + ) + + if err != nil { + fmt.Println(err.Error()) + } +} + +func (tcn *TestConsensusNode) initNodesCoordinator( + consensusSize int, + hasher hashing.Hasher, + epochStartRegistrationHandler notifier.EpochStartNotifier, + eligibleMap map[uint32][]nodesCoordinator.Validator, + waitingMap map[uint32][]nodesCoordinator.Validator, + pkBytes []byte, + cache storage.Cacher, +) { + argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ + ShardConsensusGroupSize: consensusSize, + MetaConsensusGroupSize: 1, + Marshalizer: TestMarshalizer, + Hasher: hasher, + Shuffler: &shardingMocks.NodeShufflerMock{}, + EpochStartNotifier: epochStartRegistrationHandler, + BootStorer: CreateMemUnit(), + NbShards: maxShards, + EligibleNodes: eligibleMap, + WaitingNodes: waitingMap, + SelfPublicKey: pkBytes, + ConsensusGroupCache: cache, + ShuffledOutHandler: &elrondShardingMocks.ShuffledOutHandlerStub{}, + ChanStopNode: endProcess.GetDummyEndProcessChannel(), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + IsFullArchive: false, + EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + IsWaitingListFixFlagEnabledField: true, + }, + ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, + } + + tcn.NodesCoordinator, _ = nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) +} + +func (tcn *TestConsensusNode) initBlockChain(hasher hashing.Hasher) { + if tcn.ShardCoordinator.SelfId() == core.MetachainShardId { + tcn.ChainHandler = CreateMetaChain() + } else { + tcn.ChainHandler = CreateShardChain() + } + + rootHash := []byte("roothash") + header := &dataBlock.Header{ + Nonce: 0, + ShardID: tcn.ShardCoordinator.SelfId(), + BlockBodyType: dataBlock.StateBlock, + Signature: rootHash, + RootHash: rootHash, + PrevRandSeed: rootHash, + RandSeed: rootHash, + } + + _ = tcn.ChainHandler.SetGenesisHeader(header) + hdrMarshalized, _ := TestMarshalizer.Marshal(header) + tcn.ChainHandler.SetGenesisHeaderHash(hasher.Compute(string(hdrMarshalized))) +} + +func (tcn *TestConsensusNode) initBlockProcessor() { + tcn.BlockProcessor = &mock.BlockProcessorMock{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + _ = tcn.ChainHandler.SetCurrentBlockHeaderAndRootHash(header, header.GetRootHash()) + return nil + }, + CreateBlockCalled: func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { + return header, &dataBlock.Body{}, nil + }, + MarshalizedDataToBroadcastCalled: func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { + mrsData := make(map[uint32][]byte) + mrsTxs := make(map[string][][]byte) + return mrsData, mrsTxs, nil + }, + CreateNewHeaderCalled: func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return &dataBlock.Header{ + Round: round, + Nonce: nonce, + SoftwareVersion: []byte("version"), + }, nil + }, + } + + tcn.BlockProcessor.CommitBlockCalled = func(header data.HeaderHandler, body data.BodyHandler) error { + tcn.BlockProcessor.NrCommitBlockCalled++ + _ = tcn.ChainHandler.SetCurrentBlockHeaderAndRootHash(header, header.GetRootHash()) + return nil + } + tcn.BlockProcessor.Marshalizer = TestMarshalizer +} + +func (tcn *TestConsensusNode) initResolverFinder() { + hdrResolver := &mock.HeaderResolverStub{} + mbResolver := &mock.MiniBlocksResolverStub{} + tcn.ResolverFinder = &mock.ResolversFinderStub{ + IntraShardResolverCalled: func(baseTopic string) (resolver dataRetriever.Resolver, e error) { + if baseTopic == factory.MiniBlocksTopic { + return mbResolver, nil + } + return nil, nil + }, + CrossShardResolverCalled: func(baseTopic string, crossShard uint32) (resolver dataRetriever.Resolver, err error) { + if baseTopic == factory.ShardBlocksTopic { + return hdrResolver, nil + } + return nil, nil + }, + } +} + +func (tcn *TestConsensusNode) initAccountsDB() { + storer, _, err := stateMock.CreateTestingTriePruningStorer(tcn.ShardCoordinator, notifier.NewEpochStartSubscriptionHandler()) + if err != nil { + log.Error("initAccountsDB", "error", err.Error()) + } + trieStorage, _ := CreateTrieStorageManager(storer) + + tcn.AccountsDB, _ = CreateAccountsDB(UserAccount, trieStorage) +} + +func createHasher(consensusType string) hashing.Hasher { + if consensusType == blsConsensusType { + hasher, _ := blake2b.NewBlake2bWithSize(32) + return hasher + } + return blake2b.NewBlake2b() +} + +func createTestStore() dataRetriever.StorageService { + store := dataRetriever.NewChainStorer() + store.AddStorer(dataRetriever.TransactionUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.MiniBlockUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.RewardTransactionUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.MetaBlockUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.PeerChangesUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.BlockHeaderUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.BootstrapUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.ReceiptsUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.ScheduledSCRsUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.ShardHdrNonceHashDataUnit, CreateMemUnit()) + + return store +} + +// ConnectTo will try to initiate a connection to the provided parameter +func (tcn *TestConsensusNode) ConnectTo(connectable Connectable) error { + if check.IfNil(connectable) { + return fmt.Errorf("trying to connect to a nil Connectable parameter") + } + + return tcn.Messenger.ConnectToPeer(connectable.GetConnectableAddress()) +} + +// GetConnectableAddress returns a non circuit, non windows default connectable p2p address +func (tcn *TestConsensusNode) GetConnectableAddress() string { + if tcn == nil { + return "nil" + } + + return GetConnectableAddress(tcn.Messenger) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (tcn *TestConsensusNode) IsInterfaceNil() bool { + return tcn == nil +} diff --git a/integrationTests/testHeartbeatNode.go b/integrationTests/testHeartbeatNode.go index 8b87691d92a..83342bd782a 100644 --- a/integrationTests/testHeartbeatNode.go +++ b/integrationTests/testHeartbeatNode.go @@ -29,6 +29,7 @@ import ( "github.com/ElrondNetwork/elrond-go/heartbeat/sender" "github.com/ElrondNetwork/elrond-go/integrationTests/mock" "github.com/ElrondNetwork/elrond-go/p2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/heartbeat/validator" "github.com/ElrondNetwork/elrond-go/process/interceptors" @@ -38,8 +39,8 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/networksharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" @@ -109,7 +110,7 @@ func NewTestHeartbeatNode( maxShards uint32, nodeShardId uint32, minPeersWaiting int, - p2pConfig config.P2PConfig, + p2pConfig p2pConfig.P2PConfig, heartbeatExpiryTimespanInSec int64, ) *TestHeartbeatNode { keygen := signing.NewKeyGenerator(mcl.NewSuiteBLS12()) @@ -153,9 +154,9 @@ func NewTestHeartbeatNode( shardCoordinator, _ := sharding.NewMultiShardCoordinator(maxShards, nodeShardId) messenger := CreateMessengerFromConfig(p2pConfig) - pidPk, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) - pkShardId, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) - pidShardId, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) + pidPk, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + pkShardId, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + pidShardId, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) arg := networksharding.ArgPeerShardMapper{ PeerIdPkCache: pidPk, FallbackPkShardCache: pkShardId, @@ -201,7 +202,7 @@ func NewTestHeartbeatNode( func NewTestHeartbeatNodeWithCoordinator( maxShards uint32, nodeShardId uint32, - p2pConfig config.P2PConfig, + p2pConfig p2pConfig.P2PConfig, coordinator nodesCoordinator.NodesCoordinator, keys TestKeyPair, ) *TestHeartbeatNode { @@ -224,9 +225,9 @@ func NewTestHeartbeatNodeWithCoordinator( shardCoordinator, _ := sharding.NewMultiShardCoordinator(maxShards, nodeShardId) messenger := CreateMessengerFromConfig(p2pConfig) - pidPk, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) - pkShardId, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) - pidShardId, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) + pidPk, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + pkShardId, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + pidShardId, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) arg := networksharding.ArgPeerShardMapper{ PeerIdPkCache: pidPk, FallbackPkShardCache: pkShardId, @@ -270,7 +271,7 @@ func CreateNodesWithTestHeartbeatNode( shardConsensusGroupSize int, metaConsensusGroupSize int, numObserversOnShard int, - p2pConfig config.P2PConfig, + p2pConfig p2pConfig.P2PConfig, ) map[uint32][]*TestHeartbeatNode { cp := CreateCryptoParams(nodesPerShard, numMetaNodes, uint32(numShards)) @@ -278,8 +279,8 @@ func CreateNodesWithTestHeartbeatNode( validatorsMap := GenValidatorsFromPubKeys(pubKeys, uint32(numShards)) validatorsForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) nodesMap := make(map[uint32][]*TestHeartbeatNode) - cacherCfg := storageUnit.CacheConfig{Capacity: 10000, Type: storageUnit.LRUCache, Shards: 1} - cache, _ := storageUnit.NewCache(cacherCfg) + cacherCfg := storageunit.CacheConfig{Capacity: 10000, Type: storageunit.LRUCache, Shards: 1} + suCache, _ := storageunit.NewCache(cacherCfg) for shardId, validatorList := range validatorsMap { argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ ShardConsensusGroupSize: shardConsensusGroupSize, @@ -290,7 +291,7 @@ func CreateNodesWithTestHeartbeatNode( NbShards: uint32(numShards), EligibleNodes: validatorsForNodesCoordinator, SelfPublicKey: []byte(strconv.Itoa(int(shardId))), - ConsensusGroupCache: cache, + ConsensusGroupCache: suCache, Shuffler: &shardingMocks.NodeShufflerMock{}, BootStorer: CreateMemUnit(), WaitingNodes: make(map[uint32][]nodesCoordinator.Validator), @@ -336,7 +337,7 @@ func CreateNodesWithTestHeartbeatNode( NbShards: uint32(numShards), EligibleNodes: validatorsForNodesCoordinator, SelfPublicKey: []byte(strconv.Itoa(int(shardId))), - ConsensusGroupCache: cache, + ConsensusGroupCache: suCache, Shuffler: &shardingMocks.NodeShufflerMock{}, BootStorer: CreateMemUnit(), WaitingNodes: make(map[uint32][]nodesCoordinator.Validator), @@ -387,9 +388,9 @@ func (thn *TestHeartbeatNode) InitTestHeartbeatNode(tb testing.TB, minPeersWaiti func (thn *TestHeartbeatNode) initDataPools() { thn.DataPool = dataRetrieverMock.CreatePoolsHolder(1, thn.ShardCoordinator.SelfId()) - cacherCfg := storageUnit.CacheConfig{Capacity: 10000, Type: storageUnit.LRUCache, Shards: 1} - cache, _ := storageUnit.NewCache(cacherCfg) - thn.WhiteListHandler, _ = interceptors.NewWhiteListDataVerifier(cache) + cacherCfg := storageunit.CacheConfig{Capacity: 10000, Type: storageunit.LRUCache, Shards: 1} + suCache, _ := storageunit.NewCache(cacherCfg) + thn.WhiteListHandler, _ = interceptors.NewWhiteListDataVerifier(suCache) } func (thn *TestHeartbeatNode) initStorage() { @@ -446,7 +447,7 @@ func (thn *TestHeartbeatNode) initResolvers() { DataPools: thn.DataPool, Uint64ByteSliceConverter: TestUint64Converter, DataPacker: dataPacker, - TriesContainer: &mock.TriesHolderStub{ + TriesContainer: &trieMock.TriesHolderStub{ GetCalled: func(bytes []byte) common.Trie { return &trieMock.TrieStub{} }, @@ -506,7 +507,7 @@ func (thn *TestHeartbeatNode) createRequestHandler() { } func (thn *TestHeartbeatNode) initRequestedItemsHandler() { - thn.RequestedItemsHandler = timecache.NewTimeCache(roundDuration) + thn.RequestedItemsHandler = cache.NewTimeCache(roundDuration) } func (thn *TestHeartbeatNode) initInterceptors() { diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index 37e27d25e7e..28c9a2e3e4f 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -27,7 +27,6 @@ import ( crypto "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519" - ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" @@ -39,7 +38,7 @@ import ( "github.com/ElrondNetwork/elrond-go/integrationTests/mock" "github.com/ElrondNetwork/elrond-go/node" "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/ElrondNetwork/elrond-go/process" procFactory "github.com/ElrondNetwork/elrond-go/process/factory" "github.com/ElrondNetwork/elrond-go/process/headerCheck" @@ -52,9 +51,9 @@ import ( "github.com/ElrondNetwork/elrond-go/state/storagePruningManager" "github.com/ElrondNetwork/elrond-go/state/storagePruningManager/evictionWaitingList" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" + "github.com/ElrondNetwork/elrond-go/storage/database" "github.com/ElrondNetwork/elrond-go/storage/pruning" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" @@ -122,12 +121,12 @@ func GetConnectableAddress(mes p2p.Messenger) string { return "" } -func createP2PConfig(initialPeerList []string) config.P2PConfig { - return config.P2PConfig{ - Node: config.NodeConfig{ +func createP2PConfig(initialPeerList []string) p2pConfig.P2PConfig { + return p2pConfig.P2PConfig{ + Node: p2pConfig.NodeConfig{ Port: "0", }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ + KadDhtPeerDiscovery: p2pConfig.KadDhtPeerDiscoveryConfig{ Enabled: true, Type: "optimized", RefreshIntervalInSec: 2, @@ -136,7 +135,7 @@ func createP2PConfig(initialPeerList []string) config.P2PConfig { BucketSize: 100, RoutingTableRefreshIntervalInSec: 100, }, - Sharding: config.ShardingConfig{ + Sharding: p2pConfig.ShardingConfig{ Type: p2p.NilListSharder, }, } @@ -148,55 +147,30 @@ func CreateMessengerWithKadDht(initialAddr string) p2p.Messenger { if len(initialAddr) > 0 { initialAddresses = append(initialAddresses, initialAddr) } - arg := libp2p.ArgsNetworkMessenger{ + arg := p2p.ArgsNetworkMessenger{ Marshalizer: TestMarshalizer, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, + ListenAddress: p2p.ListenLocalhostAddrWithIp4AndTcp, P2pConfig: createP2PConfig(initialAddresses), - SyncTimer: &libp2p.LocalSyncTimer{}, + SyncTimer: &p2p.LocalSyncTimer{}, PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, NodeOperationMode: p2p.NormalOperation, PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, } - libP2PMes, err := libp2p.NewNetworkMessenger(arg) - log.LogIfError(err) - - return libP2PMes -} - -// CreateMessengerWithKadDhtAndProtocolID creates a new libp2p messenger with kad-dht peer discovery and peer ID -func CreateMessengerWithKadDhtAndProtocolID(initialAddr string, protocolID string) p2p.Messenger { - initialAddresses := make([]string, 0) - if len(initialAddr) > 0 { - initialAddresses = append(initialAddresses, initialAddr) - } - p2pConfig := createP2PConfig(initialAddresses) - p2pConfig.KadDhtPeerDiscovery.ProtocolID = protocolID - arg := libp2p.ArgsNetworkMessenger{ - Marshalizer: TestMarshalizer, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: p2pConfig, - SyncTimer: &libp2p.LocalSyncTimer{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - NodeOperationMode: p2p.NormalOperation, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - - libP2PMes, err := libp2p.NewNetworkMessenger(arg) + libP2PMes, err := p2p.NewNetworkMessenger(arg) log.LogIfError(err) return libP2PMes } // CreateMessengerFromConfig creates a new libp2p messenger with provided configuration -func CreateMessengerFromConfig(p2pConfig config.P2PConfig) p2p.Messenger { - arg := libp2p.ArgsNetworkMessenger{ +func CreateMessengerFromConfig(p2pConfig p2pConfig.P2PConfig) p2p.Messenger { + arg := p2p.ArgsNetworkMessenger{ Marshalizer: TestMarshalizer, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, + ListenAddress: p2p.ListenLocalhostAddrWithIp4AndTcp, P2pConfig: p2pConfig, - SyncTimer: &libp2p.LocalSyncTimer{}, + SyncTimer: &p2p.LocalSyncTimer{}, PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, NodeOperationMode: p2p.NormalOperation, PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, @@ -208,19 +182,19 @@ func CreateMessengerFromConfig(p2pConfig config.P2PConfig) p2p.Messenger { arg.NodeOperationMode = p2p.FullArchiveMode } - libP2PMes, err := libp2p.NewNetworkMessenger(arg) + libP2PMes, err := p2p.NewNetworkMessenger(arg) log.LogIfError(err) return libP2PMes } // CreateMessengerFromConfigWithPeersRatingHandler creates a new libp2p messenger with provided configuration -func CreateMessengerFromConfigWithPeersRatingHandler(p2pConfig config.P2PConfig, peersRatingHandler p2p.PeersRatingHandler) p2p.Messenger { - arg := libp2p.ArgsNetworkMessenger{ +func CreateMessengerFromConfigWithPeersRatingHandler(p2pConfig p2pConfig.P2PConfig, peersRatingHandler p2p.PeersRatingHandler) p2p.Messenger { + arg := p2p.ArgsNetworkMessenger{ Marshalizer: TestMarshalizer, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, + ListenAddress: p2p.ListenLocalhostAddrWithIp4AndTcp, P2pConfig: p2pConfig, - SyncTimer: &libp2p.LocalSyncTimer{}, + SyncTimer: &p2p.LocalSyncTimer{}, PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, NodeOperationMode: p2p.NormalOperation, PeersRatingHandler: peersRatingHandler, @@ -232,23 +206,22 @@ func CreateMessengerFromConfigWithPeersRatingHandler(p2pConfig config.P2PConfig, arg.NodeOperationMode = p2p.FullArchiveMode } - libP2PMes, err := libp2p.NewNetworkMessenger(arg) + libP2PMes, err := p2p.NewNetworkMessenger(arg) log.LogIfError(err) return libP2PMes } // CreateP2PConfigWithNoDiscovery creates a new libp2p messenger with no peer discovery -func CreateP2PConfigWithNoDiscovery() config.P2PConfig { - return config.P2PConfig{ - Node: config.NodeConfig{ +func CreateP2PConfigWithNoDiscovery() p2pConfig.P2PConfig { + return p2pConfig.P2PConfig{ + Node: p2pConfig.NodeConfig{ Port: "0", - Seed: "", }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ + KadDhtPeerDiscovery: p2pConfig.KadDhtPeerDiscoveryConfig{ Enabled: false, }, - Sharding: config.ShardingConfig{ + Sharding: p2pConfig.ShardingConfig{ Type: p2p.NilListSharder, }, } @@ -256,27 +229,26 @@ func CreateP2PConfigWithNoDiscovery() config.P2PConfig { // CreateMessengerWithNoDiscovery creates a new libp2p messenger with no peer discovery func CreateMessengerWithNoDiscovery() p2p.Messenger { - p2pConfig := CreateP2PConfigWithNoDiscovery() + p2pCfg := CreateP2PConfigWithNoDiscovery() - return CreateMessengerFromConfig(p2pConfig) + return CreateMessengerFromConfig(p2pCfg) } // CreateMessengerWithNoDiscoveryAndPeersRatingHandler creates a new libp2p messenger with no peer discovery func CreateMessengerWithNoDiscoveryAndPeersRatingHandler(peersRatingHanlder p2p.PeersRatingHandler) p2p.Messenger { - p2pConfig := config.P2PConfig{ - Node: config.NodeConfig{ + p2pCfg := p2pConfig.P2PConfig{ + Node: p2pConfig.NodeConfig{ Port: "0", - Seed: "", }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ + KadDhtPeerDiscovery: p2pConfig.KadDhtPeerDiscoveryConfig{ Enabled: false, }, - Sharding: config.ShardingConfig{ + Sharding: p2pConfig.ShardingConfig{ Type: p2p.NilListSharder, }, } - return CreateMessengerFromConfigWithPeersRatingHandler(p2pConfig, peersRatingHanlder) + return CreateMessengerFromConfigWithPeersRatingHandler(p2pCfg, peersRatingHanlder) } // CreateFixedNetworkOf8Peers assembles a network as following: @@ -376,9 +348,9 @@ func CreateMemUnit() storage.Storer { capacity := uint32(10) shards := uint32(1) sizeInBytes := uint64(0) - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) - persist, _ := memorydb.NewlruDB(10000000) - unit, _ := storageUnit.NewStorageUnit(cache, persist) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) + persist, _ := database.NewlruDB(10000000) + unit, _ := storageunit.NewStorageUnit(cache, persist) return unit } @@ -466,7 +438,7 @@ func CreateAccountsDB( ) (*state.AccountsDB, common.Trie) { tr, _ := trie.NewTrie(trieStorageManager, TestMarshalizer, TestHasher, maxTrieLevelInMemory) - ewl, _ := evictionWaitingList.NewEvictionWaitingList(100, memorydb.New(), TestMarshalizer) + ewl, _ := evictionWaitingList.NewEvictionWaitingList(100, database.NewMemDB(), TestMarshalizer) accountFactory := getAccountFactory(accountType) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) args := state.ArgsAccountsDB{ @@ -1280,6 +1252,25 @@ func CreateNodes( numOfShards int, nodesPerShard int, numMetaChainNodes int, +) []*TestProcessorNode { + return createNodesWithEpochsConfig(numOfShards, nodesPerShard, numMetaChainNodes, GetDefaultEnableEpochsConfig()) +} + +// CreateNodesWithEnableEpochsConfig creates multiple nodes in different shards but with custom enable epochs config +func CreateNodesWithEnableEpochsConfig( + numOfShards int, + nodesPerShard int, + numMetaChainNodes int, + enableEpochsConfig *config.EnableEpochs, +) []*TestProcessorNode { + return createNodesWithEpochsConfig(numOfShards, nodesPerShard, numMetaChainNodes, enableEpochsConfig) +} + +func createNodesWithEpochsConfig( + numOfShards int, + nodesPerShard int, + numMetaChainNodes int, + enableEpochsConfig *config.EnableEpochs, ) []*TestProcessorNode { nodes := make([]*TestProcessorNode, numOfShards*nodesPerShard+numMetaChainNodes) connectableNodes := make([]Connectable, len(nodes)) @@ -1291,6 +1282,7 @@ func CreateNodes( MaxShards: uint32(numOfShards), NodeShardId: shardId, TxSignPrivKeyShardId: shardId, + EpochsConfig: enableEpochsConfig, }) nodes[idx] = n connectableNodes[idx] = n @@ -1303,6 +1295,7 @@ func CreateNodes( MaxShards: uint32(numOfShards), NodeShardId: core.MetachainShardId, TxSignPrivKeyShardId: 0, + EpochsConfig: enableEpochsConfig, }) idx = i + numOfShards*nodesPerShard nodes[idx] = metaNode @@ -1444,7 +1437,7 @@ func CreateNodesWithFullGenesis( nodes := make([]*TestProcessorNode, numOfShards*nodesPerShard+numMetaChainNodes) connectableNodes := make([]Connectable, len(nodes)) - enableEpochsConfig := getDefaultEnableEpochsConfig() + enableEpochsConfig := GetDefaultEnableEpochsConfig() enableEpochsConfig.StakingV2EnableEpoch = UnreachableEpoch economicsConfig := createDefaultEconomicsConfig() @@ -1512,7 +1505,7 @@ func CreateNodesWithCustomStateCheckpointModulus( nodes := make([]*TestProcessorNode, numOfShards*nodesPerShard+numMetaChainNodes) connectableNodes := make([]Connectable, len(nodes)) - enableEpochsConfig := getDefaultEnableEpochsConfig() + enableEpochsConfig := GetDefaultEnableEpochsConfig() enableEpochsConfig.StakingV2EnableEpoch = UnreachableEpoch scm := &IntWrapper{ @@ -1747,7 +1740,7 @@ func GenerateTransferTx( Version: version, } txBuff, _ := tx.GetDataForSigning(TestAddressPubkeyConverter, TestTxSignMarshalizer) - signer := &ed25519SingleSig.Ed25519Signer{} + signer := TestSingleSigner tx.Signature, _ = signer.Sign(senderPrivateKey, txBuff) return &tx @@ -2116,7 +2109,7 @@ func generateValidTx( coreComponents.ValidatorPubKeyConverterField = TestValidatorPubkeyConverter cryptoComponents := GetDefaultCryptoComponents() - cryptoComponents.TxSig = &ed25519SingleSig.Ed25519Signer{} + cryptoComponents.TxSig = TestSingleSigner cryptoComponents.TxKeyGen = signing.NewKeyGenerator(ed25519.NewEd25519()) cryptoComponents.BlKeyGen = signing.NewKeyGenerator(ed25519.NewEd25519()) @@ -2167,20 +2160,6 @@ func ProposeAndSyncOneBlock( return round, nonce } -// WaitForBootstrapAndShowConnected will delay a given duration in order to wait for bootstraping and print the -// number of peers that each node is connected to -func WaitForBootstrapAndShowConnected(peers []p2p.Messenger, durationBootstrapingTime time.Duration) { - log.Info("Waiting for peer discovery...", "time", durationBootstrapingTime) - time.Sleep(durationBootstrapingTime) - - strs := []string{"Connected peers:"} - for _, peer := range peers { - strs = append(strs, fmt.Sprintf("Peer %s is connected to %d peers", peer.ID().Pretty(), len(peer.ConnectedPeers()))) - } - - log.Info(strings.Join(strs, "\n")) -} - // PubKeysMapFromKeysMap returns a map of public keys per shard from the key pairs per shard map. func PubKeysMapFromKeysMap(keyPairMap map[uint32][]*TestKeyPair) map[uint32][]string { keysMap := make(map[uint32][]string) @@ -2237,7 +2216,7 @@ func CreateCryptoParams(nodesPerShard int, nbMetaNodes int, nbShards uint32) *Cr txSuite := ed25519.NewEd25519() txKeyGen := signing.NewKeyGenerator(txSuite) suite := mcl.NewSuiteBLS12() - singleSigner := &ed25519SingleSig.Ed25519Signer{} + singleSigner := TestSingleSigner keyGen := signing.NewKeyGenerator(suite) txKeysMap := make(map[uint32][]*TestKeyPair) @@ -2517,7 +2496,7 @@ func SaveDelegationManagerConfig(nodes []*TestProcessorNode) { MinDelegationAmount: big.NewInt(1), } marshaledData, _ := TestMarshalizer.Marshal(managementData) - _ = userAcc.DataTrieTracker().SaveKeyValue([]byte(delegationManagementKey), marshaledData) + _ = userAcc.SaveKeyValue([]byte(delegationManagementKey), marshaledData) _ = n.AccntState.SaveAccount(userAcc) _, _ = n.AccntState.Commit() } @@ -2537,7 +2516,7 @@ func SaveDelegationContractsList(nodes []*TestProcessorNode) { Addresses: [][]byte{[]byte("addr")}, } marshaledData, _ := TestMarshalizer.Marshal(managementData) - _ = userAcc.DataTrieTracker().SaveKeyValue([]byte(delegationContractsList), marshaledData) + _ = userAcc.SaveKeyValue([]byte(delegationContractsList), marshaledData) _ = n.AccntState.SaveAccount(userAcc) _, _ = n.AccntState.Commit() } diff --git a/integrationTests/testP2PNode.go b/integrationTests/testP2PNode.go index b3354b1af61..cbbee008a78 100644 --- a/integrationTests/testP2PNode.go +++ b/integrationTests/testP2PNode.go @@ -12,22 +12,22 @@ import ( crypto "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" - mclsig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/singlesig" "github.com/ElrondNetwork/elrond-go/common/enablers" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/dataRetriever" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" - "github.com/ElrondNetwork/elrond-go/factory" + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" "github.com/ElrondNetwork/elrond-go/factory/peerSignatureHandler" "github.com/ElrondNetwork/elrond-go/integrationTests/mock" "github.com/ElrondNetwork/elrond-go/node" "github.com/ElrondNetwork/elrond-go/p2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" "github.com/ElrondNetwork/elrond-go/process/smartContract" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/networksharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/state" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/nodeTypeProviderMock" "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" @@ -62,7 +62,7 @@ type TestP2PNode struct { func NewTestP2PNode( maxShards uint32, nodeShardId uint32, - p2pConfig config.P2PConfig, + p2pConfig p2pConfig.P2PConfig, coordinator nodesCoordinator.NodesCoordinator, keys TestKeyPair, ) *TestP2PNode { @@ -80,9 +80,9 @@ func NewTestP2PNode( tP2pNode.ShardCoordinator = shardCoordinator - pidPk, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) - pkShardId, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) - pidShardId, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) + pidPk, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + pkShardId, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + pidShardId, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) arg := networksharding.ArgPeerShardMapper{ PeerIdPkCache: pidPk, FallbackPkShardCache: pkShardId, @@ -116,7 +116,7 @@ func (tP2pNode *TestP2PNode) initStorage() { } func (tP2pNode *TestP2PNode) initCrypto() { - tP2pNode.SingleSigner = &mclsig.BlsSingleSigner{} + tP2pNode.SingleSigner = TestSingleBlsSigner suite := mcl.NewSuiteBLS12() tP2pNode.KeyGen = signing.NewKeyGenerator(suite) } @@ -216,7 +216,7 @@ func (tP2pNode *TestP2PNode) initNode() { HideInactiveValidatorIntervalInSec: 600, } - hbCompArgs := factory.HeartbeatComponentsFactoryArgs{ + hbCompArgs := heartbeatComp.HeartbeatComponentsFactoryArgs{ Config: config.Config{ Heartbeat: hbConfig, }, @@ -230,8 +230,8 @@ func (tP2pNode *TestP2PNode) initNode() { CryptoComponents: cryptoComponents, ProcessComponents: processComponents, } - heartbeatComponentsFactory, _ := factory.NewHeartbeatComponentsFactory(hbCompArgs) - managedHBComponents, err := factory.NewManagedHeartbeatComponents(heartbeatComponentsFactory) + heartbeatComponentsFactory, _ := heartbeatComp.NewHeartbeatComponentsFactory(hbCompArgs) + managedHBComponents, err := heartbeatComp.NewManagedHeartbeatComponents(heartbeatComponentsFactory) log.LogIfError(err) err = managedHBComponents.Create() @@ -322,7 +322,7 @@ func CreateNodesWithTestP2PNodes( shardConsensusGroupSize int, metaConsensusGroupSize int, numObserversOnShard int, - p2pConfig config.P2PConfig, + p2pConfig p2pConfig.P2PConfig, ) map[uint32][]*TestP2PNode { cp := CreateCryptoParams(nodesPerShard, numMetaNodes, uint32(numShards)) @@ -330,9 +330,8 @@ func CreateNodesWithTestP2PNodes( validatorsMap := GenValidatorsFromPubKeys(pubKeys, uint32(numShards)) validatorsForNodesCoordinator, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) nodesMap := make(map[uint32][]*TestP2PNode) - cacherCfg := storageUnit.CacheConfig{Capacity: 10000, Type: storageUnit.LRUCache, Shards: 1} - cache, _ := storageUnit.NewCache(cacherCfg) - + cacherCfg := storageunit.CacheConfig{Capacity: 10000, Type: storageunit.LRUCache, Shards: 1} + cache, _ := storageunit.NewCache(cacherCfg) for shardId, validatorList := range validatorsMap { argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ ShardConsensusGroupSize: shardConsensusGroupSize, diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index 59b45f32a14..be9097ed439 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/hex" + "errors" "fmt" "math/big" "strconv" @@ -28,6 +29,7 @@ import ( crypto "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519" + ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" mclsig "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl/singlesig" nodeFactory "github.com/ElrondNetwork/elrond-go/cmd/node/factory" @@ -45,8 +47,8 @@ import ( "github.com/ElrondNetwork/elrond-go/epochStart/metachain" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" "github.com/ElrondNetwork/elrond-go/epochStart/shardchain" - mainFactory "github.com/ElrondNetwork/elrond-go/factory" hdrFactory "github.com/ElrondNetwork/elrond-go/factory/block" + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" "github.com/ElrondNetwork/elrond-go/factory/peerSignatureHandler" "github.com/ElrondNetwork/elrond-go/genesis" "github.com/ElrondNetwork/elrond-go/genesis/parsing" @@ -56,7 +58,6 @@ import ( "github.com/ElrondNetwork/elrond-go/node/external" "github.com/ElrondNetwork/elrond-go/node/nodeDebugFactory" "github.com/ElrondNetwork/elrond-go/p2p" - p2pRating "github.com/ElrondNetwork/elrond-go/p2p/rating" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/block" "github.com/ElrondNetwork/elrond-go/process/block/bootstrapStorage" @@ -90,8 +91,8 @@ import ( "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/state/blockInfoProviders" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/storage/txcache" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/bootstrapMocks" @@ -106,7 +107,9 @@ import ( stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" statusHandlerMock "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" storageStubs "github.com/ElrondNetwork/elrond-go/testscommon/storage" + trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/update" "github.com/ElrondNetwork/elrond-go/update/trigger" "github.com/ElrondNetwork/elrond-go/vm" @@ -114,7 +117,6 @@ import ( "github.com/ElrondNetwork/elrond-go/vm/systemSmartContracts/defaults" vmcommon "github.com/ElrondNetwork/elrond-vm-common" "github.com/ElrondNetwork/elrond-vm-common/parsers" - "github.com/pkg/errors" ) var zero = big.NewInt(0) @@ -143,7 +145,7 @@ var TestAddressPubkeyConverter, _ = pubkeyConverter.NewBech32PubkeyConverter(32, var TestValidatorPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(96) // TestMultiSig represents a mock multisig -var TestMultiSig = cryptoMocks.NewMultiSigner(1) +var TestMultiSig = cryptoMocks.NewMultiSigner() // TestKeyGenForAccounts represents a mock key generator for balances var TestKeyGenForAccounts = signing.NewKeyGenerator(ed25519.NewEd25519()) @@ -212,6 +214,12 @@ const stateCheckpointModulus = uint(100) // UnreachableEpoch defines an unreachable epoch for integration tests const UnreachableEpoch = uint32(1000000) +// TestSingleSigner defines a Ed25519Signer +var TestSingleSigner = &ed25519SingleSig.Ed25519Signer{} + +// TestSingleBlsSigner defines a BlsSingleSigner +var TestSingleBlsSigner = &mclsig.BlsSingleSigner{} + // TestKeyPair holds a pair of private/public Keys type TestKeyPair struct { Sk crypto.PrivateKey @@ -398,8 +406,8 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { nodesCoordinatorInstance = getDefaultNodesCoordinator(args.MaxShards, pksBytes) } - peersRatingHandler, _ := p2pRating.NewPeersRatingHandler( - p2pRating.ArgPeersRatingHandler{ + peersRatingHandler, _ := p2p.NewPeersRatingHandler( + p2p.ArgPeersRatingHandler{ TopRatedCache: testscommon.NewCacherMock(), BadRatedCache: testscommon.NewCacherMock(), }) @@ -409,7 +417,7 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { genericEpochNotifier := forking.NewGenericEpochNotifier() epochsConfig := args.EpochsConfig if epochsConfig == nil { - epochsConfig = getDefaultEnableEpochsConfig() + epochsConfig = GetDefaultEnableEpochsConfig() } enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(*epochsConfig, genericEpochNotifier) @@ -725,14 +733,13 @@ func (tpn *TestProcessorNode) createFullSCQueryService(gasMap map[string]map[str gasSchedule := mock.NewGasScheduleNotifierMock(gasMap) argsBuiltIn := builtInFunctions.ArgsCreateBuiltInFunctionContainer{ - GasSchedule: gasSchedule, - MapDNSAddresses: make(map[string]struct{}), - Marshalizer: TestMarshalizer, - Accounts: tpn.AccntState, - ShardCoordinator: tpn.ShardCoordinator, - EpochNotifier: tpn.EpochNotifier, - EnableEpochsHandler: tpn.EnableEpochsHandler, - + GasSchedule: gasSchedule, + MapDNSAddresses: make(map[string]struct{}), + Marshalizer: TestMarshalizer, + Accounts: tpn.AccntState, + ShardCoordinator: tpn.ShardCoordinator, + EpochNotifier: tpn.EpochNotifier, + EnableEpochsHandler: tpn.EnableEpochsHandler, MaxNumNodesInTransferRole: 100, } argsBuiltIn.AutomaticCrawlerAddresses = GenerateOneAddressPerShard(argsBuiltIn.ShardCoordinator) @@ -896,12 +903,12 @@ func (tpn *TestProcessorNode) InitializeProcessors(gasMap map[string]map[string] func (tpn *TestProcessorNode) initDataPools() { tpn.DataPool = dataRetrieverMock.CreatePoolsHolder(1, tpn.ShardCoordinator.SelfId()) - cacherCfg := storageUnit.CacheConfig{Capacity: 10000, Type: storageUnit.LRUCache, Shards: 1} - cache, _ := storageUnit.NewCache(cacherCfg) - tpn.WhiteListHandler, _ = interceptors.NewWhiteListDataVerifier(cache) + cacherCfg := storageunit.CacheConfig{Capacity: 10000, Type: storageunit.LRUCache, Shards: 1} + suCache, _ := storageunit.NewCache(cacherCfg) + tpn.WhiteListHandler, _ = interceptors.NewWhiteListDataVerifier(suCache) - cacherVerifiedCfg := storageUnit.CacheConfig{Capacity: 5000, Type: storageUnit.LRUCache, Shards: 1} - cacheVerified, _ := storageUnit.NewCache(cacherVerifiedCfg) + cacherVerifiedCfg := storageunit.CacheConfig{Capacity: 5000, Type: storageunit.LRUCache, Shards: 1} + cacheVerified, _ := storageunit.NewCache(cacherVerifiedCfg) tpn.WhiteListerVerifiedTxs, _ = interceptors.NewWhiteListDataVerifier(cacheVerified) } @@ -1071,7 +1078,7 @@ func CreateRatingsData() *rating.RatingsData { func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { var err error - tpn.BlockBlackListHandler = timecache.NewTimeCache(TimeSpanForBadHeaders) + tpn.BlockBlackListHandler = cache.NewTimeCache(TimeSpanForBadHeaders) if check.IfNil(tpn.EpochStartNotifier) { tpn.EpochStartNotifier = notifier.NewEpochStartSubscriptionHandler() } @@ -1096,7 +1103,7 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { cryptoComponents.PubKey = nil cryptoComponents.BlockSig = tpn.OwnAccount.BlockSingleSigner cryptoComponents.TxSig = tpn.OwnAccount.SingleSigner - cryptoComponents.MultiSig = TestMultiSig + cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(TestMultiSig) cryptoComponents.BlKeyGen = tpn.OwnAccount.KeygenBlockSign cryptoComponents.TxKeyGen = tpn.OwnAccount.KeygenTxSign @@ -1364,14 +1371,13 @@ func (tpn *TestProcessorNode) initInnerProcessors(gasMap map[string]map[string]u gasSchedule := mock.NewGasScheduleNotifierMock(gasMap) argsBuiltIn := builtInFunctions.ArgsCreateBuiltInFunctionContainer{ - GasSchedule: gasSchedule, - MapDNSAddresses: mapDNSAddresses, - Marshalizer: TestMarshalizer, - Accounts: tpn.AccntState, - ShardCoordinator: tpn.ShardCoordinator, - EpochNotifier: tpn.EpochNotifier, - EnableEpochsHandler: tpn.EnableEpochsHandler, - + GasSchedule: gasSchedule, + MapDNSAddresses: mapDNSAddresses, + Marshalizer: TestMarshalizer, + Accounts: tpn.AccntState, + ShardCoordinator: tpn.ShardCoordinator, + EpochNotifier: tpn.EpochNotifier, + EnableEpochsHandler: tpn.EnableEpochsHandler, MaxNumNodesInTransferRole: 100, } argsBuiltIn.AutomaticCrawlerAddresses = GenerateOneAddressPerShard(argsBuiltIn.ShardCoordinator) @@ -1577,14 +1583,13 @@ func (tpn *TestProcessorNode) initMetaInnerProcessors(gasMap map[string]map[stri gasSchedule := mock.NewGasScheduleNotifierMock(gasMap) argsBuiltIn := builtInFunctions.ArgsCreateBuiltInFunctionContainer{ - GasSchedule: gasSchedule, - MapDNSAddresses: make(map[string]struct{}), - Marshalizer: TestMarshalizer, - Accounts: tpn.AccntState, - ShardCoordinator: tpn.ShardCoordinator, - EpochNotifier: tpn.EpochNotifier, - EnableEpochsHandler: tpn.EnableEpochsHandler, - + GasSchedule: gasSchedule, + MapDNSAddresses: make(map[string]struct{}), + Marshalizer: TestMarshalizer, + Accounts: tpn.AccntState, + ShardCoordinator: tpn.ShardCoordinator, + EpochNotifier: tpn.EpochNotifier, + EnableEpochsHandler: tpn.EnableEpochsHandler, MaxNumNodesInTransferRole: 100, } argsBuiltIn.AutomaticCrawlerAddresses = GenerateOneAddressPerShard(argsBuiltIn.ShardCoordinator) @@ -1849,7 +1854,7 @@ func (tpn *TestProcessorNode) processSCOutputAccounts(vmOutput *vmcommon.VMOutpu storageUpdates := process.GetSortedStorageUpdates(outAcc) for _, storeUpdate := range storageUpdates { - err = acc.DataTrieTracker().SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) + err = acc.SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) if err != nil { return err } @@ -2217,7 +2222,7 @@ func (tpn *TestProcessorNode) initNode() { cryptoComponents.PubKey = tpn.NodeKeys.Pk cryptoComponents.TxSig = tpn.OwnAccount.SingleSigner cryptoComponents.BlockSig = tpn.OwnAccount.SingleSigner - cryptoComponents.MultiSig = tpn.MultiSigner + cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(tpn.MultiSigner) cryptoComponents.BlKeyGen = tpn.OwnAccount.KeygenTxSign cryptoComponents.TxKeyGen = TestKeyGenForAccounts @@ -2395,8 +2400,7 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod log.Warn("blockHeader.SetPrevRandSeed", "error", err.Error()) return nil, nil, nil } - - sig, _ := TestMultiSig.AggregateSigs(nil) + sig := []byte("aggregated signature") err = blockHeader.SetSignature(sig) if err != nil { log.Warn("blockHeader.SetSignature", "error", err.Error()) @@ -2514,14 +2518,14 @@ func (tpn *TestProcessorNode) GetShardHeader(nonce uint64) (data.HeaderHandler, headerObjects, _, err := tpn.DataPool.Headers().GetHeadersByNonceAndShardId(nonce, tpn.ShardCoordinator.SelfId()) if err != nil { - return nil, errors.New(fmt.Sprintf("no headers found for nonce %d and shard id %d %s", nonce, tpn.ShardCoordinator.SelfId(), err.Error())) + return nil, fmt.Errorf("%w no headers found for nonce %d and shard id %d", err, nonce, tpn.ShardCoordinator.SelfId()) } headerObject := headerObjects[len(headerObjects)-1] header, ok := headerObject.(*dataBlock.Header) if !ok { - return nil, errors.New(fmt.Sprintf("not a *dataBlock.Header stored in headers found for nonce and shard id %d %d", nonce, tpn.ShardCoordinator.SelfId())) + return nil, fmt.Errorf("not a *dataBlock.Header stored in headers found for nonce and shard id %d %d", nonce, tpn.ShardCoordinator.SelfId()) } return header, nil @@ -2540,12 +2544,12 @@ func (tpn *TestProcessorNode) GetBlockBody(header data.HeaderHandler) (*dataBloc mbObject, ok := tpn.DataPool.MiniBlocks().Get(miniBlockHash) if !ok { - return nil, errors.New(fmt.Sprintf("no miniblock found for hash %s", hex.EncodeToString(miniBlockHash))) + return nil, fmt.Errorf("no miniblock found for hash %s", hex.EncodeToString(miniBlockHash)) } mb, ok := mbObject.(*dataBlock.MiniBlock) if !ok { - return nil, errors.New(fmt.Sprintf("not a *dataBlock.MiniBlock stored in miniblocks found for hash %s", hex.EncodeToString(miniBlockHash))) + return nil, fmt.Errorf("not a *dataBlock.MiniBlock stored in miniblocks found for hash %s", hex.EncodeToString(miniBlockHash)) } body.MiniBlocks = append(body.MiniBlocks, mb) @@ -2567,12 +2571,12 @@ func (tpn *TestProcessorNode) GetMetaBlockBody(header *dataBlock.MetaBlock) (*da mbObject, ok := tpn.DataPool.MiniBlocks().Get(miniBlockHash) if !ok { - return nil, errors.New(fmt.Sprintf("no miniblock found for hash %s", hex.EncodeToString(miniBlockHash))) + return nil, fmt.Errorf("no miniblock found for hash %s", hex.EncodeToString(miniBlockHash)) } mb, ok := mbObject.(*dataBlock.MiniBlock) if !ok { - return nil, errors.New(fmt.Sprintf("not a *dataBlock.MiniBlock stored in miniblocks found for hash %s", hex.EncodeToString(miniBlockHash))) + return nil, fmt.Errorf("not a *dataBlock.MiniBlock stored in miniblocks found for hash %s", hex.EncodeToString(miniBlockHash)) } body.MiniBlocks = append(body.MiniBlocks, mb) @@ -2590,14 +2594,14 @@ func (tpn *TestProcessorNode) GetMetaHeader(nonce uint64) (*dataBlock.MetaBlock, headerObjects, _, err := tpn.DataPool.Headers().GetHeadersByNonceAndShardId(nonce, core.MetachainShardId) if err != nil { - return nil, errors.New(fmt.Sprintf("no headers found for nonce and shard id %d %d %s", nonce, core.MetachainShardId, err.Error())) + return nil, fmt.Errorf("%w no headers found for nonce and shard id %d %d", err, nonce, core.MetachainShardId) } headerObject := headerObjects[len(headerObjects)-1] header, ok := headerObject.(*dataBlock.MetaBlock) if !ok { - return nil, errors.New(fmt.Sprintf("not a *dataBlock.MetaBlock stored in headers found for nonce and shard id %d %d", nonce, core.MetachainShardId)) + return nil, fmt.Errorf("not a *dataBlock.MetaBlock stored in headers found for nonce and shard id %d %d", nonce, core.MetachainShardId) } return header, nil @@ -2708,7 +2712,7 @@ func (tpn *TestProcessorNode) initRoundHandler() { } func (tpn *TestProcessorNode) initRequestedItemsHandler() { - tpn.RequestedItemsHandler = timecache.NewTimeCache(roundDuration) + tpn.RequestedItemsHandler = cache.NewTimeCache(roundDuration) } func (tpn *TestProcessorNode) initBlockTracker() { @@ -2764,7 +2768,7 @@ func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { cryptoComponents.PubKey = tpn.NodeKeys.Pk cryptoComponents.TxSig = tpn.OwnAccount.SingleSigner cryptoComponents.BlockSig = tpn.OwnAccount.SingleSigner - cryptoComponents.MultiSig = tpn.MultiSigner + cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(tpn.MultiSigner) cryptoComponents.BlKeyGen = tpn.OwnAccount.KeygenTxSign cryptoComponents.TxKeyGen = TestKeyGenForAccounts cryptoComponents.PeerSignHandler = psh @@ -2817,7 +2821,7 @@ func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { HideInactiveValidatorIntervalInSec: 600, } - hbFactoryArgs := mainFactory.HeartbeatComponentsFactoryArgs{ + hbFactoryArgs := heartbeatComp.HeartbeatComponentsFactoryArgs{ Config: config.Config{ Heartbeat: hbConfig, }, @@ -2830,10 +2834,10 @@ func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { ProcessComponents: tpn.Node.GetProcessComponents(), } - heartbeatFactory, err := mainFactory.NewHeartbeatComponentsFactory(hbFactoryArgs) + heartbeatFactory, err := heartbeatComp.NewHeartbeatComponentsFactory(hbFactoryArgs) log.LogIfError(err) - managedHeartbeatComponents, err := mainFactory.NewManagedHeartbeatComponents(heartbeatFactory) + managedHeartbeatComponents, err := heartbeatComp.NewManagedHeartbeatComponents(heartbeatFactory) log.LogIfError(err) err = managedHeartbeatComponents.Create() @@ -2870,7 +2874,7 @@ func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { }, } - hbv2FactoryArgs := mainFactory.ArgHeartbeatV2ComponentsFactory{ + hbv2FactoryArgs := heartbeatComp.ArgHeartbeatV2ComponentsFactory{ Config: config.Config{ HeartbeatV2: hbv2Config, Hardfork: config.HardforkConfig{ @@ -2885,10 +2889,10 @@ func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { ProcessComponents: tpn.Node.GetProcessComponents(), } - heartbeatV2Factory, err := mainFactory.NewHeartbeatV2ComponentsFactory(hbv2FactoryArgs) + heartbeatV2Factory, err := heartbeatComp.NewHeartbeatV2ComponentsFactory(hbv2FactoryArgs) log.LogIfError(err) - managedHeartbeatV2Components, err := mainFactory.NewManagedHeartbeatV2Components(heartbeatV2Factory) + managedHeartbeatV2Components, err := heartbeatComp.NewManagedHeartbeatV2Components(heartbeatV2Factory) log.LogIfError(err) err = managedHeartbeatV2Components.Create() @@ -3076,12 +3080,13 @@ func GetDefaultCryptoComponents() *mock.CryptoComponentsStub { PubKeyBytes: []byte("pubKey"), BlockSig: &mock.SignerMock{}, TxSig: &mock.SignerMock{}, - MultiSig: TestMultiSig, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(TestMultiSig), PeerSignHandler: &mock.PeerSignatureHandler{}, BlKeyGen: &mock.KeyGenMock{}, TxKeyGen: &mock.KeyGenMock{}, MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, ManagedPeersHolderField: &testscommon.ManagedPeersHolderStub{}, + KeysHandlerField: &testscommon.KeysHandlerStub{}, } } @@ -3091,7 +3096,7 @@ func GetDefaultStateComponents() *testscommon.StateComponentsMock { PeersAcc: &stateMock.AccountsStub{}, Accounts: &stateMock.AccountsStub{}, AccountsRepo: &stateMock.AccountsRepositoryStub{}, - Tries: &mock.TriesHolderStub{}, + Tries: &trieMock.TriesHolderStub{}, StorageManagers: map[string]common.StorageManager{ "0": &testscommon.StorageManagerStub{}, trieFactory.UserAccountTrie: &testscommon.StorageManagerStub{}, @@ -3132,7 +3137,7 @@ func getDefaultBootstrapComponents(shardCoordinator sharding.Coordinator) *mainF return &mainFactoryMocks.BootstrapComponentsStub{ Bootstrapper: &bootstrapMocks.EpochStartBootstrapperStub{ - TrieHolder: &mock.TriesHolderStub{}, + TrieHolder: &trieMock.TriesHolderStub{}, StorageManagers: map[string]common.StorageManager{"0": &testscommon.StorageManagerStub{}}, BootstrapCalled: nil, }, @@ -3162,7 +3167,7 @@ func GetTokenIdentifier(nodes []*TestProcessorNode, ticker []byte) []byte { rootHash, _ := userAcc.DataTrie().RootHash() chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - _ = userAcc.DataTrie().GetAllLeavesOnChannel(chLeaves, context.Background(), rootHash) + _ = userAcc.DataTrie().GetAllLeavesOnChannel(chLeaves, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) for leaf := range chLeaves { if !bytes.HasPrefix(leaf.Key(), ticker) { continue @@ -3250,7 +3255,8 @@ func getDefaultNodesCoordinator(maxShards uint32, pksBytes map[uint32][]byte) no } } -func getDefaultEnableEpochsConfig() *config.EnableEpochs { +// GetDefaultEnableEpochsConfig returns a default EnableEpochs config +func GetDefaultEnableEpochsConfig() *config.EnableEpochs { return &config.EnableEpochs{ OptimizeGasUsedInCrossMiniBlocksEnableEpoch: UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, diff --git a/integrationTests/testProcessorNodeWithCoordinator.go b/integrationTests/testProcessorNodeWithCoordinator.go index e573900ef86..43171a395e0 100644 --- a/integrationTests/testProcessorNodeWithCoordinator.go +++ b/integrationTests/testProcessorNodeWithCoordinator.go @@ -8,12 +8,11 @@ import ( crypto "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go-crypto/signing" "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519" - ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" "github.com/ElrondNetwork/elrond-go-crypto/signing/mcl" "github.com/ElrondNetwork/elrond-go/integrationTests/mock" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" "github.com/ElrondNetwork/elrond-go/testscommon" vic "github.com/ElrondNetwork/elrond-go/testscommon/validatorInfoCacher" ) @@ -59,7 +58,7 @@ func CreateProcessorNodesWithNodesCoordinator( for shardId, validatorList := range validatorsMap { nodesList := make([]*TestProcessorNode, len(validatorList)) for i, v := range validatorList { - cache, _ := lrucache.NewCache(10000) + lruCache, _ := cache.NewLRUCache(10000) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ ShardConsensusGroupSize: shardConsensusGroupSize, MetaConsensusGroupSize: metaConsensusGroupSize, @@ -70,7 +69,7 @@ func CreateProcessorNodesWithNodesCoordinator( EligibleNodes: validatorsMapForNodesCoordinator, WaitingNodes: waitingMapForNodesCoordinator, SelfPublicKey: v.PubKeyBytes(), - ConsensusGroupCache: cache, + ConsensusGroupCache: lruCache, ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, ChanStopNode: endProcess.GetDummyEndProcessChannel(), IsFullArchive: false, @@ -83,7 +82,7 @@ func CreateProcessorNodesWithNodesCoordinator( fmt.Println("error creating node coordinator") } - multiSigner, err := createMultiSigner(*cp, shardId, i) + multiSigner, err := createMultiSigner(*cp) if err != nil { log.Error("error generating multisigner: %s\n", err) return nil, 0 @@ -92,8 +91,8 @@ func CreateProcessorNodesWithNodesCoordinator( kp := ncp[shardId][i] ownAccount := &TestWalletAccount{ - SingleSigner: createTestSingleSigner(), - BlockSingleSigner: createTestSingleSigner(), + SingleSigner: TestSingleSigner, + BlockSingleSigner: TestSingleSigner, SkTxSign: kp.TxSignSk, PkTxSign: kp.TxSignPk, PkTxSignBytes: kp.TxSignPkBytes, @@ -128,10 +127,6 @@ func CreateProcessorNodesWithNodesCoordinator( return nodesMap, numShards } -func createTestSingleSigner() crypto.SingleSigner { - return &ed25519SingleSig.Ed25519Signer{} -} - func createNodesCryptoParams(rewardsAddrsAssignments map[uint32][]uint32) (map[uint32][]*nodeKeys, uint32) { numShards := uint32(0) suiteBlock := mcl.NewSuiteBLS12() diff --git a/integrationTests/testProcessorNodeWithMultisigner.go b/integrationTests/testProcessorNodeWithMultisigner.go index 7294bafc560..bb7a6e08282 100644 --- a/integrationTests/testProcessorNodeWithMultisigner.go +++ b/integrationTests/testProcessorNodeWithMultisigner.go @@ -26,9 +26,10 @@ import ( "github.com/ElrondNetwork/elrond-go/process/rating" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" + "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" "github.com/ElrondNetwork/elrond-go/testscommon/nodeTypeProviderMock" "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" vic "github.com/ElrondNetwork/elrond-go/testscommon/validatorInfoCacher" @@ -97,7 +98,7 @@ func CreateNodesWithNodesCoordinatorAndTxKeys( nodesList := make([]*TestProcessorNode, len(validatorList)) for i := range validatorList { - dataCache, _ := lrucache.NewCache(10000) + dataCache, _ := cache.NewLRUCache(10000) tpn := CreateNodeWithBLSAndTxKeys( nodesPerShard, nbMetaNodes, @@ -166,7 +167,7 @@ func CreateNodeWithBLSAndTxKeys( twa.KeygenBlockSign = &mock.KeyGenMock{} twa.Address = twa.PkTxSignBytes - peerSigCache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) + peerSigCache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) twa.PeerSigHandler, _ = peerSignatureHandler.NewPeerSignatureHandler(peerSigCache, twa.SingleSigner, keyGen) epochsConfig := config.EnableEpochs{ @@ -242,7 +243,7 @@ func CreateNodesWithNodesCoordinatorFactory( nodesListWaiting := make([]*TestProcessorNode, len(waitingMap[shardId])) for i := range validatorList { - dataCache, _ := lrucache.NewCache(10000) + dataCache, _ := cache.NewLRUCache(10000) tpn := CreateNode( nodesPerShard, nbMetaNodes, @@ -266,7 +267,7 @@ func CreateNodesWithNodesCoordinatorFactory( } for i := range waitingMap[shardId] { - dataCache, _ := lrucache.NewCache(10000) + dataCache, _ := cache.NewLRUCache(10000) tpn := CreateNode( nodesPerShard, nbMetaNodes, @@ -343,7 +344,7 @@ func CreateNode( txSignPrivKeyShardId = 0 } - multiSigner, err := createMultiSigner(*cp, shardId, keyIndex) + multiSigner, err := createMultiSigner(*cp) if err != nil { log.Error("error generating multisigner: %s\n", err) return nil @@ -409,7 +410,7 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( completeNodesList := make([]Connectable, 0) for shardId, validatorList := range validatorsMap { - consensusCache, _ := lrucache.NewCache(10000) + consensusCache, _ := cache.NewLRUCache(10000) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ ShardConsensusGroupSize: shardConsensusGroupSize, MetaConsensusGroupSize: metaConsensusGroupSize, @@ -442,7 +443,7 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( Marshalizer: TestMarshalizer, Hasher: TestHasher, NodesCoordinator: nodesCoordinatorInstance, - MultiSigVerifier: TestMultiSig, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(TestMultiSig), SingleSigVerifier: signer, KeyGen: keyGen, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, @@ -455,7 +456,7 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( } for i := range validatorList { - multiSigner, err := createMultiSigner(*cp, shardId, i) + multiSigner, err := createMultiSigner(*cp) if err != nil { log.Error("error generating multisigner: %s\n", err) return nil @@ -524,7 +525,7 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( completeNodesList := make([]Connectable, 0) for shardId, validatorList := range validatorsMap { bootStorer := CreateMemUnit() - cache, _ := lrucache.NewCache(10000) + lruCache, _ := cache.NewLRUCache(10000) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ ShardConsensusGroupSize: shardConsensusGroupSize, MetaConsensusGroupSize: metaConsensusGroupSize, @@ -538,7 +539,7 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( EligibleNodes: validatorsMapForNodesCoordinator, WaitingNodes: waitingMapForNodesCoordinator, SelfPublicKey: []byte(strconv.Itoa(int(shardId))), - ConsensusGroupCache: cache, + ConsensusGroupCache: lruCache, ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, @@ -571,7 +572,7 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( Marshalizer: TestMarshalizer, Hasher: TestHasher, NodesCoordinator: nodesCoord, - MultiSigVerifier: TestMultiSig, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(TestMultiSig), SingleSigVerifier: singleSigner, KeyGen: keyGenForBlocks, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, @@ -579,7 +580,7 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( headerSig, _ := headerCheck.NewHeaderSigVerifier(&args) - multiSigner, err := createMultiSigner(*cp, shardId, i) + multiSigner, err := createMultiSigner(*cp) if err != nil { log.Error("error generating multisigner: %s\n", err) return nil @@ -697,17 +698,17 @@ func DoConsensusSigningOnBlock( blockHeaderHash, _ := core.CalculateHash(TestMarshalizer, TestHasher, blockHeader) - var msig crypto.MultiSigner - msigProposer, _ := consensusNodes[0].MultiSigner.Create(pubKeys, 0) - _, _ = msigProposer.CreateSignatureShare(blockHeaderHash, bitmap) + pubKeysBytes := make([][]byte, len(consensusNodes)) + sigShares := make([][]byte, len(consensusNodes)) + msig := consensusNodes[0].MultiSigner - for i := 1; i < len(consensusNodes); i++ { - msig, _ = consensusNodes[i].MultiSigner.Create(pubKeys, uint16(i)) - sigShare, _ := msig.CreateSignatureShare(blockHeaderHash, bitmap) - _ = msigProposer.StoreSignatureShare(uint16(i), sigShare) + for i := 0; i < len(consensusNodes); i++ { + pubKeysBytes[i] = []byte(pubKeys[i]) + sk, _ := consensusNodes[i].NodeKeys.Sk.ToByteArray() + sigShares[i], _ = msig.CreateSignatureShare(sk, blockHeaderHash) } - sig, _ := msigProposer.AggregateSigs(bitmap) + sig, _ := msig.AggregateSigs(pubKeysBytes, sigShares) err = blockHeader.SetSignature(sig) if err != nil { log.Error("blockHeader.SetSignature", "error", err) @@ -789,17 +790,11 @@ func SyncAllShardsWithRoundBlock( time.Sleep(4 * StepDelay) } -func createMultiSigner(cp CryptoParams, shardId uint32, ownKeyIndex int) (crypto.MultiSigner, error) { +func createMultiSigner(cp CryptoParams) (crypto.MultiSigner, error) { blsHasher, _ := blake2b.NewBlake2bWithSize(hashing.BlsHashSize) llsig := &mclmultisig.BlsMultiSigner{Hasher: blsHasher} - - pubKeysMap := PubKeysMapFromKeysMap(cp.Keys) - return multisig.NewBLSMultisig( llsig, - pubKeysMap[shardId], - cp.Keys[shardId][ownKeyIndex].Sk, cp.KeyGen, - 0, ) } diff --git a/integrationTests/testStorage.go b/integrationTests/testStorage.go index 277212066f2..17bdf5034e9 100644 --- a/integrationTests/testStorage.go +++ b/integrationTests/testStorage.go @@ -13,9 +13,9 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data/transaction" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/leveldb" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) const batchDelaySeconds = 10 @@ -77,9 +77,9 @@ func (ts *TestStorage) CreateStoredData(nonce uint64) ([]byte, []byte) { // CreateStorageLevelDB creates a storage levelDB func (ts *TestStorage) CreateStorageLevelDB() storage.Storer { - db, _ := leveldb.NewDB("Transactions", batchDelaySeconds, maxBatchSize, maxOpenFiles) - cacher, _ := lrucache.NewCache(50000) - store, _ := storageUnit.NewStorageUnit( + db, _ := database.NewLevelDB("Transactions", batchDelaySeconds, maxBatchSize, maxOpenFiles) + cacher, _ := cache.NewLRUCache(50000) + store, _ := storageunit.NewStorageUnit( cacher, db, ) @@ -89,9 +89,9 @@ func (ts *TestStorage) CreateStorageLevelDB() storage.Storer { // CreateStorageLevelDBSerial creates a storage levelDB serial func (ts *TestStorage) CreateStorageLevelDBSerial() storage.Storer { - db, _ := leveldb.NewSerialDB("Transactions", batchDelaySeconds, maxBatchSize, maxOpenFiles) - cacher, _ := lrucache.NewCache(50000) - store, _ := storageUnit.NewStorageUnit( + db, _ := database.NewSerialDB("Transactions", batchDelaySeconds, maxBatchSize, maxOpenFiles) + cacher, _ := cache.NewLRUCache(50000) + store, _ := storageunit.NewStorageUnit( cacher, db, ) diff --git a/integrationTests/testWalletAccount.go b/integrationTests/testWalletAccount.go index 3cf07cf40ce..e5cd0257b98 100644 --- a/integrationTests/testWalletAccount.go +++ b/integrationTests/testWalletAccount.go @@ -6,11 +6,10 @@ import ( "math/big" "github.com/ElrondNetwork/elrond-go-crypto" - ed25519SingleSig "github.com/ElrondNetwork/elrond-go-crypto/signing/ed25519/singlesig" "github.com/ElrondNetwork/elrond-go/factory/peerSignatureHandler" "github.com/ElrondNetwork/elrond-go/integrationTests/mock" "github.com/ElrondNetwork/elrond-go/sharding" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // TestWalletAccount creates and account with balance and crypto necessary to sign transactions @@ -53,7 +52,7 @@ func CreateTestWalletAccountWithKeygenAndSingleSigner( // initCrypto initializes the crypto for the account func (twa *TestWalletAccount) initCrypto(coordinator sharding.Coordinator, shardId uint32) { - twa.SingleSigner = &ed25519SingleSig.Ed25519Signer{} + twa.SingleSigner = TestSingleSigner twa.BlockSingleSigner = &mock.SignerMock{ VerifyStub: func(public crypto.PublicKey, msg []byte, sig []byte) error { return nil @@ -71,7 +70,7 @@ func (twa *TestWalletAccount) initCrypto(coordinator sharding.Coordinator, shard twa.KeygenBlockSign = &mock.KeyGenMock{} twa.Address = twa.PkTxSignBytes - peerSigCache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000}) + peerSigCache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) twa.PeerSigHandler, _ = peerSignatureHandler.NewPeerSignatureHandler(peerSigCache, twa.SingleSigner, keyGen) } diff --git a/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverter.go b/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverter.go index 209e817381a..6279cc94735 100644 --- a/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverter.go +++ b/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverter.go @@ -33,7 +33,7 @@ func CreateAccountsFromMandosAccs(tc *vm.VMTestContext, mandosUserAccounts []*mg mandosAccStorage := mandosAcc.GetStorage() for key, value := range mandosAccStorage { - err = account.DataTrieTracker().SaveKeyValue([]byte(key), value) + err = account.SaveKeyValue([]byte(key), value) if err != nil { return err } diff --git a/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverterUtils.go b/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverterUtils.go index 05106b33704..ed3d3c025e7 100644 --- a/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverterUtils.go +++ b/integrationTests/vm/arwen/arwenvm/mandosConverter/mandosConverterUtils.go @@ -39,13 +39,12 @@ func CheckAccounts(t *testing.T, accAdapter state.AccountsAdapter, mandosAccount require.Equal(t, len(mandosAcc.GetCode()), len(code)) mandosAccStorage := mandosAcc.GetStorage() - accStorage := account.DataTrieTracker() - CheckStorage(t, accStorage, mandosAccStorage) + CheckStorage(t, account, mandosAccStorage) } } // CheckStorage checks if the dataTrie of an account equals with the storage of the corresponding mandosAccount -func CheckStorage(t *testing.T, dataTrie state.DataTrieTracker, mandosAccStorage map[string][]byte) { +func CheckStorage(t *testing.T, dataTrie state.UserAccountHandler, mandosAccStorage map[string][]byte) { for key := range mandosAccStorage { dataTrieValue, err := dataTrie.RetrieveValue([]byte(key)) require.Nil(t, err) diff --git a/integrationTests/vm/arwen/delegation/testRunner.go b/integrationTests/vm/arwen/delegation/testRunner.go index 67bdb743e3b..fb41205636c 100644 --- a/integrationTests/vm/arwen/delegation/testRunner.go +++ b/integrationTests/vm/arwen/delegation/testRunner.go @@ -18,7 +18,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" systemVm "github.com/ElrondNetwork/elrond-go/vm" vmcommon "github.com/ElrondNetwork/elrond-vm-common" ) @@ -35,13 +35,13 @@ func RunDelegationStressTest( gasSchedule map[string]map[string]uint64, ) ([]time.Duration, error) { - cacheConfig := storageUnit.CacheConfig{ + cacheConfig := storageunit.CacheConfig{ Name: "trie", Type: "SizeLRU", SizeInBytes: 314572800, // 300MB Capacity: 500000, } - trieCache, err := storageUnit.NewCache(cacheConfig) + trieCache, err := storageunit.NewCache(cacheConfig) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func RunDelegationStressTest( return nil, err } - trieStorage, err := storageUnit.NewStorageUnit(trieCache, triePersister) + trieStorage, err := storageunit.NewStorageUnit(trieCache, triePersister) if err != nil { return nil, err } diff --git a/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go b/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go index 0d6d47ec113..aace8753701 100644 --- a/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go +++ b/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go @@ -380,6 +380,6 @@ func checkDataFromAccountAndKey( expectedData []byte, ) { userAcc := esdtCommon.GetUserAccountWithAddress(t, address, nodes) - val, _ := userAcc.DataTrieTracker().RetrieveValue(key) + val, _ := userAcc.RetrieveValue(key) assert.Equal(t, expectedData, val) } diff --git a/integrationTests/vm/esdt/process/esdtProcess_test.go b/integrationTests/vm/esdt/process/esdtProcess_test.go index dc57b3041be..15ee84372f4 100644 --- a/integrationTests/vm/esdt/process/esdtProcess_test.go +++ b/integrationTests/vm/esdt/process/esdtProcess_test.go @@ -144,11 +144,11 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { wipedAcc := esdtCommon.GetUserAccountWithAddress(t, nodes[2].OwnAccount.Address, nodes) tokenKey := []byte(core.ElrondProtectedKeyPrefix + "esdt" + tokenIdentifier) - retrievedData, _ := wipedAcc.DataTrieTracker().RetrieveValue(tokenKey) + retrievedData, _ := wipedAcc.RetrieveValue(tokenKey) require.Equal(t, 0, len(retrievedData)) systemSCAcc := esdtCommon.GetUserAccountWithAddress(t, core.SystemAccountAddress, nodes) - retrievedData, _ = systemSCAcc.DataTrieTracker().RetrieveValue(tokenKey) + retrievedData, _ = systemSCAcc.RetrieveValue(tokenKey) esdtGlobalMetaData := vmcommonBuiltInFunctions.ESDTGlobalMetadataFromBytes(retrievedData) require.True(t, esdtGlobalMetaData.Paused) @@ -156,7 +156,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, finalSupply) esdtSCAcc := esdtCommon.GetUserAccountWithAddress(t, vm.ESDTSCAddress, nodes) - retrievedData, _ = esdtSCAcc.DataTrieTracker().RetrieveValue([]byte(tokenIdentifier)) + retrievedData, _ = esdtSCAcc.RetrieveValue([]byte(tokenIdentifier)) tokenInSystemSC := &systemSmartContracts.ESDTDataV2{} _ = integrationTests.TestMarshalizer.Unmarshal(tokenInSystemSC, retrievedData) require.Zero(t, tokenInSystemSC.MintedValue.Cmp(big.NewInt(initialSupply+mintValue))) @@ -258,7 +258,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { time.Sleep(time.Second) esdtSCAcc := esdtCommon.GetUserAccountWithAddress(t, vm.ESDTSCAddress, nodes) - retrievedData, _ := esdtSCAcc.DataTrieTracker().RetrieveValue([]byte(tokenIdentifier)) + retrievedData, _ := esdtSCAcc.RetrieveValue([]byte(tokenIdentifier)) tokenInSystemSC := &systemSmartContracts.ESDTDataV2{} _ = integrationTests.TestMarshalizer.Unmarshal(tokenInSystemSC, retrievedData) require.Equal(t, initialSupply, tokenInSystemSC.MintedValue.Int64()) diff --git a/integrationTests/vm/testInitializer.go b/integrationTests/vm/testInitializer.go index a92820c7be2..dc76d4418f5 100644 --- a/integrationTests/vm/testInitializer.go +++ b/integrationTests/vm/testInitializer.go @@ -51,8 +51,8 @@ import ( "github.com/ElrondNetwork/elrond-go/state/storagePruningManager" "github.com/ElrondNetwork/elrond-go/state/storagePruningManager/evictionWaitingList" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/storage/txcache" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" @@ -281,16 +281,16 @@ func CreateMemUnit() storage.Storer { capacity := uint32(10) shards := uint32(1) sizeInBytes := uint64(0) - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) - unit, _ := storageUnit.NewStorageUnit(cache, memorydb.New()) + unit, _ := storageunit.NewStorageUnit(cache, database.NewMemDB()) return unit } // CreateInMemoryShardAccountsDB - func CreateInMemoryShardAccountsDB() *state.AccountsDB { marshaller := &marshal.GogoProtoMarshalizer{} - ewl, _ := evictionWaitingList.NewEvictionWaitingList(100, memorydb.New(), marshaller) + ewl, _ := evictionWaitingList.NewEvictionWaitingList(100, database.NewMemDB(), marshaller) generalCfg := config.TrieStorageManagerConfig{ PruningBufferLen: 1000, SnapshotsBufferLen: 10, @@ -920,8 +920,8 @@ func CreateTxProcessorWithOneSCExecutorWithVMs( argsNewSCProcessor.AccountsDB = readOnlyAccountsDB - vmOutputCacher, _ := storageUnit.NewCache(storageUnit.CacheConfig{ - Type: storageUnit.LRUCache, + vmOutputCacher, _ := storageunit.NewCache(storageunit.CacheConfig{ + Type: storageunit.LRUCache, Capacity: 10000, }) txSimulatorProcessorArgs := txsimulator.ArgsTxSimulator{ @@ -1003,7 +1003,7 @@ func TestDeployedContractContents( assert.NotNil(t, destinationRecovShardAccount.GetRootHash()) for variable, requiredVal := range dataValues { - contractVariableData, err := destinationRecovShardAccount.DataTrieTracker().RetrieveValue([]byte(variable)) + contractVariableData, err := destinationRecovShardAccount.RetrieveValue([]byte(variable)) assert.Nil(t, err) assert.NotNil(t, contractVariableData) diff --git a/integrationTests/vm/txsFee/utils/utilsESDT.go b/integrationTests/vm/txsFee/utils/utilsESDT.go index 042d71df482..6d1b2286e42 100644 --- a/integrationTests/vm/txsFee/utils/utilsESDT.go +++ b/integrationTests/vm/txsFee/utils/utilsESDT.go @@ -55,7 +55,7 @@ func CreateAccountWithESDTBalance( key = append(key, big.NewInt(0).SetUint64(esdtNonce).Bytes()...) } - err = userAccount.DataTrieTracker().SaveKeyValue(key, esdtDataBytes) + err = userAccount.SaveKeyValue(key, esdtDataBytes) require.Nil(t, err) err = accnts.SaveAccount(account) @@ -82,7 +82,7 @@ func saveNewTokenOnSystemAccount(t *testing.T, accnts state.AccountsAdapter, tok sysUserAccount, ok := sysAccount.(state.UserAccountHandler) require.True(t, ok) - err = sysUserAccount.DataTrieTracker().SaveKeyValue(tokenKey, esdtDataBytes) + err = sysUserAccount.SaveKeyValue(tokenKey, esdtDataBytes) require.Nil(t, err) err = accnts.SaveAccount(sysAccount) @@ -122,7 +122,7 @@ func SetESDTRoles( key = append(key, tokenIdentifier...) if len(roles) == 0 { - err = userAccount.DataTrieTracker().SaveKeyValue(key, []byte{}) + err = userAccount.SaveKeyValue(key, []byte{}) require.Nil(t, err) return @@ -135,7 +135,7 @@ func SetESDTRoles( rolesDataBytes, err := protoMarshalizer.Marshal(rolesData) require.Nil(t, err) - err = userAccount.DataTrieTracker().SaveKeyValue(key, rolesDataBytes) + err = userAccount.SaveKeyValue(key, rolesDataBytes) require.Nil(t, err) err = accnts.SaveAccount(account) @@ -162,7 +162,7 @@ func SetLastNFTNonce( key := append([]byte(core.ElrondProtectedKeyPrefix), []byte(core.ESDTNFTLatestNonceIdentifier)...) key = append(key, tokenIdentifier...) - err = userAccount.DataTrieTracker().SaveKeyValue(key, big.NewInt(int64(lastNonce)).Bytes()) + err = userAccount.SaveKeyValue(key, big.NewInt(int64(lastNonce)).Bytes()) require.Nil(t, err) err = accnts.SaveAccount(account) diff --git a/integrationTests/vm/txsFee/validatorSC_test.go b/integrationTests/vm/txsFee/validatorSC_test.go index 762f71d87c8..b59f6561e1d 100644 --- a/integrationTests/vm/txsFee/validatorSC_test.go +++ b/integrationTests/vm/txsFee/validatorSC_test.go @@ -44,7 +44,7 @@ func saveDelegationManagerConfig(testContext *vm.VMTestContext) { managementData := &systemSmartContracts.DelegationManagement{MinDelegationAmount: big.NewInt(1)} marshaledData, _ := testContext.Marshalizer.Marshal(managementData) - _ = userAcc.DataTrieTracker().SaveKeyValue([]byte(delegationManagementKey), marshaledData) + _ = userAcc.SaveKeyValue([]byte(delegationManagementKey), marshaledData) _ = testContext.Accounts.SaveAccount(userAcc) } @@ -298,7 +298,7 @@ func saveNodesConfig(t *testing.T, testContext *vm.VMTestContext, stakedNodes, m } nodesDataBytes, _ := protoMarshalizer.Marshal(nodesConfigData) - _ = userAccount.DataTrieTracker().SaveKeyValue([]byte("nodesConfig"), nodesDataBytes) + _ = userAccount.SaveKeyValue([]byte("nodesConfig"), nodesDataBytes) _ = testContext.Accounts.SaveAccount(account) _, _ = testContext.Accounts.Commit() } diff --git a/node/external/blockAPI/baseBlock_test.go b/node/external/blockAPI/baseBlock_test.go index c19db5899d1..a3007c838cd 100644 --- a/node/external/blockAPI/baseBlock_test.go +++ b/node/external/blockAPI/baseBlock_test.go @@ -115,13 +115,13 @@ func TestBaseBlockGetIntraMiniblocksReceipts(t *testing.T) { TxHashes: [][]byte{receiptHash}, } - receipt := &receipt.Receipt{ + receiptObj := &receipt.Receipt{ Value: big.NewInt(1000), SndAddr: []byte("sndAddr"), Data: []byte("refund"), TxHash: []byte("hash"), } - receiptBytes, _ := baseAPIBlockProc.marshalizer.Marshal(receipt) + receiptBytes, _ := baseAPIBlockProc.marshalizer.Marshal(receiptObj) baseAPIBlockProc.store = genericMocks.NewChainStorerMock(0) storer, _ := baseAPIBlockProc.store.GetStorer(dataRetriever.UnsignedTransactionUnit) @@ -136,10 +136,10 @@ func TestBaseBlockGetIntraMiniblocksReceipts(t *testing.T) { baseAPIBlockProc.apiTransactionHandler = &mock.TransactionAPIHandlerStub{ UnmarshalReceiptCalled: func(receiptBytes []byte) (*transaction.ApiReceipt, error) { return &transaction.ApiReceipt{ - Value: receipt.Value, - SndAddr: baseAPIBlockProc.addressPubKeyConverter.Encode(receipt.SndAddr), - Data: string(receipt.Data), - TxHash: hex.EncodeToString(receipt.TxHash), + Value: receiptObj.Value, + SndAddr: baseAPIBlockProc.addressPubKeyConverter.Encode(receiptObj.SndAddr), + Data: string(receiptObj.Data), + TxHash: hex.EncodeToString(receiptObj.TxHash), }, nil }, } diff --git a/node/external/timemachine/fee/memory_test.go b/node/external/timemachine/fee/memoryFootprint/memory_test.go similarity index 82% rename from node/external/timemachine/fee/memory_test.go rename to node/external/timemachine/fee/memoryFootprint/memory_test.go index e41c34f33df..5316cc3932a 100644 --- a/node/external/timemachine/fee/memory_test.go +++ b/node/external/timemachine/fee/memoryFootprint/memory_test.go @@ -1,4 +1,4 @@ -package fee +package memoryFootprint import ( "fmt" @@ -7,21 +7,24 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/data/transaction" + "github.com/ElrondNetwork/elrond-go/node/external/timemachine/fee" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/stretchr/testify/require" ) +// keep this test in a separate package as to not be influenced by other the tests from the same package func TestFeeComputer_MemoryFootprint(t *testing.T) { numEpochs := 10000 - maxFootprintNumBytes := 20_000_000 + maxFootprintNumBytes := 30_000_000 journal := &memoryFootprintJournal{} journal.before = getMemStats() - computer, _ := NewFeeComputer(ArgsNewFeeComputer{ + feeComputer, _ := fee.NewFeeComputer(fee.ArgsNewFeeComputer{ BuiltInFunctionsCostHandler: &testscommon.BuiltInCostHandlerStub{}, EconomicsConfig: testscommon.GetEconomicsConfig(), }) + computer := fee.NewTestFeeComputer(feeComputer) tx := &transaction.Transaction{ GasLimit: 50000, @@ -43,7 +46,7 @@ func TestFeeComputer_MemoryFootprint(t *testing.T) { _ = computer.ComputeTransactionFee(&transaction.ApiTransactionResult{Epoch: uint32(0), Tx: tx}) journal.display() - require.Len(t, computer.economicsInstances, numEpochs) + require.Equal(t, numEpochs, computer.LenEconomicsInstances()) require.Less(t, journal.footprint(), uint64(maxFootprintNumBytes)) } diff --git a/node/external/timemachine/fee/testFeeComputer.go b/node/external/timemachine/fee/testFeeComputer.go new file mode 100644 index 00000000000..fc003effb6d --- /dev/null +++ b/node/external/timemachine/fee/testFeeComputer.go @@ -0,0 +1,26 @@ +package fee + +// testFeeComputer is an exported struct that should be used only in tests +type testFeeComputer struct { + *feeComputer +} + +// NewTestFeeComputer creates a new instance of type testFeeComputer +func NewTestFeeComputer(feeComputerInstance *feeComputer) *testFeeComputer { + return &testFeeComputer{ + feeComputer: feeComputerInstance, + } +} + +// LenEconomicsInstances returns the number of economic instances +func (computer *testFeeComputer) LenEconomicsInstances() int { + computer.mutex.RLock() + defer computer.mutex.RUnlock() + + return len(computer.economicsInstances) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (computer *testFeeComputer) IsInterfaceNil() bool { + return computer == nil +} diff --git a/node/external/transactionAPI/apiTransactionProcessor_test.go b/node/external/transactionAPI/apiTransactionProcessor_test.go index 450a9a1aed8..457bb8aa981 100644 --- a/node/external/transactionAPI/apiTransactionProcessor_test.go +++ b/node/external/transactionAPI/apiTransactionProcessor_test.go @@ -374,6 +374,7 @@ func testWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *testing.T apiTransactionProc, _ := NewAPITransactionProcessor(args) _, err := apiTransactionProc.getTransactionFromStorage([]byte("txHash")) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), ErrTransactionNotFound.Error())) } } diff --git a/node/mock/accountsWrapperMock.go b/node/mock/accountsWrapperMock.go deleted file mode 100644 index b2d80d80976..00000000000 --- a/node/mock/accountsWrapperMock.go +++ /dev/null @@ -1,165 +0,0 @@ -package mock - -import ( - "math/big" - - "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/state" -) - -// AccountWrapMock - -type AccountWrapMock struct { - MockValue int - dataTrie common.Trie - nonce uint64 - code []byte - codeMetadata []byte - codeHash []byte - rootHash []byte - address []byte - trackableDataTrie state.DataTrieTracker - - SetNonceWithJournalCalled func(nonce uint64) error `json:"-"` - SetCodeHashWithJournalCalled func(codeHash []byte) error `json:"-"` - SetCodeWithJournalCalled func(codeHash []byte) error `json:"-"` -} - -// SetTrackableDataTrie - -func (awm *AccountWrapMock) SetTrackableDataTrie(tdt state.DataTrieTracker) { - awm.trackableDataTrie = tdt -} - -// HasNewCode - -func (awm *AccountWrapMock) HasNewCode() bool { - return false -} - -// SetUserName - -func (awm *AccountWrapMock) SetUserName(_ []byte) { -} - -// GetUserName - -func (awm *AccountWrapMock) GetUserName() []byte { - return nil -} - -// AddToBalance - -func (awm *AccountWrapMock) AddToBalance(_ *big.Int) error { - return nil -} - -// SubFromBalance - -func (awm *AccountWrapMock) SubFromBalance(_ *big.Int) error { - return nil -} - -// GetBalance - -func (awm *AccountWrapMock) GetBalance() *big.Int { - return nil -} - -// ClaimDeveloperRewards - -func (awm *AccountWrapMock) ClaimDeveloperRewards([]byte) (*big.Int, error) { - return nil, nil -} - -// AddToDeveloperReward - -func (awm *AccountWrapMock) AddToDeveloperReward(*big.Int) { - -} - -// GetDeveloperReward - -func (awm *AccountWrapMock) GetDeveloperReward() *big.Int { - return nil -} - -// ChangeOwnerAddress - -func (awm *AccountWrapMock) ChangeOwnerAddress([]byte, []byte) error { - return nil -} - -// SetOwnerAddress - -func (awm *AccountWrapMock) SetOwnerAddress([]byte) { - -} - -// GetOwnerAddress - -func (awm *AccountWrapMock) GetOwnerAddress() []byte { - return nil -} - -// GetCodeHash - -func (awm *AccountWrapMock) GetCodeHash() []byte { - return awm.codeHash -} - -// SetCodeHash - -func (awm *AccountWrapMock) SetCodeHash(codeHash []byte) { - awm.codeHash = codeHash -} - -// SetCode - -func (awm *AccountWrapMock) SetCode(code []byte) { - awm.code = code -} - -// SetCodeMetadata - -func (awm *AccountWrapMock) SetCodeMetadata(codeMetadata []byte) { - awm.codeMetadata = codeMetadata -} - -// RetrieveValueFromDataTrieTracker - -func (awm *AccountWrapMock) RetrieveValueFromDataTrieTracker(key []byte) ([]byte, error) { - return awm.trackableDataTrie.RetrieveValue(key) -} - -// GetCodeMetadata - -func (awm *AccountWrapMock) GetCodeMetadata() []byte { - return awm.codeMetadata -} - -// GetRootHash - -func (awm *AccountWrapMock) GetRootHash() []byte { - return awm.rootHash -} - -// SetRootHash - -func (awm *AccountWrapMock) SetRootHash(rootHash []byte) { - awm.rootHash = rootHash -} - -// AddressBytes - -func (awm *AccountWrapMock) AddressBytes() []byte { - return awm.address -} - -// DataTrie - -func (awm *AccountWrapMock) DataTrie() common.Trie { - return awm.dataTrie -} - -// SetDataTrie - -func (awm *AccountWrapMock) SetDataTrie(trie common.Trie) { - awm.dataTrie = trie - awm.trackableDataTrie.SetDataTrie(trie) -} - -// DataTrieTracker - -func (awm *AccountWrapMock) DataTrieTracker() state.DataTrieTracker { - return awm.trackableDataTrie -} - -// IncreaseNonce - -func (awm *AccountWrapMock) IncreaseNonce(val uint64) { - awm.nonce = awm.nonce + val -} - -// GetNonce - -func (awm *AccountWrapMock) GetNonce() uint64 { - return awm.nonce -} - -// IsInterfaceNil - -func (awm *AccountWrapMock) IsInterfaceNil() bool { - return awm == nil -} diff --git a/node/mock/factory/cryptoComponentsStub.go b/node/mock/factory/cryptoComponentsStub.go index f7b1c93821d..368a93d67b1 100644 --- a/node/mock/factory/cryptoComponentsStub.go +++ b/node/mock/factory/cryptoComponentsStub.go @@ -1,9 +1,11 @@ package factory import ( + "errors" "sync" "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/heartbeat" "github.com/ElrondNetwork/elrond-go/vm" @@ -18,7 +20,7 @@ type CryptoComponentsMock struct { PubKeyBytes []byte BlockSig crypto.SingleSigner TxSig crypto.SingleSigner - MultiSig crypto.MultiSigner + MultiSigContainer cryptoCommon.MultiSignerContainer PeerSignHandler crypto.PeerSignatureHandler BlKeyGen crypto.KeyGenerator TxKeyGen crypto.KeyGenerator @@ -78,29 +80,41 @@ func (ccm *CryptoComponentsMock) TxSingleSigner() crypto.SingleSigner { return ccm.TxSig } -// MultiSigner - -func (ccm *CryptoComponentsMock) MultiSigner() crypto.MultiSigner { +// MultiSignerContainer - +func (ccm *CryptoComponentsMock) MultiSignerContainer() cryptoCommon.MultiSignerContainer { ccm.mutMultiSig.RLock() defer ccm.mutMultiSig.RUnlock() - return ccm.MultiSig + return ccm.MultiSigContainer } -// PeerSignatureHandler - -func (ccm *CryptoComponentsMock) PeerSignatureHandler() crypto.PeerSignatureHandler { +// SetMultiSignerContainer - +func (ccm *CryptoComponentsMock) SetMultiSignerContainer(ms cryptoCommon.MultiSignerContainer) error { + ccm.mutMultiSig.Lock() + ccm.MultiSigContainer = ms + ccm.mutMultiSig.Unlock() + + return nil +} + +// GetMultiSigner - +func (ccm *CryptoComponentsMock) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { ccm.mutMultiSig.RLock() defer ccm.mutMultiSig.RUnlock() - return ccm.PeerSignHandler + if ccm.MultiSigContainer == nil { + return nil, errors.New("nil multi sig container") + } + + return ccm.MultiSigContainer.GetMultiSigner(epoch) } -// SetMultiSigner - -func (ccm *CryptoComponentsMock) SetMultiSigner(ms crypto.MultiSigner) error { - ccm.mutMultiSig.Lock() - ccm.MultiSig = ms - ccm.mutMultiSig.Unlock() +// PeerSignatureHandler - +func (ccm *CryptoComponentsMock) PeerSignatureHandler() crypto.PeerSignatureHandler { + ccm.mutMultiSig.RLock() + defer ccm.mutMultiSig.RUnlock() - return nil + return ccm.PeerSignHandler } // BlockSignKeyGen - @@ -138,7 +152,7 @@ func (ccm *CryptoComponentsMock) Clone() interface{} { PubKeyBytes: ccm.PubKeyBytes, BlockSig: ccm.BlockSig, TxSig: ccm.TxSig, - MultiSig: ccm.MultiSig, + MultiSigContainer: ccm.MultiSigContainer, PeerSignHandler: ccm.PeerSignHandler, BlKeyGen: ccm.BlKeyGen, TxKeyGen: ccm.TxKeyGen, diff --git a/node/mock/trieHolderStub.go b/node/mock/trieHolderStub.go deleted file mode 100644 index 085a027cea3..00000000000 --- a/node/mock/trieHolderStub.go +++ /dev/null @@ -1,56 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go/common" -) - -// TriesHolderStub - -type TriesHolderStub struct { - PutCalled func([]byte, common.Trie) - RemoveCalled func([]byte, common.Trie) - GetCalled func([]byte) common.Trie - GetAllCalled func() []common.Trie - ResetCalled func() -} - -// Put - -func (ths *TriesHolderStub) Put(key []byte, trie common.Trie) { - if ths.PutCalled != nil { - ths.PutCalled(key, trie) - } -} - -// Replace - -func (ths *TriesHolderStub) Replace(key []byte, trie common.Trie) { - if ths.RemoveCalled != nil { - ths.RemoveCalled(key, trie) - } -} - -// Get - -func (ths *TriesHolderStub) Get(key []byte) common.Trie { - if ths.GetCalled != nil { - return ths.GetCalled(key) - } - return nil -} - -// GetAll - -func (ths *TriesHolderStub) GetAll() []common.Trie { - if ths.GetAllCalled != nil { - return ths.GetAllCalled() - } - return nil -} - -// Reset - -func (ths *TriesHolderStub) Reset() { - if ths.ResetCalled != nil { - ths.ResetCalled() - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ths *TriesHolderStub) IsInterfaceNil() bool { - return ths == nil -} diff --git a/node/node.go b/node/node.go index 80e893b6366..9ad75e49d03 100644 --- a/node/node.go +++ b/node/node.go @@ -34,6 +34,7 @@ import ( procTx "github.com/ElrondNetwork/elrond-go/process/transaction" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/trie" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/vm" "github.com/ElrondNetwork/elrond-go/vm/systemSmartContracts" vmcommon "github.com/ElrondNetwork/elrond-vm-common" @@ -216,7 +217,7 @@ func (n *Node) GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]strin } chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash) + err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, err } @@ -284,7 +285,7 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, } chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash) + err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, api.BlockInfo{}, err } @@ -320,7 +321,7 @@ func (n *Node) GetValueForKey(address string, key string, options api.AccountQue return "", api.BlockInfo{}, err } - valueBytes, err := userAccount.DataTrieTracker().RetrieveValue(keyBytes) + valueBytes, err := userAccount.RetrieveValue(keyBytes) if err != nil { return "", api.BlockInfo{}, fmt.Errorf("fetching value error: %w", err) } @@ -378,7 +379,7 @@ func (n *Node) getTokensIDsWithFilter( } chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash) + err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, api.BlockInfo{}, err } @@ -500,7 +501,7 @@ func (n *Node) GetAllESDTTokens(address string, options api.AccountQueryOptions, } chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash) + err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, api.BlockInfo{}, err } @@ -1171,7 +1172,7 @@ func (n *Node) getAccountRootHashAndVal(address []byte, accBytes []byte, key []b return nil, nil, fmt.Errorf("empty dataTrie rootHash") } - retrievedVal, err := userAccount.RetrieveValueFromDataTrieTracker(key) + retrievedVal, err := userAccount.RetrieveValue(key) if err != nil { return nil, nil, err } diff --git a/node/nodeRunner.go b/node/nodeRunner.go index 9671094005a..1997194432e 100644 --- a/node/nodeRunner.go +++ b/node/nodeRunner.go @@ -16,6 +16,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/closing" + "github.com/ElrondNetwork/elrond-go-core/core/throttler" "github.com/ElrondNetwork/elrond-go-core/data/endProcess" logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/api/gin" @@ -32,6 +33,17 @@ import ( "github.com/ElrondNetwork/elrond-go/facade" "github.com/ElrondNetwork/elrond-go/facade/initial" mainFactory "github.com/ElrondNetwork/elrond-go/factory" + apiComp "github.com/ElrondNetwork/elrond-go/factory/api" + bootstrapComp "github.com/ElrondNetwork/elrond-go/factory/bootstrap" + consensusComp "github.com/ElrondNetwork/elrond-go/factory/consensus" + coreComp "github.com/ElrondNetwork/elrond-go/factory/core" + cryptoComp "github.com/ElrondNetwork/elrond-go/factory/crypto" + dataComp "github.com/ElrondNetwork/elrond-go/factory/data" + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" + networkComp "github.com/ElrondNetwork/elrond-go/factory/network" + processComp "github.com/ElrondNetwork/elrond-go/factory/processing" + stateComp "github.com/ElrondNetwork/elrond-go/factory/state" + statusComp "github.com/ElrondNetwork/elrond-go/factory/status" "github.com/ElrondNetwork/elrond-go/genesis" "github.com/ElrondNetwork/elrond-go/genesis/parsing" "github.com/ElrondNetwork/elrond-go/health" @@ -41,9 +53,12 @@ import ( "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/interceptors" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" + "github.com/ElrondNetwork/elrond-go/state/syncer" + "github.com/ElrondNetwork/elrond-go/storage/cache" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" + trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" + "github.com/ElrondNetwork/elrond-go/trie/storageMarker" "github.com/ElrondNetwork/elrond-go/update/trigger" "github.com/google/gops/agent" ) @@ -182,6 +197,7 @@ func printEnableEpochs(configs *config.Configs) { log.Debug(readEpochFor("refactor contexts"), "epoch", enableEpochs.RefactorContextEnableEpoch) log.Debug(readEpochFor("disable heartbeat v1"), "epoch", enableEpochs.HeartbeatDisableEpoch) log.Debug(readEpochFor("mini block partial execution"), "epoch", enableEpochs.MiniBlockPartialExecutionEnableEpoch) + log.Debug(readEpochFor("fix async callback arguments list"), "epoch", enableEpochs.FixAsyncCallBackArgsListEnableEpoch) log.Debug(readEpochFor("set sender in eei output transfer"), "epoch", enableEpochs.SetSenderInEeiOutputTransferEnableEpoch) log.Debug(readEpochFor("refactor peers mini blocks"), "epoch", enableEpochs.RefactorPeersMiniBlocksEnableEpoch) gasSchedule := configs.EpochConfig.GasSchedule @@ -312,7 +328,7 @@ func (nr *nodeRunner) executeOneComponentCreationCycle( log.Debug("registering components in healthService") nr.registerDataComponentsInHealthService(healthService, managedDataComponents) - nodesShufflerOut, err := mainFactory.CreateNodesShuffleOut( + nodesShufflerOut, err := bootstrapComp.CreateNodesShuffleOut( managedCoreComponents.GenesisNodesSetup(), configs.GeneralConfig.EpochStartConfig, managedCoreComponents.ChanStopNodeProcess(), @@ -327,7 +343,7 @@ func (nr *nodeRunner) executeOneComponentCreationCycle( } log.Debug("creating nodes coordinator") - nodesCoordinatorInstance, err := mainFactory.CreateNodesCoordinator( + nodesCoordinatorInstance, err := bootstrapComp.CreateNodesCoordinator( nodesShufflerOut, managedCoreComponents.GenesisNodesSetup(), configs.PreferencesConfig.Preferences, @@ -391,6 +407,18 @@ func (nr *nodeRunner) executeOneComponentCreationCycle( return true, err } + err = addSyncersToAccountsDB( + configs.GeneralConfig, + managedCoreComponents, + managedDataComponents, + managedStateComponents, + managedBootstrapComponents, + managedProcessComponents, + ) + if err != nil { + return true, err + } + hardforkTrigger := managedProcessComponents.HardforkTrigger() err = hardforkTrigger.AddCloser(nodesShufflerOut) if err != nil { @@ -511,6 +539,139 @@ func (nr *nodeRunner) executeOneComponentCreationCycle( return false, nil } +func addSyncersToAccountsDB( + config *config.Config, + coreComponents mainFactory.CoreComponentsHolder, + dataComponents mainFactory.DataComponentsHolder, + stateComponents mainFactory.StateComponentsHolder, + bootstrapComponents mainFactory.BootstrapComponentsHolder, + processComponents mainFactory.ProcessComponentsHolder, +) error { + selfId := bootstrapComponents.ShardCoordinator().SelfId() + if selfId == core.MetachainShardId { + stateSyncer, err := getValidatorAccountSyncer( + config, + coreComponents, + dataComponents, + stateComponents, + processComponents, + ) + if err != nil { + return err + } + + err = stateComponents.PeerAccounts().SetSyncer(stateSyncer) + if err != nil { + return err + } + + err = stateComponents.PeerAccounts().StartSnapshotIfNeeded() + if err != nil { + return err + } + } + + stateSyncer, err := getUserAccountSyncer( + config, + coreComponents, + dataComponents, + stateComponents, + bootstrapComponents, + processComponents, + ) + if err != nil { + return err + } + err = stateComponents.AccountsAdapter().SetSyncer(stateSyncer) + if err != nil { + return err + } + + return stateComponents.AccountsAdapter().StartSnapshotIfNeeded() +} + +func getUserAccountSyncer( + config *config.Config, + coreComponents mainFactory.CoreComponentsHolder, + dataComponents mainFactory.DataComponentsHolder, + stateComponents mainFactory.StateComponentsHolder, + bootstrapComponents mainFactory.BootstrapComponentsHolder, + processComponents mainFactory.ProcessComponentsHolder, +) (process.AccountsDBSyncer, error) { + maxTrieLevelInMemory := config.StateTriesConfig.MaxStateTrieLevelInMemory + userTrie := stateComponents.TriesContainer().Get([]byte(trieFactory.UserAccountTrie)) + storageManager := userTrie.GetStorageManager() + + thr, err := throttler.NewNumGoRoutinesThrottler(int32(config.TrieSync.NumConcurrentTrieSyncers)) + if err != nil { + return nil, err + } + + args := syncer.ArgsNewUserAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getBaseAccountSyncerArgs( + config, + coreComponents, + dataComponents, + processComponents, + storageManager, + maxTrieLevelInMemory, + ), + ShardId: bootstrapComponents.ShardCoordinator().SelfId(), + Throttler: thr, + AddressPubKeyConverter: coreComponents.AddressPubKeyConverter(), + } + + return syncer.NewUserAccountsSyncer(args) +} + +func getValidatorAccountSyncer( + config *config.Config, + coreComponents mainFactory.CoreComponentsHolder, + dataComponents mainFactory.DataComponentsHolder, + stateComponents mainFactory.StateComponentsHolder, + processComponents mainFactory.ProcessComponentsHolder, +) (process.AccountsDBSyncer, error) { + maxTrieLevelInMemory := config.StateTriesConfig.MaxPeerTrieLevelInMemory + peerTrie := stateComponents.TriesContainer().Get([]byte(trieFactory.PeerAccountTrie)) + storageManager := peerTrie.GetStorageManager() + + args := syncer.ArgsNewValidatorAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getBaseAccountSyncerArgs( + config, + coreComponents, + dataComponents, + processComponents, + storageManager, + maxTrieLevelInMemory, + ), + } + + return syncer.NewValidatorAccountsSyncer(args) +} + +func getBaseAccountSyncerArgs( + config *config.Config, + coreComponents mainFactory.CoreComponentsHolder, + dataComponents mainFactory.DataComponentsHolder, + processComponents mainFactory.ProcessComponentsHolder, + storageManager common.StorageManager, + maxTrieLevelInMemory uint, +) syncer.ArgsNewBaseAccountsSyncer { + return syncer.ArgsNewBaseAccountsSyncer{ + Hasher: coreComponents.Hasher(), + Marshalizer: coreComponents.InternalMarshalizer(), + TrieStorageManager: storageManager, + RequestHandler: processComponents.RequestHandler(), + Timeout: common.TimeoutGettingTrieNodes, + Cacher: dataComponents.Datapool().TrieNodes(), + MaxTrieLevelInMemory: maxTrieLevelInMemory, + MaxHardCapForMissingNodes: config.TrieSync.MaxHardCapForMissingNodes, + TrieSyncerVersion: config.TrieSync.TrieSyncerVersion, + StorageMarker: storageMarker.NewDisabledStorageMarker(), + CheckNodesOnDisk: true, + } +} + func (nr *nodeRunner) createApiFacade( currentNode *Node, upgradableHttpServer shared.UpgradeableHttpServerHandler, @@ -521,7 +682,7 @@ func (nr *nodeRunner) createApiFacade( log.Debug("creating api resolver structure") - apiResolverArgs := &mainFactory.ApiResolverArgs{ + apiResolverArgs := &apiComp.ApiResolverArgs{ Configs: configs, CoreComponents: currentNode.coreComponents, DataComponents: currentNode.dataComponents, @@ -534,7 +695,7 @@ func (nr *nodeRunner) createApiFacade( AllowVMQueriesChan: allowVMQueriesChan, } - apiResolver, err := mainFactory.CreateApiResolver(apiResolverArgs) + apiResolver, err := apiComp.CreateApiResolver(apiResolverArgs) if err != nil { return nil, err } @@ -669,7 +830,7 @@ func (nr *nodeRunner) CreateManagedConsensusComponents( return nil, err } - consensusArgs := mainFactory.ConsensusComponentsFactoryArgs{ + consensusArgs := consensusComp.ConsensusComponentsFactoryArgs{ Config: *nr.configs.GeneralConfig, BootstrapRoundIndex: nr.configs.FlagsConfig.BootstrapRoundIndex, CoreComponents: coreComponents, @@ -684,12 +845,12 @@ func (nr *nodeRunner) CreateManagedConsensusComponents( ShouldDisableWatchdog: nr.configs.FlagsConfig.DisableConsensusWatchdog, } - consensusFactory, err := mainFactory.NewConsensusComponentsFactory(consensusArgs) + consensusFactory, err := consensusComp.NewConsensusComponentsFactory(consensusArgs) if err != nil { return nil, fmt.Errorf("NewConsensusComponentsFactory failed: %w", err) } - managedConsensusComponents, err := mainFactory.NewManagedConsensusComponents(consensusFactory) + managedConsensusComponents, err := consensusComp.NewManagedConsensusComponents(consensusFactory) if err != nil { return nil, err } @@ -712,7 +873,7 @@ func (nr *nodeRunner) CreateManagedHeartbeatComponents( ) (mainFactory.HeartbeatComponentsHandler, error) { genesisTime := time.Unix(coreComponents.GenesisNodesSetup().GetStartTime(), 0) - heartbeatArgs := mainFactory.HeartbeatComponentsFactoryArgs{ + heartbeatArgs := heartbeatComp.HeartbeatComponentsFactoryArgs{ Config: *nr.configs.GeneralConfig, Prefs: *nr.configs.PreferencesConfig, AppVersion: nr.configs.FlagsConfig.Version, @@ -725,12 +886,12 @@ func (nr *nodeRunner) CreateManagedHeartbeatComponents( ProcessComponents: processComponents, } - heartbeatComponentsFactory, err := mainFactory.NewHeartbeatComponentsFactory(heartbeatArgs) + heartbeatComponentsFactory, err := heartbeatComp.NewHeartbeatComponentsFactory(heartbeatArgs) if err != nil { return nil, fmt.Errorf("NewHeartbeatComponentsFactory failed: %w", err) } - managedHeartbeatComponents, err := mainFactory.NewManagedHeartbeatComponents(heartbeatComponentsFactory) + managedHeartbeatComponents, err := heartbeatComp.NewManagedHeartbeatComponents(heartbeatComponentsFactory) if err != nil { return nil, err } @@ -751,7 +912,7 @@ func (nr *nodeRunner) CreateManagedHeartbeatV2Components( dataComponents mainFactory.DataComponentsHolder, processComponents mainFactory.ProcessComponentsHolder, ) (mainFactory.HeartbeatV2ComponentsHandler, error) { - heartbeatV2Args := mainFactory.ArgHeartbeatV2ComponentsFactory{ + heartbeatV2Args := heartbeatComp.ArgHeartbeatV2ComponentsFactory{ Config: *nr.configs.GeneralConfig, Prefs: *nr.configs.PreferencesConfig, BaseVersion: nr.configs.FlagsConfig.BaseVersion, @@ -764,12 +925,12 @@ func (nr *nodeRunner) CreateManagedHeartbeatV2Components( ProcessComponents: processComponents, } - heartbeatV2ComponentsFactory, err := mainFactory.NewHeartbeatV2ComponentsFactory(heartbeatV2Args) + heartbeatV2ComponentsFactory, err := heartbeatComp.NewHeartbeatV2ComponentsFactory(heartbeatV2Args) if err != nil { return nil, fmt.Errorf("NewHeartbeatV2ComponentsFactory failed: %w", err) } - managedHeartbeatV2Components, err := mainFactory.NewManagedHeartbeatV2Components(heartbeatV2ComponentsFactory) + managedHeartbeatV2Components, err := heartbeatComp.NewManagedHeartbeatV2Components(heartbeatV2ComponentsFactory) if err != nil { return nil, err } @@ -896,7 +1057,7 @@ func (nr *nodeRunner) CreateManagedStatusComponents( nodesCoordinator nodesCoordinator.NodesCoordinator, isInImportMode bool, ) (mainFactory.StatusComponentsHandler, error) { - statArgs := mainFactory.StatusComponentsFactoryArgs{ + statArgs := statusComp.StatusComponentsFactoryArgs{ Config: *nr.configs.GeneralConfig, ExternalConfig: *nr.configs.ExternalConfig, EconomicsConfig: *nr.configs.EconomicsConfig, @@ -910,12 +1071,12 @@ func (nr *nodeRunner) CreateManagedStatusComponents( IsInImportMode: isInImportMode, } - statusComponentsFactory, err := mainFactory.NewStatusComponentsFactory(statArgs) + statusComponentsFactory, err := statusComp.NewStatusComponentsFactory(statArgs) if err != nil { return nil, fmt.Errorf("NewStatusComponentsFactory failed: %w", err) } - managedStatusComponents, err := mainFactory.NewManagedStatusComponents(statusComponentsFactory) + managedStatusComponents, err := statusComp.NewManagedStatusComponents(statusComponentsFactory) if err != nil { return nil, err } @@ -1023,7 +1184,7 @@ func (nr *nodeRunner) CreateManagedProcessComponents( return nil, err } - whiteListCache, err := storageUnit.NewCache(storageFactory.GetCacherFromConfig(configs.GeneralConfig.WhiteListPool)) + whiteListCache, err := storageunit.NewCache(storageFactory.GetCacherFromConfig(configs.GeneralConfig.WhiteListPool)) if err != nil { return nil, err } @@ -1043,10 +1204,10 @@ func (nr *nodeRunner) CreateManagedProcessComponents( } log.Trace("creating time cache for requested items components") - requestedItemsHandler := timecache.NewTimeCache( + requestedItemsHandler := cache.NewTimeCache( time.Duration(uint64(time.Millisecond) * coreComponents.GenesisNodesSetup().GetRoundDuration())) - processArgs := mainFactory.ProcessComponentsFactoryArgs{ + processArgs := processComp.ProcessComponentsFactoryArgs{ Config: *configs.GeneralConfig, EpochConfig: *configs.EpochConfig, PrefConfigs: configs.PreferencesConfig.Preferences, @@ -1072,12 +1233,12 @@ func (nr *nodeRunner) CreateManagedProcessComponents( WorkingDir: configs.FlagsConfig.WorkingDir, HistoryRepo: historyRepository, } - processComponentsFactory, err := mainFactory.NewProcessComponentsFactory(processArgs) + processComponentsFactory, err := processComp.NewProcessComponentsFactory(processArgs) if err != nil { return nil, fmt.Errorf("NewProcessComponentsFactory failed: %w", err) } - managedProcessComponents, err := mainFactory.NewManagedProcessComponents(processComponentsFactory) + managedProcessComponents, err := processComp.NewManagedProcessComponents(processComponentsFactory) if err != nil { return nil, err } @@ -1103,7 +1264,7 @@ func (nr *nodeRunner) CreateManagedDataComponents( storerEpoch = 0 } - dataArgs := mainFactory.DataComponentsFactoryArgs{ + dataArgs := dataComp.DataComponentsFactoryArgs{ Config: *configs.GeneralConfig, PrefsConfig: configs.PreferencesConfig.Preferences, ShardCoordinator: bootstrapComponents.ShardCoordinator(), @@ -1113,11 +1274,11 @@ func (nr *nodeRunner) CreateManagedDataComponents( CreateTrieEpochRootHashStorer: configs.ImportDbConfig.ImportDbSaveTrieEpochRootHash, } - dataComponentsFactory, err := mainFactory.NewDataComponentsFactory(dataArgs) + dataComponentsFactory, err := dataComp.NewDataComponentsFactory(dataArgs) if err != nil { return nil, fmt.Errorf("NewDataComponentsFactory failed: %w", err) } - managedDataComponents, err := mainFactory.NewManagedDataComponents(dataComponentsFactory) + managedDataComponents, err := dataComp.NewManagedDataComponents(dataComponentsFactory) if err != nil { return nil, err } @@ -1150,7 +1311,7 @@ func (nr *nodeRunner) CreateManagedStateComponents( if nr.configs.ImportDbConfig.IsImportDBMode { processingMode = common.ImportDb } - stateArgs := mainFactory.StateComponentsFactoryArgs{ + stateArgs := stateComp.StateComponentsFactoryArgs{ Config: *nr.configs.GeneralConfig, ShardCoordinator: bootstrapComponents.ShardCoordinator(), Core: coreComponents, @@ -1160,12 +1321,12 @@ func (nr *nodeRunner) CreateManagedStateComponents( ChainHandler: dataComponents.Blockchain(), } - stateComponentsFactory, err := mainFactory.NewStateComponentsFactory(stateArgs) + stateComponentsFactory, err := stateComp.NewStateComponentsFactory(stateArgs) if err != nil { return nil, fmt.Errorf("NewStateComponentsFactory failed: %w", err) } - managedStateComponents, err := mainFactory.NewManagedStateComponents(stateComponentsFactory) + managedStateComponents, err := stateComp.NewManagedStateComponents(stateComponentsFactory) if err != nil { return nil, err } @@ -1184,7 +1345,7 @@ func (nr *nodeRunner) CreateManagedBootstrapComponents( networkComponents mainFactory.NetworkComponentsHolder, ) (mainFactory.BootstrapComponentsHandler, error) { - bootstrapComponentsFactoryArgs := mainFactory.BootstrapComponentsFactoryArgs{ + bootstrapComponentsFactoryArgs := bootstrapComp.BootstrapComponentsFactoryArgs{ Config: *nr.configs.GeneralConfig, PrefConfig: *nr.configs.PreferencesConfig, ImportDbConfig: *nr.configs.ImportDbConfig, @@ -1195,12 +1356,12 @@ func (nr *nodeRunner) CreateManagedBootstrapComponents( NetworkComponents: networkComponents, } - bootstrapComponentsFactory, err := mainFactory.NewBootstrapComponentsFactory(bootstrapComponentsFactoryArgs) + bootstrapComponentsFactory, err := bootstrapComp.NewBootstrapComponentsFactory(bootstrapComponentsFactoryArgs) if err != nil { return nil, fmt.Errorf("NewBootstrapComponentsFactory failed: %w", err) } - managedBootstrapComponents, err := mainFactory.NewManagedBootstrapComponents(bootstrapComponentsFactory) + managedBootstrapComponents, err := bootstrapComp.NewManagedBootstrapComponents(bootstrapComponentsFactory) if err != nil { return nil, err } @@ -1222,7 +1383,7 @@ func (nr *nodeRunner) CreateManagedNetworkComponents( return nil, err } - networkComponentsFactoryArgs := mainFactory.NetworkComponentsFactoryArgs{ + networkComponentsFactoryArgs := networkComp.NetworkComponentsFactoryArgs{ P2pConfig: *nr.configs.P2pConfig, MainConfig: *nr.configs.GeneralConfig, RatingsConfig: *nr.configs.RatingsConfig, @@ -1233,6 +1394,7 @@ func (nr *nodeRunner) CreateManagedNetworkComponents( BootstrapWaitTime: common.TimeToWaitForP2PBootstrap, NodeOperationMode: p2p.NormalOperation, ConnectionWatcherType: nr.configs.PreferencesConfig.Preferences.ConnectionWatcherType, + P2pKeyPemFileName: nr.configs.ConfigurationPathsHolder.P2pKey, } if nr.configs.ImportDbConfig.IsImportDBMode { networkComponentsFactoryArgs.BootstrapWaitTime = 0 @@ -1241,12 +1403,12 @@ func (nr *nodeRunner) CreateManagedNetworkComponents( networkComponentsFactoryArgs.NodeOperationMode = p2p.FullArchiveMode } - networkComponentsFactory, err := mainFactory.NewNetworkComponentsFactory(networkComponentsFactoryArgs) + networkComponentsFactory, err := networkComp.NewNetworkComponentsFactory(networkComponentsFactoryArgs) if err != nil { return nil, fmt.Errorf("NewNetworkComponentsFactory failed: %w", err) } - managedNetworkComponents, err := mainFactory.NewManagedNetworkComponents(networkComponentsFactory) + managedNetworkComponents, err := networkComp.NewManagedNetworkComponents(networkComponentsFactory) if err != nil { return nil, err } @@ -1266,7 +1428,7 @@ func (nr *nodeRunner) CreateManagedCoreComponents( return nil, err } - coreArgs := mainFactory.CoreComponentsFactoryArgs{ + coreArgs := coreComp.CoreComponentsFactoryArgs{ Config: *nr.configs.GeneralConfig, ConfigPathsHolder: *nr.configs.ConfigurationPathsHolder, EpochConfig: *nr.configs.EpochConfig, @@ -1280,12 +1442,12 @@ func (nr *nodeRunner) CreateManagedCoreComponents( StatusHandlersFactory: statusHandlersFactory, } - coreComponentsFactory, err := mainFactory.NewCoreComponentsFactory(coreArgs) + coreComponentsFactory, err := coreComp.NewCoreComponentsFactory(coreArgs) if err != nil { return nil, fmt.Errorf("NewCoreComponentsFactory failed: %w", err) } - managedCoreComponents, err := mainFactory.NewManagedCoreComponents(coreComponentsFactory) + managedCoreComponents, err := coreComp.NewManagedCoreComponents(coreComponentsFactory) if err != nil { return nil, err } @@ -1304,7 +1466,7 @@ func (nr *nodeRunner) CreateManagedCryptoComponents( ) (mainFactory.CryptoComponentsHandler, error) { configs := nr.configs validatorKeyPemFileName := configs.ConfigurationPathsHolder.ValidatorKey - cryptoComponentsHandlerArgs := mainFactory.CryptoComponentsFactoryArgs{ + cryptoComponentsHandlerArgs := cryptoComp.CryptoComponentsFactoryArgs{ ValidatorKeyPemFileName: validatorKeyPemFileName, SkIndex: configs.FlagsConfig.ValidatorKeyIndex, Config: *configs.GeneralConfig, @@ -1313,15 +1475,16 @@ func (nr *nodeRunner) CreateManagedCryptoComponents( KeyLoader: &core.KeyLoader{}, ImportModeNoSigCheck: configs.ImportDbConfig.ImportDbNoSigCheckFlag, IsInImportMode: configs.ImportDbConfig.IsImportDBMode, + EnableEpochs: configs.EpochConfig.EnableEpochs, NoKeyProvided: configs.FlagsConfig.NoKeyProvided, } - cryptoComponentsFactory, err := mainFactory.NewCryptoComponentsFactory(cryptoComponentsHandlerArgs) + cryptoComponentsFactory, err := cryptoComp.NewCryptoComponentsFactory(cryptoComponentsHandlerArgs) if err != nil { return nil, fmt.Errorf("NewCryptoComponentsFactory failed: %w", err) } - managedCryptoComponents, err := mainFactory.NewManagedCryptoComponents(cryptoComponentsFactory) + managedCryptoComponents, err := cryptoComp.NewManagedCryptoComponents(cryptoComponentsFactory) if err != nil { return nil, err } @@ -1514,7 +1677,7 @@ func decodePreferredPeers(prefConfig config.Preferences, validatorPubKeyConverte } func createWhiteListerVerifiedTxs(generalConfig *config.Config) (process.WhiteListHandler, error) { - whiteListCacheVerified, err := storageUnit.NewCache(storageFactory.GetCacherFromConfig(generalConfig.WhiteListerVerifiedTxs)) + whiteListCacheVerified, err := storageunit.NewCache(storageFactory.GetCacherFromConfig(generalConfig.WhiteListerVerifiedTxs)) if err != nil { return nil, err } diff --git a/node/nodeTesting_test.go b/node/nodeTesting_test.go index 2ae49205cf6..8be5b772ce9 100644 --- a/node/nodeTesting_test.go +++ b/node/nodeTesting_test.go @@ -23,6 +23,7 @@ import ( dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" + trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -395,7 +396,7 @@ func getDefaultCryptoComponents() *factoryMock.CryptoComponentsMock { PubKeyBytes: []byte("pubKey"), BlockSig: &mock.SingleSignerMock{}, TxSig: &mock.SingleSignerMock{}, - MultiSig: cryptoMocks.NewMultiSigner(1), + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock( cryptoMocks.NewMultiSigner()), PeerSignHandler: &mock.PeerSignatureHandler{}, BlKeyGen: &mock.KeyGenMock{}, TxKeyGen: &mock.KeyGenMock{}, @@ -411,7 +412,7 @@ func getDefaultStateComponents() *testscommon.StateComponentsMock { Accounts: &stateMock.AccountsStub{}, AccountsAPI: &stateMock.AccountsStub{}, AccountsRepo: &stateMock.AccountsRepositoryStub{}, - Tries: &mock.TriesHolderStub{}, + Tries: &trieMock.TriesHolderStub{}, StorageManagers: map[string]common.StorageManager{"0": &testscommon.StorageManagerStub{}}, } } diff --git a/node/node_test.go b/node/node_test.go index f104bec548d..1b3fa4f5a48 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -282,9 +282,9 @@ func TestNode_GetKeyValuePairs(t *testing.T) { k2, v2 := []byte("key2"), []byte("value2") accDB := &stateMock.AccountsStub{} - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { suffix := append(k1, acc.AddressBytes()...) trieLeaf := keyValStorage.NewKeyValStorage(k1, append(v1, suffix...)) @@ -344,9 +344,9 @@ func TestNode_GetKeyValuePairsContextShouldTimeout(t *testing.T) { acc, _ := state.NewUserAccount([]byte("newaddress")) accDB := &stateMock.AccountsStub{} - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { time.Sleep(time.Second) close(ch) @@ -398,7 +398,7 @@ func TestNode_GetValueForKey(t *testing.T) { acc, _ := state.NewUserAccount([]byte("newaddress")) k1, v1 := []byte("key1"), []byte("value1") - _ = acc.DataTrieTracker().SaveKeyValue(k1, v1) + _ = acc.SaveKeyValue(k1, v1) accDB := &stateMock.AccountsStub{ GetAccountWithBlockInfoCalled: func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { @@ -537,9 +537,9 @@ func TestNode_GetAllESDTTokens(t *testing.T) { }, } - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { trieLeaf := keyValStorage.NewKeyValStorage(esdtKey, nil) ch <- trieLeaf @@ -592,9 +592,9 @@ func TestNode_GetAllESDTTokens(t *testing.T) { func TestNode_GetAllESDTTokensContextShouldTimeout(t *testing.T) { acc, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { time.Sleep(time.Second) close(ch) @@ -676,9 +676,9 @@ func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { } }, } - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -745,23 +745,23 @@ func TestNode_GetAllIssuedESDTs(t *testing.T) { esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.FungibleESDT)} marshalledData, _ := getMarshalizer().Marshal(esdtData) - _ = acc.DataTrieTracker().SaveKeyValue(esdtToken, marshalledData) + _ = acc.SaveKeyValue(esdtToken, marshalledData) sftData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("semi fungible"), TokenType: []byte(core.SemiFungibleESDT)} sftMarshalledData, _ := getMarshalizer().Marshal(sftData) - _ = acc.DataTrieTracker().SaveKeyValue(sftToken, sftMarshalledData) + _ = acc.SaveKeyValue(sftToken, sftMarshalledData) nftData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("non fungible"), TokenType: []byte(core.NonFungibleESDT)} nftMarshalledData, _ := getMarshalizer().Marshal(nftData) - _ = acc.DataTrieTracker().SaveKeyValue(nftToken, nftMarshalledData) + _ = acc.SaveKeyValue(nftToken, nftMarshalledData) esdtSuffix := append(esdtToken, acc.AddressBytes()...) nftSuffix := append(nftToken, acc.AddressBytes()...) sftSuffix := append(sftToken, acc.AddressBytes()...) - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) ch <- trieLeaf @@ -844,13 +844,13 @@ func TestNode_GetESDTsWithRole(t *testing.T) { esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.FungibleESDT), SpecialRoles: specialRoles} marshalledData, _ := getMarshalizer().Marshal(esdtData) - _ = acc.DataTrieTracker().SaveKeyValue(esdtToken, marshalledData) + _ = acc.SaveKeyValue(esdtToken, marshalledData) esdtSuffix := append(esdtToken, acc.AddressBytes()...) - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) ch <- trieLeaf @@ -921,13 +921,13 @@ func TestNode_GetESDTsRoles(t *testing.T) { esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.FungibleESDT), SpecialRoles: specialRoles} marshalledData, _ := getMarshalizer().Marshal(esdtData) - _ = acc.DataTrieTracker().SaveKeyValue(esdtToken, marshalledData) + _ = acc.SaveKeyValue(esdtToken, marshalledData) esdtSuffix := append(esdtToken, acc.AddressBytes()...) - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) ch <- trieLeaf @@ -983,13 +983,13 @@ func TestNode_GetNFTTokenIDsRegisteredByAddress(t *testing.T) { esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.SemiFungibleESDT), OwnerAddress: addrBytes} marshalledData, _ := getMarshalizer().Marshal(esdtData) - _ = acc.DataTrieTracker().SaveKeyValue(esdtToken, marshalledData) + _ = acc.SaveKeyValue(esdtToken, marshalledData) esdtSuffix := append(esdtToken, acc.AddressBytes()...) - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) ch <- trieLeaf @@ -1042,9 +1042,9 @@ func TestNode_GetNFTTokenIDsRegisteredByAddressContextShouldTimeout(t *testing.T addrBytes := testscommon.TestPubKeyAlice acc, _ := state.NewUserAccount(addrBytes) - acc.DataTrieTracker().SetDataTrie( + acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { time.Sleep(time.Second) close(ch) @@ -3506,7 +3506,7 @@ func TestNode_GetProofDataTrieShouldWork(t *testing.T) { }, nil }, GetAccountFromBytesCalled: func(address []byte, accountBytes []byte) (vmcommon.AccountHandler, error) { - acc := &mock.AccountWrapMock{} + acc := &stateMock.AccountWrapMock{} acc.SetTrackableDataTrie(&trieMock.DataTrieTrackerStub{ RetrieveValueCalled: func(key []byte) ([]byte, error) { assert.Equal(t, dataTrieKey, hex.EncodeToString(key)) @@ -3990,7 +3990,7 @@ func getDefaultDataComponents() *nodeMockFactory.DataComponentsMock { func getDefaultBootstrapComponents() *mainFactoryMocks.BootstrapComponentsStub { return &mainFactoryMocks.BootstrapComponentsStub{ Bootstrapper: &bootstrapMocks.EpochStartBootstrapperStub{ - TrieHolder: &mock.TriesHolderStub{}, + TrieHolder: &trieMock.TriesHolderStub{}, StorageManagers: map[string]common.StorageManager{"0": &testscommon.StorageManagerStub{}}, BootstrapCalled: nil, }, diff --git a/node/trieIterators/delegatedListProcessor.go b/node/trieIterators/delegatedListProcessor.go index 34506d01ccf..1cb0d643abc 100644 --- a/node/trieIterators/delegatedListProcessor.go +++ b/node/trieIterators/delegatedListProcessor.go @@ -13,6 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/process" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/vm" vmcommon "github.com/ElrondNetwork/elrond-vm-common" ) @@ -127,7 +128,7 @@ func (dlp *delegatedListProcessor) getDelegatorsList(delegationSC []byte, ctx co } chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = delegatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash) + err = delegatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, err } diff --git a/node/trieIterators/delegatedListProcessor_test.go b/node/trieIterators/delegatedListProcessor_test.go index fd29704092c..c8db6d5628b 100644 --- a/node/trieIterators/delegatedListProcessor_test.go +++ b/node/trieIterators/delegatedListProcessor_test.go @@ -13,6 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/core/keyValStorage" "github.com/ElrondNetwork/elrond-go-core/data/api" + "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/node/mock" "github.com/ElrondNetwork/elrond-go/process" @@ -223,7 +224,7 @@ func createDelegationScAccount(address []byte, leaves [][]byte, rootHash []byte, RootCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { time.Sleep(timeSleep) for _, leafBuff := range leaves { diff --git a/node/trieIterators/directStakedListProcessor.go b/node/trieIterators/directStakedListProcessor.go index 087039b5e78..1219a2d94ed 100644 --- a/node/trieIterators/directStakedListProcessor.go +++ b/node/trieIterators/directStakedListProcessor.go @@ -8,6 +8,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/api" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/vm" ) @@ -54,7 +55,7 @@ func (dslp *directStakedListProcessor) getAllStakedAccounts(validatorAccount sta } chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash) + err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, err } diff --git a/node/trieIterators/directStakedListProcessor_test.go b/node/trieIterators/directStakedListProcessor_test.go index 330b5bbe478..c1efb004a80 100644 --- a/node/trieIterators/directStakedListProcessor_test.go +++ b/node/trieIterators/directStakedListProcessor_test.go @@ -13,6 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/core/keyValStorage" "github.com/ElrondNetwork/elrond-go-core/data/api" + "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/node/mock" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/state" @@ -153,7 +154,7 @@ func createValidatorScAccount(address []byte, leaves [][]byte, rootHash []byte, RootCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { time.Sleep(timeSleep) for _, leafBuff := range leaves { diff --git a/node/trieIterators/stakeValuesProcessor.go b/node/trieIterators/stakeValuesProcessor.go index 7a1d7ac37f3..f7be26f572f 100644 --- a/node/trieIterators/stakeValuesProcessor.go +++ b/node/trieIterators/stakeValuesProcessor.go @@ -12,6 +12,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/vm" ) @@ -96,7 +97,7 @@ func (svp *stakedValuesProcessor) computeBaseStakedAndTopUp(ctx context.Context) // TODO investigate if a call to GetAllLeavesKeysOnChannel (without values) might increase performance chLeaves := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash) + err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return nil, nil, err } diff --git a/node/trieIterators/stakeValuesProcessor_test.go b/node/trieIterators/stakeValuesProcessor_test.go index dd0daf1eb0f..8d38c933c12 100644 --- a/node/trieIterators/stakeValuesProcessor_test.go +++ b/node/trieIterators/stakeValuesProcessor_test.go @@ -11,6 +11,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/keyValStorage" "github.com/ElrondNetwork/elrond-go-core/data/api" + "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/node/mock" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/state" @@ -192,7 +193,7 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue_ContextShouldTimeout(t *t acc, _ := state.NewUserAccount([]byte("newaddress")) acc.SetDataTrie(&trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(chLeaves chan core.KeyValueHolder, _ context.Context, _ []byte) error { + GetAllLeavesOnChannelCalled: func(chLeaves chan core.KeyValueHolder, _ context.Context, _ []byte, _ common.KeyBuilder) error { time.Sleep(time.Second) close(chLeaves) return nil @@ -227,7 +228,7 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue_CannotGetAllLeaves(t *tes expectedErr := errors.New("expected error") acc, _ := state.NewUserAccount([]byte("newaddress")) acc.SetDataTrie(&trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(_ chan core.KeyValueHolder, _ context.Context, _ []byte) error { + GetAllLeavesOnChannelCalled: func(_ chan core.KeyValueHolder, _ context.Context, _ []byte, _ common.KeyBuilder) error { return expectedErr }, RootCalled: func() ([]byte, error) { @@ -275,7 +276,7 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue(t *testing.T) { RootCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { leaf1 := keyValStorage.NewKeyValStorage(rootHash, append(marshalledData, suffix...)) ch <- leaf1 diff --git a/outport/factory/notifierFactory.go b/outport/factory/notifierFactory.go index 60d21da7967..ec4aa3e6b24 100644 --- a/outport/factory/notifierFactory.go +++ b/outport/factory/notifierFactory.go @@ -11,14 +11,15 @@ import ( // EventNotifierFactoryArgs defines the args needed for event notifier creation type EventNotifierFactoryArgs struct { - Enabled bool - UseAuthorization bool - ProxyUrl string - Username string - Password string - Marshaller marshal.Marshalizer - Hasher hashing.Hasher - PubKeyConverter core.PubkeyConverter + Enabled bool + UseAuthorization bool + ProxyUrl string + Username string + Password string + RequestTimeoutSec int + Marshaller marshal.Marshalizer + Hasher hashing.Hasher + PubKeyConverter core.PubkeyConverter } // CreateEventNotifier will create a new event notifier client instance @@ -27,16 +28,21 @@ func CreateEventNotifier(args *EventNotifierFactoryArgs) (outport.Driver, error) return nil, err } - httpClient := notifier.NewHttpClient(notifier.HttpClientArgs{ - UseAuthorization: args.UseAuthorization, - Username: args.Username, - Password: args.Password, - BaseUrl: args.ProxyUrl, - }) + httpClientArgs := notifier.HTTPClientWrapperArgs{ + UseAuthorization: args.UseAuthorization, + Username: args.Username, + Password: args.Password, + BaseUrl: args.ProxyUrl, + RequestTimeoutSec: args.RequestTimeoutSec, + } + httpClient, err := notifier.NewHTTPWrapperClient(httpClientArgs) + if err != nil { + return nil, err + } notifierArgs := notifier.ArgsEventNotifier{ HttpClient: httpClient, - Marshalizer: args.Marshaller, + Marshaller: args.Marshaller, Hasher: args.Hasher, PubKeyConverter: args.PubKeyConverter, } diff --git a/outport/factory/notifierFactory_test.go b/outport/factory/notifierFactory_test.go index 1c673aac63d..18e76e3faa4 100644 --- a/outport/factory/notifierFactory_test.go +++ b/outport/factory/notifierFactory_test.go @@ -13,14 +13,15 @@ import ( func createMockNotifierFactoryArgs() *factory.EventNotifierFactoryArgs { return &factory.EventNotifierFactoryArgs{ - Enabled: true, - UseAuthorization: true, - ProxyUrl: "http://localhost:5000", - Username: "", - Password: "", - Marshaller: &testscommon.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - PubKeyConverter: &testscommon.PubkeyConverterMock{}, + Enabled: true, + UseAuthorization: true, + ProxyUrl: "http://localhost:5000", + Username: "", + Password: "", + RequestTimeoutSec: 1, + Marshaller: &testscommon.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + PubKeyConverter: &testscommon.PubkeyConverterMock{}, } } diff --git a/outport/factory/outportFactory_test.go b/outport/factory/outportFactory_test.go index 09aab09216b..e6c3931d35b 100644 --- a/outport/factory/outportFactory_test.go +++ b/outport/factory/outportFactory_test.go @@ -127,6 +127,7 @@ func TestCreateOutport_SubscribeNotifierDriver(t *testing.T) { args.EventNotifierFactoryArgs.Marshaller = &mock.MarshalizerMock{} args.EventNotifierFactoryArgs.Hasher = &hashingMocks.HasherMock{} args.EventNotifierFactoryArgs.PubKeyConverter = &mock.PubkeyConverterMock{} + args.EventNotifierFactoryArgs.RequestTimeoutSec = 1 outPort, err := factory.CreateOutport(args) defer func(c outport.OutportHandler) { diff --git a/outport/mock/httpClientStub.go b/outport/mock/httpClientStub.go index f93eb04854b..eea305de0da 100644 --- a/outport/mock/httpClientStub.go +++ b/outport/mock/httpClientStub.go @@ -2,13 +2,13 @@ package mock // HTTPClientStub - type HTTPClientStub struct { - PostCalled func(route string, payload interface{}, response interface{}) error + PostCalled func(route string, payload interface{}) error } // Post - -func (stub *HTTPClientStub) Post(route string, payload interface{}, response interface{}) error { +func (stub *HTTPClientStub) Post(route string, payload interface{}) error { if stub.PostCalled != nil { - return stub.PostCalled(route, payload, response) + return stub.PostCalled(route, payload) } return nil diff --git a/outport/notifier/errors.go b/outport/notifier/errors.go index 40467bb1842..7c6fff363ac 100644 --- a/outport/notifier/errors.go +++ b/outport/notifier/errors.go @@ -6,3 +6,18 @@ import ( // ErrNilTransactionsPool signals that a nil transactions pool was provided var ErrNilTransactionsPool = errors.New("nil transactions pool") + +// ErrInvalidValue signals that an invalid value has been provided +var ErrInvalidValue = errors.New("invalid value") + +// ErrNilHTTPClientWrapper signals that a nil http client wrapper has been provided +var ErrNilHTTPClientWrapper = errors.New("nil http client wrapper") + +// ErrNilMarshaller signals that a nil marshaller has been provided +var ErrNilMarshaller = errors.New("nil marshaller") + +// ErrNilPubKeyConverter signals that a nil pubkey converter has been provided +var ErrNilPubKeyConverter = errors.New("nil pub key converter") + +// ErrNilHasher is raised when a valid hasher is expected but nil used +var ErrNilHasher = errors.New("hasher is nil") diff --git a/outport/notifier/eventNotifier.go b/outport/notifier/eventNotifier.go index bc52880b31d..7c4694951c7 100644 --- a/outport/notifier/eventNotifier.go +++ b/outport/notifier/eventNotifier.go @@ -67,7 +67,7 @@ type logEvent struct { // ArgsEventNotifier defines the arguments needed for event notifier creation type ArgsEventNotifier struct { HttpClient httpClientHandler - Marshalizer marshal.Marshalizer + Marshaller marshal.Marshalizer Hasher hashing.Hasher PubKeyConverter core.PubkeyConverter } @@ -75,14 +75,36 @@ type ArgsEventNotifier struct { // NewEventNotifier creates a new instance of the eventNotifier // It implements all methods of process.Indexer func NewEventNotifier(args ArgsEventNotifier) (*eventNotifier, error) { + err := checkEventNotifierArgs(args) + if err != nil { + return nil, err + } + return &eventNotifier{ httpClient: args.HttpClient, - marshalizer: args.Marshalizer, + marshalizer: args.Marshaller, hasher: args.Hasher, pubKeyConverter: args.PubKeyConverter, }, nil } +func checkEventNotifierArgs(args ArgsEventNotifier) error { + if check.IfNil(args.HttpClient) { + return ErrNilHTTPClientWrapper + } + if check.IfNil(args.Marshaller) { + return ErrNilMarshaller + } + if check.IfNil(args.Hasher) { + return ErrNilHasher + } + if check.IfNil(args.PubKeyConverter) { + return ErrNilPubKeyConverter + } + + return nil +} + // SaveBlock converts block data in order to be pushed to subscribers func (en *eventNotifier) SaveBlock(args *indexer.ArgsSaveBlockData) error { log.Debug("eventNotifier: SaveBlock called at block", "block hash", args.HeaderHash) @@ -103,7 +125,7 @@ func (en *eventNotifier) SaveBlock(args *indexer.ArgsSaveBlockData) error { LogEvents: events, } - err := en.httpClient.Post(pushEventEndpoint, blockData, nil) + err := en.httpClient.Post(pushEventEndpoint, blockData) if err != nil { return fmt.Errorf("%w in eventNotifier.SaveBlock while posting block data", err) } @@ -175,7 +197,7 @@ func (en *eventNotifier) RevertIndexedBlock(header nodeData.HeaderHandler, _ nod Epoch: header.GetEpoch(), } - err = en.httpClient.Post(revertEventsEndpoint, revertBlock, nil) + err = en.httpClient.Post(revertEventsEndpoint, revertBlock) if err != nil { return fmt.Errorf("%w in eventNotifier.RevertIndexedBlock while posting event data", err) } @@ -189,7 +211,7 @@ func (en *eventNotifier) FinalizedBlock(headerHash []byte) error { Hash: hex.EncodeToString(headerHash), } - err := en.httpClient.Post(finalizedEventsEndpoint, finalizedBlock, nil) + err := en.httpClient.Post(finalizedEventsEndpoint, finalizedBlock) if err != nil { return fmt.Errorf("%w in eventNotifier.FinalizedBlock while posting event data", err) } diff --git a/outport/notifier/eventNotifier_test.go b/outport/notifier/eventNotifier_test.go index b9f76fa7483..033c00faf94 100644 --- a/outport/notifier/eventNotifier_test.go +++ b/outport/notifier/eventNotifier_test.go @@ -20,7 +20,7 @@ import ( func createMockEventNotifierArgs() notifier.ArgsEventNotifier { return notifier.ArgsEventNotifier{ HttpClient: &mock.HTTPClientStub{}, - Marshalizer: &testscommon.MarshalizerMock{}, + Marshaller: &testscommon.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, PubKeyConverter: &testscommon.PubkeyConverterMock{}, } @@ -29,9 +29,57 @@ func createMockEventNotifierArgs() notifier.ArgsEventNotifier { func TestNewEventNotifier(t *testing.T) { t.Parallel() - en, err := notifier.NewEventNotifier(createMockEventNotifierArgs()) - require.Nil(t, err) - require.NotNil(t, en) + t.Run("nil http client", func(t *testing.T) { + t.Parallel() + + args := createMockEventNotifierArgs() + args.HttpClient = nil + + en, err := notifier.NewEventNotifier(args) + require.Nil(t, en) + require.Equal(t, notifier.ErrNilHTTPClientWrapper, err) + }) + + t.Run("nil marshaller", func(t *testing.T) { + t.Parallel() + + args := createMockEventNotifierArgs() + args.Marshaller = nil + + en, err := notifier.NewEventNotifier(args) + require.Nil(t, en) + require.Equal(t, notifier.ErrNilMarshaller, err) + }) + + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() + + args := createMockEventNotifierArgs() + args.Hasher = nil + + en, err := notifier.NewEventNotifier(args) + require.Nil(t, en) + require.Equal(t, notifier.ErrNilHasher, err) + }) + + t.Run("nil pub key converter", func(t *testing.T) { + t.Parallel() + + args := createMockEventNotifierArgs() + args.PubKeyConverter = nil + + en, err := notifier.NewEventNotifier(args) + require.Nil(t, en) + require.Equal(t, notifier.ErrNilPubKeyConverter, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + en, err := notifier.NewEventNotifier(createMockEventNotifierArgs()) + require.Nil(t, err) + require.NotNil(t, en) + }) } func TestSaveBlock(t *testing.T) { @@ -41,7 +89,7 @@ func TestSaveBlock(t *testing.T) { wasCalled := false args.HttpClient = &mock.HTTPClientStub{ - PostCalled: func(route string, payload, response interface{}) error { + PostCalled: func(route string, payload interface{}) error { wasCalled = true return nil }, @@ -75,7 +123,7 @@ func TestRevertIndexedBlock(t *testing.T) { wasCalled := false args.HttpClient = &mock.HTTPClientStub{ - PostCalled: func(route string, payload, response interface{}) error { + PostCalled: func(route string, payload interface{}) error { wasCalled = true return nil }, @@ -101,7 +149,7 @@ func TestFinalizedBlock(t *testing.T) { wasCalled := false args.HttpClient = &mock.HTTPClientStub{ - PostCalled: func(route string, payload, response interface{}) error { + PostCalled: func(route string, payload interface{}) error { wasCalled = true return nil }, diff --git a/outport/notifier/httpClient.go b/outport/notifier/httpClientWrapper.go similarity index 52% rename from outport/notifier/httpClient.go rename to outport/notifier/httpClientWrapper.go index 7ae732cc181..6fc38bd2841 100644 --- a/outport/notifier/httpClient.go +++ b/outport/notifier/httpClientWrapper.go @@ -6,54 +6,69 @@ import ( "fmt" "io/ioutil" "net/http" + "time" ) const ( - contentTypeKey = "Content-Type" - contentTypeValue = "application/json" + minRequestTimeoutSec = 1 + contentTypeKey = "Content-Type" + contentTypeValue = "application/json" ) -type httpClientHandler interface { - Post(route string, payload interface{}, response interface{}) error -} - -type httpClient struct { +type httpClientWrapper struct { + httpClient *http.Client useAuthorization bool username string password string baseUrl string } -// HttpClientArgs defines the arguments needed for http client creation -type HttpClientArgs struct { - UseAuthorization bool - Username string - Password string - BaseUrl string +// HTTPClientWrapperArgs defines the arguments needed for http client creation +type HTTPClientWrapperArgs struct { + UseAuthorization bool + Username string + Password string + BaseUrl string + RequestTimeoutSec int } -// NewHttpClient creates an instance of httpClient which is a wrapper for http.Client -func NewHttpClient(args HttpClientArgs) *httpClient { - return &httpClient{ +// NewHTTPWrapperClient creates an instance of httpClient which is a wrapper for http.Client +func NewHTTPWrapperClient(args HTTPClientWrapperArgs) (*httpClientWrapper, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + httpClient := &http.Client{} + httpClient.Timeout = time.Duration(args.RequestTimeoutSec) * time.Second + + return &httpClientWrapper{ + httpClient: httpClient, useAuthorization: args.UseAuthorization, username: args.Username, password: args.Password, baseUrl: args.BaseUrl, + }, nil +} + +func checkArgs(args HTTPClientWrapperArgs) error { + if args.RequestTimeoutSec < minRequestTimeoutSec { + return fmt.Errorf("%w, provided: %v, minimum: %v", ErrInvalidValue, args.RequestTimeoutSec, minRequestTimeoutSec) } + + return nil } // Post can be used to send POST requests. It handles marshalling to/from json -func (h *httpClient) Post( +func (h *httpClientWrapper) Post( route string, payload interface{}, - response interface{}, ) error { jsonData, err := json.Marshal(payload) if err != nil { return err } - client := &http.Client{} url := fmt.Sprintf("%s%s", h.baseUrl, route) req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonData)) if err != nil { @@ -66,7 +81,7 @@ func (h *httpClient) Post( req.SetBasicAuth(h.username, h.password) } - resp, err := client.Do(req) + resp, err := h.httpClient.Do(req) if err != nil { return err } @@ -89,5 +104,10 @@ func (h *httpClient) Post( return fmt.Errorf("HTTP status code: %d, %s", resp.StatusCode, http.StatusText(resp.StatusCode)) } - return json.Unmarshal(resBody, &response) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (h *httpClientWrapper) IsInterfaceNil() bool { + return h == nil } diff --git a/outport/notifier/httpClient_test.go b/outport/notifier/httpClientWrapper_test.go similarity index 55% rename from outport/notifier/httpClient_test.go rename to outport/notifier/httpClientWrapper_test.go index d7bbada6a1f..17040a33548 100644 --- a/outport/notifier/httpClient_test.go +++ b/outport/notifier/httpClientWrapper_test.go @@ -2,11 +2,13 @@ package notifier_test import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" "testing" + "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/outport/notifier" "github.com/stretchr/testify/require" ) @@ -15,21 +17,37 @@ type testStruct struct { Hash string `json:"hash"` } -func createMockHTTPClientArgs() notifier.HttpClientArgs { - return notifier.HttpClientArgs{ - UseAuthorization: false, - Username: "user", - Password: "pass", - BaseUrl: "http://localhost:8080", +func createMockHTTPClientArgs() notifier.HTTPClientWrapperArgs { + return notifier.HTTPClientWrapperArgs{ + UseAuthorization: false, + Username: "user", + Password: "pass", + BaseUrl: "http://localhost:8080", + RequestTimeoutSec: 60, } } func TestNewHTTPClient(t *testing.T) { t.Parallel() - args := createMockHTTPClientArgs() - client := notifier.NewHttpClient(args) - require.NotNil(t, client) + t.Run("invalid request timeout, should fail", func(t *testing.T) { + t.Parallel() + + args := createMockHTTPClientArgs() + args.RequestTimeoutSec = 0 + client, err := notifier.NewHTTPWrapperClient(args) + require.True(t, check.IfNil(client)) + require.True(t, errors.Is(err, notifier.ErrInvalidValue)) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := createMockHTTPClientArgs() + client, err := notifier.NewHTTPWrapperClient(args) + require.Nil(t, err) + require.False(t, check.IfNil(client)) + }) } func TestPOST(t *testing.T) { @@ -51,10 +69,11 @@ func TestPOST(t *testing.T) { args := createMockHTTPClientArgs() args.BaseUrl = ws.URL - client := notifier.NewHttpClient(args) + client, err := notifier.NewHTTPWrapperClient(args) + require.Nil(t, err) require.NotNil(t, client) - err := client.Post("/events/push", testPayload, nil) + err = client.Post("/events/push", testPayload) require.Nil(t, err) require.True(t, wasCalled) @@ -82,10 +101,11 @@ func TestPOSTShouldFail(t *testing.T) { args := createMockHTTPClientArgs() args.BaseUrl = ws.URL - client := notifier.NewHttpClient(args) + client, err := notifier.NewHTTPWrapperClient(args) + require.Nil(t, err) require.NotNil(t, client) - err := client.Post("/events/push", testPayload, nil) + err = client.Post("/events/push", testPayload) require.True(t, strings.Contains(err.Error(), http.StatusText(statusCode))) require.True(t, wasCalled) diff --git a/outport/notifier/interface.go b/outport/notifier/interface.go new file mode 100644 index 00000000000..52bdf53eb52 --- /dev/null +++ b/outport/notifier/interface.go @@ -0,0 +1,6 @@ +package notifier + +type httpClientHandler interface { + Post(route string, payload interface{}) error + IsInterfaceNil() bool +} diff --git a/outport/outport.go b/outport/outport.go index c3fd9dd99e5..0ee8204f0e2 100644 --- a/outport/outport.go +++ b/outport/outport.go @@ -3,6 +3,7 @@ package outport import ( "fmt" "sync" + "sync/atomic" "time" "github.com/ElrondNetwork/elrond-go-core/core/check" @@ -13,13 +14,17 @@ import ( var log = logger.GetOrCreate("outport") +const maxTimeForDriverCall = time.Second * 30 const minimumRetrialInterval = time.Millisecond * 10 type outport struct { - mutex sync.RWMutex - drivers []Driver - retrialInterval time.Duration - chanClose chan struct{} + mutex sync.RWMutex + drivers []Driver + retrialInterval time.Duration + chanClose chan struct{} + logHandler func(logLevel logger.LogLevel, message string, args ...interface{}) + timeForDriverCall time.Duration + messageCounter uint64 } // NewOutport will create a new instance of proxy @@ -29,10 +34,12 @@ func NewOutport(retrialInterval time.Duration) (*outport, error) { } return &outport{ - drivers: make([]Driver, 0), - mutex: sync.RWMutex{}, - retrialInterval: retrialInterval, - chanClose: make(chan struct{}), + drivers: make([]Driver, 0), + mutex: sync.RWMutex{}, + retrialInterval: retrialInterval, + chanClose: make(chan struct{}), + logHandler: log.Log, + timeForDriverCall: maxTimeForDriverCall, }, nil } @@ -46,7 +53,34 @@ func (o *outport) SaveBlock(args *indexer.ArgsSaveBlockData) { } } +func (o *outport) monitorCompletionOnDriver(function string, driver Driver) chan struct{} { + counter := atomic.AddUint64(&o.messageCounter, 1) + + o.logHandler(logger.LogDebug, "outport.monitorCompletionOnDriver starting", + "function", function, "driver", driverString(driver), "message counter", counter) + ch := make(chan struct{}) + go func() { + timer := time.NewTimer(o.timeForDriverCall) + + select { + case <-ch: + o.logHandler(logger.LogDebug, "outport.monitorCompletionOnDriver ended", + "function", function, "driver", driverString(driver), "message counter", counter) + case <-timer.C: + o.logHandler(logger.LogWarning, "outport.monitorCompletionOnDriver took too long", + "function", function, "driver", driverString(driver), "message counter", counter, "time", o.timeForDriverCall) + } + + timer.Stop() + }() + + return ch +} + func (o *outport) saveBlockBlocking(args *indexer.ArgsSaveBlockData, driver Driver) { + ch := o.monitorCompletionOnDriver("saveBlockBlocking", driver) + defer close(ch) + for { err := driver.SaveBlock(args) if err == nil { @@ -84,6 +118,9 @@ func (o *outport) RevertIndexedBlock(header data.HeaderHandler, body data.BodyHa } func (o *outport) revertIndexedBlockBlocking(header data.HeaderHandler, body data.BodyHandler, driver Driver) { + ch := o.monitorCompletionOnDriver("revertIndexedBlockBlocking", driver) + defer close(ch) + for { err := driver.RevertIndexedBlock(header, body) if err == nil { @@ -112,6 +149,9 @@ func (o *outport) SaveRoundsInfo(roundsInfo []*indexer.RoundInfo) { } func (o *outport) saveRoundsInfoBlocking(roundsInfo []*indexer.RoundInfo, driver Driver) { + ch := o.monitorCompletionOnDriver("saveRoundsInfoBlocking", driver) + defer close(ch) + for { err := driver.SaveRoundsInfo(roundsInfo) if err == nil { @@ -140,6 +180,9 @@ func (o *outport) SaveValidatorsPubKeys(validatorsPubKeys map[uint32][][]byte, e } func (o *outport) saveValidatorsPubKeysBlocking(validatorsPubKeys map[uint32][][]byte, epoch uint32, driver Driver) { + ch := o.monitorCompletionOnDriver("saveValidatorsPubKeysBlocking", driver) + defer close(ch) + for { err := driver.SaveValidatorsPubKeys(validatorsPubKeys, epoch) if err == nil { @@ -168,6 +211,9 @@ func (o *outport) SaveValidatorsRating(indexID string, infoRating []*indexer.Val } func (o *outport) saveValidatorsRatingBlocking(indexID string, infoRating []*indexer.ValidatorRatingInfo, driver Driver) { + ch := o.monitorCompletionOnDriver("saveValidatorsRatingBlocking", driver) + defer close(ch) + for { err := driver.SaveValidatorsRating(indexID, infoRating) if err == nil { @@ -196,6 +242,9 @@ func (o *outport) SaveAccounts(blockTimestamp uint64, acc []data.UserAccountHand } func (o *outport) saveAccountsBlocking(blockTimestamp uint64, acc []data.UserAccountHandler, driver Driver) { + ch := o.monitorCompletionOnDriver("saveAccountsBlocking", driver) + defer close(ch) + for { err := driver.SaveAccounts(blockTimestamp, acc) if err == nil { @@ -224,6 +273,9 @@ func (o *outport) FinalizedBlock(headerHash []byte) { } func (o *outport) finalizedBlockBlocking(headerHash []byte, driver Driver) { + ch := o.monitorCompletionOnDriver("finalizedBlockBlocking", driver) + defer close(ch) + for { err := driver.FinalizedBlock(headerHash) if err == nil { diff --git a/outport/outport_test.go b/outport/outport_test.go index ee10153809f..c1fe8bbbed4 100644 --- a/outport/outport_test.go +++ b/outport/outport_test.go @@ -3,17 +3,22 @@ package outport import ( "errors" "sync" + atomicGo "sync/atomic" "testing" "time" + "github.com/ElrondNetwork/elrond-go-core/core/atomic" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/data" "github.com/ElrondNetwork/elrond-go-core/data/indexer" + logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/outport/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +const counterPositionInLogMessage = 5 + func TestNewOutport(t *testing.T) { t.Parallel() @@ -54,13 +59,27 @@ func TestOutport_SaveAccounts(t *testing.T) { }, } outportHandler, _ := NewOutport(minimumRetrialInterval) + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + outportHandler.SaveAccounts(0, []data.UserAccountHandler{}) + time.Sleep(time.Second) _ = outportHandler.SubscribeDriver(driver1) _ = outportHandler.SubscribeDriver(driver2) outportHandler.SaveAccounts(0, []data.UserAccountHandler{}) + time.Sleep(time.Second) + assert.Equal(t, 10, numCalled1) assert.Equal(t, 1, numCalled2) + assert.Equal(t, uint32(4), atomicGo.LoadUint32(&numLogDebugCalled)) } func TestOutport_SaveBlock(t *testing.T) { @@ -86,13 +105,26 @@ func TestOutport_SaveBlock(t *testing.T) { }, } outportHandler, _ := NewOutport(minimumRetrialInterval) + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + outportHandler.SaveBlock(nil) _ = outportHandler.SubscribeDriver(driver1) _ = outportHandler.SubscribeDriver(driver2) outportHandler.SaveBlock(nil) + time.Sleep(time.Second) + assert.Equal(t, 10, numCalled1) assert.Equal(t, 1, numCalled2) + assert.Equal(t, uint32(4), atomicGo.LoadUint32(&numLogDebugCalled)) } func TestOutport_SaveRoundsInfo(t *testing.T) { @@ -118,13 +150,26 @@ func TestOutport_SaveRoundsInfo(t *testing.T) { }, } outportHandler, _ := NewOutport(minimumRetrialInterval) + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + outportHandler.SaveRoundsInfo(nil) _ = outportHandler.SubscribeDriver(driver1) _ = outportHandler.SubscribeDriver(driver2) outportHandler.SaveRoundsInfo(nil) + + time.Sleep(time.Second) assert.Equal(t, 10, numCalled1) assert.Equal(t, 1, numCalled2) + assert.Equal(t, uint32(4), atomicGo.LoadUint32(&numLogDebugCalled)) } func TestOutport_SaveValidatorsPubKeys(t *testing.T) { @@ -150,13 +195,28 @@ func TestOutport_SaveValidatorsPubKeys(t *testing.T) { }, } outportHandler, _ := NewOutport(minimumRetrialInterval) + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + outportHandler.SaveValidatorsPubKeys(nil, 0) + time.Sleep(time.Second) + _ = outportHandler.SubscribeDriver(driver1) _ = outportHandler.SubscribeDriver(driver2) outportHandler.SaveValidatorsPubKeys(nil, 0) + time.Sleep(time.Second) + assert.Equal(t, 10, numCalled1) assert.Equal(t, 1, numCalled2) + assert.Equal(t, uint32(4), atomicGo.LoadUint32(&numLogDebugCalled)) } func TestOutport_SaveValidatorsRating(t *testing.T) { @@ -182,13 +242,28 @@ func TestOutport_SaveValidatorsRating(t *testing.T) { }, } outportHandler, _ := NewOutport(minimumRetrialInterval) + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + outportHandler.SaveValidatorsRating("", nil) + time.Sleep(time.Second) + _ = outportHandler.SubscribeDriver(driver1) _ = outportHandler.SubscribeDriver(driver2) outportHandler.SaveValidatorsRating("", nil) + time.Sleep(time.Second) + assert.Equal(t, 10, numCalled1) assert.Equal(t, 1, numCalled2) + assert.Equal(t, uint32(4), atomicGo.LoadUint32(&numLogDebugCalled)) } func TestOutport_RevertIndexedBlock(t *testing.T) { @@ -214,13 +289,28 @@ func TestOutport_RevertIndexedBlock(t *testing.T) { }, } outportHandler, _ := NewOutport(minimumRetrialInterval) + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + outportHandler.RevertIndexedBlock(nil, nil) + time.Sleep(time.Second) + _ = outportHandler.SubscribeDriver(driver1) _ = outportHandler.SubscribeDriver(driver2) outportHandler.RevertIndexedBlock(nil, nil) + time.Sleep(time.Second) + assert.Equal(t, 10, numCalled1) assert.Equal(t, 1, numCalled2) + assert.Equal(t, uint32(4), atomicGo.LoadUint32(&numLogDebugCalled)) } func TestOutport_FinalizedBlock(t *testing.T) { @@ -246,13 +336,28 @@ func TestOutport_FinalizedBlock(t *testing.T) { }, } outportHandler, _ := NewOutport(minimumRetrialInterval) + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + outportHandler.FinalizedBlock(nil) + time.Sleep(time.Second) + _ = outportHandler.SubscribeDriver(driver1) _ = outportHandler.SubscribeDriver(driver2) outportHandler.FinalizedBlock(nil) + time.Sleep(time.Second) + assert.Equal(t, 10, numCalled1) assert.Equal(t, 1, numCalled2) + assert.Equal(t, uint32(4), atomicGo.LoadUint32(&numLogDebugCalled)) } func TestOutport_SubscribeDriver(t *testing.T) { @@ -388,3 +493,75 @@ func TestOutport_CloseWhileDriverIsStuckInContinuousErrors(t *testing.T) { require.Fail(t, "unable to close all drivers because of a stuck driver") } } + +func TestOutport_SaveBlockDriverStuck(t *testing.T) { + t.Parallel() + + currentCounter := uint64(778) + outportHandler, _ := NewOutport(minimumRetrialInterval) + outportHandler.messageCounter = currentCounter + outportHandler.timeForDriverCall = time.Second + logErrorCalled := atomic.Flag{} + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogWarning { + logErrorCalled.SetValue(true) + assert.Equal(t, "outport.monitorCompletionOnDriver took too long", message) + assert.Equal(t, currentCounter+1, args[counterPositionInLogMessage]) + } + if logLevel == logger.LogDebug { + atomicGo.AddUint32(&numLogDebugCalled, 1) + assert.Equal(t, currentCounter+1, args[counterPositionInLogMessage]) + } + } + + _ = outportHandler.SubscribeDriver(&mock.DriverStub{ + SaveBlockCalled: func(args *indexer.ArgsSaveBlockData) error { + time.Sleep(time.Second * 5) + return nil + }, + }) + + outportHandler.SaveBlock(nil) + + assert.True(t, logErrorCalled.IsSet()) + assert.Equal(t, uint32(1), atomicGo.LoadUint32(&numLogDebugCalled)) +} + +func TestOutport_SaveBlockDriverIsNotStuck(t *testing.T) { + t.Parallel() + + currentCounter := uint64(778) + outportHandler, _ := NewOutport(minimumRetrialInterval) + outportHandler.messageCounter = currentCounter + outportHandler.timeForDriverCall = time.Second + numLogDebugCalled := uint32(0) + outportHandler.logHandler = func(logLevel logger.LogLevel, message string, args ...interface{}) { + if logLevel == logger.LogError { + assert.Fail(t, "should have not called log error") + } + if logLevel == logger.LogDebug { + if atomicGo.LoadUint32(&numLogDebugCalled) == 0 { + assert.Equal(t, "outport.monitorCompletionOnDriver starting", message) + assert.Equal(t, currentCounter+1, args[counterPositionInLogMessage]) + } + if atomicGo.LoadUint32(&numLogDebugCalled) == 1 { + assert.Equal(t, "outport.monitorCompletionOnDriver ended", message) + assert.Equal(t, currentCounter+1, args[counterPositionInLogMessage]) + } + + atomicGo.AddUint32(&numLogDebugCalled, 1) + } + } + + _ = outportHandler.SubscribeDriver(&mock.DriverStub{ + SaveBlockCalled: func(args *indexer.ArgsSaveBlockData) error { + return nil + }, + }) + + outportHandler.SaveBlock(nil) + time.Sleep(time.Second) + + assert.Equal(t, uint32(2), atomicGo.LoadUint32(&numLogDebugCalled)) +} diff --git a/p2p/config/config.go b/p2p/config/config.go new file mode 100644 index 00000000000..d3c8311441d --- /dev/null +++ b/p2p/config/config.go @@ -0,0 +1,19 @@ +package config + +import "github.com/ElrondNetwork/elrond-go-p2p/config" + +// P2PConfig will hold all the P2P settings +type P2PConfig = config.P2PConfig + +// NodeConfig will hold basic p2p settings +type NodeConfig = config.NodeConfig + +// KadDhtPeerDiscoveryConfig will hold the kad-dht discovery config settings +type KadDhtPeerDiscoveryConfig = config.KadDhtPeerDiscoveryConfig + +// ShardingConfig will hold the network sharding config settings +type ShardingConfig = config.ShardingConfig + +// AdditionalConnectionsConfig will hold the additional connections that will be open when certain conditions are met +// All these values should be added to the maximum target peer count value +type AdditionalConnectionsConfig = config.AdditionalConnectionsConfig diff --git a/p2p/constants.go b/p2p/constants.go new file mode 100644 index 00000000000..755c82d1669 --- /dev/null +++ b/p2p/constants.go @@ -0,0 +1,30 @@ +package p2p + +import ( + p2p "github.com/ElrondNetwork/elrond-go-p2p" + "github.com/ElrondNetwork/elrond-go-p2p/libp2p" +) + +// NodeOperation defines the p2p node operation +type NodeOperation = p2p.NodeOperation + +// NormalOperation defines the normal mode operation: either seeder, observer or validator +const NormalOperation = p2p.NormalOperation + +// FullArchiveMode defines the node operation as a full archive mode +const FullArchiveMode = p2p.FullArchiveMode + +// ListsSharder is the variant that uses lists +const ListsSharder = p2p.ListsSharder + +// NilListSharder is the variant that will not do connection trimming +const NilListSharder = p2p.NilListSharder + +// ConnectionWatcherTypePrint - new connection found will be printed in the log file +const ConnectionWatcherTypePrint = p2p.ConnectionWatcherTypePrint + +// ListenAddrWithIp4AndTcp defines the listening address with ip v.4 and TCP +const ListenAddrWithIp4AndTcp = libp2p.ListenAddrWithIp4AndTcp + +// ListenLocalhostAddrWithIp4AndTcp defines the local host listening ip v.4 address and TCP +const ListenLocalhostAddrWithIp4AndTcp = libp2p.ListenLocalhostAddrWithIp4AndTcp diff --git a/p2p/crypto/errors.go b/p2p/crypto/errors.go deleted file mode 100644 index f6ca563a14a..00000000000 --- a/p2p/crypto/errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package crypto - -import "errors" - -var errNilPrivateKey = errors.New("nil private key") diff --git a/p2p/crypto/identityGenerator.go b/p2p/crypto/identityGenerator.go deleted file mode 100644 index 4389636f7a6..00000000000 --- a/p2p/crypto/identityGenerator.go +++ /dev/null @@ -1,60 +0,0 @@ -package crypto - -import ( - "crypto/ecdsa" - - "github.com/ElrondNetwork/elrond-go-core/core" - randFactory "github.com/ElrondNetwork/elrond-go/p2p/libp2p/rand/factory" - "github.com/btcsuite/btcd/btcec" - "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" -) - -const emptySeed = "" - -type identityGenerator struct { -} - -// NewIdentityGenerator creates a new identity generator -func NewIdentityGenerator() *identityGenerator { - return &identityGenerator{} -} - -// CreateRandomP2PIdentity creates a valid random p2p identity to sign messages on the behalf of other identity -func (generator *identityGenerator) CreateRandomP2PIdentity() ([]byte, core.PeerID, error) { - sk, err := generator.CreateP2PPrivateKey(emptySeed) - if err != nil { - return nil, "", err - } - - skBuff, err := crypto.MarshalPrivateKey(sk) - if err != nil { - return nil, "", err - } - - pid, err := peer.IDFromPublicKey(sk.GetPublic()) - if err != nil { - return nil, "", err - } - - return skBuff, core.PeerID(pid), nil -} - -// CreateP2PPrivateKey will create a new P2P private key based on the provided seed. If the seed is the empty string -// it will use the crypto's random generator to provide a random one. Otherwise, it will create a deterministic private -// key. This is useful when we want a private key that never changes, such as in the network seeders -func (generator *identityGenerator) CreateP2PPrivateKey(seed string) (*crypto.Secp256k1PrivateKey, error) { - randReader, err := randFactory.NewRandFactory(seed) - if err != nil { - return nil, err - } - - prvKey, _ := ecdsa.GenerateKey(btcec.S256(), randReader) - - return (*crypto.Secp256k1PrivateKey)(prvKey), nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (generator *identityGenerator) IsInterfaceNil() bool { - return generator == nil -} diff --git a/p2p/crypto/identityGenerator_test.go b/p2p/crypto/identityGenerator_test.go deleted file mode 100644 index 88ff9023656..00000000000 --- a/p2p/crypto/identityGenerator_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package crypto - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/stretchr/testify/assert" -) - -func TestNewIdentityGenerator(t *testing.T) { - t.Parallel() - - generator := NewIdentityGenerator() - assert.False(t, check.IfNil(generator)) -} - -func TestIdentityGenerator_CreateP2PPrivateKey(t *testing.T) { - t.Parallel() - - generator := NewIdentityGenerator() - seed1 := "secret seed 1" - seed2 := "secret seed 2" - t.Run("same seed should produce the same private key", func(t *testing.T) { - - sk1, err := generator.CreateP2PPrivateKey(seed1) - assert.Nil(t, err) - - sk2, err := generator.CreateP2PPrivateKey(seed1) - assert.Nil(t, err) - - assert.Equal(t, sk1, sk2) - }) - t.Run("different seed should produce different private key", func(t *testing.T) { - sk1, err := generator.CreateP2PPrivateKey(seed1) - assert.Nil(t, err) - - sk2, err := generator.CreateP2PPrivateKey(seed2) - assert.Nil(t, err) - - assert.NotEqual(t, sk1, sk2) - }) - t.Run("empty seed should produce different private key", func(t *testing.T) { - sk1, err := generator.CreateP2PPrivateKey("") - assert.Nil(t, err) - - sk2, err := generator.CreateP2PPrivateKey("") - assert.Nil(t, err) - - assert.NotEqual(t, sk1, sk2) - }) -} - -func TestIdentityGenerator_CreateRandomP2PIdentity(t *testing.T) { - t.Parallel() - - generator := NewIdentityGenerator() - sk1, pid1, err := generator.CreateRandomP2PIdentity() - assert.Nil(t, err) - - sk2, pid2, err := generator.CreateRandomP2PIdentity() - assert.Nil(t, err) - - assert.NotEqual(t, sk1, sk2) - assert.NotEqual(t, pid1, pid2) - assert.Equal(t, 36, len(sk1)) - assert.Equal(t, 39, len(pid1)) - assert.Equal(t, 36, len(sk2)) - assert.Equal(t, 39, len(pid2)) -} diff --git a/p2p/crypto/p2pSigner.go b/p2p/crypto/p2pSigner.go deleted file mode 100644 index b7531db5bf2..00000000000 --- a/p2p/crypto/p2pSigner.go +++ /dev/null @@ -1,63 +0,0 @@ -package crypto - -import ( - "fmt" - - "github.com/ElrondNetwork/elrond-go-core/core" - crypto "github.com/ElrondNetwork/elrond-go-crypto" - libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" -) - -type p2pSigner struct { - privateKey *libp2pCrypto.Secp256k1PrivateKey -} - -// NewP2PSigner creates a new p2pSigner instance -func NewP2PSigner(privateKey *libp2pCrypto.Secp256k1PrivateKey) (*p2pSigner, error) { - if privateKey == nil { - return nil, errNilPrivateKey - } - - return &p2pSigner{ - privateKey: privateKey, - }, nil -} - -// Sign will sign a payload with the internal private key -func (signer *p2pSigner) Sign(payload []byte) ([]byte, error) { - return signer.privateKey.Sign(payload) -} - -// Verify will check that the (payload, peer ID, signature) tuple is valid or not -func (signer *p2pSigner) Verify(payload []byte, pid core.PeerID, signature []byte) error { - libp2pPid, err := peer.IDFromBytes(pid.Bytes()) - if err != nil { - return err - } - - pubk, err := libp2pPid.ExtractPublicKey() - if err != nil { - return fmt.Errorf("cannot extract signing key: %s", err.Error()) - } - - sigOk, err := pubk.Verify(payload, signature) - if err != nil { - return err - } - if !sigOk { - return crypto.ErrInvalidSignature - } - - return nil -} - -// SignUsingPrivateKey will sign the payload with provided private key bytes -func (signer *p2pSigner) SignUsingPrivateKey(skBytes []byte, payload []byte) ([]byte, error) { - sk, err := libp2pCrypto.UnmarshalPrivateKey(skBytes) - if err != nil { - return nil, err - } - - return sk.Sign(payload) -} diff --git a/p2p/crypto/p2pSigner_test.go b/p2p/crypto/p2pSigner_test.go deleted file mode 100644 index d2be94f088b..00000000000 --- a/p2p/crypto/p2pSigner_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package crypto - -import ( - "crypto/ecdsa" - cryptoRand "crypto/rand" - "sync" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - crypto "github.com/ElrondNetwork/elrond-go-crypto" - "github.com/btcsuite/btcd/btcec" - libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" -) - -func generatePrivateKey() *libp2pCrypto.Secp256k1PrivateKey { - prvKey, _ := ecdsa.GenerateKey(btcec.S256(), cryptoRand.Reader) - - return (*libp2pCrypto.Secp256k1PrivateKey)(prvKey) -} - -func TestNewP2PSigner(t *testing.T) { - t.Parallel() - - t.Run("nil private key should error", func(t *testing.T) { - t.Parallel() - - signer, err := NewP2PSigner(nil) - - assert.Nil(t, signer) - assert.Equal(t, errNilPrivateKey, err) - }) - t.Run("valid private key should work", func(t *testing.T) { - t.Parallel() - - signer, err := NewP2PSigner(generatePrivateKey()) - - assert.NotNil(t, signer) - assert.Nil(t, err) - }) -} - -func TestP2pSigner_Sign(t *testing.T) { - t.Parallel() - - signer, _ := NewP2PSigner(generatePrivateKey()) - - sig, err := signer.Sign([]byte("payload")) - assert.Nil(t, err) - assert.NotNil(t, sig) -} - -func TestP2pSigner_Verify(t *testing.T) { - t.Parallel() - - sk := generatePrivateKey() - pk := sk.GetPublic() - payload := []byte("payload") - signer, _ := NewP2PSigner(sk) - libp2pPid, _ := peer.IDFromPublicKey(pk) - - t.Run("invalid public key should error", func(t *testing.T) { - t.Parallel() - - sig, err := signer.Sign(payload) - assert.Nil(t, err) - - err = signer.Verify(payload, "invalid PK", sig) - assert.NotNil(t, err) - assert.Equal(t, "length greater than remaining number of bytes in buffer", err.Error()) - }) - t.Run("malformed signature header should error", func(t *testing.T) { - t.Parallel() - - sig, err := signer.Sign(payload) - assert.Nil(t, err) - - sig[0] = sig[0] ^ sig[1] ^ sig[2] - - err = signer.Verify(payload, core.PeerID(libp2pPid), sig) - assert.NotNil(t, err) - assert.Equal(t, "malformed signature: no header magic", err.Error()) - }) - t.Run("altered signature should error", func(t *testing.T) { - t.Parallel() - - sig, err := signer.Sign(payload) - assert.Nil(t, err) - - sig[len(sig)-1] = sig[0] ^ sig[1] ^ sig[2] - - err = signer.Verify(payload, core.PeerID(libp2pPid), sig) - assert.Equal(t, crypto.ErrInvalidSignature, err) - }) - t.Run("sign and verify should work", func(t *testing.T) { - t.Parallel() - - sig, err := signer.Sign(payload) - assert.Nil(t, err) - - err = signer.Verify(payload, core.PeerID(libp2pPid), sig) - assert.Nil(t, err) - }) -} - -func TestP2PSigner_SignUsingPrivateKey(t *testing.T) { - t.Parallel() - - payload := []byte("payload") - - generator := NewIdentityGenerator() - skBytes1, pid1, err := generator.CreateRandomP2PIdentity() - assert.Nil(t, err) - - skBytes2, pid2, err := generator.CreateRandomP2PIdentity() - assert.Nil(t, err) - assert.NotEqual(t, skBytes1, skBytes2) - - p2pSigner := &p2pSigner{} - - sig1, err := p2pSigner.SignUsingPrivateKey(skBytes1, payload) - assert.Nil(t, err) - - sig2, err := p2pSigner.SignUsingPrivateKey(skBytes2, payload) - assert.Nil(t, err) - assert.NotEqual(t, sig1, sig2) - - assert.Nil(t, p2pSigner.Verify(payload, pid1, sig1)) - assert.Nil(t, p2pSigner.Verify(payload, pid2, sig2)) -} - -func TestP2pSigner_ConcurrentOperations(t *testing.T) { - t.Parallel() - - numOps := 1000 - wg := sync.WaitGroup{} - wg.Add(numOps) - - sk := generatePrivateKey() - pk := sk.GetPublic() - payload1 := []byte("payload1") - payload2 := []byte("payload2") - payload3 := []byte("payload3") - signer, _ := NewP2PSigner(sk) - libp2pPid, _ := peer.IDFromPublicKey(pk) - pid := core.PeerID(libp2pPid) - - sig1, _ := signer.Sign(payload1) - - generator := NewIdentityGenerator() - skBytes, _, err := generator.CreateRandomP2PIdentity() - assert.Nil(t, err) - - for i := 0; i < numOps; i++ { - go func(idx int) { - time.Sleep(time.Millisecond * 10) - - switch idx { - case 0: - _, errSign := signer.Sign(payload2) - assert.Nil(t, errSign) - case 1: - errVerify := signer.Verify(payload1, pid, sig1) - assert.Nil(t, errVerify) - case 2: - _, errSignWithSK := signer.SignUsingPrivateKey(skBytes, payload3) - assert.Nil(t, errSignWithSK) - } - - wg.Done() - }(i % 3) - } - - wg.Wait() -} diff --git a/p2p/data/generate.go b/p2p/data/generate.go deleted file mode 100644 index f803ab5d297..00000000000 --- a/p2p/data/generate.go +++ /dev/null @@ -1,2 +0,0 @@ -//go:generate protoc -I=. -I=$GOPATH/src -I=$GOPATH/src/github.com/ElrondNetwork/protobuf/protobuf --gogoslick_out=. topicMessage.proto -package data diff --git a/p2p/data/topicMessage.pb.go b/p2p/data/topicMessage.pb.go deleted file mode 100644 index 5aa47fb65ee..00000000000 --- a/p2p/data/topicMessage.pb.go +++ /dev/null @@ -1,580 +0,0 @@ -// Code generated by protoc-gen-gogo. DO NOT EDIT. -// source: topicMessage.proto - -package data - -import ( - bytes "bytes" - fmt "fmt" - _ "github.com/gogo/protobuf/gogoproto" - proto "github.com/gogo/protobuf/proto" - io "io" - math "math" - math_bits "math/bits" - reflect "reflect" - strings "strings" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package - -type TopicMessage struct { - Version uint32 `protobuf:"varint,1,opt,name=Version,proto3" json:"Version,omitempty"` - Payload []byte `protobuf:"bytes,2,opt,name=Payload,proto3" json:"Payload,omitempty"` - Timestamp int64 `protobuf:"varint,3,opt,name=Timestamp,proto3" json:"Timestamp,omitempty"` - Pk []byte `protobuf:"bytes,4,opt,name=Pk,proto3" json:"Pk,omitempty"` - SignatureOnPid []byte `protobuf:"bytes,5,opt,name=SignatureOnPid,proto3" json:"SignatureOnPid,omitempty"` -} - -func (m *TopicMessage) Reset() { *m = TopicMessage{} } -func (*TopicMessage) ProtoMessage() {} -func (*TopicMessage) Descriptor() ([]byte, []int) { - return fileDescriptor_131cdede10b420b6, []int{0} -} -func (m *TopicMessage) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *TopicMessage) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil -} -func (m *TopicMessage) XXX_Merge(src proto.Message) { - xxx_messageInfo_TopicMessage.Merge(m, src) -} -func (m *TopicMessage) XXX_Size() int { - return m.Size() -} -func (m *TopicMessage) XXX_DiscardUnknown() { - xxx_messageInfo_TopicMessage.DiscardUnknown(m) -} - -var xxx_messageInfo_TopicMessage proto.InternalMessageInfo - -func (m *TopicMessage) GetVersion() uint32 { - if m != nil { - return m.Version - } - return 0 -} - -func (m *TopicMessage) GetPayload() []byte { - if m != nil { - return m.Payload - } - return nil -} - -func (m *TopicMessage) GetTimestamp() int64 { - if m != nil { - return m.Timestamp - } - return 0 -} - -func (m *TopicMessage) GetPk() []byte { - if m != nil { - return m.Pk - } - return nil -} - -func (m *TopicMessage) GetSignatureOnPid() []byte { - if m != nil { - return m.SignatureOnPid - } - return nil -} - -func init() { - proto.RegisterType((*TopicMessage)(nil), "proto.TopicMessage") -} - -func init() { proto.RegisterFile("topicMessage.proto", fileDescriptor_131cdede10b420b6) } - -var fileDescriptor_131cdede10b420b6 = []byte{ - // 255 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0xc9, 0x2f, 0xc8, - 0x4c, 0xf6, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, - 0x05, 0x53, 0x52, 0xba, 0xe9, 0x99, 0x25, 0x19, 0xa5, 0x49, 0x7a, 0xc9, 0xf9, 0xb9, 0xfa, 0xe9, - 0xf9, 0xe9, 0xf9, 0xfa, 0x60, 0xe1, 0xa4, 0xd2, 0x34, 0x30, 0x0f, 0xcc, 0x01, 0xb3, 0x20, 0xba, - 0x94, 0x66, 0x30, 0x72, 0xf1, 0x84, 0x20, 0x19, 0x26, 0x24, 0xc1, 0xc5, 0x1e, 0x96, 0x5a, 0x54, - 0x9c, 0x99, 0x9f, 0x27, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0x1b, 0x04, 0xe3, 0x82, 0x64, 0x02, 0x12, - 0x2b, 0x73, 0xf2, 0x13, 0x53, 0x24, 0x98, 0x14, 0x18, 0x35, 0x78, 0x82, 0x60, 0x5c, 0x21, 0x19, - 0x2e, 0xce, 0x90, 0xcc, 0xdc, 0xd4, 0xe2, 0x92, 0xc4, 0xdc, 0x02, 0x09, 0x66, 0x05, 0x46, 0x0d, - 0xe6, 0x20, 0x84, 0x80, 0x10, 0x1f, 0x17, 0x53, 0x40, 0xb6, 0x04, 0x0b, 0x58, 0x0b, 0x53, 0x40, - 0xb6, 0x90, 0x1a, 0x17, 0x5f, 0x70, 0x66, 0x7a, 0x5e, 0x62, 0x49, 0x69, 0x51, 0xaa, 0x7f, 0x5e, - 0x40, 0x66, 0x8a, 0x04, 0x2b, 0x58, 0x0e, 0x4d, 0xd4, 0xc9, 0xee, 0xc2, 0x43, 0x39, 0x86, 0x1b, - 0x0f, 0xe5, 0x18, 0x3e, 0x3c, 0x94, 0x63, 0x6c, 0x78, 0x24, 0xc7, 0xb8, 0xe2, 0x91, 0x1c, 0xe3, - 0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0xde, 0x78, 0x24, 0xc7, 0xf8, 0xe0, 0x91, 0x1c, - 0xe3, 0x8b, 0x47, 0x72, 0x0c, 0x1f, 0x1e, 0xc9, 0x31, 0x4e, 0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, - 0x1c, 0xc3, 0x8d, 0xc7, 0x72, 0x0c, 0x51, 0x2c, 0x29, 0x89, 0x25, 0x89, 0x49, 0x6c, 0x60, 0x1f, - 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0xa0, 0x51, 0x72, 0x2b, 0x2d, 0x01, 0x00, 0x00, -} - -func (this *TopicMessage) Equal(that interface{}) bool { - if that == nil { - return this == nil - } - - that1, ok := that.(*TopicMessage) - if !ok { - that2, ok := that.(TopicMessage) - if ok { - that1 = &that2 - } else { - return false - } - } - if that1 == nil { - return this == nil - } else if this == nil { - return false - } - if this.Version != that1.Version { - return false - } - if !bytes.Equal(this.Payload, that1.Payload) { - return false - } - if this.Timestamp != that1.Timestamp { - return false - } - if !bytes.Equal(this.Pk, that1.Pk) { - return false - } - if !bytes.Equal(this.SignatureOnPid, that1.SignatureOnPid) { - return false - } - return true -} -func (this *TopicMessage) GoString() string { - if this == nil { - return "nil" - } - s := make([]string, 0, 9) - s = append(s, "&data.TopicMessage{") - s = append(s, "Version: "+fmt.Sprintf("%#v", this.Version)+",\n") - s = append(s, "Payload: "+fmt.Sprintf("%#v", this.Payload)+",\n") - s = append(s, "Timestamp: "+fmt.Sprintf("%#v", this.Timestamp)+",\n") - s = append(s, "Pk: "+fmt.Sprintf("%#v", this.Pk)+",\n") - s = append(s, "SignatureOnPid: "+fmt.Sprintf("%#v", this.SignatureOnPid)+",\n") - s = append(s, "}") - return strings.Join(s, "") -} -func valueToGoStringTopicMessage(v interface{}, typ string) string { - rv := reflect.ValueOf(v) - if rv.IsNil() { - return "nil" - } - pv := reflect.Indirect(rv).Interface() - return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) -} -func (m *TopicMessage) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *TopicMessage) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *TopicMessage) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if len(m.SignatureOnPid) > 0 { - i -= len(m.SignatureOnPid) - copy(dAtA[i:], m.SignatureOnPid) - i = encodeVarintTopicMessage(dAtA, i, uint64(len(m.SignatureOnPid))) - i-- - dAtA[i] = 0x2a - } - if len(m.Pk) > 0 { - i -= len(m.Pk) - copy(dAtA[i:], m.Pk) - i = encodeVarintTopicMessage(dAtA, i, uint64(len(m.Pk))) - i-- - dAtA[i] = 0x22 - } - if m.Timestamp != 0 { - i = encodeVarintTopicMessage(dAtA, i, uint64(m.Timestamp)) - i-- - dAtA[i] = 0x18 - } - if len(m.Payload) > 0 { - i -= len(m.Payload) - copy(dAtA[i:], m.Payload) - i = encodeVarintTopicMessage(dAtA, i, uint64(len(m.Payload))) - i-- - dAtA[i] = 0x12 - } - if m.Version != 0 { - i = encodeVarintTopicMessage(dAtA, i, uint64(m.Version)) - i-- - dAtA[i] = 0x8 - } - return len(dAtA) - i, nil -} - -func encodeVarintTopicMessage(dAtA []byte, offset int, v uint64) int { - offset -= sovTopicMessage(v) - base := offset - for v >= 1<<7 { - dAtA[offset] = uint8(v&0x7f | 0x80) - v >>= 7 - offset++ - } - dAtA[offset] = uint8(v) - return base -} -func (m *TopicMessage) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.Version != 0 { - n += 1 + sovTopicMessage(uint64(m.Version)) - } - l = len(m.Payload) - if l > 0 { - n += 1 + l + sovTopicMessage(uint64(l)) - } - if m.Timestamp != 0 { - n += 1 + sovTopicMessage(uint64(m.Timestamp)) - } - l = len(m.Pk) - if l > 0 { - n += 1 + l + sovTopicMessage(uint64(l)) - } - l = len(m.SignatureOnPid) - if l > 0 { - n += 1 + l + sovTopicMessage(uint64(l)) - } - return n -} - -func sovTopicMessage(x uint64) (n int) { - return (math_bits.Len64(x|1) + 6) / 7 -} -func sozTopicMessage(x uint64) (n int) { - return sovTopicMessage(uint64((x << 1) ^ uint64((int64(x) >> 63)))) -} -func (this *TopicMessage) String() string { - if this == nil { - return "nil" - } - s := strings.Join([]string{`&TopicMessage{`, - `Version:` + fmt.Sprintf("%v", this.Version) + `,`, - `Payload:` + fmt.Sprintf("%v", this.Payload) + `,`, - `Timestamp:` + fmt.Sprintf("%v", this.Timestamp) + `,`, - `Pk:` + fmt.Sprintf("%v", this.Pk) + `,`, - `SignatureOnPid:` + fmt.Sprintf("%v", this.SignatureOnPid) + `,`, - `}`, - }, "") - return s -} -func valueToStringTopicMessage(v interface{}) string { - rv := reflect.ValueOf(v) - if rv.IsNil() { - return "nil" - } - pv := reflect.Indirect(rv).Interface() - return fmt.Sprintf("*%v", pv) -} -func (m *TopicMessage) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: TopicMessage: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: TopicMessage: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Version", wireType) - } - m.Version = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Version |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Payload", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthTopicMessage - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthTopicMessage - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Payload = append(m.Payload[:0], dAtA[iNdEx:postIndex]...) - if m.Payload == nil { - m.Payload = []byte{} - } - iNdEx = postIndex - case 3: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Timestamp", wireType) - } - m.Timestamp = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Timestamp |= int64(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 4: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Pk", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthTopicMessage - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthTopicMessage - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Pk = append(m.Pk[:0], dAtA[iNdEx:postIndex]...) - if m.Pk == nil { - m.Pk = []byte{} - } - iNdEx = postIndex - case 5: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field SignatureOnPid", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthTopicMessage - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthTopicMessage - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.SignatureOnPid = append(m.SignatureOnPid[:0], dAtA[iNdEx:postIndex]...) - if m.SignatureOnPid == nil { - m.SignatureOnPid = []byte{} - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipTopicMessage(dAtA[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthTopicMessage - } - if (iNdEx + skippy) < 0 { - return ErrInvalidLengthTopicMessage - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func skipTopicMessage(dAtA []byte) (n int, err error) { - l := len(dAtA) - iNdEx := 0 - depth := 0 - for iNdEx < l { - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - wireType := int(wire & 0x7) - switch wireType { - case 0: - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - iNdEx++ - if dAtA[iNdEx-1] < 0x80 { - break - } - } - case 1: - iNdEx += 8 - case 2: - var length int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowTopicMessage - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - length |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if length < 0 { - return 0, ErrInvalidLengthTopicMessage - } - iNdEx += length - case 3: - depth++ - case 4: - if depth == 0 { - return 0, ErrUnexpectedEndOfGroupTopicMessage - } - depth-- - case 5: - iNdEx += 4 - default: - return 0, fmt.Errorf("proto: illegal wireType %d", wireType) - } - if iNdEx < 0 { - return 0, ErrInvalidLengthTopicMessage - } - if depth == 0 { - return iNdEx, nil - } - } - return 0, io.ErrUnexpectedEOF -} - -var ( - ErrInvalidLengthTopicMessage = fmt.Errorf("proto: negative length found during unmarshaling") - ErrIntOverflowTopicMessage = fmt.Errorf("proto: integer overflow") - ErrUnexpectedEndOfGroupTopicMessage = fmt.Errorf("proto: unexpected end of group") -) diff --git a/p2p/data/topicMessage.proto b/p2p/data/topicMessage.proto deleted file mode 100644 index be1927476de..00000000000 --- a/p2p/data/topicMessage.proto +++ /dev/null @@ -1,16 +0,0 @@ -syntax = "proto3"; - -package proto; - -option go_package = "data"; -option (gogoproto.stable_marshaler_all) = true; - -import "github.com/gogo/protobuf/gogoproto/gogo.proto"; - -message TopicMessage{ - uint32 Version = 1; - bytes Payload = 2; - int64 Timestamp = 3; - bytes Pk = 4; - bytes SignatureOnPid = 5; -} diff --git a/p2p/errors.go b/p2p/errors.go index fba838283db..fc3e7870b16 100644 --- a/p2p/errors.go +++ b/p2p/errors.go @@ -2,159 +2,15 @@ package p2p import ( "errors" -) - -// ErrNilContext signals that a nil context was provided -var ErrNilContext = errors.New("nil context") - -// ErrNilMockNet signals that a nil mocknet was provided. Should occur only in testing!!! -var ErrNilMockNet = errors.New("nil mocknet provided") - -// ErrNilTopic signals that a nil topic has been provided -var ErrNilTopic = errors.New("nil topic") - -// ErrTopicAlreadyExists signals that a topic already exists -var ErrTopicAlreadyExists = errors.New("topic already exists") - -// ErrTopicValidatorOperationNotSupported signals that an unsupported validator operation occurred -var ErrTopicValidatorOperationNotSupported = errors.New("topic validator operation is not supported") - -// ErrChannelDoesNotExist signals that a requested channel does not exist -var ErrChannelDoesNotExist = errors.New("channel does not exist") -// ErrChannelCanNotBeDeleted signals that a channel can not be deleted (might be the default channel) -var ErrChannelCanNotBeDeleted = errors.New("channel can not be deleted") - -// ErrChannelCanNotBeReAdded signals that a channel can not be re added as it is the default channel -var ErrChannelCanNotBeReAdded = errors.New("channel can not be re added") + p2p "github.com/ElrondNetwork/elrond-go-p2p" +) // ErrNilMessage signals that a nil message has been received -var ErrNilMessage = errors.New("nil message") - -// ErrAlreadySeenMessage signals that the message has already been seen -var ErrAlreadySeenMessage = errors.New("already seen this message") - -// ErrMessageTooNew signals that a message has a timestamp that is in the future relative to self -var ErrMessageTooNew = errors.New("message is too new") - -// ErrMessageTooOld signals that a message has a timestamp that is in the past relative to self -var ErrMessageTooOld = errors.New("message is too old") - -// ErrNilDirectSendMessageHandler signals that the message handler for new message has not been wired -var ErrNilDirectSendMessageHandler = errors.New("nil direct sender message handler") - -// ErrPeerNotDirectlyConnected signals that the peer is not directly connected to self -var ErrPeerNotDirectlyConnected = errors.New("peer is not directly connected") - -// ErrNilHost signals that a nil host has been provided -var ErrNilHost = errors.New("nil host") - -// ErrNilValidator signals that a validator hasn't been set for the required topic -var ErrNilValidator = errors.New("no validator has been set for this topic") - -// ErrPeerDiscoveryProcessAlreadyStarted signals that a peer discovery is already turned on -var ErrPeerDiscoveryProcessAlreadyStarted = errors.New("peer discovery is already turned on") - -// ErrMessageTooLarge signals that the message provided is too large -var ErrMessageTooLarge = errors.New("buffer too large") - -// ErrEmptyBufferToSend signals that an empty buffer was provided for sending to other peers -var ErrEmptyBufferToSend = errors.New("empty buffer to send") - -// ErrNilFetchPeersOnTopicHandler signals that a nil handler was provided -var ErrNilFetchPeersOnTopicHandler = errors.New("nil fetch peers on topic handler") - -// ErrInvalidDurationProvided signals that an invalid time.Duration has been provided -var ErrInvalidDurationProvided = errors.New("invalid time.Duration provided") - -// ErrTooManyGoroutines is raised when the number of goroutines has exceeded a threshold -var ErrTooManyGoroutines = errors.New(" number of goroutines exceeded") - -// ErrInvalidValue signals that an invalid value has been provided -var ErrInvalidValue = errors.New("invalid value") - -// ErrInvalidPortValue signals that an invalid port value has been provided -var ErrInvalidPortValue = errors.New("invalid port value") - -// ErrInvalidPortsRangeString signals that an invalid ports range string has been provided -var ErrInvalidPortsRangeString = errors.New("invalid ports range string") - -// ErrInvalidStartingPortValue signals that an invalid starting port value has been provided -var ErrInvalidStartingPortValue = errors.New("invalid starting port value") +var ErrNilMessage = p2p.ErrNilMessage -// ErrInvalidEndingPortValue signals that an invalid ending port value has been provided -var ErrInvalidEndingPortValue = errors.New("invalid ending port value") - -// ErrEndPortIsSmallerThanStartPort signals that the ending port value is smaller than the starting port value -var ErrEndPortIsSmallerThanStartPort = errors.New("ending port value is smaller than the starting port value") - -// ErrNoFreePortInRange signals that no free port was found from provided range -var ErrNoFreePortInRange = errors.New("no free port in range") - -// ErrNilSharder signals that the provided sharder is nil -var ErrNilSharder = errors.New("nil sharder") - -// ErrNilPeerShardResolver signals that the peer shard resolver provided is nil -var ErrNilPeerShardResolver = errors.New("nil PeerShardResolver") - -// ErrNilMarshalizer signals that an operation has been attempted to or with a nil marshalizer implementation -var ErrNilMarshalizer = errors.New("nil marshalizer") - -// ErrWrongTypeAssertion signals that a wrong type assertion occurred -var ErrWrongTypeAssertion = errors.New("wrong type assertion") - -// ErrNilReconnecter signals that a nil reconnecter has been provided -var ErrNilReconnecter = errors.New("nil reconnecter") - -// ErrUnwantedPeer signals that the provided peer has a longer kademlia distance in respect with the already connected -// peers and a connection to this peer will result in an immediate disconnection -var ErrUnwantedPeer = errors.New("unwanted peer: will not initiate connection as it will get disconnected") - -// ErrEmptySeed signals that an empty seed has been provided -var ErrEmptySeed = errors.New("empty seed") - -// ErrEmptyBuffer signals that an empty buffer has been provided -var ErrEmptyBuffer = errors.New("empty buffer") - -// ErrNilPeerDenialEvaluator signals that a nil peer denial evaluator was provided -var ErrNilPeerDenialEvaluator = errors.New("nil peer denial evaluator") +// ErrNilPreferredPeersHolder signals that a nil preferred peers holder was provided +var ErrNilPreferredPeersHolder = p2p.ErrNilPreferredPeersHolder // ErrNilStatusHandler signals that a nil status handler has been provided var ErrNilStatusHandler = errors.New("nil status handler") - -// ErrMessageUnmarshalError signals that an invalid message was received from a peer. There is no way to communicate -// with such a peer as it does not respect the protocol -var ErrMessageUnmarshalError = errors.New("message unmarshal error") - -// ErrUnsupportedFields signals that unsupported fields are provided -var ErrUnsupportedFields = errors.New("unsupported fields") - -// ErrUnsupportedMessageVersion signals that an unsupported message version was detected -var ErrUnsupportedMessageVersion = errors.New("unsupported message version") - -// ErrNilSyncTimer signals that a nil sync timer was provided -var ErrNilSyncTimer = errors.New("nil sync timer") - -// ErrNilPreferredPeersHolder signals that a nil preferred peers holder was provided -var ErrNilPreferredPeersHolder = errors.New("nil peers holder") - -// ErrInvalidSeedersReconnectionInterval signals that an invalid seeders reconnection interval error occurred -var ErrInvalidSeedersReconnectionInterval = errors.New("invalid seeders reconnection interval") - -// ErrMessageProcessorAlreadyDefined signals that a message processor was already defined on the provided topic and identifier -var ErrMessageProcessorAlreadyDefined = errors.New("message processor already defined") - -// ErrMessageProcessorDoesNotExists signals that a message processor does not exist on the provided topic and identifier -var ErrMessageProcessorDoesNotExists = errors.New("message processor does not exists") - -// ErrWrongTypeAssertions signals that a wrong type assertion occurred -var ErrWrongTypeAssertions = errors.New("wrong type assertion") - -// ErrNilConnectionsWatcher signals that a nil connections watcher has been provided -var ErrNilConnectionsWatcher = errors.New("nil connections watcher") - -// ErrNilPeersRatingHandler signals that a nil peers rating handler has been provided -var ErrNilPeersRatingHandler = errors.New("nil peers rating handler") - -// ErrNilCacher signals that a nil cacher has been provided -var ErrNilCacher = errors.New("nil cacher") diff --git a/p2p/interface.go b/p2p/interface.go new file mode 100644 index 00000000000..a49b2bc787c --- /dev/null +++ b/p2p/interface.go @@ -0,0 +1,103 @@ +package p2p + +import ( + "encoding/hex" + "time" + + "github.com/ElrondNetwork/elrond-go-core/core" + p2p "github.com/ElrondNetwork/elrond-go-p2p" +) + +// MessageProcessor is the interface used to describe what a receive message processor should do +// All implementations that will be called from Messenger implementation will need to satisfy this interface +// If the function returns a non nil value, the received message will not be propagated to its connected peers +type MessageProcessor = p2p.MessageProcessor + +// SendableData represents the struct used in data throttler implementation +type SendableData struct { + Buff []byte + Topic string +} + +// Messenger is the main struct used for communication with other peers +type Messenger = p2p.Messenger + +// MessageP2P defines what a p2p message can do (should return) +type MessageP2P = p2p.MessageP2P + +// ChannelLoadBalancer defines what a load balancer that uses chans should do +type ChannelLoadBalancer interface { + AddChannel(channel string) error + RemoveChannel(channel string) error + GetChannelOrDefault(channel string) chan *SendableData + CollectOneElementFromChannels() *SendableData + Close() error + IsInterfaceNil() bool +} + +// MessageOriginatorPid will output the message peer id in a pretty format +// If it can, it will display the last displayLastPidChars (12) characters from the pid +func MessageOriginatorPid(msg MessageP2P) string { + return PeerIdToShortString(msg.Peer()) +} + +// PeerIdToShortString trims the first displayLastPidChars characters of the provided peer ID after +// converting the peer ID to string using the Pretty functionality +func PeerIdToShortString(pid core.PeerID) string { + return p2p.PeerIdToShortString(pid) +} + +// MessageOriginatorSeq will output the sequence number as hex +func MessageOriginatorSeq(msg MessageP2P) string { + return hex.EncodeToString(msg.SeqNo()) +} + +// PeerShardResolver is able to resolve the link between the provided PeerID and the shardID +type PeerShardResolver = p2p.PeerShardResolver + +// ConnectedPeersInfo represents the DTO structure used to output the metrics for connected peers +type ConnectedPeersInfo = p2p.ConnectedPeersInfo + +// NetworkShardingCollector defines the updating methods used by the network sharding component +// The interface assures that the collected data will be used by the p2p network sharding components +type NetworkShardingCollector interface { + UpdatePeerIDInfo(pid core.PeerID, pk []byte, shardID uint32) + IsInterfaceNil() bool +} + +// PreferredPeersHolderHandler defines the behavior of a component able to handle preferred peers operations +type PreferredPeersHolderHandler interface { + PutConnectionAddress(peerID core.PeerID, address string) + PutShardID(peerID core.PeerID, shardID uint32) + Get() map[uint32][]core.PeerID + Contains(peerID core.PeerID) bool + Remove(peerID core.PeerID) + Clear() + IsInterfaceNil() bool +} + +// PeerDenialEvaluator defines the behavior of a component that is able to decide if a peer ID is black listed or not +// TODO merge this interface with the PeerShardResolver => P2PProtocolHandler ? +// TODO move antiflooding inside network messenger +type PeerDenialEvaluator = p2p.PeerDenialEvaluator + +// SyncTimer represent an entity able to tell the current time +type SyncTimer interface { + CurrentTime() time.Time + IsInterfaceNil() bool +} + +// PeersRatingHandler represent an entity able to handle peers ratings +type PeersRatingHandler interface { + AddPeer(pid core.PeerID) + IncreaseRating(pid core.PeerID) + DecreaseRating(pid core.PeerID) + GetTopRatedPeersFromList(peers []core.PeerID, minNumOfPeersExpected int) []core.PeerID + IsInterfaceNil() bool +} + +// RandomP2PIdentityGenerator defines an entity that is able to generate a random p2p identity +type RandomP2PIdentityGenerator interface { + CreateRandomP2PIdentity() ([]byte, core.PeerID, error) + IsInterfaceNil() bool +} diff --git a/p2p/issues/pubsub_349/main.go b/p2p/issues/pubsub_349/main.go deleted file mode 100644 index 1644011fbae..00000000000 --- a/p2p/issues/pubsub_349/main.go +++ /dev/null @@ -1,157 +0,0 @@ -package main - -import ( - "context" - "crypto/ecdsa" - "crypto/rand" - "fmt" - "time" - - pubsub "github.com/ElrondNetwork/go-libp2p-pubsub" - "github.com/btcsuite/btcd/btcec" - "github.com/libp2p/go-libp2p" - libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/host" - "github.com/libp2p/go-libp2p-core/peer" -) - -type messenger struct { - host host.Host - pb *pubsub.PubSub - topic *pubsub.Topic - subscr *pubsub.Subscription -} - -func newMessenger() *messenger { - address := "/ip4/0.0.0.0/tcp/0" - opts := []libp2p.Option{ - libp2p.ListenAddrStrings(address), - libp2p.Identity(createP2PPrivKey()), - libp2p.DefaultMuxers, - libp2p.DefaultSecurity, - libp2p.DefaultTransports, - // we need the disable relay option in order to save the node's bandwidth as much as possible - libp2p.DisableRelay(), - libp2p.NATPortMap(), - } - - h, _ := libp2p.New(opts...) - optsPS := make([]pubsub.Option, 0) - pb, _ := pubsub.NewGossipSub(context.Background(), h, optsPS...) - - return &messenger{ - host: h, - pb: pb, - } -} - -func createP2PPrivKey() *libp2pCrypto.Secp256k1PrivateKey { - prvKey, _ := ecdsa.GenerateKey(btcec.S256(), rand.Reader) - return (*libp2pCrypto.Secp256k1PrivateKey)(prvKey) -} - -func (m *messenger) connectTo(target *messenger) { - addr := peer.AddrInfo{ - ID: target.host.ID(), - Addrs: target.host.Addrs(), - } - - err := m.host.Connect(context.Background(), addr) - if err != nil { - fmt.Println("error connecting to peer: " + err.Error()) - } -} - -func (m *messenger) joinTopic(topic string) { - m.topic, _ = m.pb.Join(topic) - m.subscr, _ = m.topic.Subscribe() - - go func() { - for { - msg, err := m.subscr.Next(context.Background()) - if err != nil { - return - } - - fmt.Printf("%s: got message %s\n", m.host.ID().Pretty(), string(msg.Data)) - } - }() - -} - -func main() { - fmt.Println("creating 8 host connected statically...") - peers := create8ConnectedPeers() - - defer func() { - for _, p := range peers { - _ = p.host.Close() - } - }() - - fmt.Println() - - for _, p := range peers { - p.joinTopic("test") - } - - go func() { - time.Sleep(time.Second * 2) - // TODO uncomment these 2 lines to make the pubsub create connections - // peers[3].subscr.Cancel() - // _ = peers[3].topic.Close() - }() - - for i := 0; i < 10; i++ { - printConnections(peers) - fmt.Println() - time.Sleep(time.Second) - } -} - -func printConnections(peers []*messenger) { - for _, p := range peers { - fmt.Printf(" %s is connected to %d peers\n", p.host.ID().Pretty(), len(p.host.Network().Peers())) - } -} - -// create8ConnectedPeers assembles a network as following: -// -// 0------------------- 1 -// | | -// 2 ------------------ 3 ------------------ 4 -// | | | -// 5 6 7 -func create8ConnectedPeers() []*messenger { - peers := make([]*messenger, 0) - for i := 0; i < 8; i++ { - p := newMessenger() - fmt.Printf("%d - created peer %s\n", i, p.host.ID().Pretty()) - - peers = append(peers, p) - } - - connections := map[int][]int{ - 0: {1, 3}, - 1: {4}, - 2: {5, 3}, - 3: {4, 6}, - 4: {7}, - } - - createConnections(peers, connections) - - return peers -} - -func createConnections(peers []*messenger, connections map[int][]int) { - for pid, connectTo := range connections { - connectPeerToOthers(peers, pid, connectTo) - } -} - -func connectPeerToOthers(peers []*messenger, idx int, connectToIdxes []int) { - for _, connectToIdx := range connectToIdxes { - peers[idx].connectTo(peers[connectToIdx]) - } -} diff --git a/p2p/libp2p/connectableHost.go b/p2p/libp2p/connectableHost.go deleted file mode 100644 index 2b59ba6aee6..00000000000 --- a/p2p/libp2p/connectableHost.go +++ /dev/null @@ -1,56 +0,0 @@ -package libp2p - -import ( - "context" - - "github.com/libp2p/go-libp2p-core/host" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/multiformats/go-multiaddr" -) - -// PeerInfoHandler is the signature of the handler that gets called whenever an action for a peerInfo is triggered -type PeerInfoHandler func(pInfo peer.AddrInfo) - -// ConnectableHost is an enhanced Host interface that has the ability to connect to a string address -type ConnectableHost interface { - host.Host - ConnectToPeer(ctx context.Context, address string) error - AddressToPeerInfo(address string) (*peer.AddrInfo, error) - IsInterfaceNil() bool -} - -type connectableHost struct { - host.Host -} - -// NewConnectableHost creates a new connectable host implementation -func NewConnectableHost(h host.Host) *connectableHost { - return &connectableHost{ - Host: h, - } -} - -// AddressToPeerInfo converts the unified string address into libp2p address components: PeerID and multi-address slice -func (connHost *connectableHost) AddressToPeerInfo(address string) (*peer.AddrInfo, error) { - multiAddr, err := multiaddr.NewMultiaddr(address) - if err != nil { - return nil, err - } - - return peer.AddrInfoFromP2pAddr(multiAddr) -} - -// ConnectToPeer connects to a peer by knowing its string address -func (connHost *connectableHost) ConnectToPeer(ctx context.Context, address string) error { - pInfo, err := connHost.AddressToPeerInfo(address) - if err != nil { - return err - } - - return connHost.Connect(ctx, *pInfo) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (connHost *connectableHost) IsInterfaceNil() bool { - return connHost == nil -} diff --git a/p2p/libp2p/connectableHost_test.go b/p2p/libp2p/connectableHost_test.go deleted file mode 100644 index 629a2b5c76b..00000000000 --- a/p2p/libp2p/connectableHost_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package libp2p - -import ( - "context" - "testing" - - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" -) - -func TestConnectableHost_ConnectToPeerWrongAddressShouldErr(t *testing.T) { - uhs := &mock.ConnectableHostStub{} - // we can safely use an upgraded instead of a real host as to not create another (useless) stub - uh := NewConnectableHost(uhs) - - err := uh.ConnectToPeer(context.Background(), "invalid address") - - assert.NotNil(t, err) -} - -func TestConnectableHost_ConnectToPeerShouldWork(t *testing.T) { - wasCalled := false - - uhs := &mock.ConnectableHostStub{ - ConnectCalled: func(ctx context.Context, pi peer.AddrInfo) error { - wasCalled = true - return nil - }, - } - // we can safely use an upgraded instead of a real host as to not create another (useless) stub - uh := NewConnectableHost(uhs) - - validAddress := "/ip4/82.5.34.12/tcp/23000/p2p/16Uiu2HAkyqtHSEJDkYhVWTtm9j58Mq5xQJgrApBYXMwS6sdamXuE" - err := uh.ConnectToPeer(context.Background(), validAddress) - - assert.Nil(t, err) - assert.True(t, wasCalled) -} diff --git a/p2p/libp2p/connectionMonitor/interface.go b/p2p/libp2p/connectionMonitor/interface.go deleted file mode 100644 index 9ecf7adb0a1..00000000000 --- a/p2p/libp2p/connectionMonitor/interface.go +++ /dev/null @@ -1,15 +0,0 @@ -package connectionMonitor - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/libp2p/go-libp2p-core/peer" -) - -// Sharder defines the eviction computing process of unwanted peers -type Sharder interface { - ComputeEvictionList(pidList []peer.ID) []peer.ID - Has(pid peer.ID, list []peer.ID) bool - SetSeeders(addresses []string) - IsSeeder(pid core.PeerID) bool - IsInterfaceNil() bool -} diff --git a/p2p/libp2p/connectionMonitor/libp2pConnectionMonitorSimple.go b/p2p/libp2p/connectionMonitor/libp2pConnectionMonitorSimple.go deleted file mode 100644 index e67359400fd..00000000000 --- a/p2p/libp2p/connectionMonitor/libp2pConnectionMonitorSimple.go +++ /dev/null @@ -1,171 +0,0 @@ -package connectionMonitor - -import ( - "context" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/network" - "github.com/multiformats/go-multiaddr" -) - -// DurationBetweenReconnectAttempts is used as to not call reconnecter.ReconnectToNetwork() too often -// when there are a lot of peers disconnecting and reconnection to initial nodes succeeds -var DurationBetweenReconnectAttempts = time.Second * 5 -var log = logger.GetOrCreate("p2p/libp2p/connectionmonitor") - -type libp2pConnectionMonitorSimple struct { - chDoReconnect chan struct{} - reconnecter p2p.Reconnecter - thresholdMinConnectedPeers int - sharder Sharder - preferredPeersHolder p2p.PreferredPeersHolderHandler - cancelFunc context.CancelFunc - connectionsWatcher p2p.ConnectionsWatcher -} - -// ArgsConnectionMonitorSimple is the DTO used in the NewLibp2pConnectionMonitorSimple constructor function -type ArgsConnectionMonitorSimple struct { - Reconnecter p2p.Reconnecter - ThresholdMinConnectedPeers uint32 - Sharder Sharder - PreferredPeersHolder p2p.PreferredPeersHolderHandler - ConnectionsWatcher p2p.ConnectionsWatcher -} - -// NewLibp2pConnectionMonitorSimple creates a new connection monitor (version 2 that is more streamlined and does not care -// about pausing and resuming the discovery process) -func NewLibp2pConnectionMonitorSimple(args ArgsConnectionMonitorSimple) (*libp2pConnectionMonitorSimple, error) { - if check.IfNil(args.Reconnecter) { - return nil, p2p.ErrNilReconnecter - } - if check.IfNil(args.Sharder) { - return nil, p2p.ErrNilSharder - } - if check.IfNil(args.PreferredPeersHolder) { - return nil, p2p.ErrNilPreferredPeersHolder - } - if check.IfNil(args.ConnectionsWatcher) { - return nil, p2p.ErrNilConnectionsWatcher - } - - ctx, cancelFunc := context.WithCancel(context.Background()) - - cm := &libp2pConnectionMonitorSimple{ - reconnecter: args.Reconnecter, - chDoReconnect: make(chan struct{}), - thresholdMinConnectedPeers: int(args.ThresholdMinConnectedPeers), - sharder: args.Sharder, - cancelFunc: cancelFunc, - preferredPeersHolder: args.PreferredPeersHolder, - connectionsWatcher: args.ConnectionsWatcher, - } - - go cm.doReconnection(ctx) - - return cm, nil -} - -// Listen is called when network starts listening on an addr -func (lcms *libp2pConnectionMonitorSimple) Listen(network.Network, multiaddr.Multiaddr) {} - -// ListenClose is called when network stops listening on an addr -func (lcms *libp2pConnectionMonitorSimple) ListenClose(network.Network, multiaddr.Multiaddr) {} - -// Request a reconnect to initial list -func (lcms *libp2pConnectionMonitorSimple) doReconn() { - select { - case lcms.chDoReconnect <- struct{}{}: - default: - } -} - -// Connected is called when a connection opened -func (lcms *libp2pConnectionMonitorSimple) Connected(netw network.Network, conn network.Conn) { - allPeers := netw.Peers() - - peerId := core.PeerID(conn.RemotePeer()) - connectionStr := conn.RemoteMultiaddr().String() - lcms.connectionsWatcher.NewKnownConnection(peerId, connectionStr) - lcms.preferredPeersHolder.PutConnectionAddress(peerId, connectionStr) - - evicted := lcms.sharder.ComputeEvictionList(allPeers) - for _, pid := range evicted { - _ = netw.ClosePeer(pid) - } -} - -// Disconnected is called when a connection closed -func (lcms *libp2pConnectionMonitorSimple) Disconnected(netw network.Network, conn network.Conn) { - if conn != nil { - lcms.preferredPeersHolder.Remove(core.PeerID(conn.ID())) - } - - lcms.doReconnectionIfNeeded(netw) -} - -func (lcms *libp2pConnectionMonitorSimple) doReconnectionIfNeeded(netw network.Network) { - if !lcms.IsConnectedToTheNetwork(netw) { - lcms.doReconn() - } -} - -// OpenedStream is called when a stream opened -func (lcms *libp2pConnectionMonitorSimple) OpenedStream(network.Network, network.Stream) {} - -// ClosedStream is called when a stream closed -func (lcms *libp2pConnectionMonitorSimple) ClosedStream(network.Network, network.Stream) {} - -func (lcms *libp2pConnectionMonitorSimple) doReconnection(ctx context.Context) { - defer func() { - log.Debug("closing the connection monitor main loop") - }() - - for { - select { - case <-lcms.chDoReconnect: - case <-ctx.Done(): - return - } - lcms.reconnecter.ReconnectToNetwork(ctx) - - select { - case <-time.After(DurationBetweenReconnectAttempts): - case <-ctx.Done(): - return - } - } -} - -// IsConnectedToTheNetwork returns true if the number of connected peer is at least equal with thresholdMinConnectedPeers -func (lcms *libp2pConnectionMonitorSimple) IsConnectedToTheNetwork(netw network.Network) bool { - return len(netw.Peers()) >= lcms.thresholdMinConnectedPeers -} - -// SetThresholdMinConnectedPeers sets the minimum connected peers number when the node is considered connected on the network -func (lcms *libp2pConnectionMonitorSimple) SetThresholdMinConnectedPeers(thresholdMinConnectedPeers int, netw network.Network) { - if check.IfNilReflect(netw) { - return - } - lcms.thresholdMinConnectedPeers = thresholdMinConnectedPeers - lcms.doReconnectionIfNeeded(netw) -} - -// ThresholdMinConnectedPeers returns the minimum connected peers number when the node is considered connected on the network -func (lcms *libp2pConnectionMonitorSimple) ThresholdMinConnectedPeers() int { - return lcms.thresholdMinConnectedPeers -} - -// Close closes all underlying components -func (lcms *libp2pConnectionMonitorSimple) Close() error { - lcms.cancelFunc() - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (lcms *libp2pConnectionMonitorSimple) IsInterfaceNil() bool { - return lcms == nil -} diff --git a/p2p/libp2p/connectionMonitor/libp2pConnectionMonitorSimple_test.go b/p2p/libp2p/connectionMonitor/libp2pConnectionMonitorSimple_test.go deleted file mode 100644 index 74183699c1e..00000000000 --- a/p2p/libp2p/connectionMonitor/libp2pConnectionMonitorSimple_test.go +++ /dev/null @@ -1,257 +0,0 @@ -package connectionMonitor - -import ( - "context" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const durationTimeoutWaiting = time.Second * 2 -const durationStartGoRoutine = time.Second - -func createMockArgsConnectionMonitorSimple() ArgsConnectionMonitorSimple { - return ArgsConnectionMonitorSimple{ - Reconnecter: &mock.ReconnecterStub{}, - ThresholdMinConnectedPeers: 3, - Sharder: &mock.KadSharderStub{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - ConnectionsWatcher: &mock.ConnectionsWatcherStub{}, - } -} - -func TestNewLibp2pConnectionMonitorSimple(t *testing.T) { - t.Parallel() - - t.Run("nil reconnecter should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgsConnectionMonitorSimple() - args.Reconnecter = nil - lcms, err := NewLibp2pConnectionMonitorSimple(args) - - assert.Equal(t, p2p.ErrNilReconnecter, err) - assert.True(t, check.IfNil(lcms)) - }) - t.Run("nil sharder should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgsConnectionMonitorSimple() - args.Sharder = nil - lcms, err := NewLibp2pConnectionMonitorSimple(args) - - assert.Equal(t, p2p.ErrNilSharder, err) - assert.True(t, check.IfNil(lcms)) - }) - t.Run("nil preferred peers holder should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgsConnectionMonitorSimple() - args.PreferredPeersHolder = nil - lcms, err := NewLibp2pConnectionMonitorSimple(args) - - assert.Equal(t, p2p.ErrNilPreferredPeersHolder, err) - assert.True(t, check.IfNil(lcms)) - }) - t.Run("nil connections watcher should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgsConnectionMonitorSimple() - args.ConnectionsWatcher = nil - lcms, err := NewLibp2pConnectionMonitorSimple(args) - - assert.Equal(t, p2p.ErrNilConnectionsWatcher, err) - assert.True(t, check.IfNil(lcms)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - args := createMockArgsConnectionMonitorSimple() - lcms, err := NewLibp2pConnectionMonitorSimple(args) - - assert.Nil(t, err) - assert.False(t, check.IfNil(lcms)) - }) -} - -func TestNewLibp2pConnectionMonitorSimple_OnDisconnectedUnderThresholdShouldCallReconnect(t *testing.T) { - t.Parallel() - - chReconnectCalled := make(chan struct{}, 1) - - rs := &mock.ReconnecterStub{ - ReconnectToNetworkCalled: func(ctx context.Context) { - chReconnectCalled <- struct{}{} - }, - } - - ns := mock.NetworkStub{ - PeersCall: func() []peer.ID { - // only one connection which is under the threshold - return []peer.ID{"mock"} - }, - } - - args := createMockArgsConnectionMonitorSimple() - args.Reconnecter = rs - lcms, _ := NewLibp2pConnectionMonitorSimple(args) - time.Sleep(durationStartGoRoutine) - lcms.Disconnected(&ns, nil) - - select { - case <-chReconnectCalled: - case <-time.After(durationTimeoutWaiting): - assert.Fail(t, "timeout waiting to call reconnect") - } -} - -func TestLibp2pConnectionMonitorSimple_ConnectedWithSharderShouldCallEvictAndClosePeer(t *testing.T) { - t.Parallel() - - evictedPid := []peer.ID{"evicted"} - numComputeWasCalled := 0 - numClosedWasCalled := 0 - args := createMockArgsConnectionMonitorSimple() - args.Sharder = &mock.KadSharderStub{ - ComputeEvictListCalled: func(pidList []peer.ID) []peer.ID { - numComputeWasCalled++ - return evictedPid - }, - } - knownConnectionCalled := false - args.ConnectionsWatcher = &mock.ConnectionsWatcherStub{ - NewKnownConnectionCalled: func(pid core.PeerID, connection string) { - knownConnectionCalled = true - }, - } - putConnectionAddressCalled := false - args.PreferredPeersHolder = &p2pmocks.PeersHolderStub{ - PutConnectionAddressCalled: func(peerID core.PeerID, addressSlice string) { - putConnectionAddressCalled = true - }, - } - lcms, _ := NewLibp2pConnectionMonitorSimple(args) - - lcms.Connected( - &mock.NetworkStub{ - ClosePeerCall: func(id peer.ID) error { - numClosedWasCalled++ - return nil - }, - PeersCall: func() []peer.ID { - return nil - }, - }, - &mock.ConnStub{ - RemotePeerCalled: func() peer.ID { - return evictedPid[0] - }, - }, - ) - - assert.Equal(t, 1, numClosedWasCalled) - assert.Equal(t, 1, numComputeWasCalled) - assert.True(t, knownConnectionCalled) - assert.True(t, putConnectionAddressCalled) -} - -func TestNewLibp2pConnectionMonitorSimple_DisconnectedShouldRemovePeerFromPreferredPeers(t *testing.T) { - t.Parallel() - - prefPeerID := "preferred peer 0" - chRemoveCalled := make(chan struct{}, 1) - - ns := mock.NetworkStub{ - PeersCall: func() []peer.ID { - // only one connection which is under the threshold - return []peer.ID{"mock"} - }, - } - - removeCalled := false - prefPeersHolder := &p2pmocks.PeersHolderStub{ - RemoveCalled: func(peerID core.PeerID) { - removeCalled = true - require.Equal(t, core.PeerID(prefPeerID), peerID) - chRemoveCalled <- struct{}{} - }, - } - - args := createMockArgsConnectionMonitorSimple() - args.PreferredPeersHolder = prefPeersHolder - lcms, _ := NewLibp2pConnectionMonitorSimple(args) - lcms.Disconnected(&ns, &mock.ConnStub{ - IDCalled: func() string { - return prefPeerID - }, - }) - - require.True(t, removeCalled) - select { - case <-chRemoveCalled: - case <-time.After(durationTimeoutWaiting): - assert.Fail(t, "timeout waiting to call reconnect") - } -} - -func TestLibp2pConnectionMonitorSimple_EmptyFuncsShouldNotPanic(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r != nil { - assert.Fail(t, "should not have panic") - } - }() - - netw := &mock.NetworkStub{ - PeersCall: func() []peer.ID { - return make([]peer.ID, 0) - }, - } - - args := createMockArgsConnectionMonitorSimple() - lcms, _ := NewLibp2pConnectionMonitorSimple(args) - - lcms.ClosedStream(netw, nil) - lcms.Disconnected(netw, nil) - lcms.Listen(netw, nil) - lcms.ListenClose(netw, nil) - lcms.OpenedStream(netw, nil) -} - -func TestLibp2pConnectionMonitorSimple_SetThresholdMinConnectedPeers(t *testing.T) { - t.Parallel() - - args := createMockArgsConnectionMonitorSimple() - lcms, _ := NewLibp2pConnectionMonitorSimple(args) - - thr := 10 - lcms.SetThresholdMinConnectedPeers(thr, &mock.NetworkStub{}) - thrSet := lcms.ThresholdMinConnectedPeers() - - assert.Equal(t, thr, thrSet) -} - -func TestLibp2pConnectionMonitorSimple_SetThresholdMinConnectedPeersNilNetwShouldDoNothing(t *testing.T) { - t.Parallel() - - minConnPeers := uint32(3) - args := createMockArgsConnectionMonitorSimple() - args.ThresholdMinConnectedPeers = minConnPeers - lcms, _ := NewLibp2pConnectionMonitorSimple(args) - - thr := 10 - lcms.SetThresholdMinConnectedPeers(thr, nil) - thrSet := lcms.ThresholdMinConnectedPeers() - - assert.Equal(t, uint32(thrSet), minConnPeers) -} diff --git a/p2p/libp2p/connectionMonitorWrapper.go b/p2p/libp2p/connectionMonitorWrapper.go deleted file mode 100644 index dcb67630fff..00000000000 --- a/p2p/libp2p/connectionMonitorWrapper.go +++ /dev/null @@ -1,122 +0,0 @@ -package libp2p - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/network" - "github.com/multiformats/go-multiaddr" -) - -var _ ConnectionMonitor = (*connectionMonitorWrapper)(nil) - -// connectionMonitorWrapper is a wrapper over p2p.ConnectionMonitor that satisfies the Notifiee interface -// and is able to be notified by the current running host (connection status changes) -// it handles black list peers -type connectionMonitorWrapper struct { - ConnectionMonitor - network network.Network - mutPeerBlackList sync.RWMutex - peerDenialEvaluator p2p.PeerDenialEvaluator -} - -func newConnectionMonitorWrapper( - network network.Network, - connMonitor ConnectionMonitor, - peerDenialEvaluator p2p.PeerDenialEvaluator, -) *connectionMonitorWrapper { - return &connectionMonitorWrapper{ - ConnectionMonitor: connMonitor, - network: network, - peerDenialEvaluator: peerDenialEvaluator, - } -} - -// Listen is called when network starts listening on an addr -func (cmw *connectionMonitorWrapper) Listen(netw network.Network, ma multiaddr.Multiaddr) { - cmw.ConnectionMonitor.Listen(netw, ma) -} - -// ListenClose is called when network stops listening on an addr -func (cmw *connectionMonitorWrapper) ListenClose(netw network.Network, ma multiaddr.Multiaddr) { - cmw.ConnectionMonitor.ListenClose(netw, ma) -} - -// Connected is called when a connection opened -func (cmw *connectionMonitorWrapper) Connected(netw network.Network, conn network.Conn) { - cmw.mutPeerBlackList.RLock() - peerBlackList := cmw.peerDenialEvaluator - cmw.mutPeerBlackList.RUnlock() - - pid := conn.RemotePeer() - if peerBlackList.IsDenied(core.PeerID(pid)) { - log.Trace("dropping connection to blacklisted peer", - "pid", pid.Pretty(), - ) - _ = conn.Close() - - return - } - - cmw.ConnectionMonitor.Connected(netw, conn) -} - -// Disconnected is called when a connection closed -func (cmw *connectionMonitorWrapper) Disconnected(netw network.Network, conn network.Conn) { - cmw.ConnectionMonitor.Disconnected(netw, conn) -} - -// OpenedStream is called when a stream opened -func (cmw *connectionMonitorWrapper) OpenedStream(netw network.Network, stream network.Stream) { - cmw.ConnectionMonitor.OpenedStream(netw, stream) -} - -// ClosedStream is called when a stream closed -func (cmw *connectionMonitorWrapper) ClosedStream(netw network.Network, stream network.Stream) { - cmw.ConnectionMonitor.ClosedStream(netw, stream) -} - -// CheckConnectionsBlocking does a peer sweep, calling Close on those peers that are black listed -func (cmw *connectionMonitorWrapper) CheckConnectionsBlocking() { - peers := cmw.network.Peers() - cmw.mutPeerBlackList.RLock() - peerDenialEvaluator := cmw.peerDenialEvaluator - cmw.mutPeerBlackList.RUnlock() - - for _, pid := range peers { - if peerDenialEvaluator.IsDenied(core.PeerID(pid)) { - log.Trace("dropping connection to blacklisted peer", - "pid", pid.Pretty(), - ) - _ = cmw.network.ClosePeer(pid) - } - } -} - -// SetPeerDenialEvaluator sets the handler that is able to tell if a peer can connect to self or not (is or not blacklisted) -func (cmw *connectionMonitorWrapper) SetPeerDenialEvaluator(handler p2p.PeerDenialEvaluator) error { - if check.IfNil(handler) { - return p2p.ErrNilPeerDenialEvaluator - } - - cmw.mutPeerBlackList.Lock() - cmw.peerDenialEvaluator = handler - cmw.mutPeerBlackList.Unlock() - - return nil -} - -// PeerDenialEvaluator gets the peer denial evauator -func (cmw *connectionMonitorWrapper) PeerDenialEvaluator() p2p.PeerDenialEvaluator { - cmw.mutPeerBlackList.RLock() - defer cmw.mutPeerBlackList.RUnlock() - - return cmw.peerDenialEvaluator -} - -// IsInterfaceNil returns true if there is no value under the interface -func (cmw *connectionMonitorWrapper) IsInterfaceNil() bool { - return cmw == nil -} diff --git a/p2p/libp2p/connectionMonitorWrapper_test.go b/p2p/libp2p/connectionMonitorWrapper_test.go deleted file mode 100644 index 5ac9a0f4e07..00000000000 --- a/p2p/libp2p/connectionMonitorWrapper_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package libp2p - -import ( - "bytes" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/multiformats/go-multiaddr" - "github.com/stretchr/testify/assert" -) - -func createStubConn() *mock.ConnStub { - return &mock.ConnStub{ - RemotePeerCalled: func() peer.ID { - return "remote peer" - }, - } -} - -func TestNewConnectionMonitorWrapper_ShouldWork(t *testing.T) { - t.Parallel() - - cmw := newConnectionMonitorWrapper( - &mock.NetworkStub{}, - &mock.ConnectionMonitorStub{}, - &mock.PeerDenialEvaluatorStub{}, - ) - - assert.False(t, check.IfNil(cmw)) -} - -// ------- Connected - -func TestConnectionMonitorNotifier_ConnectedBlackListedShouldCallClose(t *testing.T) { - t.Parallel() - - peerCloseCalled := false - conn := createStubConn() - conn.CloseCalled = func() error { - peerCloseCalled = true - - return nil - } - cmw := newConnectionMonitorWrapper( - &mock.NetworkStub{}, - &mock.ConnectionMonitorStub{}, - &mock.PeerDenialEvaluatorStub{ - IsDeniedCalled: func(pid core.PeerID) bool { - return true - }, - }, - ) - - cmw.Connected(cmw.network, conn) - - assert.True(t, peerCloseCalled) -} - -func TestConnectionMonitorNotifier_ConnectedNotBlackListedShouldCallConnected(t *testing.T) { - t.Parallel() - - peerConnectedCalled := false - conn := createStubConn() - cmw := newConnectionMonitorWrapper( - &mock.NetworkStub{}, - &mock.ConnectionMonitorStub{ - ConnectedCalled: func(netw network.Network, conn network.Conn) { - peerConnectedCalled = true - }, - }, - &mock.PeerDenialEvaluatorStub{ - IsDeniedCalled: func(pid core.PeerID) bool { - return false - }, - }, - ) - - cmw.Connected(cmw.network, conn) - - assert.True(t, peerConnectedCalled) -} - -// ------- Functions - -func TestConnectionMonitorNotifier_FunctionsShouldCallHandler(t *testing.T) { - t.Parallel() - - listenCalled := false - listenCloseCalled := false - disconnectCalled := false - openedCalled := false - closedCalled := false - cmw := newConnectionMonitorWrapper( - &mock.NetworkStub{}, - &mock.ConnectionMonitorStub{ - ListenCalled: func(network.Network, multiaddr.Multiaddr) { - listenCalled = true - }, - ListenCloseCalled: func(network.Network, multiaddr.Multiaddr) { - listenCloseCalled = true - }, - DisconnectedCalled: func(network.Network, network.Conn) { - disconnectCalled = true - }, - OpenedStreamCalled: func(network.Network, network.Stream) { - openedCalled = true - }, - ClosedStreamCalled: func(network.Network, network.Stream) { - closedCalled = true - }, - }, - &mock.PeerDenialEvaluatorStub{}, - ) - - cmw.Listen(nil, nil) - cmw.ListenClose(nil, nil) - cmw.Disconnected(nil, nil) - cmw.OpenedStream(nil, nil) - cmw.ClosedStream(nil, nil) - - assert.True(t, listenCalled) - assert.True(t, listenCloseCalled) - assert.True(t, disconnectCalled) - assert.True(t, openedCalled) - assert.True(t, closedCalled) -} - -// ------- SetBlackListHandler - -func TestConnectionMonitorWrapper_SetBlackListHandlerNilHandlerShouldErr(t *testing.T) { - t.Parallel() - - cmw := newConnectionMonitorWrapper( - &mock.NetworkStub{}, - &mock.ConnectionMonitorStub{}, - &mock.PeerDenialEvaluatorStub{}, - ) - - err := cmw.SetPeerDenialEvaluator(nil) - - assert.Equal(t, p2p.ErrNilPeerDenialEvaluator, err) -} - -func TestConnectionMonitorWrapper_SetBlackListHandlerShouldWork(t *testing.T) { - t.Parallel() - - cmw := newConnectionMonitorWrapper( - &mock.NetworkStub{}, - &mock.ConnectionMonitorStub{}, - &mock.PeerDenialEvaluatorStub{}, - ) - newPeerDenialEvaluator := &mock.PeerDenialEvaluatorStub{} - - err := cmw.SetPeerDenialEvaluator(newPeerDenialEvaluator) - - assert.Nil(t, err) - // pointer testing - assert.True(t, newPeerDenialEvaluator == cmw.peerDenialEvaluator) - assert.True(t, newPeerDenialEvaluator == cmw.PeerDenialEvaluator()) -} - -// ------- CheckConnectionsBlocking - -func TestConnectionMonitorWrapper_CheckConnectionsBlockingShouldWork(t *testing.T) { - t.Parallel() - - whiteListPeer := peer.ID("whitelisted") - blackListPeer := peer.ID("blacklisted") - closeCalled := 0 - cmw := newConnectionMonitorWrapper( - &mock.NetworkStub{ - PeersCall: func() []peer.ID { - return []peer.ID{whiteListPeer, blackListPeer} - }, - ClosePeerCall: func(id peer.ID) error { - if id == blackListPeer { - closeCalled++ - return nil - } - assert.Fail(t, "should have called only the black listed peer ") - - return nil - }, - }, - &mock.ConnectionMonitorStub{}, - &mock.PeerDenialEvaluatorStub{ - IsDeniedCalled: func(pid core.PeerID) bool { - return bytes.Equal(core.PeerID(blackListPeer).Bytes(), pid.Bytes()) - }, - }, - ) - - cmw.CheckConnectionsBlocking() - assert.Equal(t, 1, closeCalled) -} diff --git a/p2p/libp2p/directSender.go b/p2p/libp2p/directSender.go deleted file mode 100644 index 19d6f840a24..00000000000 --- a/p2p/libp2p/directSender.go +++ /dev/null @@ -1,254 +0,0 @@ -package libp2p - -import ( - "bufio" - "bytes" - "context" - "encoding/binary" - "fmt" - "io" - "sync" - "sync/atomic" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" - pubsub "github.com/ElrondNetwork/go-libp2p-pubsub" - pubsubPb "github.com/ElrondNetwork/go-libp2p-pubsub/pb" - ggio "github.com/gogo/protobuf/io" - "github.com/libp2p/go-libp2p-core/host" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/whyrusleeping/timecache" -) - -var _ p2p.DirectSender = (*directSender)(nil) - -const timeSeenMessages = time.Second * 120 -const maxMutexes = 10000 - -type directSender struct { - counter uint64 - ctx context.Context - hostP2P host.Host - messageHandler func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error - mutSeenMessages sync.Mutex - seenMessages *timecache.TimeCache - mutexForPeer *MutexHolder -} - -// NewDirectSender returns a new instance of direct sender object -func NewDirectSender( - ctx context.Context, - h host.Host, - messageHandler func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error, -) (*directSender, error) { - - if h == nil { - return nil, p2p.ErrNilHost - } - if ctx == nil { - return nil, p2p.ErrNilContext - } - if messageHandler == nil { - return nil, p2p.ErrNilDirectSendMessageHandler - } - - mutexForPeer, err := NewMutexHolder(maxMutexes) - if err != nil { - return nil, err - } - - ds := &directSender{ - counter: uint64(time.Now().UnixNano()), - ctx: ctx, - hostP2P: h, - seenMessages: timecache.NewTimeCache(timeSeenMessages), - messageHandler: messageHandler, - mutexForPeer: mutexForPeer, - } - - // wire-up a handler for direct messages - h.SetStreamHandler(DirectSendID, ds.directStreamHandler) - - return ds, nil -} - -func (ds *directSender) directStreamHandler(s network.Stream) { - reader := ggio.NewDelimitedReader(s, maxSendBuffSize) - - go func(r ggio.ReadCloser) { - for { - msg := &pubsubPb.Message{} - - err := reader.ReadMsg(msg) - if err != nil { - // stream has encountered an error, close this go routine - - if err != io.EOF { - _ = s.Reset() - log.Trace("error reading rpc", - "from", s.Conn().RemotePeer(), - "error", err.Error(), - ) - } else { - // Just be nice. They probably won't read this - // but it doesn't hurt to send it. - _ = s.Close() - } - return - } - - err = ds.processReceivedDirectMessage(msg, s.Conn().RemotePeer()) - if err != nil { - log.Trace("p2p processReceivedDirectMessage", "error", err.Error()) - } - } - }(reader) -} - -func (ds *directSender) processReceivedDirectMessage(message *pubsubPb.Message, fromConnectedPeer peer.ID) error { - if message == nil { - return p2p.ErrNilMessage - } - if message.Topic == nil { - return p2p.ErrNilTopic - } - if !bytes.Equal(message.GetFrom(), []byte(fromConnectedPeer)) { - return fmt.Errorf("%w mismatch between From and fromConnectedPeer values", p2p.ErrInvalidValue) - } - if ds.checkAndSetSeenMessage(message) { - return p2p.ErrAlreadySeenMessage - } - - pbMessage := &pubsub.Message{ - Message: message, - } - - return ds.messageHandler(pbMessage, core.PeerID(fromConnectedPeer)) -} - -func (ds *directSender) checkAndSetSeenMessage(msg *pubsubPb.Message) bool { - msgId := string(msg.GetFrom()) + string(msg.GetSeqno()) - - ds.mutSeenMessages.Lock() - defer ds.mutSeenMessages.Unlock() - - if ds.seenMessages.Has(msgId) { - return true - } - - ds.seenMessages.Add(msgId) - return false -} - -// NextSeqno returns the next uint64 found in *counter as byte slice -func (ds *directSender) NextSeqno() []byte { - seqno := make([]byte, 8) - newVal := atomic.AddUint64(&ds.counter, 1) - binary.BigEndian.PutUint64(seqno, newVal) - return seqno -} - -// Send will send a direct message to the connected peer -func (ds *directSender) Send(topic string, buff []byte, peer core.PeerID) error { - if len(buff) >= maxSendBuffSize { - return fmt.Errorf("%w, to be sent: %d, maximum: %d", p2p.ErrMessageTooLarge, len(buff), maxSendBuffSize) - } - - mut := ds.mutexForPeer.Get(string(peer)) - mut.Lock() - defer mut.Unlock() - - conn, err := ds.getConnection(peer) - if err != nil { - return err - } - - stream, err := ds.getOrCreateStream(conn) - if err != nil { - return err - } - - msg := ds.createMessage(topic, buff, conn) - - bufw := bufio.NewWriter(stream) - w := ggio.NewDelimitedWriter(bufw) - - err = w.WriteMsg(msg) - if err != nil { - _ = stream.Reset() - _ = stream.Close() - return err - } - - err = bufw.Flush() - if err != nil { - _ = stream.Reset() - _ = stream.Close() - return err - } - - return nil -} - -func (ds *directSender) getConnection(p core.PeerID) (network.Conn, error) { - conns := ds.hostP2P.Network().ConnsToPeer(peer.ID(p)) - if len(conns) == 0 { - return nil, p2p.ErrPeerNotDirectlyConnected - } - - // return the connection that has the highest number of streams - lStreams := 0 - var conn network.Conn - for _, c := range conns { - length := len(c.GetStreams()) - if length >= lStreams { - lStreams = length - conn = c - } - } - - return conn, nil -} - -func (ds *directSender) getOrCreateStream(conn network.Conn) (network.Stream, error) { - streams := conn.GetStreams() - var foundStream network.Stream - for i := 0; i < len(streams); i++ { - isExpectedStream := streams[i].Protocol() == DirectSendID - isSendableStream := streams[i].Stat().Direction == network.DirOutbound - - if isExpectedStream && isSendableStream { - foundStream = streams[i] - break - } - } - - var err error - - if foundStream == nil { - foundStream, err = ds.hostP2P.NewStream(ds.ctx, conn.RemotePeer(), DirectSendID) - if err != nil { - return nil, err - } - } - - return foundStream, nil -} - -func (ds *directSender) createMessage(topic string, buff []byte, conn network.Conn) *pubsubPb.Message { - seqno := ds.NextSeqno() - mes := pubsubPb.Message{} - mes.Data = buff - mes.Topic = &topic - mes.From = []byte(conn.LocalPeer()) - mes.Seqno = seqno - - return &mes -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ds *directSender) IsInterfaceNil() bool { - return ds == nil -} diff --git a/p2p/libp2p/directSender_test.go b/p2p/libp2p/directSender_test.go deleted file mode 100644 index 6b04e6dbc58..00000000000 --- a/p2p/libp2p/directSender_test.go +++ /dev/null @@ -1,569 +0,0 @@ -package libp2p_test - -import ( - "bytes" - "context" - "crypto/ecdsa" - "crypto/rand" - "errors" - "fmt" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/ElrondNetwork/go-libp2p-pubsub" - pb "github.com/ElrondNetwork/go-libp2p-pubsub/pb" - "github.com/btcsuite/btcd/btcec" - ggio "github.com/gogo/protobuf/io" - libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" - "github.com/stretchr/testify/assert" -) - -const timeout = time.Second * 5 -const testMaxSize = 1 << 21 - -var blankMessageHandler = func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - return nil -} - -func generateHostStub() *mock.ConnectableHostStub { - return &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) {}, - } -} - -func createConnStub(stream network.Stream, id peer.ID, sk libp2pCrypto.PrivKey, remotePeer peer.ID) *mock.ConnStub { - return &mock.ConnStub{ - GetStreamsCalled: func() []network.Stream { - if stream == nil { - return make([]network.Stream, 0) - } - - return []network.Stream{stream} - }, - LocalPeerCalled: func() peer.ID { - return id - }, - LocalPrivateKeyCalled: func() libp2pCrypto.PrivKey { - return sk - }, - RemotePeerCalled: func() peer.ID { - return remotePeer - }, - } -} - -func createLibP2PCredentialsDirectSender() (peer.ID, libp2pCrypto.PrivKey) { - prvKey, _ := ecdsa.GenerateKey(btcec.S256(), rand.Reader) - sk := (*libp2pCrypto.Secp256k1PrivateKey)(prvKey) - id, _ := peer.IDFromPublicKey(sk.GetPublic()) - - return id, sk -} - -// ------- NewDirectSender - -func TestNewDirectSender_NilContextShouldErr(t *testing.T) { - hs := &mock.ConnectableHostStub{} - - var ctx context.Context = nil - ds, err := libp2p.NewDirectSender(ctx, hs, func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - return nil - }) - - assert.True(t, check.IfNil(ds)) - assert.Equal(t, p2p.ErrNilContext, err) -} - -func TestNewDirectSender_NilHostShouldErr(t *testing.T) { - ds, err := libp2p.NewDirectSender(context.Background(), nil, func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - return nil - }) - - assert.True(t, check.IfNil(ds)) - assert.Equal(t, p2p.ErrNilHost, err) -} - -func TestNewDirectSender_NilMessageHandlerShouldErr(t *testing.T) { - ds, err := libp2p.NewDirectSender(context.Background(), generateHostStub(), nil) - - assert.True(t, check.IfNil(ds)) - assert.Equal(t, p2p.ErrNilDirectSendMessageHandler, err) -} - -func TestNewDirectSender_OkValsShouldWork(t *testing.T) { - ds, err := libp2p.NewDirectSender(context.Background(), generateHostStub(), func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - return nil - }) - - assert.False(t, check.IfNil(ds)) - assert.Nil(t, err) -} - -func TestNewDirectSender_OkValsShouldCallSetStreamHandlerWithCorrectValues(t *testing.T) { - var pidCalled protocol.ID - var handlerCalled network.StreamHandler - - hs := &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) { - pidCalled = pid - handlerCalled = handler - }, - } - - _, _ = libp2p.NewDirectSender(context.Background(), hs, func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - return nil - }) - - assert.NotNil(t, handlerCalled) - assert.Equal(t, libp2p.DirectSendID, pidCalled) -} - -// ------- ProcessReceivedDirectMessage - -func TestDirectSender_ProcessReceivedDirectMessageNilMessageShouldErr(t *testing.T) { - ds, _ := libp2p.NewDirectSender( - context.Background(), - generateHostStub(), - blankMessageHandler, - ) - - err := ds.ProcessReceivedDirectMessage(nil, "peer id") - - assert.Equal(t, p2p.ErrNilMessage, err) -} - -func TestDirectSender_ProcessReceivedDirectMessageNilTopicIdsShouldErr(t *testing.T) { - ds, _ := libp2p.NewDirectSender( - context.Background(), - generateHostStub(), - blankMessageHandler, - ) - - id, _ := createLibP2PCredentialsDirectSender() - - msg := &pb.Message{} - msg.Data = []byte("data") - msg.Seqno = []byte("111") - msg.From = []byte(id) - msg.Topic = nil - - err := ds.ProcessReceivedDirectMessage(msg, id) - - assert.Equal(t, p2p.ErrNilTopic, err) -} - -func TestDirectSender_ProcessReceivedDirectMessageAlreadySeenMsgShouldErr(t *testing.T) { - ds, _ := libp2p.NewDirectSender( - context.Background(), - generateHostStub(), - blankMessageHandler, - ) - - id, _ := createLibP2PCredentialsDirectSender() - - msg := &pb.Message{} - msg.Data = []byte("data") - msg.Seqno = []byte("111") - msg.From = []byte(id) - topic := "topic" - msg.Topic = &topic - - msgId := string(msg.GetFrom()) + string(msg.GetSeqno()) - ds.SeenMessages().Add(msgId) - - err := ds.ProcessReceivedDirectMessage(msg, id) - - assert.Equal(t, p2p.ErrAlreadySeenMessage, err) -} - -func TestDirectSender_ProcessReceivedDirectMessageShouldWork(t *testing.T) { - ds, _ := libp2p.NewDirectSender( - context.Background(), - generateHostStub(), - blankMessageHandler, - ) - - id, _ := createLibP2PCredentialsDirectSender() - - msg := &pb.Message{} - msg.Data = []byte("data") - msg.Seqno = []byte("111") - msg.From = []byte(id) - topic := "topic" - msg.Topic = &topic - - err := ds.ProcessReceivedDirectMessage(msg, id) - - assert.Nil(t, err) -} - -func TestDirectSender_ProcessReceivedDirectMessageShouldCallMessageHandler(t *testing.T) { - wasCalled := false - - ds, _ := libp2p.NewDirectSender( - context.Background(), - generateHostStub(), - func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - wasCalled = true - return nil - }, - ) - - id, _ := createLibP2PCredentialsDirectSender() - - msg := &pb.Message{} - msg.Data = []byte("data") - msg.Seqno = []byte("111") - msg.From = []byte(id) - topic := "topic" - msg.Topic = &topic - - _ = ds.ProcessReceivedDirectMessage(msg, id) - - assert.True(t, wasCalled) -} - -func TestDirectSender_ProcessReceivedDirectMessageShouldReturnHandlersError(t *testing.T) { - checkErr := errors.New("checking error") - - ds, _ := libp2p.NewDirectSender( - context.Background(), - generateHostStub(), - func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - return checkErr - }, - ) - - id, _ := createLibP2PCredentialsDirectSender() - - msg := &pb.Message{} - msg.Data = []byte("data") - msg.Seqno = []byte("111") - msg.From = []byte(id) - topic := "topic" - msg.Topic = &topic - - err := ds.ProcessReceivedDirectMessage(msg, id) - - assert.Equal(t, checkErr, err) -} - -// ------- SendDirectToConnectedPeer - -func TestDirectSender_SendDirectToConnectedPeerBufferToLargeShouldErr(t *testing.T) { - netw := &mock.NetworkStub{} - - id, sk := createLibP2PCredentialsDirectSender() - remotePeer := peer.ID("remote peer") - - stream := mock.NewStreamMock() - _ = stream.SetProtocol(libp2p.DirectSendID) - - cs := createConnStub(stream, id, sk, remotePeer) - - netw.ConnsToPeerCalled = func(p peer.ID) []network.Conn { - return []network.Conn{cs} - } - - ds, _ := libp2p.NewDirectSender( - context.Background(), - &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) {}, - NetworkCalled: func() network.Network { - return netw - }, - }, - blankMessageHandler, - ) - - messageTooLarge := bytes.Repeat([]byte{65}, libp2p.MaxSendBuffSize) - - err := ds.Send("topic", messageTooLarge, core.PeerID(cs.RemotePeer())) - - assert.True(t, errors.Is(err, p2p.ErrMessageTooLarge)) -} - -func TestDirectSender_SendDirectToConnectedPeerNotConnectedPeerShouldErr(t *testing.T) { - netw := &mock.NetworkStub{ - ConnsToPeerCalled: func(p peer.ID) []network.Conn { - return make([]network.Conn, 0) - }, - } - - ds, _ := libp2p.NewDirectSender( - context.Background(), - &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) {}, - NetworkCalled: func() network.Network { - return netw - }, - }, - blankMessageHandler, - ) - - err := ds.Send("topic", []byte("data"), "not connected peer") - - assert.Equal(t, p2p.ErrPeerNotDirectlyConnected, err) -} - -func TestDirectSender_SendDirectToConnectedPeerNewStreamErrorsShouldErr(t *testing.T) { - t.Parallel() - - netw := &mock.NetworkStub{} - - hs := &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) {}, - NetworkCalled: func() network.Network { - return netw - }, - } - - ds, _ := libp2p.NewDirectSender( - context.Background(), - hs, - blankMessageHandler, - ) - - id, sk := createLibP2PCredentialsDirectSender() - remotePeer := peer.ID("remote peer") - errNewStream := errors.New("new stream error") - - cs := createConnStub(nil, id, sk, remotePeer) - - netw.ConnsToPeerCalled = func(p peer.ID) []network.Conn { - return []network.Conn{cs} - } - - hs.NewStreamCalled = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { - return nil, errNewStream - } - - data := []byte("data") - topic := "topic" - err := ds.Send(topic, data, core.PeerID(cs.RemotePeer())) - - assert.Equal(t, errNewStream, err) -} - -func TestDirectSender_SendDirectToConnectedPeerExistingStreamShouldSendToStream(t *testing.T) { - netw := &mock.NetworkStub{} - - ds, _ := libp2p.NewDirectSender( - context.Background(), - &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) {}, - NetworkCalled: func() network.Network { - return netw - }, - }, - blankMessageHandler, - ) - - id, sk := createLibP2PCredentialsDirectSender() - remotePeer := peer.ID("remote peer") - - stream := mock.NewStreamMock() - err := stream.SetProtocol(libp2p.DirectSendID) - assert.Nil(t, err) - - cs := createConnStub(stream, id, sk, remotePeer) - - netw.ConnsToPeerCalled = func(p peer.ID) []network.Conn { - return []network.Conn{cs} - } - - receivedMsg := &pb.Message{} - chanDone := make(chan bool) - - go func(s network.Stream) { - reader := ggio.NewDelimitedReader(s, testMaxSize) - for { - errRead := reader.ReadMsg(receivedMsg) - if errRead != nil { - fmt.Println(errRead.Error()) - return - } - - chanDone <- true - } - }(stream) - - data := []byte("data") - topic := "topic" - err = ds.Send(topic, data, core.PeerID(cs.RemotePeer())) - assert.Nil(t, err) - - select { - case <-chanDone: - case <-time.After(timeout): - assert.Fail(t, "timeout getting data from stream") - return - } - - assert.Nil(t, err) - assert.Equal(t, data, receivedMsg.Data) - assert.Equal(t, topic, *receivedMsg.Topic) -} - -func TestDirectSender_SendDirectToConnectedPeerNewStreamShouldSendToStream(t *testing.T) { - netw := &mock.NetworkStub{} - - hs := &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) {}, - NetworkCalled: func() network.Network { - return netw - }, - } - - ds, _ := libp2p.NewDirectSender( - context.Background(), - hs, - blankMessageHandler, - ) - - id, sk := createLibP2PCredentialsDirectSender() - remotePeer := peer.ID("remote peer") - - stream := mock.NewStreamMock() - _ = stream.SetProtocol(libp2p.DirectSendID) - - cs := createConnStub(stream, id, sk, remotePeer) - - netw.ConnsToPeerCalled = func(p peer.ID) []network.Conn { - return []network.Conn{cs} - } - - hs.NewStreamCalled = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { - if p == remotePeer && pids[0] == libp2p.DirectSendID { - return stream, nil - } - return nil, errors.New("wrong parameters") - } - - receivedMsg := &pb.Message{} - chanDone := make(chan bool) - - go func(s network.Stream) { - reader := ggio.NewDelimitedReader(s, testMaxSize) - for { - err := reader.ReadMsg(receivedMsg) - if err != nil { - fmt.Println(err.Error()) - return - } - - chanDone <- true - } - }(stream) - - data := []byte("data") - topic := "topic" - err := ds.Send(topic, data, core.PeerID(cs.RemotePeer())) - - select { - case <-chanDone: - case <-time.After(timeout): - assert.Fail(t, "timeout getting data from stream") - return - } - - assert.Nil(t, err) - assert.Equal(t, data, receivedMsg.Data) - assert.Equal(t, topic, *receivedMsg.Topic) -} - -// ------- received messages tests - -func TestDirectSender_ReceivedSentMessageShouldCallMessageHandlerTestFullCycle(t *testing.T) { - var streamHandler network.StreamHandler - netw := &mock.NetworkStub{} - - hs := &mock.ConnectableHostStub{ - SetStreamHandlerCalled: func(pid protocol.ID, handler network.StreamHandler) { - streamHandler = handler - }, - NetworkCalled: func() network.Network { - return netw - }, - } - - var receivedMsg *pubsub.Message - chanDone := make(chan bool) - - ds, _ := libp2p.NewDirectSender( - context.Background(), - hs, - func(msg *pubsub.Message, fromConnectedPeer core.PeerID) error { - receivedMsg = msg - chanDone <- true - return nil - }, - ) - - id, sk := createLibP2PCredentialsDirectSender() - remotePeer := peer.ID("remote peer") - - stream := mock.NewStreamMock() - stream.SetConn( - &mock.ConnStub{ - RemotePeerCalled: func() peer.ID { - return remotePeer - }, - }) - _ = stream.SetProtocol(libp2p.DirectSendID) - - streamHandler(stream) - - cs := createConnStub(stream, id, sk, remotePeer) - - netw.ConnsToPeerCalled = func(p peer.ID) []network.Conn { - return []network.Conn{cs} - } - cs.LocalPeerCalled = func() peer.ID { - return cs.RemotePeer() - } - - data := []byte("data") - topic := "topic" - _ = ds.Send(topic, data, core.PeerID(cs.RemotePeer())) - - select { - case <-chanDone: - case <-time.After(timeout): - assert.Fail(t, "timeout") - return - } - - assert.NotNil(t, receivedMsg) - assert.Equal(t, data, receivedMsg.Data) - assert.Equal(t, topic, *receivedMsg.Topic) -} - -func TestDirectSender_ProcessReceivedDirectMessageFromMismatchesFromConnectedPeerShouldErr(t *testing.T) { - ds, _ := libp2p.NewDirectSender( - context.Background(), - generateHostStub(), - blankMessageHandler, - ) - - id, _ := createLibP2PCredentialsDirectSender() - - msg := &pb.Message{} - msg.Data = []byte("data") - msg.Seqno = []byte("111") - msg.From = []byte(id) - topic := "topic" - msg.Topic = &topic - - err := ds.ProcessReceivedDirectMessage(msg, "not the same peer id") - - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} diff --git a/p2p/libp2p/disabled/peerDenialEvaluator.go b/p2p/libp2p/disabled/peerDenialEvaluator.go deleted file mode 100644 index e4203127e66..00000000000 --- a/p2p/libp2p/disabled/peerDenialEvaluator.go +++ /dev/null @@ -1,27 +0,0 @@ -package disabled - -import ( - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" -) - -// PeerDenialEvaluator is a disabled implementation of PeerDenialEvaluator that does not manage black listed keys -// (all keys [peers] are whitelisted) -type PeerDenialEvaluator struct { -} - -// IsDenied outputs false (all peers are white listed) -func (pde *PeerDenialEvaluator) IsDenied(_ core.PeerID) bool { - return false -} - -// UpsertPeerID returns nil and does nothing -func (pde *PeerDenialEvaluator) UpsertPeerID(_ core.PeerID, _ time.Duration) error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (pde *PeerDenialEvaluator) IsInterfaceNil() bool { - return pde == nil -} diff --git a/p2p/libp2p/disabled/peerDenialEvaluator_test.go b/p2p/libp2p/disabled/peerDenialEvaluator_test.go deleted file mode 100644 index 7e2964be69e..00000000000 --- a/p2p/libp2p/disabled/peerDenialEvaluator_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package disabled - -import ( - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/stretchr/testify/assert" -) - -func TestPeerDenialEvaluator_ShouldWork(t *testing.T) { - t.Parallel() - - pde := &PeerDenialEvaluator{} - - assert.False(t, check.IfNil(pde)) - assert.Nil(t, pde.UpsertPeerID("", time.Second)) - assert.False(t, pde.IsDenied("")) -} diff --git a/p2p/libp2p/discovery/continuousKadDhtDiscoverer.go b/p2p/libp2p/discovery/continuousKadDhtDiscoverer.go deleted file mode 100644 index 86394740ac2..00000000000 --- a/p2p/libp2p/discovery/continuousKadDhtDiscoverer.go +++ /dev/null @@ -1,298 +0,0 @@ -package discovery - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/protocol" - dht "github.com/libp2p/go-libp2p-kad-dht" - kbucket "github.com/libp2p/go-libp2p-kbucket" -) - -var _ p2p.PeerDiscoverer = (*ContinuousKadDhtDiscoverer)(nil) -var _ p2p.Reconnecter = (*ContinuousKadDhtDiscoverer)(nil) - -var log = logger.GetOrCreate("p2p/libp2p/kaddht") - -const kadDhtName = "kad-dht discovery" - -// ArgKadDht represents the kad-dht config argument DTO -type ArgKadDht struct { - Context context.Context - Host ConnectableHost - PeersRefreshInterval time.Duration - SeedersReconnectionInterval time.Duration - ProtocolID string - InitialPeersList []string - BucketSize uint32 - RoutingTableRefresh time.Duration - KddSharder p2p.Sharder - ConnectionWatcher p2p.ConnectionsWatcher -} - -// ContinuousKadDhtDiscoverer is the kad-dht discovery type implementation -// This implementation does not support pausing and resuming of the discovery process -type ContinuousKadDhtDiscoverer struct { - host ConnectableHost - context context.Context - mutKadDht sync.RWMutex - kadDHT *dht.IpfsDHT - refreshCancel context.CancelFunc - - peersRefreshInterval time.Duration - protocolID string - initialPeersList []string - bucketSize uint32 - routingTableRefresh time.Duration - hostConnManagement *hostWithConnectionManagement - sharder Sharder - connectionWatcher p2p.ConnectionsWatcher -} - -// NewContinuousKadDhtDiscoverer creates a new kad-dht discovery type implementation -// initialPeersList can be nil or empty, no initial connection will be attempted, a warning message will appear -func NewContinuousKadDhtDiscoverer(arg ArgKadDht) (*ContinuousKadDhtDiscoverer, error) { - sharder, err := prepareArguments(arg) - if err != nil { - return nil, err - } - - sharder.SetSeeders(arg.InitialPeersList) - - return &ContinuousKadDhtDiscoverer{ - context: arg.Context, - host: arg.Host, - sharder: sharder, - peersRefreshInterval: arg.PeersRefreshInterval, - protocolID: arg.ProtocolID, - initialPeersList: arg.InitialPeersList, - bucketSize: arg.BucketSize, - routingTableRefresh: arg.RoutingTableRefresh, - connectionWatcher: arg.ConnectionWatcher, - }, nil -} - -func prepareArguments(arg ArgKadDht) (Sharder, error) { - if check.IfNilReflect(arg.Context) { - return nil, p2p.ErrNilContext - } - if check.IfNilReflect(arg.Host) { - return nil, p2p.ErrNilHost - } - if check.IfNil(arg.KddSharder) { - return nil, p2p.ErrNilSharder - } - if check.IfNil(arg.ConnectionWatcher) { - return nil, p2p.ErrNilConnectionsWatcher - } - sharder, ok := arg.KddSharder.(Sharder) - if !ok { - return nil, fmt.Errorf("%w for sharder: expected discovery.Sharder type of interface", p2p.ErrWrongTypeAssertion) - } - if arg.PeersRefreshInterval < time.Second { - return nil, fmt.Errorf("%w, PeersRefreshInterval should have been at least 1 second", p2p.ErrInvalidValue) - } - if arg.RoutingTableRefresh < time.Second { - return nil, fmt.Errorf("%w, RoutingTableRefresh should have been at least 1 second", p2p.ErrInvalidValue) - } - isListNilOrEmpty := len(arg.InitialPeersList) == 0 - if isListNilOrEmpty { - log.Warn("nil or empty initial peers list provided to kad dht implementation. " + - "No initial connection will be done") - } - - return sharder, nil -} - -// Bootstrap will start the bootstrapping new peers process -func (ckdd *ContinuousKadDhtDiscoverer) Bootstrap() error { - ckdd.mutKadDht.Lock() - defer ckdd.mutKadDht.Unlock() - - if ckdd.kadDHT != nil { - return p2p.ErrPeerDiscoveryProcessAlreadyStarted - } - - return ckdd.startDHT() -} - -func (ckdd *ContinuousKadDhtDiscoverer) startDHT() error { - ctxrun, cancel := context.WithCancel(ckdd.context) - var err error - args := ArgsHostWithConnectionManagement{ - ConnectableHost: ckdd.host, - Sharder: ckdd.sharder, - ConnectionsWatcher: ckdd.connectionWatcher, - } - ckdd.hostConnManagement, err = NewHostWithConnectionManagement(args) - if err != nil { - cancel() - return err - } - - protocolID := protocol.ID(ckdd.protocolID) - kademliaDHT, err := dht.New( - ckdd.context, - ckdd.hostConnManagement, - dht.ProtocolPrefix(protocolID), - dht.RoutingTableRefreshPeriod(ckdd.routingTableRefresh), - dht.Mode(dht.ModeServer), - ) - if err != nil { - cancel() - return err - } - - go ckdd.connectToInitialAndBootstrap(ctxrun) - - ckdd.kadDHT = kademliaDHT - ckdd.refreshCancel = cancel - return nil -} - -func (ckdd *ContinuousKadDhtDiscoverer) stopDHT() error { - if ckdd.refreshCancel == nil { - return nil - } - - ckdd.refreshCancel() - ckdd.refreshCancel = nil - - protocolID := protocol.ID(ckdd.protocolID) - ckdd.host.RemoveStreamHandler(protocolID) - - err := ckdd.kadDHT.Close() - - ckdd.kadDHT = nil - - return err -} - -func (ckdd *ContinuousKadDhtDiscoverer) connectToInitialAndBootstrap(ctx context.Context) { - chanStartBootstrap := ckdd.connectToOnePeerFromInitialPeersList( - ckdd.peersRefreshInterval, - ckdd.initialPeersList, - ) - - // TODO: needs refactor - go func() { - select { - case <-chanStartBootstrap: - case <-ctx.Done(): - return - } - ckdd.bootstrap(ctx) - }() -} - -func (ckdd *ContinuousKadDhtDiscoverer) bootstrap(ctx context.Context) { - log.Debug("starting the p2p bootstrapping process") - for { - ckdd.mutKadDht.RLock() - kadDht := ckdd.kadDHT - ckdd.mutKadDht.RUnlock() - - shouldReconnect := kadDht != nil && kbucket.ErrLookupFailure == kadDht.Bootstrap(ckdd.context) - if shouldReconnect { - log.Debug("pausing the p2p bootstrapping process") - ckdd.ReconnectToNetwork(ctx) - log.Debug("resuming the p2p bootstrapping process") - } - - select { - case <-time.After(ckdd.peersRefreshInterval): - case <-ctx.Done(): - log.Debug("closing the p2p bootstrapping process") - return - } - } -} - -func (ckdd *ContinuousKadDhtDiscoverer) connectToOnePeerFromInitialPeersList( - intervalBetweenAttempts time.Duration, - initialPeersList []string, -) <-chan struct{} { - - chanDone := make(chan struct{}, 1) - - if len(initialPeersList) == 0 { - chanDone <- struct{}{} - return chanDone - } - - go ckdd.tryConnectToSeeder(intervalBetweenAttempts, initialPeersList, chanDone) - - return chanDone -} - -func (ckdd *ContinuousKadDhtDiscoverer) tryConnectToSeeder( - intervalBetweenAttempts time.Duration, - initialPeersList []string, - chanDone chan struct{}, -) { - - startIndex := 0 - - for { - initialPeer := initialPeersList[startIndex] - err := ckdd.host.ConnectToPeer(ckdd.context, initialPeer) - if err != nil { - printConnectionErrorToSeeder(initialPeer, err) - startIndex++ - startIndex = startIndex % len(initialPeersList) - select { - case <-ckdd.context.Done(): - log.Debug("context done in ContinuousKadDhtDiscoverer") - return - case <-time.After(intervalBetweenAttempts): - continue - } - } else { - log.Debug("connected to seeder", "address", initialPeer) - } - - break - } - chanDone <- struct{}{} -} - -func printConnectionErrorToSeeder(peer string, err error) { - if errors.Is(err, p2p.ErrUnwantedPeer) { - log.Trace("tryConnectToSeeder: unwanted peer", - "seeder", peer, - "error", err.Error(), - ) - - return - } - - log.Debug("error connecting to seeder", - "seeder", peer, - "error", err.Error(), - ) -} - -// Name returns the name of the kad dht peer discovery implementation -func (ckdd *ContinuousKadDhtDiscoverer) Name() string { - return kadDhtName -} - -// ReconnectToNetwork will try to connect to one peer from the initial peer list -func (ckdd *ContinuousKadDhtDiscoverer) ReconnectToNetwork(ctx context.Context) { - select { - case <-ckdd.connectToOnePeerFromInitialPeersList(ckdd.peersRefreshInterval, ckdd.initialPeersList): - case <-ctx.Done(): - return - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ckdd *ContinuousKadDhtDiscoverer) IsInterfaceNil() bool { - return ckdd == nil -} diff --git a/p2p/libp2p/discovery/continuousKadDhtDiscoverer_test.go b/p2p/libp2p/discovery/continuousKadDhtDiscoverer_test.go deleted file mode 100644 index 4c6c2ec8391..00000000000 --- a/p2p/libp2p/discovery/continuousKadDhtDiscoverer_test.go +++ /dev/null @@ -1,326 +0,0 @@ -package discovery_test - -import ( - "context" - "errors" - "sync/atomic" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/libp2p/go-libp2p-core/event" - "github.com/stretchr/testify/assert" -) - -var timeoutWaitResponses = 2 * time.Second - -func createTestArgument() discovery.ArgKadDht { - return discovery.ArgKadDht{ - Context: context.Background(), - Host: &mock.ConnectableHostStub{}, - KddSharder: &mock.KadSharderStub{}, - PeersRefreshInterval: time.Second, - ProtocolID: "/erd/test/0.0.0", - InitialPeersList: []string{"peer1", "peer2"}, - BucketSize: 100, - RoutingTableRefresh: 5 * time.Second, - SeedersReconnectionInterval: time.Second * 5, - ConnectionWatcher: &mock.ConnectionsWatcherStub{}, - } -} - -func TestNewContinuousKadDhtDiscoverer(t *testing.T) { - t.Parallel() - - t.Run("nil context should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.Context = nil - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.True(t, check.IfNil(kdd)) - assert.True(t, errors.Is(err, p2p.ErrNilContext)) - }) - t.Run("nil host should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.Host = nil - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.True(t, check.IfNil(kdd)) - assert.True(t, errors.Is(err, p2p.ErrNilHost)) - }) - t.Run("nil sharder should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.KddSharder = nil - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.True(t, check.IfNil(kdd)) - assert.True(t, errors.Is(err, p2p.ErrNilSharder)) - }) - t.Run("wrong sharder should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.KddSharder = &mock.SharderStub{} - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.True(t, check.IfNil(kdd)) - assert.True(t, errors.Is(err, p2p.ErrWrongTypeAssertion)) - }) - t.Run("invalid peers refresh interval should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.PeersRefreshInterval = time.Second - time.Microsecond - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.Nil(t, kdd) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) - }) - t.Run("invalid routing table refresh interval should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.RoutingTableRefresh = time.Second - time.Microsecond - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.Nil(t, kdd) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) - }) - t.Run("nil connections watcher should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.ConnectionWatcher = nil - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.Nil(t, kdd) - assert.True(t, errors.Is(err, p2p.ErrNilConnectionsWatcher)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.False(t, check.IfNil(kdd)) - assert.Nil(t, err) - }) -} - -func TestNewContinuousKadDhtDiscoverer_EmptyInitialPeersShouldWork(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.InitialPeersList = nil - - kdd, err := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.False(t, check.IfNil(kdd)) - assert.Nil(t, err) -} - -// ------- Bootstrap - -func TestContinuousKadDhtDiscoverer_BootstrapCalledOnceShouldWork(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - ckdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - - err := ckdd.Bootstrap() - - assert.Nil(t, err) - time.Sleep(arg.PeersRefreshInterval * 2) -} - -func TestContinuousKadDhtDiscoverer_BootstrapCalledTwiceShouldErr(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - ckdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - - _ = ckdd.Bootstrap() - err := ckdd.Bootstrap() - - assert.Equal(t, p2p.ErrPeerDiscoveryProcessAlreadyStarted, err) -} - -// ------- connectToOnePeerFromInitialPeersList - -func TestContinuousKadDhtDiscoverer_ConnectToOnePeerFromInitialPeersListNilListShouldRetWithChanFull(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - ckdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - - chanDone := ckdd.ConnectToOnePeerFromInitialPeersList(time.Second, nil) - - assert.Equal(t, 1, len(chanDone)) -} - -func TestContinuousKadDhtDiscoverer_ConnectToOnePeerFromInitialPeersListEmptyListShouldRetWithChanFull(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - ckdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - - chanDone := ckdd.ConnectToOnePeerFromInitialPeersList(time.Second, make([]string, 0)) - - assert.Equal(t, 1, len(chanDone)) -} - -func TestContinuousKadDhtDiscoverer_ConnectToOnePeerFromInitialPeersOnePeerShouldTryToConnect(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - peerID := "peer" - wasConnectCalled := int32(0) - - arg.Host = &mock.ConnectableHostStub{ - ConnectToPeerCalled: func(ctx context.Context, address string) error { - if peerID == address { - atomic.AddInt32(&wasConnectCalled, 1) - } - - return nil - }, - EventBusCalled: func() event.Bus { - return &mock.EventBusStub{ - SubscribeCalled: func(eventType interface{}, opts ...event.SubscriptionOpt) (event.Subscription, error) { - return &mock.EventSubscriptionStub{}, nil - }, - } - }, - } - ckdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - chanDone := ckdd.ConnectToOnePeerFromInitialPeersList(time.Second, []string{peerID}) - - select { - case <-chanDone: - assert.Equal(t, int32(1), atomic.LoadInt32(&wasConnectCalled)) - case <-time.After(timeoutWaitResponses): - assert.Fail(t, "timeout") - } -} - -func TestContinuousKadDhtDiscoverer_ConnectToOnePeerFromInitialPeersOnePeerShouldTryToConnectContinously(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - peerID := "peer" - wasConnectCalled := int32(0) - - errDidNotConnect := errors.New("did not connect") - noOfTimesToRefuseConnection := 5 - arg.Host = &mock.ConnectableHostStub{ - ConnectToPeerCalled: func(ctx context.Context, address string) error { - if peerID != address { - assert.Fail(t, "should have tried to connect to the same ID") - } - - atomic.AddInt32(&wasConnectCalled, 1) - - if atomic.LoadInt32(&wasConnectCalled) < int32(noOfTimesToRefuseConnection) { - return errDidNotConnect - } - - return nil - }, - EventBusCalled: func() event.Bus { - return &mock.EventBusStub{ - SubscribeCalled: func(eventType interface{}, opts ...event.SubscriptionOpt) (event.Subscription, error) { - return &mock.EventSubscriptionStub{}, nil - }, - } - }, - } - ckdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - - chanDone := ckdd.ConnectToOnePeerFromInitialPeersList(time.Millisecond*10, []string{peerID}) - - select { - case <-chanDone: - assert.Equal(t, int32(noOfTimesToRefuseConnection), atomic.LoadInt32(&wasConnectCalled)) - case <-time.After(timeoutWaitResponses): - assert.Fail(t, "timeout") - } -} - -func TestContinuousKadDhtDiscoverer_ConnectToOnePeerFromInitialPeersTwoPeersShouldAlternate(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - peerID1 := "peer1" - peerID2 := "peer2" - wasConnectCalled := int32(0) - errDidNotConnect := errors.New("did not connect") - noOfTimesToRefuseConnection := 5 - arg.Host = &mock.ConnectableHostStub{ - ConnectToPeerCalled: func(ctx context.Context, address string) error { - connCalled := atomic.LoadInt32(&wasConnectCalled) - - atomic.AddInt32(&wasConnectCalled, 1) - - if connCalled >= int32(noOfTimesToRefuseConnection) { - return nil - } - - connCalled = connCalled % 2 - if connCalled == 0 { - if peerID1 != address { - assert.Fail(t, "should have tried to connect to "+peerID1) - } - } - - if connCalled == 1 { - if peerID2 != address { - assert.Fail(t, "should have tried to connect to "+peerID2) - } - } - - return errDidNotConnect - }, - EventBusCalled: func() event.Bus { - return &mock.EventBusStub{ - SubscribeCalled: func(eventType interface{}, opts ...event.SubscriptionOpt) (event.Subscription, error) { - return &mock.EventSubscriptionStub{}, nil - }, - } - }, - } - ckdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - - chanDone := ckdd.ConnectToOnePeerFromInitialPeersList(time.Millisecond*10, []string{peerID1, peerID2}) - - select { - case <-chanDone: - case <-time.After(timeoutWaitResponses): - assert.Fail(t, "timeout") - } -} - -func TestContinuousKadDhtDiscoverer_Name(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - kdd, _ := discovery.NewContinuousKadDhtDiscoverer(arg) - - assert.Equal(t, discovery.KadDhtName, kdd.Name()) -} diff --git a/p2p/libp2p/discovery/export_test.go b/p2p/libp2p/discovery/export_test.go deleted file mode 100644 index 2478c5cc0ef..00000000000 --- a/p2p/libp2p/discovery/export_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package discovery - -import ( - "context" - "time" - - "github.com/ElrondNetwork/elrond-go/p2p" -) - -const KadDhtName = kadDhtName -const OptimizedKadDhtName = optimizedKadDhtName -const NullName = nilName - -// ------- ContinuousKadDhtDiscoverer - -func (ckdd *ContinuousKadDhtDiscoverer) ConnectToOnePeerFromInitialPeersList( - durationBetweenAttempts time.Duration, - initialPeersList []string) <-chan struct{} { - - return ckdd.connectToOnePeerFromInitialPeersList(durationBetweenAttempts, initialPeersList) -} - -func (ckdd *ContinuousKadDhtDiscoverer) StopDHT() error { - ckdd.mutKadDht.Lock() - err := ckdd.stopDHT() - ckdd.mutKadDht.Unlock() - - return err -} - -// NewOptimizedKadDhtDiscovererWithInitFunc - -func NewOptimizedKadDhtDiscovererWithInitFunc( - arg ArgKadDht, - createFunc func(ctx context.Context) (KadDhtHandler, error), -) (*optimizedKadDhtDiscoverer, error) { - sharder, err := prepareArguments(arg) - if err != nil { - return nil, err - } - - if arg.SeedersReconnectionInterval < minIntervalForSeedersReconnection { - return nil, p2p.ErrInvalidSeedersReconnectionInterval - } - - okdd := &optimizedKadDhtDiscoverer{ - sharder: sharder, - peersRefreshInterval: arg.PeersRefreshInterval, - seedersReconnectionInterval: arg.SeedersReconnectionInterval, - protocolID: arg.ProtocolID, - initialPeersList: arg.InitialPeersList, - bucketSize: arg.BucketSize, - routingTableRefresh: arg.RoutingTableRefresh, - status: statNotInitialized, - chanInit: make(chan struct{}), - errChanInit: make(chan error), - chanConnectToSeeders: make(chan struct{}), - } - - okdd.createKadDhtHandler = createFunc - argConnectionManagement := ArgsHostWithConnectionManagement{ - ConnectableHost: arg.Host, - Sharder: okdd.sharder, - ConnectionsWatcher: arg.ConnectionWatcher, - } - okdd.hostConnManagement, err = NewHostWithConnectionManagement(argConnectionManagement) - if err != nil { - return nil, err - } - - go okdd.processLoop(arg.Context) - - return okdd, nil -} diff --git a/p2p/libp2p/discovery/factory/peerDiscovererFactory.go b/p2p/libp2p/discovery/factory/peerDiscovererFactory.go deleted file mode 100644 index 0c883a2a554..00000000000 --- a/p2p/libp2p/discovery/factory/peerDiscovererFactory.go +++ /dev/null @@ -1,75 +0,0 @@ -package factory - -import ( - "context" - "fmt" - "time" - - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery" -) - -const typeLegacy = "legacy" -const typeOptimized = "optimized" -const defaultSeedersReconnectionInterval = time.Minute * 5 - -var log = logger.GetOrCreate("p2p/discovery/factory") - -// ArgsPeerDiscoverer is the DTO struct used in the NewPeerDiscoverer function -type ArgsPeerDiscoverer struct { - Context context.Context - Host discovery.ConnectableHost - Sharder p2p.Sharder - P2pConfig config.P2PConfig - ConnectionsWatcher p2p.ConnectionsWatcher -} - -// NewPeerDiscoverer generates an implementation of PeerDiscoverer by parsing the p2pConfig struct -// Errors if config is badly formatted -func NewPeerDiscoverer(args ArgsPeerDiscoverer) (p2p.PeerDiscoverer, error) { - if args.P2pConfig.KadDhtPeerDiscovery.Enabled { - return createKadDhtPeerDiscoverer(args) - } - - log.Debug("using nil discoverer") - return discovery.NewNilDiscoverer(), nil -} - -func createKadDhtPeerDiscoverer(args ArgsPeerDiscoverer) (p2p.PeerDiscoverer, error) { - arg := discovery.ArgKadDht{ - Context: args.Context, - Host: args.Host, - KddSharder: args.Sharder, - PeersRefreshInterval: time.Second * time.Duration(args.P2pConfig.KadDhtPeerDiscovery.RefreshIntervalInSec), - SeedersReconnectionInterval: defaultSeedersReconnectionInterval, - ProtocolID: args.P2pConfig.KadDhtPeerDiscovery.ProtocolID, - InitialPeersList: args.P2pConfig.KadDhtPeerDiscovery.InitialPeerList, - BucketSize: args.P2pConfig.KadDhtPeerDiscovery.BucketSize, - RoutingTableRefresh: time.Second * time.Duration(args.P2pConfig.KadDhtPeerDiscovery.RoutingTableRefreshIntervalInSec), - ConnectionWatcher: args.ConnectionsWatcher, - } - - switch args.P2pConfig.Sharding.Type { - case p2p.ListsSharder, p2p.OneListSharder, p2p.NilListSharder: - return createKadDhtDiscoverer(args.P2pConfig, arg) - default: - return nil, fmt.Errorf("%w unable to select peer discoverer based on "+ - "selected sharder: unknown sharder '%s'", p2p.ErrInvalidValue, args.P2pConfig.Sharding.Type) - } -} - -func createKadDhtDiscoverer(p2pConfig config.P2PConfig, arg discovery.ArgKadDht) (p2p.PeerDiscoverer, error) { - switch p2pConfig.KadDhtPeerDiscovery.Type { - case typeLegacy: - log.Debug("using continuous (legacy) kad dht discoverer") - return discovery.NewContinuousKadDhtDiscoverer(arg) - case typeOptimized: - log.Debug("using optimized kad dht discoverer") - return discovery.NewOptimizedKadDhtDiscoverer(arg) - default: - return nil, fmt.Errorf("%w unable to select peer discoverer based on type '%s'", - p2p.ErrInvalidValue, p2pConfig.KadDhtPeerDiscovery.Type) - } -} diff --git a/p2p/libp2p/discovery/factory/peerDiscovererFactory_test.go b/p2p/libp2p/discovery/factory/peerDiscovererFactory_test.go deleted file mode 100644 index 9fbe8386354..00000000000 --- a/p2p/libp2p/discovery/factory/peerDiscovererFactory_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package factory_test - -import ( - "context" - "errors" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery/factory" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/stretchr/testify/assert" -) - -func TestNewPeerDiscoverer_NoDiscoveryEnabledShouldRetNullDiscoverer(t *testing.T) { - t.Parallel() - - args := factory.ArgsPeerDiscoverer{ - Context: context.Background(), - Host: &mock.ConnectableHostStub{}, - Sharder: &mock.SharderStub{}, - P2pConfig: config.P2PConfig{ - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - }, - }, - ConnectionsWatcher: &mock.ConnectionsWatcherStub{}, - } - pDiscoverer, err := factory.NewPeerDiscoverer(args) - _, ok := pDiscoverer.(*discovery.NilDiscoverer) - - assert.True(t, ok) - assert.Nil(t, err) -} - -func TestNewPeerDiscoverer_ListsSharderShouldWork(t *testing.T) { - t.Parallel() - - args := factory.ArgsPeerDiscoverer{ - Context: context.Background(), - Host: &mock.ConnectableHostStub{}, - Sharder: &mock.KadSharderStub{}, - P2pConfig: config.P2PConfig{ - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - RefreshIntervalInSec: 1, - RoutingTableRefreshIntervalInSec: 300, - Type: "legacy", - }, - Sharding: config.ShardingConfig{ - Type: p2p.ListsSharder, - }, - }, - ConnectionsWatcher: &mock.ConnectionsWatcherStub{}, - } - - pDiscoverer, err := factory.NewPeerDiscoverer(args) - _, ok := pDiscoverer.(*discovery.ContinuousKadDhtDiscoverer) - - assert.NotNil(t, pDiscoverer) - assert.True(t, ok) - assert.Nil(t, err) -} - -func TestNewPeerDiscoverer_OptimizedKadDhtShouldWork(t *testing.T) { - t.Parallel() - - args := factory.ArgsPeerDiscoverer{ - Context: context.Background(), - Host: &mock.ConnectableHostStub{}, - Sharder: &mock.KadSharderStub{}, - P2pConfig: config.P2PConfig{ - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - RefreshIntervalInSec: 1, - RoutingTableRefreshIntervalInSec: 300, - Type: "optimized", - }, - Sharding: config.ShardingConfig{ - Type: p2p.ListsSharder, - }, - }, - ConnectionsWatcher: &mock.ConnectionsWatcherStub{}, - } - pDiscoverer, err := factory.NewPeerDiscoverer(args) - - assert.Nil(t, err) - assert.NotNil(t, pDiscoverer) - assert.Equal(t, "optimized kad-dht discovery", pDiscoverer.Name()) -} - -func TestNewPeerDiscoverer_UnknownSharderShouldErr(t *testing.T) { - t.Parallel() - - args := factory.ArgsPeerDiscoverer{ - Context: context.Background(), - Host: &mock.ConnectableHostStub{}, - Sharder: &mock.SharderStub{}, - P2pConfig: config.P2PConfig{ - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - RefreshIntervalInSec: 1, - }, - Sharding: config.ShardingConfig{ - Type: "unknown", - }, - }, - ConnectionsWatcher: &mock.ConnectionsWatcherStub{}, - } - - pDiscoverer, err := factory.NewPeerDiscoverer(args) - - assert.True(t, check.IfNil(pDiscoverer)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewPeerDiscoverer_UnknownKadDhtShouldErr(t *testing.T) { - t.Parallel() - - args := factory.ArgsPeerDiscoverer{ - Context: context.Background(), - Host: &mock.ConnectableHostStub{}, - Sharder: &mock.SharderStub{}, - P2pConfig: config.P2PConfig{ - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - RefreshIntervalInSec: 1, - RoutingTableRefreshIntervalInSec: 300, - Type: "unknown", - }, - Sharding: config.ShardingConfig{ - Type: p2p.ListsSharder, - }, - }, - ConnectionsWatcher: &mock.ConnectionsWatcherStub{}, - } - - pDiscoverer, err := factory.NewPeerDiscoverer(args) - - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) - assert.True(t, check.IfNil(pDiscoverer)) -} diff --git a/p2p/libp2p/discovery/hostWithConnectionManagement.go b/p2p/libp2p/discovery/hostWithConnectionManagement.go deleted file mode 100644 index 746be53b9bb..00000000000 --- a/p2p/libp2p/discovery/hostWithConnectionManagement.go +++ /dev/null @@ -1,92 +0,0 @@ -package discovery - -import ( - "context" - "fmt" - "strings" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/multiformats/go-multiaddr" -) - -// ArgsHostWithConnectionManagement is the argument DTO used in the NewHostWithConnectionManagement function -type ArgsHostWithConnectionManagement struct { - ConnectableHost ConnectableHost - Sharder Sharder - ConnectionsWatcher p2p.ConnectionsWatcher -} - -type hostWithConnectionManagement struct { - ConnectableHost - sharder Sharder - connectionsWatcher p2p.ConnectionsWatcher -} - -// NewHostWithConnectionManagement returns a host wrapper able to decide if connection initiated to a peer -// will actually be kept or not -func NewHostWithConnectionManagement(args ArgsHostWithConnectionManagement) (*hostWithConnectionManagement, error) { - if check.IfNil(args.ConnectableHost) { - return nil, p2p.ErrNilHost - } - if check.IfNil(args.Sharder) { - return nil, p2p.ErrNilSharder - } - if check.IfNil(args.ConnectionsWatcher) { - return nil, p2p.ErrNilConnectionsWatcher - } - - return &hostWithConnectionManagement{ - ConnectableHost: args.ConnectableHost, - sharder: args.Sharder, - connectionsWatcher: args.ConnectionsWatcher, - }, nil -} - -// Connect tries to connect to the provided address info if the sharder allows it -func (hwcm *hostWithConnectionManagement) Connect(ctx context.Context, pi peer.AddrInfo) error { - addresses := concatenateAddresses(pi.Addrs) - hwcm.connectionsWatcher.NewKnownConnection(core.PeerID(pi.ID), addresses) - err := hwcm.canConnectToPeer(pi.ID) - if err != nil { - return err - } - - return hwcm.ConnectableHost.Connect(ctx, pi) -} - -func concatenateAddresses(addresses []multiaddr.Multiaddr) string { - sb := strings.Builder{} - for _, ma := range addresses { - sb.WriteString(ma.String() + " ") - } - - return sb.String() -} - -func (hwcm *hostWithConnectionManagement) canConnectToPeer(pid peer.ID) error { - allPeers := hwcm.ConnectableHost.Network().Peers() - if !hwcm.sharder.Has(pid, allPeers) { - allPeers = append(allPeers, pid) - } - - evicted := hwcm.sharder.ComputeEvictionList(allPeers) - if hwcm.sharder.Has(pid, evicted) { - return fmt.Errorf("%w, pid: %s", p2p.ErrUnwantedPeer, pid.Pretty()) - } - - return nil -} - -// IsConnected returns true if the current host is connected to the provided peer info -func (hwcm *hostWithConnectionManagement) IsConnected(pi peer.AddrInfo) bool { - return hwcm.Network().Connectedness(pi.ID) == network.Connected -} - -// IsInterfaceNil returns true if there is no value under the interface -func (hwcm *hostWithConnectionManagement) IsInterfaceNil() bool { - return hwcm == nil || check.IfNil(hwcm.ConnectableHost) -} diff --git a/p2p/libp2p/discovery/hostWithConnectionManagement_test.go b/p2p/libp2p/discovery/hostWithConnectionManagement_test.go deleted file mode 100644 index 7b71e856623..00000000000 --- a/p2p/libp2p/discovery/hostWithConnectionManagement_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package discovery_test - -import ( - "context" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" -) - -func createStubNetwork() network.Network { - return &mock.NetworkStub{ - PeersCall: func() []peer.ID { - return make([]peer.ID, 0) - }, - } -} - -func createMockArgsHostWithConnectionManagement() discovery.ArgsHostWithConnectionManagement { - return discovery.ArgsHostWithConnectionManagement{ - ConnectableHost: &mock.ConnectableHostStub{}, - Sharder: &mock.KadSharderStub{}, - ConnectionsWatcher: &mock.ConnectionsWatcherStub{}, - } -} - -func TestNewHostWithConnectionManagement(t *testing.T) { - t.Parallel() - - t.Run("nil connectable host should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgsHostWithConnectionManagement() - args.ConnectableHost = nil - hwcm, err := discovery.NewHostWithConnectionManagement(args) - - assert.True(t, check.IfNil(hwcm)) - assert.Equal(t, p2p.ErrNilHost, err) - }) - t.Run("nil sharder should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgsHostWithConnectionManagement() - args.Sharder = nil - hwcm, err := discovery.NewHostWithConnectionManagement(args) - - assert.True(t, check.IfNil(hwcm)) - assert.Equal(t, p2p.ErrNilSharder, err) - }) - t.Run("nil connection watcher should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgsHostWithConnectionManagement() - args.ConnectionsWatcher = nil - hwcm, err := discovery.NewHostWithConnectionManagement(args) - - assert.True(t, check.IfNil(hwcm)) - assert.Equal(t, p2p.ErrNilConnectionsWatcher, err) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - args := createMockArgsHostWithConnectionManagement() - hwcm, err := discovery.NewHostWithConnectionManagement(args) - - assert.False(t, check.IfNil(hwcm)) - assert.Nil(t, err) - }) -} - -// ------- Connect - -func TestHostWithConnectionManagement_ConnectWithSharderNotEvictedShouldCallConnect(t *testing.T) { - t.Parallel() - - connectCalled := false - args := createMockArgsHostWithConnectionManagement() - args.ConnectableHost = &mock.ConnectableHostStub{ - ConnectCalled: func(_ context.Context, _ peer.AddrInfo) error { - connectCalled = true - return nil - }, - NetworkCalled: func() network.Network { - return createStubNetwork() - }, - } - args.Sharder = &mock.KadSharderStub{ - ComputeEvictListCalled: func(pidList []peer.ID) []peer.ID { - return make([]peer.ID, 0) - }, - HasCalled: func(pid peer.ID, list []peer.ID) bool { - return false - }, - } - newKnownConnectionCalled := false - args.ConnectionsWatcher = &mock.ConnectionsWatcherStub{ - NewKnownConnectionCalled: func(pid core.PeerID, connection string) { - newKnownConnectionCalled = true - }, - } - hwcm, _ := discovery.NewHostWithConnectionManagement(args) - - _ = hwcm.Connect(context.Background(), peer.AddrInfo{}) - - assert.True(t, connectCalled) - assert.True(t, newKnownConnectionCalled) -} - -func TestHostWithConnectionManagement_ConnectWithSharderEvictedShouldNotCallConnect(t *testing.T) { - t.Parallel() - - connectCalled := false - args := createMockArgsHostWithConnectionManagement() - args.ConnectableHost = &mock.ConnectableHostStub{ - ConnectCalled: func(_ context.Context, _ peer.AddrInfo) error { - connectCalled = true - return nil - }, - NetworkCalled: func() network.Network { - return createStubNetwork() - }, - } - args.Sharder = &mock.KadSharderStub{ - ComputeEvictListCalled: func(pidList []peer.ID) []peer.ID { - return make([]peer.ID, 0) - }, - HasCalled: func(pid peer.ID, list []peer.ID) bool { - return true - }, - } - newKnownConnectionCalled := false - args.ConnectionsWatcher = &mock.ConnectionsWatcherStub{ - NewKnownConnectionCalled: func(pid core.PeerID, connection string) { - newKnownConnectionCalled = true - }, - } - hwcm, _ := discovery.NewHostWithConnectionManagement(args) - - _ = hwcm.Connect(context.Background(), peer.AddrInfo{}) - - assert.False(t, connectCalled) - assert.True(t, newKnownConnectionCalled) -} diff --git a/p2p/libp2p/discovery/interface.go b/p2p/libp2p/discovery/interface.go deleted file mode 100644 index 6fec58bda38..00000000000 --- a/p2p/libp2p/discovery/interface.go +++ /dev/null @@ -1,31 +0,0 @@ -package discovery - -import ( - "context" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/libp2p/go-libp2p-core/host" - "github.com/libp2p/go-libp2p-core/peer" -) - -// ConnectableHost is an enhanced Host interface that has the ability to connect to a string address -type ConnectableHost interface { - host.Host - ConnectToPeer(ctx context.Context, address string) error - AddressToPeerInfo(address string) (*peer.AddrInfo, error) - IsInterfaceNil() bool -} - -// Sharder defines the eviction computing process of unwanted peers -type Sharder interface { - ComputeEvictionList(pidList []peer.ID) []peer.ID - Has(pid peer.ID, list []peer.ID) bool - SetSeeders(addresses []string) - IsSeeder(pid core.PeerID) bool - IsInterfaceNil() bool -} - -// KadDhtHandler defines the behavior of a component that can find new peers in a p2p network through kad dht mechanism -type KadDhtHandler interface { - Bootstrap(ctx context.Context) error -} diff --git a/p2p/libp2p/discovery/nilDiscoverer.go b/p2p/libp2p/discovery/nilDiscoverer.go deleted file mode 100644 index 2785b562de0..00000000000 --- a/p2p/libp2p/discovery/nilDiscoverer.go +++ /dev/null @@ -1,40 +0,0 @@ -package discovery - -import ( - "context" - - "github.com/ElrondNetwork/elrond-go/p2p" -) - -var _ p2p.PeerDiscoverer = (*NilDiscoverer)(nil) -var _ p2p.Reconnecter = (*NilDiscoverer)(nil) - -const nilName = "no peer discovery" - -// NilDiscoverer is the non-functional peer discoverer aimed to be used when peer discovery options are all disabled -type NilDiscoverer struct { -} - -// NewNilDiscoverer creates a new NullDiscoverer implementation -func NewNilDiscoverer() *NilDiscoverer { - return &NilDiscoverer{} -} - -// Bootstrap will return nil. There is no implementation. -func (nd *NilDiscoverer) Bootstrap() error { - return nil -} - -// Name returns a message which says no peer discovery mechanism is used -func (nd *NilDiscoverer) Name() string { - return nilName -} - -// ReconnectToNetwork does nothing -func (nd *NilDiscoverer) ReconnectToNetwork(_ context.Context) { -} - -// IsInterfaceNil returns true if there is no value under the interface -func (nd *NilDiscoverer) IsInterfaceNil() bool { - return nd == nil -} diff --git a/p2p/libp2p/discovery/nilDiscoverer_test.go b/p2p/libp2p/discovery/nilDiscoverer_test.go deleted file mode 100644 index 4e592d7b05c..00000000000 --- a/p2p/libp2p/discovery/nilDiscoverer_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package discovery_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery" - "github.com/stretchr/testify/assert" -) - -func TestNilDiscoverer(t *testing.T) { - t.Parallel() - - nd := discovery.NewNilDiscoverer() - - assert.False(t, check.IfNil(nd)) - assert.Equal(t, discovery.NullName, nd.Name()) - assert.Nil(t, nd.Bootstrap()) -} diff --git a/p2p/libp2p/discovery/optimizedKadDhtDiscoverer.go b/p2p/libp2p/discovery/optimizedKadDhtDiscoverer.go deleted file mode 100644 index c6246babeef..00000000000 --- a/p2p/libp2p/discovery/optimizedKadDhtDiscoverer.go +++ /dev/null @@ -1,252 +0,0 @@ -package discovery - -import ( - "context" - "time" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/protocol" - dht "github.com/libp2p/go-libp2p-kad-dht" -) - -type discovererStatus string - -const statNotInitialized discovererStatus = "not initialized" -const statInitialized discovererStatus = "initialized" -const minIntervalForSeedersReconnection = time.Second -const optimizedKadDhtName = "optimized kad-dht discovery" - -type optimizedKadDhtDiscoverer struct { - kadDHT KadDhtHandler - peersRefreshInterval time.Duration - seedersReconnectionInterval time.Duration - protocolID string - initialPeersList []string - bucketSize uint32 - routingTableRefresh time.Duration - hostConnManagement *hostWithConnectionManagement - sharder Sharder - status discovererStatus - chanInit chan struct{} - errChanInit chan error - chanConnectToSeeders chan struct{} - createKadDhtHandler func(ctx context.Context) (KadDhtHandler, error) - connectionWatcher p2p.ConnectionsWatcher -} - -// NewOptimizedKadDhtDiscoverer creates an optimized kad-dht discovery type implementation -// initialPeersList can be nil or empty, no initial connection will be attempted, a warning message will appear -// This implementation uses a single process loop function able to carry multiple tasks synchronously -func NewOptimizedKadDhtDiscoverer(arg ArgKadDht) (*optimizedKadDhtDiscoverer, error) { - sharder, err := prepareArguments(arg) - if err != nil { - return nil, err - } - - if arg.SeedersReconnectionInterval < minIntervalForSeedersReconnection { - return nil, p2p.ErrInvalidSeedersReconnectionInterval - } - - sharder.SetSeeders(arg.InitialPeersList) - - okdd := &optimizedKadDhtDiscoverer{ - sharder: sharder, - peersRefreshInterval: arg.PeersRefreshInterval, - seedersReconnectionInterval: arg.SeedersReconnectionInterval, - protocolID: arg.ProtocolID, - initialPeersList: arg.InitialPeersList, - bucketSize: arg.BucketSize, - routingTableRefresh: arg.RoutingTableRefresh, - status: statNotInitialized, - chanInit: make(chan struct{}), - errChanInit: make(chan error), - chanConnectToSeeders: make(chan struct{}), - connectionWatcher: arg.ConnectionWatcher, - } - - okdd.createKadDhtHandler = okdd.createKadDht - args := ArgsHostWithConnectionManagement{ - ConnectableHost: arg.Host, - Sharder: okdd.sharder, - ConnectionsWatcher: okdd.connectionWatcher, - } - okdd.hostConnManagement, err = NewHostWithConnectionManagement(args) - if err != nil { - return nil, err - } - - go okdd.processLoop(arg.Context) - - return okdd, nil -} - -// Bootstrap will start the bootstrapping new peers process -func (okdd *optimizedKadDhtDiscoverer) Bootstrap() error { - okdd.chanInit <- struct{}{} - return <-okdd.errChanInit -} - -func (okdd *optimizedKadDhtDiscoverer) processLoop(ctx context.Context) { - chTimeSeedersReconnect := time.After(okdd.seedersReconnectionInterval) - chTimeFindPeers := time.After(okdd.peersRefreshInterval) - - for { - select { - case <-okdd.chanInit: - chTimeSeedersReconnect = okdd.processInit(ctx) - - case <-chTimeSeedersReconnect: - chTimeSeedersReconnect = okdd.processSeedersReconnect(ctx) - - case <-okdd.chanConnectToSeeders: - chTimeSeedersReconnect = okdd.processSeedersReconnect(ctx) - - case <-chTimeFindPeers: - okdd.findPeers(ctx) - chTimeFindPeers = time.After(okdd.peersRefreshInterval) - - case <-ctx.Done(): - log.Debug("closing the p2p bootstrapping process") - - okdd.finishMainLoopProcessing(ctx) - return - } - } -} - -func (okdd *optimizedKadDhtDiscoverer) processInit(ctx context.Context) <-chan time.Time { - err := okdd.init(ctx) - okdd.errChanInit <- err - if err != nil { - return okdd.createChTimeSeedersReconnect(false) - } - - ch := okdd.processSeedersReconnect(ctx) - okdd.findPeers(ctx) - - return ch -} - -func (okdd *optimizedKadDhtDiscoverer) processSeedersReconnect(ctx context.Context) <-chan time.Time { - isConnectedToSeeders := okdd.tryToReconnectAtLeastToASeeder(ctx) - return okdd.createChTimeSeedersReconnect(isConnectedToSeeders) -} - -func (okdd *optimizedKadDhtDiscoverer) finishMainLoopProcessing(ctx context.Context) { - select { - case okdd.errChanInit <- ctx.Err(): - default: - } -} - -func (okdd *optimizedKadDhtDiscoverer) createChTimeSeedersReconnect(isConnectedToSeeders bool) <-chan time.Time { - if isConnectedToSeeders { - // the reconnection will be done less often - return time.After(okdd.seedersReconnectionInterval) - } - - // no connection to seeders, let's try a little bit faster - return time.After(okdd.peersRefreshInterval) -} - -func (okdd *optimizedKadDhtDiscoverer) init(ctx context.Context) error { - if okdd.status != statNotInitialized { - return p2p.ErrPeerDiscoveryProcessAlreadyStarted - } - - kadDhtHandler, err := okdd.createKadDhtHandler(ctx) - if err != nil { - return err - } - - okdd.kadDHT = kadDhtHandler - okdd.status = statInitialized - - return nil -} - -func (okdd *optimizedKadDhtDiscoverer) createKadDht(ctx context.Context) (KadDhtHandler, error) { - protocolID := protocol.ID(okdd.protocolID) - return dht.New( - ctx, - okdd.hostConnManagement, - dht.ProtocolPrefix(protocolID), - dht.RoutingTableRefreshPeriod(okdd.routingTableRefresh), - dht.Mode(dht.ModeServer), - ) -} - -func (okdd *optimizedKadDhtDiscoverer) tryToReconnectAtLeastToASeeder(ctx context.Context) bool { - if okdd.status != statInitialized { - return false - } - - if len(okdd.initialPeersList) == 0 { - return true - } - - connectedToOneSeeder := false - for _, seederAddress := range okdd.initialPeersList { - err := okdd.connectToSeeder(ctx, seederAddress) - if err != nil { - printConnectionErrorToSeeder(seederAddress, err) - } else { - connectedToOneSeeder = true - } - - select { - case <-ctx.Done(): - log.Debug("optimizedKadDhtDiscoverer.tryToReconnectAtLeastToASeeder", - "num seeders", len(okdd.initialPeersList), "connected to a seeder", true, "context", "done") - return true - default: - } - } - - log.Debug("optimizedKadDhtDiscoverer.tryToReconnectAtLeastToASeeder", - "num seeders", len(okdd.initialPeersList), "connected to a seeder", connectedToOneSeeder) - - return connectedToOneSeeder -} - -func (okdd *optimizedKadDhtDiscoverer) connectToSeeder(ctx context.Context, seederAddress string) error { - seederInfo, err := okdd.hostConnManagement.AddressToPeerInfo(seederAddress) - if err != nil { - return err - } - - if okdd.hostConnManagement.IsConnected(*seederInfo) { - return nil - } - - return okdd.hostConnManagement.Connect(ctx, *seederInfo) -} - -func (okdd *optimizedKadDhtDiscoverer) findPeers(ctx context.Context) { - if okdd.status != statInitialized { - return - } - - err := okdd.kadDHT.Bootstrap(ctx) - if err != nil { - log.Debug("kad dht bootstrap", "error", err) - } -} - -// Name returns the name of the kad dht peer discovery implementation -func (okdd *optimizedKadDhtDiscoverer) Name() string { - return optimizedKadDhtName -} - -// ReconnectToNetwork will try to connect to one peer from the initial peer list -func (okdd *optimizedKadDhtDiscoverer) ReconnectToNetwork(_ context.Context) { - select { - case okdd.chanConnectToSeeders <- struct{}{}: - default: - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (okdd *optimizedKadDhtDiscoverer) IsInterfaceNil() bool { - return okdd == nil -} diff --git a/p2p/libp2p/discovery/optimizedKadDhtDiscoverer_test.go b/p2p/libp2p/discovery/optimizedKadDhtDiscoverer_test.go deleted file mode 100644 index bb8304ce444..00000000000 --- a/p2p/libp2p/discovery/optimizedKadDhtDiscoverer_test.go +++ /dev/null @@ -1,263 +0,0 @@ -package discovery_test - -import ( - "context" - "errors" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" -) - -func TestNewOptimizedKadDhtDiscoverer(t *testing.T) { - t.Parallel() - - t.Run("invalid argument should error", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.Host = nil - okdd, err := discovery.NewOptimizedKadDhtDiscoverer(arg) - assert.Equal(t, p2p.ErrNilHost, err) - assert.True(t, check.IfNil(okdd)) - - arg = createTestArgument() - arg.SeedersReconnectionInterval = 0 - okdd, err = discovery.NewOptimizedKadDhtDiscoverer(arg) - assert.Equal(t, p2p.ErrInvalidSeedersReconnectionInterval, err) - assert.True(t, check.IfNil(okdd)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - var cancelFunc func() - arg.Context, cancelFunc = context.WithCancel(context.Background()) - okdd, err := discovery.NewOptimizedKadDhtDiscoverer(arg) - - assert.Nil(t, err) - assert.False(t, check.IfNil(okdd)) - cancelFunc() - - assert.Equal(t, discovery.OptimizedKadDhtName, okdd.Name()) - }) -} - -func TestOptimizedKadDhtDiscoverer_BootstrapWithRealKadDhtFuncShouldNotError(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.InitialPeersList = make([]string, 0) - var cancelFunc func() - arg.Context, cancelFunc = context.WithCancel(context.Background()) - okdd, _ := discovery.NewOptimizedKadDhtDiscoverer(arg) - - err := okdd.Bootstrap() - - assert.Nil(t, err) - cancelFunc() -} - -func TestOptimizedKadDhtDiscoverer_BootstrapEmptyPeerListShouldStartBootstrap(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.InitialPeersList = make([]string, 0) - var cancelFunc func() - arg.Context, cancelFunc = context.WithCancel(context.Background()) - bootstrapCalled := uint32(0) - kadDhtStub := &mock.KadDhtHandlerStub{ - BootstrapCalled: func(ctx context.Context) error { - atomic.AddUint32(&bootstrapCalled, 1) - return nil - }, - } - - okdd, _ := discovery.NewOptimizedKadDhtDiscovererWithInitFunc( - arg, - func(ctx context.Context) (discovery.KadDhtHandler, error) { - return kadDhtStub, nil - }, - ) - - err := okdd.Bootstrap() - // a little delay as the bootstrap returns immediately after init. The seeder reconnection and bootstrap part - // are called async - time.Sleep(time.Second + time.Millisecond*500) // the value is chosen as such as to avoid edgecases on select statement - - assert.Nil(t, err) - assert.Equal(t, uint32(2), atomic.LoadUint32(&bootstrapCalled)) - cancelFunc() -} - -func TestOptimizedKadDhtDiscoverer_BootstrapWithPeerListShouldStartBootstrap(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - arg.SeedersReconnectionInterval = time.Second - bootstrapCalled := uint32(0) - connectCalled := uint32(0) - arg.Host = &mock.ConnectableHostStub{ - ConnectCalled: func(ctx context.Context, pi peer.AddrInfo) error { - atomic.AddUint32(&connectCalled, 1) - return nil - }, - AddressToPeerInfoCalled: func(address string) (*peer.AddrInfo, error) { - return &peer.AddrInfo{}, nil - }, - } - var cancelFunc func() - arg.Context, cancelFunc = context.WithCancel(context.Background()) - - kadDhtStub := &mock.KadDhtHandlerStub{ - BootstrapCalled: func(ctx context.Context) error { - atomic.AddUint32(&bootstrapCalled, 1) - return nil - }, - } - - okdd, _ := discovery.NewOptimizedKadDhtDiscovererWithInitFunc( - arg, - func(ctx context.Context) (discovery.KadDhtHandler, error) { - return kadDhtStub, nil - }, - ) - - err := okdd.Bootstrap() - time.Sleep(time.Second*4 + time.Millisecond*500) // the value is chosen as such as to avoid edgecases on select statement - cancelFunc() - - assert.Nil(t, err) - assert.Equal(t, uint32(5), atomic.LoadUint32(&bootstrapCalled)) - assert.Equal(t, uint32(10), atomic.LoadUint32(&connectCalled)) -} - -func TestOptimizedKadDhtDiscoverer_BootstrapErrorsShouldKeepRetrying(t *testing.T) { - t.Parallel() - - arg := createTestArgument() - var cancelFunc func() - arg.Context, cancelFunc = context.WithCancel(context.Background()) - bootstrapCalled := uint32(0) - expectedErr := errors.New("expected error") - kadDhtStub := &mock.KadDhtHandlerStub{ - BootstrapCalled: func(ctx context.Context) error { - atomic.AddUint32(&bootstrapCalled, 1) - return expectedErr - }, - } - - okdd, _ := discovery.NewOptimizedKadDhtDiscovererWithInitFunc( - arg, - func(ctx context.Context) (discovery.KadDhtHandler, error) { - return kadDhtStub, nil - }, - ) - - err := okdd.Bootstrap() - // a little delay as the bootstrap returns immediately after init. The seeder reconnection and bootstrap part - // are called async - time.Sleep(time.Second*4 + time.Millisecond*500) // the value is chosen as such as to avoid edgecases on select statement - - assert.Nil(t, err) - assert.Equal(t, uint32(5), atomic.LoadUint32(&bootstrapCalled)) - cancelFunc() -} - -func TestOptimizedKadDhtDiscoverer_BootstrapErrorsForSeedersShouldRetryFast(t *testing.T) { - t.Parallel() - - numConnectCalls := uint32(0) - arg := createTestArgument() - arg.Host = &mock.ConnectableHostStub{ - ConnectCalled: func(ctx context.Context, pi peer.AddrInfo) error { - atomic.AddUint32(&numConnectCalls, 1) - return errors.New("cannot connect") - }, - } - arg.InitialPeersList = []string{"/ip4/127.0.0.1/tcp/9999/p2p/16Uiu2HAkw5SNNtSvH1zJiQ6Gc3WoGNSxiyNueRKe6fuAuh57G3Bk"} - var cancelFunc func() - arg.Context, cancelFunc = context.WithCancel(context.Background()) - kadDhtStub := &mock.KadDhtHandlerStub{ - BootstrapCalled: func(ctx context.Context) error { - return nil - }, - } - - okdd, _ := discovery.NewOptimizedKadDhtDiscovererWithInitFunc( - arg, - func(ctx context.Context) (discovery.KadDhtHandler, error) { - return kadDhtStub, nil - }, - ) - - err := okdd.Bootstrap() - // a little delay as the bootstrap returns immediately after init. The seeder reconnection and bootstrap part - // are called async - time.Sleep(time.Second*4 + time.Millisecond*500) // the value is chosen as such as to avoid edgecases on select statement - - assert.Nil(t, err) - assert.True(t, atomic.LoadUint32(&numConnectCalls) > 1) - cancelFunc() -} - -func TestOptimizedKadDhtDiscoverer_ReconnectToNetwork(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - t.Parallel() - - arg := createTestArgument() - var cancelFunc func() - arg.Context, cancelFunc = context.WithCancel(context.Background()) - bootstrapCalled := uint32(0) - expectedErr := errors.New("expected error") - mutConnect := sync.Mutex{} - connectCalled := 0 - arg.Host = &mock.ConnectableHostStub{ - ConnectCalled: func(ctx context.Context, pi peer.AddrInfo) error { - mutConnect.Lock() - defer mutConnect.Unlock() - - connectCalled++ - - return nil - }, - AddressToPeerInfoCalled: func(address string) (*peer.AddrInfo, error) { - return &peer.AddrInfo{}, nil - }, - } - kadDhtStub := &mock.KadDhtHandlerStub{ - BootstrapCalled: func(ctx context.Context) error { - atomic.AddUint32(&bootstrapCalled, 1) - return expectedErr - }, - } - - okdd, _ := discovery.NewOptimizedKadDhtDiscovererWithInitFunc( - arg, - func(ctx context.Context) (discovery.KadDhtHandler, error) { - return kadDhtStub, nil - }, - ) - - err := okdd.Bootstrap() - time.Sleep(time.Second) - okdd.ReconnectToNetwork(context.Background()) - time.Sleep(time.Millisecond * 500) // the value is chosen as such as to avoid edge cases on select statement - cancelFunc() - - assert.Nil(t, err) - assert.Equal(t, uint32(2), atomic.LoadUint32(&bootstrapCalled)) - mutConnect.Lock() - assert.True(t, connectCalled > 0) - mutConnect.Unlock() -} diff --git a/p2p/libp2p/export_test.go b/p2p/libp2p/export_test.go deleted file mode 100644 index 320766d2111..00000000000 --- a/p2p/libp2p/export_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package libp2p - -import ( - "context" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/go-libp2p-pubsub" - pb "github.com/ElrondNetwork/go-libp2p-pubsub/pb" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/whyrusleeping/timecache" -) - -var MaxSendBuffSize = maxSendBuffSize -var BroadcastGoRoutines = broadcastGoRoutines -var PubsubTimeCacheDuration = pubsubTimeCacheDuration -var AcceptMessagesInAdvanceDuration = acceptMessagesInAdvanceDuration - -const CurrentTopicMessageVersion = currentTopicMessageVersion -const PollWaitForConnectionsInterval = pollWaitForConnectionsInterval - -// SetHost - -func (netMes *networkMessenger) SetHost(newHost ConnectableHost) { - netMes.p2pHost = newHost -} - -// SetLoadBalancer - -func (netMes *networkMessenger) SetLoadBalancer(outgoingPLB p2p.ChannelLoadBalancer) { - netMes.outgoingPLB = outgoingPLB -} - -// SetPeerDiscoverer - -func (netMes *networkMessenger) SetPeerDiscoverer(discoverer p2p.PeerDiscoverer) { - netMes.peerDiscoverer = discoverer -} - -// PubsubCallback - -func (netMes *networkMessenger) PubsubCallback(handler p2p.MessageProcessor, topic string) func(ctx context.Context, pid peer.ID, message *pubsub.Message) bool { - topicProcs := newTopicProcessors() - _ = topicProcs.addTopicProcessor("identifier", handler) - - return netMes.pubsubCallback(topicProcs, topic) -} - -// ValidMessageByTimestamp - -func (netMes *networkMessenger) ValidMessageByTimestamp(msg p2p.MessageP2P) error { - return netMes.validMessageByTimestamp(msg) -} - -// MapHistogram - -func (netMes *networkMessenger) MapHistogram(input map[uint32]int) string { - return netMes.mapHistogram(input) -} - -// PubsubHasTopic - -func (netMes *networkMessenger) PubsubHasTopic(expectedTopic string) bool { - netMes.mutTopics.RLock() - topics := netMes.pb.GetTopics() - netMes.mutTopics.RUnlock() - - for _, topic := range topics { - if topic == expectedTopic { - return true - } - } - return false -} - -// HasProcessorForTopic - -func (netMes *networkMessenger) HasProcessorForTopic(expectedTopic string) bool { - processor, found := netMes.processors[expectedTopic] - - return found && processor != nil -} - -// ProcessReceivedDirectMessage - -func (ds *directSender) ProcessReceivedDirectMessage(message *pb.Message, fromConnectedPeer peer.ID) error { - return ds.processReceivedDirectMessage(message, fromConnectedPeer) -} - -// SeenMessages - -func (ds *directSender) SeenMessages() *timecache.TimeCache { - return ds.seenMessages -} - -// Counter - -func (ds *directSender) Counter() uint64 { - return ds.counter -} - -// Mutexes - -func (mh *MutexHolder) Mutexes() storage.Cacher { - return mh.mutexes -} diff --git a/p2p/libp2p/interface.go b/p2p/libp2p/interface.go deleted file mode 100644 index fae9fb214df..00000000000 --- a/p2p/libp2p/interface.go +++ /dev/null @@ -1,29 +0,0 @@ -package libp2p - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/network" -) - -// ConnectionMonitor defines the behavior of a connection monitor -type ConnectionMonitor interface { - network.Notifiee - IsConnectedToTheNetwork(netw network.Network) bool - SetThresholdMinConnectedPeers(thresholdMinConnectedPeers int, netw network.Network) - ThresholdMinConnectedPeers() int - Close() error - IsInterfaceNil() bool -} - -// PeerDiscovererWithSharder extends the PeerDiscoverer with the possibility to set the sharder -type PeerDiscovererWithSharder interface { - p2p.PeerDiscoverer - SetSharder(sharder p2p.Sharder) error -} - -type p2pSigner interface { - Sign(payload []byte) ([]byte, error) - Verify(payload []byte, pid core.PeerID, signature []byte) error - SignUsingPrivateKey(skBytes []byte, payload []byte) ([]byte, error) -} diff --git a/p2p/libp2p/issues_test.go b/p2p/libp2p/issues_test.go deleted file mode 100644 index 4335d744d36..00000000000 --- a/p2p/libp2p/issues_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package libp2p_test - -import ( - "bytes" - "fmt" - "sync/atomic" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" -) - -func createMessenger() p2p.Messenger { - args := libp2p.ArgsNetworkMessenger{ - Marshalizer: &testscommon.ProtoMarshalizerMock{}, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - }, - Sharding: config.ShardingConfig{ - Type: p2p.NilListSharder, - }, - }, - SyncTimer: &libp2p.LocalSyncTimer{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - NodeOperationMode: p2p.NormalOperation, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - - libP2PMes, err := libp2p.NewNetworkMessenger(args) - if err != nil { - fmt.Println(err.Error()) - } - - return libP2PMes -} - -// TestIssueEN898_StreamResetError emphasizes what happens if direct sender writes to a stream that has been reset -// Testing is done by writing a large buffer that will cause the recipient to reset its inbound stream -// Sender will then be notified that the stream writing did not succeed but it will only log the error -// Next message that the sender tries to send will cause a new error to be logged and no data to be sent -// The fix consists in the full stream closing when an error occurs during writing. -func TestIssueEN898_StreamResetError(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - mes1 := createMessenger() - mes2 := createMessenger() - - defer func() { - _ = mes1.Close() - _ = mes2.Close() - }() - - _ = mes1.ConnectToPeer(getConnectableAddress(mes2)) - - topic := "test topic" - - size4MB := 1 << 22 - size4kB := 1 << 12 - - // a 4MB slice containing character A - largePacket := bytes.Repeat([]byte{65}, size4MB) - // a 4kB slice containing character B - smallPacket := bytes.Repeat([]byte{66}, size4kB) - - largePacketReceived := &atomic.Value{} - largePacketReceived.Store(false) - - smallPacketReceived := &atomic.Value{} - smallPacketReceived.Store(false) - - _ = mes2.CreateTopic(topic, false) - _ = mes2.RegisterMessageProcessor(topic, "identifier", &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, _ core.PeerID) error { - if bytes.Equal(message.Data(), largePacket) { - largePacketReceived.Store(true) - } - - if bytes.Equal(message.Data(), smallPacket) { - smallPacketReceived.Store(true) - } - - return nil - }, - }) - - fmt.Println("sending the large packet...") - _ = mes1.SendToConnectedPeer(topic, largePacket, mes2.ID()) - - time.Sleep(time.Second) - - fmt.Println("sending the small packet...") - _ = mes1.SendToConnectedPeer(topic, smallPacket, mes2.ID()) - - time.Sleep(time.Second) - - assert.False(t, largePacketReceived.Load().(bool)) - assert.True(t, smallPacketReceived.Load().(bool)) -} diff --git a/p2p/libp2p/localSyncTimer.go b/p2p/libp2p/localSyncTimer.go deleted file mode 100644 index 9b3644a982c..00000000000 --- a/p2p/libp2p/localSyncTimer.go +++ /dev/null @@ -1,17 +0,0 @@ -package libp2p - -import "time" - -// LocalSyncTimer uses the local system to provide the current time -type LocalSyncTimer struct { -} - -// CurrentTime returns the local current time -func (lst *LocalSyncTimer) CurrentTime() time.Time { - return time.Now() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (lst *LocalSyncTimer) IsInterfaceNil() bool { - return lst == nil -} diff --git a/p2p/libp2p/message.go b/p2p/libp2p/message.go deleted file mode 100644 index 2cff6bd3c47..00000000000 --- a/p2p/libp2p/message.go +++ /dev/null @@ -1,65 +0,0 @@ -package libp2p - -import ( - "fmt" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/data" - "github.com/ElrondNetwork/elrond-go/p2p/message" - pubsub "github.com/ElrondNetwork/go-libp2p-pubsub" - "github.com/libp2p/go-libp2p-core/peer" -) - -const currentTopicMessageVersion = uint32(1) - -// NewMessage returns a new instance of a Message object -func NewMessage(msg *pubsub.Message, marshalizer p2p.Marshalizer) (*message.Message, error) { - if check.IfNil(marshalizer) { - return nil, p2p.ErrNilMarshalizer - } - if msg == nil { - return nil, p2p.ErrNilMessage - } - if msg.Topic == nil { - return nil, p2p.ErrNilTopic - } - - newMsg := &message.Message{ - FromField: msg.From, - PayloadField: msg.Data, - SeqNoField: msg.Seqno, - TopicField: *msg.Topic, - SignatureField: msg.Signature, - KeyField: msg.Key, - } - - topicMessage := &data.TopicMessage{} - err := marshalizer.Unmarshal(topicMessage, msg.Data) - if err != nil { - return nil, fmt.Errorf("%w error: %s", p2p.ErrMessageUnmarshalError, err.Error()) - } - - // TODO change this area when new versions of the message will need to be implemented - if topicMessage.Version != currentTopicMessageVersion { - return nil, fmt.Errorf("%w, supported %d, got %d", - p2p.ErrUnsupportedMessageVersion, currentTopicMessageVersion, topicMessage.Version) - } - - if len(topicMessage.SignatureOnPid)+len(topicMessage.Pk) > 0 { - return nil, fmt.Errorf("%w for topicMessage.SignatureOnPid and topicMessage.Pk", - p2p.ErrUnsupportedFields) - } - - newMsg.DataField = topicMessage.Payload - newMsg.TimestampField = topicMessage.Timestamp - - id, err := peer.IDFromBytes(newMsg.From()) - if err != nil { - return nil, err - } - - newMsg.PeerField = core.PeerID(id) - return newMsg, nil -} diff --git a/p2p/libp2p/message_test.go b/p2p/libp2p/message_test.go deleted file mode 100644 index 3a8bfb63ef0..00000000000 --- a/p2p/libp2p/message_test.go +++ /dev/null @@ -1,251 +0,0 @@ -package libp2p_test - -import ( - "crypto/ecdsa" - "crypto/rand" - "errors" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/data" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/go-libp2p-pubsub" - pb "github.com/ElrondNetwork/go-libp2p-pubsub/pb" - "github.com/btcsuite/btcd/btcec" - libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func getRandomID() []byte { - prvKey, _ := ecdsa.GenerateKey(btcec.S256(), rand.Reader) - sk := (*libp2pCrypto.Secp256k1PrivateKey)(prvKey) - id, _ := peer.IDFromPublicKey(sk.GetPublic()) - - return []byte(id) -} - -func TestMessage_NilMarshalizerShouldErr(t *testing.T) { - t.Parallel() - - pMes := &pubsub.Message{} - m, err := libp2p.NewMessage(pMes, nil) - - assert.True(t, check.IfNil(m)) - assert.True(t, errors.Is(err, p2p.ErrNilMarshalizer)) -} - -func TestMessage_ShouldErrBecauseOfFromField(t *testing.T) { - t.Parallel() - - from := []byte("dummy from") - marshalizer := &testscommon.ProtoMarshalizerMock{} - - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - } - buff, _ := marshalizer.Marshal(topicMessage) - topic := "topic" - mes := &pb.Message{ - From: from, - Data: buff, - Topic: &topic, - } - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - assert.True(t, check.IfNil(m)) - assert.NotNil(t, err) -} - -func TestMessage_ShouldWork(t *testing.T) { - t.Parallel() - - marshalizer := &testscommon.ProtoMarshalizerMock{} - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - } - buff, _ := marshalizer.Marshal(topicMessage) - topic := "topic" - mes := &pb.Message{ - From: getRandomID(), - Data: buff, - Topic: &topic, - } - - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - require.Nil(t, err) - assert.False(t, check.IfNil(m)) -} - -func TestMessage_From(t *testing.T) { - t.Parallel() - - from := getRandomID() - marshalizer := &testscommon.ProtoMarshalizerMock{} - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - } - buff, _ := marshalizer.Marshal(topicMessage) - topic := "topic" - mes := &pb.Message{ - From: from, - Data: buff, - Topic: &topic, - } - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - require.Nil(t, err) - assert.Equal(t, m.From(), from) -} - -func TestMessage_Peer(t *testing.T) { - t.Parallel() - - id := getRandomID() - marshalizer := &testscommon.ProtoMarshalizerMock{} - - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - } - buff, _ := marshalizer.Marshal(topicMessage) - topic := "topic" - mes := &pb.Message{ - From: id, - Data: buff, - Topic: &topic, - } - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - require.Nil(t, err) - assert.Equal(t, core.PeerID(id), m.Peer()) -} - -func TestMessage_WrongVersionShouldErr(t *testing.T) { - t.Parallel() - - marshalizer := &testscommon.ProtoMarshalizerMock{} - - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion + 1, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - } - buff, _ := marshalizer.Marshal(topicMessage) - topic := "topic" - mes := &pb.Message{ - From: getRandomID(), - Data: buff, - Topic: &topic, - } - - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - assert.True(t, check.IfNil(m)) - assert.True(t, errors.Is(err, p2p.ErrUnsupportedMessageVersion)) -} - -func TestMessage_PopulatedPkFieldShouldErr(t *testing.T) { - t.Parallel() - - marshalizer := &testscommon.ProtoMarshalizerMock{} - - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - Pk: []byte("p"), - } - buff, _ := marshalizer.Marshal(topicMessage) - topic := "topic" - mes := &pb.Message{ - From: getRandomID(), - Data: buff, - Topic: &topic, - } - - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - assert.True(t, check.IfNil(m)) - assert.True(t, errors.Is(err, p2p.ErrUnsupportedFields)) -} - -func TestMessage_PopulatedSigFieldShouldErr(t *testing.T) { - t.Parallel() - - marshalizer := &testscommon.ProtoMarshalizerMock{} - - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - SignatureOnPid: []byte("s"), - } - buff, _ := marshalizer.Marshal(topicMessage) - topic := "topic" - mes := &pb.Message{ - From: getRandomID(), - Data: buff, - Topic: &topic, - } - - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - assert.True(t, check.IfNil(m)) - assert.True(t, errors.Is(err, p2p.ErrUnsupportedFields)) -} - -func TestMessage_NilTopic(t *testing.T) { - t.Parallel() - - id := getRandomID() - marshalizer := &testscommon.ProtoMarshalizerMock{} - - topicMessage := &data.TopicMessage{ - Version: libp2p.CurrentTopicMessageVersion, - Timestamp: time.Now().Unix(), - Payload: []byte("data"), - } - buff, _ := marshalizer.Marshal(topicMessage) - mes := &pb.Message{ - From: id, - Data: buff, - Topic: nil, - } - pMes := &pubsub.Message{Message: mes} - m, err := libp2p.NewMessage(pMes, marshalizer) - - assert.Equal(t, p2p.ErrNilTopic, err) - assert.True(t, check.IfNil(m)) -} - -func TestMessage_NilMessage(t *testing.T) { - t.Parallel() - - marshalizer := &testscommon.ProtoMarshalizerMock{} - - m, err := libp2p.NewMessage(nil, marshalizer) - - assert.Equal(t, p2p.ErrNilMessage, err) - assert.True(t, check.IfNil(m)) -} diff --git a/p2p/libp2p/metrics/connections.go b/p2p/libp2p/metrics/connections.go deleted file mode 100644 index 865a7ab267c..00000000000 --- a/p2p/libp2p/metrics/connections.go +++ /dev/null @@ -1,54 +0,0 @@ -package metrics - -import ( - "sync/atomic" - - "github.com/libp2p/go-libp2p-core/network" - "github.com/multiformats/go-multiaddr" -) - -// Connections is a metric that counts connections and disconnections done by the host implementation -type Connections struct { - numConnections uint32 - numDisconnections uint32 -} - -// NewConnections returns a new connsDisconnsMetric instance -func NewConnections() *Connections { - return &Connections{ - numConnections: 0, - numDisconnections: 0, - } -} - -// Listen is called when network starts listening on an addr -func (conns *Connections) Listen(network.Network, multiaddr.Multiaddr) {} - -// ListenClose is called when network stops listening on an addr -func (conns *Connections) ListenClose(network.Network, multiaddr.Multiaddr) {} - -// Connected is called when a connection opened. It increments the numConnections counter -func (conns *Connections) Connected(network.Network, network.Conn) { - atomic.AddUint32(&conns.numConnections, 1) -} - -// Disconnected is called when a connection closed it increments the numDisconnections counter -func (conns *Connections) Disconnected(network.Network, network.Conn) { - atomic.AddUint32(&conns.numDisconnections, 1) -} - -// OpenedStream is called when a stream opened -func (conns *Connections) OpenedStream(network.Network, network.Stream) {} - -// ClosedStream is called when a stream closed -func (conns *Connections) ClosedStream(network.Network, network.Stream) {} - -// ResetNumConnections resets the numConnections counter returning the previous value -func (conns *Connections) ResetNumConnections() uint32 { - return atomic.SwapUint32(&conns.numConnections, 0) -} - -// ResetNumDisconnections resets the numDisconnections counter returning the previous value -func (conns *Connections) ResetNumDisconnections() uint32 { - return atomic.SwapUint32(&conns.numDisconnections, 0) -} diff --git a/p2p/libp2p/metrics/connections_test.go b/p2p/libp2p/metrics/connections_test.go deleted file mode 100644 index 0d3a177c034..00000000000 --- a/p2p/libp2p/metrics/connections_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package metrics_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/metrics" - "github.com/stretchr/testify/assert" -) - -func TestConnections_EmptyFunctionsDoNotPanicWhenCalled(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r != nil { - assert.Fail(t, "test should not have failed") - } - }() - - cdm := metrics.NewConnections() - - cdm.ClosedStream(nil, nil) - cdm.Listen(nil, nil) - cdm.ListenClose(nil, nil) - cdm.OpenedStream(nil, nil) -} - -func TestConnections_ResetNumConnectionsShouldWork(t *testing.T) { - t.Parallel() - - cdm := metrics.NewConnections() - - cdm.Connected(nil, nil) - cdm.Connected(nil, nil) - - existing := cdm.ResetNumConnections() - assert.Equal(t, uint32(2), existing) - - existing = cdm.ResetNumConnections() - assert.Equal(t, uint32(0), existing) -} - -func TestConnsDisconnsMetric_ResetNumDisconnectionsShouldWork(t *testing.T) { - t.Parallel() - - cdm := metrics.NewConnections() - - cdm.Disconnected(nil, nil) - cdm.Disconnected(nil, nil) - - existing := cdm.ResetNumDisconnections() - assert.Equal(t, uint32(2), existing) - - existing = cdm.ResetNumDisconnections() - assert.Equal(t, uint32(0), existing) -} diff --git a/p2p/libp2p/metrics/disabledConnectionsWatcher.go b/p2p/libp2p/metrics/disabledConnectionsWatcher.go deleted file mode 100644 index 63689b6508d..00000000000 --- a/p2p/libp2p/metrics/disabledConnectionsWatcher.go +++ /dev/null @@ -1,23 +0,0 @@ -package metrics - -import "github.com/ElrondNetwork/elrond-go-core/core" - -type disabledConnectionsWatcher struct{} - -// NewDisabledConnectionsWatcher returns a disabled ConnectionWatcher implementation -func NewDisabledConnectionsWatcher() *disabledConnectionsWatcher { - return &disabledConnectionsWatcher{} -} - -// NewKnownConnection does nothing -func (dcw *disabledConnectionsWatcher) NewKnownConnection(_ core.PeerID, _ string) {} - -// Close does nothing and returns nil -func (dcw *disabledConnectionsWatcher) Close() error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (dcw *disabledConnectionsWatcher) IsInterfaceNil() bool { - return dcw == nil -} diff --git a/p2p/libp2p/metrics/disabledConnectionsWatcher_test.go b/p2p/libp2p/metrics/disabledConnectionsWatcher_test.go deleted file mode 100644 index e910c49ebdc..00000000000 --- a/p2p/libp2p/metrics/disabledConnectionsWatcher_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package metrics - -import ( - "fmt" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/stretchr/testify/assert" -) - -func TestDisabledConnectionsWatcher_MethodsShouldNotPanic(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r != nil { - assert.Fail(t, fmt.Sprintf("should have not panic: %v", r)) - } - }() - - dcw := NewDisabledConnectionsWatcher() - assert.False(t, check.IfNil(dcw)) - dcw.NewKnownConnection("", "") - err := dcw.Close() - assert.Nil(t, err) -} diff --git a/p2p/libp2p/metrics/errors.go b/p2p/libp2p/metrics/errors.go deleted file mode 100644 index 1bbd6d5074d..00000000000 --- a/p2p/libp2p/metrics/errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package metrics - -import "errors" - -var errInvalidValueForTimeToLiveParam = errors.New("invalid value for the time-to-live parameter") diff --git a/p2p/libp2p/metrics/export_test.go b/p2p/libp2p/metrics/export_test.go deleted file mode 100644 index 6ce8114afb4..00000000000 --- a/p2p/libp2p/metrics/export_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package metrics - -import ( - "context" - "fmt" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/storage/timecache" -) - -// NewPrintConnectionsWatcherWithHandler - -func NewPrintConnectionsWatcherWithHandler(timeToLive time.Duration, handler func(pid core.PeerID, connection string)) (*printConnectionsWatcher, error) { - if timeToLive < minTimeToLive { - return nil, fmt.Errorf("%w in NewPrintConnectionsWatcher, got: %d, minimum: %d", errInvalidValueForTimeToLiveParam, timeToLive, minTimeToLive) - } - - pcw := &printConnectionsWatcher{ - timeToLive: timeToLive, - timeCacher: timecache.NewTimeCache(timeToLive), - printHandler: handler, - } - - ctx, cancel := context.WithCancel(context.Background()) - pcw.cancel = cancel - go pcw.doSweep(ctx) - - return pcw, nil -} diff --git a/p2p/libp2p/metrics/factory/connectionWatcherFactory.go b/p2p/libp2p/metrics/factory/connectionWatcherFactory.go deleted file mode 100644 index 562bdfa2112..00000000000 --- a/p2p/libp2p/metrics/factory/connectionWatcherFactory.go +++ /dev/null @@ -1,21 +0,0 @@ -package factory - -import ( - "fmt" - "time" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/metrics" -) - -// NewConnectionsWatcher creates a new ConnectionWatcher instance based on the input parameters -func NewConnectionsWatcher(connectionsWatcherType string, timeToLive time.Duration) (p2p.ConnectionsWatcher, error) { - switch connectionsWatcherType { - case p2p.ConnectionWatcherTypePrint: - return metrics.NewPrintConnectionsWatcher(timeToLive) - case p2p.ConnectionWatcherTypeDisabled, p2p.ConnectionWatcherTypeEmpty: - return metrics.NewDisabledConnectionsWatcher(), nil - default: - return nil, fmt.Errorf("%w %s", errUnknownConnectionWatcherType, connectionsWatcherType) - } -} diff --git a/p2p/libp2p/metrics/factory/connectionWatcherFactory_test.go b/p2p/libp2p/metrics/factory/connectionWatcherFactory_test.go deleted file mode 100644 index 1dcb5980d84..00000000000 --- a/p2p/libp2p/metrics/factory/connectionWatcherFactory_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package factory - -import ( - "errors" - "fmt" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/stretchr/testify/assert" -) - -func TestNewConnectionsWatcher(t *testing.T) { - t.Parallel() - - t.Run("print connections watcher", func(t *testing.T) { - t.Parallel() - - cw, err := NewConnectionsWatcher(p2p.ConnectionWatcherTypePrint, time.Second) - assert.Nil(t, err) - assert.False(t, check.IfNil(cw)) - assert.Equal(t, "*metrics.printConnectionsWatcher", fmt.Sprintf("%T", cw)) - }) - t.Run("disabled connections watcher", func(t *testing.T) { - t.Parallel() - - cw, err := NewConnectionsWatcher(p2p.ConnectionWatcherTypeDisabled, time.Second) - assert.Nil(t, err) - assert.False(t, check.IfNil(cw)) - assert.Equal(t, "*metrics.disabledConnectionsWatcher", fmt.Sprintf("%T", cw)) - }) - t.Run("empty connections watcher", func(t *testing.T) { - t.Parallel() - - cw, err := NewConnectionsWatcher(p2p.ConnectionWatcherTypeEmpty, time.Second) - assert.Nil(t, err) - assert.False(t, check.IfNil(cw)) - assert.Equal(t, "*metrics.disabledConnectionsWatcher", fmt.Sprintf("%T", cw)) - }) - t.Run("unknown type", func(t *testing.T) { - t.Parallel() - - cw, err := NewConnectionsWatcher("unknown", time.Second) - assert.True(t, errors.Is(err, errUnknownConnectionWatcherType)) - assert.True(t, check.IfNil(cw)) - }) -} diff --git a/p2p/libp2p/metrics/factory/errors.go b/p2p/libp2p/metrics/factory/errors.go deleted file mode 100644 index df1a5f63fb0..00000000000 --- a/p2p/libp2p/metrics/factory/errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package factory - -import "errors" - -var errUnknownConnectionWatcherType = errors.New("unknown connection type") diff --git a/p2p/libp2p/metrics/printConnectionWatcher_test.go b/p2p/libp2p/metrics/printConnectionWatcher_test.go deleted file mode 100644 index c8226bee74b..00000000000 --- a/p2p/libp2p/metrics/printConnectionWatcher_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package metrics - -import ( - "errors" - "fmt" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/stretchr/testify/assert" -) - -func TestNewPrintConnectionsWatcher(t *testing.T) { - t.Parallel() - - t.Run("invalid value for time to live parameter should error", func(t *testing.T) { - t.Parallel() - - pcw, err := NewPrintConnectionsWatcher(minTimeToLive - time.Nanosecond) - assert.True(t, check.IfNil(pcw)) - assert.True(t, errors.Is(err, errInvalidValueForTimeToLiveParam)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - pcw, err := NewPrintConnectionsWatcher(minTimeToLive) - assert.False(t, check.IfNil(pcw)) - assert.Nil(t, err) - - _ = pcw.Close() - }) -} - -func TestPrintConnectionsWatcher_Close(t *testing.T) { - t.Parallel() - - t.Run("no iteration has been done", func(t *testing.T) { - t.Parallel() - - pcw, _ := NewPrintConnectionsWatcher(time.Hour) - err := pcw.Close() - - assert.Nil(t, err) - time.Sleep(time.Second) // allow the go routine to close - assert.True(t, pcw.goRoutineClosed.IsSet()) - }) - t.Run("iterations were done", func(t *testing.T) { - t.Parallel() - - pcw, _ := NewPrintConnectionsWatcher(time.Second) - time.Sleep(time.Second * 4) - err := pcw.Close() - - assert.Nil(t, err) - time.Sleep(time.Second) // allow the go routine to close - assert.True(t, pcw.goRoutineClosed.IsSet()) - }) - -} - -func TestPrintConnectionsWatcher_NewKnownConnection(t *testing.T) { - t.Parallel() - - t.Run("invalid connection", func(t *testing.T) { - providedPid := core.PeerID("pid") - connection := " " - numCalled := 0 - - handler := func(pid core.PeerID, conn string) { - numCalled++ - } - pcw, _ := NewPrintConnectionsWatcherWithHandler(time.Hour, handler) - - pcw.NewKnownConnection(providedPid, connection) - assert.Equal(t, 0, numCalled) - }) - t.Run("valid connection", func(t *testing.T) { - providedPid := core.PeerID("pid") - connection := "connection" - numCalled := 0 - - handler := func(pid core.PeerID, conn string) { - numCalled++ - assert.Equal(t, providedPid, pid) - assert.Equal(t, connection, conn) - } - pcw, _ := NewPrintConnectionsWatcherWithHandler(time.Hour, handler) - - pcw.NewKnownConnection(providedPid, connection) - assert.Equal(t, 1, numCalled) - pcw.NewKnownConnection(providedPid, connection) - assert.Equal(t, 1, numCalled) - }) -} - -func TestLogPrintHandler_shouldNotPanic(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r != nil { - assert.Fail(t, fmt.Sprintf("should have not panic: %v", r)) - } - }() - - logPrintHandler("pid", "connection") -} diff --git a/p2p/libp2p/metrics/printConnectionsWatcher.go b/p2p/libp2p/metrics/printConnectionsWatcher.go deleted file mode 100644 index b2e4d411a2b..00000000000 --- a/p2p/libp2p/metrics/printConnectionsWatcher.go +++ /dev/null @@ -1,102 +0,0 @@ -package metrics - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/timecache" -) - -const minTimeToLive = time.Second - -var log = logger.GetOrCreate("p2p/libp2p/metrics") - -type printConnectionsWatcher struct { - timeCacher storage.TimeCacher - goRoutineClosed atomic.Flag - timeToLive time.Duration - printHandler func(pid core.PeerID, connection string) - cancel func() -} - -// NewPrintConnectionsWatcher creates a new -func NewPrintConnectionsWatcher(timeToLive time.Duration) (*printConnectionsWatcher, error) { - if timeToLive < minTimeToLive { - return nil, fmt.Errorf("%w in NewPrintConnectionsWatcher, got: %d, minimum: %d", errInvalidValueForTimeToLiveParam, timeToLive, minTimeToLive) - } - - pcw := &printConnectionsWatcher{ - timeToLive: timeToLive, - timeCacher: timecache.NewTimeCache(timeToLive), - printHandler: logPrintHandler, - } - - ctx, cancel := context.WithCancel(context.Background()) - pcw.cancel = cancel - go pcw.doSweep(ctx) - - return pcw, nil -} - -func (pcw *printConnectionsWatcher) doSweep(ctx context.Context) { - timer := time.NewTimer(pcw.timeToLive) - defer func() { - timer.Stop() - pcw.goRoutineClosed.SetValue(true) - }() - - for { - timer.Reset(pcw.timeToLive) - - select { - case <-ctx.Done(): - log.Debug("printConnectionsWatcher's processing loop is closing...") - return - case <-timer.C: - } - - pcw.timeCacher.Sweep() - } -} - -// NewKnownConnection will add the known connection to the cache, printing it as necessary -func (pcw *printConnectionsWatcher) NewKnownConnection(pid core.PeerID, connection string) { - conn := strings.Trim(connection, " ") - if len(conn) == 0 { - return - } - - has := pcw.timeCacher.Has(pid.Pretty()) - err := pcw.timeCacher.Upsert(pid.Pretty(), pcw.timeToLive) - if err != nil { - log.Warn("programming error in printConnectionsWatcher.NewKnownConnection", "error", err) - return - } - if has { - return - } - - pcw.printHandler(pid, conn) -} - -// Close will close any go routines opened by this instance -func (pcw *printConnectionsWatcher) Close() error { - pcw.cancel() - - return nil -} - -func logPrintHandler(pid core.PeerID, connection string) { - log.Debug("new known peer", "pid", pid.Pretty(), "connection", connection) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (pcw *printConnectionsWatcher) IsInterfaceNil() bool { - return pcw == nil -} diff --git a/p2p/libp2p/mockMessenger.go b/p2p/libp2p/mockMessenger.go deleted file mode 100644 index cb86c7fa987..00000000000 --- a/p2p/libp2p/mockMessenger.go +++ /dev/null @@ -1,44 +0,0 @@ -package libp2p - -import ( - "context" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/metrics/factory" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" -) - -// NewMockMessenger creates a new sandbox testable instance of libP2P messenger -// It should not open ports on current machine -// Should be used only in testing! -func NewMockMessenger( - args ArgsNetworkMessenger, - mockNet mocknet.Mocknet, -) (*networkMessenger, error) { - if mockNet == nil { - return nil, p2p.ErrNilMockNet - } - - h, err := mockNet.GenPeer() - if err != nil { - return nil, err - } - - ctx, cancelFunc := context.WithCancel(context.Background()) - p2pNode := &networkMessenger{ - p2pHost: NewConnectableHost(h), - ctx: ctx, - cancelFunc: cancelFunc, - } - p2pNode.printConnectionsWatcher, err = factory.NewConnectionsWatcher(args.ConnectionWatcherType, ttlConnectionsWatcher) - if err != nil { - return nil, err - } - - err = addComponentsToNode(args, p2pNode, withoutMessageSigning) - if err != nil { - return nil, err - } - - return p2pNode, err -} diff --git a/p2p/libp2p/mutexHolder.go b/p2p/libp2p/mutexHolder.go deleted file mode 100644 index 4b53ef2c4d9..00000000000 --- a/p2p/libp2p/mutexHolder.go +++ /dev/null @@ -1,51 +0,0 @@ -package libp2p - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" -) - -// MutexHolder holds a cache of mutexes: pairs of (key, *sync.Mutex) -type MutexHolder struct { - // generalMutex is used to serialize the access to the already concurrent safe lrucache - generalMutex sync.Mutex - mutexes storage.Cacher -} - -// NewMutexHolder creates a new instance of MutexHolder with specified capacity. -func NewMutexHolder(mutexesCapacity int) (*MutexHolder, error) { - mh := &MutexHolder{} - var err error - mh.mutexes, err = lrucache.NewCache(mutexesCapacity) - if err != nil { - return nil, err - } - - return mh, nil -} - -// Get returns a mutex for the provided key. If the key was not found, it will create a new mutex, save it in the -// cache and returns it. -func (mh *MutexHolder) Get(key string) *sync.Mutex { - mh.generalMutex.Lock() - defer mh.generalMutex.Unlock() - - sliceKey := []byte(key) - val, ok := mh.mutexes.Get(sliceKey) - if !ok { - newMutex := &sync.Mutex{} - mh.mutexes.Put(sliceKey, newMutex, 0) - return newMutex - } - - mutex, ok := val.(*sync.Mutex) - if !ok { - newMutex := &sync.Mutex{} - mh.mutexes.Put(sliceKey, newMutex, 0) - return newMutex - } - - return mutex -} diff --git a/p2p/libp2p/mutexHolder_test.go b/p2p/libp2p/mutexHolder_test.go deleted file mode 100644 index cddae37cb0f..00000000000 --- a/p2p/libp2p/mutexHolder_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package libp2p_test - -import ( - "sync" - "testing" - - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - "github.com/stretchr/testify/assert" -) - -func TestNewMutexHolder_InvalidCapacityShouldErr(t *testing.T) { - t.Parallel() - - mh, err := libp2p.NewMutexHolder(-1) - - assert.Nil(t, mh) - assert.NotNil(t, err) -} - -func TestNewMutexHolder_InvalidCapacityShouldWork(t *testing.T) { - t.Parallel() - - mh, err := libp2p.NewMutexHolder(10) - - assert.NotNil(t, mh) - assert.Nil(t, err) -} - -func TestMutexHolder_MutexNotFoundShouldCreate(t *testing.T) { - t.Parallel() - - mh, _ := libp2p.NewMutexHolder(10) - key := "key" - mut := mh.Get(key) - - assert.NotNil(t, mut) - assert.Equal(t, 1, mh.Mutexes().Len()) - addedMutex, _ := mh.Mutexes().Get([]byte(key)) - // pointer testing to not have the situation of creating new mutexes for each getMutex call - assert.True(t, mut == addedMutex) -} - -func TestMutexHolder_OtherObjectInCacheShouldRewriteWithNewMutexAndReturn(t *testing.T) { - t.Parallel() - - mh, _ := libp2p.NewMutexHolder(10) - key := "key" - mh.Mutexes().Put([]byte(key), "not a mutex value", 0) - mut := mh.Get(key) - - assert.NotNil(t, mut) - assert.Equal(t, 1, mh.Mutexes().Len()) - addedMutex, _ := mh.Mutexes().Get([]byte(key)) - // pointer testing to not have the situation of creating new mutexes for each getMutex call - assert.True(t, mut == addedMutex) -} - -func TestMutexHolder_MutexFoundShouldReturnIt(t *testing.T) { - t.Parallel() - - mh, _ := libp2p.NewMutexHolder(10) - key := "key" - mut := &sync.Mutex{} - mh.Mutexes().Put([]byte(key), mut, 0) - mutRecov := mh.Get(key) - - assert.NotNil(t, mutRecov) - assert.Equal(t, 1, mh.Mutexes().Len()) - addedMutex, _ := mh.Mutexes().Get([]byte(key)) - // pointer testing to not have the situation of creating new mutexes for each getMutex call - assert.True(t, mut == addedMutex) - assert.True(t, mut == mutRecov) -} diff --git a/p2p/libp2p/netMessenger.go b/p2p/libp2p/netMessenger.go deleted file mode 100644 index 3921cf85627..00000000000 --- a/p2p/libp2p/netMessenger.go +++ /dev/null @@ -1,1433 +0,0 @@ -package libp2p - -import ( - "context" - "encoding/hex" - "fmt" - "sort" - "strings" - "sync" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go-core/core/throttler" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/config" - p2pDebug "github.com/ElrondNetwork/elrond-go/debug/p2p" - "github.com/ElrondNetwork/elrond-go/p2p" - p2pCrypto "github.com/ElrondNetwork/elrond-go/p2p/crypto" - "github.com/ElrondNetwork/elrond-go/p2p/data" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/connectionMonitor" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/disabled" - discoveryFactory "github.com/ElrondNetwork/elrond-go/p2p/libp2p/discovery/factory" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/metrics" - metricsFactory "github.com/ElrondNetwork/elrond-go/p2p/libp2p/metrics/factory" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/networksharding/factory" - "github.com/ElrondNetwork/elrond-go/p2p/loadBalancer" - pubsub "github.com/ElrondNetwork/go-libp2p-pubsub" - pubsubPb "github.com/ElrondNetwork/go-libp2p-pubsub/pb" - logging "github.com/ipfs/go-log" - "github.com/libp2p/go-libp2p" - "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" -) - -const ( - // ListenAddrWithIp4AndTcp defines the listening address with ip v.4 and TCP - ListenAddrWithIp4AndTcp = "/ip4/0.0.0.0/tcp/" - - // ListenLocalhostAddrWithIp4AndTcp defines the local host listening ip v.4 address and TCP - ListenLocalhostAddrWithIp4AndTcp = "/ip4/127.0.0.1/tcp/" - - // DirectSendID represents the protocol ID for sending and receiving direct P2P messages - DirectSendID = protocol.ID("/erd/directsend/1.0.0") - - durationBetweenSends = time.Microsecond * 10 - durationCheckConnections = time.Second - refreshPeersOnTopic = time.Second * 3 - ttlPeersOnTopic = time.Second * 10 - ttlConnectionsWatcher = time.Hour * 2 - pubsubTimeCacheDuration = 10 * time.Minute - acceptMessagesInAdvanceDuration = 20 * time.Second // we are accepting the messages with timestamp in the future only for this delta - pollWaitForConnectionsInterval = time.Second - broadcastGoRoutines = 1000 - timeBetweenPeerPrints = time.Second * 20 - timeBetweenExternalLoggersCheck = time.Second * 20 - minRangePortValue = 1025 - noSignPolicy = pubsub.MessageSignaturePolicy(0) // should be used only in tests - msgBindError = "address already in use" - maxRetriesIfBindError = 10 -) - -type messageSigningConfig bool - -const ( - withMessageSigning messageSigningConfig = true - withoutMessageSigning messageSigningConfig = false -) - -// TODO remove the header size of the message when commit d3c5ecd3a3e884206129d9f2a9a4ddfd5e7c8951 from -// https://github.com/libp2p/go-libp2p-pubsub/pull/189/commits will be part of a new release -var messageHeader = 64 * 1024 // 64kB -var maxSendBuffSize = (1 << 21) - messageHeader -var log = logger.GetOrCreate("p2p/libp2p") - -var _ p2p.Messenger = (*networkMessenger)(nil) -var externalPackages = []string{"dht", "nat", "basichost", "pubsub"} - -func init() { - pubsub.TimeCacheDuration = pubsubTimeCacheDuration - - for _, external := range externalPackages { - _ = logger.GetOrCreate(fmt.Sprintf("external/%s", external)) - } -} - -// TODO refactor this struct to have be a wrapper (with logic) over a glue code -type networkMessenger struct { - p2pSigner - ctx context.Context - cancelFunc context.CancelFunc - p2pHost ConnectableHost - port int - pb *pubsub.PubSub - ds p2p.DirectSender - // TODO refactor this (connMonitor & connMonitorWrapper) - connMonitor ConnectionMonitor - connMonitorWrapper p2p.ConnectionMonitorWrapper - peerDiscoverer p2p.PeerDiscoverer - sharder p2p.Sharder - peerShardResolver p2p.PeerShardResolver - mutPeerResolver sync.RWMutex - mutTopics sync.RWMutex - processors map[string]*topicProcessors - topics map[string]*pubsub.Topic - subscriptions map[string]*pubsub.Subscription - outgoingPLB p2p.ChannelLoadBalancer - poc *peersOnChannel - goRoutinesThrottler *throttler.NumGoRoutinesThrottler - connectionsMetric *metrics.Connections - debugger p2p.Debugger - marshalizer p2p.Marshalizer - syncTimer p2p.SyncTimer - preferredPeersHolder p2p.PreferredPeersHolderHandler - printConnectionsWatcher p2p.ConnectionsWatcher - peersRatingHandler p2p.PeersRatingHandler -} - -// ArgsNetworkMessenger defines the options used to create a p2p wrapper -type ArgsNetworkMessenger struct { - ListenAddress string - Marshalizer p2p.Marshalizer - P2pConfig config.P2PConfig - SyncTimer p2p.SyncTimer - PreferredPeersHolder p2p.PreferredPeersHolderHandler - NodeOperationMode p2p.NodeOperation - PeersRatingHandler p2p.PeersRatingHandler - ConnectionWatcherType string -} - -// NewNetworkMessenger creates a libP2P messenger by opening a port on the current machine -func NewNetworkMessenger(args ArgsNetworkMessenger) (*networkMessenger, error) { - return newNetworkMessenger(args, withMessageSigning) -} - -func newNetworkMessenger(args ArgsNetworkMessenger, messageSigning messageSigningConfig) (*networkMessenger, error) { - if check.IfNil(args.Marshalizer) { - return nil, fmt.Errorf("%w when creating a new network messenger", p2p.ErrNilMarshalizer) - } - if check.IfNil(args.SyncTimer) { - return nil, fmt.Errorf("%w when creating a new network messenger", p2p.ErrNilSyncTimer) - } - if check.IfNil(args.PreferredPeersHolder) { - return nil, fmt.Errorf("%w when creating a new network messenger", p2p.ErrNilPreferredPeersHolder) - } - if check.IfNil(args.PeersRatingHandler) { - return nil, fmt.Errorf("%w when creating a new network messenger", p2p.ErrNilPeersRatingHandler) - } - - keyGen := p2pCrypto.NewIdentityGenerator() - p2pPrivKey, err := keyGen.CreateP2PPrivateKey(args.P2pConfig.Node.Seed) - if err != nil { - return nil, err - } - - setupExternalP2PLoggers() - - p2pNode, err := constructNodeWithPortRetry(args, p2pPrivKey) - if err != nil { - return nil, err - } - - err = addComponentsToNode(args, p2pNode, messageSigning) - if err != nil { - log.LogIfError(p2pNode.p2pHost.Close()) - return nil, err - } - - return p2pNode, nil -} - -func constructNode( - args ArgsNetworkMessenger, - p2pPrivKey *crypto.Secp256k1PrivateKey, -) (*networkMessenger, error) { - - port, err := getPort(args.P2pConfig.Node.Port, checkFreePort) - if err != nil { - return nil, err - } - - log.Debug("connectionWatcherType", "type", args.ConnectionWatcherType) - connWatcher, err := metricsFactory.NewConnectionsWatcher(args.ConnectionWatcherType, ttlConnectionsWatcher) - if err != nil { - return nil, err - } - - p2pSignerInstance, err := p2pCrypto.NewP2PSigner(p2pPrivKey) - if err != nil { - return nil, err - } - - address := fmt.Sprintf(args.ListenAddress+"%d", port) - opts := []libp2p.Option{ - libp2p.ListenAddrStrings(address), - libp2p.Identity(p2pPrivKey), - libp2p.DefaultMuxers, - libp2p.DefaultSecurity, - libp2p.DefaultTransports, - // we need the disable relay option in order to save the node's bandwidth as much as possible - libp2p.DisableRelay(), - libp2p.NATPortMap(), - } - - ctx, cancelFunc := context.WithCancel(context.Background()) - h, err := libp2p.New(opts...) - if err != nil { - cancelFunc() - return nil, err - } - - p2pNode := &networkMessenger{ - p2pSigner: p2pSignerInstance, - ctx: ctx, - cancelFunc: cancelFunc, - p2pHost: NewConnectableHost(h), - port: port, - printConnectionsWatcher: connWatcher, - peersRatingHandler: args.PeersRatingHandler, - } - - return p2pNode, nil -} - -func constructNodeWithPortRetry( - args ArgsNetworkMessenger, - p2pPrivKey *crypto.Secp256k1PrivateKey, -) (*networkMessenger, error) { - - var lastErr error - for i := 0; i < maxRetriesIfBindError; i++ { - p2pNode, err := constructNode(args, p2pPrivKey) - if err == nil { - return p2pNode, nil - } - - lastErr = err - if !strings.Contains(err.Error(), msgBindError) { - // not a bind error, return directly - return nil, err - } - - log.Debug("bind error in network messenger", "retry number", i+1, "error", err) - } - - return nil, lastErr -} - -func setupExternalP2PLoggers() { - for _, external := range externalPackages { - logLevel := logger.GetLoggerLogLevel("external/" + external) - if logLevel > logger.LogTrace { - continue - } - - _ = logging.SetLogLevel(external, "DEBUG") - } -} - -func addComponentsToNode( - args ArgsNetworkMessenger, - p2pNode *networkMessenger, - messageSigning messageSigningConfig, -) error { - var err error - - p2pNode.processors = make(map[string]*topicProcessors) - p2pNode.topics = make(map[string]*pubsub.Topic) - p2pNode.subscriptions = make(map[string]*pubsub.Subscription) - p2pNode.outgoingPLB = loadBalancer.NewOutgoingChannelLoadBalancer() - p2pNode.peerShardResolver = &unknownPeerShardResolver{} - p2pNode.marshalizer = args.Marshalizer - p2pNode.syncTimer = args.SyncTimer - p2pNode.preferredPeersHolder = args.PreferredPeersHolder - p2pNode.debugger = p2pDebug.NewP2PDebugger(core.PeerID(p2pNode.p2pHost.ID())) - p2pNode.peersRatingHandler = args.PeersRatingHandler - - err = p2pNode.createPubSub(messageSigning) - if err != nil { - return err - } - - err = p2pNode.createSharder(args) - if err != nil { - return err - } - - err = p2pNode.createDiscoverer(args.P2pConfig) - if err != nil { - return err - } - - err = p2pNode.createConnectionMonitor(args.P2pConfig) - if err != nil { - return err - } - - p2pNode.createConnectionsMetric() - - p2pNode.ds, err = NewDirectSender(p2pNode.ctx, p2pNode.p2pHost, p2pNode.directMessageHandler) - if err != nil { - return err - } - - p2pNode.goRoutinesThrottler, err = throttler.NewNumGoRoutinesThrottler(broadcastGoRoutines) - if err != nil { - return err - } - - p2pNode.printLogs() - - return nil -} - -func (netMes *networkMessenger) createPubSub(messageSigning messageSigningConfig) error { - optsPS := make([]pubsub.Option, 0) - if messageSigning == withoutMessageSigning { - log.Warn("signature verification is turned off in network messenger instance. NOT recommended in production environment") - optsPS = append(optsPS, pubsub.WithMessageSignaturePolicy(noSignPolicy)) - } - - var err error - netMes.pb, err = pubsub.NewGossipSub(netMes.ctx, netMes.p2pHost, optsPS...) - if err != nil { - return err - } - - netMes.poc, err = newPeersOnChannel( - netMes.peersRatingHandler, - netMes.pb.ListPeers, - refreshPeersOnTopic, - ttlPeersOnTopic) - if err != nil { - return err - } - - go func(plb p2p.ChannelLoadBalancer) { - for { - select { - case <-time.After(durationBetweenSends): - case <-netMes.ctx.Done(): - log.Debug("closing networkMessenger's send from channel load balancer go routine") - return - } - - sendableData := plb.CollectOneElementFromChannels() - if sendableData == nil { - continue - } - - netMes.mutTopics.RLock() - topic := netMes.topics[sendableData.Topic] - netMes.mutTopics.RUnlock() - - if topic == nil { - log.Warn("writing on a topic that the node did not register on - message dropped", - "topic", sendableData.Topic, - ) - - continue - } - - packedSendableDataBuff := netMes.createMessageBytes(sendableData.Buff) - if len(packedSendableDataBuff) == 0 { - continue - } - - errPublish := netMes.publish(topic, sendableData, packedSendableDataBuff) - if errPublish != nil { - log.Trace("error sending data", "error", errPublish) - } - } - }(netMes.outgoingPLB) - - return nil -} - -func (netMes *networkMessenger) publish(topic *pubsub.Topic, data *p2p.SendableData, packedSendableDataBuff []byte) error { - if data.Sk == nil { - return topic.Publish(netMes.ctx, packedSendableDataBuff) - } - - return topic.PublishWithSk(netMes.ctx, packedSendableDataBuff, data.Sk, data.ID) -} - -func (netMes *networkMessenger) createMessageBytes(buff []byte) []byte { - message := &data.TopicMessage{ - Version: currentTopicMessageVersion, - Payload: buff, - Timestamp: netMes.syncTimer.CurrentTime().Unix(), - } - - buffToSend, errMarshal := netMes.marshalizer.Marshal(message) - if errMarshal != nil { - log.Warn("error sending data", "error", errMarshal) - return nil - } - - return buffToSend -} - -func (netMes *networkMessenger) createSharder(argsNetMes ArgsNetworkMessenger) error { - args := factory.ArgsSharderFactory{ - PeerShardResolver: &unknownPeerShardResolver{}, - Pid: netMes.p2pHost.ID(), - P2pConfig: argsNetMes.P2pConfig, - PreferredPeersHolder: netMes.preferredPeersHolder, - NodeOperationMode: argsNetMes.NodeOperationMode, - } - - var err error - netMes.sharder, err = factory.NewSharder(args) - - return err -} - -func (netMes *networkMessenger) createDiscoverer(p2pConfig config.P2PConfig) error { - var err error - - args := discoveryFactory.ArgsPeerDiscoverer{ - Context: netMes.ctx, - Host: netMes.p2pHost, - Sharder: netMes.sharder, - P2pConfig: p2pConfig, - ConnectionsWatcher: netMes.printConnectionsWatcher, - } - - netMes.peerDiscoverer, err = discoveryFactory.NewPeerDiscoverer(args) - - return err -} - -func (netMes *networkMessenger) createConnectionMonitor(p2pConfig config.P2PConfig) error { - reconnecter, ok := netMes.peerDiscoverer.(p2p.Reconnecter) - if !ok { - return fmt.Errorf("%w when converting peerDiscoverer to reconnecter interface", p2p.ErrWrongTypeAssertion) - } - - sharder, ok := netMes.sharder.(connectionMonitor.Sharder) - if !ok { - return fmt.Errorf("%w in networkMessenger.createConnectionMonitor", p2p.ErrWrongTypeAssertions) - } - - args := connectionMonitor.ArgsConnectionMonitorSimple{ - Reconnecter: reconnecter, - Sharder: sharder, - ThresholdMinConnectedPeers: p2pConfig.Node.ThresholdMinConnectedPeers, - PreferredPeersHolder: netMes.preferredPeersHolder, - ConnectionsWatcher: netMes.printConnectionsWatcher, - } - var err error - netMes.connMonitor, err = connectionMonitor.NewLibp2pConnectionMonitorSimple(args) - if err != nil { - return err - } - - cmw := newConnectionMonitorWrapper( - netMes.p2pHost.Network(), - netMes.connMonitor, - &disabled.PeerDenialEvaluator{}, - ) - netMes.p2pHost.Network().Notify(cmw) - netMes.connMonitorWrapper = cmw - - go func() { - for { - cmw.CheckConnectionsBlocking() - select { - case <-time.After(durationCheckConnections): - case <-netMes.ctx.Done(): - log.Debug("peer monitoring go routine is stopping...") - return - } - } - }() - - return nil -} - -func (netMes *networkMessenger) createConnectionsMetric() { - netMes.connectionsMetric = metrics.NewConnections() - netMes.p2pHost.Network().Notify(netMes.connectionsMetric) -} - -func (netMes *networkMessenger) printLogs() { - addresses := make([]interface{}, 0) - for i, address := range netMes.p2pHost.Addrs() { - addresses = append(addresses, fmt.Sprintf("addr%d", i)) - addresses = append(addresses, address.String()+"/p2p/"+netMes.ID().Pretty()) - } - log.Info("listening on addresses", addresses...) - - go netMes.printLogsStats() - go netMes.checkExternalLoggers() -} - -func (netMes *networkMessenger) printLogsStats() { - for { - select { - case <-netMes.ctx.Done(): - log.Debug("closing networkMessenger.printLogsStats go routine") - return - case <-time.After(timeBetweenPeerPrints): - } - - conns := netMes.connectionsMetric.ResetNumConnections() - disconns := netMes.connectionsMetric.ResetNumDisconnections() - - peersInfo := netMes.GetConnectedPeersInfo() - log.Debug("network connection status", - "known peers", len(netMes.Peers()), - "connected peers", len(netMes.ConnectedPeers()), - "intra shard validators", peersInfo.NumIntraShardValidators, - "intra shard observers", peersInfo.NumIntraShardObservers, - "cross shard validators", peersInfo.NumCrossShardValidators, - "cross shard observers", peersInfo.NumCrossShardObservers, - "full history observers", peersInfo.NumFullHistoryObservers, - "unknown", len(peersInfo.UnknownPeers), - "seeders", len(peersInfo.Seeders), - "current shard", peersInfo.SelfShardID, - "validators histogram", netMes.mapHistogram(peersInfo.NumValidatorsOnShard), - "observers histogram", netMes.mapHistogram(peersInfo.NumObserversOnShard), - "preferred peers histogram", netMes.mapHistogram(peersInfo.NumPreferredPeersOnShard), - ) - - connsPerSec := conns / uint32(timeBetweenPeerPrints/time.Second) - disconnsPerSec := disconns / uint32(timeBetweenPeerPrints/time.Second) - - log.Debug("network connection metrics", - "connections/s", connsPerSec, - "disconnections/s", disconnsPerSec, - "connections", conns, - "disconnections", disconns, - "time", timeBetweenPeerPrints, - ) - } -} - -func (netMes *networkMessenger) mapHistogram(input map[uint32]int) string { - keys := make([]uint32, 0, len(input)) - for shard := range input { - keys = append(keys, shard) - } - sort.Slice(keys, func(i, j int) bool { - return keys[i] < keys[j] - }) - - vals := make([]string, 0, len(keys)) - for _, key := range keys { - var shard string - if key == core.MetachainShardId { - shard = "meta" - } else { - shard = fmt.Sprintf("shard %d", key) - } - - vals = append(vals, fmt.Sprintf("%s: %d", shard, input[key])) - } - - return strings.Join(vals, ", ") -} - -func (netMes *networkMessenger) checkExternalLoggers() { - for { - select { - case <-netMes.ctx.Done(): - log.Debug("closing networkMessenger.checkExternalLoggers go routine") - return - case <-time.After(timeBetweenExternalLoggersCheck): - } - - setupExternalP2PLoggers() - } -} - -// Close closes the host, connections and streams -func (netMes *networkMessenger) Close() error { - log.Debug("closing network messenger's host...") - - var err error - errHost := netMes.p2pHost.Close() - if errHost != nil { - err = errHost - log.Warn("networkMessenger.Close", - "component", "host", - "error", err) - } - - log.Debug("closing network messenger's print connection watcher...") - errConnWatcher := netMes.printConnectionsWatcher.Close() - if errConnWatcher != nil { - err = errConnWatcher - log.Warn("networkMessenger.Close", - "component", "connectionsWatcher", - "error", err) - } - - log.Debug("closing network messenger's outgoing load balancer...") - errOplb := netMes.outgoingPLB.Close() - if errOplb != nil { - err = errOplb - log.Warn("networkMessenger.Close", - "component", "outgoingPLB", - "error", err) - } - - log.Debug("closing network messenger's peers on channel...") - errPoc := netMes.poc.Close() - if errPoc != nil { - log.Warn("networkMessenger.Close", - "component", "peersOnChannel", - "error", errPoc) - } - - log.Debug("closing network messenger's connection monitor...") - errConnMonitor := netMes.connMonitor.Close() - if errConnMonitor != nil { - log.Warn("networkMessenger.Close", - "component", "connMonitor", - "error", errConnMonitor) - } - - log.Debug("closing network messenger's components through the context...") - netMes.cancelFunc() - - log.Debug("closing network messenger's debugger...") - errDebugger := netMes.debugger.Close() - if errDebugger != nil { - err = errDebugger - log.Warn("networkMessenger.Close", - "component", "debugger", - "error", err) - } - - log.Debug("closing network messenger's peerstore...") - errPeerStore := netMes.p2pHost.Peerstore().Close() - if errPeerStore != nil { - err = errPeerStore - log.Warn("networkMessenger.Close", - "component", "peerstore", - "error", err) - } - - if err == nil { - log.Info("network messenger closed successfully") - } - - return err -} - -// ID returns the messenger's ID -func (netMes *networkMessenger) ID() core.PeerID { - h := netMes.p2pHost - - return core.PeerID(h.ID()) -} - -// Peers returns the list of all known peers ID (including self) -func (netMes *networkMessenger) Peers() []core.PeerID { - peers := make([]core.PeerID, 0) - - for _, p := range netMes.p2pHost.Peerstore().Peers() { - peers = append(peers, core.PeerID(p)) - } - return peers -} - -// Addresses returns all addresses found in peerstore -func (netMes *networkMessenger) Addresses() []string { - addrs := make([]string, 0) - - for _, address := range netMes.p2pHost.Addrs() { - addrs = append(addrs, address.String()+"/p2p/"+netMes.ID().Pretty()) - } - - return addrs -} - -// ConnectToPeer tries to open a new connection to a peer -func (netMes *networkMessenger) ConnectToPeer(address string) error { - return netMes.p2pHost.ConnectToPeer(netMes.ctx, address) -} - -// Bootstrap will start the peer discovery mechanism -func (netMes *networkMessenger) Bootstrap() error { - err := netMes.peerDiscoverer.Bootstrap() - if err == nil { - log.Info("started the network discovery process...") - } - return err -} - -// WaitForConnections will wait the maxWaitingTime duration or until the target connected peers was achieved -func (netMes *networkMessenger) WaitForConnections(maxWaitingTime time.Duration, minNumOfPeers uint32) { - startTime := time.Now() - defer func() { - log.Debug("networkMessenger.WaitForConnections", - "waited", time.Since(startTime), "num connected peers", len(netMes.ConnectedPeers())) - }() - - if minNumOfPeers == 0 { - log.Debug("networkMessenger.WaitForConnections", "waiting", maxWaitingTime) - time.Sleep(maxWaitingTime) - return - } - - netMes.waitForConnections(maxWaitingTime, minNumOfPeers) -} - -func (netMes *networkMessenger) waitForConnections(maxWaitingTime time.Duration, minNumOfPeers uint32) { - log.Debug("networkMessenger.WaitForConnections", "waiting", maxWaitingTime, "min num of peers", minNumOfPeers) - ctxMaxWaitingTime, cancel := context.WithTimeout(context.Background(), maxWaitingTime) - defer cancel() - - for { - if netMes.shouldStopWaiting(ctxMaxWaitingTime, minNumOfPeers) { - return - } - } -} - -func (netMes *networkMessenger) shouldStopWaiting(ctxMaxWaitingTime context.Context, minNumOfPeers uint32) bool { - ctx, cancel := context.WithTimeout(context.Background(), pollWaitForConnectionsInterval) - defer cancel() - - select { - case <-ctxMaxWaitingTime.Done(): - return true - case <-ctx.Done(): - return int(minNumOfPeers) <= len(netMes.ConnectedPeers()) - } -} - -// IsConnected returns true if current node is connected to provided peer -func (netMes *networkMessenger) IsConnected(peerID core.PeerID) bool { - h := netMes.p2pHost - - connectedness := h.Network().Connectedness(peer.ID(peerID)) - - return connectedness == network.Connected -} - -// ConnectedPeers returns the current connected peers list -func (netMes *networkMessenger) ConnectedPeers() []core.PeerID { - h := netMes.p2pHost - - connectedPeers := make(map[core.PeerID]struct{}) - - for _, conn := range h.Network().Conns() { - p := core.PeerID(conn.RemotePeer()) - - if netMes.IsConnected(p) { - connectedPeers[p] = struct{}{} - } - } - - peerList := make([]core.PeerID, len(connectedPeers)) - - index := 0 - for k := range connectedPeers { - peerList[index] = k - index++ - } - - return peerList -} - -// ConnectedAddresses returns all connected peer's addresses -func (netMes *networkMessenger) ConnectedAddresses() []string { - h := netMes.p2pHost - conns := make([]string, 0) - - for _, c := range h.Network().Conns() { - conns = append(conns, c.RemoteMultiaddr().String()+"/p2p/"+c.RemotePeer().Pretty()) - } - return conns -} - -// PeerAddresses returns the peer's addresses or empty slice if the peer is unknown -func (netMes *networkMessenger) PeerAddresses(pid core.PeerID) []string { - h := netMes.p2pHost - result := make([]string, 0) - - // check if the peer is connected to return it's connected address - for _, c := range h.Network().Conns() { - if string(c.RemotePeer()) == string(pid.Bytes()) { - result = append(result, c.RemoteMultiaddr().String()) - break - } - } - - // check in peerstore (maybe it is known but not connected) - addresses := h.Peerstore().Addrs(peer.ID(pid.Bytes())) - for _, addr := range addresses { - result = append(result, addr.String()) - } - - return result -} - -// ConnectedPeersOnTopic returns the connected peers on a provided topic -func (netMes *networkMessenger) ConnectedPeersOnTopic(topic string) []core.PeerID { - return netMes.poc.ConnectedPeersOnChannel(topic) -} - -// ConnectedFullHistoryPeersOnTopic returns the connected peers on a provided topic -func (netMes *networkMessenger) ConnectedFullHistoryPeersOnTopic(topic string) []core.PeerID { - peerList := netMes.ConnectedPeersOnTopic(topic) - fullHistoryList := make([]core.PeerID, 0) - for _, topicPeer := range peerList { - peerInfo := netMes.peerShardResolver.GetPeerInfo(topicPeer) - if peerInfo.PeerSubType == core.FullHistoryObserver { - fullHistoryList = append(fullHistoryList, topicPeer) - } - } - - return fullHistoryList -} - -// CreateTopic opens a new topic using pubsub infrastructure -func (netMes *networkMessenger) CreateTopic(name string, createChannelForTopic bool) error { - netMes.mutTopics.Lock() - defer netMes.mutTopics.Unlock() - _, found := netMes.topics[name] - if found { - return nil - } - - if name == common.ConnectionTopic { - return nil - } - - topic, err := netMes.pb.Join(name) - if err != nil { - return fmt.Errorf("%w for topic %s", err, name) - } - - netMes.topics[name] = topic - subscrRequest, err := topic.Subscribe() - if err != nil { - return fmt.Errorf("%w for topic %s", err, name) - } - - netMes.subscriptions[name] = subscrRequest - if createChannelForTopic { - err = netMes.outgoingPLB.AddChannel(name) - } - - // just a dummy func to consume messages received by the newly created topic - go func() { - var errSubscrNext error - for { - _, errSubscrNext = subscrRequest.Next(netMes.ctx) - if errSubscrNext != nil { - log.Debug("closed subscription", - "topic", subscrRequest.Topic(), - "err", errSubscrNext, - ) - return - } - } - }() - - return err -} - -// HasTopic returns true if the topic has been created -func (netMes *networkMessenger) HasTopic(name string) bool { - netMes.mutTopics.RLock() - _, found := netMes.topics[name] - netMes.mutTopics.RUnlock() - - return found -} - -// BroadcastOnChannelBlocking tries to send a byte buffer onto a topic using provided channel -// It is a blocking method. It needs to be launched on a go routine -func (netMes *networkMessenger) BroadcastOnChannelBlocking(channel string, topic string, buff []byte) error { - err := netMes.checkSendableData(buff) - if err != nil { - return err - } - - if !netMes.goRoutinesThrottler.CanProcess() { - return p2p.ErrTooManyGoroutines - } - - netMes.goRoutinesThrottler.StartProcessing() - - sendable := &p2p.SendableData{ - Buff: buff, - Topic: topic, - ID: netMes.p2pHost.ID(), - } - netMes.outgoingPLB.GetChannelOrDefault(channel) <- sendable - netMes.goRoutinesThrottler.EndProcessing() - return nil -} - -func (netMes *networkMessenger) checkSendableData(buff []byte) error { - if len(buff) > maxSendBuffSize { - return fmt.Errorf("%w, to be sent: %d, maximum: %d", p2p.ErrMessageTooLarge, len(buff), maxSendBuffSize) - } - if len(buff) == 0 { - return p2p.ErrEmptyBufferToSend - } - - return nil -} - -// BroadcastOnChannel tries to send a byte buffer onto a topic using provided channel -func (netMes *networkMessenger) BroadcastOnChannel(channel string, topic string, buff []byte) { - go func() { - err := netMes.BroadcastOnChannelBlocking(channel, topic, buff) - if err != nil { - log.Warn("p2p broadcast", "error", err.Error()) - } - }() -} - -// Broadcast tries to send a byte buffer onto a topic using the topic name as channel -func (netMes *networkMessenger) Broadcast(topic string, buff []byte) { - netMes.BroadcastOnChannel(topic, topic, buff) -} - -// BroadcastOnChannelBlockingUsingPrivateKey tries to send a byte buffer onto a topic using provided channel -// It is a blocking method. It needs to be launched on a go routine -func (netMes *networkMessenger) BroadcastOnChannelBlockingUsingPrivateKey( - channel string, - topic string, - buff []byte, - pid core.PeerID, - skBytes []byte, -) error { - id := peer.ID(pid) - sk, err := crypto.UnmarshalPrivateKey(skBytes) - if err != nil { - return err - } - - err = netMes.checkSendableData(buff) - if err != nil { - return err - } - - if !netMes.goRoutinesThrottler.CanProcess() { - return p2p.ErrTooManyGoroutines - } - - netMes.goRoutinesThrottler.StartProcessing() - - sendable := &p2p.SendableData{ - Buff: buff, - Topic: topic, - Sk: sk, - ID: id, - } - netMes.outgoingPLB.GetChannelOrDefault(channel) <- sendable - netMes.goRoutinesThrottler.EndProcessing() - return nil -} - -// BroadcastOnChannelUsingPrivateKey tries to send a byte buffer onto a topic using provided channel -func (netMes *networkMessenger) BroadcastOnChannelUsingPrivateKey( - channel string, - topic string, - buff []byte, - pid core.PeerID, - skBytes []byte, -) { - go func() { - err := netMes.BroadcastOnChannelBlockingUsingPrivateKey(channel, topic, buff, pid, skBytes) - if err != nil { - log.Warn("p2p broadcast using private key", "error", err.Error()) - } - }() -} - -// BroadcastUsingPrivateKey tries to send a byte buffer onto a topic using the topic name as channel -func (netMes *networkMessenger) BroadcastUsingPrivateKey( - topic string, - buff []byte, - pid core.PeerID, - skBytes []byte, -) { - netMes.BroadcastOnChannelUsingPrivateKey(topic, topic, buff, pid, skBytes) -} - -// RegisterMessageProcessor registers a message process on a topic. The function allows registering multiple handlers -// on a topic. Each handler should be associated with a new identifier on the same topic. Using same identifier on different -// topics is allowed. The order of handler calling on a particular topic is not deterministic. -func (netMes *networkMessenger) RegisterMessageProcessor(topic string, identifier string, handler p2p.MessageProcessor) error { - if check.IfNil(handler) { - return fmt.Errorf("%w when calling networkMessenger.RegisterMessageProcessor for topic %s", - p2p.ErrNilValidator, topic) - } - - netMes.mutTopics.Lock() - defer netMes.mutTopics.Unlock() - topicProcs := netMes.processors[topic] - if topicProcs == nil { - topicProcs = newTopicProcessors() - netMes.processors[topic] = topicProcs - - err := netMes.registerOnPubSub(topic, topicProcs) - if err != nil { - return err - } - } - - err := topicProcs.addTopicProcessor(identifier, handler) - if err != nil { - return fmt.Errorf("%w, topic %s", err, topic) - } - - return nil -} - -func (netMes *networkMessenger) registerOnPubSub(topic string, topicProcs *topicProcessors) error { - if topic == common.ConnectionTopic { - // do not allow broadcasts on this connection topic - return nil - } - - return netMes.pb.RegisterTopicValidator(topic, netMes.pubsubCallback(topicProcs, topic)) -} - -func (netMes *networkMessenger) pubsubCallback(topicProcs *topicProcessors, topic string) func(ctx context.Context, pid peer.ID, message *pubsub.Message) bool { - return func(ctx context.Context, pid peer.ID, message *pubsub.Message) bool { - fromConnectedPeer := core.PeerID(pid) - msg, err := netMes.transformAndCheckMessage(message, fromConnectedPeer, topic) - if err != nil { - log.Trace("p2p validator - new message", "error", err.Error(), "topic", topic) - return false - } - - identifiers, handlers := topicProcs.getList() - messageOk := true - for index, handler := range handlers { - err = handler.ProcessReceivedMessage(msg, fromConnectedPeer) - if err != nil { - log.Trace("p2p validator", - "error", err.Error(), - "topic", topic, - "originator", p2p.MessageOriginatorPid(msg), - "from connected peer", p2p.PeerIdToShortString(fromConnectedPeer), - "seq no", p2p.MessageOriginatorSeq(msg), - "topic identifier", identifiers[index], - ) - messageOk = false - } - } - netMes.processDebugMessage(topic, fromConnectedPeer, uint64(len(message.Data)), !messageOk) - - if messageOk { - netMes.peersRatingHandler.IncreaseRating(fromConnectedPeer) - } - - return messageOk - } -} - -func (netMes *networkMessenger) transformAndCheckMessage(pbMsg *pubsub.Message, pid core.PeerID, topic string) (p2p.MessageP2P, error) { - msg, errUnmarshal := NewMessage(pbMsg, netMes.marshalizer) - if errUnmarshal != nil { - // this error is so severe that will need to blacklist both the originator and the connected peer as there is - // no way this node can communicate with them - pidFrom := core.PeerID(pbMsg.From) - netMes.blacklistPid(pid, common.WrongP2PMessageBlacklistDuration) - netMes.blacklistPid(pidFrom, common.WrongP2PMessageBlacklistDuration) - - return nil, errUnmarshal - } - - err := netMes.validMessageByTimestamp(msg) - if err != nil { - // not reprocessing nor re-broadcasting the same message over and over again - log.Trace("received an invalid message", - "originator pid", p2p.MessageOriginatorPid(msg), - "from connected pid", p2p.PeerIdToShortString(pid), - "sequence", hex.EncodeToString(msg.SeqNo()), - "timestamp", msg.Timestamp(), - "error", err, - ) - netMes.processDebugMessage(topic, pid, uint64(len(msg.Data())), true) - - return nil, err - } - - return msg, nil -} - -func (netMes *networkMessenger) blacklistPid(pid core.PeerID, banDuration time.Duration) { - if netMes.connMonitorWrapper.PeerDenialEvaluator().IsDenied(pid) { - return - } - if len(pid) == 0 { - return - } - - log.Debug("blacklisted due to incompatible p2p message", - "pid", pid.Pretty(), - "time", banDuration, - ) - - err := netMes.connMonitorWrapper.PeerDenialEvaluator().UpsertPeerID(pid, banDuration) - if err != nil { - log.Warn("error blacklisting peer ID in network messnger", - "pid", pid.Pretty(), - "error", err.Error(), - ) - } -} - -// invalidMessageByTimestamp will check that the message time stamp should be in the interval -// (now-pubsubTimeCacheDuration+acceptMessagesInAdvanceDuration, now+acceptMessagesInAdvanceDuration) -func (netMes *networkMessenger) validMessageByTimestamp(msg p2p.MessageP2P) error { - now := netMes.syncTimer.CurrentTime() - isInFuture := now.Add(acceptMessagesInAdvanceDuration).Unix() < msg.Timestamp() - if isInFuture { - return fmt.Errorf("%w, self timestamp %d, message timestamp %d", - p2p.ErrMessageTooNew, now.Unix(), msg.Timestamp()) - } - - past := now.Unix() - int64(pubsubTimeCacheDuration.Seconds()) - if msg.Timestamp() < past { - return fmt.Errorf("%w, self timestamp %d, message timestamp %d", - p2p.ErrMessageTooOld, now.Unix(), msg.Timestamp()) - } - - return nil -} - -func (netMes *networkMessenger) processDebugMessage(topic string, fromConnectedPeer core.PeerID, size uint64, isRejected bool) { - if fromConnectedPeer == netMes.ID() { - netMes.debugger.AddOutgoingMessage(topic, size, isRejected) - } else { - netMes.debugger.AddIncomingMessage(topic, size, isRejected) - } -} - -// UnregisterAllMessageProcessors will unregister all message processors for topics -func (netMes *networkMessenger) UnregisterAllMessageProcessors() error { - netMes.mutTopics.Lock() - defer netMes.mutTopics.Unlock() - - for topic := range netMes.processors { - if topic == common.ConnectionTopic { - delete(netMes.processors, topic) - continue - } - - err := netMes.pb.UnregisterTopicValidator(topic) - if err != nil { - return err - } - - delete(netMes.processors, topic) - } - return nil -} - -// UnjoinAllTopics call close on all topics -func (netMes *networkMessenger) UnjoinAllTopics() error { - netMes.mutTopics.Lock() - defer netMes.mutTopics.Unlock() - - var errFound error - for topicName, t := range netMes.topics { - subscr := netMes.subscriptions[topicName] - if subscr != nil { - subscr.Cancel() - } - - err := t.Close() - if err != nil { - log.Warn("error closing topic", - "topic", topicName, - "error", err, - ) - errFound = err - } - - delete(netMes.topics, topicName) - } - - return errFound -} - -// UnregisterMessageProcessor unregisters a message processes on a topic -func (netMes *networkMessenger) UnregisterMessageProcessor(topic string, identifier string) error { - netMes.mutTopics.Lock() - defer netMes.mutTopics.Unlock() - - topicProcs := netMes.processors[topic] - if topicProcs == nil { - return nil - } - - err := topicProcs.removeTopicProcessor(identifier) - if err != nil { - return err - } - - identifiers, _ := topicProcs.getList() - if len(identifiers) == 0 { - netMes.processors[topic] = nil - - if topic != common.ConnectionTopic { // no validator registered for this topic - return netMes.pb.UnregisterTopicValidator(topic) - } - } - - return nil -} - -// SendToConnectedPeer sends a direct message to a connected peer -func (netMes *networkMessenger) SendToConnectedPeer(topic string, buff []byte, peerID core.PeerID) error { - err := netMes.checkSendableData(buff) - if err != nil { - return err - } - - buffToSend := netMes.createMessageBytes(buff) - if len(buffToSend) == 0 { - return nil - } - - if peerID == netMes.ID() { - return netMes.sendDirectToSelf(topic, buffToSend) - } - - err = netMes.ds.Send(topic, buffToSend, peerID) - netMes.debugger.AddOutgoingMessage(topic, uint64(len(buffToSend)), err != nil) - - return err -} - -func (netMes *networkMessenger) sendDirectToSelf(topic string, buff []byte) error { - msg := &pubsub.Message{ - Message: &pubsubPb.Message{ - From: netMes.ID().Bytes(), - Data: buff, - Seqno: netMes.ds.NextSeqno(), - Topic: &topic, - Signature: netMes.ID().Bytes(), - }, - } - - return netMes.directMessageHandler(msg, netMes.ID()) -} - -func (netMes *networkMessenger) directMessageHandler(message *pubsub.Message, fromConnectedPeer core.PeerID) error { - topic := *message.Topic - msg, err := netMes.transformAndCheckMessage(message, fromConnectedPeer, topic) - if err != nil { - return err - } - - netMes.mutTopics.RLock() - topicProcs := netMes.processors[topic] - netMes.mutTopics.RUnlock() - - if topicProcs == nil { - return fmt.Errorf("%w on directMessageHandler for topic %s", p2p.ErrNilValidator, topic) - } - identifiers, handlers := topicProcs.getList() - - go func(msg p2p.MessageP2P) { - if check.IfNil(msg) { - return - } - - // we won't recheck the message id against the cacher here as there might be collisions since we are using - // a separate sequence counter for direct sender - messageOk := true - for index, handler := range handlers { - errProcess := handler.ProcessReceivedMessage(msg, fromConnectedPeer) - if errProcess != nil { - log.Trace("p2p validator", - "error", errProcess.Error(), - "topic", msg.Topic(), - "originator", p2p.MessageOriginatorPid(msg), - "from connected peer", p2p.PeerIdToShortString(fromConnectedPeer), - "seq no", p2p.MessageOriginatorSeq(msg), - "topic identifier", identifiers[index], - ) - messageOk = false - } - } - - netMes.debugger.AddIncomingMessage(msg.Topic(), uint64(len(msg.Data())), !messageOk) - - if messageOk { - netMes.peersRatingHandler.IncreaseRating(fromConnectedPeer) - } - }(msg) - - return nil -} - -// IsConnectedToTheNetwork returns true if the current node is connected to the network -func (netMes *networkMessenger) IsConnectedToTheNetwork() bool { - netw := netMes.p2pHost.Network() - return netMes.connMonitor.IsConnectedToTheNetwork(netw) -} - -// SetThresholdMinConnectedPeers sets the minimum connected peers before triggering a new reconnection -func (netMes *networkMessenger) SetThresholdMinConnectedPeers(minConnectedPeers int) error { - if minConnectedPeers < 0 { - return p2p.ErrInvalidValue - } - - netw := netMes.p2pHost.Network() - netMes.connMonitor.SetThresholdMinConnectedPeers(minConnectedPeers, netw) - - return nil -} - -// ThresholdMinConnectedPeers returns the minimum connected peers before triggering a new reconnection -func (netMes *networkMessenger) ThresholdMinConnectedPeers() int { - return netMes.connMonitor.ThresholdMinConnectedPeers() -} - -// SetPeerShardResolver sets the peer shard resolver component that is able to resolve the link -// between p2p.PeerID and shardId -func (netMes *networkMessenger) SetPeerShardResolver(peerShardResolver p2p.PeerShardResolver) error { - if check.IfNil(peerShardResolver) { - return p2p.ErrNilPeerShardResolver - } - - err := netMes.sharder.SetPeerShardResolver(peerShardResolver) - if err != nil { - return err - } - - netMes.mutPeerResolver.Lock() - netMes.peerShardResolver = peerShardResolver - netMes.mutPeerResolver.Unlock() - - return nil -} - -// SetPeerDenialEvaluator sets the peer black list handler -// TODO decide if we continue on using setters or switch to options. Refactor if necessary -func (netMes *networkMessenger) SetPeerDenialEvaluator(handler p2p.PeerDenialEvaluator) error { - return netMes.connMonitorWrapper.SetPeerDenialEvaluator(handler) -} - -// GetConnectedPeersInfo gets the current connected peers information -func (netMes *networkMessenger) GetConnectedPeersInfo() *p2p.ConnectedPeersInfo { - peers := netMes.p2pHost.Network().Peers() - connPeerInfo := &p2p.ConnectedPeersInfo{ - UnknownPeers: make([]string, 0), - Seeders: make([]string, 0), - IntraShardValidators: make(map[uint32][]string), - IntraShardObservers: make(map[uint32][]string), - CrossShardValidators: make(map[uint32][]string), - CrossShardObservers: make(map[uint32][]string), - FullHistoryObservers: make(map[uint32][]string), - NumObserversOnShard: make(map[uint32]int), - NumValidatorsOnShard: make(map[uint32]int), - NumPreferredPeersOnShard: make(map[uint32]int), - } - - netMes.mutPeerResolver.RLock() - defer netMes.mutPeerResolver.RUnlock() - - selfPeerInfo := netMes.peerShardResolver.GetPeerInfo(netMes.ID()) - connPeerInfo.SelfShardID = selfPeerInfo.ShardID - - for _, p := range peers { - conns := netMes.p2pHost.Network().ConnsToPeer(p) - connString := "[invalid connection string]" - if len(conns) > 0 { - connString = conns[0].RemoteMultiaddr().String() + "/p2p/" + p.Pretty() - } - - pid := core.PeerID(p) - peerInfo := netMes.peerShardResolver.GetPeerInfo(pid) - switch peerInfo.PeerType { - case core.UnknownPeer: - if netMes.sharder.IsSeeder(pid) { - connPeerInfo.Seeders = append(connPeerInfo.Seeders, connString) - } else { - connPeerInfo.UnknownPeers = append(connPeerInfo.UnknownPeers, connString) - } - case core.ValidatorPeer: - connPeerInfo.NumValidatorsOnShard[peerInfo.ShardID]++ - if selfPeerInfo.ShardID != peerInfo.ShardID { - connPeerInfo.CrossShardValidators[peerInfo.ShardID] = append(connPeerInfo.CrossShardValidators[peerInfo.ShardID], connString) - connPeerInfo.NumCrossShardValidators++ - } else { - connPeerInfo.IntraShardValidators[peerInfo.ShardID] = append(connPeerInfo.IntraShardValidators[peerInfo.ShardID], connString) - connPeerInfo.NumIntraShardValidators++ - } - case core.ObserverPeer: - connPeerInfo.NumObserversOnShard[peerInfo.ShardID]++ - if peerInfo.PeerSubType == core.FullHistoryObserver { - connPeerInfo.FullHistoryObservers[peerInfo.ShardID] = append(connPeerInfo.FullHistoryObservers[peerInfo.ShardID], connString) - connPeerInfo.NumFullHistoryObservers++ - break - } - if selfPeerInfo.ShardID != peerInfo.ShardID { - connPeerInfo.CrossShardObservers[peerInfo.ShardID] = append(connPeerInfo.CrossShardObservers[peerInfo.ShardID], connString) - connPeerInfo.NumCrossShardObservers++ - break - } - - connPeerInfo.IntraShardObservers[peerInfo.ShardID] = append(connPeerInfo.IntraShardObservers[peerInfo.ShardID], connString) - connPeerInfo.NumIntraShardObservers++ - } - - if netMes.preferredPeersHolder.Contains(pid) { - connPeerInfo.NumPreferredPeersOnShard[peerInfo.ShardID]++ - } - } - - return connPeerInfo -} - -// Port returns the port that this network messenger is using -func (netMes *networkMessenger) Port() int { - return netMes.port -} - -// IsInterfaceNil returns true if there is no value under the interface -func (netMes *networkMessenger) IsInterfaceNil() bool { - return netMes == nil -} diff --git a/p2p/libp2p/netMessenger_test.go b/p2p/libp2p/netMessenger_test.go deleted file mode 100644 index 7da5d77ceee..00000000000 --- a/p2p/libp2p/netMessenger_test.go +++ /dev/null @@ -1,1984 +0,0 @@ -package libp2p_test - -import ( - "bytes" - "context" - "errors" - "fmt" - "runtime" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go-core/marshal" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/crypto" - "github.com/ElrondNetwork/elrond-go/p2p/data" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p" - "github.com/ElrondNetwork/elrond-go/p2p/message" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - pubsub "github.com/ElrondNetwork/go-libp2p-pubsub" - pb "github.com/ElrondNetwork/go-libp2p-pubsub/pb" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/peerstore" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" - "github.com/multiformats/go-multiaddr" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var timeoutWaitResponses = time.Second * 2 - -func waitDoneWithTimeout(t *testing.T, chanDone chan bool, timeout time.Duration) { - select { - case <-chanDone: - return - case <-time.After(timeout): - assert.Fail(t, "timeout reached") - } -} - -func prepareMessengerForMatchDataReceive(messenger p2p.Messenger, matchData []byte, wg *sync.WaitGroup) { - _ = messenger.CreateTopic("test", false) - - _ = messenger.RegisterMessageProcessor("test", "identifier", - &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, _ core.PeerID) error { - if bytes.Equal(matchData, message.Data()) { - fmt.Printf("%s got the message\n", messenger.ID().Pretty()) - wg.Done() - } - - return nil - }, - }) -} - -func getConnectableAddress(messenger p2p.Messenger) string { - for _, addr := range messenger.Addresses() { - if strings.Contains(addr, "circuit") || strings.Contains(addr, "169.254") { - continue - } - - return addr - } - - return "" -} - -func createMockNetworkArgs() libp2p.ArgsNetworkMessenger { - return libp2p.ArgsNetworkMessenger{ - Marshalizer: &testscommon.ProtoMarshalizerMock{}, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - }, - Sharding: config.ShardingConfig{ - Type: p2p.NilListSharder, - }, - }, - SyncTimer: &libp2p.LocalSyncTimer{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } -} - -func createMockNetworkOf2() (mocknet.Mocknet, p2p.Messenger, p2p.Messenger) { - netw := mocknet.New() - - messenger1, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - messenger2, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - _ = netw.LinkAll() - - return netw, messenger1, messenger2 -} - -func createMockNetworkOf3() (p2p.Messenger, p2p.Messenger, p2p.Messenger) { - netw := mocknet.New() - - messenger1, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - messenger2, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - _ = netw.LinkAll() - - nscm1 := mock.NewNetworkShardingCollectorMock() - nscm1.PutPeerIdSubType(messenger1.ID(), core.FullHistoryObserver) - nscm1.PutPeerIdSubType(messenger2.ID(), core.FullHistoryObserver) - nscm1.PutPeerIdSubType(messenger3.ID(), core.RegularPeer) - _ = messenger1.SetPeerShardResolver(nscm1) - - nscm2 := mock.NewNetworkShardingCollectorMock() - nscm2.PutPeerIdSubType(messenger1.ID(), core.FullHistoryObserver) - nscm2.PutPeerIdSubType(messenger2.ID(), core.FullHistoryObserver) - nscm2.PutPeerIdSubType(messenger3.ID(), core.RegularPeer) - _ = messenger2.SetPeerShardResolver(nscm2) - - nscm3 := mock.NewNetworkShardingCollectorMock() - nscm3.PutPeerIdSubType(messenger1.ID(), core.FullHistoryObserver) - nscm3.PutPeerIdSubType(messenger2.ID(), core.FullHistoryObserver) - nscm3.PutPeerIdSubType(messenger3.ID(), core.RegularPeer) - _ = messenger3.SetPeerShardResolver(nscm3) - - return messenger1, messenger2, messenger3 -} - -func createMockMessenger() p2p.Messenger { - netw := mocknet.New() - - messenger, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - return messenger -} - -func containsPeerID(list []core.PeerID, searchFor core.PeerID) bool { - for _, pid := range list { - if bytes.Equal(pid.Bytes(), searchFor.Bytes()) { - return true - } - } - return false -} - -// ------- NewMemoryLibp2pMessenger - -func TestNewMemoryLibp2pMessenger_NilMockNetShouldErr(t *testing.T) { - args := createMockNetworkArgs() - messenger, err := libp2p.NewMockMessenger(args, nil) - - assert.Nil(t, messenger) - assert.Equal(t, p2p.ErrNilMockNet, err) -} - -func TestNewMemoryLibp2pMessenger_OkValsWithoutDiscoveryShouldWork(t *testing.T) { - netw := mocknet.New() - - messenger, err := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - assert.Nil(t, err) - assert.False(t, check.IfNil(messenger)) - - _ = messenger.Close() -} - -// ------- NewNetworkMessenger - -func TestNewNetworkMessenger_NilMessengerShouldErr(t *testing.T) { - arg := createMockNetworkArgs() - arg.Marshalizer = nil - messenger, err := libp2p.NewNetworkMessenger(arg) - - assert.True(t, check.IfNil(messenger)) - assert.True(t, errors.Is(err, p2p.ErrNilMarshalizer)) -} - -func TestNewNetworkMessenger_NilPreferredPeersHolderShouldErr(t *testing.T) { - arg := createMockNetworkArgs() - arg.PreferredPeersHolder = nil - messenger, err := libp2p.NewNetworkMessenger(arg) - - assert.True(t, check.IfNil(messenger)) - assert.True(t, errors.Is(err, p2p.ErrNilPreferredPeersHolder)) -} - -func TestNewNetworkMessenger_NilPeersRatingHandlerShouldErr(t *testing.T) { - arg := createMockNetworkArgs() - arg.PeersRatingHandler = nil - mes, err := libp2p.NewNetworkMessenger(arg) - - assert.True(t, check.IfNil(mes)) - assert.True(t, errors.Is(err, p2p.ErrNilPeersRatingHandler)) -} - -func TestNewNetworkMessenger_NilSyncTimerShouldErr(t *testing.T) { - arg := createMockNetworkArgs() - arg.SyncTimer = nil - messenger, err := libp2p.NewNetworkMessenger(arg) - - assert.True(t, check.IfNil(messenger)) - assert.True(t, errors.Is(err, p2p.ErrNilSyncTimer)) -} - -func TestNewNetworkMessenger_WithDeactivatedKadDiscovererShouldWork(t *testing.T) { - arg := createMockNetworkArgs() - messenger, err := libp2p.NewNetworkMessenger(arg) - - assert.NotNil(t, messenger) - assert.Nil(t, err) - - _ = messenger.Close() -} - -func TestNewNetworkMessenger_WithKadDiscovererListsSharderInvalidTargetConnShouldErr(t *testing.T) { - arg := createMockNetworkArgs() - arg.P2pConfig.KadDhtPeerDiscovery = config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - Type: "optimized", - RefreshIntervalInSec: 10, - ProtocolID: "/erd/kad/1.0.0", - InitialPeerList: nil, - BucketSize: 100, - RoutingTableRefreshIntervalInSec: 10, - } - arg.P2pConfig.Sharding.Type = p2p.ListsSharder - messenger, err := libp2p.NewNetworkMessenger(arg) - - assert.True(t, check.IfNil(messenger)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewNetworkMessenger_WithKadDiscovererListSharderShouldWork(t *testing.T) { - arg := createMockNetworkArgs() - arg.P2pConfig.KadDhtPeerDiscovery = config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - Type: "optimized", - RefreshIntervalInSec: 10, - ProtocolID: "/erd/kad/1.0.0", - InitialPeerList: nil, - BucketSize: 100, - RoutingTableRefreshIntervalInSec: 10, - } - arg.P2pConfig.Sharding = config.ShardingConfig{ - Type: p2p.NilListSharder, - TargetPeerCount: 10, - } - messenger, err := libp2p.NewNetworkMessenger(arg) - - assert.False(t, check.IfNil(messenger)) - assert.Nil(t, err) - - _ = messenger.Close() -} - -// ------- Messenger functionality - -func TestLibp2pMessenger_ConnectToPeerShouldCallUpgradedHost(t *testing.T) { - netw := mocknet.New() - - messenger, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - _ = messenger.Close() - - wasCalled := false - - p := "peer" - - uhs := &mock.ConnectableHostStub{ - ConnectToPeerCalled: func(ctx context.Context, address string) error { - if p == address { - wasCalled = true - } - return nil - }, - } - - messenger.SetHost(uhs) - _ = messenger.ConnectToPeer(p) - assert.True(t, wasCalled) -} - -func TestLibp2pMessenger_IsConnectedShouldWork(t *testing.T) { - _, messenger1, messenger2 := createMockNetworkOf2() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - - assert.True(t, messenger1.IsConnected(messenger2.ID())) - assert.True(t, messenger2.IsConnected(messenger1.ID())) - - _ = messenger1.Close() - _ = messenger2.Close() -} - -func TestLibp2pMessenger_CreateTopicOkValsShouldWork(t *testing.T) { - messenger := createMockMessenger() - - err := messenger.CreateTopic("test", true) - assert.Nil(t, err) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_CreateTopicTwiceShouldNotErr(t *testing.T) { - messenger := createMockMessenger() - - _ = messenger.CreateTopic("test", false) - err := messenger.CreateTopic("test", false) - assert.Nil(t, err) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_HasTopicIfHaveTopicShouldReturnTrue(t *testing.T) { - messenger := createMockMessenger() - - _ = messenger.CreateTopic("test", false) - - assert.True(t, messenger.HasTopic("test")) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_HasTopicIfDoNotHaveTopicShouldReturnFalse(t *testing.T) { - messenger := createMockMessenger() - - _ = messenger.CreateTopic("test", false) - - assert.False(t, messenger.HasTopic("one topic")) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_RegisterTopicValidatorOnInexistentTopicShouldWork(t *testing.T) { - messenger := createMockMessenger() - - err := messenger.RegisterMessageProcessor("test", "identifier", &mock.MessageProcessorStub{}) - - assert.Nil(t, err) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_RegisterTopicValidatorWithNilHandlerShouldErr(t *testing.T) { - messenger := createMockMessenger() - - _ = messenger.CreateTopic("test", false) - - err := messenger.RegisterMessageProcessor("test", "identifier", nil) - - assert.True(t, errors.Is(err, p2p.ErrNilValidator)) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_RegisterTopicValidatorOkValsShouldWork(t *testing.T) { - messenger := createMockMessenger() - - _ = messenger.CreateTopic("test", false) - - err := messenger.RegisterMessageProcessor("test", "identifier", &mock.MessageProcessorStub{}) - - assert.Nil(t, err) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_RegisterTopicValidatorReregistrationShouldErr(t *testing.T) { - messenger := createMockMessenger() - _ = messenger.CreateTopic("test", false) - // registration - _ = messenger.RegisterMessageProcessor("test", "identifier", &mock.MessageProcessorStub{}) - // re-registration - err := messenger.RegisterMessageProcessor("test", "identifier", &mock.MessageProcessorStub{}) - - assert.True(t, errors.Is(err, p2p.ErrMessageProcessorAlreadyDefined)) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_UnegisterTopicValidatorOnANotRegisteredTopicShouldNotErr(t *testing.T) { - messenger := createMockMessenger() - - _ = messenger.CreateTopic("test", false) - err := messenger.UnregisterMessageProcessor("test", "identifier") - - assert.Nil(t, err) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_UnregisterTopicValidatorShouldWork(t *testing.T) { - messenger := createMockMessenger() - - _ = messenger.CreateTopic("test", false) - - // registration - _ = messenger.RegisterMessageProcessor("test", "identifier", &mock.MessageProcessorStub{}) - - // unregistration - err := messenger.UnregisterMessageProcessor("test", "identifier") - - assert.Nil(t, err) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_UnregisterAllTopicValidatorShouldWork(t *testing.T) { - messenger := createMockMessenger() - _ = messenger.CreateTopic("test", false) - // registration - _ = messenger.CreateTopic("test1", false) - _ = messenger.RegisterMessageProcessor("test1", "identifier", &mock.MessageProcessorStub{}) - _ = messenger.CreateTopic("test2", false) - _ = messenger.RegisterMessageProcessor("test2", "identifier", &mock.MessageProcessorStub{}) - // unregistration - err := messenger.UnregisterAllMessageProcessors() - assert.Nil(t, err) - err = messenger.RegisterMessageProcessor("test1", "identifier", &mock.MessageProcessorStub{}) - assert.Nil(t, err) - err = messenger.RegisterMessageProcessor("test2", "identifier", &mock.MessageProcessorStub{}) - assert.Nil(t, err) - _ = messenger.Close() -} - -func TestLibp2pMessenger_RegisterUnregisterConcurrentlyShouldNotPanic(t *testing.T) { - defer func() { - r := recover() - if r != nil { - assert.Fail(t, fmt.Sprintf("should have not panic: %v", r)) - } - }() - - messenger := createMockMessenger() - topic := "test topic" - _ = messenger.CreateTopic(topic, false) - - numIdentifiers := 100 - identifiers := make([]string, 0, numIdentifiers) - for i := 0; i < numIdentifiers; i++ { - identifiers = append(identifiers, fmt.Sprintf("identifier%d", i)) - } - - wg := sync.WaitGroup{} - wg.Add(numIdentifiers * 3) - for i := 0; i < numIdentifiers; i++ { - go func(index int) { - _ = messenger.RegisterMessageProcessor(topic, identifiers[index], &mock.MessageProcessorStub{}) - wg.Done() - }(i) - - go func(index int) { - _ = messenger.UnregisterMessageProcessor(topic, identifiers[index]) - wg.Done() - }(i) - - go func() { - messenger.Broadcast(topic, []byte("buff")) - wg.Done() - }() - } - - wg.Wait() - _ = messenger.Close() -} - -func TestLibp2pMessenger_BroadcastDataLargeMessageShouldNotCallSend(t *testing.T) { - msg := make([]byte, libp2p.MaxSendBuffSize+1) - messenger, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - messenger.SetLoadBalancer(&mock.ChannelLoadBalancerStub{ - GetChannelOrDefaultCalled: func(pipe string) chan *p2p.SendableData { - assert.Fail(t, "should have not got to this line") - - return make(chan *p2p.SendableData, 1) - }, - CollectOneElementFromChannelsCalled: func() *p2p.SendableData { - return nil - }, - }) - - messenger.Broadcast("topic", msg) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_BroadcastDataBetween2PeersShouldWork(t *testing.T) { - msg := []byte("test message") - - _, messenger1, messenger2 := createMockNetworkOf2() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - - wg := &sync.WaitGroup{} - chanDone := make(chan bool) - wg.Add(2) - - go func() { - wg.Wait() - chanDone <- true - }() - - prepareMessengerForMatchDataReceive(messenger1, msg, wg) - prepareMessengerForMatchDataReceive(messenger2, msg, wg) - - fmt.Println("Delaying as to allow peers to announce themselves on the opened topic...") - time.Sleep(time.Second) - - fmt.Printf("sending message from %s...\n", messenger1.ID().Pretty()) - - messenger1.Broadcast("test", msg) - - waitDoneWithTimeout(t, chanDone, timeoutWaitResponses) - - _ = messenger1.Close() - _ = messenger2.Close() -} - -func TestLibp2pMessenger_BroadcastOnChannelBlockingShouldLimitNumberOfGoRoutines(t *testing.T) { - if testing.Short() { - t.Skip("this test does not perform well in TC with race detector on") - } - - msg := []byte("test message") - numBroadcasts := libp2p.BroadcastGoRoutines + 5 - - ch := make(chan *p2p.SendableData) - - wg := sync.WaitGroup{} - wg.Add(numBroadcasts) - - messenger, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - messenger.SetLoadBalancer(&mock.ChannelLoadBalancerStub{ - CollectOneElementFromChannelsCalled: func() *p2p.SendableData { - return nil - }, - GetChannelOrDefaultCalled: func(pipe string) chan *p2p.SendableData { - wg.Done() - return ch - }, - }) - - numErrors := uint32(0) - - for i := 0; i < numBroadcasts; i++ { - go func() { - err := messenger.BroadcastOnChannelBlocking("test", "test", msg) - if err == p2p.ErrTooManyGoroutines { - atomic.AddUint32(&numErrors, 1) - wg.Done() - } - }() - } - - wg.Wait() - - // cleanup stuck go routines that are trying to write on the ch channel - for i := 0; i < libp2p.BroadcastGoRoutines; i++ { - select { - case <-ch: - default: - } - } - - assert.True(t, atomic.LoadUint32(&numErrors) > 0) - - _ = messenger.Close() -} - -func TestLibp2pMessenger_BroadcastDataBetween2PeersWithLargeMsgShouldWork(t *testing.T) { - msg := bytes.Repeat([]byte{'A'}, libp2p.MaxSendBuffSize) - - _, messenger1, messenger2 := createMockNetworkOf2() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - - wg := &sync.WaitGroup{} - chanDone := make(chan bool) - wg.Add(2) - - go func() { - wg.Wait() - chanDone <- true - }() - - prepareMessengerForMatchDataReceive(messenger1, msg, wg) - prepareMessengerForMatchDataReceive(messenger2, msg, wg) - - fmt.Println("Delaying as to allow peers to announce themselves on the opened topic...") - time.Sleep(time.Second) - - fmt.Printf("sending message from %s...\n", messenger1.ID().Pretty()) - - messenger1.Broadcast("test", msg) - - waitDoneWithTimeout(t, chanDone, timeoutWaitResponses) - - _ = messenger1.Close() - _ = messenger2.Close() -} - -func TestLibp2pMessenger_Peers(t *testing.T) { - _, messenger1, messenger2 := createMockNetworkOf2() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - - // should know both peers - foundCurrent := false - foundConnected := false - - for _, p := range messenger1.Peers() { - fmt.Println(p.Pretty()) - - if p.Pretty() == messenger1.ID().Pretty() { - foundCurrent = true - } - if p.Pretty() == messenger2.ID().Pretty() { - foundConnected = true - } - } - - assert.True(t, foundCurrent && foundConnected) - - _ = messenger1.Close() - _ = messenger2.Close() -} - -func TestLibp2pMessenger_ConnectedPeers(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - - // connected peers: 1 ----- 2 ----- 3 - - assert.Equal(t, []core.PeerID{messenger2.ID()}, messenger1.ConnectedPeers()) - assert.Equal(t, []core.PeerID{messenger2.ID()}, messenger3.ConnectedPeers()) - assert.Equal(t, 2, len(messenger2.ConnectedPeers())) - // no need to further test that messenger2 is connected to messenger1 and messenger3 as this was tested in first 2 asserts - - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() -} - -func TestLibp2pMessenger_ConnectedAddresses(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - - // connected peers: 1 ----- 2 ----- 3 - - foundAddr1 := false - foundAddr3 := false - - for _, addr := range messenger2.ConnectedAddresses() { - for _, address := range messenger1.Addresses() { - if addr == address { - foundAddr1 = true - } - } - - for _, address := range messenger3.Addresses() { - if addr == address { - foundAddr3 = true - } - } - } - - assert.True(t, foundAddr1) - assert.True(t, foundAddr3) - assert.Equal(t, 2, len(messenger2.ConnectedAddresses())) - // no need to further test that messenger2 is connected to messenger1 and messenger3 as this was tested in first 2 asserts - - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() -} - -func TestLibp2pMessenger_PeerAddressConnectedPeerShouldWork(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - - // connected peers: 1 ----- 2 ----- 3 - - defer func() { - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() - }() - - addressesRecov := messenger2.PeerAddresses(messenger1.ID()) - for _, addr := range messenger1.Addresses() { - for _, addrRecov := range addressesRecov { - if strings.Contains(addr, addrRecov) { - // address returned is valid, test is successful - return - } - } - } - - assert.Fail(t, "Returned address is not valid!") -} - -func TestLibp2pMessenger_PeerAddressNotConnectedShouldReturnFromPeerstore(t *testing.T) { - netw := mocknet.New() - messenger, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - networkHandler := &mock.NetworkStub{ - ConnsCalled: func() []network.Conn { - return nil - }, - } - - peerstoreHandler := &mock.PeerstoreStub{ - AddrsCalled: func(p peer.ID) []multiaddr.Multiaddr { - return []multiaddr.Multiaddr{ - &mock.MultiaddrStub{ - StringCalled: func() string { - return "multiaddress 1" - }, - }, - &mock.MultiaddrStub{ - StringCalled: func() string { - return "multiaddress 2" - }, - }, - } - }, - } - - messenger.SetHost(&mock.ConnectableHostStub{ - NetworkCalled: func() network.Network { - return networkHandler - }, - PeerstoreCalled: func() peerstore.Peerstore { - return peerstoreHandler - }, - }) - - addresses := messenger.PeerAddresses("pid") - require.Equal(t, 2, len(addresses)) - assert.Equal(t, addresses[0], "multiaddress 1") - assert.Equal(t, addresses[1], "multiaddress 2") -} - -func TestLibp2pMessenger_PeerAddressDisconnectedPeerShouldWork(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - - defer func() { - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() - }() - - _ = netw.UnlinkPeers(peer.ID(messenger1.ID().Bytes()), peer.ID(messenger2.ID().Bytes())) - _ = netw.DisconnectPeers(peer.ID(messenger1.ID().Bytes()), peer.ID(messenger2.ID().Bytes())) - _ = netw.DisconnectPeers(peer.ID(messenger2.ID().Bytes()), peer.ID(messenger1.ID().Bytes())) - - // connected peers: 1 --x-- 2 ----- 3 - - assert.False(t, messenger2.IsConnected(messenger1.ID())) -} - -func TestLibp2pMessenger_PeerAddressUnknownPeerShouldReturnEmpty(t *testing.T) { - _, messenger1, _ := createMockNetworkOf2() - - defer func() { - _ = messenger1.Close() - }() - - adr1Recov := messenger1.PeerAddresses("unknown peer") - assert.Equal(t, 0, len(adr1Recov)) -} - -// ------- ConnectedPeersOnTopic - -func TestLibp2pMessenger_ConnectedPeersOnTopicInvalidTopicShouldRetEmptyList(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - // connected peers: 1 ----- 2 ----- 3 - connPeers := messenger1.ConnectedPeersOnTopic("non-existent topic") - assert.Equal(t, 0, len(connPeers)) - - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() -} - -func TestLibp2pMessenger_ConnectedPeersOnTopicOneTopicShouldWork(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - messenger4, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - _ = messenger4.ConnectToPeer(adr2) - // connected peers: 1 ----- 2 ----- 3 - // | - // 4 - // 1, 2, 3 should be on topic "topic123" - _ = messenger1.CreateTopic("topic123", false) - _ = messenger2.CreateTopic("topic123", false) - _ = messenger3.CreateTopic("topic123", false) - - // wait a bit for topic announcements - time.Sleep(time.Second) - - peersOnTopic123 := messenger2.ConnectedPeersOnTopic("topic123") - - assert.Equal(t, 2, len(peersOnTopic123)) - assert.True(t, containsPeerID(peersOnTopic123, messenger1.ID())) - assert.True(t, containsPeerID(peersOnTopic123, messenger3.ID())) - - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() - _ = messenger4.Close() -} - -func TestLibp2pMessenger_ConnectedPeersOnTopicOneTopicDifferentViewsShouldWork(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - messenger4, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - _ = messenger4.ConnectToPeer(adr2) - // connected peers: 1 ----- 2 ----- 3 - // | - // 4 - // 1, 2, 3 should be on topic "topic123" - _ = messenger1.CreateTopic("topic123", false) - _ = messenger2.CreateTopic("topic123", false) - _ = messenger3.CreateTopic("topic123", false) - - // wait a bit for topic announcements - time.Sleep(time.Second) - - peersOnTopic123FromMessenger2 := messenger2.ConnectedPeersOnTopic("topic123") - peersOnTopic123FromMessenger4 := messenger4.ConnectedPeersOnTopic("topic123") - - // keep the same checks as the test above as to be 100% that the returned list are correct - assert.Equal(t, 2, len(peersOnTopic123FromMessenger2)) - assert.True(t, containsPeerID(peersOnTopic123FromMessenger2, messenger1.ID())) - assert.True(t, containsPeerID(peersOnTopic123FromMessenger2, messenger3.ID())) - - assert.Equal(t, 1, len(peersOnTopic123FromMessenger4)) - assert.True(t, containsPeerID(peersOnTopic123FromMessenger4, messenger2.ID())) - - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() - _ = messenger4.Close() -} - -func TestLibp2pMessenger_ConnectedPeersOnTopicTwoTopicsShouldWork(t *testing.T) { - netw, messenger1, messenger2 := createMockNetworkOf2() - messenger3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - messenger4, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - _ = netw.LinkAll() - - adr2 := messenger2.Addresses()[0] - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - _ = messenger4.ConnectToPeer(adr2) - // connected peers: 1 ----- 2 ----- 3 - // | - // 4 - // 1, 2, 3 should be on topic "topic123" - // 2, 4 should be on topic "topic24" - _ = messenger1.CreateTopic("topic123", false) - _ = messenger2.CreateTopic("topic123", false) - _ = messenger2.CreateTopic("topic24", false) - _ = messenger3.CreateTopic("topic123", false) - _ = messenger4.CreateTopic("topic24", false) - - // wait a bit for topic announcements - time.Sleep(time.Second) - - peersOnTopic123 := messenger2.ConnectedPeersOnTopic("topic123") - peersOnTopic24 := messenger2.ConnectedPeersOnTopic("topic24") - - // keep the same checks as the test above as to be 100% that the returned list are correct - assert.Equal(t, 2, len(peersOnTopic123)) - assert.True(t, containsPeerID(peersOnTopic123, messenger1.ID())) - assert.True(t, containsPeerID(peersOnTopic123, messenger3.ID())) - - assert.Equal(t, 1, len(peersOnTopic24)) - assert.True(t, containsPeerID(peersOnTopic24, messenger4.ID())) - - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() - _ = messenger4.Close() -} - -// ------- ConnectedFullHistoryPeersOnTopic - -func TestLibp2pMessenger_ConnectedFullHistoryPeersOnTopicShouldWork(t *testing.T) { - messenger1, messenger2, messenger3 := createMockNetworkOf3() - - adr2 := messenger2.Addresses()[0] - adr3 := messenger3.Addresses()[0] - fmt.Println("Connecting ...") - - _ = messenger1.ConnectToPeer(adr2) - _ = messenger3.ConnectToPeer(adr2) - _ = messenger1.ConnectToPeer(adr3) - // connected peers: 1 ----- 2 - // | | - // 3 ------+ - - _ = messenger1.CreateTopic("topic123", false) - _ = messenger2.CreateTopic("topic123", false) - _ = messenger3.CreateTopic("topic123", false) - - // wait a bit for topic announcements - time.Sleep(time.Second) - - assert.Equal(t, 2, len(messenger1.ConnectedPeersOnTopic("topic123"))) - assert.Equal(t, 1, len(messenger1.ConnectedFullHistoryPeersOnTopic("topic123"))) - - assert.Equal(t, 2, len(messenger2.ConnectedPeersOnTopic("topic123"))) - assert.Equal(t, 1, len(messenger2.ConnectedFullHistoryPeersOnTopic("topic123"))) - - assert.Equal(t, 2, len(messenger3.ConnectedPeersOnTopic("topic123"))) - assert.Equal(t, 2, len(messenger3.ConnectedFullHistoryPeersOnTopic("topic123"))) - - _ = messenger1.Close() - _ = messenger2.Close() - _ = messenger3.Close() -} - -func TestLibp2pMessenger_ConnectedPeersShouldReturnUniquePeers(t *testing.T) { - pid1 := core.PeerID("pid1") - pid2 := core.PeerID("pid2") - pid3 := core.PeerID("pid3") - pid4 := core.PeerID("pid4") - - hs := &mock.ConnectableHostStub{ - NetworkCalled: func() network.Network { - return &mock.NetworkStub{ - ConnsCalled: func() []network.Conn { - // generate a mock list that contain duplicates - return []network.Conn{ - generateConnWithRemotePeer(pid1), - generateConnWithRemotePeer(pid1), - generateConnWithRemotePeer(pid2), - generateConnWithRemotePeer(pid1), - generateConnWithRemotePeer(pid4), - generateConnWithRemotePeer(pid3), - generateConnWithRemotePeer(pid1), - generateConnWithRemotePeer(pid3), - generateConnWithRemotePeer(pid4), - generateConnWithRemotePeer(pid2), - generateConnWithRemotePeer(pid1), - generateConnWithRemotePeer(pid1), - } - }, - ConnectednessCalled: func(id peer.ID) network.Connectedness { - return network.Connected - }, - } - }, - } - - netw := mocknet.New() - mes, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - // we can safely close the host as the next operations will be done on a mock - _ = mes.Close() - - mes.SetHost(hs) - - peerList := mes.ConnectedPeers() - - assert.Equal(t, 4, len(peerList)) - assert.True(t, existInList(peerList, pid1)) - assert.True(t, existInList(peerList, pid2)) - assert.True(t, existInList(peerList, pid3)) - assert.True(t, existInList(peerList, pid4)) -} - -func existInList(list []core.PeerID, pid core.PeerID) bool { - for _, p := range list { - if bytes.Equal(p.Bytes(), pid.Bytes()) { - return true - } - } - - return false -} - -func generateConnWithRemotePeer(pid core.PeerID) network.Conn { - return &mock.ConnStub{ - RemotePeerCalled: func() peer.ID { - return peer.ID(pid) - }, - } -} - -func TestLibp2pMessenger_SendDirectWithMockNetToConnectedPeerShouldWork(t *testing.T) { - msg := []byte("test message") - - _, messenger1, messenger2 := createMockNetworkOf2() - - adr2 := messenger2.Addresses()[0] - - fmt.Printf("Connecting to %s...\n", adr2) - - _ = messenger1.ConnectToPeer(adr2) - - wg := &sync.WaitGroup{} - chanDone := make(chan bool) - wg.Add(1) - - go func() { - wg.Wait() - chanDone <- true - }() - - prepareMessengerForMatchDataReceive(messenger2, msg, wg) - - fmt.Println("Delaying as to allow peers to announce themselves on the opened topic...") - time.Sleep(time.Second) - - fmt.Printf("sending message from %s...\n", messenger1.ID().Pretty()) - - err := messenger1.SendToConnectedPeer("test", msg, messenger2.ID()) - - assert.Nil(t, err) - - waitDoneWithTimeout(t, chanDone, timeoutWaitResponses) - - _ = messenger1.Close() - _ = messenger2.Close() -} - -func TestLibp2pMessenger_SendDirectWithRealNetToConnectedPeerShouldWork(t *testing.T) { - msg := []byte("test message") - - fmt.Println("Messenger 1:") - messenger1, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - fmt.Println("Messenger 2:") - messenger2, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - err := messenger1.ConnectToPeer(getConnectableAddress(messenger2)) - assert.Nil(t, err) - - wg := &sync.WaitGroup{} - chanDone := make(chan bool) - wg.Add(2) - - go func() { - wg.Wait() - chanDone <- true - }() - - prepareMessengerForMatchDataReceive(messenger1, msg, wg) - prepareMessengerForMatchDataReceive(messenger2, msg, wg) - - fmt.Println("Delaying as to allow peers to announce themselves on the opened topic...") - time.Sleep(time.Second) - - fmt.Printf("Messenger 1 is sending message from %s...\n", messenger1.ID().Pretty()) - err = messenger1.SendToConnectedPeer("test", msg, messenger2.ID()) - assert.Nil(t, err) - - time.Sleep(time.Second) - fmt.Printf("Messenger 2 is sending message from %s...\n", messenger2.ID().Pretty()) - err = messenger2.SendToConnectedPeer("test", msg, messenger1.ID()) - assert.Nil(t, err) - - waitDoneWithTimeout(t, chanDone, timeoutWaitResponses) - - _ = messenger1.Close() - _ = messenger2.Close() -} - -func TestLibp2pMessenger_SendDirectWithRealNetToSelfShouldWork(t *testing.T) { - msg := []byte("test message") - - fmt.Println("Messenger 1:") - mes, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - wg := &sync.WaitGroup{} - chanDone := make(chan bool) - wg.Add(1) - - go func() { - wg.Wait() - chanDone <- true - }() - - prepareMessengerForMatchDataReceive(mes, msg, wg) - - fmt.Printf("Messenger 1 is sending message from %s to self...\n", mes.ID().Pretty()) - err := mes.SendToConnectedPeer("test", msg, mes.ID()) - assert.Nil(t, err) - - time.Sleep(time.Second) - - waitDoneWithTimeout(t, chanDone, timeoutWaitResponses) - - _ = mes.Close() -} - -// ------- Bootstrap - -func TestNetworkMessenger_BootstrapPeerDiscoveryShouldCallPeerBootstrapper(t *testing.T) { - wasCalled := false - - netw := mocknet.New() - pdm := &mock.PeerDiscovererStub{ - BootstrapCalled: func() error { - wasCalled = true - return nil - }, - CloseCalled: func() error { - return nil - }, - } - mes, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - mes.SetPeerDiscoverer(pdm) - - _ = mes.Bootstrap() - - assert.True(t, wasCalled) - - _ = mes.Close() -} - -// ------- SetThresholdMinConnectedPeers - -func TestNetworkMessenger_SetThresholdMinConnectedPeersInvalidValueShouldErr(t *testing.T) { - messenger := createMockMessenger() - defer func() { - _ = messenger.Close() - }() - - err := messenger.SetThresholdMinConnectedPeers(-1) - - assert.Equal(t, p2p.ErrInvalidValue, err) -} - -func TestNetworkMessenger_SetThresholdMinConnectedPeersShouldWork(t *testing.T) { - messenger := createMockMessenger() - defer func() { - _ = messenger.Close() - }() - - minConnectedPeers := 56 - err := messenger.SetThresholdMinConnectedPeers(minConnectedPeers) - - assert.Nil(t, err) - assert.Equal(t, minConnectedPeers, messenger.ThresholdMinConnectedPeers()) -} - -// ------- IsConnectedToTheNetwork - -func TestNetworkMessenger_IsConnectedToTheNetworkRetFalse(t *testing.T) { - messenger := createMockMessenger() - defer func() { - _ = messenger.Close() - }() - - minConnectedPeers := 56 - _ = messenger.SetThresholdMinConnectedPeers(minConnectedPeers) - - assert.False(t, messenger.IsConnectedToTheNetwork()) -} - -func TestNetworkMessenger_IsConnectedToTheNetworkWithZeroRetTrue(t *testing.T) { - messenger := createMockMessenger() - defer func() { - _ = messenger.Close() - }() - - minConnectedPeers := 0 - _ = messenger.SetThresholdMinConnectedPeers(minConnectedPeers) - - assert.True(t, messenger.IsConnectedToTheNetwork()) -} - -// ------- SetPeerShardResolver - -func TestNetworkMessenger_SetPeerShardResolverNilShouldErr(t *testing.T) { - messenger := createMockMessenger() - defer func() { - _ = messenger.Close() - }() - - err := messenger.SetPeerShardResolver(nil) - - assert.Equal(t, p2p.ErrNilPeerShardResolver, err) -} - -func TestNetworkMessenger_SetPeerShardResolver(t *testing.T) { - messenger := createMockMessenger() - defer func() { - _ = messenger.Close() - }() - - err := messenger.SetPeerShardResolver(&mock.PeerShardResolverStub{}) - - assert.Nil(t, err) -} - -func TestNetworkMessenger_DoubleCloseShouldWork(t *testing.T) { - messenger := createMessenger() - - time.Sleep(time.Second) - - err := messenger.Close() - assert.Nil(t, err) - - err = messenger.Close() - assert.Nil(t, err) -} - -func TestNetworkMessenger_PreventReprocessingShouldWork(t *testing.T) { - args := libp2p.ArgsNetworkMessenger{ - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - Marshalizer: &testscommon.ProtoMarshalizerMock{}, - P2pConfig: config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - }, - Sharding: config.ShardingConfig{ - Type: p2p.NilListSharder, - }, - }, - SyncTimer: &libp2p.LocalSyncTimer{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - - mes, _ := libp2p.NewNetworkMessenger(args) - - numCalled := uint32(0) - handler := &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - atomic.AddUint32(&numCalled, 1) - return nil - }, - } - - callBackFunc := mes.PubsubCallback(handler, "") - ctx := context.Background() - pid := peer.ID(mes.ID()) - timeStamp := time.Now().Unix() - 1 - timeStamp -= int64(libp2p.AcceptMessagesInAdvanceDuration.Seconds()) - timeStamp -= int64(libp2p.PubsubTimeCacheDuration.Seconds()) - - innerMessage := &data.TopicMessage{ - Payload: []byte("data"), - Timestamp: timeStamp, - } - buff, _ := args.Marshalizer.Marshal(innerMessage) - msg := &pubsub.Message{ - Message: &pb.Message{ - From: []byte(pid), - Data: buff, - Seqno: []byte{0, 0, 0, 1}, - Topic: nil, - Signature: nil, - Key: nil, - XXX_NoUnkeyedLiteral: struct{}{}, - XXX_unrecognized: nil, - XXX_sizecache: 0, - }, - ReceivedFrom: "", - ValidatorData: nil, - } - - assert.False(t, callBackFunc(ctx, pid, msg)) // this will not call - assert.False(t, callBackFunc(ctx, pid, msg)) // this will not call - assert.Equal(t, uint32(0), atomic.LoadUint32(&numCalled)) - - _ = mes.Close() -} - -func TestNetworkMessenger_PubsubCallbackNotMessageNotValidShouldNotCallHandler(t *testing.T) { - args := libp2p.ArgsNetworkMessenger{ - Marshalizer: &testscommon.ProtoMarshalizerMock{}, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - }, - Sharding: config.ShardingConfig{ - Type: p2p.NilListSharder, - }, - }, - SyncTimer: &libp2p.LocalSyncTimer{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - - mes, _ := libp2p.NewNetworkMessenger(args) - numUpserts := int32(0) - _ = mes.SetPeerDenialEvaluator(&mock.PeerDenialEvaluatorStub{ - UpsertPeerIDCalled: func(pid core.PeerID, duration time.Duration) error { - atomic.AddInt32(&numUpserts, 1) - // any error thrown here should not impact the execution - return fmt.Errorf("expected error") - }, - IsDeniedCalled: func(pid core.PeerID) bool { - return false - }, - }) - - numCalled := uint32(0) - handler := &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - atomic.AddUint32(&numCalled, 1) - return nil - }, - } - - callBackFunc := mes.PubsubCallback(handler, "") - ctx := context.Background() - pid := peer.ID(mes.ID()) - innerMessage := &data.TopicMessage{ - Payload: []byte("data"), - Timestamp: time.Now().Unix(), - } - buff, _ := args.Marshalizer.Marshal(innerMessage) - msg := &pubsub.Message{ - Message: &pb.Message{ - From: []byte("not a valid pid"), - Data: buff, - Seqno: []byte{0, 0, 0, 1}, - Topic: nil, - Signature: nil, - Key: nil, - XXX_NoUnkeyedLiteral: struct{}{}, - XXX_unrecognized: nil, - XXX_sizecache: 0, - }, - ReceivedFrom: "", - ValidatorData: nil, - } - - assert.False(t, callBackFunc(ctx, pid, msg)) - assert.Equal(t, uint32(0), atomic.LoadUint32(&numCalled)) - assert.Equal(t, int32(2), atomic.LoadInt32(&numUpserts)) - - _ = mes.Close() -} - -func TestNetworkMessenger_PubsubCallbackReturnsFalseIfHandlerErrors(t *testing.T) { - args := libp2p.ArgsNetworkMessenger{ - Marshalizer: &testscommon.ProtoMarshalizerMock{}, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - }, - Sharding: config.ShardingConfig{ - Type: p2p.NilListSharder, - }, - }, - SyncTimer: &libp2p.LocalSyncTimer{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - - mes, _ := libp2p.NewNetworkMessenger(args) - - numCalled := uint32(0) - expectedErr := errors.New("expected error") - handler := &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - atomic.AddUint32(&numCalled, 1) - return expectedErr - }, - } - - callBackFunc := mes.PubsubCallback(handler, "") - ctx := context.Background() - pid := peer.ID(mes.ID()) - innerMessage := &data.TopicMessage{ - Payload: []byte("data"), - Timestamp: time.Now().Unix(), - Version: libp2p.CurrentTopicMessageVersion, - } - buff, _ := args.Marshalizer.Marshal(innerMessage) - topic := "topic" - msg := &pubsub.Message{ - Message: &pb.Message{ - From: []byte(mes.ID()), - Data: buff, - Seqno: []byte{0, 0, 0, 1}, - Topic: &topic, - Signature: nil, - Key: nil, - XXX_NoUnkeyedLiteral: struct{}{}, - XXX_unrecognized: nil, - XXX_sizecache: 0, - }, - ReceivedFrom: "", - ValidatorData: nil, - } - - assert.False(t, callBackFunc(ctx, pid, msg)) - assert.Equal(t, uint32(1), atomic.LoadUint32(&numCalled)) - - _ = mes.Close() -} - -func TestNetworkMessenger_UnjoinAllTopicsShouldWork(t *testing.T) { - args := libp2p.ArgsNetworkMessenger{ - Marshalizer: &testscommon.ProtoMarshalizerMock{}, - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - P2pConfig: config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: false, - }, - Sharding: config.ShardingConfig{ - Type: p2p.NilListSharder, - }, - }, - SyncTimer: &libp2p.LocalSyncTimer{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - - mes, _ := libp2p.NewNetworkMessenger(args) - - topic := "topic" - _ = mes.CreateTopic(topic, true) - assert.True(t, mes.HasTopic(topic)) - - err := mes.UnjoinAllTopics() - assert.Nil(t, err) - - assert.False(t, mes.HasTopic(topic)) -} - -func TestNetworkMessenger_ValidMessageByTimestampMessageTooOld(t *testing.T) { - args := createMockNetworkArgs() - now := time.Now() - args.SyncTimer = &mock.SyncTimerStub{ - CurrentTimeCalled: func() time.Time { - return now - }, - } - mes, _ := libp2p.NewNetworkMessenger(args) - - msg := &message.Message{ - TimestampField: now.Unix() - int64(libp2p.PubsubTimeCacheDuration.Seconds()) - 1, - } - err := mes.ValidMessageByTimestamp(msg) - - assert.True(t, errors.Is(err, p2p.ErrMessageTooOld)) -} - -func TestNetworkMessenger_ValidMessageByTimestampMessageAtLowerLimitShouldWork(t *testing.T) { - mes, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - now := time.Now() - msg := &message.Message{ - TimestampField: now.Unix() - int64(libp2p.PubsubTimeCacheDuration.Seconds()) + int64(libp2p.AcceptMessagesInAdvanceDuration.Seconds()), - } - err := mes.ValidMessageByTimestamp(msg) - - assert.Nil(t, err) -} - -func TestNetworkMessenger_ValidMessageByTimestampMessageTooNew(t *testing.T) { - args := createMockNetworkArgs() - now := time.Now() - args.SyncTimer = &mock.SyncTimerStub{ - CurrentTimeCalled: func() time.Time { - return now - }, - } - mes, _ := libp2p.NewNetworkMessenger(args) - - msg := &message.Message{ - TimestampField: now.Unix() + int64(libp2p.AcceptMessagesInAdvanceDuration.Seconds()) + 1, - } - err := mes.ValidMessageByTimestamp(msg) - - assert.True(t, errors.Is(err, p2p.ErrMessageTooNew)) -} - -func TestNetworkMessenger_ValidMessageByTimestampMessageAtUpperLimitShouldWork(t *testing.T) { - args := createMockNetworkArgs() - now := time.Now() - args.SyncTimer = &mock.SyncTimerStub{ - CurrentTimeCalled: func() time.Time { - return now - }, - } - mes, _ := libp2p.NewNetworkMessenger(args) - - msg := &message.Message{ - TimestampField: now.Unix() + int64(libp2p.AcceptMessagesInAdvanceDuration.Seconds()), - } - err := mes.ValidMessageByTimestamp(msg) - - assert.Nil(t, err) -} - -func TestNetworkMessenger_GetConnectedPeersInfo(t *testing.T) { - netw := mocknet.New() - - peers := []peer.ID{ - "valI1", - "valC1", - "valC2", - "obsI1", - "obsI2", - "obsI3", - "obsC1", - "obsC2", - "obsC3", - "obsC4", - "unknown", - } - mes, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - mes.SetHost(&mock.ConnectableHostStub{ - NetworkCalled: func() network.Network { - return &mock.NetworkStub{ - PeersCall: func() []peer.ID { - return peers - }, - ConnsToPeerCalled: func(p peer.ID) []network.Conn { - return make([]network.Conn, 0) - }, - } - }, - }) - selfShardID := uint32(0) - crossShardID := uint32(1) - _ = mes.SetPeerShardResolver(&mock.PeerShardResolverStub{ - GetPeerInfoCalled: func(pid core.PeerID) core.P2PPeerInfo { - pinfo := core.P2PPeerInfo{ - PeerType: core.UnknownPeer, - } - if pid.Pretty() == mes.ID().Pretty() { - pinfo.ShardID = selfShardID - pinfo.PeerType = core.ObserverPeer - return pinfo - } - - strPid := string(pid) - if strings.Contains(strPid, "I") { - pinfo.ShardID = selfShardID - } - if strings.Contains(strPid, "C") { - pinfo.ShardID = crossShardID - } - - if strings.Contains(strPid, "val") { - pinfo.PeerType = core.ValidatorPeer - } - - if strings.Contains(strPid, "obs") { - pinfo.PeerType = core.ObserverPeer - } - - return pinfo - }, - }) - - cpi := mes.GetConnectedPeersInfo() - - assert.Equal(t, 4, cpi.NumCrossShardObservers) - assert.Equal(t, 2, cpi.NumCrossShardValidators) - assert.Equal(t, 3, cpi.NumIntraShardObservers) - assert.Equal(t, 1, cpi.NumIntraShardValidators) - assert.Equal(t, 3, cpi.NumObserversOnShard[selfShardID]) - assert.Equal(t, 4, cpi.NumObserversOnShard[crossShardID]) - assert.Equal(t, 1, cpi.NumValidatorsOnShard[selfShardID]) - assert.Equal(t, 2, cpi.NumValidatorsOnShard[crossShardID]) - assert.Equal(t, selfShardID, cpi.SelfShardID) - assert.Equal(t, 1, len(cpi.UnknownPeers)) -} - -func TestNetworkMessenger_mapHistogram(t *testing.T) { - t.Parallel() - - args := createMockNetworkArgs() - netMes, _ := libp2p.NewNetworkMessenger(args) - - inp := map[uint32]int{ - 0: 5, - 1: 7, - 2: 9, - core.MetachainShardId: 11, - } - output := `shard 0: 5, shard 1: 7, shard 2: 9, meta: 11` - - require.Equal(t, output, netMes.MapHistogram(inp)) -} - -func TestNetworkMessenger_Bootstrap(t *testing.T) { - t.Skip("long test used to debug go routines closing on the netMessenger") - - t.Parallel() - - _ = logger.SetLogLevel("*:DEBUG") - log := logger.GetOrCreate("internal tests") - - args := libp2p.ArgsNetworkMessenger{ - ListenAddress: libp2p.ListenLocalhostAddrWithIp4AndTcp, - Marshalizer: &marshal.GogoProtoMarshalizer{}, - P2pConfig: config.P2PConfig{ - Node: config.NodeConfig{ - Port: "0", - Seed: "", - MaximumExpectedPeerCount: 1, - ThresholdMinConnectedPeers: 1, - }, - KadDhtPeerDiscovery: config.KadDhtPeerDiscoveryConfig{ - Enabled: true, - Type: "optimized", - RefreshIntervalInSec: 10, - ProtocolID: "erd/kad/1.0.0", - InitialPeerList: []string{"/ip4/35.214.140.83/tcp/10000/p2p/16Uiu2HAm6hPymvkZyFgbvWaVBKhEoPjmXhkV32r9JaFvQ7Rk8ynU"}, - BucketSize: 10, - RoutingTableRefreshIntervalInSec: 5, - }, - Sharding: config.ShardingConfig{ - TargetPeerCount: 0, - MaxIntraShardValidators: 0, - MaxCrossShardValidators: 0, - MaxIntraShardObservers: 0, - MaxCrossShardObservers: 0, - MaxSeeders: 0, - Type: "NilListSharder", - }, - }, - SyncTimer: &mock.SyncTimerStub{}, - PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, ConnectionWatcherType: p2p.ConnectionWatcherTypePrint, - } - - netMes, err := libp2p.NewNetworkMessenger(args) - require.Nil(t, err) - - go func() { - time.Sleep(time.Second * 1) - goRoutinesNumberStart := runtime.NumGoroutine() - log.Info("before closing", "num go routines", goRoutinesNumberStart) - - _ = netMes.Close() - }() - - _ = netMes.Bootstrap() - - time.Sleep(time.Second * 5) - - goRoutinesNumberStart := runtime.NumGoroutine() - core.DumpGoRoutinesToLog(goRoutinesNumberStart, log) -} - -func TestNetworkMessenger_WaitForConnections(t *testing.T) { - t.Parallel() - - t.Run("min num of peers is 0", func(t *testing.T) { - t.Parallel() - - startTime := time.Now() - _, mes1, mes2 := createMockNetworkOf2() - _ = mes1.ConnectToPeer(mes2.Addresses()[0]) - - defer func() { - _ = mes1.Close() - _ = mes2.Close() - }() - - timeToWait := time.Second * 3 - mes1.WaitForConnections(timeToWait, 0) - - assert.True(t, timeToWait <= time.Since(startTime)) - }) - t.Run("min num of peers is 2", func(t *testing.T) { - t.Parallel() - - startTime := time.Now() - netw, mes1, mes2 := createMockNetworkOf2() - mes3, _ := libp2p.NewMockMessenger(createMockNetworkArgs(), netw) - _ = netw.LinkAll() - - _ = mes1.ConnectToPeer(mes2.Addresses()[0]) - go func() { - time.Sleep(time.Second * 2) - _ = mes1.ConnectToPeer(mes3.Addresses()[0]) - }() - - defer func() { - _ = mes1.Close() - _ = mes2.Close() - _ = mes3.Close() - }() - - timeToWait := time.Second * 10 - mes1.WaitForConnections(timeToWait, 2) - - assert.True(t, timeToWait > time.Since(startTime)) - assert.True(t, libp2p.PollWaitForConnectionsInterval <= time.Since(startTime)) - }) - t.Run("min num of peers is 2 but we only connected to 1 peer", func(t *testing.T) { - t.Parallel() - - startTime := time.Now() - _, mes1, mes2 := createMockNetworkOf2() - - _ = mes1.ConnectToPeer(mes2.Addresses()[0]) - - defer func() { - _ = mes1.Close() - _ = mes2.Close() - }() - - timeToWait := time.Second * 10 - mes1.WaitForConnections(timeToWait, 2) - - assert.True(t, timeToWait < time.Since(startTime)) - }) -} - -func TestLibp2pMessenger_SignVerifyPayloadShouldWork(t *testing.T) { - fmt.Println("Messenger 1:") - messenger1, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - fmt.Println("Messenger 2:") - messenger2, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - err := messenger1.ConnectToPeer(getConnectableAddress(messenger2)) - assert.Nil(t, err) - - defer func() { - _ = messenger1.Close() - _ = messenger2.Close() - }() - - payload := []byte("payload") - sig, err := messenger1.Sign(payload) - assert.Nil(t, err) - - err = messenger2.Verify(payload, messenger1.ID(), sig) - assert.Nil(t, err) - - err = messenger1.Verify(payload, messenger1.ID(), sig) - assert.Nil(t, err) -} - -func TestLibp2pMessenger_ConnectionTopic(t *testing.T) { - t.Parallel() - - t.Run("create topic should work", func(t *testing.T) { - t.Parallel() - - netMes, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - topic := common.ConnectionTopic - err := netMes.CreateTopic(topic, true) - assert.Nil(t, err) - assert.False(t, netMes.HasTopic(topic)) - assert.False(t, netMes.PubsubHasTopic(topic)) - - testTopic := "test topic" - err = netMes.CreateTopic(testTopic, true) - assert.Nil(t, err) - assert.True(t, netMes.HasTopic(testTopic)) - assert.True(t, netMes.PubsubHasTopic(testTopic)) - - err = netMes.UnjoinAllTopics() - assert.Nil(t, err) - assert.False(t, netMes.HasTopic(topic)) - assert.False(t, netMes.PubsubHasTopic(topic)) - assert.False(t, netMes.HasTopic(testTopic)) - assert.False(t, netMes.PubsubHasTopic(testTopic)) - - _ = netMes.Close() - }) - t.Run("register-unregister message processor should work", func(t *testing.T) { - t.Parallel() - - netMes, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - identifier := "identifier" - topic := common.ConnectionTopic - err := netMes.RegisterMessageProcessor(topic, identifier, &mock.MessageProcessorStub{}) - assert.Nil(t, err) - assert.True(t, netMes.HasProcessorForTopic(topic)) - - err = netMes.UnregisterMessageProcessor(topic, identifier) - assert.Nil(t, err) - assert.False(t, netMes.HasProcessorForTopic(topic)) - - _ = netMes.Close() - }) - t.Run("unregister all processors should work", func(t *testing.T) { - t.Parallel() - - netMes, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - topic := common.ConnectionTopic - err := netMes.RegisterMessageProcessor(topic, "identifier", &mock.MessageProcessorStub{}) - assert.Nil(t, err) - assert.True(t, netMes.HasProcessorForTopic(topic)) - - testTopic := "test topic" - err = netMes.RegisterMessageProcessor(testTopic, "identifier", &mock.MessageProcessorStub{}) - assert.Nil(t, err) - assert.True(t, netMes.HasProcessorForTopic(testTopic)) - - err = netMes.UnregisterAllMessageProcessors() - assert.Nil(t, err) - assert.False(t, netMes.HasProcessorForTopic(topic)) - assert.False(t, netMes.HasProcessorForTopic(testTopic)) - - _ = netMes.Close() - }) - t.Run("unregister all processors should work", func(t *testing.T) { - t.Parallel() - - netMes, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - - topic := common.ConnectionTopic - err := netMes.RegisterMessageProcessor(topic, "identifier", &mock.MessageProcessorStub{}) - assert.Nil(t, err) - assert.True(t, netMes.HasProcessorForTopic(topic)) - - testTopic := "test topic" - err = netMes.RegisterMessageProcessor(testTopic, "identifier", &mock.MessageProcessorStub{}) - assert.Nil(t, err) - assert.True(t, netMes.HasProcessorForTopic(testTopic)) - - err = netMes.UnregisterAllMessageProcessors() - assert.Nil(t, err) - assert.False(t, netMes.HasProcessorForTopic(topic)) - assert.False(t, netMes.HasProcessorForTopic(testTopic)) - - _ = netMes.Close() - }) -} - -func TestNetworkMessenger_BroadcastUsingPrivateKey(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - msg := []byte("test message") - topic := "topic" - - interceptors := make([]*mock.MessageProcessorMock, 2) - - fmt.Println("Messenger 1:") - mes1, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - _ = mes1.CreateTopic(topic, true) - interceptors[0] = mock.NewMessageProcessorMock() - _ = mes1.RegisterMessageProcessor(topic, "", interceptors[0]) - - fmt.Println("Messenger 2:") - mes2, _ := libp2p.NewNetworkMessenger(createMockNetworkArgs()) - _ = mes2.CreateTopic(topic, true) - interceptors[1] = mock.NewMessageProcessorMock() - _ = mes2.RegisterMessageProcessor(topic, "", interceptors[1]) - - err := mes1.ConnectToPeer(getConnectableAddress(mes2)) - assert.Nil(t, err) - - time.Sleep(time.Second * 2) - - keyGen := crypto.NewIdentityGenerator() - skBuff, pid, err := keyGen.CreateRandomP2PIdentity() - assert.Nil(t, err) - fmt.Printf("new identity: %s\n", pid.Pretty()) - - mes1.BroadcastUsingPrivateKey(topic, msg, pid, skBuff) - - time.Sleep(time.Second * 2) - - for _, i := range interceptors { - messages := i.GetMessages() - - assert.Equal(t, 1, len(messages)) - assert.Equal(t, 1, messages[pid]) - assert.Equal(t, 0, messages[mes1.ID()]) - assert.Equal(t, 0, messages[mes2.ID()]) - } -} diff --git a/p2p/libp2p/networksharding/factory/sharderFactory.go b/p2p/libp2p/networksharding/factory/sharderFactory.go deleted file mode 100644 index ac9de8324df..00000000000 --- a/p2p/libp2p/networksharding/factory/sharderFactory.go +++ /dev/null @@ -1,79 +0,0 @@ -package factory - -import ( - "fmt" - - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/networksharding" - "github.com/libp2p/go-libp2p-core/peer" -) - -var log = logger.GetOrCreate("p2p/networksharding/factory") - -// ArgsSharderFactory represents the argument for the sharder factory -type ArgsSharderFactory struct { - PeerShardResolver p2p.PeerShardResolver - Pid peer.ID - P2pConfig config.P2PConfig - PreferredPeersHolder p2p.PreferredPeersHolderHandler - NodeOperationMode p2p.NodeOperation -} - -// NewSharder creates new Sharder instances -func NewSharder(arg ArgsSharderFactory) (p2p.Sharder, error) { - shardingType := arg.P2pConfig.Sharding.Type - switch shardingType { - case p2p.ListsSharder: - return listSharder(arg) - case p2p.OneListSharder: - return oneListSharder(arg) - case p2p.NilListSharder: - return nilListSharder() - default: - return nil, fmt.Errorf("%w when selecting sharder: unknown %s value", p2p.ErrInvalidValue, shardingType) - } -} - -func listSharder(arg ArgsSharderFactory) (p2p.Sharder, error) { - switch arg.NodeOperationMode { - case p2p.NormalOperation, p2p.FullArchiveMode: - default: - return nil, fmt.Errorf("%w unknown node operation mode %s", p2p.ErrInvalidValue, arg.NodeOperationMode) - } - - log.Debug("using lists sharder", - "MaxConnectionCount", arg.P2pConfig.Sharding.TargetPeerCount, - "MaxIntraShardValidators", arg.P2pConfig.Sharding.MaxIntraShardValidators, - "MaxCrossShardValidators", arg.P2pConfig.Sharding.MaxCrossShardValidators, - "MaxIntraShardObservers", arg.P2pConfig.Sharding.MaxIntraShardObservers, - "MaxCrossShardObservers", arg.P2pConfig.Sharding.MaxCrossShardObservers, - "MaxFullHistoryObservers", arg.P2pConfig.Sharding.AdditionalConnections.MaxFullHistoryObservers, - "MaxSeeders", arg.P2pConfig.Sharding.MaxSeeders, - "node operation", arg.NodeOperationMode, - ) - argListsSharder := networksharding.ArgListsSharder{ - PeerResolver: arg.PeerShardResolver, - SelfPeerId: arg.Pid, - P2pConfig: arg.P2pConfig, - PreferredPeersHolder: arg.PreferredPeersHolder, - NodeOperationMode: arg.NodeOperationMode, - } - return networksharding.NewListsSharder(argListsSharder) -} - -func oneListSharder(arg ArgsSharderFactory) (p2p.Sharder, error) { - log.Debug("using one list sharder", - "MaxConnectionCount", arg.P2pConfig.Sharding.TargetPeerCount, - ) - return networksharding.NewOneListSharder( - arg.Pid, - int(arg.P2pConfig.Sharding.TargetPeerCount), - ) -} - -func nilListSharder() (p2p.Sharder, error) { - log.Debug("using nil list sharder") - return networksharding.NewNilListSharder(), nil -} diff --git a/p2p/libp2p/networksharding/factory/sharderFactory_test.go b/p2p/libp2p/networksharding/factory/sharderFactory_test.go deleted file mode 100644 index cdf9286d5de..00000000000 --- a/p2p/libp2p/networksharding/factory/sharderFactory_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package factory - -import ( - "errors" - "reflect" - "strings" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/networksharding" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" -) - -func createMockArg() ArgsSharderFactory { - return ArgsSharderFactory{ - - PeerShardResolver: &mock.PeerShardResolverStub{}, - Pid: "", - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - P2pConfig: config.P2PConfig{ - Sharding: config.ShardingConfig{ - Type: "unknown", - TargetPeerCount: 6, - MaxIntraShardValidators: 1, - MaxCrossShardValidators: 1, - MaxIntraShardObservers: 1, - MaxCrossShardObservers: 1, - AdditionalConnections: config.AdditionalConnectionsConfig{ - MaxFullHistoryObservers: 1, - }, - }, - }, - NodeOperationMode: p2p.NormalOperation, - } -} - -func TestNewSharder_CreateListsSharderUnknownNodeOperationShouldError(t *testing.T) { - t.Parallel() - - arg := createMockArg() - arg.P2pConfig.Sharding.Type = p2p.ListsSharder - arg.NodeOperationMode = "" - sharder, err := NewSharder(arg) - - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) - assert.True(t, strings.Contains(err.Error(), "unknown node operation mode")) - assert.True(t, check.IfNil(sharder)) -} - -func TestNewSharder_CreateListsSharderShouldWork(t *testing.T) { - t.Parallel() - - arg := createMockArg() - arg.P2pConfig.Sharding.Type = p2p.ListsSharder - sharder, err := NewSharder(arg) - maxPeerCount := uint32(5) - maxValidators := uint32(1) - maxObservers := uint32(1) - - argListsSharder := networksharding.ArgListsSharder{ - PeerResolver: &mock.PeerShardResolverStub{}, - SelfPeerId: "", - P2pConfig: config.P2PConfig{ - Sharding: config.ShardingConfig{ - TargetPeerCount: maxPeerCount, - MaxIntraShardObservers: maxObservers, - MaxIntraShardValidators: maxValidators, - MaxCrossShardObservers: maxObservers, - MaxCrossShardValidators: maxValidators, - MaxSeeders: 0, - }, - }, - } - expectedSharder, _ := networksharding.NewListsSharder(argListsSharder) - assert.Nil(t, err) - assert.IsType(t, reflect.TypeOf(expectedSharder), reflect.TypeOf(sharder)) -} - -func TestNewSharder_CreateOneListSharderShouldWork(t *testing.T) { - t.Parallel() - - arg := createMockArg() - arg.P2pConfig.Sharding.Type = p2p.OneListSharder - sharder, err := NewSharder(arg) - maxPeerCount := 2 - - expectedSharder, _ := networksharding.NewOneListSharder("", maxPeerCount) - assert.Nil(t, err) - assert.IsType(t, reflect.TypeOf(expectedSharder), reflect.TypeOf(sharder)) -} - -func TestNewSharder_CreateNilListSharderShouldWork(t *testing.T) { - t.Parallel() - - arg := createMockArg() - arg.P2pConfig.Sharding.Type = p2p.NilListSharder - sharder, err := NewSharder(arg) - - expectedSharder := networksharding.NewNilListSharder() - assert.Nil(t, err) - assert.IsType(t, reflect.TypeOf(expectedSharder), reflect.TypeOf(sharder)) -} - -func TestNewSharder_CreateWithUnknownVariantShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockArg() - sharder, err := NewSharder(arg) - - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) - assert.True(t, check.IfNil(sharder)) -} diff --git a/p2p/libp2p/networksharding/listsSharder.go b/p2p/libp2p/networksharding/listsSharder.go deleted file mode 100644 index 84f491ae19b..00000000000 --- a/p2p/libp2p/networksharding/listsSharder.go +++ /dev/null @@ -1,419 +0,0 @@ -package networksharding - -import ( - "fmt" - "math/big" - "math/bits" - "sort" - "strings" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/networksharding/sorting" - "github.com/libp2p/go-libp2p-core/peer" - kbucket "github.com/libp2p/go-libp2p-kbucket" -) - -var _ p2p.Sharder = (*listsSharder)(nil) - -const minAllowedConnectedPeersListSharder = 5 -const minAllowedValidators = 1 -const minAllowedObservers = 1 -const minUnknownPeers = 1 - -const intraShardValidators = 0 -const intraShardObservers = 10 -const crossShardValidators = 20 -const crossShardObservers = 30 -const seeders = 40 -const unknown = 50 -const fullHistoryObservers = 60 - -var log = logger.GetOrCreate("p2p/libp2p/networksharding") - -var leadingZerosCount = []int{ - 8, 7, 6, 6, 5, 5, 5, 5, - 4, 4, 4, 4, 4, 4, 4, 4, - 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 3, 3, 3, 3, 3, - 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, -} - -// this will fail if we have less than 256 values in the slice -var _ = leadingZerosCount[255] - -// ArgListsSharder represents the argument structure used in the initialization of a listsSharder implementation -type ArgListsSharder struct { - PeerResolver p2p.PeerShardResolver - SelfPeerId peer.ID - P2pConfig config.P2PConfig - PreferredPeersHolder p2p.PreferredPeersHolderHandler - NodeOperationMode p2p.NodeOperation -} - -// listsSharder is the struct able to compute an eviction list of connected peers id according to the -// provided parameters. It basically splits all connected peers into 3 lists: intra shard peers, cross shard peers -// and unknown peers by the following rule: both intra shard and cross shard lists are upper bounded to provided -// maximum levels, unknown list is able to fill the gap until maximum peer count value is fulfilled. -type listsSharder struct { - mutResolver sync.RWMutex - peerShardResolver p2p.PeerShardResolver - selfPeerId peer.ID - maxPeerCount int - maxIntraShardValidators int - maxCrossShardValidators int - maxIntraShardObservers int - maxCrossShardObservers int - maxSeeders int - maxFullHistoryObservers int - maxUnknown int - mutSeeders sync.RWMutex - seeders []string - computeDistance func(src peer.ID, dest peer.ID) *big.Int - preferredPeersHolder p2p.PreferredPeersHolderHandler -} - -type peersConnections struct { - maxPeerCount int - intraShardValidators int - crossShardValidators int - intraShardObservers int - crossShardObservers int - seeders int - fullHistoryObservers int - unknown int -} - -// NewListsSharder creates a new kad list based kad sharder instance -func NewListsSharder(arg ArgListsSharder) (*listsSharder, error) { - if check.IfNil(arg.PeerResolver) { - return nil, p2p.ErrNilPeerShardResolver - } - if arg.P2pConfig.Sharding.TargetPeerCount < minAllowedConnectedPeersListSharder { - return nil, fmt.Errorf("%w, maxPeerCount should be at least %d", p2p.ErrInvalidValue, minAllowedConnectedPeersListSharder) - } - if arg.P2pConfig.Sharding.MaxIntraShardValidators < minAllowedValidators { - return nil, fmt.Errorf("%w, maxIntraShardValidators should be at least %d", p2p.ErrInvalidValue, minAllowedValidators) - } - if arg.P2pConfig.Sharding.MaxCrossShardValidators < minAllowedValidators { - return nil, fmt.Errorf("%w, maxCrossShardValidators should be at least %d", p2p.ErrInvalidValue, minAllowedValidators) - } - if arg.P2pConfig.Sharding.MaxIntraShardObservers < minAllowedObservers { - return nil, fmt.Errorf("%w, maxIntraShardObservers should be at least %d", p2p.ErrInvalidValue, minAllowedObservers) - } - if arg.P2pConfig.Sharding.MaxCrossShardObservers < minAllowedObservers { - return nil, fmt.Errorf("%w, maxCrossShardObservers should be at least %d", p2p.ErrInvalidValue, minAllowedObservers) - } - if check.IfNil(arg.PreferredPeersHolder) { - return nil, fmt.Errorf("%w while creating a new listsShared", p2p.ErrNilPreferredPeersHolder) - } - peersConn, err := processNumConnections(arg) - if err != nil { - return nil, err - } - - ls := &listsSharder{ - peerShardResolver: arg.PeerResolver, - selfPeerId: arg.SelfPeerId, - maxPeerCount: peersConn.maxPeerCount, - computeDistance: computeDistanceByCountingBits, - maxIntraShardValidators: peersConn.intraShardValidators, - maxCrossShardValidators: peersConn.crossShardValidators, - maxIntraShardObservers: peersConn.intraShardObservers, - maxCrossShardObservers: peersConn.crossShardObservers, - maxSeeders: peersConn.seeders, - maxFullHistoryObservers: peersConn.fullHistoryObservers, - maxUnknown: peersConn.unknown, - preferredPeersHolder: arg.PreferredPeersHolder, - } - - return ls, nil -} - -func processNumConnections(arg ArgListsSharder) (peersConnections, error) { - peersConn := peersConnections{ - maxPeerCount: int(arg.P2pConfig.Sharding.TargetPeerCount), - intraShardValidators: int(arg.P2pConfig.Sharding.MaxIntraShardValidators), - crossShardValidators: int(arg.P2pConfig.Sharding.MaxCrossShardValidators), - intraShardObservers: int(arg.P2pConfig.Sharding.MaxIntraShardObservers), - crossShardObservers: int(arg.P2pConfig.Sharding.MaxCrossShardObservers), - seeders: int(arg.P2pConfig.Sharding.MaxSeeders), - fullHistoryObservers: 0, - } - if arg.NodeOperationMode == p2p.FullArchiveMode { - peersConn.fullHistoryObservers = int(arg.P2pConfig.Sharding.AdditionalConnections.MaxFullHistoryObservers) - peersConn.maxPeerCount += peersConn.fullHistoryObservers - } - - if peersConn.crossShardObservers+peersConn.intraShardObservers+peersConn.fullHistoryObservers == 0 { - log.Warn("No connections to observers are possible. This is NOT a recommended setting!") - } - - providedPeers := peersConn.intraShardValidators + peersConn.crossShardValidators + - peersConn.intraShardObservers + peersConn.crossShardObservers + - peersConn.seeders + peersConn.fullHistoryObservers - if providedPeers+minUnknownPeers > peersConn.maxPeerCount { - return peersConnections{}, fmt.Errorf("%w, maxValidators + maxObservers + seeders + full archive nodes should be less than %d", p2p.ErrInvalidValue, peersConn.maxPeerCount) - } - - peersConn.unknown = peersConn.maxPeerCount - providedPeers - - return peersConn, nil -} - -// ComputeEvictionList returns the eviction list -func (ls *listsSharder) ComputeEvictionList(pidList []peer.ID) []peer.ID { - peerDistances := ls.splitPeerIds(pidList) - - existingNumIntraShardValidators := len(peerDistances[intraShardValidators]) - existingNumIntraShardObservers := len(peerDistances[intraShardObservers]) - existingNumCrossShardValidators := len(peerDistances[crossShardValidators]) - existingNumCrossShardObservers := len(peerDistances[crossShardObservers]) - existingNumSeeders := len(peerDistances[seeders]) - existingNumFullHistoryObservers := len(peerDistances[fullHistoryObservers]) - existingNumUnknown := len(peerDistances[unknown]) - - var numIntraShardValidators, numCrossShardValidators int - var numIntraShardObservers, numCrossShardObservers int - var numFullHistoryObservers int - var numSeeders, numUnknown, remaining int - - numIntraShardValidators, remaining = computeUsedAndSpare(existingNumIntraShardValidators, ls.maxIntraShardValidators) - numCrossShardValidators, remaining = computeUsedAndSpare(existingNumCrossShardValidators, ls.maxCrossShardValidators+remaining) - numIntraShardObservers, remaining = computeUsedAndSpare(existingNumIntraShardObservers, ls.maxIntraShardObservers+remaining) - numCrossShardObservers, remaining = computeUsedAndSpare(existingNumCrossShardObservers, ls.maxCrossShardObservers+remaining) - numSeeders, _ = computeUsedAndSpare(existingNumSeeders, ls.maxSeeders) // we are not mixing remaining value. We are strict with the number of seeders - numFullHistoryObservers, _ = computeUsedAndSpare(existingNumFullHistoryObservers, ls.maxFullHistoryObservers) - numUnknown, _ = computeUsedAndSpare(existingNumUnknown, ls.maxUnknown+remaining) - - evictionProposed := evict(peerDistances[intraShardValidators], numIntraShardValidators) - e := evict(peerDistances[crossShardValidators], numCrossShardValidators) - evictionProposed = append(evictionProposed, e...) - e = evict(peerDistances[intraShardObservers], numIntraShardObservers) - evictionProposed = append(evictionProposed, e...) - e = evict(peerDistances[crossShardObservers], numCrossShardObservers) - evictionProposed = append(evictionProposed, e...) - e = evict(peerDistances[seeders], numSeeders) - evictionProposed = append(evictionProposed, e...) - e = evict(peerDistances[fullHistoryObservers], numFullHistoryObservers) - evictionProposed = append(evictionProposed, e...) - e = evict(peerDistances[unknown], numUnknown) - evictionProposed = append(evictionProposed, e...) - - return evictionProposed -} - -// computeUsedAndSpare returns the used and the remaining of the two provided (capacity) values -// if used > maximum, used will equal to maximum and remaining will be 0 -func computeUsedAndSpare(existing int, maximum int) (int, int) { - if existing < maximum { - return existing, maximum - existing - } - - return maximum, 0 -} - -// Has returns true if provided pid is among the provided list -func (ls *listsSharder) Has(pid peer.ID, list []peer.ID) bool { - return has(pid, list) -} - -func has(pid peer.ID, list []peer.ID) bool { - for _, p := range list { - if p == pid { - return true - } - } - - return false -} - -func (ls *listsSharder) splitPeerIds(peers []peer.ID) map[int]sorting.PeerDistances { - peerDistances := map[int]sorting.PeerDistances{ - intraShardValidators: {}, - intraShardObservers: {}, - crossShardValidators: {}, - crossShardObservers: {}, - fullHistoryObservers: {}, - seeders: {}, - unknown: {}, - } - - ls.mutResolver.RLock() - selfPeerInfo := ls.peerShardResolver.GetPeerInfo(core.PeerID(ls.selfPeerId)) - ls.mutResolver.RUnlock() - - for _, p := range peers { - pd := &sorting.PeerDistance{ - ID: p, - Distance: ls.computeDistance(p, ls.selfPeerId), - } - pid := core.PeerID(p) - isSeeder := ls.IsSeeder(pid) - if isSeeder { - peerDistances[seeders] = append(peerDistances[seeders], pd) - continue - } - - ls.mutResolver.RLock() - peerInfo := ls.peerShardResolver.GetPeerInfo(pid) - ls.mutResolver.RUnlock() - - if ls.preferredPeersHolder.Contains(pid) { - continue - } - - if peerInfo.PeerType == core.UnknownPeer { - peerDistances[unknown] = append(peerDistances[unknown], pd) - continue - } - - isCrossShard := peerInfo.ShardID != selfPeerInfo.ShardID - if isCrossShard { - switch peerInfo.PeerType { - case core.ValidatorPeer: - peerDistances[crossShardValidators] = append(peerDistances[crossShardValidators], pd) - case core.ObserverPeer: - peerDistances[crossShardObservers] = append(peerDistances[crossShardObservers], pd) - } - - continue - } - - switch peerInfo.PeerType { - case core.ValidatorPeer: - peerDistances[intraShardValidators] = append(peerDistances[intraShardValidators], pd) - case core.ObserverPeer: - shouldAppendToFullHistory := peerInfo.PeerSubType == core.FullHistoryObserver && ls.maxFullHistoryObservers > 0 - if shouldAppendToFullHistory { - peerDistances[fullHistoryObservers] = append(peerDistances[fullHistoryObservers], pd) - } else { - peerDistances[intraShardObservers] = append(peerDistances[intraShardObservers], pd) - } - } - } - - return peerDistances -} - -func evict(distances sorting.PeerDistances, numKeep int) []peer.ID { - if numKeep < 0 { - numKeep = 0 - } - if numKeep >= len(distances) { - return make([]peer.ID, 0) - } - - sort.Sort(distances) - evictedPD := distances[numKeep:] - evictedPids := make([]peer.ID, len(evictedPD)) - for i, pd := range evictedPD { - evictedPids[i] = pd.ID - } - - return evictedPids -} - -// computes the kademlia distance between 2 provided peers by doing byte xor operations and counting the resulting bits -func computeDistanceByCountingBits(src peer.ID, dest peer.ID) *big.Int { - srcBuff := kbucket.ConvertPeerID(src) - destBuff := kbucket.ConvertPeerID(dest) - - cumulatedBits := 0 - for i := 0; i < len(srcBuff); i++ { - result := srcBuff[i] ^ destBuff[i] - cumulatedBits += bits.OnesCount8(result) - } - - return big.NewInt(0).SetInt64(int64(cumulatedBits)) -} - -// computes the kademlia distance between 2 provided peers by doing byte xor operations and applying log2 on the result -func computeDistanceLog2Based(src peer.ID, dest peer.ID) *big.Int { - srcBuff := kbucket.ConvertPeerID(src) - destBuff := kbucket.ConvertPeerID(dest) - - val := 0 - for i := 0; i < len(srcBuff); i++ { - result := srcBuff[i] ^ destBuff[i] - val += leadingZerosCount[result] - if result != 0 { - break - } - } - - val = len(srcBuff)*8 - val - - return big.NewInt(0).SetInt64(int64(val)) -} - -// IsSeeder returns true if the provided peer is a seeder -func (ls *listsSharder) IsSeeder(pid core.PeerID) bool { - ls.mutSeeders.RLock() - defer ls.mutSeeders.RUnlock() - - strPretty := pid.Pretty() - for _, seeder := range ls.seeders { - if strings.Contains(seeder, strPretty) { - return true - } - } - - return false -} - -// SetSeeders will set the seeders -func (ls *listsSharder) SetSeeders(addresses []string) { - ls.mutSeeders.Lock() - ls.seeders = addresses - ls.mutSeeders.Unlock() -} - -// SetPeerShardResolver sets the peer shard resolver for this sharder -func (ls *listsSharder) SetPeerShardResolver(psp p2p.PeerShardResolver) error { - if check.IfNil(psp) { - return p2p.ErrNilPeerShardResolver - } - - ls.mutResolver.Lock() - ls.peerShardResolver = psp - ls.mutResolver.Unlock() - - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ls *listsSharder) IsInterfaceNil() bool { - return ls == nil -} diff --git a/p2p/libp2p/networksharding/listsSharder_test.go b/p2p/libp2p/networksharding/listsSharder_test.go deleted file mode 100644 index 0470db2fadf..00000000000 --- a/p2p/libp2p/networksharding/listsSharder_test.go +++ /dev/null @@ -1,572 +0,0 @@ -package networksharding - -import ( - "encoding/hex" - "errors" - "fmt" - "strings" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/ElrondNetwork/elrond-go/p2p/peersHolder" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const crtShardId = uint32(0) -const crossShardId = uint32(1) - -const validatorMarker = "validator" -const observerMarker = "observer" -const unknownMarker = "unknown" -const seederMarker = "seeder" - -var crtPid = peer.ID(fmt.Sprintf("%d pid", crtShardId)) - -func createStringPeersShardResolver() *mock.PeerShardResolverStub { - return &mock.PeerShardResolverStub{ - GetPeerInfoCalled: func(pid core.PeerID) core.P2PPeerInfo { - strPid := string(pid) - pInfo := core.P2PPeerInfo{} - - if strings.Contains(strPid, fmt.Sprintf("%d", crtShardId)) { - pInfo.ShardID = crtShardId - } else { - pInfo.ShardID = crossShardId - } - - if strings.Contains(strPid, unknownMarker) { - pInfo.PeerType = core.UnknownPeer - } - if strings.Contains(strPid, validatorMarker) { - pInfo.PeerType = core.ValidatorPeer - } - if strings.Contains(strPid, observerMarker) { - pInfo.PeerType = core.ObserverPeer - } - - return pInfo - }, - } -} - -func countPeers(peers []peer.ID, shardID uint32, marker string) int { - counter := 0 - for _, pid := range peers { - if strings.Contains(string(pid), marker) && - strings.Contains(string(pid), fmt.Sprintf("%d", shardID)) { - counter++ - } - } - - return counter -} - -func createMockListSharderArguments() ArgListsSharder { - return ArgListsSharder{ - PeerResolver: createStringPeersShardResolver(), - SelfPeerId: crtPid, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - P2pConfig: config.P2PConfig{ - Sharding: config.ShardingConfig{ - TargetPeerCount: minAllowedConnectedPeersListSharder, - MaxIntraShardValidators: minAllowedValidators, - MaxCrossShardValidators: minAllowedValidators, - MaxIntraShardObservers: minAllowedObservers, - MaxCrossShardObservers: minAllowedObservers, - MaxSeeders: 0, - }, - }, - } -} - -func TestNewListsSharder_InvalidMinimumTargetPeerCountShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.TargetPeerCount = minAllowedConnectedPeersListSharder - 1 - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) - assert.True(t, strings.Contains(err.Error(), "maxPeerCount should be at least")) -} - -func TestNewListsSharder_NilPeerShardResolverShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.PeerResolver = nil - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrNilPeerShardResolver)) -} - -func TestNewListsSharder_InvalidIntraShardValidatorsShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.MaxIntraShardValidators = minAllowedValidators - 1 - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewListsSharder_InvalidCrossShardValidatorsShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.MaxCrossShardValidators = minAllowedValidators - 1 - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewListsSharder_InvalidIntraShardObserversShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.MaxIntraShardObservers = minAllowedObservers - 1 - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewListsSharder_InvalidCrossShardObserversShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.MaxCrossShardObservers = minAllowedObservers - 1 - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewListsSharder_NoRoomForUnknownShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.MaxCrossShardObservers = minAllowedObservers + 1 - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewListsSharder_NilPreferredPeersShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.PreferredPeersHolder = nil - ls, err := NewListsSharder(arg) - - assert.True(t, check.IfNil(ls)) - assert.True(t, errors.Is(err, p2p.ErrNilPreferredPeersHolder)) -} - -func TestNewListsSharder_NormalShouldWork(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.TargetPeerCount = 25 - arg.P2pConfig.Sharding.MaxIntraShardValidators = 6 - arg.P2pConfig.Sharding.MaxCrossShardValidators = 5 - arg.P2pConfig.Sharding.MaxIntraShardObservers = 4 - arg.P2pConfig.Sharding.MaxCrossShardObservers = 3 - arg.P2pConfig.Sharding.MaxSeeders = 2 - arg.P2pConfig.Sharding.AdditionalConnections.MaxFullHistoryObservers = 1 - ls, err := NewListsSharder(arg) - - assert.False(t, check.IfNil(ls)) - assert.Nil(t, err) - assert.Equal(t, 25, ls.maxPeerCount) - assert.Equal(t, 6, ls.maxIntraShardValidators) - assert.Equal(t, 5, ls.maxCrossShardValidators) - assert.Equal(t, 4, ls.maxIntraShardObservers) - assert.Equal(t, 3, ls.maxCrossShardObservers) - assert.Equal(t, 2, ls.maxSeeders) - assert.Equal(t, 0, ls.maxFullHistoryObservers) - assert.Equal(t, 5, ls.maxUnknown) -} - -func TestNewListsSharder_FullArchiveShouldWork(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.NodeOperationMode = p2p.FullArchiveMode - arg.P2pConfig.Sharding.TargetPeerCount = 25 - arg.P2pConfig.Sharding.MaxIntraShardValidators = 6 - arg.P2pConfig.Sharding.MaxCrossShardValidators = 5 - arg.P2pConfig.Sharding.MaxIntraShardObservers = 4 - arg.P2pConfig.Sharding.MaxCrossShardObservers = 3 - arg.P2pConfig.Sharding.MaxSeeders = 2 - arg.P2pConfig.Sharding.AdditionalConnections.MaxFullHistoryObservers = 1 - ls, err := NewListsSharder(arg) - - assert.False(t, check.IfNil(ls)) - assert.Nil(t, err) - assert.Equal(t, 26, ls.maxPeerCount) - assert.Equal(t, 6, ls.maxIntraShardValidators) - assert.Equal(t, 5, ls.maxCrossShardValidators) - assert.Equal(t, 4, ls.maxIntraShardObservers) - assert.Equal(t, 3, ls.maxCrossShardObservers) - assert.Equal(t, 2, ls.maxSeeders) - assert.Equal(t, 1, ls.maxFullHistoryObservers) - assert.Equal(t, 5, ls.maxUnknown) -} - -// ------- ComputeEvictionList - -func TestListsSharder_ComputeEvictionListNotReachedValidatorsShouldRetEmpty(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - ls, _ := NewListsSharder(arg) - pidCrtShard := peer.ID(fmt.Sprintf("%d %s", crtShardId, validatorMarker)) - pidCrossShard := peer.ID(fmt.Sprintf("%d %s", crossShardId, validatorMarker)) - pids := []peer.ID{pidCrtShard, pidCrossShard} - - evictList := ls.ComputeEvictionList(pids) - - assert.Equal(t, 0, len(evictList)) -} - -func TestListsSharder_ComputeEvictionListNotReachedObserversShouldRetEmpty(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - ls, _ := NewListsSharder(arg) - pidCrtShard := peer.ID(fmt.Sprintf("%d %s", crtShardId, observerMarker)) - pidCrossShard := peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)) - pids := []peer.ID{pidCrtShard, pidCrossShard} - - evictList := ls.ComputeEvictionList(pids) - - assert.Equal(t, 0, len(evictList)) -} - -func TestListsSharder_ComputeEvictionListNotReachedUnknownShouldRetEmpty(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - ls, _ := NewListsSharder(arg) - pidUnknown := peer.ID(fmt.Sprintf("0 %s", unknownMarker)) - pids := []peer.ID{pidUnknown} - - evictList := ls.ComputeEvictionList(pids) - - assert.Equal(t, 0, len(evictList)) -} - -func TestListsSharder_ComputeEvictionListReachedIntraShardShouldSortAndEvict(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - ls, _ := NewListsSharder(arg) - pidCrtShard1 := peer.ID(fmt.Sprintf("%d - 1 - %s", crtShardId, validatorMarker)) - pidCrtShard2 := peer.ID(fmt.Sprintf("%d - 2 - %s", crtShardId, validatorMarker)) - pids := []peer.ID{pidCrtShard2, pidCrtShard1} - - evictList := ls.ComputeEvictionList(pids) - - assert.Equal(t, 1, len(evictList)) - assert.Equal(t, pidCrtShard1, evictList[0]) -} - -func TestListsSharder_ComputeEvictionListUnknownPeersShouldFillTheGap(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.TargetPeerCount = 5 - ls, _ := NewListsSharder(arg) - - unknownPids := make([]peer.ID, arg.P2pConfig.Sharding.TargetPeerCount) - for i := 0; i < int(arg.P2pConfig.Sharding.TargetPeerCount); i++ { - unknownPids[i] = unknownMarker - } - newUnknownPid := peer.ID(unknownMarker) - unknownPids = append(unknownPids, newUnknownPid) - - evictList := ls.ComputeEvictionList(unknownPids) - - assert.Equal(t, 1, len(evictList)) - assert.Equal(t, unknownPids[0], evictList[0]) -} - -func TestListsSharder_ComputeEvictionListCrossShouldFillTheGap(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.TargetPeerCount = 5 - arg.P2pConfig.Sharding.MaxIntraShardValidators = 1 - arg.P2pConfig.Sharding.MaxCrossShardValidators = 1 - arg.P2pConfig.Sharding.MaxIntraShardObservers = 1 - arg.P2pConfig.Sharding.MaxCrossShardObservers = 1 - ls, _ := NewListsSharder(arg) - - pids := []peer.ID{ - peer.ID(fmt.Sprintf("%d %s", crossShardId, validatorMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, validatorMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, unknownMarker)), - } - - evictList := ls.ComputeEvictionList(pids) - - assert.Equal(t, 0, len(evictList)) -} - -func TestListsSharder_ComputeEvictionListEvictFromAllShouldWork(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - arg.P2pConfig.Sharding.TargetPeerCount = 6 - arg.P2pConfig.Sharding.MaxIntraShardValidators = 1 - arg.P2pConfig.Sharding.MaxCrossShardValidators = 1 - arg.P2pConfig.Sharding.MaxIntraShardObservers = 1 - arg.P2pConfig.Sharding.MaxCrossShardObservers = 1 - arg.P2pConfig.Sharding.MaxSeeders = 1 - ls, _ := NewListsSharder(arg) - seeder := peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)) - ls.SetSeeders([]string{ - "ip6/" + seeder.Pretty(), - }) - - pids := []peer.ID{ - peer.ID(fmt.Sprintf("%d %s", crtShardId, validatorMarker)), - peer.ID(fmt.Sprintf("%d %s", crtShardId, validatorMarker)), - - peer.ID(fmt.Sprintf("%d %s", crossShardId, validatorMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, validatorMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, validatorMarker)), - - peer.ID(fmt.Sprintf("%d %s", crtShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crtShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crtShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crtShardId, observerMarker)), - - peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, observerMarker)), - - peer.ID(fmt.Sprintf("%d %s", crossShardId, unknownMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, unknownMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, unknownMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, unknownMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, unknownMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, unknownMarker)), - - peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)), - peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)), - } - - evictList := ls.ComputeEvictionList(pids) - - assert.Equal(t, 21, len(evictList)) - assert.Equal(t, 1, countPeers(evictList, crtShardId, validatorMarker)) - assert.Equal(t, 2, countPeers(evictList, crossShardId, validatorMarker)) - assert.Equal(t, 3, countPeers(evictList, crtShardId, observerMarker)) - assert.Equal(t, 4, countPeers(evictList, crossShardId, observerMarker)) - assert.Equal(t, 5, countPeers(evictList, crossShardId, unknownMarker)) - assert.Equal(t, 6, countPeers(evictList, crossShardId, seederMarker)) -} - -func TestListsSharder_ComputeEvictionListShouldNotContainPreferredPeers(t *testing.T) { - arg := createMockListSharderArguments() - pids := []peer.ID{ - "preferredPeer0", - "peer0", - "peer1", - "preferredPeer1", - "peer2", - "preferredPeer2", - } - arg.PreferredPeersHolder = &p2pmocks.PeersHolderStub{ - ContainsCalled: func(peerID core.PeerID) bool { - return strings.HasPrefix(string(peerID), "preferred") - }, - } - - ls, _ := NewListsSharder(arg) - seeder := peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)) - ls.SetSeeders([]string{ - "ip6/" + seeder.Pretty(), - }) - - evictList := ls.ComputeEvictionList(pids) - - for _, peerID := range evictList { - require.False(t, strings.HasPrefix(string(peerID), "preferred")) - } -} - -func TestListsSharder_ComputeEvictionListWithRealPreferredPeersHandler(t *testing.T) { - arg := createMockListSharderArguments() - - preferredHexPrefix := "preferred" - prefP0 := preferredHexPrefix + "preferredPeer0" - prefP1 := preferredHexPrefix + "preferredPeer1" - prefP2 := preferredHexPrefix + "preferredPeer2" - pubKeyHexSuffix := hex.EncodeToString([]byte("pubKey")) - pids := []peer.ID{ - peer.ID(core.PeerID(prefP0).Pretty()), - "peer0", - "peer1", - peer.ID(core.PeerID(prefP1).Pretty()), - "peer2", - peer.ID(core.PeerID(prefP2).Pretty()), - } - - prefPeers := []string{ - core.PeerID(prefP0).Pretty(), - core.PeerID(prefP1).Pretty(), - core.PeerID(prefP2).Pretty(), - } - - arg.PreferredPeersHolder, _ = peersHolder.NewPeersHolder(prefPeers) - for _, prefPid := range prefPeers { - peerId := core.PeerID(prefPid) - arg.PreferredPeersHolder.PutConnectionAddress(peerId, prefPid) - arg.PreferredPeersHolder.PutShardID(peerId, 0) - } - - arg.PeerResolver = &mock.PeerShardResolverStub{ - GetPeerInfoCalled: func(pid core.PeerID) core.P2PPeerInfo { - if strings.HasPrefix(string(pid), preferredHexPrefix) { - pkBytes, _ := hex.DecodeString(string(pid) + pubKeyHexSuffix) - return core.P2PPeerInfo{ - PeerType: 0, - PeerSubType: 0, - ShardID: 0, - PkBytes: pkBytes, - } - } - return core.P2PPeerInfo{} - }, - } - ls, _ := NewListsSharder(arg) - seeder := peer.ID(fmt.Sprintf("%d %s", crossShardId, seederMarker)) - ls.SetSeeders([]string{ - "ip6/" + seeder.Pretty(), - }) - - evictList := ls.ComputeEvictionList(pids) - for _, peerID := range evictList { - require.False(t, strings.HasPrefix(string(peerID), preferredHexPrefix)) - } - - found := arg.PreferredPeersHolder.Contains(core.PeerID(peer.ID(prefP0).Pretty())) - require.True(t, found) - - found = arg.PreferredPeersHolder.Contains(core.PeerID(peer.ID(prefP1).Pretty())) - require.True(t, found) - - found = arg.PreferredPeersHolder.Contains(core.PeerID(peer.ID(prefP2).Pretty())) - require.True(t, found) - - peers := arg.PreferredPeersHolder.Get() - expectedMap := map[uint32][]core.PeerID{ - 0: { - core.PeerID(peer.ID(prefP0).Pretty()), - core.PeerID(peer.ID(prefP1).Pretty()), - core.PeerID(peer.ID(prefP2).Pretty()), - }, - } - require.Equal(t, expectedMap, peers) -} - -// ------- Has - -func TestListsSharder_HasNotFound(t *testing.T) { - t.Parallel() - - list := []peer.ID{"pid1", "pid2", "pid3"} - ls := &listsSharder{} - - assert.False(t, ls.Has("pid4", list)) -} - -func TestListsSharder_HasEmpty(t *testing.T) { - t.Parallel() - - list := make([]peer.ID, 0) - lks := &listsSharder{} - - assert.False(t, lks.Has("pid4", list)) -} - -func TestListsSharder_HasFound(t *testing.T) { - t.Parallel() - - list := []peer.ID{"pid1", "pid2", "pid3"} - lks := &listsSharder{} - - assert.True(t, lks.Has("pid2", list)) -} - -// ------- computeDistance - -func TestComputeDistanceByCountingBits(t *testing.T) { - t.Parallel() - - // compute will be done on hashes. Impossible to predict the outcome in this test - assert.Equal(t, uint64(0), computeDistanceByCountingBits("", "").Uint64()) - assert.Equal(t, uint64(0), computeDistanceByCountingBits("a", "a").Uint64()) - assert.Equal(t, uint64(139), computeDistanceByCountingBits(peer.ID([]byte{0}), peer.ID([]byte{1})).Uint64()) - assert.Equal(t, uint64(130), computeDistanceByCountingBits(peer.ID([]byte{0}), peer.ID([]byte{255})).Uint64()) - assert.Equal(t, uint64(117), computeDistanceByCountingBits(peer.ID([]byte{0, 128}), peer.ID([]byte{255, 255})).Uint64()) -} - -func TestComputeDistanceLog2Based(t *testing.T) { - t.Parallel() - - // compute will be done on hashes. Impossible to predict the outcome in this test - assert.Equal(t, uint64(0), computeDistanceLog2Based("", "").Uint64()) - assert.Equal(t, uint64(0), computeDistanceLog2Based("a", "a").Uint64()) - assert.Equal(t, uint64(254), computeDistanceLog2Based(peer.ID([]byte{0}), peer.ID([]byte{1})).Uint64()) - assert.Equal(t, uint64(250), computeDistanceLog2Based(peer.ID([]byte{254}), peer.ID([]byte{255})).Uint64()) - assert.Equal(t, uint64(256), computeDistanceLog2Based(peer.ID([]byte{0, 128}), peer.ID([]byte{255, 255})).Uint64()) -} - -func TestListsSharder_SetPeerShardResolverNilShouldErr(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - lks, _ := NewListsSharder(arg) - - err := lks.SetPeerShardResolver(nil) - - assert.Equal(t, p2p.ErrNilPeerShardResolver, err) -} - -func TestListsSharder_SetPeerShardResolverShouldWork(t *testing.T) { - t.Parallel() - - arg := createMockListSharderArguments() - lks, _ := NewListsSharder(arg) - newPeerShardResolver := &mock.PeerShardResolverStub{} - err := lks.SetPeerShardResolver(newPeerShardResolver) - - // pointer testing - assert.True(t, lks.peerShardResolver == newPeerShardResolver) - assert.Nil(t, err) -} diff --git a/p2p/libp2p/networksharding/nilListSharder.go b/p2p/libp2p/networksharding/nilListSharder.go deleted file mode 100644 index 6898d592e80..00000000000 --- a/p2p/libp2p/networksharding/nilListSharder.go +++ /dev/null @@ -1,46 +0,0 @@ -package networksharding - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/peer" -) - -var _ p2p.Sharder = (*nilListSharder)(nil) - -// nilListSharder will not cause connections trimming -type nilListSharder struct{} - -// NewNilListSharder returns a disabled sharder implementation -func NewNilListSharder() *nilListSharder { - return &nilListSharder{} -} - -// ComputeEvictionList will always output an empty list as to not cause connection trimming -func (nls *nilListSharder) ComputeEvictionList(_ []peer.ID) []peer.ID { - return make([]peer.ID, 0) -} - -// Has will output false, causing all peers to connect to each other -func (nls *nilListSharder) Has(_ peer.ID, _ []peer.ID) bool { - return false -} - -// SetPeerShardResolver will do nothing -func (nls *nilListSharder) SetPeerShardResolver(_ p2p.PeerShardResolver) error { - return nil -} - -// SetSeeders does nothing -func (nls *nilListSharder) SetSeeders(_ []string) { -} - -// IsSeeder returns false -func (nls *nilListSharder) IsSeeder(_ core.PeerID) bool { - return false -} - -// IsInterfaceNil returns true if there is no value under the interface -func (nls *nilListSharder) IsInterfaceNil() bool { - return nls == nil -} diff --git a/p2p/libp2p/networksharding/nilListSharder_test.go b/p2p/libp2p/networksharding/nilListSharder_test.go deleted file mode 100644 index 13a9c742768..00000000000 --- a/p2p/libp2p/networksharding/nilListSharder_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package networksharding - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/stretchr/testify/assert" -) - -func TestNilListSharderSharder(t *testing.T) { - nls := NewNilListSharder() - - assert.False(t, check.IfNil(nls)) - assert.Equal(t, 0, len(nls.ComputeEvictionList(nil))) - assert.False(t, nls.Has("", nil)) - assert.Nil(t, nls.SetPeerShardResolver(nil)) -} diff --git a/p2p/libp2p/networksharding/oneListSharder.go b/p2p/libp2p/networksharding/oneListSharder.go deleted file mode 100644 index fbae1e14771..00000000000 --- a/p2p/libp2p/networksharding/oneListSharder.go +++ /dev/null @@ -1,83 +0,0 @@ -package networksharding - -import ( - "fmt" - "math/big" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/networksharding/sorting" - "github.com/libp2p/go-libp2p-core/peer" -) - -var _ p2p.Sharder = (*oneListSharder)(nil) - -const minAllowedConnectedPeersOneSharder = 3 - -type oneListSharder struct { - selfPeerId peer.ID - maxPeerCount int - computeDistance func(src peer.ID, dest peer.ID) *big.Int -} - -// NewOneListSharder creates a new sharder instance that is shard agnostic and uses one list -func NewOneListSharder( - selfPeerId peer.ID, - maxPeerCount int, -) (*oneListSharder, error) { - if maxPeerCount < minAllowedConnectedPeersOneSharder { - return nil, fmt.Errorf("%w, maxPeerCount should be at least %d", p2p.ErrInvalidValue, minAllowedConnectedPeersOneSharder) - } - - return &oneListSharder{ - selfPeerId: selfPeerId, - maxPeerCount: maxPeerCount, - computeDistance: computeDistanceByCountingBits, - }, nil -} - -// ComputeEvictionList returns the eviction list -func (ols *oneListSharder) ComputeEvictionList(pidList []peer.ID) []peer.ID { - list := ols.convertList(pidList) - evictionProposed := evict(list, ols.maxPeerCount) - - return evictionProposed -} - -func (ols *oneListSharder) convertList(peers []peer.ID) sorting.PeerDistances { - list := sorting.PeerDistances{} - - for _, p := range peers { - pd := &sorting.PeerDistance{ - ID: p, - Distance: ols.computeDistance(p, ols.selfPeerId), - } - list = append(list, pd) - } - - return list -} - -// Has returns true if provided pid is among the provided list -func (ols *oneListSharder) Has(pid peer.ID, list []peer.ID) bool { - return has(pid, list) -} - -// SetPeerShardResolver sets the peer shard resolver for this sharder. Doesn't do anything in this implementation -func (ols *oneListSharder) SetPeerShardResolver(_ p2p.PeerShardResolver) error { - return nil -} - -// SetSeeders does nothing as all peers are treated equally in this implementation -func (ols *oneListSharder) SetSeeders(_ []string) { -} - -// IsSeeder returns false -func (ols *oneListSharder) IsSeeder(_ core.PeerID) bool { - return false -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ols *oneListSharder) IsInterfaceNil() bool { - return ols == nil -} diff --git a/p2p/libp2p/networksharding/oneListSharder_test.go b/p2p/libp2p/networksharding/oneListSharder_test.go deleted file mode 100644 index 7b6bc831b4b..00000000000 --- a/p2p/libp2p/networksharding/oneListSharder_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package networksharding - -import ( - "errors" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" -) - -func TestNewOneListSharder_InvalidMaxPeerCountShouldErr(t *testing.T) { - t.Parallel() - - ols, err := NewOneListSharder( - "", - minAllowedConnectedPeersOneSharder-1, - ) - - assert.True(t, check.IfNil(ols)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestNewOneListSharder_ShouldWork(t *testing.T) { - t.Parallel() - - ols, err := NewOneListSharder( - "", - minAllowedConnectedPeersOneSharder, - ) - - assert.False(t, check.IfNil(ols)) - assert.Nil(t, err) -} - -// ------- ComputeEvictionList - -func TestOneListSharder_ComputeEvictionListNotReachedShouldRetEmpty(t *testing.T) { - t.Parallel() - - ols, _ := NewOneListSharder( - crtPid, - minAllowedConnectedPeersOneSharder, - ) - pid1 := peer.ID("pid1") - pid2 := peer.ID("pid2") - pids := []peer.ID{pid1, pid2} - - evictList := ols.ComputeEvictionList(pids) - - assert.Equal(t, 0, len(evictList)) -} - -func TestOneListSharder_ComputeEvictionListReachedIntraShardShouldSortAndEvict(t *testing.T) { - t.Parallel() - - ols, _ := NewOneListSharder( - crtPid, - minAllowedConnectedPeersOneSharder, - ) - pid1 := peer.ID("pid1") - pid2 := peer.ID("pid2") - pid3 := peer.ID("pid3") - pid4 := peer.ID("pid4") - pids := []peer.ID{pid1, pid2, pid3, pid4} - - evictList := ols.ComputeEvictionList(pids) - - assert.Equal(t, 1, len(evictList)) - assert.Equal(t, pid3, evictList[0]) -} - -// ------- Has - -func TestOneListSharder_HasNotFound(t *testing.T) { - t.Parallel() - - list := []peer.ID{"pid1", "pid2", "pid3"} - lnks := &oneListSharder{} - - assert.False(t, lnks.Has("pid4", list)) -} - -func TestOneListSharder_HasEmpty(t *testing.T) { - t.Parallel() - - list := make([]peer.ID, 0) - lnks := &oneListSharder{} - - assert.False(t, lnks.Has("pid4", list)) -} - -func TestOneListSharder_HasFound(t *testing.T) { - t.Parallel() - - list := []peer.ID{"pid1", "pid2", "pid3"} - lnks := &oneListSharder{} - - assert.True(t, lnks.Has("pid2", list)) -} - -func TestOneListSharder_SetPeerShardResolverShouldNotPanic(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r != nil { - assert.Fail(t, "should not have paniced") - } - }() - - ols, _ := NewOneListSharder( - "", - minAllowedConnectedPeersOneSharder, - ) - - err := ols.SetPeerShardResolver(nil) - - assert.Nil(t, err) -} diff --git a/p2p/libp2p/networksharding/sorting/peerDistances.go b/p2p/libp2p/networksharding/sorting/peerDistances.go deleted file mode 100644 index 87cd6e25923..00000000000 --- a/p2p/libp2p/networksharding/sorting/peerDistances.go +++ /dev/null @@ -1,32 +0,0 @@ -package sorting - -import ( - "math/big" - - "github.com/libp2p/go-libp2p-core/peer" -) - -// PeerDistance is a composite struct on top of a peer ID that also contains the kad distance measured -// against the current peer and held as a big.Int -type PeerDistance struct { - peer.ID - Distance *big.Int -} - -// PeerDistances represents a sortable peerDistance slice -type PeerDistances []*PeerDistance - -// Len returns the length of this slice -func (pd PeerDistances) Len() int { - return len(pd) -} - -// Less is used in sorting and returns if i-th element is less than j-th element -func (pd PeerDistances) Less(i, j int) bool { - return pd[i].Distance.Cmp(pd[j].Distance) < 0 -} - -// Swap is used in sorting and swaps the values between the i-th position with the one found on j-th position -func (pd PeerDistances) Swap(i, j int) { - pd[i], pd[j] = pd[j], pd[i] -} diff --git a/p2p/libp2p/networksharding/sorting/peerDistances_test.go b/p2p/libp2p/networksharding/sorting/peerDistances_test.go deleted file mode 100644 index a495cb6d293..00000000000 --- a/p2p/libp2p/networksharding/sorting/peerDistances_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package sorting - -import ( - "fmt" - "math/big" - "sort" - "testing" - - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" -) - -func createPeerDistance(distance int) *PeerDistance { - return &PeerDistance{ - ID: peer.ID(fmt.Sprintf("pid_%d", distance)), - Distance: big.NewInt(int64(distance)), - } -} - -func TestPeerDistances_Sort(t *testing.T) { - t.Parallel() - - pid4 := createPeerDistance(4) - pid0 := createPeerDistance(0) - pid100 := createPeerDistance(100) - pid1 := createPeerDistance(1) - pid2 := createPeerDistance(2) - - pids := PeerDistances{pid4, pid0, pid100, pid1, pid2} - sort.Sort(pids) - - assert.Equal(t, pid0, pids[0]) - assert.Equal(t, pid1, pids[1]) - assert.Equal(t, pid2, pids[2]) - assert.Equal(t, pid4, pids[3]) - assert.Equal(t, pid100, pids[4]) - assert.Equal(t, 5, len(pids)) -} diff --git a/p2p/libp2p/networksharding/sorting/sortedList.go b/p2p/libp2p/networksharding/sorting/sortedList.go deleted file mode 100644 index 13ab9da115a..00000000000 --- a/p2p/libp2p/networksharding/sorting/sortedList.go +++ /dev/null @@ -1,49 +0,0 @@ -package sorting - -import ( - "math/big" - "sort" - - "github.com/libp2p/go-libp2p-core/peer" -) - -// SortedID contains the peer data -type SortedID struct { - ID peer.ID - Key []byte - Shard uint32 - Distance *big.Int -} - -// SortedList holds a sorted list of elements in respect with the reference value -type SortedList struct { - Ref SortedID - Peers []SortedID -} - -// Len is the number of elements in the collection. -func (sl *SortedList) Len() int { - return len(sl.Peers) -} - -// Less reports whether the element with -// index i should sort before the element with index j. -func (sl *SortedList) Less(i int, j int) bool { - return sl.Peers[i].Distance.Cmp(sl.Peers[j].Distance) < 0 -} - -// Swap swaps the elements with indexes i and j. -func (sl *SortedList) Swap(i int, j int) { - sl.Peers[i], sl.Peers[j] = sl.Peers[j], sl.Peers[i] -} - -// SortedPeers get the orted list of peers -func (sl *SortedList) SortedPeers() []peer.ID { - sort.Sort(sl) - ret := make([]peer.ID, len(sl.Peers)) - - for i, id := range sl.Peers { - ret[i] = id.ID - } - return ret -} diff --git a/p2p/libp2p/peersOnChannel.go b/p2p/libp2p/peersOnChannel.go deleted file mode 100644 index 01ae7be96b3..00000000000 --- a/p2p/libp2p/peersOnChannel.go +++ /dev/null @@ -1,142 +0,0 @@ -package libp2p - -import ( - "context" - "sync" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/peer" -) - -// peersOnChannel manages peers on topics -// it buffers the data and refresh the peers list continuously (in refreshInterval intervals) -type peersOnChannel struct { - mutPeers sync.RWMutex - peersRatingHandler p2p.PeersRatingHandler - peers map[string][]core.PeerID - lastUpdated map[string]time.Time - - refreshInterval time.Duration - ttlInterval time.Duration - fetchPeersHandler func(topic string) []peer.ID - getTimeHandler func() time.Time - cancelFunc context.CancelFunc -} - -// newPeersOnChannel returns a new peersOnChannel object -func newPeersOnChannel( - peersRatingHandler p2p.PeersRatingHandler, - fetchPeersHandler func(topic string) []peer.ID, - refreshInterval time.Duration, - ttlInterval time.Duration, -) (*peersOnChannel, error) { - - if check.IfNil(peersRatingHandler) { - return nil, p2p.ErrNilPeersRatingHandler - } - if fetchPeersHandler == nil { - return nil, p2p.ErrNilFetchPeersOnTopicHandler - } - if refreshInterval == 0 { - return nil, p2p.ErrInvalidDurationProvided - } - if ttlInterval == 0 { - return nil, p2p.ErrInvalidDurationProvided - } - - ctx, cancelFunc := context.WithCancel(context.Background()) - - poc := &peersOnChannel{ - peersRatingHandler: peersRatingHandler, - peers: make(map[string][]core.PeerID), - lastUpdated: make(map[string]time.Time), - refreshInterval: refreshInterval, - ttlInterval: ttlInterval, - fetchPeersHandler: fetchPeersHandler, - cancelFunc: cancelFunc, - } - poc.getTimeHandler = poc.clockTime - - go poc.refreshPeersOnAllKnownTopics(ctx) - - return poc, nil -} - -func (poc *peersOnChannel) clockTime() time.Time { - return time.Now() -} - -// ConnectedPeersOnChannel returns the known peers on a topic -// if the list was not initialized, it will trigger a manual fetch -func (poc *peersOnChannel) ConnectedPeersOnChannel(topic string) []core.PeerID { - poc.mutPeers.RLock() - peers := poc.peers[topic] - poc.mutPeers.RUnlock() - - if peers != nil { - return peers - } - - return poc.refreshPeersOnTopic(topic) -} - -// updateConnectedPeersOnTopic updates the connected peers on a topic and the last update timestamp -func (poc *peersOnChannel) updateConnectedPeersOnTopic(topic string, connectedPeers []core.PeerID) { - poc.mutPeers.Lock() - poc.peers[topic] = connectedPeers - poc.lastUpdated[topic] = poc.getTimeHandler() - poc.mutPeers.Unlock() -} - -// refreshPeersOnAllKnownTopics iterates each topic, fetching its last timestamp -// it the timestamp + ttlInterval < time.Now, will trigger a fetch of connected peers on topic -func (poc *peersOnChannel) refreshPeersOnAllKnownTopics(ctx context.Context) { - for { - select { - case <-ctx.Done(): - log.Debug("refreshPeersOnAllKnownTopics's go routine is stopping...") - return - case <-time.After(poc.refreshInterval): - } - - listTopicsToBeRefreshed := make([]string, 0) - - // build required topic list - poc.mutPeers.RLock() - for topic, lastRefreshed := range poc.lastUpdated { - needsToBeRefreshed := poc.getTimeHandler().Sub(lastRefreshed) > poc.ttlInterval - if needsToBeRefreshed { - listTopicsToBeRefreshed = append(listTopicsToBeRefreshed, topic) - } - } - poc.mutPeers.RUnlock() - - for _, topic := range listTopicsToBeRefreshed { - _ = poc.refreshPeersOnTopic(topic) - } - } -} - -// refreshPeersOnTopic -func (poc *peersOnChannel) refreshPeersOnTopic(topic string) []core.PeerID { - list := poc.fetchPeersHandler(topic) - connectedPeers := make([]core.PeerID, len(list)) - for i, pid := range list { - peerID := core.PeerID(pid) - connectedPeers[i] = peerID - poc.peersRatingHandler.AddPeer(peerID) - } - - poc.updateConnectedPeersOnTopic(topic, connectedPeers) - return connectedPeers -} - -// Close closes all underlying components -func (poc *peersOnChannel) Close() error { - poc.cancelFunc() - - return nil -} diff --git a/p2p/libp2p/peersOnChannel_test.go b/p2p/libp2p/peersOnChannel_test.go deleted file mode 100644 index b37be93e048..00000000000 --- a/p2p/libp2p/peersOnChannel_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package libp2p - -import ( - "sync/atomic" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - coreAtomic "github.com/ElrondNetwork/elrond-go-core/core/atomic" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" -) - -func TestNewPeersOnChannel_NilPeersRatingHandlerShouldErr(t *testing.T) { - t.Parallel() - - poc, err := newPeersOnChannel(nil, nil, 1, 1) - - assert.Nil(t, poc) - assert.Equal(t, p2p.ErrNilPeersRatingHandler, err) -} - -func TestNewPeersOnChannel_NilFetchPeersHandlerShouldErr(t *testing.T) { - t.Parallel() - - poc, err := newPeersOnChannel(&p2pmocks.PeersRatingHandlerStub{}, nil, 1, 1) - - assert.Nil(t, poc) - assert.Equal(t, p2p.ErrNilFetchPeersOnTopicHandler, err) -} - -func TestNewPeersOnChannel_InvalidRefreshIntervalShouldErr(t *testing.T) { - t.Parallel() - - poc, err := newPeersOnChannel( - &p2pmocks.PeersRatingHandlerStub{}, - func(topic string) []peer.ID { - return nil - }, - 0, - 1) - - assert.Nil(t, poc) - assert.Equal(t, p2p.ErrInvalidDurationProvided, err) -} - -func TestNewPeersOnChannel_InvalidTTLIntervalShouldErr(t *testing.T) { - t.Parallel() - - poc, err := newPeersOnChannel( - &p2pmocks.PeersRatingHandlerStub{}, - func(topic string) []peer.ID { - return nil - }, - 1, - 0) - - assert.Nil(t, poc) - assert.Equal(t, p2p.ErrInvalidDurationProvided, err) -} - -func TestNewPeersOnChannel_OkValsShouldWork(t *testing.T) { - t.Parallel() - - poc, err := newPeersOnChannel( - &p2pmocks.PeersRatingHandlerStub{}, - func(topic string) []peer.ID { - return nil - }, - 1, - 1) - - assert.NotNil(t, poc) - assert.Nil(t, err) -} - -func TestPeersOnChannel_ConnectedPeersOnChannelMissingTopicShouldTriggerFetchAndReturn(t *testing.T) { - t.Parallel() - - retPeerIDs := []peer.ID{"peer1", "peer2"} - testTopic := "test_topic" - wasFetchCalled := atomic.Value{} - wasFetchCalled.Store(false) - - poc, _ := newPeersOnChannel( - &p2pmocks.PeersRatingHandlerStub{}, - func(topic string) []peer.ID { - if topic == testTopic { - wasFetchCalled.Store(true) - return retPeerIDs - } - return nil - }, - time.Second, - time.Second, - ) - - peers := poc.ConnectedPeersOnChannel(testTopic) - - assert.True(t, wasFetchCalled.Load().(bool)) - for idx, pid := range retPeerIDs { - assert.Equal(t, []byte(pid), peers[idx].Bytes()) - } -} - -func TestPeersOnChannel_ConnectedPeersOnChannelFindTopicShouldReturn(t *testing.T) { - t.Parallel() - - retPeerIDs := []core.PeerID{"peer1", "peer2"} - testTopic := "test_topic" - wasFetchCalled := atomic.Value{} - wasFetchCalled.Store(false) - - poc, _ := newPeersOnChannel( - &p2pmocks.PeersRatingHandlerStub{}, - func(topic string) []peer.ID { - wasFetchCalled.Store(true) - return nil - }, - time.Second, - time.Second, - ) - // manually put peers - poc.mutPeers.Lock() - poc.peers[testTopic] = retPeerIDs - poc.mutPeers.Unlock() - - peers := poc.ConnectedPeersOnChannel(testTopic) - - assert.False(t, wasFetchCalled.Load().(bool)) - for idx, pid := range retPeerIDs { - assert.Equal(t, []byte(pid), peers[idx].Bytes()) - } -} - -func TestPeersOnChannel_RefreshShouldBeDone(t *testing.T) { - t.Parallel() - - retPeerIDs := []core.PeerID{"peer1", "peer2"} - testTopic := "test_topic" - wasFetchCalled := coreAtomic.Flag{} - wasFetchCalled.Reset() - - refreshInterval := time.Millisecond * 100 - ttlInterval := time.Duration(2) - - poc, _ := newPeersOnChannel( - &p2pmocks.PeersRatingHandlerStub{}, - func(topic string) []peer.ID { - wasFetchCalled.SetValue(true) - return nil - }, - refreshInterval, - ttlInterval, - ) - poc.getTimeHandler = func() time.Time { - return time.Unix(0, 4) - } - // manually put peers - poc.mutPeers.Lock() - poc.peers[testTopic] = retPeerIDs - poc.lastUpdated[testTopic] = time.Unix(0, 1) - poc.mutPeers.Unlock() - - // wait for the go routine cycle finish up - time.Sleep(time.Second) - - assert.True(t, wasFetchCalled.IsSet()) - poc.mutPeers.Lock() - assert.Empty(t, poc.peers[testTopic]) - poc.mutPeers.Unlock() -} diff --git a/p2p/libp2p/ports.go b/p2p/libp2p/ports.go deleted file mode 100644 index db97b364e70..00000000000 --- a/p2p/libp2p/ports.go +++ /dev/null @@ -1,88 +0,0 @@ -package libp2p - -import ( - "fmt" - "net" - "strconv" - "strings" - - "github.com/ElrondNetwork/elrond-go-core/core/random" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -func getPort(port string, handler func(int) error) (int, error) { - val, err := strconv.Atoi(port) - if err == nil { - if val < 0 { - return 0, fmt.Errorf("%w, %d does not represent a positive value for port", p2p.ErrInvalidPortValue, val) - } - - return val, nil - } - - ports := strings.Split(port, "-") - if len(ports) != 2 { - return 0, fmt.Errorf("%w, provided port string `%s` is not in the correct format, expected `start-end`", p2p.ErrInvalidPortsRangeString, port) - } - - startPort, err := strconv.Atoi(ports[0]) - if err != nil { - return 0, p2p.ErrInvalidStartingPortValue - } - - endPort, err := strconv.Atoi(ports[1]) - if err != nil { - return 0, p2p.ErrInvalidEndingPortValue - } - - if startPort < minRangePortValue { - return 0, fmt.Errorf("%w, provided starting port should be >= %d", p2p.ErrInvalidValue, minRangePortValue) - } - if endPort < startPort { - return 0, p2p.ErrEndPortIsSmallerThanStartPort - } - - return choosePort(startPort, endPort, handler) -} - -func choosePort(startPort int, endPort int, handler func(int) error) (int, error) { - log.Debug("generating random free port", - "range", fmt.Sprintf("%d-%d", startPort, endPort), - ) - - ports := make([]int, 0, endPort-startPort+1) - for i := startPort; i <= endPort; i++ { - ports = append(ports, i) - } - - ports = random.FisherYatesShuffle(ports, &random.ConcurrentSafeIntRandomizer{}) - for _, p := range ports { - err := handler(p) - if err != nil { - log.Trace("opening port error", - "port", p, "error", err) - continue - } - - log.Debug("free port chosen", "port", p) - return p, nil - } - - return 0, fmt.Errorf("%w, range %d-%d", p2p.ErrNoFreePortInRange, startPort, endPort) -} - -func checkFreePort(port int) error { - addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("localhost:%d", port)) - if err != nil { - return err - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return err - } - - _ = l.Close() - - return nil -} diff --git a/p2p/libp2p/ports_test.go b/p2p/libp2p/ports_test.go deleted file mode 100644 index 8c2bb229cd7..00000000000 --- a/p2p/libp2p/ports_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package libp2p - -import ( - "errors" - "fmt" - "net" - "testing" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/stretchr/testify/assert" -) - -func TestGetPort_InvalidStringShouldErr(t *testing.T) { - t.Parallel() - - port, err := getPort("NaN", checkFreePort) - - assert.Equal(t, 0, port) - assert.True(t, errors.Is(err, p2p.ErrInvalidPortsRangeString)) -} - -func TestGetPort_InvalidPortNumberShouldErr(t *testing.T) { - t.Parallel() - - port, err := getPort("-1", checkFreePort) - assert.Equal(t, 0, port) - assert.True(t, errors.Is(err, p2p.ErrInvalidPortValue)) -} - -func TestGetPort_SinglePortShouldWork(t *testing.T) { - t.Parallel() - - port, err := getPort("0", checkFreePort) - assert.Equal(t, 0, port) - assert.Nil(t, err) - - p := 3638 - port, err = getPort(fmt.Sprintf("%d", p), checkFreePort) - assert.Equal(t, p, port) - assert.Nil(t, err) -} - -func TestCheckFreePort_InvalidStartingPortShouldErr(t *testing.T) { - t.Parallel() - - port, err := getPort("NaN-10000", checkFreePort) - assert.Equal(t, 0, port) - assert.Equal(t, p2p.ErrInvalidStartingPortValue, err) - - port, err = getPort("1024-10000", checkFreePort) - assert.Equal(t, 0, port) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) -} - -func TestCheckFreePort_InvalidEndingPortShouldErr(t *testing.T) { - t.Parallel() - - port, err := getPort("10000-NaN", checkFreePort) - assert.Equal(t, 0, port) - assert.Equal(t, p2p.ErrInvalidEndingPortValue, err) -} - -func TestGetPort_EndPortLargerThanSendPort(t *testing.T) { - t.Parallel() - - port, err := getPort("10000-9999", checkFreePort) - assert.Equal(t, 0, port) - assert.Equal(t, p2p.ErrEndPortIsSmallerThanStartPort, err) -} - -func TestGetPort_RangeOfOneShouldWork(t *testing.T) { - t.Parallel() - - port := 5000 - numCall := 0 - handler := func(p int) error { - if p != port { - assert.Fail(t, fmt.Sprintf("should have been %d", port)) - } - numCall++ - return nil - } - - result, err := getPort(fmt.Sprintf("%d-%d", port, port), handler) - assert.Nil(t, err) - assert.Equal(t, port, result) -} - -func TestGetPort_RangeOccupiedShouldErrorShouldWork(t *testing.T) { - t.Parallel() - - portStart := 5000 - portEnd := 10000 - portsTried := make(map[int]struct{}) - expectedErr := errors.New("expected error") - handler := func(p int) error { - portsTried[p] = struct{}{} - return expectedErr - } - - result, err := getPort(fmt.Sprintf("%d-%d", portStart, portEnd), handler) - - assert.True(t, errors.Is(err, p2p.ErrNoFreePortInRange)) - assert.Equal(t, portEnd-portStart+1, len(portsTried)) - assert.Equal(t, 0, result) -} - -func TestCheckFreePort_PortZeroAlwaysWorks(t *testing.T) { - err := checkFreePort(0) - - assert.Nil(t, err) -} - -func TestCheckFreePort_InvalidPortShouldErr(t *testing.T) { - err := checkFreePort(-1) - - assert.NotNil(t, err) -} - -func TestCheckFreePort_OccupiedPortShouldErr(t *testing.T) { - // 1. get a free port from OS, open a TCP listner - // 2. get the allocated port - // 3. test if that port is occupied - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - assert.Fail(t, err.Error()) - return - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - assert.Fail(t, err.Error()) - return - } - - port := l.Addr().(*net.TCPAddr).Port - - fmt.Printf("testing port %d\n", port) - err = checkFreePort(port) - assert.NotNil(t, err) - - _ = l.Close() -} diff --git a/p2p/libp2p/rand/factory/randFactory.go b/p2p/libp2p/rand/factory/randFactory.go deleted file mode 100644 index 2cb07565f90..00000000000 --- a/p2p/libp2p/rand/factory/randFactory.go +++ /dev/null @@ -1,17 +0,0 @@ -package factory - -import ( - cryptoRand "crypto/rand" - "io" - - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/rand" -) - -// NewRandFactory will create a reader based on the provided seed string -func NewRandFactory(seed string) (io.Reader, error) { - if len(seed) == 0 { - return cryptoRand.Reader, nil - } - - return rand.NewSeedRandReader([]byte(seed)) -} diff --git a/p2p/libp2p/rand/factory/randFactory_test.go b/p2p/libp2p/rand/factory/randFactory_test.go deleted file mode 100644 index fefcb83135b..00000000000 --- a/p2p/libp2p/rand/factory/randFactory_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package factory_test - -import ( - "crypto/rand" - "reflect" - "testing" - - rand2 "github.com/ElrondNetwork/elrond-go/p2p/libp2p/rand" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/rand/factory" - "github.com/stretchr/testify/assert" -) - -func TestNewRandFactory_EmptySeedShouldReturnCryptoRand(t *testing.T) { - t.Parallel() - - r, err := factory.NewRandFactory("") - - assert.Nil(t, err) - assert.True(t, r == rand.Reader) -} - -func TestNewRandFactory_NotEmptySeedShouldSeedRandReader(t *testing.T) { - t.Parallel() - - seed := "seed" - srrExpected, _ := rand2.NewSeedRandReader([]byte(seed)) - - r, err := factory.NewRandFactory(seed) - - assert.Nil(t, err) - assert.Equal(t, reflect.TypeOf(r), reflect.TypeOf(srrExpected)) -} diff --git a/p2p/libp2p/rand/seedRandReader.go b/p2p/libp2p/rand/seedRandReader.go deleted file mode 100644 index a6edcaee4de..00000000000 --- a/p2p/libp2p/rand/seedRandReader.go +++ /dev/null @@ -1,40 +0,0 @@ -package rand - -import ( - "crypto/sha256" - "encoding/binary" - "math/rand" - - "github.com/ElrondNetwork/elrond-go/p2p" -) - -type seedRandReader struct { - seedNumber int64 -} - -// NewSeedRandReader will return a new instance of a seed-based reader -// This is mostly used to generate predictable seeder addresses so other peers can connect to -func NewSeedRandReader(seed []byte) (*seedRandReader, error) { - if len(seed) == 0 { - return nil, p2p.ErrEmptySeed - } - - seedHash := sha256.Sum256(seed) - seedNumber := binary.BigEndian.Uint64(seedHash[:]) - - return &seedRandReader{ - seedNumber: int64(seedNumber), - }, nil -} - -// Read will read upto len(p) bytes. It will rotate the existing byte buffer (seed) until it will fill up the provided -// p buffer -func (srr *seedRandReader) Read(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, p2p.ErrEmptyBuffer - } - - randomizer := rand.New(rand.NewSource(srr.seedNumber)) - - return randomizer.Read(p) -} diff --git a/p2p/libp2p/rand/seedRandReader_test.go b/p2p/libp2p/rand/seedRandReader_test.go deleted file mode 100644 index d6dd249e8d2..00000000000 --- a/p2p/libp2p/rand/seedRandReader_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package rand_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/libp2p/rand" - "github.com/stretchr/testify/assert" -) - -func TestNewSeedRandReader_NilSeedShouldErr(t *testing.T) { - t.Parallel() - - srr, err := rand.NewSeedRandReader(nil) - - assert.Nil(t, srr) - assert.Equal(t, p2p.ErrEmptySeed, err) -} - -func TestNewSeedRandReader_ShouldWork(t *testing.T) { - t.Parallel() - - seed := []byte("seed") - srr, err := rand.NewSeedRandReader(seed) - - assert.NotNil(t, srr) - assert.Nil(t, err) -} - -func TestSeedRandReader_ReadNilBufferShouldErr(t *testing.T) { - t.Parallel() - - seed := []byte("seed") - srr, _ := rand.NewSeedRandReader(seed) - - n, err := srr.Read(nil) - - assert.Equal(t, 0, n) - assert.Equal(t, err, p2p.ErrEmptyBuffer) -} - -func TestSeedRandReader_ReadShouldWork(t *testing.T) { - t.Parallel() - - seed := []byte("seed") - srr, _ := rand.NewSeedRandReader(seed) - - testTbl := []struct { - pSize int - p []byte - n int - err error - name string - }{ - {pSize: 1, p: []byte{15}, n: 1, err: nil, name: "1 character"}, - {pSize: 2, p: []byte{15, 210}, n: 2, err: nil, name: "2 characters"}, - {pSize: 4, p: []byte{15, 210, 236, 97}, n: 4, err: nil, name: "4 characters"}, - {pSize: 5, p: []byte{15, 210, 236, 97, 112}, n: 5, err: nil, name: "5 characters"}, - {pSize: 8, p: []byte{15, 210, 236, 97, 112, 165, 91, 186}, n: 8, err: nil, name: "8 characters"}, - } - - for _, tc := range testTbl { - t.Run(tc.name, func(t *testing.T) { - p := make([]byte, tc.pSize) - - n, err := srr.Read(p) - - assert.Equal(t, tc.p, p) - assert.Equal(t, tc.n, n) - assert.Equal(t, tc.err, err) - }) - } -} diff --git a/p2p/libp2p/topicProcessors.go b/p2p/libp2p/topicProcessors.go deleted file mode 100644 index 7abc14741ae..00000000000 --- a/p2p/libp2p/topicProcessors.go +++ /dev/null @@ -1,68 +0,0 @@ -package libp2p - -import ( - "fmt" - "sync" - - "github.com/ElrondNetwork/elrond-go/p2p" -) - -type topicProcessors struct { - processors map[string]p2p.MessageProcessor - mutProcessors sync.RWMutex -} - -func newTopicProcessors() *topicProcessors { - return &topicProcessors{ - processors: make(map[string]p2p.MessageProcessor), - } -} - -func (tp *topicProcessors) addTopicProcessor(identifier string, processor p2p.MessageProcessor) error { - tp.mutProcessors.Lock() - defer tp.mutProcessors.Unlock() - - _, alreadyExists := tp.processors[identifier] - if alreadyExists { - return fmt.Errorf("%w, in addTopicProcessor, identifier %s", - p2p.ErrMessageProcessorAlreadyDefined, - identifier, - ) - } - - tp.processors[identifier] = processor - - return nil -} - -func (tp *topicProcessors) removeTopicProcessor(identifier string) error { - tp.mutProcessors.Lock() - defer tp.mutProcessors.Unlock() - - _, alreadyExists := tp.processors[identifier] - if !alreadyExists { - return fmt.Errorf("%w, in removeTopicProcessor, identifier %s", - p2p.ErrMessageProcessorDoesNotExists, - identifier, - ) - } - - delete(tp.processors, identifier) - - return nil -} - -func (tp *topicProcessors) getList() ([]string, []p2p.MessageProcessor) { - tp.mutProcessors.RLock() - defer tp.mutProcessors.RUnlock() - - list := make([]p2p.MessageProcessor, 0, len(tp.processors)) - identifiers := make([]string, 0, len(tp.processors)) - - for identifier, handler := range tp.processors { - list = append(list, handler) - identifiers = append(identifiers, identifier) - } - - return identifiers, list -} diff --git a/p2p/libp2p/topicProcessors_test.go b/p2p/libp2p/topicProcessors_test.go deleted file mode 100644 index 8510328d012..00000000000 --- a/p2p/libp2p/topicProcessors_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package libp2p - -import ( - "errors" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewTopicProcessors(t *testing.T) { - t.Parallel() - - tp := newTopicProcessors() - - assert.NotNil(t, tp) -} - -func TestTopicProcessorsAddShouldWork(t *testing.T) { - t.Parallel() - - tp := newTopicProcessors() - - identifier := "identifier" - proc := &mock.MessageProcessorStub{} - err := tp.addTopicProcessor(identifier, proc) - - assert.Nil(t, err) - require.Equal(t, 1, len(tp.processors)) - assert.True(t, proc == tp.processors[identifier]) // pointer testing -} - -func TestTopicProcessorsDoubleAddShouldErr(t *testing.T) { - t.Parallel() - - tp := newTopicProcessors() - - identifier := "identifier" - _ = tp.addTopicProcessor(identifier, &mock.MessageProcessorStub{}) - err := tp.addTopicProcessor(identifier, &mock.MessageProcessorStub{}) - - assert.True(t, errors.Is(err, p2p.ErrMessageProcessorAlreadyDefined)) - require.Equal(t, 1, len(tp.processors)) -} - -func TestTopicProcessorsRemoveInexistentShouldErr(t *testing.T) { - t.Parallel() - - tp := newTopicProcessors() - - identifier := "identifier" - err := tp.removeTopicProcessor(identifier) - - assert.True(t, errors.Is(err, p2p.ErrMessageProcessorDoesNotExists)) -} - -func TestTopicProcessorsRemoveShouldWork(t *testing.T) { - t.Parallel() - - tp := newTopicProcessors() - - identifier1 := "identifier1" - identifier2 := "identifier2" - _ = tp.addTopicProcessor(identifier1, &mock.MessageProcessorStub{}) - _ = tp.addTopicProcessor(identifier2, &mock.MessageProcessorStub{}) - - require.Equal(t, 2, len(tp.processors)) - - err := tp.removeTopicProcessor(identifier2) - - assert.Nil(t, err) - require.Equal(t, 1, len(tp.processors)) - - err = tp.removeTopicProcessor(identifier1) - - assert.Nil(t, err) - require.Equal(t, 0, len(tp.processors)) -} - -func TestTopicProcessorsGetListShouldWorkAndPreserveOrder(t *testing.T) { - t.Parallel() - - tp := newTopicProcessors() - - identifier1 := "identifier1" - identifier2 := "identifier2" - identifier3 := "identifier3" - handler1 := &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - return nil - }, - } - handler2 := &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - return nil - }, - } - handler3 := &mock.MessageProcessorStub{ - ProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - return nil - }, - } - - _ = tp.addTopicProcessor(identifier3, handler3) - _ = tp.addTopicProcessor(identifier1, handler1) - _ = tp.addTopicProcessor(identifier2, handler2) - - require.Equal(t, 3, len(tp.processors)) - - identifiers, handlers := tp.getList() - assert.ElementsMatch(t, identifiers, []string{identifier1, identifier2, identifier3}) - assert.ElementsMatch(t, handlers, []p2p.MessageProcessor{handler1, handler2, handler3}) - - _ = tp.removeTopicProcessor(identifier1) - identifiers, handlers = tp.getList() - assert.ElementsMatch(t, identifiers, []string{identifier2, identifier3}) - assert.ElementsMatch(t, handlers, []p2p.MessageProcessor{handler2, handler3}) - - _ = tp.removeTopicProcessor(identifier2) - identifiers, handlers = tp.getList() - assert.Equal(t, identifiers, []string{identifier3}) - assert.Equal(t, handlers, []p2p.MessageProcessor{handler3}) - - _ = tp.removeTopicProcessor(identifier3) - identifiers, handlers = tp.getList() - assert.Equal(t, identifiers, make([]string, 0)) - assert.Equal(t, handlers, make([]p2p.MessageProcessor, 0)) -} diff --git a/p2p/libp2p/unknownPeerShardResolver.go b/p2p/libp2p/unknownPeerShardResolver.go deleted file mode 100644 index 55af89c06b4..00000000000 --- a/p2p/libp2p/unknownPeerShardResolver.go +++ /dev/null @@ -1,25 +0,0 @@ -package libp2p - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -var _ p2p.PeerShardResolver = (*unknownPeerShardResolver)(nil) - -type unknownPeerShardResolver struct { -} - -// GetPeerInfo returns a P2PPeerInfo value holding an unknown peer value -func (upsr *unknownPeerShardResolver) GetPeerInfo(_ core.PeerID) core.P2PPeerInfo { - return core.P2PPeerInfo{ - PeerType: core.UnknownPeer, - PeerSubType: core.RegularPeer, - ShardID: 0, - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (upsr *unknownPeerShardResolver) IsInterfaceNil() bool { - return upsr == nil -} diff --git a/p2p/libp2p/unknownPeerShardResolver_test.go b/p2p/libp2p/unknownPeerShardResolver_test.go deleted file mode 100644 index d366dacd1ee..00000000000 --- a/p2p/libp2p/unknownPeerShardResolver_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package libp2p - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/stretchr/testify/assert" -) - -func TestUnknownPeerShardResolver_IsInterfaceNil(t *testing.T) { - t.Parallel() - - var upsr *unknownPeerShardResolver - assert.True(t, check.IfNil(upsr)) - - upsr = &unknownPeerShardResolver{} - assert.False(t, check.IfNil(upsr)) -} - -func TestUnknownPeerShardResolver_GetPeerInfoShouldReturnUnknownId(t *testing.T) { - t.Parallel() - - upsr := &unknownPeerShardResolver{} - expectedPeerInfo := core.P2PPeerInfo{ - PeerType: core.UnknownPeer, - ShardID: 0, - } - - assert.Equal(t, expectedPeerInfo, upsr.GetPeerInfo("")) -} diff --git a/p2p/loadBalancer/export_test.go b/p2p/loadBalancer/export_test.go deleted file mode 100644 index ff21d7bb79e..00000000000 --- a/p2p/loadBalancer/export_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package loadBalancer - -import ( - "github.com/ElrondNetwork/elrond-go/p2p" -) - -func (oplb *OutgoingChannelLoadBalancer) Chans() []chan *p2p.SendableData { - return oplb.chans -} - -func (oplb *OutgoingChannelLoadBalancer) Names() []string { - return oplb.names -} - -func (oplb *OutgoingChannelLoadBalancer) NamesChans() map[string]chan *p2p.SendableData { - return oplb.namesChans -} - -func DefaultSendChannel() string { - return defaultSendChannel -} diff --git a/p2p/loadBalancer/outgoingChannelLoadBalancer.go b/p2p/loadBalancer/outgoingChannelLoadBalancer.go deleted file mode 100644 index 6ef07f88056..00000000000 --- a/p2p/loadBalancer/outgoingChannelLoadBalancer.go +++ /dev/null @@ -1,162 +0,0 @@ -package loadBalancer - -import ( - "context" - "sync" - - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -var _ p2p.ChannelLoadBalancer = (*OutgoingChannelLoadBalancer)(nil) -var log = logger.GetOrCreate("p2p/loadbalancer") - -const defaultSendChannel = "default send channel" - -// OutgoingChannelLoadBalancer is a component that evenly balances requests to be sent -type OutgoingChannelLoadBalancer struct { - mut sync.RWMutex - chans []chan *p2p.SendableData - mainChan chan *p2p.SendableData - names []string - //namesChans is defined only for performance purposes as to fast search by name - //iteration is done directly on slices as that is used very often and is about 50x - //faster then an iteration over a map - namesChans map[string]chan *p2p.SendableData - cancelFunc context.CancelFunc - ctx context.Context //we need the context saved here in order to call appendChannel from exported func AddChannel -} - -// NewOutgoingChannelLoadBalancer creates a new instance of a ChannelLoadBalancer instance -func NewOutgoingChannelLoadBalancer() *OutgoingChannelLoadBalancer { - ctx, cancelFunc := context.WithCancel(context.Background()) - - oclb := &OutgoingChannelLoadBalancer{ - chans: make([]chan *p2p.SendableData, 0), - names: make([]string, 0), - namesChans: make(map[string]chan *p2p.SendableData), - mainChan: make(chan *p2p.SendableData), - cancelFunc: cancelFunc, - ctx: ctx, - } - - oclb.appendChannel(defaultSendChannel) - - return oclb -} - -func (oplb *OutgoingChannelLoadBalancer) appendChannel(channel string) { - oplb.names = append(oplb.names, channel) - ch := make(chan *p2p.SendableData) - oplb.chans = append(oplb.chans, ch) - oplb.namesChans[channel] = ch - - go func() { - for { - var obj *p2p.SendableData - - select { - case obj = <-ch: - case <-oplb.ctx.Done(): - log.Debug("closing OutgoingChannelLoadBalancer's append channel go routine") - return - } - - oplb.mainChan <- obj - } - }() -} - -// AddChannel adds a new channel to the throttler, if it does not exists -func (oplb *OutgoingChannelLoadBalancer) AddChannel(channel string) error { - if channel == defaultSendChannel { - return p2p.ErrChannelCanNotBeReAdded - } - - oplb.mut.Lock() - defer oplb.mut.Unlock() - - for _, name := range oplb.names { - if name == channel { - return nil - } - } - - oplb.appendChannel(channel) - - return nil -} - -// RemoveChannel removes an existing channel from the throttler -func (oplb *OutgoingChannelLoadBalancer) RemoveChannel(channel string) error { - if channel == defaultSendChannel { - return p2p.ErrChannelCanNotBeDeleted - } - - oplb.mut.Lock() - defer oplb.mut.Unlock() - - index := -1 - - for idx, name := range oplb.names { - if name == channel { - index = idx - break - } - } - - if index == -1 { - return p2p.ErrChannelDoesNotExist - } - - sendableChan := oplb.chans[index] - - //remove the index-th element in the chan slice - copy(oplb.chans[index:], oplb.chans[index+1:]) - oplb.chans[len(oplb.chans)-1] = nil - oplb.chans = oplb.chans[:len(oplb.chans)-1] - - //remove the index-th element in the names slice - copy(oplb.names[index:], oplb.names[index+1:]) - oplb.names = oplb.names[:len(oplb.names)-1] - - close(sendableChan) - - delete(oplb.namesChans, channel) - - return nil -} - -// GetChannelOrDefault fetches the required channel or the default if the channel is not present -func (oplb *OutgoingChannelLoadBalancer) GetChannelOrDefault(channel string) chan *p2p.SendableData { - oplb.mut.RLock() - defer oplb.mut.RUnlock() - - ch := oplb.namesChans[channel] - if ch != nil { - return ch - } - - return oplb.chans[0] -} - -// CollectOneElementFromChannels gets the waiting object from mainChan. It is a blocking call. -func (oplb *OutgoingChannelLoadBalancer) CollectOneElementFromChannels() *p2p.SendableData { - select { - case obj := <-oplb.mainChan: - return obj - case <-oplb.ctx.Done(): - return nil - } -} - -// Close finishes all started go routines in this instance -func (oplb *OutgoingChannelLoadBalancer) Close() error { - oplb.cancelFunc() - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (oplb *OutgoingChannelLoadBalancer) IsInterfaceNil() bool { - return oplb == nil -} diff --git a/p2p/loadBalancer/outgoingChannelLoadBalancer_test.go b/p2p/loadBalancer/outgoingChannelLoadBalancer_test.go deleted file mode 100644 index 1098bd6a7c3..00000000000 --- a/p2p/loadBalancer/outgoingChannelLoadBalancer_test.go +++ /dev/null @@ -1,301 +0,0 @@ -package loadBalancer_test - -import ( - "errors" - "sync" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/loadBalancer" - "github.com/stretchr/testify/assert" -) - -var errLenDifferent = errors.New("len different for names and chans") -var errLenDifferentNamesChans = errors.New("len different for names and chans") -var errMissingChannel = errors.New("missing channel") -var errChannelsMismatch = errors.New("channels mismatch") -var durationWait = time.Second * 2 - -func checkIntegrity(oclb *loadBalancer.OutgoingChannelLoadBalancer, name string) error { - if len(oclb.Names()) != len(oclb.Chans()) { - return errLenDifferent - } - - if len(oclb.Names()) != len(oclb.NamesChans()) { - return errLenDifferentNamesChans - } - - idxFound := -1 - for i, n := range oclb.Names() { - if n == name { - idxFound = i - break - } - } - - if idxFound == -1 && oclb.NamesChans()[name] == nil { - return errMissingChannel - } - - if oclb.NamesChans()[name] != oclb.Chans()[idxFound] { - return errChannelsMismatch - } - - return nil -} - -//------- NewOutgoingChannelLoadBalancer - -func TestNewOutgoingChannelLoadBalancer_ShouldNotProduceNil(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - assert.NotNil(t, oclb) -} - -func TestNewOutgoingChannelLoadBalancer_ShouldAddDefaultChannel(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - assert.Equal(t, 1, len(oclb.Names())) - assert.Nil(t, checkIntegrity(oclb, loadBalancer.DefaultSendChannel())) -} - -//------- AddChannel - -func TestOutgoingChannelLoadBalancer_AddChannelNewChannelShouldNotErrAndAddNewChannel(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - err := oclb.AddChannel("test") - - assert.Nil(t, err) - assert.Equal(t, 2, len(oclb.Names())) - assert.Nil(t, checkIntegrity(oclb, loadBalancer.DefaultSendChannel())) - assert.Nil(t, checkIntegrity(oclb, "test")) -} - -func TestOutgoingChannelLoadBalancer_AddChannelDefaultChannelShouldErr(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - err := oclb.AddChannel(loadBalancer.DefaultSendChannel()) - - assert.Equal(t, p2p.ErrChannelCanNotBeReAdded, err) -} - -func TestOutgoingChannelLoadBalancer_AddChannelReAddChannelShouldDoNothing(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - _ = oclb.AddChannel("test") - err := oclb.AddChannel("test") - - assert.Nil(t, err) - assert.Equal(t, 2, len(oclb.Chans())) -} - -//------- RemoveChannel - -func TestOutgoingChannelLoadBalancer_RemoveChannelRemoveDefaultShouldErr(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - err := oclb.RemoveChannel(loadBalancer.DefaultSendChannel()) - - assert.Equal(t, p2p.ErrChannelCanNotBeDeleted, err) -} - -func TestOutgoingChannelLoadBalancer_RemoveChannelRemoveNotFoundChannelShouldErr(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - err := oclb.RemoveChannel("test") - - assert.Equal(t, p2p.ErrChannelDoesNotExist, err) -} - -func TestOutgoingChannelLoadBalancer_RemoveChannelRemoveLastChannelAddedShouldWork(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - _ = oclb.AddChannel("test1") - _ = oclb.AddChannel("test2") - _ = oclb.AddChannel("test3") - - err := oclb.RemoveChannel("test3") - - assert.Nil(t, err) - - assert.Equal(t, 3, len(oclb.Names())) - assert.Nil(t, checkIntegrity(oclb, loadBalancer.DefaultSendChannel())) - assert.Nil(t, checkIntegrity(oclb, "test1")) - assert.Nil(t, checkIntegrity(oclb, "test2")) - assert.Equal(t, errMissingChannel, checkIntegrity(oclb, "test3")) -} - -func TestOutgoingChannelLoadBalancer_RemoveChannelRemoveFirstChannelAddedShouldWork(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - _ = oclb.AddChannel("test1") - _ = oclb.AddChannel("test2") - _ = oclb.AddChannel("test3") - - err := oclb.RemoveChannel("test1") - - assert.Nil(t, err) - - assert.Equal(t, 3, len(oclb.Names())) - assert.Nil(t, checkIntegrity(oclb, loadBalancer.DefaultSendChannel())) - assert.Equal(t, errMissingChannel, checkIntegrity(oclb, "test1")) - assert.Nil(t, checkIntegrity(oclb, "test2")) - assert.Nil(t, checkIntegrity(oclb, "test3")) -} - -func TestOutgoingChannelLoadBalancer_RemoveChannelRemoveMiddleChannelAddedShouldWork(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - _ = oclb.AddChannel("test1") - _ = oclb.AddChannel("test2") - _ = oclb.AddChannel("test3") - - err := oclb.RemoveChannel("test2") - - assert.Nil(t, err) - - assert.Equal(t, 3, len(oclb.Names())) - assert.Nil(t, checkIntegrity(oclb, loadBalancer.DefaultSendChannel())) - assert.Nil(t, checkIntegrity(oclb, "test1")) - assert.Equal(t, errMissingChannel, checkIntegrity(oclb, "test2")) - assert.Nil(t, checkIntegrity(oclb, "test3")) -} - -//------- GetChannelOrDefault - -func TestOutgoingChannelLoadBalancer_GetChannelOrDefaultNotFoundShouldReturnDefault(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - _ = oclb.AddChannel("test1") - - channel := oclb.GetChannelOrDefault("missing channel") - - assert.True(t, oclb.NamesChans()[loadBalancer.DefaultSendChannel()] == channel) -} - -func TestOutgoingChannelLoadBalancer_GetChannelOrDefaultFoundShouldReturnChannel(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - _ = oclb.AddChannel("test1") - - channel := oclb.GetChannelOrDefault("test1") - - assert.True(t, oclb.NamesChans()["test1"] == channel) -} - -//------- CollectOneElementFromChannels - -func TestOutgoingChannelLoadBalancer_CollectFromChannelsNoObjectsShouldWaitBlocking(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - chanDone := make(chan struct{}) - - go func() { - _ = oclb.CollectOneElementFromChannels() - - chanDone <- struct{}{} - }() - - select { - case <-chanDone: - assert.Fail(t, "should have not received object") - case <-time.After(durationWait): - } -} - -func TestOutgoingChannelLoadBalancer_CollectOneElementFromChannelsShouldWork(t *testing.T) { - t.Parallel() - - oclb := loadBalancer.NewOutgoingChannelLoadBalancer() - - _ = oclb.AddChannel("test") - - obj1 := &p2p.SendableData{Topic: "test"} - obj2 := &p2p.SendableData{Topic: "default"} - - chanDone := make(chan bool) - wg := sync.WaitGroup{} - wg.Add(3) - - //send on channel test - go func() { - oclb.GetChannelOrDefault("test") <- obj1 - wg.Done() - }() - - //send on default channel - go func() { - oclb.GetChannelOrDefault(loadBalancer.DefaultSendChannel()) <- obj2 - wg.Done() - }() - - //func to wait finishing sending and receiving - go func() { - wg.Wait() - chanDone <- true - }() - - //func to periodically consume from channels - go func() { - foundObj1 := false - foundObj2 := false - - for { - obj := oclb.CollectOneElementFromChannels() - - if !foundObj1 { - if obj == obj1 { - foundObj1 = true - } - } - - if !foundObj2 { - if obj == obj2 { - foundObj2 = true - } - } - - if foundObj1 && foundObj2 { - break - } - } - - wg.Done() - }() - - select { - case <-chanDone: - return - case <-time.After(durationWait): - assert.Fail(t, "timeout") - return - } -} diff --git a/p2p/memp2p/errors.go b/p2p/memp2p/errors.go deleted file mode 100644 index 17e87241034..00000000000 --- a/p2p/memp2p/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package memp2p - -import "errors" - -// ErrNilNetwork signals that a nil was given where a memp2p.Network instance was expected -var ErrNilNetwork = errors.New("nil network") - -// ErrNotConnectedToNetwork signals that a peer tried to perform a network-related operation, but is not connected to any network -var ErrNotConnectedToNetwork = errors.New("not connected to network") - -// ErrReceivingPeerNotConnected signals that the receiving peer of a sending operation is not connected to the network -var ErrReceivingPeerNotConnected = errors.New("receiving peer not connected to network") diff --git a/p2p/memp2p/export_test.go b/p2p/memp2p/export_test.go deleted file mode 100644 index d779ae8e0ee..00000000000 --- a/p2p/memp2p/export_test.go +++ /dev/null @@ -1,11 +0,0 @@ -package memp2p - -import "github.com/ElrondNetwork/elrond-go/p2p" - -func (messenger *Messenger) TopicValidator(name string) p2p.MessageProcessor { - messenger.topicsMutex.RLock() - processor := messenger.topicValidators[name] - messenger.topicsMutex.RUnlock() - - return processor -} diff --git a/p2p/memp2p/message.go b/p2p/memp2p/message.go deleted file mode 100644 index 15ae88b49f7..00000000000 --- a/p2p/memp2p/message.go +++ /dev/null @@ -1,91 +0,0 @@ -package memp2p - -import ( - "encoding/binary" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -var _ p2p.MessageP2P = (*message)(nil) - -// Message represents a message to be sent through the in-memory network -// simulated by the Network struct. -type message struct { - from []byte - data []byte - seqNo []byte - topic string - signature []byte - key []byte - peer core.PeerID - payloadField []byte - timestampField int64 -} - -// NewMessage constructs a new Message instance from arguments -func newMessage(topic string, data []byte, peerID core.PeerID, seqNo uint64) *message { - empty := make([]byte, 0) - seqNoBytes := make([]byte, 8) - binary.BigEndian.PutUint64(seqNoBytes, seqNo) - - return &message{ - from: []byte(peerID), - data: data, - seqNo: seqNoBytes, - topic: topic, - signature: empty, - key: []byte(peerID), - peer: peerID, - } -} - -// From returns the message originator's peer ID -func (msg *message) From() []byte { - return msg.from -} - -// Data returns the message payload -func (msg *message) Data() []byte { - return msg.data -} - -// SeqNo returns the message sequence number -func (msg *message) SeqNo() []byte { - return msg.seqNo -} - -// Topic returns the topic on which the message was sent -func (msg *message) Topic() string { - return msg.topic -} - -// Signature returns the message signature -func (msg *message) Signature() []byte { - return msg.signature -} - -// Key returns the message public key (if it can not be recovered from From field) -func (msg *message) Key() []byte { - return msg.key -} - -// Peer returns the peer that originated the message -func (msg *message) Peer() core.PeerID { - return msg.peer -} - -// Payload returns the encapsulated message along with meta data such as timestamp -func (msg *message) Payload() []byte { - return msg.payloadField -} - -// Timestamp returns the message timestamp to prevent endless re-processing of the same message -func (msg *message) Timestamp() int64 { - return msg.timestampField -} - -// IsInterfaceNil returns true if there is no value under the interface -func (msg *message) IsInterfaceNil() bool { - return msg == nil -} diff --git a/p2p/memp2p/messenger.go b/p2p/memp2p/messenger.go deleted file mode 100644 index c7ba39f1aba..00000000000 --- a/p2p/memp2p/messenger.go +++ /dev/null @@ -1,404 +0,0 @@ -package memp2p - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "sync" - "sync/atomic" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -const maxQueueSize = 1000 - -var log = logger.GetOrCreate("p2p/memp2p") - -// Messenger is an implementation of the p2p.Messenger interface that -// uses no real networking code, but instead connects to a network simulated in -// memory (the Network struct). The Messenger is intended for use -// in automated tests instead of the real libp2p, in order to speed up their -// execution and reduce resource usage. -// -// All message-sending functions imitate the synchronous/asynchronous -// behavior of the Messenger struct originally implemented for libp2p. Note -// that the Network ensures that all messengers are connected to all -// other messengers, thus when a Messenger is connected to an in-memory -// network, it reports being connected to all the nodes. Consequently, -// broadcasting a message will be received by all the messengers in the -// network. -type Messenger struct { - network *Network - p2pID core.PeerID - address string - topics map[string]struct{} - topicValidators map[string]p2p.MessageProcessor - topicsMutex *sync.RWMutex - seqNo uint64 - processQueue chan p2p.MessageP2P - numReceived uint64 -} - -// NewMessenger constructs a new Messenger that is connected to the -// Network instance provided as argument. -func NewMessenger(network *Network) (*Messenger, error) { - if network == nil { - return nil, ErrNilNetwork - } - - buff := make([]byte, 32) - _, _ = rand.Reader.Read(buff) - ID := base64.StdEncoding.EncodeToString(buff) - Address := fmt.Sprintf("/memp2p/%s", ID) - - messenger := &Messenger{ - network: network, - p2pID: core.PeerID(ID), - address: Address, - topics: make(map[string]struct{}), - topicValidators: make(map[string]p2p.MessageProcessor), - topicsMutex: &sync.RWMutex{}, - processQueue: make(chan p2p.MessageP2P, maxQueueSize), - } - network.RegisterPeer(messenger) - go messenger.processFromQueue() - - return messenger, nil -} - -// ID returns the P2P ID of the messenger -func (messenger *Messenger) ID() core.PeerID { - return messenger.p2pID -} - -// Peers returns a slice containing the P2P IDs of all the other peers that it -// has knowledge of. Since this is an in-memory network structured as a fully -// connected graph, this function returns the list of the P2P IDs of all the -// peers in the network (assuming this Messenger is connected). -func (messenger *Messenger) Peers() []core.PeerID { - // If the messenger is connected to the network, it has knowledge of all - // other peers. - if !messenger.IsConnectedToNetwork() { - return []core.PeerID{} - } - return messenger.network.PeerIDs() -} - -// Addresses returns a list of all the physical addresses that this Messenger -// is bound to and listening to, depending on the available network interfaces -// of the machine. Being an in-memory simulation, the only possible address to -// return is an artificial one, built by the constructor NewMessenger(). -func (messenger *Messenger) Addresses() []string { - addresses := make([]string, 1) - addresses[0] = messenger.address - return addresses -} - -// ConnectToPeer usually does nothing, because peers connected to the in-memory -// network are already all connected to each other. This function will return -// an error if the Messenger is not connected to the network, though. -func (messenger *Messenger) ConnectToPeer(_ string) error { - if !messenger.IsConnectedToNetwork() { - return ErrNotConnectedToNetwork - } - // Do nothing, all peers are connected to each other already. - return nil -} - -// IsConnectedToNetwork returns true if this messenger is connected to the -// in-memory network, false otherwise. -func (messenger *Messenger) IsConnectedToNetwork() bool { - return messenger.network.IsPeerConnected(messenger.ID()) -} - -// IsConnected returns true if this Messenger is connected to the peer with the -// specified ID. It always returns true if the Messenger is connected to the -// network and false otherwise, regardless of the provided peer ID. -func (messenger *Messenger) IsConnected(_ core.PeerID) bool { - return messenger.IsConnectedToNetwork() -} - -// ConnectedPeers returns a slice of IDs belonging to the peers to which this -// Messenger is connected. If the Messenger is connected to the in₋memory -// network, then the function returns a slice containing the IDs of all the -// other peers connected to the network. Returns false if the Messenger is -// not connected. -func (messenger *Messenger) ConnectedPeers() []core.PeerID { - if !messenger.IsConnectedToNetwork() { - return []core.PeerID{} - } - return messenger.network.PeerIDsExceptOne(messenger.ID()) -} - -// ConnectedAddresses returns a slice of peer addresses to which this Messenger -// is connected. If this Messenger is connected to the network, then the -// addresses of all the other peers in the network are returned. -func (messenger *Messenger) ConnectedAddresses() []string { - if !messenger.IsConnectedToNetwork() { - return []string{} - } - return messenger.network.ListAddressesExceptOne(messenger.ID()) -} - -// PeerAddresses creates the address string from a given peer ID. -func (messenger *Messenger) PeerAddresses(pid core.PeerID) []string { - return []string{fmt.Sprintf("/memp2p/%s", string(pid))} -} - -// ConnectedPeersOnTopic returns a slice of IDs belonging to the peers in the -// network that have declared their interest in the given topic and are -// listening to messages on that topic. -func (messenger *Messenger) ConnectedPeersOnTopic(topic string) []core.PeerID { - var filteredPeers []core.PeerID - if !messenger.IsConnectedToNetwork() { - return filteredPeers - } - - allPeersExceptThis := messenger.network.PeersExceptOne(messenger.ID()) - for _, peer := range allPeersExceptThis { - if peer.HasTopic(topic) { - filteredPeers = append(filteredPeers, peer.ID()) - } - } - - return filteredPeers -} - -// TrimConnections does nothing, as it is not applicable to the in-memory -// messenger. -func (messenger *Messenger) TrimConnections() { -} - -// Bootstrap does nothing, as it is not applicable to the in-memory messenger. -func (messenger *Messenger) Bootstrap(_ uint32) error { - return nil -} - -// CreateTopic adds the topic provided as argument to the list of topics of -// interest for this Messenger. It also registers a nil message validator to -// handle the messages received on this topic. -func (messenger *Messenger) CreateTopic(name string, _ bool) error { - messenger.topicsMutex.Lock() - defer messenger.topicsMutex.Unlock() - - _, found := messenger.topics[name] - if found { - return p2p.ErrTopicAlreadyExists - } - messenger.topics[name] = struct{}{} - - return nil -} - -// HasTopic returns true if this Messenger has declared interest in the given -// topic; returns false otherwise. -func (messenger *Messenger) HasTopic(name string) bool { - messenger.topicsMutex.RLock() - _, found := messenger.topics[name] - messenger.topicsMutex.RUnlock() - - return found -} - -// RegisterMessageProcessor sets the provided message processor to be the -// processor of received messages for the given topic. -func (messenger *Messenger) RegisterMessageProcessor(topic string, _ string, handler p2p.MessageProcessor) error { - if check.IfNil(handler) { - return p2p.ErrNilValidator - } - - messenger.topicsMutex.Lock() - defer messenger.topicsMutex.Unlock() - - _, found := messenger.topics[topic] - if !found { - return fmt.Errorf("%w RegisterMessageProcessor, topic: %s", p2p.ErrNilTopic, topic) - } - - validator := messenger.topicValidators[topic] - if !check.IfNil(validator) { - return p2p.ErrTopicValidatorOperationNotSupported - } - - messenger.topicValidators[topic] = handler - return nil -} - -// UnregisterMessageProcessor unsets the message processor for the given topic -// (sets it to nil). -func (messenger *Messenger) UnregisterMessageProcessor(topic string, _ string) error { - messenger.topicsMutex.Lock() - defer messenger.topicsMutex.Unlock() - - _, found := messenger.topics[topic] - if !found { - return fmt.Errorf("%w UnregisterMessageProcessor, topic: %s", p2p.ErrNilTopic, topic) - } - - validator := messenger.topicValidators[topic] - if check.IfNil(validator) { - return p2p.ErrTopicValidatorOperationNotSupported - } - - messenger.topicValidators[topic] = nil - return nil -} - -// OutgoingChannelLoadBalancer does nothing, as it is not applicable to the in-memory network. -func (messenger *Messenger) OutgoingChannelLoadBalancer() p2p.ChannelLoadBalancer { - return nil -} - -// BroadcastOnChannelBlocking sends the message to all peers in the network. It -// calls parametricBroadcast() with async=false, which means that peers will -// have their ReceiveMessage() function called synchronously. The call -// to parametricBroadcast() is done synchronously as well. This function should -// be called as a go-routine. -func (messenger *Messenger) BroadcastOnChannelBlocking(_ string, topic string, buff []byte) error { - return messenger.synchronousBroadcast(topic, buff) -} - -// BroadcastOnChannel sends the message to all peers in the network. It calls -// parametricBroadcast() with async=false, which means that peers will have -// their ReceiveMessage() function called synchronously. The call to -// parametricBroadcast() is done as a go-routine, which means this function is, -// in fact, non-blocking, but it is identical with BroadcastOnChannelBlocking() -// in all other regards. -func (messenger *Messenger) BroadcastOnChannel(_ string, topic string, buff []byte) { - err := messenger.synchronousBroadcast(topic, buff) - log.LogIfError(err) -} - -// Broadcast asynchronously sends the message to all peers in the network. It -// calls parametricBroadcast() with async=true, which means that peers will -// have their ReceiveMessage() function independently called as go-routines. -func (messenger *Messenger) Broadcast(topic string, buff []byte) { - err := messenger.synchronousBroadcast(topic, buff) - log.LogIfError(err) -} - -// synchronousBroadcast sends a message to all peers in the network in a synchronous way -func (messenger *Messenger) synchronousBroadcast(topic string, data []byte) error { - if !messenger.IsConnectedToNetwork() { - return ErrNotConnectedToNetwork - } - - seqNo := atomic.AddUint64(&messenger.seqNo, 1) - messageObject := newMessage(topic, data, messenger.ID(), seqNo) - - peers := messenger.network.Peers() - for _, peer := range peers { - peer.receiveMessage(messageObject) - } - - return nil -} - -func (messenger *Messenger) processFromQueue() { - for { - messageObject := <-messenger.processQueue - if check.IfNil(messageObject) { - continue - } - - topic := messageObject.Topic() - if topic == "" { - continue - } - - messenger.topicsMutex.Lock() - _, found := messenger.topics[topic] - if !found { - messenger.topicsMutex.Unlock() - continue - } - - // numReceived gets incremented because the message arrived on a registered topic - atomic.AddUint64(&messenger.numReceived, 1) - validator := messenger.topicValidators[topic] - if check.IfNil(validator) { - messenger.topicsMutex.Unlock() - continue - } - messenger.topicsMutex.Unlock() - - _ = validator.ProcessReceivedMessage(messageObject, messenger.p2pID) - } -} - -// SendToConnectedPeer sends a message directly to the peer specified by the ID. -func (messenger *Messenger) SendToConnectedPeer(topic string, buff []byte, peerID core.PeerID) error { - if messenger.IsConnectedToNetwork() { - seqNo := atomic.AddUint64(&messenger.seqNo, 1) - messageObject := newMessage(topic, buff, messenger.ID(), seqNo) - - receivingPeer, peerFound := messenger.network.Peers()[peerID] - if !peerFound { - return ErrReceivingPeerNotConnected - } - - receivingPeer.receiveMessage(messageObject) - - return nil - } - - return ErrNotConnectedToNetwork -} - -// receiveMessage handles the received message by passing it to the message -// processor of the corresponding topic, given that this Messenger has -// previously registered a message processor for that topic. The Network will -// log the message only if the Network.LogMessages flag is set and only if the -// Messenger has the requested topic and MessageProcessor. -func (messenger *Messenger) receiveMessage(message p2p.MessageP2P) { - messenger.processQueue <- message -} - -// IsConnectedToTheNetwork returns true as this implementation is always connected to its network -func (messenger *Messenger) IsConnectedToTheNetwork() bool { - return true -} - -// SetThresholdMinConnectedPeers does nothing as this implementation is always connected to its network -func (messenger *Messenger) SetThresholdMinConnectedPeers(_ int) error { - return nil -} - -// ThresholdMinConnectedPeers always return 0 -func (messenger *Messenger) ThresholdMinConnectedPeers() int { - return 0 -} - -// NumMessagesReceived returns the number of messages received -func (messenger *Messenger) NumMessagesReceived() uint64 { - return atomic.LoadUint64(&messenger.numReceived) -} - -// SetPeerShardResolver is a dummy function, not setting anything -func (messenger *Messenger) SetPeerShardResolver(_ p2p.PeerShardResolver) error { - return nil -} - -// SetPeerDenialEvaluator does nothing -func (messenger *Messenger) SetPeerDenialEvaluator(_ p2p.PeerDenialEvaluator) error { - return nil -} - -// GetConnectedPeersInfo returns a nil object. Not implemented. -func (messenger *Messenger) GetConnectedPeersInfo() *p2p.ConnectedPeersInfo { - return nil -} - -// Close disconnects this Messenger from the network it was connected to. -func (messenger *Messenger) Close() error { - messenger.network.UnregisterPeer(messenger.ID()) - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (messenger *Messenger) IsInterfaceNil() bool { - return messenger == nil -} diff --git a/p2p/memp2p/messenger_test.go b/p2p/memp2p/messenger_test.go deleted file mode 100644 index dd9c488f5c4..00000000000 --- a/p2p/memp2p/messenger_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package memp2p_test - -import ( - "errors" - "fmt" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/memp2p" - "github.com/ElrondNetwork/elrond-go/p2p/mock" - "github.com/stretchr/testify/assert" -) - -func TestInitializingNetworkAndPeer(t *testing.T) { - network := memp2p.NewNetwork() - - peer, err := memp2p.NewMessenger(network) - assert.Nil(t, err) - - assert.Equal(t, 1, len(network.Peers())) - - assert.Equal(t, 1, len(peer.Addresses())) - assert.Equal(t, "/memp2p/"+string(peer.ID()), peer.Addresses()[0]) - - err = peer.Close() - assert.Nil(t, err) -} - -func TestRegisteringTopics(t *testing.T) { - network := memp2p.NewNetwork() - - messenger, err := memp2p.NewMessenger(network) - assert.Nil(t, err) - - processor := &mock.MessageProcessorStub{} - - // Cannot register a MessageProcessor to a topic that doesn't exist. - err = messenger.RegisterMessageProcessor("rocket", "", processor) - assert.True(t, errors.Is(err, p2p.ErrNilTopic)) - - // Create a proper topic. - assert.False(t, messenger.HasTopic("rocket")) - assert.Nil(t, messenger.CreateTopic("rocket", false)) - assert.True(t, messenger.HasTopic("rocket")) - - // The newly created topic has no MessageProcessor attached to it, so we - // attach one now. - assert.Nil(t, messenger.TopicValidator("rocket")) - err = messenger.RegisterMessageProcessor("rocket", "", processor) - assert.Nil(t, err) - assert.Equal(t, processor, messenger.TopicValidator("rocket")) - - // Cannot unregister a MessageProcessor from a topic that doesn't exist. - err = messenger.UnregisterMessageProcessor("albatross", "") - assert.True(t, errors.Is(err, p2p.ErrNilTopic)) - - // Cannot unregister a MessageProcessor from a topic that doesn't have a - // MessageProcessor, even if the topic itself exists. - err = messenger.CreateTopic("nitrous_oxide", false) - assert.Nil(t, err) - err = messenger.UnregisterMessageProcessor("nitrous_oxide", "") - assert.Equal(t, p2p.ErrTopicValidatorOperationNotSupported, err) - - // Unregister the MessageProcessor from a topic that exists and has a - // MessageProcessor. - err = messenger.UnregisterMessageProcessor("rocket", "") - assert.Nil(t, err) - assert.True(t, messenger.HasTopic("rocket")) - assert.Nil(t, messenger.TopicValidator("rocket")) - - // Disallow creating duplicate topics. - err = messenger.CreateTopic("more_rockets", false) - assert.Nil(t, err) - assert.True(t, messenger.HasTopic("more_rockets")) - err = messenger.CreateTopic("more_rockets", false) - assert.NotNil(t, err) -} - -func TestBroadcastingMessages(t *testing.T) { - network := memp2p.NewNetwork() - - numPeers := 4 - peers := make([]*memp2p.Messenger, numPeers) - for i := 0; i < numPeers; i++ { - peer, _ := memp2p.NewMessenger(network) - _ = peer.CreateTopic("rocket", false) - peers[i] = peer - } - - // Send a message to everybody. - _ = peers[0].BroadcastOnChannelBlocking("rocket", "rocket", []byte("launch the rocket")) - time.Sleep(1 * time.Second) - testReceivedMessages(t, peers, map[int]uint64{0: 1, 1: 1, 2: 1, 3: 1, 4: 1}) - - // Send a message after disconnecting. No new messages should get broadcast - err := peers[0].Close() - assert.Nil(t, err) - _ = peers[0].BroadcastOnChannelBlocking("rocket", "rocket", []byte("launch the rocket again")) - time.Sleep(1 * time.Second) - testReceivedMessages(t, peers, map[int]uint64{0: 1, 1: 1, 2: 1, 3: 1, 4: 1}) - - peers[2].Broadcast("rocket", []byte("launch another rocket")) - time.Sleep(1 * time.Second) - testReceivedMessages(t, peers, map[int]uint64{0: 1, 1: 2, 2: 2, 3: 2, 4: 2}) - - peers[2].Broadcast("nitrous_oxide", []byte("this message should not get broadcast")) - time.Sleep(1 * time.Second) - testReceivedMessages(t, peers, map[int]uint64{0: 1, 1: 2, 2: 2, 3: 2, 4: 2}) -} - -func testReceivedMessages(t *testing.T, peers []*memp2p.Messenger, receivedNumMap map[int]uint64) { - for idx, p := range peers { - val, found := receivedNumMap[idx] - if !found { - assert.Fail(t, fmt.Sprintf("number of messages received was not defined for index %d", idx)) - return - } - - assert.Equal(t, val, p.NumMessagesReceived(), "for peer on index %d", idx) - } -} - -func TestConnectivityAndTopics(t *testing.T) { - network := memp2p.NewNetwork() - - // Create 4 peers on the network, all listening to the topic "rocket". - numPeers := 4 - peers := make([]*memp2p.Messenger, numPeers) - for i := 0; i < numPeers; i++ { - peer, _ := memp2p.NewMessenger(network) - _ = peer.CreateTopic("rocket", false) - peers[i] = peer - } - - // Peers 2 and 3 also listen on the topic "carbohydrate" - _ = peers[2].CreateTopic("carbohydrate", false) - _ = peers[2].RegisterMessageProcessor("carbohydrate", "", &mock.MessageProcessorStub{}) - _ = peers[3].CreateTopic("carbohydrate", false) - _ = peers[3].RegisterMessageProcessor("carbohydrate", "", &mock.MessageProcessorStub{}) - - // Test to which peers is Peer0 connected, based on the topics they listen to. - peer0 := peers[0] - assert.Equal(t, numPeers, len(network.PeerIDs())) - assert.Equal(t, numPeers-1, len(peer0.ConnectedPeers())) - assert.Equal(t, numPeers-1, len(peer0.ConnectedPeersOnTopic("rocket"))) - assert.Equal(t, 2, len(peer0.ConnectedPeersOnTopic("carbohydrate"))) -} - -func TestSendingDirectMessages(t *testing.T) { - network := memp2p.NewNetwork() - - peer1, _ := memp2p.NewMessenger(network) - peer2, _ := memp2p.NewMessenger(network) - - // Peer1 attempts to send a direct message to Peer2 on topic "rocket", but - // Peer2 is not listening to this topic. - _ = peer1.SendToConnectedPeer("rocket", []byte("try to launch this rocket"), peer2.ID()) - time.Sleep(time.Millisecond * 100) - - // The same as above, but in reverse (Peer2 sends to Peer1). - _ = peer2.SendToConnectedPeer("rocket", []byte("try to launch this rocket"), peer1.ID()) - time.Sleep(time.Millisecond * 100) - - // Both peers did not get the message - assert.Equal(t, uint64(0), peer1.NumMessagesReceived()) - assert.Equal(t, uint64(0), peer2.NumMessagesReceived()) - - // Create a topic on Peer1. This doesn't help, because Peer2 still can't - // receive messages on topic "rocket". - _ = peer1.CreateTopic("nitrous_oxide", false) - _ = peer2.SendToConnectedPeer("rocket", []byte("try to launch this rocket"), peer1.ID()) - time.Sleep(time.Millisecond * 100) - - // peer1 still did not get the message - assert.Equal(t, uint64(0), peer1.NumMessagesReceived()) - - // Finally, create the topic "rocket" on Peer1 - // This allows it to receive a message on this topic from Peer2. - _ = peer1.CreateTopic("rocket", false) - _ = peer2.SendToConnectedPeer("rocket", []byte("try to launch this rocket"), peer1.ID()) - time.Sleep(time.Millisecond * 100) - - // Peer1 got the message - assert.Equal(t, uint64(1), peer1.NumMessagesReceived()) -} diff --git a/p2p/memp2p/network.go b/p2p/memp2p/network.go deleted file mode 100644 index 224140832be..00000000000 --- a/p2p/memp2p/network.go +++ /dev/null @@ -1,131 +0,0 @@ -package memp2p - -import ( - "fmt" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" -) - -// Network provides in-memory connectivity for the Messenger -// struct. It simulates a network where each peer is connected to all the other -// peers. The peers are connected to the network if they are in the internal -// `peers` map; otherwise, they are disconnected. -type Network struct { - mutex sync.RWMutex - peers map[core.PeerID]*Messenger -} - -// NewNetwork constructs a new Network instance with an empty -// internal map of peers. -func NewNetwork() *Network { - network := Network{ - mutex: sync.RWMutex{}, - peers: make(map[core.PeerID]*Messenger), - } - - return &network -} - -// ListAddressesExceptOne provides the addresses of the known peers, except a specified one. -func (network *Network) ListAddressesExceptOne(peerIDToExclude core.PeerID) []string { - network.mutex.RLock() - resultingLength := len(network.peers) - 1 - addresses := make([]string, resultingLength) - idx := 0 - for _, peer := range network.peers { - if peer.ID() == peerIDToExclude { - continue - } - addresses[idx] = fmt.Sprintf("/memp2p/%s", peer.ID()) - idx++ - } - network.mutex.RUnlock() - - return addresses -} - -// Peers provides a copy of its internal map of peers -func (network *Network) Peers() map[core.PeerID]*Messenger { - peersCopy := make(map[core.PeerID]*Messenger) - - network.mutex.RLock() - for peerID, peer := range network.peers { - peersCopy[peerID] = peer - } - network.mutex.RUnlock() - - return peersCopy -} - -// PeersExceptOne provides a copy of its internal map of peers, excluding a specific peer. -func (network *Network) PeersExceptOne(peerIDToExclude core.PeerID) map[core.PeerID]*Messenger { - peersCopy := make(map[core.PeerID]*Messenger) - - network.mutex.RLock() - for peerID, peer := range network.peers { - if peerID == peerIDToExclude { - continue - } - peersCopy[peerID] = peer - } - network.mutex.RUnlock() - - return peersCopy -} - -// PeerIDs provides a copy of its internal slice of peerIDs -func (network *Network) PeerIDs() []core.PeerID { - network.mutex.RLock() - peerIDsCopy := make([]core.PeerID, len(network.peers)) - idx := 0 - for peerID := range network.peers { - peerIDsCopy[idx] = peerID - idx++ - } - network.mutex.RUnlock() - - return peerIDsCopy -} - -//PeerIDsExceptOne provides a copy of its internal slice of peerIDs, excluding a specific peer. -func (network *Network) PeerIDsExceptOne(peerIDToExclude core.PeerID) []core.PeerID { - network.mutex.RLock() - peerIDsCopy := make([]core.PeerID, len(network.peers)-1) - idx := 0 - for peerID := range network.peers { - if peerID == peerIDToExclude { - continue - } - peerIDsCopy[idx] = peerID - idx++ - } - network.mutex.RUnlock() - return peerIDsCopy -} - -// RegisterPeer adds a messenger to the Peers map and its PeerID to the peerIDs -// slice. -func (network *Network) RegisterPeer(messenger *Messenger) { - network.mutex.Lock() - network.peers[messenger.ID()] = messenger - network.mutex.Unlock() -} - -// UnregisterPeer removes a messenger from the Peers map and its PeerID from -// the peerIDs slice. -func (network *Network) UnregisterPeer(peerID core.PeerID) { - network.mutex.Lock() - delete(network.peers, peerID) - network.mutex.Unlock() -} - -// IsPeerConnected returns true if the peer represented by the provided ID is -// found in the inner `peers` map of the Network instance, which -// determines whether it is connected to the network or not. -func (network *Network) IsPeerConnected(peerID core.PeerID) bool { - network.mutex.RLock() - _, found := network.peers[peerID] - network.mutex.RUnlock() - return found -} diff --git a/p2p/message/directConnectionMessage.pb.go b/p2p/message/directConnectionMessage.pb.go deleted file mode 100644 index 9a2a6bb0aa9..00000000000 --- a/p2p/message/directConnectionMessage.pb.go +++ /dev/null @@ -1,379 +0,0 @@ -// Code generated by protoc-gen-gogo. DO NOT EDIT. -// source: directConnectionMessage.proto - -package message - -import ( - fmt "fmt" - _ "github.com/gogo/protobuf/gogoproto" - proto "github.com/gogo/protobuf/proto" - io "io" - math "math" - math_bits "math/bits" - reflect "reflect" - strings "strings" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package - -// DirectConnectionInfo represents the data regarding a new direct connection`s info -type DirectConnectionInfo struct { - ShardId string `protobuf:"bytes,1,opt,name=ShardId,proto3" json:"shardId"` -} - -func (m *DirectConnectionInfo) Reset() { *m = DirectConnectionInfo{} } -func (*DirectConnectionInfo) ProtoMessage() {} -func (*DirectConnectionInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_f237562c19ebfede, []int{0} -} -func (m *DirectConnectionInfo) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *DirectConnectionInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil -} -func (m *DirectConnectionInfo) XXX_Merge(src proto.Message) { - xxx_messageInfo_DirectConnectionInfo.Merge(m, src) -} -func (m *DirectConnectionInfo) XXX_Size() int { - return m.Size() -} -func (m *DirectConnectionInfo) XXX_DiscardUnknown() { - xxx_messageInfo_DirectConnectionInfo.DiscardUnknown(m) -} - -var xxx_messageInfo_DirectConnectionInfo proto.InternalMessageInfo - -func (m *DirectConnectionInfo) GetShardId() string { - if m != nil { - return m.ShardId - } - return "" -} - -func init() { - proto.RegisterType((*DirectConnectionInfo)(nil), "proto.DirectConnectionInfo") -} - -func init() { proto.RegisterFile("directConnectionMessage.proto", fileDescriptor_f237562c19ebfede) } - -var fileDescriptor_f237562c19ebfede = []byte{ - // 201 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4d, 0xc9, 0x2c, 0x4a, - 0x4d, 0x2e, 0x71, 0xce, 0xcf, 0xcb, 0x4b, 0x4d, 0x2e, 0xc9, 0xcc, 0xcf, 0xf3, 0x4d, 0x2d, 0x2e, - 0x4e, 0x4c, 0x4f, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x52, 0xba, 0xe9, - 0x99, 0x25, 0x19, 0xa5, 0x49, 0x7a, 0xc9, 0xf9, 0xb9, 0xfa, 0xe9, 0xf9, 0xe9, 0xf9, 0xfa, 0x60, - 0xe1, 0xa4, 0xd2, 0x34, 0x30, 0x0f, 0xcc, 0x01, 0xb3, 0x20, 0xba, 0x94, 0x6c, 0xb9, 0x44, 0x5c, - 0xd0, 0x8c, 0xf5, 0xcc, 0x4b, 0xcb, 0x17, 0x52, 0xe5, 0x62, 0x0f, 0xce, 0x48, 0x2c, 0x4a, 0xf1, - 0x4c, 0x91, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x74, 0xe2, 0x7e, 0x75, 0x4f, 0x9e, 0xbd, 0x18, 0x22, - 0x14, 0x04, 0x93, 0x73, 0x72, 0xbc, 0xf0, 0x50, 0x8e, 0xe1, 0xc6, 0x43, 0x39, 0x86, 0x0f, 0x0f, - 0xe5, 0x18, 0x1b, 0x1e, 0xc9, 0x31, 0xae, 0x78, 0x24, 0xc7, 0x78, 0xe2, 0x91, 0x1c, 0xe3, 0x85, - 0x47, 0x72, 0x8c, 0x37, 0x1e, 0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, 0xf8, 0xe2, 0x91, 0x1c, 0xc3, - 0x87, 0x47, 0x72, 0x8c, 0x13, 0x1e, 0xcb, 0x31, 0x5c, 0x78, 0x2c, 0xc7, 0x70, 0xe3, 0xb1, 0x1c, - 0x43, 0x14, 0x7b, 0x2e, 0xc4, 0xf5, 0x49, 0x6c, 0x60, 0x87, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff, - 0xff, 0x70, 0x6f, 0x2c, 0x03, 0xdf, 0x00, 0x00, 0x00, -} - -func (this *DirectConnectionInfo) Equal(that interface{}) bool { - if that == nil { - return this == nil - } - - that1, ok := that.(*DirectConnectionInfo) - if !ok { - that2, ok := that.(DirectConnectionInfo) - if ok { - that1 = &that2 - } else { - return false - } - } - if that1 == nil { - return this == nil - } else if this == nil { - return false - } - if this.ShardId != that1.ShardId { - return false - } - return true -} -func (this *DirectConnectionInfo) GoString() string { - if this == nil { - return "nil" - } - s := make([]string, 0, 5) - s = append(s, "&message.DirectConnectionInfo{") - s = append(s, "ShardId: "+fmt.Sprintf("%#v", this.ShardId)+",\n") - s = append(s, "}") - return strings.Join(s, "") -} -func valueToGoStringDirectConnectionMessage(v interface{}, typ string) string { - rv := reflect.ValueOf(v) - if rv.IsNil() { - return "nil" - } - pv := reflect.Indirect(rv).Interface() - return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) -} -func (m *DirectConnectionInfo) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *DirectConnectionInfo) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *DirectConnectionInfo) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if len(m.ShardId) > 0 { - i -= len(m.ShardId) - copy(dAtA[i:], m.ShardId) - i = encodeVarintDirectConnectionMessage(dAtA, i, uint64(len(m.ShardId))) - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - -func encodeVarintDirectConnectionMessage(dAtA []byte, offset int, v uint64) int { - offset -= sovDirectConnectionMessage(v) - base := offset - for v >= 1<<7 { - dAtA[offset] = uint8(v&0x7f | 0x80) - v >>= 7 - offset++ - } - dAtA[offset] = uint8(v) - return base -} -func (m *DirectConnectionInfo) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - l = len(m.ShardId) - if l > 0 { - n += 1 + l + sovDirectConnectionMessage(uint64(l)) - } - return n -} - -func sovDirectConnectionMessage(x uint64) (n int) { - return (math_bits.Len64(x|1) + 6) / 7 -} -func sozDirectConnectionMessage(x uint64) (n int) { - return sovDirectConnectionMessage(uint64((x << 1) ^ uint64((int64(x) >> 63)))) -} -func (this *DirectConnectionInfo) String() string { - if this == nil { - return "nil" - } - s := strings.Join([]string{`&DirectConnectionInfo{`, - `ShardId:` + fmt.Sprintf("%v", this.ShardId) + `,`, - `}`, - }, "") - return s -} -func valueToStringDirectConnectionMessage(v interface{}) string { - rv := reflect.ValueOf(v) - if rv.IsNil() { - return "nil" - } - pv := reflect.Indirect(rv).Interface() - return fmt.Sprintf("*%v", pv) -} -func (m *DirectConnectionInfo) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowDirectConnectionMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: DirectConnectionInfo: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: DirectConnectionInfo: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field ShardId", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowDirectConnectionMessage - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - stringLen |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthDirectConnectionMessage - } - postIndex := iNdEx + intStringLen - if postIndex < 0 { - return ErrInvalidLengthDirectConnectionMessage - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.ShardId = string(dAtA[iNdEx:postIndex]) - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipDirectConnectionMessage(dAtA[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthDirectConnectionMessage - } - if (iNdEx + skippy) < 0 { - return ErrInvalidLengthDirectConnectionMessage - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func skipDirectConnectionMessage(dAtA []byte) (n int, err error) { - l := len(dAtA) - iNdEx := 0 - depth := 0 - for iNdEx < l { - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowDirectConnectionMessage - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - wireType := int(wire & 0x7) - switch wireType { - case 0: - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowDirectConnectionMessage - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - iNdEx++ - if dAtA[iNdEx-1] < 0x80 { - break - } - } - case 1: - iNdEx += 8 - case 2: - var length int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowDirectConnectionMessage - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - length |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if length < 0 { - return 0, ErrInvalidLengthDirectConnectionMessage - } - iNdEx += length - case 3: - depth++ - case 4: - if depth == 0 { - return 0, ErrUnexpectedEndOfGroupDirectConnectionMessage - } - depth-- - case 5: - iNdEx += 4 - default: - return 0, fmt.Errorf("proto: illegal wireType %d", wireType) - } - if iNdEx < 0 { - return 0, ErrInvalidLengthDirectConnectionMessage - } - if depth == 0 { - return iNdEx, nil - } - } - return 0, io.ErrUnexpectedEOF -} - -var ( - ErrInvalidLengthDirectConnectionMessage = fmt.Errorf("proto: negative length found during unmarshaling") - ErrIntOverflowDirectConnectionMessage = fmt.Errorf("proto: integer overflow") - ErrUnexpectedEndOfGroupDirectConnectionMessage = fmt.Errorf("proto: unexpected end of group") -) diff --git a/p2p/message/directConnectionMessage.proto b/p2p/message/directConnectionMessage.proto deleted file mode 100644 index 26eeec0be32..00000000000 --- a/p2p/message/directConnectionMessage.proto +++ /dev/null @@ -1,13 +0,0 @@ -syntax = "proto3"; - -package proto; - -option go_package = "message"; -option (gogoproto.stable_marshaler_all) = true; - -import "github.com/gogo/protobuf/gogoproto/gogo.proto"; - -// DirectConnectionInfo represents the data regarding a new direct connection`s info -message DirectConnectionInfo { - string ShardId = 1 [(gogoproto.jsontag) = "shardId"]; -} diff --git a/p2p/message/generate.go b/p2p/message/generate.go deleted file mode 100644 index d0b9445a167..00000000000 --- a/p2p/message/generate.go +++ /dev/null @@ -1,3 +0,0 @@ -//go:generate protoc -I=. -I=$GOPATH/src -I=$GOPATH/src/github.com/ElrondNetwork/protobuf/protobuf --gogoslick_out=. directConnectionMessage.proto - -package message diff --git a/p2p/message/message.go b/p2p/message/message.go deleted file mode 100644 index 8b6d9d80a42..00000000000 --- a/p2p/message/message.go +++ /dev/null @@ -1,71 +0,0 @@ -package message - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -var _ p2p.MessageP2P = (*Message)(nil) - -// Message is a data holder struct -type Message struct { - FromField []byte - DataField []byte - PayloadField []byte - SeqNoField []byte - TopicField string - SignatureField []byte - KeyField []byte - PeerField core.PeerID - TimestampField int64 -} - -// From returns the message originator's peer ID -func (m *Message) From() []byte { - return m.FromField -} - -// Data returns the useful message that was actually sent -func (m *Message) Data() []byte { - return m.DataField -} - -// Payload returns the encapsulated message along with meta data such as timestamp -func (m *Message) Payload() []byte { - return m.PayloadField -} - -// SeqNo returns the message sequence number -func (m *Message) SeqNo() []byte { - return m.SeqNoField -} - -// Topic returns the topic on which the message was sent -func (m *Message) Topic() string { - return m.TopicField -} - -// Signature returns the message signature -func (m *Message) Signature() []byte { - return m.SignatureField -} - -// Key returns the message public key (if it can not be recovered from From field) -func (m *Message) Key() []byte { - return m.KeyField -} - -// Peer returns the peer that originated the message -func (m *Message) Peer() core.PeerID { - return m.PeerField -} - -// Timestamp returns the message timestamp to prevent endless re-processing of the same message -func (m *Message) Timestamp() int64 { - return m.TimestampField -} - -// IsInterfaceNil returns true if there is no value under the interface -func (m *Message) IsInterfaceNil() bool { - return m == nil -} diff --git a/p2p/message/message_test.go b/p2p/message/message_test.go deleted file mode 100644 index 2956b982d75..00000000000 --- a/p2p/message/message_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package message_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p/message" - "github.com/stretchr/testify/assert" -) - -func TestMessage_AllFieldsShouldWork(t *testing.T) { - t.Parallel() - - from := []byte("from") - data := []byte("data") - seqNo := []byte("seq no") - topic := "topic" - sig := []byte("sig") - key := []byte("key") - peer := core.PeerID("peer") - - msg := &message.Message{ - FromField: from, - DataField: data, - SeqNoField: seqNo, - TopicField: topic, - SignatureField: sig, - KeyField: key, - PeerField: peer, - } - - assert.False(t, check.IfNil(msg)) - assert.Equal(t, from, msg.From()) - assert.Equal(t, data, msg.Data()) - assert.Equal(t, seqNo, msg.SeqNo()) - assert.Equal(t, topic, msg.Topic()) - assert.Equal(t, sig, msg.Signature()) - assert.Equal(t, key, msg.Key()) - assert.Equal(t, peer, msg.Peer()) -} diff --git a/p2p/mock/channelLoadBalancerStub.go b/p2p/mock/channelLoadBalancerStub.go deleted file mode 100644 index c65be6bb5b7..00000000000 --- a/p2p/mock/channelLoadBalancerStub.go +++ /dev/null @@ -1,48 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go/p2p" -) - -// ChannelLoadBalancerStub - -type ChannelLoadBalancerStub struct { - AddChannelCalled func(pipe string) error - RemoveChannelCalled func(pipe string) error - GetChannelOrDefaultCalled func(pipe string) chan *p2p.SendableData - CollectOneElementFromChannelsCalled func() *p2p.SendableData - CloseCalled func() error -} - -// AddChannel - -func (clbs *ChannelLoadBalancerStub) AddChannel(pipe string) error { - return clbs.AddChannelCalled(pipe) -} - -// RemoveChannel - -func (clbs *ChannelLoadBalancerStub) RemoveChannel(pipe string) error { - return clbs.RemoveChannelCalled(pipe) -} - -// GetChannelOrDefault - -func (clbs *ChannelLoadBalancerStub) GetChannelOrDefault(pipe string) chan *p2p.SendableData { - return clbs.GetChannelOrDefaultCalled(pipe) -} - -// CollectOneElementFromChannels - -func (clbs *ChannelLoadBalancerStub) CollectOneElementFromChannels() *p2p.SendableData { - return clbs.CollectOneElementFromChannelsCalled() -} - -// Close - -func (clbs *ChannelLoadBalancerStub) Close() error { - if clbs.CloseCalled != nil { - return clbs.CloseCalled() - } - - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (clbs *ChannelLoadBalancerStub) IsInterfaceNil() bool { - return clbs == nil -} diff --git a/p2p/mock/connManagerNotifieeStub.go b/p2p/mock/connManagerNotifieeStub.go deleted file mode 100644 index 69bbbe3cefa..00000000000 --- a/p2p/mock/connManagerNotifieeStub.go +++ /dev/null @@ -1,103 +0,0 @@ -package mock - -import ( - "context" - - "github.com/libp2p/go-libp2p-core/connmgr" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/multiformats/go-multiaddr" -) - -// ConnManagerNotifieeStub - -type ConnManagerNotifieeStub struct { - UpsertTagCalled func(p peer.ID, tag string, upsert func(int) int) - ProtectCalled func(id peer.ID, tag string) - UnprotectCalled func(id peer.ID, tag string) (protected bool) - CloseCalled func() error - TagPeerCalled func(p peer.ID, tag string, val int) - UntagPeerCalled func(p peer.ID, tag string) - GetTagInfoCalled func(p peer.ID) *connmgr.TagInfo - TrimOpenConnsCalled func(ctx context.Context) - ListenCalled func(netw network.Network, ma multiaddr.Multiaddr) - ListenCloseCalled func(netw network.Network, ma multiaddr.Multiaddr) - ConnectedCalled func(netw network.Network, conn network.Conn) - DisconnectedCalled func(netw network.Network, conn network.Conn) - OpenedStreamCalled func(netw network.Network, stream network.Stream) - ClosedStreamCalled func(netw network.Network, stream network.Stream) -} - -// UpsertTag - -func (cmns *ConnManagerNotifieeStub) UpsertTag(p peer.ID, tag string, upsert func(int) int) { - cmns.UpsertTagCalled(p, tag, upsert) -} - -// Protect - -func (cmns *ConnManagerNotifieeStub) Protect(id peer.ID, tag string) { - cmns.ProtectCalled(id, tag) -} - -// Unprotect - -func (cmns *ConnManagerNotifieeStub) Unprotect(id peer.ID, tag string) (protected bool) { - return cmns.UnprotectCalled(id, tag) -} - -// Close - -func (cmns *ConnManagerNotifieeStub) Close() error { - return cmns.CloseCalled() -} - -// TagPeer - -func (cmns *ConnManagerNotifieeStub) TagPeer(p peer.ID, tag string, val int) { - cmns.TagPeerCalled(p, tag, val) -} - -// UntagPeer - -func (cmns *ConnManagerNotifieeStub) UntagPeer(p peer.ID, tag string) { - cmns.UntagPeerCalled(p, tag) -} - -// GetTagInfo - -func (cmns *ConnManagerNotifieeStub) GetTagInfo(p peer.ID) *connmgr.TagInfo { - return cmns.GetTagInfoCalled(p) -} - -// TrimOpenConns - -func (cmns *ConnManagerNotifieeStub) TrimOpenConns(ctx context.Context) { - cmns.TrimOpenConnsCalled(ctx) -} - -// Notifee - -func (cmns *ConnManagerNotifieeStub) Notifee() network.Notifiee { - return cmns -} - -// Listen - -func (cmns *ConnManagerNotifieeStub) Listen(netw network.Network, ma multiaddr.Multiaddr) { - cmns.ListenCalled(netw, ma) -} - -// ListenClose - -func (cmns *ConnManagerNotifieeStub) ListenClose(netw network.Network, ma multiaddr.Multiaddr) { - cmns.ListenCloseCalled(netw, ma) -} - -// Connected - -func (cmns *ConnManagerNotifieeStub) Connected(netw network.Network, conn network.Conn) { - cmns.ConnectedCalled(netw, conn) -} - -// Disconnected - -func (cmns *ConnManagerNotifieeStub) Disconnected(netw network.Network, conn network.Conn) { - cmns.DisconnectedCalled(netw, conn) -} - -// OpenedStream - -func (cmns *ConnManagerNotifieeStub) OpenedStream(netw network.Network, stream network.Stream) { - cmns.OpenedStreamCalled(netw, stream) -} - -// ClosedStream - -func (cmns *ConnManagerNotifieeStub) ClosedStream(netw network.Network, stream network.Stream) { - cmns.ClosedStreamCalled(netw, stream) -} diff --git a/p2p/mock/connStub.go b/p2p/mock/connStub.go deleted file mode 100644 index fca91f61b9a..00000000000 --- a/p2p/mock/connStub.go +++ /dev/null @@ -1,103 +0,0 @@ -package mock - -import ( - "context" - - libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/multiformats/go-multiaddr" -) - -// ConnStub - -type ConnStub struct { - IDCalled func() string - CloseCalled func() error - LocalPeerCalled func() peer.ID - LocalPrivateKeyCalled func() libp2pCrypto.PrivKey - RemotePeerCalled func() peer.ID - RemotePublicKeyCalled func() libp2pCrypto.PubKey - LocalMultiaddrCalled func() multiaddr.Multiaddr - RemoteMultiaddrCalled func() multiaddr.Multiaddr - NewStreamCalled func(ctx context.Context) (network.Stream, error) - GetStreamsCalled func() []network.Stream - StatCalled func() network.ConnStats - ScopeCalled func() network.ConnScope -} - -// ID - -func (cs *ConnStub) ID() string { - if cs.IDCalled != nil { - return cs.IDCalled() - } - - return "" -} - -// Close - -func (cs *ConnStub) Close() error { - return cs.CloseCalled() -} - -// LocalPeer - -func (cs *ConnStub) LocalPeer() peer.ID { - return cs.LocalPeerCalled() -} - -// LocalPrivateKey - -func (cs *ConnStub) LocalPrivateKey() libp2pCrypto.PrivKey { - return cs.LocalPrivateKeyCalled() -} - -// RemotePeer - -func (cs *ConnStub) RemotePeer() peer.ID { - return cs.RemotePeerCalled() -} - -// RemotePublicKey - -func (cs *ConnStub) RemotePublicKey() libp2pCrypto.PubKey { - return cs.RemotePublicKeyCalled() -} - -// LocalMultiaddr - -func (cs *ConnStub) LocalMultiaddr() multiaddr.Multiaddr { - return cs.LocalMultiaddrCalled() -} - -// RemoteMultiaddr - -func (cs *ConnStub) RemoteMultiaddr() multiaddr.Multiaddr { - if cs.RemoteMultiaddrCalled != nil { - return cs.RemoteMultiaddrCalled() - } - - ma, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/9999/p2p/16Uiu2HAkw5SNNtSvH1zJiQ6Gc3WoGNSxiyNueRKe6fuAuh57G3Bk") - return ma -} - -// NewStream - -func (cs *ConnStub) NewStream(ctx context.Context) (network.Stream, error) { - return cs.NewStreamCalled(ctx) -} - -// GetStreams - -func (cs *ConnStub) GetStreams() []network.Stream { - return cs.GetStreamsCalled() -} - -// Stat - -func (cs *ConnStub) Stat() network.ConnStats { - if cs.StatCalled != nil { - return cs.StatCalled() - } - - return network.ConnStats{} -} - -// Scope - -func (cs *ConnStub) Scope() network.ConnScope { - if cs.ScopeCalled != nil { - cs.ScopeCalled() - } - - return network.NullScope -} diff --git a/p2p/mock/connectableHostStub.go b/p2p/mock/connectableHostStub.go deleted file mode 100644 index 53038610f09..00000000000 --- a/p2p/mock/connectableHostStub.go +++ /dev/null @@ -1,172 +0,0 @@ -package mock - -import ( - "context" - "errors" - - "github.com/libp2p/go-libp2p-core/connmgr" - "github.com/libp2p/go-libp2p-core/event" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-core/protocol" - "github.com/multiformats/go-multiaddr" -) - -// ConnectableHostStub - -type ConnectableHostStub struct { - EventBusCalled func() event.Bus - IDCalled func() peer.ID - PeerstoreCalled func() peerstore.Peerstore - AddrsCalled func() []multiaddr.Multiaddr - NetworkCalled func() network.Network - MuxCalled func() protocol.Switch - ConnectCalled func(ctx context.Context, pi peer.AddrInfo) error - SetStreamHandlerCalled func(pid protocol.ID, handler network.StreamHandler) - SetStreamHandlerMatchCalled func(protocol.ID, func(string) bool, network.StreamHandler) - RemoveStreamHandlerCalled func(pid protocol.ID) - NewStreamCalled func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) - CloseCalled func() error - ConnManagerCalled func() connmgr.ConnManager - ConnectToPeerCalled func(ctx context.Context, address string) error - AddressToPeerInfoCalled func(address string) (*peer.AddrInfo, error) -} - -// EventBus - -func (hs *ConnectableHostStub) EventBus() event.Bus { - if hs.EventBusCalled != nil { - return hs.EventBusCalled() - } - - return &EventBusStub{} -} - -// ConnectToPeer - -func (hs *ConnectableHostStub) ConnectToPeer(ctx context.Context, address string) error { - if hs.ConnectToPeerCalled != nil { - return hs.ConnectToPeerCalled(ctx, address) - } - - return nil -} - -// ID - -func (hs *ConnectableHostStub) ID() peer.ID { - if hs.IDCalled != nil { - return hs.IDCalled() - } - - return "mock pid" -} - -// Peerstore - -func (hs *ConnectableHostStub) Peerstore() peerstore.Peerstore { - if hs.PeerstoreCalled != nil { - return hs.PeerstoreCalled() - } - - return nil -} - -// Addrs - -func (hs *ConnectableHostStub) Addrs() []multiaddr.Multiaddr { - if hs.AddrsCalled != nil { - return hs.AddrsCalled() - } - - return make([]multiaddr.Multiaddr, 0) -} - -// Network - -func (hs *ConnectableHostStub) Network() network.Network { - if hs.NetworkCalled != nil { - return hs.NetworkCalled() - } - - return &NetworkStub{} -} - -// Mux - -func (hs *ConnectableHostStub) Mux() protocol.Switch { - if hs.MuxCalled != nil { - return hs.MuxCalled() - } - - return nil -} - -// Connect - -func (hs *ConnectableHostStub) Connect(ctx context.Context, pi peer.AddrInfo) error { - if hs.ConnectCalled != nil { - return hs.ConnectCalled(ctx, pi) - } - - return nil -} - -// SetStreamHandler - -func (hs *ConnectableHostStub) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) { - if hs.SetStreamHandlerCalled != nil { - hs.SetStreamHandlerCalled(pid, handler) - } -} - -// SetStreamHandlerMatch - -func (hs *ConnectableHostStub) SetStreamHandlerMatch(pid protocol.ID, handler func(string) bool, streamHandler network.StreamHandler) { - if hs.SetStreamHandlerMatchCalled != nil { - hs.SetStreamHandlerMatchCalled(pid, handler, streamHandler) - } -} - -// RemoveStreamHandler - -func (hs *ConnectableHostStub) RemoveStreamHandler(pid protocol.ID) { - if hs.RemoveStreamHandlerCalled != nil { - hs.RemoveStreamHandlerCalled(pid) - } -} - -// NewStream - -func (hs *ConnectableHostStub) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { - if hs.NewStreamCalled != nil { - return hs.NewStreamCalled(ctx, p, pids...) - } - - return nil, errors.New("no stream") -} - -// Close - -func (hs *ConnectableHostStub) Close() error { - if hs.CloseCalled != nil { - return hs.CloseCalled() - } - - return nil -} - -// ConnManager - -func (hs *ConnectableHostStub) ConnManager() connmgr.ConnManager { - if hs.ConnManagerCalled != nil { - return hs.ConnManagerCalled() - } - - return nil -} - -// AddressToPeerInfo - -func (hs *ConnectableHostStub) AddressToPeerInfo(address string) (*peer.AddrInfo, error) { - if hs.AddressToPeerInfoCalled != nil { - return hs.AddressToPeerInfoCalled(address) - } - - multiAddr, err := multiaddr.NewMultiaddr(address) - if err != nil { - return nil, err - } - - return peer.AddrInfoFromP2pAddr(multiAddr) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (hs *ConnectableHostStub) IsInterfaceNil() bool { - return hs == nil -} diff --git a/p2p/mock/connectionMonitorStub.go b/p2p/mock/connectionMonitorStub.go deleted file mode 100644 index 13b6702222e..00000000000 --- a/p2p/mock/connectionMonitorStub.go +++ /dev/null @@ -1,96 +0,0 @@ -package mock - -import ( - "github.com/libp2p/go-libp2p-core/network" - "github.com/multiformats/go-multiaddr" -) - -// ConnectionMonitorStub - -type ConnectionMonitorStub struct { - ListenCalled func(netw network.Network, ma multiaddr.Multiaddr) - ListenCloseCalled func(netw network.Network, ma multiaddr.Multiaddr) - ConnectedCalled func(netw network.Network, conn network.Conn) - DisconnectedCalled func(netw network.Network, conn network.Conn) - OpenedStreamCalled func(netw network.Network, stream network.Stream) - ClosedStreamCalled func(netw network.Network, stream network.Stream) - IsConnectedToTheNetworkCalled func(netw network.Network) bool - SetThresholdMinConnectedPeersCalled func(thresholdMinConnectedPeers int, netw network.Network) - ThresholdMinConnectedPeersCalled func() int -} - -// Listen - -func (cms *ConnectionMonitorStub) Listen(netw network.Network, ma multiaddr.Multiaddr) { - if cms.ListenCalled != nil { - cms.ListenCalled(netw, ma) - } -} - -// ListenClose - -func (cms *ConnectionMonitorStub) ListenClose(netw network.Network, ma multiaddr.Multiaddr) { - if cms.ListenCloseCalled != nil { - cms.ListenCloseCalled(netw, ma) - } -} - -// Connected - -func (cms *ConnectionMonitorStub) Connected(netw network.Network, conn network.Conn) { - if cms.ConnectedCalled != nil { - cms.ConnectedCalled(netw, conn) - } -} - -// Disconnected - -func (cms *ConnectionMonitorStub) Disconnected(netw network.Network, conn network.Conn) { - if cms.DisconnectedCalled != nil { - cms.DisconnectedCalled(netw, conn) - } -} - -// OpenedStream - -func (cms *ConnectionMonitorStub) OpenedStream(netw network.Network, stream network.Stream) { - if cms.OpenedStreamCalled != nil { - cms.OpenedStreamCalled(netw, stream) - } -} - -// ClosedStream - -func (cms *ConnectionMonitorStub) ClosedStream(netw network.Network, stream network.Stream) { - if cms.ClosedStreamCalled != nil { - cms.ClosedStreamCalled(netw, stream) - } -} - -// IsConnectedToTheNetwork - -func (cms *ConnectionMonitorStub) IsConnectedToTheNetwork(netw network.Network) bool { - if cms.IsConnectedToTheNetworkCalled != nil { - return cms.IsConnectedToTheNetworkCalled(netw) - } - - return false -} - -// SetThresholdMinConnectedPeers - -func (cms *ConnectionMonitorStub) SetThresholdMinConnectedPeers(thresholdMinConnectedPeers int, netw network.Network) { - if cms.SetThresholdMinConnectedPeersCalled != nil { - cms.SetThresholdMinConnectedPeersCalled(thresholdMinConnectedPeers, netw) - } -} - -// ThresholdMinConnectedPeers - -func (cms *ConnectionMonitorStub) ThresholdMinConnectedPeers() int { - if cms.ThresholdMinConnectedPeersCalled != nil { - return cms.ThresholdMinConnectedPeersCalled() - } - - return 0 -} - -// Close - -func (cms *ConnectionMonitorStub) Close() error { - return nil -} - -// IsInterfaceNil - -func (cms *ConnectionMonitorStub) IsInterfaceNil() bool { - return cms == nil -} diff --git a/p2p/mock/connectionsWatcherStub.go b/p2p/mock/connectionsWatcherStub.go deleted file mode 100644 index c6479167ae4..00000000000 --- a/p2p/mock/connectionsWatcherStub.go +++ /dev/null @@ -1,30 +0,0 @@ -package mock - -import "github.com/ElrondNetwork/elrond-go-core/core" - -// ConnectionsWatcherStub - -type ConnectionsWatcherStub struct { - NewKnownConnectionCalled func(pid core.PeerID, connection string) - CloseCalled func() error -} - -// NewKnownConnection - -func (stub *ConnectionsWatcherStub) NewKnownConnection(pid core.PeerID, connection string) { - if stub.NewKnownConnectionCalled != nil { - stub.NewKnownConnectionCalled(pid, connection) - } -} - -// Close - -func (stub *ConnectionsWatcherStub) Close() error { - if stub.CloseCalled != nil { - return stub.CloseCalled() - } - - return nil -} - -// IsInterfaceNil - -func (stub *ConnectionsWatcherStub) IsInterfaceNil() bool { - return stub == nil -} diff --git a/p2p/mock/contextProviderMock.go b/p2p/mock/contextProviderMock.go deleted file mode 100644 index 7156dc72887..00000000000 --- a/p2p/mock/contextProviderMock.go +++ /dev/null @@ -1,19 +0,0 @@ -package mock - -import ( - "context" -) - -// ContextProviderMock - -type ContextProviderMock struct { -} - -// Context - -func (*ContextProviderMock) Context() context.Context { - panic("implement me") -} - -// IsInterfaceNil returns true if there is no value under the interface -func (c *ContextProviderMock) IsInterfaceNil() bool { - return c == nil -} diff --git a/p2p/mock/eventBusStub.go b/p2p/mock/eventBusStub.go deleted file mode 100644 index e9e245b2bdd..00000000000 --- a/p2p/mock/eventBusStub.go +++ /dev/null @@ -1,41 +0,0 @@ -package mock - -import ( - "reflect" - - "github.com/libp2p/go-libp2p-core/event" -) - -// EventBusStub - -type EventBusStub struct { - SubscribeCalled func(eventType interface{}, opts ...event.SubscriptionOpt) (event.Subscription, error) - EmitterCalled func(eventType interface{}, opts ...event.EmitterOpt) (event.Emitter, error) - GetAllEventTypesCalled func() []reflect.Type -} - -// Subscribe - -func (ebs *EventBusStub) Subscribe(eventType interface{}, opts ...event.SubscriptionOpt) (event.Subscription, error) { - if ebs.SubscribeCalled != nil { - return ebs.SubscribeCalled(eventType, opts...) - } - - return &EventSubscriptionStub{}, nil -} - -// Emitter - -func (ebs *EventBusStub) Emitter(eventType interface{}, opts ...event.EmitterOpt) (event.Emitter, error) { - if ebs.EmitterCalled != nil { - return ebs.EmitterCalled(eventType, opts...) - } - - return nil, nil -} - -// GetAllEventTypes - -func (ebs *EventBusStub) GetAllEventTypes() []reflect.Type { - if ebs.GetAllEventTypesCalled != nil { - return ebs.GetAllEventTypesCalled() - } - - return make([]reflect.Type, 0) -} diff --git a/p2p/mock/eventSubscriptionStub.go b/p2p/mock/eventSubscriptionStub.go deleted file mode 100644 index d9713512f2a..00000000000 --- a/p2p/mock/eventSubscriptionStub.go +++ /dev/null @@ -1,25 +0,0 @@ -package mock - -// EventSubscriptionStub - -type EventSubscriptionStub struct { - CloseCalled func() error - OutCalled func() <-chan interface{} -} - -// Close - -func (ess *EventSubscriptionStub) Close() error { - if ess.CloseCalled != nil { - return ess.CloseCalled() - } - - return nil -} - -// Out - -func (ess *EventSubscriptionStub) Out() <-chan interface{} { - if ess.OutCalled != nil { - return ess.OutCalled() - } - - return make(chan interface{}) -} diff --git a/p2p/mock/kadDhtHandlerStub.go b/p2p/mock/kadDhtHandlerStub.go deleted file mode 100644 index 6f105c0213b..00000000000 --- a/p2p/mock/kadDhtHandlerStub.go +++ /dev/null @@ -1,17 +0,0 @@ -package mock - -import "context" - -// KadDhtHandlerStub - -type KadDhtHandlerStub struct { - BootstrapCalled func(ctx context.Context) error -} - -// Bootstrap - -func (kdhs *KadDhtHandlerStub) Bootstrap(ctx context.Context) error { - if kdhs.BootstrapCalled != nil { - return kdhs.BootstrapCalled(ctx) - } - - return nil -} diff --git a/p2p/mock/kadSharderStub.go b/p2p/mock/kadSharderStub.go deleted file mode 100644 index 0e16e59ea9e..00000000000 --- a/p2p/mock/kadSharderStub.go +++ /dev/null @@ -1,64 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/libp2p/go-libp2p-core/peer" -) - -// KadSharderStub - -type KadSharderStub struct { - ComputeEvictListCalled func(pidList []peer.ID) []peer.ID - HasCalled func(pid peer.ID, list []peer.ID) bool - SetPeerShardResolverCalled func(psp p2p.PeerShardResolver) error - SetSeedersCalled func(addresses []string) - IsSeederCalled func(pid core.PeerID) bool -} - -// ComputeEvictionList - -func (kss *KadSharderStub) ComputeEvictionList(pidList []peer.ID) []peer.ID { - if kss.ComputeEvictListCalled != nil { - return kss.ComputeEvictListCalled(pidList) - } - - return make([]peer.ID, 0) -} - -// Has - -func (kss *KadSharderStub) Has(pid peer.ID, list []peer.ID) bool { - if kss.HasCalled != nil { - return kss.HasCalled(pid, list) - } - - return false -} - -// SetPeerShardResolver - -func (kss *KadSharderStub) SetPeerShardResolver(psp p2p.PeerShardResolver) error { - if kss.SetPeerShardResolverCalled != nil { - return kss.SetPeerShardResolverCalled(psp) - } - - return nil -} - -// SetSeeders - -func (kss *KadSharderStub) SetSeeders(addresses []string) { - if kss.SetSeedersCalled != nil { - kss.SetSeedersCalled(addresses) - } -} - -// IsSeeder - -func (kss *KadSharderStub) IsSeeder(pid core.PeerID) bool { - if kss.IsSeederCalled != nil { - return kss.IsSeederCalled(pid) - } - - return false -} - -// IsInterfaceNil - -func (kss *KadSharderStub) IsInterfaceNil() bool { - return kss == nil -} diff --git a/p2p/mock/marshalizerStub.go b/p2p/mock/marshalizerStub.go deleted file mode 100644 index 493f0313201..00000000000 --- a/p2p/mock/marshalizerStub.go +++ /dev/null @@ -1,22 +0,0 @@ -package mock - -// MarshalizerStub - -type MarshalizerStub struct { - MarshalCalled func(obj interface{}) ([]byte, error) - UnmarshalCalled func(obj interface{}, buff []byte) error -} - -// Marshal - -func (ms *MarshalizerStub) Marshal(obj interface{}) ([]byte, error) { - return ms.MarshalCalled(obj) -} - -// Unmarshal - -func (ms *MarshalizerStub) Unmarshal(obj interface{}, buff []byte) error { - return ms.UnmarshalCalled(obj, buff) -} - -// IsInterfaceNil - -func (ms *MarshalizerStub) IsInterfaceNil() bool { - return ms == nil -} diff --git a/p2p/mock/messageProcessorStub.go b/p2p/mock/messageProcessorStub.go deleted file mode 100644 index 9f404d38a74..00000000000 --- a/p2p/mock/messageProcessorStub.go +++ /dev/null @@ -1,25 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -// MessageProcessorStub - -type MessageProcessorStub struct { - ProcessMessageCalled func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error -} - -// ProcessReceivedMessage - -func (mps *MessageProcessorStub) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { - if mps.ProcessMessageCalled != nil { - return mps.ProcessMessageCalled(message, fromConnectedPeer) - } - - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (mps *MessageProcessorStub) IsInterfaceNil() bool { - return mps == nil -} diff --git a/p2p/mock/multiaddrStub.go b/p2p/mock/multiaddrStub.go deleted file mode 100644 index 4a37f423b1c..00000000000 --- a/p2p/mock/multiaddrStub.go +++ /dev/null @@ -1,137 +0,0 @@ -package mock - -import "github.com/multiformats/go-multiaddr" - -// MultiaddrStub - -type MultiaddrStub struct { - MarshalJSONCalled func() ([]byte, error) - UnmarshalJSONCalled func(bytes []byte) error - MarshalTextCalled func() (text []byte, err error) - UnmarshalTextCalled func(text []byte) error - MarshalBinaryCalled func() (data []byte, err error) - UnmarshalBinaryCalled func(data []byte) error - EqualCalled func(multiaddr multiaddr.Multiaddr) bool - BytesCalled func() []byte - StringCalled func() string - ProtocolsCalled func() []multiaddr.Protocol - EncapsulateCalled func(multiaddr multiaddr.Multiaddr) multiaddr.Multiaddr - DecapsulateCalled func(multiaddr multiaddr.Multiaddr) multiaddr.Multiaddr - ValueForProtocolCalled func(code int) (string, error) -} - -// MarshalJSON - -func (mas *MultiaddrStub) MarshalJSON() ([]byte, error) { - if mas.MarshalJSONCalled != nil { - return mas.MarshalJSONCalled() - } - - return nil, nil -} - -// UnmarshalJSON - -func (mas *MultiaddrStub) UnmarshalJSON(bytes []byte) error { - if mas.UnmarshalJSONCalled != nil { - return mas.UnmarshalJSONCalled(bytes) - } - - return nil -} - -// MarshalText - -func (mas *MultiaddrStub) MarshalText() (text []byte, err error) { - if mas.MarshalTextCalled != nil { - return mas.MarshalTextCalled() - } - - return nil, err -} - -// UnmarshalText - -func (mas *MultiaddrStub) UnmarshalText(text []byte) error { - if mas.UnmarshalTextCalled != nil { - return mas.UnmarshalTextCalled(text) - } - - return nil -} - -// MarshalBinary - -func (mas *MultiaddrStub) MarshalBinary() (data []byte, err error) { - if mas.MarshalBinaryCalled != nil { - return mas.MarshalBinaryCalled() - } - - return nil, nil -} - -// UnmarshalBinary - -func (mas *MultiaddrStub) UnmarshalBinary(data []byte) error { - if mas.UnmarshalBinaryCalled != nil { - return mas.UnmarshalBinaryCalled(data) - } - - return nil -} - -// Equal - -func (mas *MultiaddrStub) Equal(multiaddr multiaddr.Multiaddr) bool { - if mas.EqualCalled != nil { - return mas.EqualCalled(multiaddr) - } - - return false -} - -// Bytes - -func (mas *MultiaddrStub) Bytes() []byte { - if mas.BytesCalled != nil { - return mas.BytesCalled() - } - - return nil -} - -// String - -func (mas *MultiaddrStub) String() string { - if mas.StringCalled != nil { - return mas.StringCalled() - } - - return "" -} - -// Protocols - -func (mas *MultiaddrStub) Protocols() []multiaddr.Protocol { - if mas.ProtocolsCalled != nil { - return mas.ProtocolsCalled() - } - - return nil -} - -// Encapsulate - -func (mas *MultiaddrStub) Encapsulate(multiaddr multiaddr.Multiaddr) multiaddr.Multiaddr { - if mas.EncapsulateCalled != nil { - return mas.EncapsulateCalled(multiaddr) - } - - return nil -} - -// Decapsulate - -func (mas *MultiaddrStub) Decapsulate(multiaddr multiaddr.Multiaddr) multiaddr.Multiaddr { - if mas.DecapsulateCalled != nil { - return mas.DecapsulateCalled(multiaddr) - } - - return nil -} - -// ValueForProtocol - -func (mas *MultiaddrStub) ValueForProtocol(code int) (string, error) { - if mas.ValueForProtocolCalled != nil { - return mas.ValueForProtocolCalled(code) - } - - return "", nil -} diff --git a/p2p/mock/networkShardingCollectorMock.go b/p2p/mock/networkShardingCollectorMock.go deleted file mode 100644 index 750f3dbffb6..00000000000 --- a/p2p/mock/networkShardingCollectorMock.go +++ /dev/null @@ -1,73 +0,0 @@ -package mock - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" -) - -type networkShardingCollectorMock struct { - mutPeerIdPkMap sync.RWMutex - peerIdPkMap map[core.PeerID][]byte - - mutFallbackPkShardMap sync.RWMutex - fallbackPkShardMap map[string]uint32 - - mutFallbackPidShardMap sync.RWMutex - fallbackPidShardMap map[string]uint32 - - mutPeerIdSubType sync.RWMutex - peerIdSubType map[core.PeerID]uint32 -} - -// NewNetworkShardingCollectorMock - -func NewNetworkShardingCollectorMock() *networkShardingCollectorMock { - return &networkShardingCollectorMock{ - peerIdPkMap: make(map[core.PeerID][]byte), - peerIdSubType: make(map[core.PeerID]uint32), - fallbackPkShardMap: make(map[string]uint32), - fallbackPidShardMap: make(map[string]uint32), - } -} - -// UpdatePeerIdPublicKey - -func (nscm *networkShardingCollectorMock) UpdatePeerIDInfo(pid core.PeerID, pk []byte, shardID uint32) { - nscm.mutPeerIdPkMap.Lock() - nscm.peerIdPkMap[pid] = pk - nscm.mutPeerIdPkMap.Unlock() - - if shardID == core.AllShardId { - return - } - - nscm.mutFallbackPkShardMap.Lock() - nscm.fallbackPkShardMap[string(pk)] = shardID - nscm.mutFallbackPkShardMap.Unlock() - - nscm.mutFallbackPidShardMap.Lock() - nscm.fallbackPidShardMap[string(pid)] = shardID - nscm.mutFallbackPidShardMap.Unlock() -} - -// PutPeerIdSubType - -func (nscm *networkShardingCollectorMock) PutPeerIdSubType(pid core.PeerID, peerSubType core.P2PPeerSubType) { - nscm.mutPeerIdSubType.Lock() - nscm.peerIdSubType[pid] = uint32(peerSubType) - nscm.mutPeerIdSubType.Unlock() -} - -// GetPeerInfo - -func (nscm *networkShardingCollectorMock) GetPeerInfo(pid core.PeerID) core.P2PPeerInfo { - nscm.mutPeerIdSubType.Lock() - defer nscm.mutPeerIdSubType.Unlock() - - return core.P2PPeerInfo{ - PeerType: core.ObserverPeer, - PeerSubType: core.P2PPeerSubType(nscm.peerIdSubType[pid]), - } -} - -// IsInterfaceNil - -func (nscm *networkShardingCollectorMock) IsInterfaceNil() bool { - return nscm == nil -} diff --git a/p2p/mock/networkStub.go b/p2p/mock/networkStub.go deleted file mode 100644 index e6958410075..00000000000 --- a/p2p/mock/networkStub.go +++ /dev/null @@ -1,140 +0,0 @@ -package mock - -import ( - "context" - "errors" - - "github.com/jbenet/goprocess" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/multiformats/go-multiaddr" -) - -// NetworkStub - -type NetworkStub struct { - ConnsToPeerCalled func(p peer.ID) []network.Conn - ConnsCalled func() []network.Conn - ConnectednessCalled func(peer.ID) network.Connectedness - NotifyCalled func(network.Notifiee) - StopNotifyCalled func(network.Notifiee) - PeersCall func() []peer.ID - ClosePeerCall func(peer.ID) error - ResourceManagerCalled func() network.ResourceManager -} - -// ResourceManager - -func (ns *NetworkStub) ResourceManager() network.ResourceManager { - if ns.ResourceManagerCalled != nil { - return ns.ResourceManagerCalled() - } - - return nil -} - -// Peerstore - -func (ns *NetworkStub) Peerstore() peerstore.Peerstore { - return nil -} - -// LocalPeer - -func (ns *NetworkStub) LocalPeer() peer.ID { - return "not a peer" -} - -// DialPeer - -func (ns *NetworkStub) DialPeer(_ context.Context, _ peer.ID) (network.Conn, error) { - return nil, errors.New("dial error") -} - -// ClosePeer - -func (ns *NetworkStub) ClosePeer(pid peer.ID) error { - if ns.ClosePeerCall != nil { - return ns.ClosePeerCall(pid) - } - - return nil -} - -// Connectedness - -func (ns *NetworkStub) Connectedness(pid peer.ID) network.Connectedness { - if ns.ConnectednessCalled != nil { - return ns.ConnectednessCalled(pid) - } - - return network.NotConnected -} - -// Peers - -func (ns *NetworkStub) Peers() []peer.ID { - if ns.PeersCall != nil { - return ns.PeersCall() - } - - return make([]peer.ID, 0) -} - -// Conns - -func (ns *NetworkStub) Conns() []network.Conn { - if ns.ConnsCalled != nil { - return ns.ConnsCalled() - } - - return make([]network.Conn, 0) -} - -// ConnsToPeer - -func (ns *NetworkStub) ConnsToPeer(p peer.ID) []network.Conn { - if ns.ConnsToPeerCalled != nil { - return ns.ConnsToPeerCalled(p) - } - - return make([]network.Conn, 0) -} - -// Notify - -func (ns *NetworkStub) Notify(notifee network.Notifiee) { - if ns.NotifyCalled != nil { - ns.NotifyCalled(notifee) - } -} - -// StopNotify - -func (ns *NetworkStub) StopNotify(notifee network.Notifiee) { - if ns.StopNotifyCalled != nil { - ns.StopNotifyCalled(notifee) - } -} - -// Close - -func (ns *NetworkStub) Close() error { - return nil -} - -// SetStreamHandler - -func (ns *NetworkStub) SetStreamHandler(network.StreamHandler) {} - -// NewStream - -func (ns *NetworkStub) NewStream(context.Context, peer.ID) (network.Stream, error) { - return nil, errors.New("new stream error") -} - -// Listen - -func (ns *NetworkStub) Listen(...multiaddr.Multiaddr) error { - return nil -} - -// ListenAddresses - -func (ns *NetworkStub) ListenAddresses() []multiaddr.Multiaddr { - return make([]multiaddr.Multiaddr, 0) -} - -// InterfaceListenAddresses - -func (ns *NetworkStub) InterfaceListenAddresses() ([]multiaddr.Multiaddr, error) { - return make([]multiaddr.Multiaddr, 0), nil -} - -// Process - -func (ns *NetworkStub) Process() goprocess.Process { - return nil -} diff --git a/p2p/mock/p2pMessageMock.go b/p2p/mock/p2pMessageMock.go deleted file mode 100644 index db267138a83..00000000000 --- a/p2p/mock/p2pMessageMock.go +++ /dev/null @@ -1,68 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" -) - -// P2PMessageMock - -type P2PMessageMock struct { - FromField []byte - DataField []byte - SeqNoField []byte - TopicField string - SignatureField []byte - KeyField []byte - PeerField core.PeerID - PayloadField []byte - TimestampField int64 -} - -// From - -func (msg *P2PMessageMock) From() []byte { - return msg.FromField -} - -// Data - -func (msg *P2PMessageMock) Data() []byte { - return msg.DataField -} - -// SeqNo - -func (msg *P2PMessageMock) SeqNo() []byte { - return msg.SeqNoField -} - -// Topic - -func (msg *P2PMessageMock) Topic() string { - return msg.TopicField -} - -// Signature - -func (msg *P2PMessageMock) Signature() []byte { - return msg.SignatureField -} - -// Key - -func (msg *P2PMessageMock) Key() []byte { - return msg.KeyField -} - -// Peer - -func (msg *P2PMessageMock) Peer() core.PeerID { - return msg.PeerField -} - -// Timestamp - -func (msg *P2PMessageMock) Timestamp() int64 { - return msg.TimestampField -} - -// Payload - -func (msg *P2PMessageMock) Payload() []byte { - return msg.PayloadField -} - -// IsInterfaceNil returns true if there is no value under the interface -func (msg *P2PMessageMock) IsInterfaceNil() bool { - return msg == nil -} diff --git a/p2p/mock/peerDenialEvaluatorStub.go b/p2p/mock/peerDenialEvaluatorStub.go deleted file mode 100644 index a78f84891ad..00000000000 --- a/p2p/mock/peerDenialEvaluatorStub.go +++ /dev/null @@ -1,32 +0,0 @@ -package mock - -import ( - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" -) - -// PeerDenialEvaluatorStub - -type PeerDenialEvaluatorStub struct { - UpsertPeerIDCalled func(pid core.PeerID, duration time.Duration) error - IsDeniedCalled func(pid core.PeerID) bool -} - -// UpsertPeerID - -func (pdes *PeerDenialEvaluatorStub) UpsertPeerID(pid core.PeerID, duration time.Duration) error { - if pdes.UpsertPeerIDCalled != nil { - return pdes.UpsertPeerIDCalled(pid, duration) - } - - return nil -} - -// IsDenied - -func (pdes *PeerDenialEvaluatorStub) IsDenied(pid core.PeerID) bool { - return pdes.IsDeniedCalled(pid) -} - -// IsInterfaceNil - -func (pdes *PeerDenialEvaluatorStub) IsInterfaceNil() bool { - return pdes == nil -} diff --git a/p2p/mock/peerDiscovererStub.go b/p2p/mock/peerDiscovererStub.go deleted file mode 100644 index e81417aef06..00000000000 --- a/p2p/mock/peerDiscovererStub.go +++ /dev/null @@ -1,22 +0,0 @@ -package mock - -// PeerDiscovererStub - -type PeerDiscovererStub struct { - BootstrapCalled func() error - CloseCalled func() error -} - -// Bootstrap - -func (pds *PeerDiscovererStub) Bootstrap() error { - return pds.BootstrapCalled() -} - -// Name - -func (pds *PeerDiscovererStub) Name() string { - return "PeerDiscovererStub" -} - -// IsInterfaceNil returns true if there is no value under the interface -func (pds *PeerDiscovererStub) IsInterfaceNil() bool { - return pds == nil -} diff --git a/p2p/mock/peerShardResolverStub.go b/p2p/mock/peerShardResolverStub.go deleted file mode 100644 index 28e9f6baa26..00000000000 --- a/p2p/mock/peerShardResolverStub.go +++ /dev/null @@ -1,20 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" -) - -// PeerShardResolverStub - -type PeerShardResolverStub struct { - GetPeerInfoCalled func(pid core.PeerID) core.P2PPeerInfo -} - -// GetPeerInfo - -func (psrs *PeerShardResolverStub) GetPeerInfo(pid core.PeerID) core.P2PPeerInfo { - return psrs.GetPeerInfoCalled(pid) -} - -// IsInterfaceNil - -func (psrs *PeerShardResolverStub) IsInterfaceNil() bool { - return psrs == nil -} diff --git a/p2p/mock/peerstoreStub.go b/p2p/mock/peerstoreStub.go deleted file mode 100644 index ed068deb7f4..00000000000 --- a/p2p/mock/peerstoreStub.go +++ /dev/null @@ -1,278 +0,0 @@ -package mock - -import ( - "context" - "time" - - libp2pCrypto "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/multiformats/go-multiaddr" -) - -// PeerstoreStub - -type PeerstoreStub struct { - CloseCalled func() error - AddAddrCalled func(p peer.ID, addr multiaddr.Multiaddr, ttl time.Duration) - AddAddrsCalled func(p peer.ID, addrs []multiaddr.Multiaddr, ttl time.Duration) - SetAddrCalled func(p peer.ID, addr multiaddr.Multiaddr, ttl time.Duration) - SetAddrsCalled func(p peer.ID, addrs []multiaddr.Multiaddr, ttl time.Duration) - UpdateAddrsCalled func(p peer.ID, oldTTL time.Duration, newTTL time.Duration) - AddrsCalled func(p peer.ID) []multiaddr.Multiaddr - AddrStreamCalled func(ctx context.Context, id peer.ID) <-chan multiaddr.Multiaddr - ClearAddrsCalled func(p peer.ID) - PeersWithAddrsCalled func() peer.IDSlice - PubKeyCalled func(id peer.ID) libp2pCrypto.PubKey - AddPubKeyCalled func(id peer.ID, key libp2pCrypto.PubKey) error - PrivKeyCalled func(id peer.ID) libp2pCrypto.PrivKey - AddPrivKeyCalled func(id peer.ID, key libp2pCrypto.PrivKey) error - PeersWithKeysCalled func() peer.IDSlice - GetCalled func(p peer.ID, key string) (interface{}, error) - PutCalled func(p peer.ID, key string, val interface{}) error - RecordLatencyCalled func(id peer.ID, duration time.Duration) - LatencyEWMACalled func(id peer.ID) time.Duration - GetProtocolsCalled func(id peer.ID) ([]string, error) - AddProtocolsCalled func(id peer.ID, s ...string) error - SetProtocolsCalled func(id peer.ID, s ...string) error - RemoveProtocolsCalled func(id peer.ID, s ...string) error - SupportsProtocolsCalled func(id peer.ID, s ...string) ([]string, error) - FirstSupportedProtocolCalled func(id peer.ID, s ...string) (string, error) - PeerInfoCalled func(id peer.ID) peer.AddrInfo - PeersCalled func() peer.IDSlice - RemovePeerCalled func(id peer.ID) -} - -// Close - -func (ps *PeerstoreStub) Close() error { - if ps.CloseCalled != nil { - return ps.CloseCalled() - } - - return nil -} - -// AddAddr - -func (ps *PeerstoreStub) AddAddr(p peer.ID, addr multiaddr.Multiaddr, ttl time.Duration) { - if ps.AddAddrCalled != nil { - ps.AddAddrCalled(p, addr, ttl) - } -} - -// AddAddrs - -func (ps *PeerstoreStub) AddAddrs(p peer.ID, addrs []multiaddr.Multiaddr, ttl time.Duration) { - if ps.AddAddrsCalled != nil { - ps.AddAddrsCalled(p, addrs, ttl) - } -} - -// SetAddr - -func (ps *PeerstoreStub) SetAddr(p peer.ID, addr multiaddr.Multiaddr, ttl time.Duration) { - if ps.SetAddrCalled != nil { - ps.SetAddrCalled(p, addr, ttl) - } -} - -// SetAddrs - -func (ps *PeerstoreStub) SetAddrs(p peer.ID, addrs []multiaddr.Multiaddr, ttl time.Duration) { - if ps.SetAddrsCalled != nil { - ps.SetAddrsCalled(p, addrs, ttl) - } -} - -// UpdateAddrs - -func (ps *PeerstoreStub) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL time.Duration) { - if ps.UpdateAddrsCalled != nil { - ps.UpdateAddrsCalled(p, oldTTL, newTTL) - } -} - -// Addrs - -func (ps *PeerstoreStub) Addrs(p peer.ID) []multiaddr.Multiaddr { - if ps.AddrsCalled != nil { - return ps.AddrsCalled(p) - } - - return nil -} - -// AddrStream - -func (ps *PeerstoreStub) AddrStream(ctx context.Context, id peer.ID) <-chan multiaddr.Multiaddr { - if ps.AddrStreamCalled != nil { - return ps.AddrStreamCalled(ctx, id) - } - - return nil -} - -// ClearAddrs - -func (ps *PeerstoreStub) ClearAddrs(p peer.ID) { - if ps.ClearAddrsCalled != nil { - ps.ClearAddrsCalled(p) - } -} - -// PeersWithAddrs - -func (ps *PeerstoreStub) PeersWithAddrs() peer.IDSlice { - if ps.PeersWithAddrsCalled != nil { - return ps.PeersWithAddrsCalled() - } - - return nil -} - -// PubKey - -func (ps *PeerstoreStub) PubKey(id peer.ID) libp2pCrypto.PubKey { - if ps.PubKeyCalled != nil { - return ps.PubKeyCalled(id) - } - - return nil -} - -// AddPubKey - -func (ps *PeerstoreStub) AddPubKey(id peer.ID, key libp2pCrypto.PubKey) error { - if ps.AddPubKeyCalled != nil { - return ps.AddPubKeyCalled(id, key) - } - - return nil -} - -// PrivKey - -func (ps *PeerstoreStub) PrivKey(id peer.ID) libp2pCrypto.PrivKey { - if ps.PrivKeyCalled != nil { - return ps.PrivKeyCalled(id) - } - - return nil -} - -// AddPrivKey - -func (ps *PeerstoreStub) AddPrivKey(id peer.ID, key libp2pCrypto.PrivKey) error { - if ps.AddPrivKeyCalled != nil { - return ps.AddPrivKeyCalled(id, key) - } - - return nil -} - -// PeersWithKeys - -func (ps *PeerstoreStub) PeersWithKeys() peer.IDSlice { - if ps.PeersWithKeysCalled != nil { - return ps.PeersWithKeysCalled() - } - - return nil -} - -// Get - -func (ps *PeerstoreStub) Get(p peer.ID, key string) (interface{}, error) { - if ps.GetCalled != nil { - return ps.GetCalled(p, key) - } - - return nil, nil -} - -// Put - -func (ps *PeerstoreStub) Put(p peer.ID, key string, val interface{}) error { - if ps.PutCalled != nil { - return ps.PutCalled(p, key, val) - } - - return nil -} - -// RecordLatency - -func (ps *PeerstoreStub) RecordLatency(id peer.ID, duration time.Duration) { - if ps.RecordLatencyCalled != nil { - ps.RecordLatencyCalled(id, duration) - } -} - -// LatencyEWMA - -func (ps *PeerstoreStub) LatencyEWMA(id peer.ID) time.Duration { - if ps.LatencyEWMACalled != nil { - return ps.LatencyEWMACalled(id) - } - - return 0 -} - -// GetProtocols - -func (ps *PeerstoreStub) GetProtocols(id peer.ID) ([]string, error) { - if ps.GetProtocolsCalled != nil { - return ps.GetProtocolsCalled(id) - } - - return nil, nil -} - -// AddProtocols - -func (ps *PeerstoreStub) AddProtocols(id peer.ID, s ...string) error { - if ps.AddProtocolsCalled != nil { - return ps.AddProtocolsCalled(id, s...) - } - - return nil -} - -// SetProtocols - -func (ps *PeerstoreStub) SetProtocols(id peer.ID, s ...string) error { - if ps.SetProtocolsCalled != nil { - return ps.SetProtocolsCalled(id, s...) - } - - return nil -} - -// RemoveProtocols - -func (ps *PeerstoreStub) RemoveProtocols(id peer.ID, s ...string) error { - if ps.RemoveProtocolsCalled != nil { - return ps.RemoveProtocolsCalled(id, s...) - } - - return nil -} - -// SupportsProtocols - -func (ps *PeerstoreStub) SupportsProtocols(id peer.ID, s ...string) ([]string, error) { - if ps.SupportsProtocolsCalled != nil { - return ps.SupportsProtocolsCalled(id, s...) - } - - return nil, nil -} - -// FirstSupportedProtocol - -func (ps *PeerstoreStub) FirstSupportedProtocol(id peer.ID, s ...string) (string, error) { - if ps.FirstSupportedProtocolCalled != nil { - return ps.FirstSupportedProtocolCalled(id, s...) - } - - return "", nil -} - -// PeerInfo - -func (ps *PeerstoreStub) PeerInfo(id peer.ID) peer.AddrInfo { - if ps.PeerInfoCalled != nil { - return ps.PeerInfoCalled(id) - } - - return peer.AddrInfo{} -} - -// Peers - -func (ps *PeerstoreStub) Peers() peer.IDSlice { - if ps.PeersCalled != nil { - return ps.PeersCalled() - } - - return nil -} - -// RemovePeer - -func (ps *PeerstoreStub) RemovePeer(id peer.ID) { - if ps.RemovePeerCalled != nil { - ps.RemovePeerCalled(id) - } -} diff --git a/p2p/mock/reconnecterStub.go b/p2p/mock/reconnecterStub.go deleted file mode 100644 index d3bbaa82e3e..00000000000 --- a/p2p/mock/reconnecterStub.go +++ /dev/null @@ -1,22 +0,0 @@ -package mock - -import "context" - -// ReconnecterStub - -type ReconnecterStub struct { - ReconnectToNetworkCalled func(ctx context.Context) - PauseCall func() - ResumeCall func() -} - -// ReconnectToNetwork - -func (rs *ReconnecterStub) ReconnectToNetwork(ctx context.Context) { - if rs.ReconnectToNetworkCalled != nil { - rs.ReconnectToNetworkCalled(ctx) - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (rs *ReconnecterStub) IsInterfaceNil() bool { - return rs == nil -} diff --git a/p2p/mock/sharderStub.go b/p2p/mock/sharderStub.go deleted file mode 100644 index b1783ac1717..00000000000 --- a/p2p/mock/sharderStub.go +++ /dev/null @@ -1,43 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" -) - -// SharderStub - -type SharderStub struct { - SetPeerShardResolverCalled func(psp p2p.PeerShardResolver) error - SetSeedersCalled func(addresses []string) - IsSeederCalled func(pid core.PeerID) bool -} - -// SetPeerShardResolver - -func (ss *SharderStub) SetPeerShardResolver(psp p2p.PeerShardResolver) error { - if ss.SetPeerShardResolverCalled != nil { - return ss.SetPeerShardResolverCalled(psp) - } - - return nil -} - -// SetSeeders - -func (ss *SharderStub) SetSeeders(addresses []string) { - if ss.SetSeedersCalled != nil { - ss.SetSeedersCalled(addresses) - } -} - -// IsSeeder - -func (ss *SharderStub) IsSeeder(pid core.PeerID) bool { - if ss.IsSeederCalled != nil { - return ss.IsSeederCalled(pid) - } - - return false -} - -// IsInterfaceNil - -func (ss *SharderStub) IsInterfaceNil() bool { - return ss == nil -} diff --git a/p2p/mock/streamMock.go b/p2p/mock/streamMock.go deleted file mode 100644 index f6beac99d80..00000000000 --- a/p2p/mock/streamMock.go +++ /dev/null @@ -1,162 +0,0 @@ -package mock - -import ( - "bytes" - "io" - "sync" - "time" - - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/protocol" -) - -type streamMock struct { - mutData sync.Mutex - buffStream *bytes.Buffer - pid protocol.ID - streamClosed bool - canRead bool - conn network.Conn - id string -} - -// NewStreamMock - -func NewStreamMock() *streamMock { - return &streamMock{ - mutData: sync.Mutex{}, - buffStream: new(bytes.Buffer), - streamClosed: false, - canRead: false, - } -} - -// Read - -func (sm *streamMock) Read(p []byte) (int, error) { - // just a mock implementation of blocking read - for { - time.Sleep(time.Millisecond * 10) - - sm.mutData.Lock() - if sm.streamClosed { - sm.mutData.Unlock() - return 0, io.EOF - } - - if sm.canRead { - n, err := sm.buffStream.Read(p) - sm.canRead = false - sm.mutData.Unlock() - - return n, err - } - sm.mutData.Unlock() - } -} - -// Write - -func (sm *streamMock) Write(p []byte) (int, error) { - sm.mutData.Lock() - n, err := sm.buffStream.Write(p) - if err == nil { - sm.canRead = true - } - sm.mutData.Unlock() - - return n, err -} - -// Close - -func (sm *streamMock) Close() error { - sm.mutData.Lock() - defer sm.mutData.Unlock() - - sm.streamClosed = true - return nil -} - -// Reset - -func (sm *streamMock) Reset() error { - sm.mutData.Lock() - defer sm.mutData.Unlock() - - sm.buffStream.Reset() - sm.canRead = false - return nil -} - -// SetDeadline - -func (sm *streamMock) SetDeadline(time.Time) error { - panic("implement me") -} - -// SetReadDeadline - -func (sm *streamMock) SetReadDeadline(time.Time) error { - panic("implement me") -} - -// SetWriteDeadline - -func (sm *streamMock) SetWriteDeadline(time.Time) error { - panic("implement me") -} - -// Protocol - -func (sm *streamMock) Protocol() protocol.ID { - return sm.pid -} - -// SetProtocol - -func (sm *streamMock) SetProtocol(pid protocol.ID) error { - sm.pid = pid - - return nil -} - -// Stat - -func (sm *streamMock) Stat() network.Stats { - return network.Stats{ - Direction: network.DirOutbound, - } -} - -// Conn - -func (sm *streamMock) Conn() network.Conn { - return sm.conn -} - -// SetConn - -func (sm *streamMock) SetConn(conn network.Conn) { - sm.conn = conn -} - -// ID - -func (sm *streamMock) ID() string { - return sm.id -} - -// SetID - -func (sm *streamMock) SetID(id string) { - sm.id = id -} - -// CloseWrite - -func (sm *streamMock) CloseWrite() error { - sm.mutData.Lock() - defer sm.mutData.Unlock() - - sm.streamClosed = true - return nil -} - -// CloseRead - -func (sm *streamMock) CloseRead() error { - sm.mutData.Lock() - defer sm.mutData.Unlock() - - sm.streamClosed = true - return nil -} - -// Scope - -func (sm *streamMock) Scope() network.StreamScope { - return network.NullScope -} diff --git a/p2p/mock/syncTimerStub.go b/p2p/mock/syncTimerStub.go deleted file mode 100644 index aa7458b2a15..00000000000 --- a/p2p/mock/syncTimerStub.go +++ /dev/null @@ -1,22 +0,0 @@ -package mock - -import "time" - -// SyncTimerStub - -type SyncTimerStub struct { - CurrentTimeCalled func() time.Time -} - -// CurrentTime - -func (sts *SyncTimerStub) CurrentTime() time.Time { - if sts.CurrentTimeCalled != nil { - return sts.CurrentTimeCalled() - } - - return time.Time{} -} - -// IsInterfaceNil - -func (sts *SyncTimerStub) IsInterfaceNil() bool { - return sts == nil -} diff --git a/p2p/p2p.go b/p2p/p2p.go index e1692885bc6..f555be50eb1 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -1,362 +1,45 @@ package p2p import ( - "context" - "encoding/hex" - "io" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" -) - -const displayLastPidChars = 12 - -const ( - // ListsSharder is the variant that uses lists - ListsSharder = "ListsSharder" - // OneListSharder is the variant that is shard agnostic and uses one list - OneListSharder = "OneListSharder" - // NilListSharder is the variant that will not do connection trimming - NilListSharder = "NilListSharder" -) - -// NodeOperation defines the p2p node operation -type NodeOperation string - -// NormalOperation defines the normal mode operation: either seeder, observer or validator -const NormalOperation NodeOperation = "normal operation" - -// FullArchiveMode defines the node operation as a full archive mode -const FullArchiveMode NodeOperation = "full archive mode" - -const ( - // ConnectionWatcherTypePrint - new connection found will be printed in the log file - ConnectionWatcherTypePrint = "print" - // ConnectionWatcherTypeDisabled - no connection watching should be made - ConnectionWatcherTypeDisabled = "disabled" - // ConnectionWatcherTypeEmpty - not set, no connection watching should be made - ConnectionWatcherTypeEmpty = "" + p2p "github.com/ElrondNetwork/elrond-go-p2p" + "github.com/ElrondNetwork/elrond-go-p2p/libp2p" + "github.com/ElrondNetwork/elrond-go-p2p/libp2p/crypto" + "github.com/ElrondNetwork/elrond-go-p2p/message" + "github.com/ElrondNetwork/elrond-go-p2p/peersHolder" + "github.com/ElrondNetwork/elrond-go-p2p/rating" ) -// MessageProcessor is the interface used to describe what a receive message processor should do -// All implementations that will be called from Messenger implementation will need to satisfy this interface -// If the function returns a non nil value, the received message will not be propagated to its connected peers -type MessageProcessor interface { - ProcessReceivedMessage(message MessageP2P, fromConnectedPeer core.PeerID) error - IsInterfaceNil() bool -} +// ArgsNetworkMessenger defines the options used to create a p2p wrapper +type ArgsNetworkMessenger = libp2p.ArgsNetworkMessenger -// SendableData represents the struct used in data throttler implementation -type SendableData struct { - Buff []byte - Topic string - Sk crypto.PrivKey - ID peer.ID +// NewNetworkMessenger creates a libP2P messenger by opening a port on the current machine +func NewNetworkMessenger(args ArgsNetworkMessenger) (p2p.Messenger, error) { + return libp2p.NewNetworkMessenger(args) } -// PeerDiscoverer defines the behaviour of a peer discovery mechanism -type PeerDiscoverer interface { - Bootstrap() error - Name() string - IsInterfaceNil() bool -} - -// Reconnecter defines the behaviour of a network reconnection mechanism -type Reconnecter interface { - ReconnectToNetwork(ctx context.Context) - IsInterfaceNil() bool -} - -// Messenger is the main struct used for communication with other peers -type Messenger interface { - io.Closer - - // ID is the Messenger's unique peer identifier across the network (a - // string). It is derived from the public key of the P2P credentials. - ID() core.PeerID - - // Peers is the list of IDs of peers known to the Messenger. - Peers() []core.PeerID - - // Addresses is the list of addresses that the Messenger is currently bound - // to and listening to. - Addresses() []string - - // ConnectToPeer explicitly connect to a specific peer with a known address (note that the - // address contains the peer ID). This function is usually not called - // manually, because any underlying implementation of the Messenger interface - // should be keeping connections to peers open. - ConnectToPeer(address string) error - - // IsConnected returns true if the Messenger are connected to a specific peer. - IsConnected(peerID core.PeerID) bool +// LocalSyncTimer uses the local system to provide the current time +type LocalSyncTimer = libp2p.LocalSyncTimer - // ConnectedPeers returns the list of IDs of the peers the Messenger is - // currently connected to. - ConnectedPeers() []core.PeerID +// Message is a data holder struct +type Message = message.Message - // ConnectedAddresses returns the list of addresses of the peers to which the - // Messenger is currently connected. - ConnectedAddresses() []string +// DirectConnectionInfo represents the data regarding a new direct connection`s info +type DirectConnectionInfo = message.DirectConnectionInfo - // PeerAddresses returns the known addresses for the provided peer ID - PeerAddresses(pid core.PeerID) []string - - // ConnectedPeersOnTopic returns the IDs of the peers to which the Messenger - // is currently connected, but filtered by a topic they are registered to. - ConnectedPeersOnTopic(topic string) []core.PeerID - - // ConnectedFullHistoryPeersOnTopic returns the IDs of the full history peers to which the Messenger - // is currently connected, but filtered by a topic they are registered to. - ConnectedFullHistoryPeersOnTopic(topic string) []core.PeerID - - // Bootstrap runs the initialization phase which includes peer discovery, - // setting up initial connections and self-announcement in the network. - Bootstrap() error - - // CreateTopic defines a new topic for sending messages, and optionally - // creates a channel in the LoadBalancer for this topic (otherwise, the topic - // will use a default channel). - CreateTopic(name string, createChannelForTopic bool) error - - // HasTopic returns true if the Messenger has declared interest in a topic - // and it is listening to messages referencing it. - HasTopic(name string) bool - - // RegisterMessageProcessor adds the provided MessageProcessor to the list - // of handlers that are invoked whenever a message is received on the - // specified topic. - RegisterMessageProcessor(topic string, identifier string, handler MessageProcessor) error - - // UnregisterAllMessageProcessors removes all the MessageProcessor set by the - // Messenger from the list of registered handlers for the messages on the - // given topic. - UnregisterAllMessageProcessors() error - - // UnregisterMessageProcessor removes the MessageProcessor set by the - // Messenger from the list of registered handlers for the messages on the - // given topic. - UnregisterMessageProcessor(topic string, identifier string) error - - // BroadcastOnChannelBlocking asynchronously waits until it can send a - // message on the channel, but once it is able to, it synchronously sends the - // message, blocking until sending is completed. - BroadcastOnChannelBlocking(channel string, topic string, buff []byte) error - - // BroadcastOnChannel asynchronously sends a message on a given topic - // through a specified channel. - BroadcastOnChannel(channel string, topic string, buff []byte) - - // BroadcastUsingPrivateKey tries to send a byte buffer onto a topic using the topic name as channel - BroadcastUsingPrivateKey(topic string, buff []byte, pid core.PeerID, skBytes []byte) - - // Broadcast is a convenience function that calls BroadcastOnChannelBlocking, - // but implicitly sets the channel to be identical to the specified topic. - Broadcast(topic string, buff []byte) - - // SendToConnectedPeer asynchronously sends a message to a peer directly, - // bypassing pubsub and topics. It opens a new connection with the given - // peer, but reuses a connection and a stream if possible. - SendToConnectedPeer(topic string, buff []byte, peerID core.PeerID) error - - IsConnectedToTheNetwork() bool - ThresholdMinConnectedPeers() int - SetThresholdMinConnectedPeers(minConnectedPeers int) error - SetPeerShardResolver(peerShardResolver PeerShardResolver) error - SetPeerDenialEvaluator(handler PeerDenialEvaluator) error - GetConnectedPeersInfo() *ConnectedPeersInfo - UnjoinAllTopics() error - Port() int - WaitForConnections(maxWaitingTime time.Duration, minNumOfPeers uint32) - Sign(payload []byte) ([]byte, error) - Verify(payload []byte, pid core.PeerID, signature []byte) error - SignUsingPrivateKey(skBytes []byte, payload []byte) ([]byte, error) - - // IsInterfaceNil returns true if there is no value under the interface - IsInterfaceNil() bool -} - -// MessageP2P defines what a p2p message can do (should return) -type MessageP2P interface { - From() []byte - Data() []byte - Payload() []byte - SeqNo() []byte - Topic() string - Signature() []byte - Key() []byte - Peer() core.PeerID - Timestamp() int64 - IsInterfaceNil() bool -} - -// ChannelLoadBalancer defines what a load balancer that uses chans should do -type ChannelLoadBalancer interface { - AddChannel(channel string) error - RemoveChannel(channel string) error - GetChannelOrDefault(channel string) chan *SendableData - CollectOneElementFromChannels() *SendableData - Close() error - IsInterfaceNil() bool -} - -// DirectSender defines a component that can send direct messages to connected peers -type DirectSender interface { - NextSeqno() []byte - Send(topic string, buff []byte, peer core.PeerID) error - IsInterfaceNil() bool -} - -// PeerDiscoveryFactory defines the factory for peer discoverer implementation -type PeerDiscoveryFactory interface { - CreatePeerDiscoverer() (PeerDiscoverer, error) - IsInterfaceNil() bool -} - -// MessageOriginatorPid will output the message peer id in a pretty format -// If it can, it will display the last displayLastPidChars (12) characters from the pid -func MessageOriginatorPid(msg MessageP2P) string { - return PeerIdToShortString(msg.Peer()) -} - -// PeerIdToShortString trims the first displayLastPidChars characters of the provided peer ID after -// converting the peer ID to string using the Pretty functionality -func PeerIdToShortString(pid core.PeerID) string { - prettyPid := pid.Pretty() - lenPrettyPid := len(prettyPid) - - if lenPrettyPid > displayLastPidChars { - return "..." + prettyPid[lenPrettyPid-displayLastPidChars:] - } - - return prettyPid -} - -// MessageOriginatorSeq will output the sequence number as hex -func MessageOriginatorSeq(msg MessageP2P) string { - return hex.EncodeToString(msg.SeqNo()) -} - -// PeerShardResolver is able to resolve the link between the provided PeerID and the shardID -type PeerShardResolver interface { - GetPeerInfo(pid core.PeerID) core.P2PPeerInfo - IsInterfaceNil() bool -} - -// ConnectedPeersInfo represents the DTO structure used to output the metrics for connected peers -type ConnectedPeersInfo struct { - SelfShardID uint32 - UnknownPeers []string - Seeders []string - IntraShardValidators map[uint32][]string - IntraShardObservers map[uint32][]string - CrossShardValidators map[uint32][]string - CrossShardObservers map[uint32][]string - FullHistoryObservers map[uint32][]string - NumValidatorsOnShard map[uint32]int - NumObserversOnShard map[uint32]int - NumPreferredPeersOnShard map[uint32]int - NumIntraShardValidators int - NumIntraShardObservers int - NumCrossShardValidators int - NumCrossShardObservers int - NumFullHistoryObservers int -} - -// NetworkShardingCollector defines the updating methods used by the network sharding component -// The interface assures that the collected data will be used by the p2p network sharding components -type NetworkShardingCollector interface { - UpdatePeerIDInfo(pid core.PeerID, pk []byte, shardID uint32) - IsInterfaceNil() bool -} - -// SignerVerifier is used in higher level protocol authentication of 2 peers after the basic p2p connection has been made -type SignerVerifier interface { - Sign(message []byte) ([]byte, error) - Verify(message []byte, sig []byte, pk []byte) error - PublicKey() []byte - IsInterfaceNil() bool -} - -// Marshalizer defines the 2 basic operations: serialize (marshal) and deserialize (unmarshal) -type Marshalizer interface { - Marshal(obj interface{}) ([]byte, error) - Unmarshal(obj interface{}, buff []byte) error - IsInterfaceNil() bool -} - -// PreferredPeersHolderHandler defines the behavior of a component able to handle preferred peers operations -type PreferredPeersHolderHandler interface { - PutConnectionAddress(peerID core.PeerID, address string) - PutShardID(peerID core.PeerID, shardID uint32) - Get() map[uint32][]core.PeerID - Contains(peerID core.PeerID) bool - Remove(peerID core.PeerID) - Clear() - IsInterfaceNil() bool -} - -// PeerCounts represents the DTO structure used to output the count metrics for connected peers -type PeerCounts struct { - UnknownPeers int - IntraShardPeers int - CrossShardPeers int -} - -// Sharder defines the eviction computing process of unwanted peers -type Sharder interface { - SetSeeders(addresses []string) - IsSeeder(pid core.PeerID) bool - SetPeerShardResolver(psp PeerShardResolver) error - IsInterfaceNil() bool -} - -// PeerDenialEvaluator defines the behavior of a component that is able to decide if a peer ID is black listed or not -// TODO merge this interface with the PeerShardResolver => P2PProtocolHandler ? -// TODO move antiflooding inside network messenger -type PeerDenialEvaluator interface { - IsDenied(pid core.PeerID) bool - UpsertPeerID(pid core.PeerID, duration time.Duration) error - IsInterfaceNil() bool -} - -// ConnectionMonitorWrapper uses a connection monitor but checks if the peer is blacklisted or not -// TODO this should be removed after merging of the PeerShardResolver and BlacklistHandler -type ConnectionMonitorWrapper interface { - CheckConnectionsBlocking() - SetPeerDenialEvaluator(handler PeerDenialEvaluator) error - PeerDenialEvaluator() PeerDenialEvaluator - IsInterfaceNil() bool -} - -// Debugger represent a p2p debugger able to print p2p statistics (messages received/sent per topic) -type Debugger interface { - AddIncomingMessage(topic string, size uint64, isRejected bool) - AddOutgoingMessage(topic string, size uint64, isRejected bool) - Close() error - IsInterfaceNil() bool -} +// ArgPeersRatingHandler is the DTO used to create a new peers rating handler +type ArgPeersRatingHandler = rating.ArgPeersRatingHandler -// SyncTimer represent an entity able to tell the current time -type SyncTimer interface { - CurrentTime() time.Time - IsInterfaceNil() bool +// NewPeersRatingHandler returns a new peers rating handler +func NewPeersRatingHandler(args ArgPeersRatingHandler) (p2p.PeersRatingHandler, error) { + return rating.NewPeersRatingHandler(args) } -// ConnectionsWatcher represent an entity able to watch new connections -type ConnectionsWatcher interface { - NewKnownConnection(pid core.PeerID, connection string) - Close() error - IsInterfaceNil() bool +// NewPeersHolder returns a new instance of peersHolder +func NewPeersHolder(preferredConnectionAddresses []string) (p2p.PreferredPeersHolderHandler, error) { + return peersHolder.NewPeersHolder(preferredConnectionAddresses) } -// PeersRatingHandler represent an entity able to handle peers ratings -type PeersRatingHandler interface { - AddPeer(pid core.PeerID) - IncreaseRating(pid core.PeerID) - DecreaseRating(pid core.PeerID) - GetTopRatedPeersFromList(peers []core.PeerID, minNumOfPeersExpected int) []core.PeerID - IsInterfaceNil() bool +// NewRandomP2PIdentityGenerator creates a new identity generator +func NewRandomP2PIdentityGenerator() RandomP2PIdentityGenerator { + return crypto.NewIdentityGenerator() } diff --git a/p2p/peersHolder/connectionStringValidator/connectionStringValidator.go b/p2p/peersHolder/connectionStringValidator/connectionStringValidator.go deleted file mode 100644 index ce9e90c5616..00000000000 --- a/p2p/peersHolder/connectionStringValidator/connectionStringValidator.go +++ /dev/null @@ -1,29 +0,0 @@ -package connectionStringValidator - -import ( - "net" - - "github.com/ElrondNetwork/elrond-go-core/core" -) - -type connectionStringValidator struct { -} - -// NewConnectionStringValidator returns a new connection string validator -func NewConnectionStringValidator() *connectionStringValidator { - return &connectionStringValidator{} -} - -// IsValid checks either a connection string is a valid ip or peer id -func (csv *connectionStringValidator) IsValid(connStr string) bool { - return csv.isValidIP(connStr) || csv.isValidPeerID(connStr) -} - -func (csv *connectionStringValidator) isValidIP(connStr string) bool { - return net.ParseIP(connStr) != nil -} - -func (csv *connectionStringValidator) isValidPeerID(connStr string) bool { - _, err := core.NewPeerID(connStr) - return err == nil -} diff --git a/p2p/peersHolder/connectionStringValidator/connectionStringValidator_test.go b/p2p/peersHolder/connectionStringValidator/connectionStringValidator_test.go deleted file mode 100644 index ad9052dfa6b..00000000000 --- a/p2p/peersHolder/connectionStringValidator/connectionStringValidator_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package connectionStringValidator - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestConnectionStringValidator_IsValid(t *testing.T) { - t.Parallel() - - csv := NewConnectionStringValidator() - assert.False(t, csv.IsValid("invalid string")) - assert.False(t, csv.IsValid("")) - - assert.True(t, csv.IsValid("5.22.219.242")) - assert.True(t, csv.IsValid("2031:0:130F:0:0:9C0:876A:130B")) - assert.True(t, csv.IsValid("16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ")) -} -func TestConnectionStringValidator_isValidIP(t *testing.T) { - t.Parallel() - - csv := NewConnectionStringValidator() - assert.False(t, csv.isValidIP("invalid ip")) - assert.False(t, csv.isValidIP("")) - assert.False(t, csv.isValidIP("a.b.c.d")) - assert.False(t, csv.isValidIP("10.0.0")) - assert.False(t, csv.isValidIP("10.0")) - assert.False(t, csv.isValidIP("10")) - assert.False(t, csv.isValidIP("2031:0:130F:0:0:9C0:876A")) - assert.False(t, csv.isValidIP("2031:0:130F:0:0:9C0")) - assert.False(t, csv.isValidIP("2031:0:130F:0:0")) - assert.False(t, csv.isValidIP("2031:0:130F:0")) - assert.False(t, csv.isValidIP("2031:0:130F")) - assert.False(t, csv.isValidIP("2031:0")) - assert.False(t, csv.isValidIP("16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ")) - - assert.True(t, csv.isValidIP("127.0.0.1")) - assert.True(t, csv.isValidIP("5.22.219.242")) - assert.True(t, csv.isValidIP("2031:0:130F:0:0:9C0:876A:130B")) -} - -func TestConnectionStringValidator_isValidPeerID(t *testing.T) { - t.Parallel() - - csv := NewConnectionStringValidator() - assert.False(t, csv.isValidPeerID("invalid peer id")) - assert.False(t, csv.isValidPeerID("")) - assert.False(t, csv.isValidPeerID("blaiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ")) // first 3 chars altered - assert.False(t, csv.isValidPeerID("16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdobla")) // last 3 chars altered - assert.False(t, csv.isValidPeerID("16Uiu2HAm6yvbp1oZ6zjnWsn9FblaBSaQkbhELyaThuq48ybdojvJ")) // middle chars altered - assert.False(t, csv.isValidPeerID("5.22.219.242")) - - assert.True(t, csv.isValidPeerID("16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ")) -} diff --git a/p2p/peersHolder/peersHolder.go b/p2p/peersHolder/peersHolder.go deleted file mode 100644 index 70d31ea20a6..00000000000 --- a/p2p/peersHolder/peersHolder.go +++ /dev/null @@ -1,255 +0,0 @@ -package peersHolder - -import ( - "fmt" - "strings" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/p2p/peersHolder/connectionStringValidator" -) - -type peerInfo struct { - pid core.PeerID - shardID uint32 -} - -type peerIDData struct { - connectionAddress string - shardID uint32 - index int -} - -type peersHolder struct { - preferredConnAddresses []string - connAddrToPeersInfo map[string][]*peerInfo - tempPeerIDsWaitingForShard map[core.PeerID]string - peerIDsPerShard map[uint32][]core.PeerID - peerIDs map[core.PeerID]*peerIDData - mut sync.RWMutex -} - -// NewPeersHolder returns a new instance of peersHolder -func NewPeersHolder(preferredConnectionAddresses []string) (*peersHolder, error) { - preferredConnections := make([]string, 0) - connAddrToPeerIDs := make(map[string][]*peerInfo) - - connectionValidator := connectionStringValidator.NewConnectionStringValidator() - - for _, connAddr := range preferredConnectionAddresses { - if !connectionValidator.IsValid(connAddr) { - return nil, fmt.Errorf("%w for preferred connection address %s", p2p.ErrInvalidValue, connAddr) - } - - preferredConnections = append(preferredConnections, connAddr) - connAddrToPeerIDs[connAddr] = nil - } - - return &peersHolder{ - preferredConnAddresses: preferredConnections, - connAddrToPeersInfo: connAddrToPeerIDs, - tempPeerIDsWaitingForShard: make(map[core.PeerID]string), - peerIDsPerShard: make(map[uint32][]core.PeerID), - peerIDs: make(map[core.PeerID]*peerIDData), - }, nil -} - -// PutConnectionAddress will perform the insert or the upgrade operation if the provided peerID is inside the preferred peers list -func (ph *peersHolder) PutConnectionAddress(peerID core.PeerID, connectionAddress string) { - ph.mut.Lock() - defer ph.mut.Unlock() - - knownConnection := ph.getKnownConnection(connectionAddress) - if len(knownConnection) == 0 { - return - } - - peersInfo := ph.connAddrToPeersInfo[knownConnection] - if peersInfo == nil { - ph.addNewPeerInfoToMaps(peerID, knownConnection) - return - } - - // if we have new peer for same connection, add it to maps - pInfo := ph.getPeerInfoForPeerID(peerID, peersInfo) - if pInfo == nil { - ph.addNewPeerInfoToMaps(peerID, knownConnection) - } -} - -func (ph *peersHolder) addNewPeerInfoToMaps(peerID core.PeerID, knownConnection string) { - ph.tempPeerIDsWaitingForShard[peerID] = knownConnection - - newPeerInfo := &peerInfo{ - pid: peerID, - shardID: core.AllShardId, // this will be overwritten once shard is available - } - - ph.connAddrToPeersInfo[knownConnection] = append(ph.connAddrToPeersInfo[knownConnection], newPeerInfo) -} - -func (ph *peersHolder) getPeerInfoForPeerID(peerID core.PeerID, peersInfo []*peerInfo) *peerInfo { - for _, pInfo := range peersInfo { - if peerID == pInfo.pid { - return pInfo - } - } - - return nil -} - -// PutShardID will perform the insert or the upgrade operation if the provided peerID is inside the preferred peers list -func (ph *peersHolder) PutShardID(peerID core.PeerID, shardID uint32) { - ph.mut.Lock() - defer ph.mut.Unlock() - - knownConnection, isWaitingForShardID := ph.tempPeerIDsWaitingForShard[peerID] - if !isWaitingForShardID { - return - } - - peersInfo, ok := ph.connAddrToPeersInfo[knownConnection] - if !ok || peersInfo == nil { - return - } - - pInfo := ph.getPeerInfoForPeerID(peerID, peersInfo) - if pInfo == nil { - return - } - - pInfo.shardID = shardID - - ph.peerIDsPerShard[shardID] = append(ph.peerIDsPerShard[shardID], peerID) - - ph.peerIDs[peerID] = &peerIDData{ - connectionAddress: knownConnection, - shardID: shardID, - index: len(ph.peerIDsPerShard[shardID]) - 1, - } - - delete(ph.tempPeerIDsWaitingForShard, peerID) -} - -// Get will return a map containing the preferred peer IDs, split by shard ID -func (ph *peersHolder) Get() map[uint32][]core.PeerID { - peerIDsPerShardCopy := make(map[uint32][]core.PeerID) - - ph.mut.RLock() - for shardId, peerIds := range ph.peerIDsPerShard { - peerIDsPerShardCopy[shardId] = peerIds - } - ph.mut.RUnlock() - - return peerIDsPerShardCopy -} - -// Contains returns true if the provided peer id is a preferred connection -func (ph *peersHolder) Contains(peerID core.PeerID) bool { - ph.mut.RLock() - defer ph.mut.RUnlock() - - _, found := ph.peerIDs[peerID] - return found -} - -// Remove will remove the provided peer ID from the inner members -func (ph *peersHolder) Remove(peerID core.PeerID) { - ph.mut.Lock() - defer ph.mut.Unlock() - - pidData, found := ph.peerIDs[peerID] - if !found { - return - } - - shard, index, _ := ph.getShardAndIndexForPeer(peerID) - ph.removePeerFromMapAtIndex(shard, index) - - connAddress := pidData.connectionAddress - - delete(ph.peerIDs, peerID) - - ph.removePeerInfoAtConnectionAddress(peerID, connAddress) - - _, isWaitingForShardID := ph.tempPeerIDsWaitingForShard[peerID] - if isWaitingForShardID { - delete(ph.tempPeerIDsWaitingForShard, peerID) - } -} - -// removePeerInfoAtConnectionAddress removes the entry associated with the provided pid from connAddrToPeersInfo map -// it never removes the map key as it may be reused on a further reconnection -func (ph *peersHolder) removePeerInfoAtConnectionAddress(peerID core.PeerID, connAddr string) { - peersInfo := ph.connAddrToPeersInfo[connAddr] - if peersInfo == nil { - return - } - - var index int - var pInfo *peerInfo - for index, pInfo = range peersInfo { - if peerID == pInfo.pid { - ph.removePeerFromPeersInfoAtIndex(peersInfo, index, connAddr) - return - } - } - -} - -func (ph *peersHolder) removePeerFromPeersInfoAtIndex(peersInfo []*peerInfo, index int, connAddr string) { - peersInfo = append(peersInfo[:index], peersInfo[index+1:]...) - if len(peersInfo) == 0 { - peersInfo = nil - } - - ph.connAddrToPeersInfo[connAddr] = peersInfo -} - -// getKnownConnection checks if the connection address string contains any of the initial preferred connection address -// if true, it returns it -// this function must be called under mutex protection -func (ph *peersHolder) getKnownConnection(connectionAddressStr string) string { - for _, preferredConnAddr := range ph.preferredConnAddresses { - if strings.Contains(connectionAddressStr, preferredConnAddr) { - return preferredConnAddr - } - } - - return "" -} - -// this function must be called under mutex protection -func (ph *peersHolder) removePeerFromMapAtIndex(shardID uint32, index int) { - ph.peerIDsPerShard[shardID] = append(ph.peerIDsPerShard[shardID][:index], ph.peerIDsPerShard[shardID][index+1:]...) - if len(ph.peerIDsPerShard[shardID]) == 0 { - delete(ph.peerIDsPerShard, shardID) - } -} - -// this function must be called under mutex protection -func (ph *peersHolder) getShardAndIndexForPeer(peerID core.PeerID) (uint32, int, bool) { - pidData, ok := ph.peerIDs[peerID] - if !ok { - return 0, 0, false - } - - return pidData.shardID, pidData.index, true -} - -// Clear will delete all the entries from the inner map -func (ph *peersHolder) Clear() { - ph.mut.Lock() - defer ph.mut.Unlock() - - ph.tempPeerIDsWaitingForShard = make(map[core.PeerID]string) - ph.peerIDsPerShard = make(map[uint32][]core.PeerID) - ph.peerIDs = make(map[core.PeerID]*peerIDData) - ph.connAddrToPeersInfo = make(map[string][]*peerInfo) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ph *peersHolder) IsInterfaceNil() bool { - return ph == nil -} diff --git a/p2p/peersHolder/peersHolder_test.go b/p2p/peersHolder/peersHolder_test.go deleted file mode 100644 index ca48fd5d35f..00000000000 --- a/p2p/peersHolder/peersHolder_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package peersHolder - -import ( - "errors" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/stretchr/testify/assert" -) - -func TestNewPeersHolder(t *testing.T) { - t.Parallel() - - t.Run("invalid addresses should error", func(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100", "invalid string"} - ph, err := NewPeersHolder(preferredPeers) - assert.True(t, check.IfNil(ph)) - assert.True(t, errors.Is(err, p2p.ErrInvalidValue)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - ph, _ := NewPeersHolder([]string{"10.100.100.100"}) - assert.False(t, check.IfNil(ph)) - }) -} - -func TestPeersHolder_PutConnectionAddress(t *testing.T) { - t.Parallel() - - t.Run("not preferred should not add", func(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - unknownConnection := "/ip4/20.200.200.200/tcp/8080/p2p/some-random-pid" // preferredPeers[0] - providedPid := core.PeerID("provided pid") - ph.PutConnectionAddress(providedPid, unknownConnection) - - _, found := ph.tempPeerIDsWaitingForShard[providedPid] - assert.False(t, found) - - peers := ph.Get() - assert.Equal(t, 0, len(peers)) - }) - t.Run("new connection should add to intermediate maps", func(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100", "10.100.100.101"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - newConnection := "/ip4/10.100.100.100/tcp/38191/p2p/some-random-pid" // preferredPeers[0] - providedPid := core.PeerID("provided pid") - ph.PutConnectionAddress(providedPid, newConnection) - - knownConnection, found := ph.tempPeerIDsWaitingForShard[providedPid] - assert.True(t, found) - assert.Equal(t, preferredPeers[0], knownConnection) - - peersInfo := ph.connAddrToPeersInfo[knownConnection] - assert.Equal(t, 1, len(peersInfo)) - assert.Equal(t, providedPid, peersInfo[0].pid) - assert.Equal(t, core.AllShardId, peersInfo[0].shardID) - - // not in the final map yet - peers := ph.Get() - assert.Equal(t, 0, len(peers)) - }) - t.Run("should save second pid on same address", func(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100", "10.100.100.101", "16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - newConnection := "/ip4/10.100.100.102/tcp/38191/p2p/16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ" // preferredPeers[2] - providedPid := core.PeerID("provided pid") - ph.PutConnectionAddress(providedPid, newConnection) - - knownConnection, found := ph.tempPeerIDsWaitingForShard[providedPid] - assert.True(t, found) - assert.Equal(t, preferredPeers[2], knownConnection) - - peersInfo := ph.connAddrToPeersInfo[knownConnection] - assert.Equal(t, 1, len(peersInfo)) - assert.Equal(t, providedPid, peersInfo[0].pid) - assert.Equal(t, core.AllShardId, peersInfo[0].shardID) - - ph.PutConnectionAddress(providedPid, newConnection) // try to update with same connection for coverage - - newPid := core.PeerID("new pid") - ph.PutConnectionAddress(newPid, newConnection) - knownConnection, found = ph.tempPeerIDsWaitingForShard[providedPid] - assert.True(t, found) - assert.Equal(t, preferredPeers[2], knownConnection) - - peersInfo = ph.connAddrToPeersInfo[knownConnection] - assert.Equal(t, 2, len(peersInfo)) - assert.Equal(t, newPid, peersInfo[1].pid) - assert.Equal(t, core.AllShardId, peersInfo[1].shardID) - - // not in the final map yet - peers := ph.Get() - assert.Equal(t, 0, len(peers)) - }) -} - -func TestPeersHolder_PutShardID(t *testing.T) { - t.Parallel() - - t.Run("peer not added in the waiting list should be skipped", func(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - providedPid := core.PeerID("provided pid") - providedShardID := uint32(123) - ph.PutShardID(providedPid, providedShardID) - - peers := ph.Get() - assert.Equal(t, 0, len(peers)) - }) - t.Run("peer not added in map should be skipped", func(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - providedPid := core.PeerID("provided pid") - providedShardID := uint32(123) - ph.tempPeerIDsWaitingForShard[providedPid] = preferredPeers[0] - ph.PutShardID(providedPid, providedShardID) - - peers := ph.Get() - assert.Equal(t, 0, len(peers)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100", "10.100.100.101", "16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - newConnection := "/ip4/10.100.100.101/tcp/38191/p2p/some-random-pid" // preferredPeers[1] - providedPid := core.PeerID("provided pid") - ph.PutConnectionAddress(providedPid, newConnection) - - providedShardID := uint32(123) - ph.PutShardID(providedPid, providedShardID) - - peers := ph.Get() - assert.Equal(t, 1, len(peers)) - peersInShard, found := peers[providedShardID] - assert.True(t, found) - assert.Equal(t, providedPid, peersInShard[0]) - - pidData := ph.peerIDs[providedPid] - assert.Equal(t, preferredPeers[1], pidData.connectionAddress) - assert.Equal(t, providedShardID, pidData.shardID) - assert.Equal(t, 0, pidData.index) - - _, found = ph.tempPeerIDsWaitingForShard[providedPid] - assert.False(t, found) - }) -} - -func TestPeersHolder_Contains(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100", "10.100.100.101"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - newConnection := "/ip4/10.100.100.101/tcp/38191/p2p/some-random-pid" // preferredPeers[1] - providedPid := core.PeerID("provided pid") - ph.PutConnectionAddress(providedPid, newConnection) - - providedShardID := uint32(123) - ph.PutShardID(providedPid, providedShardID) - - assert.True(t, ph.Contains(providedPid)) - - ph.Remove(providedPid) - assert.False(t, ph.Contains(providedPid)) - - unknownPid := core.PeerID("unknown pid") - ph.Remove(unknownPid) // for code coverage -} - -func TestPeersHolder_Clear(t *testing.T) { - t.Parallel() - - preferredPeers := []string{"10.100.100.100", "16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ"} - ph, _ := NewPeersHolder(preferredPeers) - assert.False(t, check.IfNil(ph)) - - newConnection1 := "/ip4/10.100.100.100/tcp/38191/p2p/some-random-pid" // preferredPeers[0] - providedPid1 := core.PeerID("provided pid 1") - ph.PutConnectionAddress(providedPid1, newConnection1) - providedShardID := uint32(123) - ph.PutShardID(providedPid1, providedShardID) - assert.True(t, ph.Contains(providedPid1)) - - newConnection2 := "/ip4/10.100.100.102/tcp/38191/p2p/16Uiu2HAm6yvbp1oZ6zjnWsn9FdRqBSaQkbhELyaThuq48ybdojvJ" // preferredPeers[1] - providedPid2 := core.PeerID("provided pid 2") - ph.PutConnectionAddress(providedPid2, newConnection2) - ph.PutShardID(providedPid2, providedShardID) - assert.True(t, ph.Contains(providedPid2)) - - peers := ph.Get() - assert.Equal(t, 1, len(peers)) - assert.Equal(t, 2, len(peers[providedShardID])) - - ph.Clear() - peers = ph.Get() - assert.Equal(t, 0, len(peers)) -} diff --git a/p2p/rating/peersRatingHandler.go b/p2p/rating/peersRatingHandler.go deleted file mode 100644 index be7935ef2d3..00000000000 --- a/p2p/rating/peersRatingHandler.go +++ /dev/null @@ -1,238 +0,0 @@ -package rating - -import ( - "fmt" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/storage" -) - -const ( - topRatedTier = "top rated tier" - badRatedTier = "bad rated tier" - defaultRating = int32(0) - minRating = -100 - maxRating = 100 - increaseFactor = 2 - decreaseFactor = -1 - minNumOfPeers = 1 - int32Size = 4 -) - -var log = logger.GetOrCreate("p2p/peersRatingHandler") - -// ArgPeersRatingHandler is the DTO used to create a new peers rating handler -type ArgPeersRatingHandler struct { - TopRatedCache storage.Cacher - BadRatedCache storage.Cacher -} - -type peersRatingHandler struct { - topRatedCache storage.Cacher - badRatedCache storage.Cacher - mut sync.Mutex -} - -// NewPeersRatingHandler returns a new peers rating handler -func NewPeersRatingHandler(args ArgPeersRatingHandler) (*peersRatingHandler, error) { - err := checkArgs(args) - if err != nil { - return nil, err - } - - prh := &peersRatingHandler{ - topRatedCache: args.TopRatedCache, - badRatedCache: args.BadRatedCache, - } - - return prh, nil -} - -func checkArgs(args ArgPeersRatingHandler) error { - if check.IfNil(args.TopRatedCache) { - return fmt.Errorf("%w for TopRatedCache", p2p.ErrNilCacher) - } - if check.IfNil(args.BadRatedCache) { - return fmt.Errorf("%w for BadRatedCache", p2p.ErrNilCacher) - } - - return nil -} - -// AddPeer adds a new peer to the cache with rating 0 -// this is called when a new peer is detected -func (prh *peersRatingHandler) AddPeer(pid core.PeerID) { - prh.mut.Lock() - defer prh.mut.Unlock() - - _, found := prh.getOldRating(pid) - if found { - return - } - - prh.topRatedCache.Put(pid.Bytes(), defaultRating, int32Size) -} - -// IncreaseRating increases the rating of a peer with the increase factor -func (prh *peersRatingHandler) IncreaseRating(pid core.PeerID) { - prh.mut.Lock() - defer prh.mut.Unlock() - - prh.updateRatingIfNeeded(pid, increaseFactor) -} - -// DecreaseRating decreases the rating of a peer with the decrease factor -func (prh *peersRatingHandler) DecreaseRating(pid core.PeerID) { - prh.mut.Lock() - defer prh.mut.Unlock() - - prh.updateRatingIfNeeded(pid, decreaseFactor) -} - -func (prh *peersRatingHandler) getOldRating(pid core.PeerID) (int32, bool) { - oldRating, found := prh.topRatedCache.Get(pid.Bytes()) - if found { - oldRatingInt, _ := oldRating.(int32) - return oldRatingInt, found - } - - oldRating, found = prh.badRatedCache.Get(pid.Bytes()) - if found { - oldRatingInt, _ := oldRating.(int32) - return oldRatingInt, found - } - - return defaultRating, found -} - -func (prh *peersRatingHandler) updateRatingIfNeeded(pid core.PeerID, updateFactor int32) { - oldRating, found := prh.getOldRating(pid) - if !found { - // new pid, add it with default rating - prh.topRatedCache.Put(pid.Bytes(), defaultRating, int32Size) - return - } - - decreasingUnderMin := oldRating == minRating && updateFactor == decreaseFactor - increasingOverMax := oldRating == maxRating && updateFactor == increaseFactor - shouldSkipUpdate := decreasingUnderMin || increasingOverMax - if shouldSkipUpdate { - return - } - - newRating := oldRating + updateFactor - if newRating > maxRating { - newRating = maxRating - } - - if newRating < minRating { - newRating = minRating - } - - prh.updateRating(pid, oldRating, newRating) -} - -func (prh *peersRatingHandler) updateRating(pid core.PeerID, oldRating, newRating int32) { - oldTier := computeRatingTier(oldRating) - newTier := computeRatingTier(newRating) - if newTier == oldTier { - if newTier == topRatedTier { - prh.topRatedCache.Put(pid.Bytes(), newRating, int32Size) - } else { - prh.badRatedCache.Put(pid.Bytes(), newRating, int32Size) - } - - return - } - - prh.movePeerToNewTier(newRating, pid) -} - -func computeRatingTier(peerRating int32) string { - if peerRating >= defaultRating { - return topRatedTier - } - - return badRatedTier -} - -func (prh *peersRatingHandler) movePeerToNewTier(newRating int32, pid core.PeerID) { - newTier := computeRatingTier(newRating) - if newTier == topRatedTier { - prh.badRatedCache.Remove(pid.Bytes()) - prh.topRatedCache.Put(pid.Bytes(), newRating, int32Size) - } else { - prh.topRatedCache.Remove(pid.Bytes()) - prh.badRatedCache.Put(pid.Bytes(), newRating, int32Size) - } -} - -// GetTopRatedPeersFromList returns a list of peers, searching them in the order of rating tiers -func (prh *peersRatingHandler) GetTopRatedPeersFromList(peers []core.PeerID, minNumOfPeersExpected int) []core.PeerID { - prh.mut.Lock() - defer prh.mut.Unlock() - - peersTopRated := make([]core.PeerID, 0) - defer prh.displayPeersRating(&peersTopRated, minNumOfPeersExpected) - - isListEmpty := len(peers) == 0 - if minNumOfPeersExpected < minNumOfPeers || isListEmpty { - return make([]core.PeerID, 0) - } - - peersTopRated, peersBadRated := prh.splitPeersByTiers(peers) - if len(peersTopRated) < minNumOfPeersExpected { - peersTopRated = append(peersTopRated, peersBadRated...) - } - - return peersTopRated -} - -func (prh *peersRatingHandler) displayPeersRating(peers *[]core.PeerID, minNumOfPeersExpected int) { - if log.GetLevel() != logger.LogTrace { - return - } - - strPeersRatings := "" - for _, peer := range *peers { - rating, ok := prh.topRatedCache.Get(peer.Bytes()) - if !ok { - rating, _ = prh.badRatedCache.Get(peer.Bytes()) - } - - ratingInt, ok := rating.(int32) - if ok { - strPeersRatings += fmt.Sprintf("\n peerID: %s, rating: %d", peer.Pretty(), ratingInt) - } else { - strPeersRatings += fmt.Sprintf("\n peerID: %s, rating: invalid", peer.Pretty()) - } - } - - log.Trace("Best peers to request from", "min requested", minNumOfPeersExpected, "peers ratings", strPeersRatings) -} - -func (prh *peersRatingHandler) splitPeersByTiers(peers []core.PeerID) ([]core.PeerID, []core.PeerID) { - topRated := make([]core.PeerID, 0) - badRated := make([]core.PeerID, 0) - - for _, peer := range peers { - if prh.topRatedCache.Has(peer.Bytes()) { - topRated = append(topRated, peer) - } - - if prh.badRatedCache.Has(peer.Bytes()) { - badRated = append(badRated, peer) - } - } - - return topRated, badRated -} - -// IsInterfaceNil returns true if there is no value under the interface -func (prh *peersRatingHandler) IsInterfaceNil() bool { - return prh == nil -} diff --git a/p2p/rating/peersRatingHandler_test.go b/p2p/rating/peersRatingHandler_test.go deleted file mode 100644 index 5070634847e..00000000000 --- a/p2p/rating/peersRatingHandler_test.go +++ /dev/null @@ -1,426 +0,0 @@ -package rating - -import ( - "bytes" - "errors" - "strings" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/stretchr/testify/assert" -) - -func createMockArgs() ArgPeersRatingHandler { - return ArgPeersRatingHandler{ - TopRatedCache: &testscommon.CacherStub{}, - BadRatedCache: &testscommon.CacherStub{}, - } -} - -func TestNewPeersRatingHandler(t *testing.T) { - t.Parallel() - - t.Run("nil top rated cache should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.TopRatedCache = nil - - prh, err := NewPeersRatingHandler(args) - assert.True(t, errors.Is(err, p2p.ErrNilCacher)) - assert.True(t, strings.Contains(err.Error(), "TopRatedCache")) - assert.True(t, check.IfNil(prh)) - }) - t.Run("nil bad rated cache should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.BadRatedCache = nil - - prh, err := NewPeersRatingHandler(args) - assert.True(t, errors.Is(err, p2p.ErrNilCacher)) - assert.True(t, strings.Contains(err.Error(), "BadRatedCache")) - assert.True(t, check.IfNil(prh)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - prh, err := NewPeersRatingHandler(createMockArgs()) - assert.Nil(t, err) - assert.False(t, check.IfNil(prh)) - }) -} - -func TestPeersRatingHandler_AddPeer(t *testing.T) { - t.Parallel() - - t.Run("new peer should add", func(t *testing.T) { - t.Parallel() - - wasCalled := false - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - assert.True(t, bytes.Equal(providedPid.Bytes(), key)) - - wasCalled = true - return false - }, - } - args.BadRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - } - - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - prh.AddPeer(providedPid) - assert.True(t, wasCalled) - }) - t.Run("peer in top rated should not add", func(t *testing.T) { - t.Parallel() - - wasCalled := false - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, true - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - wasCalled = true - return false - }, - } - args.BadRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - } - - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - prh.AddPeer(providedPid) - assert.False(t, wasCalled) - }) - t.Run("peer in bad rated should not add", func(t *testing.T) { - t.Parallel() - - wasCalled := false - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - wasCalled = true - return false - }, - } - args.BadRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, true - }, - } - - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - prh.AddPeer(providedPid) - assert.False(t, wasCalled) - }) -} - -func TestPeersRatingHandler_IncreaseRating(t *testing.T) { - t.Parallel() - - t.Run("new peer should add to cache", func(t *testing.T) { - t.Parallel() - - wasCalled := false - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - assert.True(t, bytes.Equal(providedPid.Bytes(), key)) - - wasCalled = true - return false - }, - } - args.BadRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - } - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - prh.IncreaseRating(providedPid) - assert.True(t, wasCalled) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - cacheMap := make(map[string]interface{}) - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - val, found := cacheMap[string(key)] - return val, found - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - cacheMap[string(key)] = value - return false - }, - } - - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - prh.IncreaseRating(providedPid) - val, found := cacheMap[string(providedPid.Bytes())] - assert.True(t, found) - assert.Equal(t, defaultRating, val) - - // exceed the limit - numOfCalls := 100 - for i := 0; i < numOfCalls; i++ { - prh.IncreaseRating(providedPid) - } - val, found = cacheMap[string(providedPid.Bytes())] - assert.True(t, found) - assert.Equal(t, int32(maxRating), val) - }) -} - -func TestPeersRatingHandler_DecreaseRating(t *testing.T) { - t.Parallel() - - t.Run("new peer should add to cache", func(t *testing.T) { - t.Parallel() - - wasCalled := false - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - assert.True(t, bytes.Equal(providedPid.Bytes(), key)) - - wasCalled = true - return false - }, - } - args.BadRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - return nil, false - }, - } - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - prh.DecreaseRating(providedPid) - assert.True(t, wasCalled) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - topRatedCacheMap := make(map[string]interface{}) - badRatedCacheMap := make(map[string]interface{}) - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - val, found := topRatedCacheMap[string(key)] - return val, found - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - topRatedCacheMap[string(key)] = value - return false - }, - RemoveCalled: func(key []byte) { - delete(topRatedCacheMap, string(key)) - }, - } - args.BadRatedCache = &testscommon.CacherStub{ - GetCalled: func(key []byte) (value interface{}, ok bool) { - val, found := badRatedCacheMap[string(key)] - return val, found - }, - PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - badRatedCacheMap[string(key)] = value - return false - }, - RemoveCalled: func(key []byte) { - delete(badRatedCacheMap, string(key)) - }, - } - - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - // first call just adds it with default rating - prh.DecreaseRating(providedPid) - val, found := topRatedCacheMap[string(providedPid.Bytes())] - assert.True(t, found) - assert.Equal(t, defaultRating, val) - - // exceed the limit - numOfCalls := 200 - for i := 0; i < numOfCalls; i++ { - prh.DecreaseRating(providedPid) - } - val, found = badRatedCacheMap[string(providedPid.Bytes())] - assert.True(t, found) - assert.Equal(t, int32(minRating), val) - - // move back to top tier - for i := 0; i < numOfCalls; i++ { - prh.IncreaseRating(providedPid) - } - _, found = badRatedCacheMap[string(providedPid.Bytes())] - assert.False(t, found) - - val, found = topRatedCacheMap[string(providedPid.Bytes())] - assert.True(t, found) - assert.Equal(t, int32(maxRating), val) - }) -} - -func TestPeersRatingHandler_GetTopRatedPeersFromList(t *testing.T) { - t.Parallel() - - t.Run("asking for 0 peers should return empty list", func(t *testing.T) { - t.Parallel() - - prh, _ := NewPeersRatingHandler(createMockArgs()) - assert.False(t, check.IfNil(prh)) - - res := prh.GetTopRatedPeersFromList([]core.PeerID{"pid"}, 0) - assert.Equal(t, 0, len(res)) - }) - t.Run("nil provided list should return empty list", func(t *testing.T) { - t.Parallel() - - prh, _ := NewPeersRatingHandler(createMockArgs()) - assert.False(t, check.IfNil(prh)) - - res := prh.GetTopRatedPeersFromList(nil, 1) - assert.Equal(t, 0, len(res)) - }) - t.Run("no peers in maps should return empty list", func(t *testing.T) { - t.Parallel() - - prh, _ := NewPeersRatingHandler(createMockArgs()) - assert.False(t, check.IfNil(prh)) - - providedListOfPeers := []core.PeerID{"pid 1", "pid 2"} - res := prh.GetTopRatedPeersFromList(providedListOfPeers, 5) - assert.Equal(t, 0, len(res)) - }) - t.Run("one peer in top rated, asking for one should work", func(t *testing.T) { - t.Parallel() - - providedPid := core.PeerID("provided pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - LenCalled: func() int { - return 1 - }, - KeysCalled: func() [][]byte { - return [][]byte{providedPid.Bytes()} - }, - HasCalled: func(key []byte) bool { - return bytes.Equal(key, providedPid.Bytes()) - }, - } - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - providedListOfPeers := []core.PeerID{providedPid, "another pid"} - res := prh.GetTopRatedPeersFromList(providedListOfPeers, 1) - assert.Equal(t, 1, len(res)) - assert.Equal(t, providedPid, res[0]) - }) - t.Run("one peer in each, asking for two should work", func(t *testing.T) { - t.Parallel() - - providedTopPid := core.PeerID("provided top pid") - providedBadPid := core.PeerID("provided bad pid") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - LenCalled: func() int { - return 1 - }, - KeysCalled: func() [][]byte { - return [][]byte{providedTopPid.Bytes()} - }, - HasCalled: func(key []byte) bool { - return bytes.Equal(key, providedTopPid.Bytes()) - }, - } - args.BadRatedCache = &testscommon.CacherStub{ - LenCalled: func() int { - return 1 - }, - KeysCalled: func() [][]byte { - return [][]byte{providedBadPid.Bytes()} - }, - HasCalled: func(key []byte) bool { - return bytes.Equal(key, providedBadPid.Bytes()) - }, - } - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - providedListOfPeers := []core.PeerID{providedTopPid, providedBadPid, "another pid"} - expectedListOfPeers := []core.PeerID{providedTopPid, providedBadPid} - res := prh.GetTopRatedPeersFromList(providedListOfPeers, 2) - assert.Equal(t, expectedListOfPeers, res) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - providedPid1, providedPid2, providedPid3 := core.PeerID("provided pid 1"), core.PeerID("provided pid 2"), core.PeerID("provided pid 3") - args := createMockArgs() - args.TopRatedCache = &testscommon.CacherStub{ - LenCalled: func() int { - return 3 - }, - KeysCalled: func() [][]byte { - return [][]byte{providedPid1.Bytes(), providedPid2.Bytes(), providedPid3.Bytes()} - }, - HasCalled: func(key []byte) bool { - has := bytes.Equal(key, providedPid1.Bytes()) || - bytes.Equal(key, providedPid2.Bytes()) || - bytes.Equal(key, providedPid3.Bytes()) - return has - }, - } - prh, _ := NewPeersRatingHandler(args) - assert.False(t, check.IfNil(prh)) - - providedListOfPeers := []core.PeerID{providedPid1, providedPid2, providedPid3, "another pid 1", "another pid 2"} - expectedListOfPeers := []core.PeerID{providedPid1, providedPid2, providedPid3} - res := prh.GetTopRatedPeersFromList(providedListOfPeers, 2) - assert.Equal(t, expectedListOfPeers, res) - }) -} diff --git a/p2p/readme.md b/p2p/readme.md deleted file mode 100644 index 4f2c002c1a6..00000000000 --- a/p2p/readme.md +++ /dev/null @@ -1,12 +0,0 @@ -# P2P protocol description - -The `Messenger` interface with its implementation are -used to define the way to communicate between Elrond nodes. - -There are 2 ways to send data to the other peers: -1. Broadcasting messages on a `pubsub` using topics; -1. Direct sending messages to the connected peers. - -The first type is used to send messages that has to reach every node -(from corresponding shard, metachain, consensus group, etc.) and the second type is -used to resolve requests comming from directly connected peers. diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index 86dbda2944b..f7beb246904 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" "sort" + "sync" "time" "github.com/ElrondNetwork/elrond-go-core/core" @@ -23,9 +24,11 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/common/holders" "github.com/ElrondNetwork/elrond-go/common/logging" + "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/consensus" "github.com/ElrondNetwork/elrond-go/dataRetriever" "github.com/ElrondNetwork/elrond-go/dblookupext" + debugFactory "github.com/ElrondNetwork/elrond-go/debug/factory" "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/outport" "github.com/ElrondNetwork/elrond-go/process" @@ -34,7 +37,7 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/state" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) var log = logger.GetOrCreate("process/block") @@ -78,6 +81,8 @@ type baseProcessor struct { blockChain data.ChainHandler hdrsForCurrBlock *hdrForBlock genesisNonce uint64 + mutProcessDebugger sync.RWMutex + processDebugger process.Debugger versionedHeaderFactory nodeFactory.VersionedHeaderFactory headerIntegrityVerifier process.HeaderIntegrityVerifier @@ -98,11 +103,11 @@ type baseProcessor struct { gasConsumedProvider gasConsumedProvider economicsData process.EconomicsDataHandler - processDataTriesOnCommitEpoch bool - lastRestartNonce uint64 - pruningDelay uint32 - processedMiniBlocksTracker process.ProcessedMiniBlocksTracker - receiptsRepository receiptsRepository + processDataTriesOnCommitEpoch bool + lastRestartNonce uint64 + pruningDelay uint32 + processedMiniBlocksTracker process.ProcessedMiniBlocksTracker + receiptsRepository receiptsRepository } type bootStorerDataArgs struct { @@ -1704,7 +1709,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl if check.IfNil(trieEpochRootHashStorageUnit) { return nil } - _, isStorerDisabled := trieEpochRootHashStorageUnit.(*storageUnit.NilStorer) + _, isStorerDisabled := trieEpochRootHashStorageUnit.(*storageunit.NilStorer) if isStorerDisabled { return nil } @@ -1808,15 +1813,17 @@ func unmarshalUserAccount(address []byte, userAccountsBytes []byte, marshalizer // Close - closes all underlying components func (bp *baseProcessor) Close() error { - var err1, err2 error + var err1, err2, err3 error if !check.IfNil(bp.vmContainer) { err1 = bp.vmContainer.Close() } if !check.IfNil(bp.vmContainerFactory) { err2 = bp.vmContainerFactory.Close() } - if err1 != nil || err2 != nil { - return fmt.Errorf("vmContainer close error: %v, vmContainerFactory close error: %v", err1, err2) + err3 = bp.processDebugger.Close() + if err1 != nil || err2 != nil || err3 != nil { + return fmt.Errorf("vmContainer close error: %v, vmContainerFactory close error: %v, processDebugger close: %v", + err1, err2, err3) } return nil @@ -1996,3 +2003,30 @@ func displayCleanupErrorMessage(message string, shardID uint32, noncesToPrevFina "nonces to previous final", noncesToPrevFinal, "error", err.Error()) } + +// SetProcessDebugger sets the process debugger associated to this block processor +func (bp *baseProcessor) SetProcessDebugger(debugger process.Debugger) error { + if check.IfNil(debugger) { + return process.ErrNilProcessDebugger + } + + bp.mutProcessDebugger.Lock() + bp.processDebugger = debugger + bp.mutProcessDebugger.Unlock() + + return nil +} + +func (bp *baseProcessor) updateLastCommittedInDebugger(round uint64) { + bp.mutProcessDebugger.RLock() + bp.processDebugger.SetLastCommittedBlockRound(round) + bp.mutProcessDebugger.RUnlock() +} + +func createDisabledProcessDebugger() (process.Debugger, error) { + configs := config.ProcessDebugConfig{ + Enabled: false, + } + + return debugFactory.CreateProcessDebugger(configs) +} diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index 9bd291ba49f..d0e06ecd220 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -36,8 +36,8 @@ import ( "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/dblookupext" @@ -102,7 +102,7 @@ func createArgBaseProcessor( BlockSizeThrottler: &mock.BlockSizeThrottlerStub{}, Version: "softwareVersion", HistoryRepository: &dblookupext.HistoryRepositoryStub{}, - EnableRoundsHandler: &testscommon.EnableRoundsHandlerStub{}, + EnableRoundsHandler: &testscommon.EnableRoundsHandlerStub{}, GasHandler: &mock.GasHandlerMock{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ScheduledMiniBlocksEnableEpoch: 2, @@ -120,14 +120,14 @@ func createTestBlockchain() *testscommon.ChainHandlerStub { } func generateTestCache() storage.Cacher { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) return cache } func generateTestUnit() storage.Storer { - storer, _ := storageUnit.NewStorageUnit( + storer, _ := storageunit.NewStorageUnit( generateTestCache(), - memorydb.New(), + database.NewMemDB(), ) return storer @@ -1803,6 +1803,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeededNilStorerShouldErr(t *test mb := &block.MetaBlock{Epoch: epoch} err := sp.CommitTrieEpochRootHashIfNeeded(mb, []byte("root")) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), dataRetriever.ErrStorerNotFound.Error())) require.True(t, strings.Contains(err.Error(), dataRetriever.TrieEpochRootHashUnit.String())) } @@ -1813,7 +1814,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeededDisabledStorerShouldNotErr epoch := uint32(37) coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - dataComponents.Storage.AddStorer(dataRetriever.TrieEpochRootHashUnit, &storageUnit.NilStorer{}) + dataComponents.Storage.AddStorer(dataRetriever.TrieEpochRootHashUnit, &storageunit.NilStorer{}) arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) sp, _ := blproc.NewShardProcessor(arguments) diff --git a/process/block/metablock.go b/process/block/metablock.go index 73141649284..e06e9c6242b 100644 --- a/process/block/metablock.go +++ b/process/block/metablock.go @@ -95,6 +95,11 @@ func NewMetaProcessor(arguments ArgMetaProcessor) (*metaProcessor, error) { pruningDelay = defaultPruningDelay } + processDebugger, err := createDisabledProcessDebugger() + if err != nil { + return nil, err + } + genesisHdr := arguments.DataComponents.Blockchain().GetGenesisHeader() base := &baseProcessor{ accountsDB: arguments.AccountsDB, @@ -136,6 +141,7 @@ func NewMetaProcessor(arguments ArgMetaProcessor) (*metaProcessor, error) { pruningDelay: pruningDelay, processedMiniBlocksTracker: arguments.ProcessedMiniBlocksTracker, receiptsRepository: arguments.ReceiptsRepository, + processDebugger: processDebugger, } mp := metaProcessor{ @@ -1270,6 +1276,8 @@ func (mp *metaProcessor) CommitBlock( "nonce", headerHandler.GetNonce(), "hash", headerHash) + mp.updateLastCommittedInDebugger(headerHandler.GetRound()) + notarizedHeadersHashes, errNotCritical := mp.updateCrossShardInfo(header) if errNotCritical != nil { log.Debug("updateCrossShardInfo", "error", errNotCritical.Error()) diff --git a/process/block/metablock_test.go b/process/block/metablock_test.go index 2b683516ed7..1f8d2fbf4bf 100644 --- a/process/block/metablock_test.go +++ b/process/block/metablock_test.go @@ -1007,6 +1007,20 @@ func TestMetaProcessor_CommitBlockOkValsShouldWork(t *testing.T) { mp, _ := blproc.NewMetaProcessor(arguments) + debuggerMethodWasCalled := false + debugger := &testscommon.ProcessDebuggerStub{ + SetLastCommittedBlockRoundCalled: func(round uint64) { + assert.Equal(t, hdr.Round, round) + debuggerMethodWasCalled = true + }, + } + + err := mp.SetProcessDebugger(nil) + assert.Equal(t, process.ErrNilProcessDebugger, err) + + err = mp.SetProcessDebugger(debugger) + assert.Nil(t, err) + mdp.HeadersCalled = func() dataRetriever.HeadersPool { cs := &mock.HeadersCacherStub{} cs.RegisterHandlerCalled = func(i func(header data.HeaderHandler, key []byte)) { @@ -1027,9 +1041,10 @@ func TestMetaProcessor_CommitBlockOkValsShouldWork(t *testing.T) { } mp.SetHdrForCurrentBlock([]byte("hdr_hash1"), &block.Header{}, true) - err := mp.CommitBlock(hdr, body) + err = mp.CommitBlock(hdr, body) assert.Nil(t, err) assert.True(t, forkDetectorAddCalled) + assert.True(t, debuggerMethodWasCalled) // this should sleep as there is an async call to display current header and block in CommitBlock time.Sleep(time.Second) } diff --git a/process/block/preprocess/miniBlockBuilder_test.go b/process/block/preprocess/miniBlockBuilder_test.go index 433a1e08e6d..bfe21fb9f0f 100644 --- a/process/block/preprocess/miniBlockBuilder_test.go +++ b/process/block/preprocess/miniBlockBuilder_test.go @@ -7,8 +7,6 @@ import ( "sync" "testing" - stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" - "github.com/ElrondNetwork/elrond-go-core/data" "github.com/ElrondNetwork/elrond-go-core/data/block" "github.com/ElrondNetwork/elrond-go-core/data/transaction" @@ -17,6 +15,7 @@ import ( "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/economicsmocks" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" + stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/process/block/shardblock.go b/process/block/shardblock.go index 0fd2ff3b69c..10138272c41 100644 --- a/process/block/shardblock.go +++ b/process/block/shardblock.go @@ -82,46 +82,52 @@ func NewShardProcessor(arguments ArgShardProcessor) (*shardProcessor, error) { pruningDelay = defaultPruningDelay } + processDebugger, err := createDisabledProcessDebugger() + if err != nil { + return nil, err + } + base := &baseProcessor{ - accountsDB: arguments.AccountsDB, - blockSizeThrottler: arguments.BlockSizeThrottler, - forkDetector: arguments.ForkDetector, - hasher: arguments.CoreComponents.Hasher(), - marshalizer: arguments.CoreComponents.InternalMarshalizer(), - store: arguments.DataComponents.StorageService(), - shardCoordinator: arguments.BootstrapComponents.ShardCoordinator(), - nodesCoordinator: arguments.NodesCoordinator, - uint64Converter: arguments.CoreComponents.Uint64ByteSliceConverter(), - requestHandler: arguments.RequestHandler, - appStatusHandler: arguments.CoreComponents.StatusHandler(), - blockChainHook: arguments.BlockChainHook, - txCoordinator: arguments.TxCoordinator, - roundHandler: arguments.CoreComponents.RoundHandler(), - epochStartTrigger: arguments.EpochStartTrigger, - headerValidator: arguments.HeaderValidator, - bootStorer: arguments.BootStorer, - blockTracker: arguments.BlockTracker, - dataPool: arguments.DataComponents.Datapool(), - stateCheckpointModulus: arguments.Config.StateTriesConfig.CheckpointRoundsModulus, - blockChain: arguments.DataComponents.Blockchain(), - feeHandler: arguments.FeeHandler, - outportHandler: arguments.StatusComponents.OutportHandler(), - genesisNonce: genesisHdr.GetNonce(), - versionedHeaderFactory: arguments.BootstrapComponents.VersionedHeaderFactory(), - headerIntegrityVerifier: arguments.BootstrapComponents.HeaderIntegrityVerifier(), - historyRepo: arguments.HistoryRepository, + accountsDB: arguments.AccountsDB, + blockSizeThrottler: arguments.BlockSizeThrottler, + forkDetector: arguments.ForkDetector, + hasher: arguments.CoreComponents.Hasher(), + marshalizer: arguments.CoreComponents.InternalMarshalizer(), + store: arguments.DataComponents.StorageService(), + shardCoordinator: arguments.BootstrapComponents.ShardCoordinator(), + nodesCoordinator: arguments.NodesCoordinator, + uint64Converter: arguments.CoreComponents.Uint64ByteSliceConverter(), + requestHandler: arguments.RequestHandler, + appStatusHandler: arguments.CoreComponents.StatusHandler(), + blockChainHook: arguments.BlockChainHook, + txCoordinator: arguments.TxCoordinator, + roundHandler: arguments.CoreComponents.RoundHandler(), + epochStartTrigger: arguments.EpochStartTrigger, + headerValidator: arguments.HeaderValidator, + bootStorer: arguments.BootStorer, + blockTracker: arguments.BlockTracker, + dataPool: arguments.DataComponents.Datapool(), + stateCheckpointModulus: arguments.Config.StateTriesConfig.CheckpointRoundsModulus, + blockChain: arguments.DataComponents.Blockchain(), + feeHandler: arguments.FeeHandler, + outportHandler: arguments.StatusComponents.OutportHandler(), + genesisNonce: genesisHdr.GetNonce(), + versionedHeaderFactory: arguments.BootstrapComponents.VersionedHeaderFactory(), + headerIntegrityVerifier: arguments.BootstrapComponents.HeaderIntegrityVerifier(), + historyRepo: arguments.HistoryRepository, epochNotifier: arguments.CoreComponents.EpochNotifier(), enableEpochsHandler: arguments.CoreComponents.EnableEpochsHandler(), enableRoundsHandler: arguments.EnableRoundsHandler, - vmContainerFactory: arguments.VMContainersFactory, - vmContainer: arguments.VmContainer, - processDataTriesOnCommitEpoch: arguments.Config.Debug.EpochStart.ProcessDataTrieOnCommitEpoch, - gasConsumedProvider: arguments.GasHandler, - economicsData: arguments.CoreComponents.EconomicsData(), - scheduledTxsExecutionHandler: arguments.ScheduledTxsExecutionHandler, - pruningDelay: pruningDelay, - processedMiniBlocksTracker: arguments.ProcessedMiniBlocksTracker, - receiptsRepository: arguments.ReceiptsRepository, + vmContainerFactory: arguments.VMContainersFactory, + vmContainer: arguments.VmContainer, + processDataTriesOnCommitEpoch: arguments.Config.Debug.EpochStart.ProcessDataTrieOnCommitEpoch, + gasConsumedProvider: arguments.GasHandler, + economicsData: arguments.CoreComponents.EconomicsData(), + scheduledTxsExecutionHandler: arguments.ScheduledTxsExecutionHandler, + pruningDelay: pruningDelay, + processedMiniBlocksTracker: arguments.ProcessedMiniBlocksTracker, + receiptsRepository: arguments.ReceiptsRepository, + processDebugger: processDebugger, } sp := shardProcessor{ @@ -1032,6 +1038,8 @@ func (sp *shardProcessor) CommitBlock( "hash", headerHash, ) + sp.updateLastCommittedInDebugger(headerHandler.GetRound()) + errNotCritical := sp.updateCrossShardInfo(processedMetaHdrs) if errNotCritical != nil { log.Debug("updateCrossShardInfo", "error", errNotCritical.Error()) diff --git a/process/block/shardblock_test.go b/process/block/shardblock_test.go index fbf32f16e71..ff4b379aa94 100644 --- a/process/block/shardblock_test.go +++ b/process/block/shardblock_test.go @@ -2105,13 +2105,27 @@ func TestShardProcessor_CommitBlockOkValsShouldWork(t *testing.T) { arguments.BlockTracker = blockTrackerMock sp, _ := blproc.NewShardProcessor(arguments) + debuggerMethodWasCalled := false + debugger := &testscommon.ProcessDebuggerStub{ + SetLastCommittedBlockRoundCalled: func(round uint64) { + assert.Equal(t, hdr.Round, round) + debuggerMethodWasCalled = true + }, + } - err := sp.ProcessBlock(hdr, body, haveTime) + err := sp.SetProcessDebugger(nil) + assert.Equal(t, process.ErrNilProcessDebugger, err) + + err = sp.SetProcessDebugger(debugger) + assert.Nil(t, err) + + err = sp.ProcessBlock(hdr, body, haveTime) assert.Nil(t, err) err = sp.CommitBlock(hdr, body) assert.Nil(t, err) assert.True(t, forkDetectorAddCalled) assert.Equal(t, hdrHash, blkc.GetCurrentBlockHeaderHash()) + assert.True(t, debuggerMethodWasCalled) // this should sleep as there is an async call to display current hdr and block in CommitBlock time.Sleep(time.Second) } diff --git a/process/coordinator/process.go b/process/coordinator/process.go index 6d895e68cf4..ff532881e3b 100644 --- a/process/coordinator/process.go +++ b/process/coordinator/process.go @@ -24,7 +24,7 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" ) var _ process.TransactionCoordinator = (*transactionCoordinator)(nil) @@ -155,7 +155,7 @@ func NewTransactionCoordinator(args ArgTransactionCoordinator) (*transactionCoor tc.interimProcessors[value] = interProc } - tc.requestedItemsHandler = timecache.NewTimeCache(common.MaxWaitingTimeToReceiveRequestedItem) + tc.requestedItemsHandler = cache.NewTimeCache(common.MaxWaitingTimeToReceiveRequestedItem) tc.miniBlockPool.RegisterHandler(tc.receivedMiniBlock, core.UniqueIdentifier()) return tc, nil diff --git a/process/coordinator/process_test.go b/process/coordinator/process_test.go index 3381eb61f84..bafed34fe7c 100644 --- a/process/coordinator/process_test.go +++ b/process/coordinator/process_test.go @@ -28,8 +28,8 @@ import ( "github.com/ElrondNetwork/elrond-go/process/factory/shard" "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" @@ -196,14 +196,14 @@ func initStore() *dataRetriever.ChainStorer { } func generateTestCache() storage.Cacher { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) return cache } func generateTestUnit() storage.Storer { - storer, _ := storageUnit.NewStorageUnit( + storer, _ := storageunit.NewStorageUnit( generateTestCache(), - memorydb.New(), + database.NewMemDB(), ) return storer @@ -856,8 +856,8 @@ func TestTransactionCoordinator_CreateMbsAndProcessCrossShardTransactions(t *tes t.Parallel() tdp := initDataPool(txHash) - cacherCfg := storageUnit.CacheConfig{Capacity: 100, Type: storageUnit.LRUCache} - hdrPool, _ := storageUnit.NewCache(cacherCfg) + cacherCfg := storageunit.CacheConfig{Capacity: 100, Type: storageunit.LRUCache} + hdrPool, _ := storageunit.NewCache(cacherCfg) tdp.MiniBlocksCalled = func() storage.Cacher { return hdrPool } @@ -1044,8 +1044,8 @@ func TestTransactionCoordinator_CreateMbsAndProcessCrossShardTransactionsNilPreP t.Parallel() tdp := initDataPool(txHash) - cacherCfg := storageUnit.CacheConfig{Capacity: 100, Type: storageUnit.LRUCache} - hdrPool, _ := storageUnit.NewCache(cacherCfg) + cacherCfg := storageunit.CacheConfig{Capacity: 100, Type: storageunit.LRUCache} + hdrPool, _ := storageunit.NewCache(cacherCfg) tdp.MiniBlocksCalled = func() storage.Cacher { return hdrPool } diff --git a/process/errors.go b/process/errors.go index 69842defd1e..a743d6079ab 100644 --- a/process/errors.go +++ b/process/errors.go @@ -1131,6 +1131,9 @@ var ErrNilESDTGlobalSettingsHandler = errors.New("nil esdt global settings handl // ErrNilEnableEpochsHandler signals that a nil enable epochs handler has been provided var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") +// ErrNilMultiSignerContainer signals that the given multisigner container is nil +var ErrNilMultiSignerContainer = errors.New("nil multiSigner container") + // ErrNilCrawlerAllowedAddress signals that no crawler allowed address was found var ErrNilCrawlerAllowedAddress = errors.New("nil crawler allowed address") @@ -1145,3 +1148,6 @@ var ErrPropertyTooLong = errors.New("property too long") // ErrPropertyTooShort signals that a heartbeat property was too short var ErrPropertyTooShort = errors.New("property too short") + +// ErrNilProcessDebugger signals that a nil process debugger was provided +var ErrNilProcessDebugger = errors.New("nil process debugger") diff --git a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go index 7aeecea423f..67b93dd7684 100644 --- a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go @@ -100,7 +100,11 @@ func checkBaseParams( if coreComponents.MinTransactionVersion() == 0 { return process.ErrInvalidTransactionVersion } - if check.IfNil(cryptoComponents.MultiSigner()) { + multiSigner, err := cryptoComponents.GetMultiSigner(0) + if err != nil { + return err + } + if check.IfNil(multiSigner) { return process.ErrNilMultiSigVerifier } if check.IfNil(cryptoComponents.BlockSignKeyGen()) { diff --git a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go index c42c2bf76c7..98369ebc6ee 100644 --- a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go +++ b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go @@ -13,6 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/testscommon" + "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" @@ -230,7 +231,7 @@ func TestNewMetaInterceptorsContainerFactory_NilMultiSignerShouldErr(t *testing. t.Parallel() coreComp, cryptoComp := createMockComponentHolders() - cryptoComp.MultiSig = nil + cryptoComp.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(nil) args := getArgumentsMeta(coreComp, cryptoComp) icf, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) diff --git a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go index 127f53421b3..05e12772ca5 100644 --- a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go +++ b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go @@ -269,7 +269,7 @@ func TestNewShardInterceptorsContainerFactory_NilMultiSignerShouldErr(t *testing t.Parallel() coreComp, cryptoComp := createMockComponentHolders() - cryptoComp.MultiSig = nil + cryptoComp.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(nil) args := getArgumentsShard(coreComp, cryptoComp) icf, err := interceptorscontainer.NewShardInterceptorsContainerFactory(args) @@ -687,12 +687,13 @@ func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoCompone HardforkTriggerPubKeyField: providedHardforkPubKey, EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, } + multiSigner := cryptoMocks.NewMultiSigner() cryptoComponents := &mock.CryptoComponentsMock{ - BlockSig: &mock.SignerMock{}, - TxSig: &mock.SignerMock{}, - MultiSig: cryptoMocks.NewMultiSigner(21), - BlKeyGen: &mock.SingleSignKeyGenMock{}, - TxKeyGen: &mock.SingleSignKeyGenMock{}, + BlockSig: &mock.SignerMock{}, + TxSig: &mock.SignerMock{}, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(multiSigner), + BlKeyGen: &mock.SingleSignKeyGenMock{}, + TxKeyGen: &mock.SingleSignKeyGenMock{}, } return coreComponents, cryptoComponents diff --git a/process/headerCheck/errors.go b/process/headerCheck/errors.go index 443cea6dc84..152e5c62dfa 100644 --- a/process/headerCheck/errors.go +++ b/process/headerCheck/errors.go @@ -17,3 +17,9 @@ var ErrInvalidChainID = errors.New("invalid chain ID") // ErrNilHeaderVersionHandler signals that the provided header version handler is nil var ErrNilHeaderVersionHandler = errors.New("nil header version handler") + +// ErrIndexOutOfBounds signals that the given index is outside of expected bounds +var ErrIndexOutOfBounds = errors.New("index is out of bounds") + +// ErrIndexNotSelected signals that the given index is not selected +var ErrIndexNotSelected = errors.New("index is not selected") diff --git a/process/headerCheck/headerSignatureVerify.go b/process/headerCheck/headerSignatureVerify.go index 6a6fd78d59a..e424ef68677 100644 --- a/process/headerCheck/headerSignatureVerify.go +++ b/process/headerCheck/headerSignatureVerify.go @@ -10,6 +10,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/marshal" crypto "github.com/ElrondNetwork/elrond-go-crypto" logger "github.com/ElrondNetwork/elrond-go-logger" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" ) @@ -23,7 +24,7 @@ type ArgsHeaderSigVerifier struct { Marshalizer marshal.Marshalizer Hasher hashing.Hasher NodesCoordinator nodesCoordinator.NodesCoordinator - MultiSigVerifier crypto.MultiSigVerifier + MultiSigContainer cryptoCommon.MultiSignerContainer SingleSigVerifier crypto.SingleSigner KeyGen crypto.KeyGenerator FallbackHeaderValidator process.FallbackHeaderValidator @@ -34,7 +35,7 @@ type HeaderSigVerifier struct { marshalizer marshal.Marshalizer hasher hashing.Hasher nodesCoordinator nodesCoordinator.NodesCoordinator - multiSigVerifier crypto.MultiSigVerifier + multiSigContainer cryptoCommon.MultiSignerContainer singleSigVerifier crypto.SingleSigner keyGen crypto.KeyGenerator fallbackHeaderValidator process.FallbackHeaderValidator @@ -51,7 +52,7 @@ func NewHeaderSigVerifier(arguments *ArgsHeaderSigVerifier) (*HeaderSigVerifier, marshalizer: arguments.Marshalizer, hasher: arguments.Hasher, nodesCoordinator: arguments.NodesCoordinator, - multiSigVerifier: arguments.MultiSigVerifier, + multiSigContainer: arguments.MultiSigContainer, singleSigVerifier: arguments.SingleSigVerifier, keyGen: arguments.KeyGen, fallbackHeaderValidator: arguments.FallbackHeaderValidator, @@ -71,7 +72,14 @@ func checkArgsHeaderSigVerifier(arguments *ArgsHeaderSigVerifier) error { if check.IfNil(arguments.Marshalizer) { return process.ErrNilMarshalizer } - if check.IfNil(arguments.MultiSigVerifier) { + if check.IfNil(arguments.MultiSigContainer) { + return process.ErrNilMultiSignerContainer + } + multiSigner, err := arguments.MultiSigContainer.GetMultiSigner(0) + if err != nil { + return err + } + if check.IfNil(multiSigner) { return process.ErrNilMultiSigVerifier } if check.IfNil(arguments.NodesCoordinator) { @@ -87,50 +95,70 @@ func checkArgsHeaderSigVerifier(arguments *ArgsHeaderSigVerifier) error { return nil } -// VerifySignature will check if signature is correct -func (hsv *HeaderSigVerifier) VerifySignature(header data.HeaderHandler) error { +func isIndexInBitmap(index uint16, bitmap []byte) error { + indexOutOfBounds := index >= uint16(len(bitmap)*8) + if indexOutOfBounds { + return ErrIndexOutOfBounds + } + + indexNotInBitmap := bitmap[index/8]&(1< 0 { - epoch = epoch - 1 + // TODO: remove if start of epochForConsensus block needs to be validated by the new epochForConsensus nodes + epochForConsensus := header.GetEpoch() + if header.IsStartOfEpochBlock() && epochForConsensus > 0 { + epochForConsensus = epochForConsensus - 1 } consensusPubKeys, err := hsv.nodesCoordinator.GetConsensusValidatorsPublicKeys( randSeed, header.GetRound(), header.GetShardID(), - epoch, + epochForConsensus, ) if err != nil { - return err + return nil, err } err = hsv.verifyConsensusSize(consensusPubKeys, header) if err != nil { - return err + return nil, err } - verifier, err := hsv.multiSigVerifier.Create(consensusPubKeys, 0) - if err != nil { - return err + pubKeysSigners := make([][]byte, 0, len(consensusPubKeys)) + for i := range consensusPubKeys { + err = isIndexInBitmap(uint16(i), bitmap) + if err != nil { + continue + } + pubKeysSigners = append(pubKeysSigners, []byte(consensusPubKeys[i])) } - err = verifier.SetAggregatedSig(header.GetSignature()) + return pubKeysSigners, nil +} + +// VerifySignature will check if signature is correct +func (hsv *HeaderSigVerifier) VerifySignature(header data.HeaderHandler) error { + multiSigVerifier, err := hsv.multiSigContainer.GetMultiSigner(header.GetEpoch()) if err != nil { return err } - // get marshalled block header without signature and bitmap - // as this is the message that was signed headerCopy, err := hsv.copyHeaderWithoutSig(header) if err != nil { return err @@ -141,7 +169,12 @@ func (hsv *HeaderSigVerifier) VerifySignature(header data.HeaderHandler) error { return err } - return verifier.Verify(hash, bitmap) + pubKeysSigners, err := hsv.getConsensusSigners(header) + if err != nil { + return err + } + + return multiSigVerifier.VerifyAggregatedSig(pubKeysSigners, hash, header.GetSignature()) } func (hsv *HeaderSigVerifier) verifyConsensusSize(consensusPubKeys []string, header data.HeaderHandler) error { diff --git a/process/headerCheck/headerSignatureVerify_test.go b/process/headerCheck/headerSignatureVerify_test.go index 2f0401375cf..7d01fbe656c 100644 --- a/process/headerCheck/headerSignatureVerify_test.go +++ b/process/headerCheck/headerSignatureVerify_test.go @@ -25,7 +25,7 @@ func createHeaderSigVerifierArgs() *ArgsHeaderSigVerifier { Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, - MultiSigVerifier: cryptoMocks.NewMultiSigner(21), + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(cryptoMocks.NewMultiSigner()), SingleSigVerifier: &mock.SignerMock{}, KeyGen: &mock.SingleSignKeyGenMock{}, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, @@ -78,7 +78,7 @@ func TestNewHeaderSigVerifier_NilMultiSigShouldErr(t *testing.T) { t.Parallel() args := createHeaderSigVerifierArgs() - args.MultiSigVerifier = nil + args.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(nil) hdrSigVerifier, err := NewHeaderSigVerifier(args) require.Nil(t, hdrSigVerifier) @@ -148,13 +148,13 @@ func TestHeaderSigVerifier_VerifyRandSeedOk(t *testing.T) { } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{} @@ -183,13 +183,13 @@ func TestHeaderSigVerifier_VerifyRandSeedShouldErrWhenVerificationFails(t *testi } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{} @@ -229,13 +229,13 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyShouldErrWhenVa } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{} @@ -268,13 +268,13 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyLeaderSigShould } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ LeaderSignature: leaderSig, @@ -304,13 +304,13 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureOk(t *testing.T) { } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{} @@ -350,13 +350,13 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyShouldErrWhenValidationFai } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{} @@ -389,13 +389,13 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyLeaderSigShouldErr(t *test } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ LeaderSignature: leaderSig, @@ -425,13 +425,13 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureOk(t *testing.T) { } pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{} @@ -482,13 +482,13 @@ func TestHeaderSigVerifier_VerifySignatureWrongSizeBitmapShouldErr(t *testing.T) args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ @@ -504,13 +504,13 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErr(t *testing.T) { args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v, v, v, v, v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ @@ -527,23 +527,19 @@ func TestHeaderSigVerifier_VerifySignatureOk(t *testing.T) { wasCalled := false args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v}, nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc - args.MultiSigVerifier = &cryptoMocks.MultisignerMock{ - CreateCalled: func(pubKeys []string, index uint16) (signer crypto.MultiSigner, err error) { - return &cryptoMocks.MultisignerMock{ - VerifyCalled: func(msg []byte, bitmap []byte) error { - wasCalled = true - return nil - }}, nil - }, - } + args.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(&cryptoMocks.MultisignerMock{ + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasCalled = true + return nil + }}) hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ @@ -561,7 +557,7 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErrWhenFallbackThre wasCalled := false args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v, v, v, v, v}, nil @@ -573,18 +569,15 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErrWhenFallbackThre }, } multiSigVerifier := &cryptoMocks.MultisignerMock{ - CreateCalled: func(pubKeys []string, index uint16) (signer crypto.MultiSigner, err error) { - return &cryptoMocks.MultisignerMock{ - VerifyCalled: func(msg []byte, bitmap []byte) error { - wasCalled = true - return nil - }}, nil + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasCalled = true + return nil }, } - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc args.FallbackHeaderValidator = fallbackHeaderValidator - args.MultiSigVerifier = multiSigVerifier + args.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigVerifier) hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.MetaBlock{ @@ -602,7 +595,7 @@ func TestHeaderSigVerifier_VerifySignatureOkWhenFallbackThresholdCouldBeApplied( wasCalled := false args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") - nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + nc := &shardingMocks.NodesCoordinatorMock{ ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) return []nodesCoordinator.Validator{v, v, v, v, v}, nil @@ -614,18 +607,14 @@ func TestHeaderSigVerifier_VerifySignatureOkWhenFallbackThresholdCouldBeApplied( }, } multiSigVerifier := &cryptoMocks.MultisignerMock{ - CreateCalled: func(pubKeys []string, index uint16) (signer crypto.MultiSigner, err error) { - return &cryptoMocks.MultisignerMock{ - VerifyCalled: func(msg []byte, bitmap []byte) error { - wasCalled = true - return nil - }}, nil - }, - } + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasCalled = true + return nil + }} - args.NodesCoordinator = nodesCoordinator + args.NodesCoordinator = nc args.FallbackHeaderValidator = fallbackHeaderValidator - args.MultiSigVerifier = multiSigVerifier + args.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(multiSigVerifier) hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.MetaBlock{ diff --git a/process/interceptors/factory/argInterceptedDataFactory.go b/process/interceptors/factory/argInterceptedDataFactory.go index 889accc3f1b..8cc8aa0abd7 100644 --- a/process/interceptors/factory/argInterceptedDataFactory.go +++ b/process/interceptors/factory/argInterceptedDataFactory.go @@ -34,7 +34,7 @@ type interceptedDataCryptoComponentsHolder interface { BlockSignKeyGen() crypto.KeyGenerator TxSingleSigner() crypto.SingleSigner BlockSigner() crypto.SingleSigner - MultiSigner() crypto.MultiSigner + GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) PublicKey() crypto.PublicKey IsInterfaceNil() bool } diff --git a/process/interceptors/factory/interceptedDirectConnectionInfoFactory_test.go b/process/interceptors/factory/interceptedDirectConnectionInfoFactory_test.go index ac2b4ab5cac..eeac92dee16 100644 --- a/process/interceptors/factory/interceptedDirectConnectionInfoFactory_test.go +++ b/process/interceptors/factory/interceptedDirectConnectionInfoFactory_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/p2p/message" + "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" "github.com/stretchr/testify/assert" ) @@ -56,7 +56,7 @@ func TestNewInterceptedDirectConnectionInfoFactory(t *testing.T) { assert.Nil(t, err) assert.False(t, check.IfNil(idcif)) - msg := &message.DirectConnectionInfo{ + msg := &p2p.DirectConnectionInfo{ ShardId: "5", } msgBuff, _ := arg.CoreComponents.InternalMarshalizer().Marshal(msg) diff --git a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go index 059c335652a..d7bc91d1bbc 100644 --- a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go +++ b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go @@ -74,11 +74,11 @@ func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoCompone EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, } cryptoComponents := &mock.CryptoComponentsMock{ - BlockSig: createMockSigner(), - TxSig: createMockSigner(), - MultiSig: cryptoMocks.NewMultiSigner(21), - BlKeyGen: createMockKeyGen(), - TxKeyGen: createMockKeyGen(), + BlockSig: createMockSigner(), + TxSig: createMockSigner(), + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(cryptoMocks.NewMultiSigner()), + BlKeyGen: createMockKeyGen(), + TxKeyGen: createMockKeyGen(), } return coreComponents, cryptoComponents diff --git a/process/interceptors/processor/directConnectionInfoInterceptorProcessor_test.go b/process/interceptors/processor/directConnectionInfoInterceptorProcessor_test.go index c7bf45dc972..76b8469c47a 100644 --- a/process/interceptors/processor/directConnectionInfoInterceptorProcessor_test.go +++ b/process/interceptors/processor/directConnectionInfoInterceptorProcessor_test.go @@ -7,11 +7,11 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/marshal" heartbeatMessages "github.com/ElrondNetwork/elrond-go/heartbeat" - "github.com/ElrondNetwork/elrond-go/p2p/message" + "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/heartbeat" "github.com/ElrondNetwork/elrond-go/process/mock" - "github.com/ElrondNetwork/elrond-go/process/p2p" + processP2P "github.com/ElrondNetwork/elrond-go/process/p2p" "github.com/stretchr/testify/assert" ) @@ -87,17 +87,17 @@ func TestDirectConnectionInfoInterceptorProcessor_Save(t *testing.T) { assert.Nil(t, err) assert.False(t, check.IfNil(processor)) - msg := &message.DirectConnectionInfo{ + msg := &p2p.DirectConnectionInfo{ ShardId: "invalid shard", } marshaller := marshal.GogoProtoMarshalizer{} dataBuff, _ := marshaller.Marshal(msg) - arg := p2p.ArgInterceptedDirectConnectionInfo{ + arg := processP2P.ArgInterceptedDirectConnectionInfo{ Marshaller: &marshaller, DataBuff: dataBuff, NumOfShards: 10, } - data, _ := p2p.NewInterceptedDirectConnectionInfo(arg) + data, _ := processP2P.NewInterceptedDirectConnectionInfo(arg) err = processor.Save(data, "", "") assert.NotNil(t, err) @@ -118,17 +118,17 @@ func TestDirectConnectionInfoInterceptorProcessor_Save(t *testing.T) { assert.Nil(t, err) assert.False(t, check.IfNil(processor)) - msg := &message.DirectConnectionInfo{ + msg := &p2p.DirectConnectionInfo{ ShardId: "5", } marshaller := marshal.GogoProtoMarshalizer{} dataBuff, _ := marshaller.Marshal(msg) - arg := p2p.ArgInterceptedDirectConnectionInfo{ + arg := processP2P.ArgInterceptedDirectConnectionInfo{ Marshaller: &marshaller, DataBuff: dataBuff, NumOfShards: 10, } - data, _ := p2p.NewInterceptedDirectConnectionInfo(arg) + data, _ := processP2P.NewInterceptedDirectConnectionInfo(arg) err = processor.Save(data, "", "") assert.Nil(t, err) diff --git a/process/interface.go b/process/interface.go index 4ae73a5f779..ec847e18209 100644 --- a/process/interface.go +++ b/process/interface.go @@ -19,6 +19,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/marshal" "github.com/ElrondNetwork/elrond-go-crypto" "github.com/ElrondNetwork/elrond-go/common" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process/block/bootstrapStorage" @@ -1136,8 +1137,9 @@ type CryptoComponentsHolder interface { BlockSignKeyGen() crypto.KeyGenerator TxSingleSigner() crypto.SingleSigner BlockSigner() crypto.SingleSigner - MultiSigner() crypto.MultiSigner - SetMultiSigner(ms crypto.MultiSigner) error + GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) + MultiSignerContainer() cryptoCommon.MultiSignerContainer + SetMultiSignerContainer(ms cryptoCommon.MultiSignerContainer) error PeerSignatureHandler() crypto.PeerSignatureHandler PublicKey() crypto.PublicKey Clone() interface{} @@ -1242,3 +1244,10 @@ type PeerAuthenticationPayloadValidator interface { ValidateTimestamp(payloadTimestamp int64) error IsInterfaceNil() bool } + +// Debugger defines what a process debugger implementation should do +type Debugger interface { + SetLastCommittedBlockRound(round uint64) + Close() error + IsInterfaceNil() bool +} diff --git a/process/mock/accountWrapperMock.go b/process/mock/accountWrapperMock.go deleted file mode 100644 index e3f46dad3e5..00000000000 --- a/process/mock/accountWrapperMock.go +++ /dev/null @@ -1,178 +0,0 @@ -package mock - -import ( - "math/big" - - "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/state" - vmcommon "github.com/ElrondNetwork/elrond-vm-common" -) - -// AccountWrapMock - -type AccountWrapMock struct { - MockValue int - dataTrie common.Trie - nonce uint64 - code []byte - codeMetadata []byte - codeHash []byte - rootHash []byte - address []byte - trackableDataTrie state.DataTrieTracker - - SetNonceWithJournalCalled func(nonce uint64) error `json:"-"` - SetCodeHashWithJournalCalled func(codeHash []byte) error `json:"-"` - SetCodeWithJournalCalled func(codeHash []byte) error `json:"-"` - AccountDataHandlerCalled func() vmcommon.AccountDataHandler `json:"-"` -} - -// NewAccountWrapMock - -func NewAccountWrapMock(adr []byte) *AccountWrapMock { - return &AccountWrapMock{ - address: adr, - trackableDataTrie: state.NewTrackableDataTrie([]byte("identifier"), nil), - } -} - -// HasNewCode - -func (awm *AccountWrapMock) HasNewCode() bool { - return false -} - -// SetUserName - -func (awm *AccountWrapMock) SetUserName(_ []byte) { -} - -// GetUserName - -func (awm *AccountWrapMock) GetUserName() []byte { - return nil -} - -// AddToBalance - -func (awm *AccountWrapMock) AddToBalance(_ *big.Int) error { - return nil -} - -// SubFromBalance - -func (awm *AccountWrapMock) SubFromBalance(_ *big.Int) error { - return nil -} - -// GetBalance - -func (awm *AccountWrapMock) GetBalance() *big.Int { - return nil -} - -// ClaimDeveloperRewards - -func (awm *AccountWrapMock) ClaimDeveloperRewards([]byte) (*big.Int, error) { - return nil, nil -} - -// AddToDeveloperReward - -func (awm *AccountWrapMock) AddToDeveloperReward(*big.Int) { - -} - -// GetDeveloperReward - -func (awm *AccountWrapMock) GetDeveloperReward() *big.Int { - return nil -} - -// ChangeOwnerAddress - -func (awm *AccountWrapMock) ChangeOwnerAddress([]byte, []byte) error { - return nil -} - -// SetOwnerAddress - -func (awm *AccountWrapMock) SetOwnerAddress([]byte) { - -} - -// GetOwnerAddress - -func (awm *AccountWrapMock) GetOwnerAddress() []byte { - return nil -} - -// GetCodeHash - -func (awm *AccountWrapMock) GetCodeHash() []byte { - return awm.codeHash -} - -// RetrieveValueFromDataTrieTracker - -func (awm *AccountWrapMock) RetrieveValueFromDataTrieTracker(key []byte) ([]byte, error) { - return awm.trackableDataTrie.RetrieveValue(key) -} - -// SetCodeHash - -func (awm *AccountWrapMock) SetCodeHash(codeHash []byte) { - awm.codeHash = codeHash -} - -// SetCode - -func (awm *AccountWrapMock) SetCode(code []byte) { - awm.code = code -} - -// SetCodeMetadata - -func (awm *AccountWrapMock) SetCodeMetadata(codeMetadata []byte) { - awm.codeMetadata = codeMetadata -} - -// GetCodeMetadata - -func (awm *AccountWrapMock) GetCodeMetadata() []byte { - return awm.codeMetadata -} - -// GetRootHash - -func (awm *AccountWrapMock) GetRootHash() []byte { - return awm.rootHash -} - -// SetRootHash - -func (awm *AccountWrapMock) SetRootHash(rootHash []byte) { - awm.rootHash = rootHash -} - -// AddressBytes - -func (awm *AccountWrapMock) AddressBytes() []byte { - return awm.address -} - -// DataTrie - -func (awm *AccountWrapMock) DataTrie() common.Trie { - return awm.dataTrie -} - -// SetDataTrie - -func (awm *AccountWrapMock) SetDataTrie(trie common.Trie) { - awm.dataTrie = trie - awm.trackableDataTrie.SetDataTrie(trie) -} - -// DataTrieTracker - -func (awm *AccountWrapMock) DataTrieTracker() state.DataTrieTracker { - return awm.trackableDataTrie -} - -// AccountDataHandler - -func (awm *AccountWrapMock) AccountDataHandler() vmcommon.AccountDataHandler { - if awm.AccountDataHandlerCalled != nil { - return awm.AccountDataHandlerCalled() - } - return awm.trackableDataTrie -} - -// IncreaseNonce - -func (awm *AccountWrapMock) IncreaseNonce(val uint64) { - awm.nonce = awm.nonce + val -} - -// GetNonce - -func (awm *AccountWrapMock) GetNonce() uint64 { - return awm.nonce -} - -// IsInterfaceNil - -func (awm *AccountWrapMock) IsInterfaceNil() bool { - return awm == nil -} diff --git a/process/mock/cryptoComponentsMock.go b/process/mock/cryptoComponentsMock.go index 7c74300b2e1..60bd0918dca 100644 --- a/process/mock/cryptoComponentsMock.go +++ b/process/mock/cryptoComponentsMock.go @@ -1,21 +1,23 @@ package mock import ( + "errors" "sync" "github.com/ElrondNetwork/elrond-go-crypto" + cryptoCommon "github.com/ElrondNetwork/elrond-go/common/crypto" ) // CryptoComponentsMock - type CryptoComponentsMock struct { - BlockSig crypto.SingleSigner - TxSig crypto.SingleSigner - MultiSig crypto.MultiSigner - PeerSignHandler crypto.PeerSignatureHandler - BlKeyGen crypto.KeyGenerator - TxKeyGen crypto.KeyGenerator - PubKey crypto.PublicKey - mutMultiSig sync.RWMutex + BlockSig crypto.SingleSigner + TxSig crypto.SingleSigner + MultiSigContainer cryptoCommon.MultiSignerContainer + PeerSignHandler crypto.PeerSignatureHandler + BlKeyGen crypto.KeyGenerator + TxKeyGen crypto.KeyGenerator + PubKey crypto.PublicKey + mutMultiSig sync.RWMutex } // BlockSigner - @@ -28,18 +30,32 @@ func (ccm *CryptoComponentsMock) TxSingleSigner() crypto.SingleSigner { return ccm.TxSig } -// MultiSigner - -func (ccm *CryptoComponentsMock) MultiSigner() crypto.MultiSigner { +// GetMultiSigner - +func (ccm *CryptoComponentsMock) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { ccm.mutMultiSig.RLock() defer ccm.mutMultiSig.RUnlock() - return ccm.MultiSig + + if ccm.MultiSigContainer == nil { + return nil, errors.New("multisigner container is nil") + } + + return ccm.MultiSigContainer.GetMultiSigner(epoch) } -// SetMultiSigner - -func (ccm *CryptoComponentsMock) SetMultiSigner(multiSigner crypto.MultiSigner) error { +// MultiSignerContainer - +func (ccm *CryptoComponentsMock) MultiSignerContainer() cryptoCommon.MultiSignerContainer { + ccm.mutMultiSig.RLock() + defer ccm.mutMultiSig.RUnlock() + + return ccm.MultiSigContainer +} + +// SetMultiSignerContainer - +func (ccm *CryptoComponentsMock) SetMultiSignerContainer(msc cryptoCommon.MultiSignerContainer) error { ccm.mutMultiSig.Lock() - ccm.MultiSig = multiSigner - ccm.mutMultiSig.Unlock() + defer ccm.mutMultiSig.Unlock() + + ccm.MultiSigContainer = msc return nil } @@ -69,14 +85,14 @@ func (ccm *CryptoComponentsMock) PublicKey() crypto.PublicKey { // Clone - func (ccm *CryptoComponentsMock) Clone() interface{} { return &CryptoComponentsMock{ - BlockSig: ccm.BlockSig, - TxSig: ccm.TxSig, - MultiSig: ccm.MultiSig, - PeerSignHandler: ccm.PeerSignHandler, - BlKeyGen: ccm.BlKeyGen, - TxKeyGen: ccm.TxKeyGen, - PubKey: ccm.PubKey, - mutMultiSig: sync.RWMutex{}, + BlockSig: ccm.BlockSig, + TxSig: ccm.TxSig, + MultiSigContainer: ccm.MultiSigContainer, + PeerSignHandler: ccm.PeerSignHandler, + BlKeyGen: ccm.BlKeyGen, + TxKeyGen: ccm.TxKeyGen, + PubKey: ccm.PubKey, + mutMultiSig: sync.RWMutex{}, } } diff --git a/process/mock/dataTrieTrackerStub.go b/process/mock/dataTrieTrackerStub.go deleted file mode 100644 index 36afb9f4a65..00000000000 --- a/process/mock/dataTrieTrackerStub.go +++ /dev/null @@ -1,50 +0,0 @@ -package mock - -import ( - "github.com/ElrondNetwork/elrond-go/common" -) - -// DataTrieTrackerStub - -type DataTrieTrackerStub struct { - ClearDataCachesCalled func() - DirtyDataCalled func() map[string][]byte - RetrieveValueCalled func(key []byte) ([]byte, error) - SaveKeyValueCalled func(key []byte, value []byte) error - SetDataTrieCalled func(tr common.Trie) - DataTrieCalled func() common.Trie -} - -// ClearDataCaches - -func (dtts *DataTrieTrackerStub) ClearDataCaches() { - dtts.ClearDataCachesCalled() -} - -// DirtyData - -func (dtts *DataTrieTrackerStub) DirtyData() map[string][]byte { - return dtts.DirtyDataCalled() -} - -// RetrieveValue - -func (dtts *DataTrieTrackerStub) RetrieveValue(key []byte) ([]byte, error) { - return dtts.RetrieveValueCalled(key) -} - -// SaveKeyValue - -func (dtts *DataTrieTrackerStub) SaveKeyValue(key []byte, value []byte) error { - return dtts.SaveKeyValueCalled(key, value) -} - -// SetDataTrie - -func (dtts *DataTrieTrackerStub) SetDataTrie(tr common.Trie) { - dtts.SetDataTrieCalled(tr) -} - -// DataTrie - -func (dtts *DataTrieTrackerStub) DataTrie() common.Trie { - return dtts.DataTrieCalled() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (dtts *DataTrieTrackerStub) IsInterfaceNil() bool { - return dtts == nil -} diff --git a/process/mock/peerAccountHandlerMock.go b/process/mock/peerAccountHandlerMock.go index bc1b1fcdd66..a862fa4153e 100644 --- a/process/mock/peerAccountHandlerMock.go +++ b/process/mock/peerAccountHandlerMock.go @@ -274,11 +274,6 @@ func (p *PeerAccountHandlerMock) DataTrie() common.Trie { return nil } -// DataTrieTracker - -func (p *PeerAccountHandlerMock) DataTrieTracker() state.DataTrieTracker { - return nil -} - // GetConsecutiveProposerMisses - func (p *PeerAccountHandlerMock) GetConsecutiveProposerMisses() uint32 { if p.GetConsecutiveProposerMissesCalled != nil { diff --git a/process/p2p/interceptedDirectConnectionInfo.go b/process/p2p/interceptedDirectConnectionInfo.go index 1b5ec693565..bca83e7cc9f 100644 --- a/process/p2p/interceptedDirectConnectionInfo.go +++ b/process/p2p/interceptedDirectConnectionInfo.go @@ -7,7 +7,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/marshal" "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/p2p/message" + "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" ) @@ -22,7 +22,7 @@ type ArgInterceptedDirectConnectionInfo struct { // interceptedDirectConnectionInfo is a wrapper over DirectConnectionInfo type interceptedDirectConnectionInfo struct { - directConnectionInfo message.DirectConnectionInfo + directConnectionInfo p2p.DirectConnectionInfo numOfShards uint32 } @@ -58,8 +58,8 @@ func checkArgs(args ArgInterceptedDirectConnectionInfo) error { return nil } -func createDirectConnectionInfo(marshaller marshal.Marshalizer, buff []byte) (*message.DirectConnectionInfo, error) { - directConnectionInfo := &message.DirectConnectionInfo{} +func createDirectConnectionInfo(marshaller marshal.Marshalizer, buff []byte) (*p2p.DirectConnectionInfo, error) { + directConnectionInfo := &p2p.DirectConnectionInfo{} err := marshaller.Unmarshal(directConnectionInfo, buff) if err != nil { return nil, err diff --git a/process/p2p/interceptedDirectConnectionInfo_test.go b/process/p2p/interceptedDirectConnectionInfo_test.go index ce3338df3da..c5a065ed08a 100644 --- a/process/p2p/interceptedDirectConnectionInfo_test.go +++ b/process/p2p/interceptedDirectConnectionInfo_test.go @@ -8,7 +8,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/marshal" - "github.com/ElrondNetwork/elrond-go/p2p/message" + "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/process" "github.com/stretchr/testify/assert" ) @@ -17,7 +17,7 @@ const providedShard = "5" func createMockArgInterceptedDirectConnectionInfo() ArgInterceptedDirectConnectionInfo { marshaller := &marshal.GogoProtoMarshalizer{} - msg := &message.DirectConnectionInfo{ + msg := &p2p.DirectConnectionInfo{ ShardId: providedShard, } msgBuff, _ := marshaller.Marshal(msg) @@ -87,7 +87,7 @@ func Test_interceptedDirectConnectionInfo_CheckValidity(t *testing.T) { t.Parallel() args := createMockArgInterceptedDirectConnectionInfo() - msg := &message.DirectConnectionInfo{ + msg := &p2p.DirectConnectionInfo{ ShardId: "invalid shard", } msgBuff, _ := args.Marshaller.Marshal(msg) diff --git a/process/peer/process_test.go b/process/peer/process_test.go index 67a9aa11460..6d2ff85f6e9 100644 --- a/process/peer/process_test.go +++ b/process/peer/process_test.go @@ -304,7 +304,7 @@ func TestValidatorStatisticsProcessor_SaveInitialStateGetAccountReturnsInvalid(t peerAdapter := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - return &mock.AccountWrapMock{}, nil + return &stateMock.AccountWrapMock{}, nil }, } @@ -487,7 +487,7 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateGetExistingAccountInvalidTy adapter := getAccountsMock() adapter.LoadAccountCalled = func(address []byte) (handler vmcommon.AccountHandler, e error) { - return &mock.AccountWrapMock{}, nil + return &stateMock.AccountWrapMock{}, nil } arguments := createMockArguments() diff --git a/process/rewardTransaction/process.go b/process/rewardTransaction/process.go index 099124ec611..801b18b2e13 100644 --- a/process/rewardTransaction/process.go +++ b/process/rewardTransaction/process.go @@ -111,13 +111,13 @@ func (rtp *rewardTxProcessor) saveAccumulatedRewards( existingReward := big.NewInt(0) fullRewardKey := core.ElrondProtectedKeyPrefix + rewardKey - val, err := userAccount.DataTrieTracker().RetrieveValue([]byte(fullRewardKey)) + val, err := userAccount.RetrieveValue([]byte(fullRewardKey)) if err == nil { existingReward.SetBytes(val) } existingReward.Add(existingReward, rtx.Value) - _ = userAccount.DataTrieTracker().SaveKeyValue([]byte(fullRewardKey), existingReward.Bytes()) + _ = userAccount.SaveKeyValue([]byte(fullRewardKey), existingReward.Bytes()) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/process/rewardTransaction/process_test.go b/process/rewardTransaction/process_test.go index 1a7d06b3437..feebdf4efca 100644 --- a/process/rewardTransaction/process_test.go +++ b/process/rewardTransaction/process_test.go @@ -245,14 +245,14 @@ func TestRewardTxProcessor_ProcessRewardTransactionToASmartContractShouldWork(t err := rtp.ProcessRewardTransaction(&rwdTx) assert.Nil(t, err) assert.True(t, saveAccountWasCalled) - val, err := userAccount.DataTrieTracker().RetrieveValue([]byte(core.ElrondProtectedKeyPrefix + rewardTransaction.RewardKey)) + val, err := userAccount.RetrieveValue([]byte(core.ElrondProtectedKeyPrefix + rewardTransaction.RewardKey)) assert.Nil(t, err) assert.True(t, rwdTx.Value.Cmp(big.NewInt(0).SetBytes(val)) == 0) err = rtp.ProcessRewardTransaction(&rwdTx) assert.Nil(t, err) assert.True(t, saveAccountWasCalled) - val, err = userAccount.DataTrieTracker().RetrieveValue([]byte(core.ElrondProtectedKeyPrefix + rewardTransaction.RewardKey)) + val, err = userAccount.RetrieveValue([]byte(core.ElrondProtectedKeyPrefix + rewardTransaction.RewardKey)) assert.Nil(t, err) rwdTx.Value.Add(rwdTx.Value, rwdTx.Value) assert.True(t, rwdTx.Value.Cmp(big.NewInt(0).SetBytes(val)) == 0) diff --git a/process/scToProtocol/stakingToPeer.go b/process/scToProtocol/stakingToPeer.go index cc2acc0f308..8f10bd1aad3 100644 --- a/process/scToProtocol/stakingToPeer.go +++ b/process/scToProtocol/stakingToPeer.go @@ -139,7 +139,7 @@ func (stp *stakingToPeer) getUserAccount(key []byte) (state.UserAccountHandler, } func (stp *stakingToPeer) getStorageFromAccount(userAcc state.UserAccountHandler, key []byte) []byte { - value, err := userAcc.DataTrieTracker().RetrieveValue(key) + value, err := userAcc.RetrieveValue(key) if err != nil { return nil } diff --git a/process/scToProtocol/stakingToPeer_test.go b/process/scToProtocol/stakingToPeer_test.go index f1274569222..7bb320ededb 100644 --- a/process/scToProtocol/stakingToPeer_test.go +++ b/process/scToProtocol/stakingToPeer_test.go @@ -316,7 +316,7 @@ func TestStakingToPeer_UpdateProtocolCannotSetRewardAddressShouldErr(t *testing. return userAcc, nil } retData, _ := json.Marshal(&stakingData) - _ = userAcc.DataTrieTracker().SaveKeyValue(offset, retData) + _ = userAcc.SaveKeyValue(offset, retData) arguments.BaseState = baseState arguments.ArgParser = argParser @@ -371,7 +371,7 @@ func TestStakingToPeer_UpdateProtocolEmptyDataShouldNotAddToTrie(t *testing.T) { baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil } - _ = userAcc.DataTrieTracker().SaveKeyValue(offset, nil) + _ = userAcc.SaveKeyValue(offset, nil) arguments.BaseState = baseState arguments.ArgParser = argParser @@ -439,7 +439,7 @@ func TestStakingToPeer_UpdateProtocolCannotSaveAccountShouldErr(t *testing.T) { return userAcc, nil } retData, _ := json.Marshal(&stakingData) - _ = userAcc.DataTrieTracker().SaveKeyValue(offset, retData) + _ = userAcc.SaveKeyValue(offset, retData) arguments.BaseState = baseState arguments.ArgParser = argParser @@ -502,7 +502,7 @@ func TestStakingToPeer_UpdateProtocolCannotSaveAccountNonceShouldErr(t *testing. return userAcc, nil } retData, _ := json.Marshal(&stakingData) - _ = userAcc.DataTrieTracker().SaveKeyValue(offset, retData) + _ = userAcc.SaveKeyValue(offset, retData) arguments.BaseState = baseState arguments.ArgParser = argParser @@ -568,7 +568,7 @@ func TestStakingToPeer_UpdateProtocol(t *testing.T) { return userAcc, nil } retData, _ := json.Marshal(&stakingData) - _ = userAcc.DataTrieTracker().SaveKeyValue(offset, retData) + _ = userAcc.SaveKeyValue(offset, retData) arguments.BaseState = baseState stp, _ := NewStakingToPeer(arguments) @@ -627,7 +627,7 @@ func TestStakingToPeer_UpdateProtocolCannotSaveUnStakedNonceShouldErr(t *testing return userAcc, nil } retData, _ := json.Marshal(&stakingData) - _ = userAcc.DataTrieTracker().SaveKeyValue(offset, retData) + _ = userAcc.SaveKeyValue(offset, retData) arguments.BaseState = baseState arguments.ArgParser = argParser diff --git a/process/smartContract/hooks/blockChainHook.go b/process/smartContract/hooks/blockChainHook.go index 39ce677cd3f..ff3e6a1917a 100644 --- a/process/smartContract/hooks/blockChainHook.go +++ b/process/smartContract/hooks/blockChainHook.go @@ -27,7 +27,7 @@ import ( "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" vmcommon "github.com/ElrondNetwork/elrond-vm-common" "github.com/ElrondNetwork/elrond-vm-common/parsers" ) @@ -726,13 +726,13 @@ func (bh *BlockChainHookImpl) ClearCompiledCodes() { func (bh *BlockChainHookImpl) makeCompiledSCStorage() error { if bh.nilCompiledSCStore { - bh.compiledScStorage = storageUnit.NewNilStorer() + bh.compiledScStorage = storageunit.NewNilStorer() return nil } dbConfig := factory.GetDBFromConfig(bh.configSCStorage.DB) dbConfig.FilePath = path.Join(bh.workingDir, defaultCompiledSCPath, bh.configSCStorage.DB.FilePath) - store, err := storageUnit.NewStorageUnitFromConf( + store, err := storageunit.NewStorageUnitFromConf( factory.GetCacherFromConfig(bh.configSCStorage.Cache), dbConfig, ) diff --git a/process/smartContract/hooks/blockChainHook_test.go b/process/smartContract/hooks/blockChainHook_test.go index a571559b64b..1d3d90d3652 100644 --- a/process/smartContract/hooks/blockChainHook_test.go +++ b/process/smartContract/hooks/blockChainHook_test.go @@ -22,7 +22,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process/smartContract/hooks" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/epochNotifier" @@ -42,7 +42,7 @@ func createMockBlockChainHookArgs() hooks.ArgBlockChainHook { arguments := hooks.ArgBlockChainHook{ Accounts: &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - return &mock.AccountWrapMock{}, nil + return &stateMock.AccountWrapMock{}, nil }, }, PubkeyConv: mock.NewPubkeyConverterMock(32), @@ -356,7 +356,7 @@ func TestBlockChainHookImpl_GetStorageDataCannotRetrieveAccountValueExpectError( return nil, expectedErr }, } - account := &mock.AccountWrapMock{ + account := &stateMock.AccountWrapMock{ AccountDataHandlerCalled: func() vmcommon.AccountDataHandler { return dataTrieStub }, @@ -398,8 +398,8 @@ func TestBlockChainHookImpl_GetStorageDataShouldWork(t *testing.T) { variableIdentifier := []byte("variable") variableValue := []byte("value") - accnt := mock.NewAccountWrapMock(nil) - _ = accnt.DataTrieTracker().SaveKeyValue(variableIdentifier, variableValue) + accnt := stateMock.NewAccountWrapMock(nil) + _ = accnt.SaveKeyValue(variableIdentifier, variableValue) args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ @@ -889,7 +889,7 @@ func TestBlockChainHookImpl_IsPayableSCNonPayable(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - acc := &mock.AccountWrapMock{} + acc := &stateMock.AccountWrapMock{} acc.SetCodeMetadata([]byte{0, 0}) return acc, nil }, @@ -906,7 +906,7 @@ func TestBlockChainHookImpl_IsPayablePayable(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - acc := &mock.AccountWrapMock{} + acc := &stateMock.AccountWrapMock{} acc.SetCodeMetadata([]byte{0, vmcommon.MetadataPayable}) return acc, nil }, @@ -928,7 +928,7 @@ func TestBlockChainHookImpl_IsPayablePayableBySC(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - acc := &mock.AccountWrapMock{} + acc := &stateMock.AccountWrapMock{} acc.SetCodeMetadata([]byte{0, vmcommon.MetadataPayableBySC}) return acc, nil }, @@ -1100,11 +1100,11 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { args.ConfigSCStorage = config.StorageConfig{ Cache: config.CacheConfig{ Capacity: 10, - Type: string(storageUnit.LRUCache), + Type: string(storageunit.LRUCache), }, DB: config.DBConfig{ FilePath: "test1", - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), MaxBatchSize: 1, MaxOpenFiles: 10, }, @@ -1137,11 +1137,11 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { args.ConfigSCStorage = config.StorageConfig{ Cache: config.CacheConfig{ Capacity: 10, - Type: string(storageUnit.LRUCache), + Type: string(storageunit.LRUCache), }, DB: config.DBConfig{ FilePath: "test2", - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), MaxBatchSize: 1, MaxOpenFiles: 10, }, @@ -1284,7 +1284,7 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrSender, addressContainer) - return mock.NewAccountWrapMock(addrSender), nil + return stateMock.NewAccountWrapMock(addrSender), nil }, LoadAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { @@ -1308,7 +1308,7 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrSender, addressContainer) - return mock.NewAccountWrapMock(addrSender), nil + return stateMock.NewAccountWrapMock(addrSender), nil }, LoadAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { @@ -1348,12 +1348,12 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrSender, addressContainer) - return mock.NewAccountWrapMock(addrSender), nil + return stateMock.NewAccountWrapMock(addrSender), nil }, LoadAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrReceiver, addressContainer) - return mock.NewAccountWrapMock(addrReceiver), nil + return stateMock.NewAccountWrapMock(addrReceiver), nil }, } @@ -1419,12 +1419,12 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrSender, addressContainer) - return mock.NewAccountWrapMock(addrSender), nil + return stateMock.NewAccountWrapMock(addrSender), nil }, LoadAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrReceiver, addressContainer) - return mock.NewAccountWrapMock(addrReceiver), nil + return stateMock.NewAccountWrapMock(addrReceiver), nil }, SaveAccountCalled: func(account vmcommon.AccountHandler) error { isSender := bytes.Equal(addrSender, account.AddressBytes()) @@ -1453,7 +1453,7 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrSender, addressContainer) - return mock.NewAccountWrapMock(addrSender), nil + return stateMock.NewAccountWrapMock(addrSender), nil }, LoadAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { @@ -1485,7 +1485,7 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrSender, addressContainer) - return mock.NewAccountWrapMock(addrSender), nil + return stateMock.NewAccountWrapMock(addrSender), nil }, SaveAccountCalled: func(account vmcommon.AccountHandler) error { require.Equal(t, addrSender, account.AddressBytes()) @@ -1509,12 +1509,12 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrSender, addressContainer) - return mock.NewAccountWrapMock(addrSender), nil + return stateMock.NewAccountWrapMock(addrSender), nil }, LoadAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, addrReceiver, addressContainer) - return mock.NewAccountWrapMock(addrReceiver), nil + return stateMock.NewAccountWrapMock(addrReceiver), nil }, SaveAccountCalled: func(account vmcommon.AccountHandler) error { isSender := bytes.Equal(addrSender, account.AddressBytes()) @@ -1587,8 +1587,8 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { require.Equal(t, address, addressContainer) - account := mock.NewAccountWrapMock(address) - _ = account.DataTrieTracker().SaveKeyValue(completeEsdtTokenKey, invalidUnmarshalledData) + account := stateMock.NewAccountWrapMock(address) + _ = account.SaveKeyValue(completeEsdtTokenKey, invalidUnmarshalledData) return account, nil }, @@ -1634,7 +1634,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - addressHandler := mock.NewAccountWrapMock(address) + addressHandler := stateMock.NewAccountWrapMock(address) addressHandler.SetDataTrie(nil) return addressHandler, nil @@ -1658,7 +1658,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - addressHandler := mock.NewAccountWrapMock(address) + addressHandler := stateMock.NewAccountWrapMock(address) addressHandler.SetDataTrie(&trie.TrieStub{ GetCalled: func(key []byte) ([]byte, error) { return make([]byte, 0), nil @@ -1687,10 +1687,10 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - addressHandler := mock.NewAccountWrapMock(address) + addressHandler := stateMock.NewAccountWrapMock(address) buffToken, _ := args.Marshalizer.Marshal(testESDTData) key := append(completeEsdtTokenKey, big.NewInt(0).SetUint64(nftNonce).Bytes()...) - _ = addressHandler.DataTrieTracker().SaveKeyValue(key, buffToken) + _ = addressHandler.SaveKeyValue(key, buffToken) return addressHandler, nil }, @@ -1713,9 +1713,9 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - addressHandler := mock.NewAccountWrapMock(address) + addressHandler := stateMock.NewAccountWrapMock(address) buffToken, _ := args.Marshalizer.Marshal(testESDTData) - _ = addressHandler.DataTrieTracker().SaveKeyValue(completeEsdtTokenKey, buffToken) + _ = addressHandler.SaveKeyValue(completeEsdtTokenKey, buffToken) return addressHandler, nil }, @@ -1739,7 +1739,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return mock.NewAccountWrapMock(address), nil + return stateMock.NewAccountWrapMock(address), nil }, } args.NFTStorageHandler = &testscommon.SimpleNFTStorageHandlerStub{ @@ -1767,7 +1767,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return mock.NewAccountWrapMock(address), nil + return stateMock.NewAccountWrapMock(address), nil }, } args.NFTStorageHandler = &testscommon.SimpleNFTStorageHandlerStub{ diff --git a/process/smartContract/process.go b/process/smartContract/process.go index 0d5f75001a7..7599493bbcf 100644 --- a/process/smartContract/process.go +++ b/process/smartContract/process.go @@ -262,7 +262,8 @@ func (sc *scProcessor) ExecuteSmartContractTransaction( duration := sw.GetMeasurement("execute") if duration > executeDurationAlarmThreshold { - log.Debug(fmt.Sprintf("scProcessor.ExecuteSmartContractTransaction(): execution took > %s", executeDurationAlarmThreshold), "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) + txHash := sc.computeTxHashUnsafe(tx) + log.Debug(fmt.Sprintf("scProcessor.ExecuteSmartContractTransaction(): execution took > %s", executeDurationAlarmThreshold), "tx hash", txHash, "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) } else { log.Trace("scProcessor.ExecuteSmartContractTransaction()", "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) } @@ -835,7 +836,8 @@ func (sc *scProcessor) ExecuteBuiltInFunction( duration := sw.GetMeasurement("executeBuiltIn") if duration > executeDurationAlarmThreshold { - log.Debug(fmt.Sprintf("scProcessor.ExecuteBuiltInFunction(): execution took > %s", executeDurationAlarmThreshold), "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) + txHash := sc.computeTxHashUnsafe(tx) + log.Debug(fmt.Sprintf("scProcessor.ExecuteBuiltInFunction(): execution took > %s", executeDurationAlarmThreshold), "tx hash", txHash, "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) } else { log.Trace("scProcessor.ExecuteBuiltInFunction()", "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) } @@ -1600,7 +1602,8 @@ func (sc *scProcessor) DeploySmartContract(tx data.TransactionHandler, acntSnd s duration := sw.GetMeasurement("deploy") if duration > executeDurationAlarmThreshold { - log.Debug(fmt.Sprintf("scProcessor.DeploySmartContract(): execution took > %s", executeDurationAlarmThreshold), "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) + txHash := sc.computeTxHashUnsafe(tx) + log.Debug(fmt.Sprintf("scProcessor.DeploySmartContract(): execution took > %s", executeDurationAlarmThreshold), "tx hash", txHash, "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) } else { log.Trace("scProcessor.DeploySmartContract()", "sc", tx.GetRcvAddr(), "duration", duration, "returnCode", returnCode, "err", err, "data", string(tx.GetData())) } @@ -2276,6 +2279,10 @@ func (sc *scProcessor) useLastTransferAsAsyncCallBackWhenNeeded( return false } + if sc.enableEpochsHandler.IsFixAsyncCallBackArgsListFlagEnabled() { + result.Data = append(result.Data, []byte("@"+core.ConvertToEvenHex(int(vmOutput.ReturnCode)))...) + } + addReturnDataToSCR(vmOutput, result) result.CallType = vmData.AsynchronousCallBack result.GasLimit, _ = core.SafeAddUint64(result.GasLimit, vmOutput.GasRemaining) @@ -2447,7 +2454,7 @@ func (sc *scProcessor) processSCOutputAccounts( continue } - err = acc.DataTrieTracker().SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) + err = acc.SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) if err != nil { log.Warn("saveKeyValue", "error", err) return false, nil, err @@ -2804,6 +2811,13 @@ func isReturnOKTxHandler( return bytes.HasPrefix(resultTx.GetData(), []byte(returnOkData)) } +// this function should only be called for logging reasons, since it does not perform sanity checks +func (sc *scProcessor) computeTxHashUnsafe(tx data.TransactionHandler) []byte { + txHash, _ := core.CalculateHash(sc.marshalizer, sc.hasher, tx) + + return txHash +} + // IsPayable returns if address is payable, smart contract ca set to false func (sc *scProcessor) IsPayable(sndAddress []byte, recvAddress []byte) (bool, error) { return sc.blockChainHook.IsPayable(sndAddress, recvAddress) diff --git a/process/smartContract/process_test.go b/process/smartContract/process_test.go index 88366b7388f..2d136f6de57 100644 --- a/process/smartContract/process_test.go +++ b/process/smartContract/process_test.go @@ -23,7 +23,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/state" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/storage/txcache" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/economicsmocks" @@ -101,7 +101,7 @@ func createMockSmartContractProcessorArguments() ArgsNewSmartContractProcessor { GasHandler: &testscommon.GasHandlerStub{ SetGasRefundedCalled: func(gasRefunded uint64, hash []byte) {}, }, - GasSchedule: testscommon.NewGasScheduleNotifierMock(gasSchedule), + GasSchedule: testscommon.NewGasScheduleNotifierMock(gasSchedule), EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, }, @@ -1602,8 +1602,8 @@ func TestScProcessor_ExecuteSmartContractTransactionGasConsumedChecksError(t *te arguments.VmContainer = vm arguments.ArgsParser = argParser arguments.AccountsDB = accntState - arguments.VMOutputCacher, _ = storageUnit.NewCache(storageUnit.CacheConfig{ - Type: storageUnit.LRUCache, + arguments.VMOutputCacher, _ = storageunit.NewCache(storageunit.CacheConfig{ + Type: storageunit.LRUCache, Capacity: 10000, }) @@ -2749,6 +2749,10 @@ func TestScProcessor_CreateCrossShardTransactionsWithAsyncCalls(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(5) arguments := createMockSmartContractProcessorArguments() + enableEpochsHandler := &testscommon.EnableEpochsHandlerStub{ + IsFixAsyncCallBackArgsListFlagEnabledField: false, + } + arguments.EnableEpochsHandler = enableEpochsHandler arguments.AccountsDB = accountsDB arguments.ShardCoordinator = shardCoordinator sc, err := NewSmartContractProcessor(arguments) @@ -2797,9 +2801,23 @@ func TestScProcessor_CreateCrossShardTransactionsWithAsyncCalls(t *testing.T) { require.Nil(t, err) require.Equal(t, len(outputAccounts), len(scTxs)) require.True(t, createdAsyncSCR) - lastScTx := scTxs[len(scTxs)-1].(*smartContractResult.SmartContractResult) - require.Equal(t, vmData.AsynchronousCallBack, lastScTx.CallType) + + t.Run("backwards compatibility", func(t *testing.T) { + require.Equal(t, vmData.AsynchronousCallBack, lastScTx.CallType) + require.Equal(t, []byte(nil), lastScTx.Data) + }) + enableEpochsHandler.IsFixAsyncCallBackArgsListFlagEnabledField = true + + _, scTxs, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{GasRemaining: 1000}, vmData.AsynchronousCall, outputAccounts, tx, txHash) + require.Nil(t, err) + require.Equal(t, len(outputAccounts), len(scTxs)) + require.True(t, createdAsyncSCR) + lastScTx = scTxs[len(scTxs)-1].(*smartContractResult.SmartContractResult) + t.Run("fix enabled, data field is correctly populated", func(t *testing.T) { + require.Equal(t, vmData.AsynchronousCallBack, lastScTx.CallType) + require.Equal(t, []byte("@"+core.ConvertToEvenHex(int(vmcommon.Ok))), lastScTx.Data) + }) tx.Value = big.NewInt(0) scTxs, err = sc.processVMOutput(&vmcommon.VMOutput{GasRemaining: 1000}, txHash, tx, vmData.AsynchronousCall, 10000) @@ -2934,7 +2952,7 @@ func TestScProcessor_ProcessSmartContractResultBadAccType(t *testing.T) { accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - return &mock.AccountWrapMock{}, nil + return &stateMock.AccountWrapMock{}, nil }, } shardCoordinator := mock.NewMultiShardsCoordinatorMock(5) diff --git a/process/sync/metablock_test.go b/process/sync/metablock_test.go index 182051abf79..c773bb2bd0d 100644 --- a/process/sync/metablock_test.go +++ b/process/sync/metablock_test.go @@ -404,6 +404,7 @@ func testMetaWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *testi bs, err := sync.NewMetaBootstrap(args) assert.Nil(t, bs) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), storage.ErrKeyNotFound.Error())) } } diff --git a/process/sync/shardblock.go b/process/sync/shardblock.go index c29a05a9b41..d1238dcd887 100644 --- a/process/sync/shardblock.go +++ b/process/sync/shardblock.go @@ -158,7 +158,7 @@ func isErrGetNodeFromDB(err error) bool { return false } - if strings.Contains(err.Error(), errors.ErrDBIsClosed.Error()) { + if strings.Contains(err.Error(), storage.ErrDBIsClosed.Error()) { return false } diff --git a/process/sync/shardblock_test.go b/process/sync/shardblock_test.go index e8ebe602d51..f61b759556a 100644 --- a/process/sync/shardblock_test.go +++ b/process/sync/shardblock_test.go @@ -23,8 +23,8 @@ import ( "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/process/sync" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/dblookupext" @@ -80,14 +80,14 @@ func createStore() *storageStubs.ChainStorerStub { } func generateTestCache() storage.Cacher { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) return cache } func generateTestUnit() storage.Storer { - storer, _ := storageUnit.NewStorageUnit( + storer, _ := storageunit.NewStorageUnit( generateTestCache(), - memorydb.New(), + database.NewMemDB(), ) return storer @@ -460,6 +460,7 @@ func testShardWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *test bs, err := sync.NewShardBootstrap(args) assert.Nil(t, bs) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), storage.ErrKeyNotFound.Error())) } } diff --git a/process/throttle/antiflood/factory/p2pAntifloodAndBlacklistFactory.go b/process/throttle/antiflood/factory/p2pAntifloodAndBlacklistFactory.go index 651dcf7aa82..e39623784f0 100644 --- a/process/throttle/antiflood/factory/p2pAntifloodAndBlacklistFactory.go +++ b/process/throttle/antiflood/factory/p2pAntifloodAndBlacklistFactory.go @@ -16,9 +16,9 @@ import ( "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/disabled" "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/floodPreventers" "github.com/ElrondNetwork/elrond-go/statusHandler/p2pQuota" + "github.com/ElrondNetwork/elrond-go/storage/cache" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) var log = logger.GetOrCreate("p2p/antiflood/factory") @@ -64,13 +64,13 @@ func initP2PAntiFloodComponents( statusHandler core.AppStatusHandler, currentPid core.PeerID, ) (*AntiFloodComponents, error) { - cache := timecache.NewTimeCache(defaultSpan) - p2pPeerBlackList, err := timecache.NewPeerTimeCache(cache) + timeCache := cache.NewTimeCache(defaultSpan) + p2pPeerBlackList, err := cache.NewPeerTimeCache(timeCache) if err != nil { return nil, err } - publicKeysCache := timecache.NewTimeCache(defaultSpan) + publicKeysCache := cache.NewTimeCache(defaultSpan) fastReactingFloodPreventer, err := createFloodPreventer( ctx, @@ -207,7 +207,7 @@ func createFloodPreventer( selfPid core.PeerID, ) (process.FloodPreventer, error) { cacheConfig := storageFactory.GetCacherFromConfig(antifloodCacheConfig) - blackListCache, err := storageUnit.NewCache(cacheConfig) + blackListCache, err := storageunit.NewCache(cacheConfig) if err != nil { return nil, err } @@ -226,7 +226,7 @@ func createFloodPreventer( return nil, err } - antifloodCache, err := storageUnit.NewCache(cacheConfig) + antifloodCache, err := storageunit.NewCache(cacheConfig) if err != nil { return nil, err } diff --git a/process/throttle/antiflood/factory/p2pOutputAntiflood.go b/process/throttle/antiflood/factory/p2pOutputAntiflood.go index fd1c32f895f..9a1321f1a43 100644 --- a/process/throttle/antiflood/factory/p2pOutputAntiflood.go +++ b/process/throttle/antiflood/factory/p2pOutputAntiflood.go @@ -9,7 +9,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/disabled" "github.com/ElrondNetwork/elrond-go/process/throttle/antiflood/floodPreventers" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) const outputReservedPercent = float32(0) @@ -25,7 +25,7 @@ func NewP2POutputAntiFlood(ctx context.Context, mainConfig config.Config) (proce func initP2POutputAntiFlood(ctx context.Context, mainConfig config.Config) (process.P2PAntifloodHandler, error) { cacheConfig := storageFactory.GetCacherFromConfig(mainConfig.Antiflood.Cache) - antifloodCache, err := storageUnit.NewCache(cacheConfig) + antifloodCache, err := storageunit.NewCache(cacheConfig) if err != nil { return nil, err } diff --git a/process/track/baseBlockTrack_test.go b/process/track/baseBlockTrack_test.go index 024ba5697c3..04ddeb10e93 100644 --- a/process/track/baseBlockTrack_test.go +++ b/process/track/baseBlockTrack_test.go @@ -17,8 +17,8 @@ import ( "github.com/ElrondNetwork/elrond-go/process/track" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" @@ -85,9 +85,9 @@ func initStore() *dataRetriever.ChainStorer { } func generateStorageUnit() storage.Storer { - memDB := memorydb.New() + memDB := database.NewMemDB() - storer, _ := storageUnit.NewStorageUnit( + storer, _ := storageunit.NewStorageUnit( generateTestCache(), memDB, ) @@ -96,7 +96,7 @@ func generateStorageUnit() storage.Storer { } func generateTestCache() storage.Cacher { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) return cache } diff --git a/process/transactionLog/process.go b/process/transactionLog/process.go index 949d541fbb3..1fcf05be87e 100644 --- a/process/transactionLog/process.go +++ b/process/transactionLog/process.go @@ -12,7 +12,7 @@ import ( logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" vmcommon "github.com/ElrondNetwork/elrond-vm-common" ) @@ -44,7 +44,7 @@ func NewTxLogProcessor(args ArgTxLogProcessor) (*txLogProcessor, error) { } if !args.SaveInStorageEnabled { - storer = storageUnit.NewNilStorer() + storer = storageunit.NewNilStorer() } if check.IfNil(args.Marshalizer) { diff --git a/process/txsimulator/txSimulator_test.go b/process/txsimulator/txSimulator_test.go index 031093ecbe1..1a9836f17a6 100644 --- a/process/txsimulator/txSimulator_test.go +++ b/process/txsimulator/txSimulator_test.go @@ -13,7 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/transaction" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/mock" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/storage/txcache" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" @@ -144,8 +144,8 @@ func TestTransactionSimulator_getVMOutput(t *testing.T) { t.Parallel() args := getTxSimulatorArgs() - args.VMOutputCacher, _ = storageUnit.NewCache(storageUnit.CacheConfig{ - Type: storageUnit.LRUCache, + args.VMOutputCacher, _ = storageunit.NewCache(storageunit.CacheConfig{ + Type: storageunit.LRUCache, Capacity: 100, }) @@ -179,8 +179,8 @@ func TestTransactionSimulator_ProcessTxShouldIncludeScrsAndReceipts(t *testing.T } args := getTxSimulatorArgs() - args.VMOutputCacher, _ = storageUnit.NewCache(storageUnit.CacheConfig{ - Type: storageUnit.LRUCache, + args.VMOutputCacher, _ = storageunit.NewCache(storageunit.CacheConfig{ + Type: storageunit.LRUCache, Capacity: 100, }) diff --git a/process/txsimulator/wrappedAccountsDB.go b/process/txsimulator/wrappedAccountsDB.go index 856e9495614..97b13aa9d19 100644 --- a/process/txsimulator/wrappedAccountsDB.go +++ b/process/txsimulator/wrappedAccountsDB.go @@ -24,6 +24,16 @@ func NewReadOnlyAccountsDB(accountsDB state.AccountsAdapter) (*readOnlyAccountsD return &readOnlyAccountsDB{originalAccounts: accountsDB}, nil } +// SetSyncer returns nil for this implementation +func (r *readOnlyAccountsDB) SetSyncer(_ state.AccountsDBSyncer) error { + return nil +} + +// StartSnapshotIfNeeded returns nil for this implementation +func (r *readOnlyAccountsDB) StartSnapshotIfNeeded() error { + return nil +} + // GetCode returns the code for the given account func (r *readOnlyAccountsDB) GetCode(codeHash []byte) []byte { return r.originalAccounts.GetCode(codeHash) diff --git a/process/txsimulator/wrappedAccountsDB_test.go b/process/txsimulator/wrappedAccountsDB_test.go index 567619c05c9..748f991d422 100644 --- a/process/txsimulator/wrappedAccountsDB_test.go +++ b/process/txsimulator/wrappedAccountsDB_test.go @@ -7,7 +7,6 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/state" stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" vmcommon "github.com/ElrondNetwork/elrond-vm-common" @@ -106,7 +105,7 @@ func TestReadOnlyAccountsDB_WriteOperationsShouldNotCalled(t *testing.T) { func TestReadOnlyAccountsDB_ReadOperationsShouldWork(t *testing.T) { t.Parallel() - expectedAcc := &mock.AccountWrapMock{} + expectedAcc := &stateMock.AccountWrapMock{} expectedJournalLen := 37 expectedRootHash := []byte("root") diff --git a/sharding/mock/enableEpochsHandlerMock.go b/sharding/mock/enableEpochsHandlerMock.go index 4eca2905de3..cf0618ffcd2 100644 --- a/sharding/mock/enableEpochsHandlerMock.go +++ b/sharding/mock/enableEpochsHandlerMock.go @@ -536,6 +536,11 @@ func (mock *EnableEpochsHandlerMock) IsRefactorPeersMiniBlocksFlagEnabled() bool return mock.IsRefactorPeersMiniBlocksFlagEnabledField } +// IsFixAsyncCallBackArgsListFlagEnabled - +func (mock *EnableEpochsHandlerMock) IsFixAsyncCallBackArgsListFlagEnabled() bool { + return false +} + // IsInterfaceNil returns true if there is no value under the interface func (mock *EnableEpochsHandlerMock) IsInterfaceNil() bool { return mock == nil diff --git a/sharding/networksharding/peerShardMapper.go b/sharding/networksharding/peerShardMapper.go index 0a63dfef9fb..f969f27ab36 100644 --- a/sharding/networksharding/peerShardMapper.go +++ b/sharding/networksharding/peerShardMapper.go @@ -13,7 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go/p2p" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" ) const maxNumPidsPerPk = 3 @@ -75,12 +75,12 @@ func NewPeerShardMapper(arg ArgPeerShardMapper) (*PeerShardMapper, error) { return nil, p2p.ErrNilPreferredPeersHolder } - pkPeerId, err := lrucache.NewCache(arg.PeerIdPkCache.MaxSize()) + pkPeerId, err := cache.NewLRUCache(arg.PeerIdPkCache.MaxSize()) if err != nil { return nil, err } - peerIdSubTypeCache, err := lrucache.NewCache(arg.PeerIdPkCache.MaxSize()) + peerIdSubTypeCache, err := cache.NewLRUCache(arg.PeerIdPkCache.MaxSize()) if err != nil { return nil, err } diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go b/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go index 709e170ba8f..cb0483b51a4 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go @@ -25,7 +25,7 @@ import ( "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/sharding/mock" "github.com/ElrondNetwork/elrond-go/state" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" "github.com/ElrondNetwork/elrond-go/testscommon/genericMocks" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" "github.com/ElrondNetwork/elrond-go/testscommon/nodeTypeProviderMock" @@ -444,7 +444,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM getCounter := int32(0) putCounter := int32(0) - cache := &mock.NodesCoordinatorCacheMock{ + lruCache := &mock.NodesCoordinatorCacheMock{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { atomic.AddInt32(&putCounter, 1) return false @@ -467,7 +467,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM EligibleNodes: eligibleMap, WaitingNodes: waitingMap, SelfPublicKey: []byte("key"), - ConsensusGroupCache: cache, + ConsensusGroupCache: lruCache, ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, ChanStopNode: make(chan endProcess.ArgEndProcess), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, @@ -523,7 +523,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe //consensusGroup := list[0:21] cacheMap := make(map[string]interface{}) - cache := &mock.NodesCoordinatorCacheMock{ + lruCache := &mock.NodesCoordinatorCacheMock{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { mut.Lock() defer mut.Unlock() @@ -555,7 +555,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe EligibleNodes: eligibleMap, WaitingNodes: waitingMap, SelfPublicKey: []byte("key"), - ConsensusGroupCache: cache, + ConsensusGroupCache: lruCache, ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, ChanStopNode: make(chan endProcess.ArgEndProcess), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, @@ -586,7 +586,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400TestEqualSameParams(t *testing.T) { t.Skip("testing consistency - to be run manually") - cache := &mock.NodesCoordinatorCacheMock{ + lruCache := &mock.NodesCoordinatorCacheMock{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -627,7 +627,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400TestEqualSameP EligibleNodes: eligibleMap, WaitingNodes: waitingMap, SelfPublicKey: []byte("key"), - ConsensusGroupCache: cache, + ConsensusGroupCache: lruCache, ChanStopNode: make(chan endProcess.ArgEndProcess), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -844,9 +844,9 @@ func BenchmarkIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400Recompute nodesPerShard := uint32(400) eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") - consensusGroupCache, _ := lrucache.NewCache(1) + consensusGroupCache, _ := cache.NewLRUCache(1) computeMemoryRequirements(consensusGroupCache, consensusGroupSize, eligibleMap, b) - consensusGroupCache, _ = lrucache.NewCache(1) + consensusGroupCache, _ = cache.NewLRUCache(1) runBenchmark(consensusGroupCache, consensusGroupSize, eligibleMap, b) } @@ -855,9 +855,9 @@ func BenchmarkIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400Recomput nodesPerShard := uint32(400) eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") - consensusGroupCache, _ := lrucache.NewCache(1) + consensusGroupCache, _ := cache.NewLRUCache(1) computeMemoryRequirements(consensusGroupCache, consensusGroupSize, eligibleMap, b) - consensusGroupCache, _ = lrucache.NewCache(1) + consensusGroupCache, _ = cache.NewLRUCache(1) runBenchmark(consensusGroupCache, consensusGroupSize, eligibleMap, b) } @@ -866,9 +866,9 @@ func BenchmarkIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400Memoizati nodesPerShard := uint32(400) eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") - consensusGroupCache, _ := lrucache.NewCache(10000) + consensusGroupCache, _ := cache.NewLRUCache(10000) computeMemoryRequirements(consensusGroupCache, consensusGroupSize, eligibleMap, b) - consensusGroupCache, _ = lrucache.NewCache(10000) + consensusGroupCache, _ = cache.NewLRUCache(10000) runBenchmark(consensusGroupCache, consensusGroupSize, eligibleMap, b) } @@ -877,9 +877,9 @@ func BenchmarkIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400Memoizat nodesPerShard := uint32(400) eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") - consensusGroupCache, _ := lrucache.NewCache(1000) + consensusGroupCache, _ := cache.NewLRUCache(1000) computeMemoryRequirements(consensusGroupCache, consensusGroupSize, eligibleMap, b) - consensusGroupCache, _ = lrucache.NewCache(1000) + consensusGroupCache, _ = cache.NewLRUCache(1000) runBenchmark(consensusGroupCache, consensusGroupSize, eligibleMap, b) } diff --git a/state/accountsDB.go b/state/accountsDB.go index dc7a130e5cf..562193fcaee 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -10,6 +10,7 @@ import ( "time" "github.com/ElrondNetwork/elrond-go-core/core" + "github.com/ElrondNetwork/elrond-go-core/core/atomic" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/hashing" "github.com/ElrondNetwork/elrond-go-core/marshal" @@ -17,12 +18,14 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/common/holders" "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" vmcommon "github.com/ElrondNetwork/elrond-vm-common" ) const ( - leavesChannelSize = 100 - lastSnapshotStarted = "lastSnapshot" + leavesChannelSize = 100 + missingNodesChannelSize = 100 + lastSnapshotStarted = "lastSnapshot" ) type loadingMeasurements struct { @@ -73,11 +76,13 @@ type AccountsDB struct { accountFactory AccountFactory storagePruningManager StoragePruningManager obsoleteDataTrieHashes map[string][][]byte + trieSyncer AccountsDBSyncer - lastSnapshot *snapshotInfo - lastRootHash []byte - dataTries common.TriesHolder - entries []JournalEntry + isSnapshotInProgress atomic.Flag + lastSnapshot *snapshotInfo + lastRootHash []byte + dataTries common.TriesHolder + entries []JournalEntry // TODO use mutOp only for critical sections, and refactor to parallelize as much as possible mutOp sync.RWMutex processingMode common.NodeProcessingMode @@ -109,7 +114,11 @@ func NewAccountsDB(args ArgsAccountsDB) (*AccountsDB, error) { return nil, err } - adb := &AccountsDB{ + return createAccountsDb(args), nil +} + +func createAccountsDb(args ArgsAccountsDB) *AccountsDB { + return &AccountsDB{ mainTrie: args.Trie, hasher: args.Hasher, marshaller: args.Marshaller, @@ -126,15 +135,8 @@ func NewAccountsDB(args ArgsAccountsDB) (*AccountsDB, error) { shouldSerializeSnapshots: args.ShouldSerializeSnapshots, lastSnapshot: &snapshotInfo{}, processStatusHandler: args.ProcessStatusHandler, + isSnapshotInProgress: atomic.Flag{}, } - - trieStorageManager := adb.mainTrie.GetStorageManager() - val, err := trieStorageManager.GetFromCurrentEpoch([]byte(common.ActiveDBKey)) - if err != nil || !bytes.Equal(val, []byte(common.ActiveDBVal)) { - startSnapshotAfterRestart(adb, args) - } - - return adb, nil } func checkArgsAccountsDB(args ArgsAccountsDB) error { @@ -160,17 +162,16 @@ func checkArgsAccountsDB(args ArgsAccountsDB) error { return nil } -func startSnapshotAfterRestart(adb AccountsAdapter, args ArgsAccountsDB) { - tsm := args.Trie.GetStorageManager() +func startSnapshotAfterRestart(adb AccountsAdapter, tsm common.StorageManager, processingMode common.NodeProcessingMode) { epoch, err := tsm.GetLatestStorageEpoch() if err != nil { log.Error("could not get latest storage epoch") } putActiveDBMarker := epoch == 0 && err == nil - isInImportDBMode := args.ProcessingMode == common.ImportDb + isInImportDBMode := processingMode == common.ImportDb putActiveDBMarker = putActiveDBMarker || isInImportDBMode if putActiveDBMarker { - log.Debug("marking activeDB", "epoch", epoch, "error", err, "processing mode", args.ProcessingMode) + log.Debug("marking activeDB", "epoch", epoch, "error", err, "processing mode", processingMode) err = tsm.Put([]byte(common.ActiveDBKey), []byte(common.ActiveDBVal)) handleLoggingWhenError("error while putting active DB value into main storer", err) return @@ -204,6 +205,42 @@ func handleLoggingWhenError(message string, err error, extraArguments ...interfa log.Warn(message, append(args, extraArguments...)...) } +// SetSyncer sets the given syncer as the syncer for the underlying trie +func (adb *AccountsDB) SetSyncer(syncer AccountsDBSyncer) error { + if check.IfNil(syncer) { + return ErrNilTrieSyncer + } + + adb.mutOp.Lock() + defer adb.mutOp.Unlock() + + adb.trieSyncer = syncer + return nil +} + +// StartSnapshotIfNeeded starts the snapshot if the previous snapshot process was not fully completed +func (adb *AccountsDB) StartSnapshotIfNeeded() error { + return startSnapshotIfNeeded(adb, adb.trieSyncer, adb.mainTrie.GetStorageManager(), adb.processingMode) +} + +func startSnapshotIfNeeded( + adb AccountsAdapter, + trieSyncer AccountsDBSyncer, + trieStorageManager common.StorageManager, + processingMode common.NodeProcessingMode, +) error { + if check.IfNil(trieSyncer) { + return ErrNilTrieSyncer + } + + val, err := trieStorageManager.GetFromCurrentEpoch([]byte(common.ActiveDBKey)) + if err != nil || !bytes.Equal(val, []byte(common.ActiveDBVal)) { + startSnapshotAfterRestart(adb, trieStorageManager, processingMode) + } + + return nil +} + // GetCode returns the code for the given account func (adb *AccountsDB) GetCode(codeHash []byte) []byte { if len(codeHash) == 0 { @@ -461,67 +498,34 @@ func (adb *AccountsDB) loadDataTrie(accountHandler baseAccountHandler) error { // SaveDataTrie is used to save the data trie (not committing it) and to recompute the new Root value // If data is not dirtied, method will not create its JournalEntries to keep track of data modification func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) error { - if check.IfNil(accountHandler.DataTrieTracker()) { - return ErrNilTrackableDataTrie + oldValues, err := accountHandler.SaveDirtyData(adb.mainTrie) + if err != nil { + return err } - if len(accountHandler.DataTrieTracker().DirtyData()) == 0 { + if len(oldValues) == 0 { return nil } - log.Trace("accountsDB.SaveDataTrie", - "address", hex.EncodeToString(accountHandler.AddressBytes()), - "nonce", accountHandler.GetNonce(), - ) - - if check.IfNil(accountHandler.DataTrie()) { - newDataTrie, err := adb.mainTrie.Recreate(make([]byte, 0)) - if err != nil { - return err - } - - accountHandler.SetDataTrie(newDataTrie) - adb.dataTries.Put(accountHandler.AddressBytes(), newDataTrie) - } - - trackableDataTrie := accountHandler.DataTrieTracker() - dataTrie := trackableDataTrie.DataTrie() - oldValues := make(map[string][]byte) - - for k, v := range trackableDataTrie.DirtyData() { - val, err := dataTrie.Get([]byte(k)) - if err != nil { - return err - } - - oldValues[k] = val - - err = dataTrie.Update([]byte(k), v) - if err != nil { - return err - } - } - entry, err := NewJournalEntryDataTrieUpdates(oldValues, accountHandler) if err != nil { return err } adb.journalize(entry) - rootHash, err := trackableDataTrie.DataTrie().RootHash() + rootHash, err := accountHandler.DataTrie().RootHash() if err != nil { return err } - accountHandler.SetRootHash(rootHash) - trackableDataTrie.ClearDataCaches() - - log.Trace("accountsDB.SaveDataTrie", - "address", hex.EncodeToString(accountHandler.AddressBytes()), - "new root hash", accountHandler.GetRootHash(), - ) if check.IfNil(adb.dataTries.Get(accountHandler.AddressBytes())) { - adb.dataTries.Put(accountHandler.AddressBytes(), accountHandler.DataTrie()) + trie, ok := accountHandler.DataTrie().(common.Trie) + if !ok { + log.Warn("wrong type conversion", "trie type", fmt.Sprintf("%T", accountHandler.DataTrie())) + return nil + } + + adb.dataTries.Put(accountHandler.AddressBytes(), trie) } return nil @@ -1003,19 +1007,16 @@ func (adb *AccountsDB) recreateTrie(options common.RootHashHolder) error { // RecreateAllTries recreates all the tries from the accounts DB func (adb *AccountsDB) RecreateAllTries(rootHash []byte) (map[string]common.Trie, error) { leavesChannel := make(chan core.KeyValueHolder, leavesChannelSize) - err := adb.mainTrie.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash) + err := adb.mainTrie.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) if err != nil { return nil, err } - recreatedTrie, err := adb.mainTrie.Recreate(rootHash) + allTries, err := adb.recreateMainTrie(rootHash) if err != nil { return nil, err } - allTries := make(map[string]common.Trie) - allTries[string(rootHash)] = recreatedTrie - for leaf := range leavesChannel { account := &userAccount{} err = adb.marshaller.Unmarshal(account, leaf.Value()) @@ -1037,6 +1038,18 @@ func (adb *AccountsDB) RecreateAllTries(rootHash []byte) (map[string]common.Trie return allTries, nil } +func (adb *AccountsDB) recreateMainTrie(rootHash []byte) (map[string]common.Trie, error) { + recreatedTrie, err := adb.mainTrie.Recreate(rootHash) + if err != nil { + return nil, err + } + + allTries := make(map[string]common.Trie) + allTries[string(rootHash)] = recreatedTrie + + return allTries, nil +} + // GetTrie returns the trie that has the given rootHash func (adb *AccountsDB) GetTrie(rootHash []byte) (common.Trie, error) { adb.mutOp.Lock() @@ -1092,52 +1105,101 @@ func (adb *AccountsDB) SnapshotState(rootHash []byte) { adb.mutOp.Lock() defer adb.mutOp.Unlock() - trieStorageManager := adb.mainTrie.GetStorageManager() - epoch, err := trieStorageManager.GetLatestStorageEpoch() - if err != nil { - log.Error("snapshotState error", "err", err.Error()) + trieStorageManager, epoch, shouldTakeSnapshot := adb.prepareSnapshot(rootHash) + if !shouldTakeSnapshot { return } - snapshotAlreadyTaken := bytes.Equal(adb.lastSnapshot.rootHash, rootHash) && adb.lastSnapshot.epoch == epoch - if !trieStorageManager.ShouldTakeSnapshot() || snapshotAlreadyTaken { + log.Info("starting snapshot user trie", "rootHash", rootHash, "epoch", epoch) + missingNodesChannel := make(chan []byte, missingNodesChannelSize) + errChan := make(chan error, 1) + stats := newSnapshotStatistics(1, 1) + go func() { + leavesChannel := make(chan core.KeyValueHolder, leavesChannelSize) + stats.NewSnapshotStarted() + trieStorageManager.TakeSnapshot(rootHash, rootHash, leavesChannel, missingNodesChannel, errChan, stats, epoch) + adb.snapshotUserAccountDataTrie(true, rootHash, leavesChannel, missingNodesChannel, errChan, stats, epoch) + + stats.SnapshotFinished() + }() + + go adb.syncMissingNodes(missingNodesChannel, stats) + + go adb.processSnapshotCompletion(stats, missingNodesChannel, errChan, rootHash, "snapshotState user trie", epoch) + + adb.waitForCompletionIfAppropriate(stats) +} + +func (adb *AccountsDB) prepareSnapshot(rootHash []byte) (common.StorageManager, uint32, bool) { + trieStorageManager, epoch, err := adb.getTrieStorageManagerAndLatestEpoch() + if err != nil { + log.Error("prepareSnapshot error", "err", err.Error()) + return nil, 0, false + } + + if !adb.shouldTakeSnapshot(trieStorageManager, rootHash, epoch) { log.Debug("skipping snapshot", "last snapshot rootHash", adb.lastSnapshot.rootHash, "rootHash", rootHash, "last snapshot epoch", adb.lastSnapshot.epoch, "epoch", epoch, + "isSnapshotInProgress", adb.isSnapshotInProgress.IsSet(), ) - return + return nil, 0, false } - log.Info("starting snapshot", "rootHash", rootHash, "epoch", epoch) - + adb.isSnapshotInProgress.SetValue(true) adb.lastSnapshot.rootHash = rootHash adb.lastSnapshot.epoch = epoch err = trieStorageManager.Put([]byte(lastSnapshotStarted), rootHash) handleLoggingWhenError("could not set lastSnapshotStarted", err, "rootHash", rootHash) - trieStorageManager.EnterPruningBufferingMode() - errChan := make(chan error, 1) - stats := newSnapshotStatistics(1) - go func() { - leavesChannel := make(chan core.KeyValueHolder, leavesChannelSize) - stats.NewSnapshotStarted() - trieStorageManager.TakeSnapshot(rootHash, rootHash, leavesChannel, errChan, stats, epoch) - adb.snapshotUserAccountDataTrie(true, rootHash, leavesChannel, errChan, stats, epoch) - trieStorageManager.ExitPruningBufferingMode() + return trieStorageManager, epoch, true +} - stats.wg.Done() - }() +func (adb *AccountsDB) getTrieStorageManagerAndLatestEpoch() (common.StorageManager, uint32, error) { + trieStorageManager := adb.mainTrie.GetStorageManager() + epoch, err := trieStorageManager.GetLatestStorageEpoch() + if err != nil { + return nil, 0, fmt.Errorf("%w while getting the latest storage epoch", err) + } - go adb.markActiveDBAfterSnapshot(stats, errChan, rootHash, "snapshotState user trie", epoch) + return trieStorageManager, epoch, nil +} - adb.waitForCompletionIfAppropriate(stats) +func (adb *AccountsDB) shouldTakeSnapshot(trieStorageManager common.StorageManager, rootHash []byte, epoch uint32) bool { + snapshotAlreadyTaken := bytes.Equal(adb.lastSnapshot.rootHash, rootHash) && adb.lastSnapshot.epoch == epoch + if snapshotAlreadyTaken { + return false + } + + if adb.isSnapshotInProgress.IsSet() { + return false + } + + return trieStorageManager.ShouldTakeSnapshot() } -func (adb *AccountsDB) markActiveDBAfterSnapshot(stats *snapshotStatistics, errChan chan error, rootHash []byte, message string, epoch uint32) { +func (adb *AccountsDB) finishSnapshotOperation( + rootHash []byte, + stats *snapshotStatistics, + missingNodesCh chan []byte, + message string, +) { + stats.WaitForSnapshotsToFinish() + close(missingNodesCh) + stats.WaitForSyncToFinish() + + adb.mainTrie.GetStorageManager().ExitPruningBufferingMode() + stats.PrintStats(message, rootHash) +} + +func (adb *AccountsDB) processSnapshotCompletion(stats *snapshotStatistics, missingNodesCh chan []byte, errChan chan error, rootHash []byte, message string, epoch uint32) { + adb.finishSnapshotOperation(rootHash, stats, missingNodesCh, message) + + defer adb.isSnapshotInProgress.Reset() trieStorageManager := adb.mainTrie.GetStorageManager() containsErrorDuringSnapshot := emptyErrChanReturningHadContained(errChan) @@ -1157,6 +1219,26 @@ func (adb *AccountsDB) markActiveDBAfterSnapshot(stats *snapshotStatistics, errC handleLoggingWhenError("error while putting active DB value into main storer", errPut) } +func (adb *AccountsDB) syncMissingNodes(missingNodesChan chan []byte, stats *snapshotStatistics) { + defer stats.SyncFinished() + + if check.IfNil(adb.trieSyncer) { + log.Error("nil trie syncer") + for missingNode := range missingNodesChan { + log.Warn("could not sync node", "hash", missingNode) + } + + return + } + + for missingNode := range missingNodesChan { + err := adb.trieSyncer.SyncAccounts(missingNode) + if err != nil { + log.Error("could not sync missing node", "error", err) + } + } +} + func emptyErrChanReturningHadContained(errChan chan error) bool { contained := false for { @@ -1173,6 +1255,7 @@ func (adb *AccountsDB) snapshotUserAccountDataTrie( isSnapshot bool, mainTrieRootHash []byte, leavesChannel chan core.KeyValueHolder, + missingNodesChannel chan []byte, errChan chan error, stats common.SnapshotStatisticsHandler, epoch uint32, @@ -1193,11 +1276,11 @@ func (adb *AccountsDB) snapshotUserAccountDataTrie( stats.NewDataTrie() if isSnapshot { - adb.mainTrie.GetStorageManager().TakeSnapshot(account.RootHash, mainTrieRootHash, nil, errChan, stats, epoch) + adb.mainTrie.GetStorageManager().TakeSnapshot(account.RootHash, mainTrieRootHash, nil, missingNodesChannel, errChan, stats, epoch) continue } - adb.mainTrie.GetStorageManager().SetCheckpoint(account.RootHash, mainTrieRootHash, nil, errChan, stats) + adb.mainTrie.GetStorageManager().SetCheckpoint(account.RootHash, mainTrieRootHash, nil, missingNodesChannel, errChan, stats) } } @@ -1214,21 +1297,23 @@ func (adb *AccountsDB) setStateCheckpoint(rootHash []byte) { log.Trace("accountsDB.SetStateCheckpoint", "root hash", rootHash) trieStorageManager.EnterPruningBufferingMode() - stats := newSnapshotStatistics(1) errChan := make(chan error, 1) + missingNodesChannel := make(chan []byte, missingNodesChannelSize) + stats := newSnapshotStatistics(1, 1) go func() { leavesChannel := make(chan core.KeyValueHolder, leavesChannelSize) stats.NewSnapshotStarted() - trieStorageManager.SetCheckpoint(rootHash, rootHash, leavesChannel, errChan, stats) - adb.snapshotUserAccountDataTrie(false, rootHash, leavesChannel, errChan, stats, 0) - trieStorageManager.ExitPruningBufferingMode() + trieStorageManager.SetCheckpoint(rootHash, rootHash, leavesChannel, missingNodesChannel, errChan, stats) + adb.snapshotUserAccountDataTrie(false, rootHash, leavesChannel, missingNodesChannel, errChan, stats, 0) - stats.wg.Done() + stats.SnapshotFinished() }() + go adb.syncMissingNodes(missingNodesChannel, stats) + // TODO decide if we need to take some actions whenever we hit an error that occurred in the checkpoint process // that will be present in the errChan var - go stats.PrintStats("setStateCheckpoint user trie", rootHash) + go adb.finishSnapshotOperation(rootHash, stats, missingNodesChannel, "setStateCheckpoint user trie") adb.waitForCompletionIfAppropriate(stats) } @@ -1255,7 +1340,7 @@ func (adb *AccountsDB) GetAllLeaves(leavesChannel chan core.KeyValueHolder, ctx adb.mutOp.Lock() defer adb.mutOp.Unlock() - return adb.mainTrie.GetAllLeavesOnChannel(leavesChannel, ctx, rootHash) + return adb.mainTrie.GetAllLeavesOnChannel(leavesChannel, ctx, rootHash, keyBuilder.NewKeyBuilder()) } // Close will handle the closing of the underlying components diff --git a/state/accountsDBApi.go b/state/accountsDBApi.go index e5b32a9a62e..d4d1985b514 100644 --- a/state/accountsDBApi.go +++ b/state/accountsDBApi.go @@ -75,6 +75,16 @@ func (accountsDB *accountsDBApi) doRecreateTrieWithBlockInfo(newBlockInfo common return newBlockInfo, nil } +// SetSyncer returns nil for this implementation +func (accountsDB *accountsDBApi) SetSyncer(_ AccountsDBSyncer) error { + return nil +} + +// StartSnapshotIfNeeded returns nil for this implementation +func (accountsDB *accountsDBApi) StartSnapshotIfNeeded() error { + return nil +} + // GetExistingAccount will call the inner accountsAdapter method after trying to recreate the trie func (accountsDB *accountsDBApi) GetExistingAccount(address []byte) (vmcommon.AccountHandler, error) { account, _, err := accountsDB.GetAccountWithBlockInfo(address, holders.NewRootHashHolderAsEmpty()) diff --git a/state/accountsDBApiWithHistory.go b/state/accountsDBApiWithHistory.go index df8125acc78..7d6bbe58eff 100644 --- a/state/accountsDBApiWithHistory.go +++ b/state/accountsDBApiWithHistory.go @@ -29,6 +29,16 @@ func NewAccountsDBApiWithHistory(innerAccountsAdapter AccountsAdapter) (*account }, nil } +// SetSyncer is a not permitted operation in this implementation and thus, does nothing +func (accountsDB *accountsDBApiWithHistory) SetSyncer(_ AccountsDBSyncer) error { + return nil +} + +// StartSnapshotIfNeeded is a not permitted operation in this implementation and thus, does nothing +func (accountsDB *accountsDBApiWithHistory) StartSnapshotIfNeeded() error { + return nil +} + // GetExistingAccount will return an error func (accountsDB *accountsDBApiWithHistory) GetExistingAccount(_ []byte) (vmcommon.AccountHandler, error) { return nil, ErrFunctionalityNotImplemented diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 84058529f47..c075f4e5af5 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -17,6 +17,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/marshal" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" + "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/state/factory" "github.com/ElrondNetwork/elrond-go/state/storagePruningManager" @@ -305,7 +306,7 @@ func TestAccountsDB_SaveAccountSavesCodeAndDataTrieForUserAccount(t *testing.T) accCode := []byte("code") acc, _ := state.NewUserAccount([]byte("someAddress")) acc.SetCode(accCode) - _ = acc.DataTrieTracker().SaveKeyValue([]byte("key"), []byte("value")) + _ = acc.SaveKeyValue([]byte("key"), []byte("value")) err := adb.SaveAccount(acc) assert.Nil(t, err) @@ -769,7 +770,7 @@ func TestAccountsDB_LoadDataWithSomeValuesShouldWork(t *testing.T) { assert.Nil(t, err) // verify data - dataRecov, err := account.DataTrieTracker().RetrieveValue(keyRequired) + dataRecov, err := account.RetrieveValue(keyRequired) assert.Nil(t, err) assert.Equal(t, val, dataRecov) } @@ -807,6 +808,9 @@ func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { return nil }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, }, nil }, GetStorageManagerCalled: func() common.StorageManager { @@ -817,7 +821,7 @@ func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { adb := generateAccountDBFromTrie(&trieStub) accnt, _ := adb.LoadAccount(make([]byte, 32)) - _ = accnt.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue([]byte("dog"), []byte("puppy")) + _ = accnt.(state.UserAccountHandler).SaveKeyValue([]byte("dog"), []byte("puppy")) _ = adb.SaveAccount(accnt) _, err := adb.Commit() @@ -904,7 +908,7 @@ func TestAccountsDB_SnapshotState(t *testing.T) { trieStub := &trieMock.TrieStub{ GetStorageManagerCalled: func() common.StorageManager { return &testscommon.StorageManagerStub{ - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { snapshotMut.Lock() takeSnapshotWasCalled = true snapshotMut.Unlock() @@ -933,7 +937,7 @@ func TestAccountsDB_SnapshotStateOnAClosedStorageManagerShouldNotMarkActiveDB(t ShouldTakeSnapshotCalled: func() bool { return true }, - TakeSnapshotCalled: func(_ []byte, _ []byte, ch chan core.KeyValueHolder, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, ch chan core.KeyValueHolder, _ chan []byte, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { close(ch) stats.SnapshotFinished() }, @@ -986,7 +990,7 @@ func TestAccountsDB_SnapshotStateWithErrorsShouldNotMarkActiveDB(t *testing.T) { ShouldTakeSnapshotCalled: func() bool { return true }, - TakeSnapshotCalled: func(_ []byte, _ []byte, ch chan core.KeyValueHolder, errChan chan error, stats common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, ch chan core.KeyValueHolder, _ chan []byte, errChan chan error, stats common.SnapshotStatisticsHandler, _ uint32) { errChan <- expectedErr close(ch) stats.SnapshotFinished() @@ -1037,7 +1041,7 @@ func TestAccountsDB_SnapshotStateGetLatestStorageEpochErrDoesNotSnapshot(t *test GetLatestStorageEpochCalled: func() (uint32, error) { return 0, fmt.Errorf("new error") }, - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { takeSnapshotCalled = true }, } @@ -1064,9 +1068,11 @@ func TestAccountsDB_SnapshotStateSnapshotSameRootHash(t *testing.T) { GetLatestStorageEpochCalled: func() (uint32, error) { return latestEpoch, nil }, - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, leavesChan chan core.KeyValueHolder, _ chan []byte, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { snapshotMutex.Lock() takeSnapshotCalled++ + close(leavesChan) + stats.SnapshotFinished() snapshotMutex.Unlock() }, } @@ -1129,6 +1135,41 @@ func TestAccountsDB_SnapshotStateSnapshotSameRootHash(t *testing.T) { snapshotMutex.Unlock() } +func TestAccountsDB_SnapshotStateSkipSnapshotIfSnapshotInProgress(t *testing.T) { + t.Parallel() + + rootHashes := [][]byte{[]byte("rootHash1"), []byte("rootHash2"), []byte("rootHash3"), []byte("rootHash4")} + latestEpoch := uint32(0) + snapshotMutex := sync.RWMutex{} + takeSnapshotCalled := 0 + trieStub := &trieMock.TrieStub{ + GetStorageManagerCalled: func() common.StorageManager { + return &testscommon.StorageManagerStub{ + GetLatestStorageEpochCalled: func() (uint32, error) { + return latestEpoch, nil + }, + TakeSnapshotCalled: func(_ []byte, _ []byte, leavesChan chan core.KeyValueHolder, _ chan []byte, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { + snapshotMutex.Lock() + takeSnapshotCalled++ + close(leavesChan) + stats.SnapshotFinished() + snapshotMutex.Unlock() + }, + } + }, + } + adb := generateAccountDBFromTrie(trieStub) + waitForOpToFinish := time.Millisecond * 100 + + for _, rootHash := range rootHashes { + adb.SnapshotState(rootHash) + } + time.Sleep(waitForOpToFinish) + snapshotMutex.Lock() + assert.Equal(t, 1, takeSnapshotCalled) + snapshotMutex.Unlock() +} + func TestAccountsDB_SetStateCheckpointWithDataTries(t *testing.T) { t.Parallel() @@ -1168,7 +1209,7 @@ func TestAccountsDB_SetStateCheckpoint(t *testing.T) { trieStub := &trieMock.TrieStub{ GetStorageManagerCalled: func() common.StorageManager { return &testscommon.StorageManagerStub{ - SetCheckpointCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler) { + SetCheckpointCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler) { snapshotMut.Lock() setCheckPointWasCalled = true snapshotMut.Unlock() @@ -1260,14 +1301,14 @@ func TestAccountsDB_SaveAccountWithoutLoading(t *testing.T) { assert.Nil(t, err) userAcc := account.(state.UserAccountHandler) - err = userAcc.DataTrieTracker().SaveKeyValue(key, value) + err = userAcc.SaveKeyValue(key, value) assert.Nil(t, err) err = adb.SaveAccount(userAcc) assert.Nil(t, err) _, err = adb.Commit() assert.Nil(t, err) - err = userAcc.DataTrieTracker().SaveKeyValue(key1, value) + err = userAcc.SaveKeyValue(key1, value) assert.Nil(t, err) err = adb.SaveAccount(userAcc) assert.Nil(t, err) @@ -1278,11 +1319,11 @@ func TestAccountsDB_SaveAccountWithoutLoading(t *testing.T) { assert.Nil(t, err) userAcc = account.(state.UserAccountHandler) - returnedVal, err := userAcc.DataTrieTracker().RetrieveValue(key) + returnedVal, err := userAcc.RetrieveValue(key) assert.Nil(t, err) assert.Equal(t, value, returnedVal) - returnedVal, err = userAcc.DataTrieTracker().RetrieveValue(key1) + returnedVal, err = userAcc.RetrieveValue(key1) assert.Nil(t, err) assert.Equal(t, value, returnedVal) @@ -1310,7 +1351,7 @@ func TestAccountsDB_RecreateTrieInvalidatesJournalEntries(t *testing.T) { _ = adb.SaveAccount(acc) acc, _ = adb.LoadAccount(address) - _ = acc.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue(key, value) + _ = acc.(state.UserAccountHandler).SaveKeyValue(key, value) _ = adb.SaveAccount(acc) assert.Equal(t, 5, adb.JournalLen()) @@ -1345,7 +1386,7 @@ func TestAccountsDB_GetAllLeaves(t *testing.T) { getAllLeavesCalled := false trieStub := &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, builder common.KeyBuilder) error { getAllLeavesCalled = true close(ch) @@ -1702,7 +1743,7 @@ func TestAccountsDB_RemoveAccountSetsObsoleteHashes(t *testing.T) { addr := make([]byte, 32) acc, _ := adb.LoadAccount(addr) userAcc := acc.(state.UserAccountHandler) - _ = userAcc.DataTrieTracker().SaveKeyValue([]byte("key"), []byte("value")) + _ = userAcc.SaveKeyValue([]byte("key"), []byte("value")) _ = adb.SaveAccount(userAcc) _, _ = adb.Commit() @@ -1711,7 +1752,7 @@ func TestAccountsDB_RemoveAccountSetsObsoleteHashes(t *testing.T) { userAcc = acc.(state.UserAccountHandler) userAcc.SetCode([]byte("code")) snapshot := adb.JournalLen() - hashes, _ := userAcc.DataTrieTracker().DataTrie().GetAllHashes() + hashes, _ := userAcc.DataTrie().(common.Trie).GetAllHashes() err := adb.RemoveAccount(addr) obsoleteHashes := adb.GetObsoleteHashes() @@ -1757,7 +1798,7 @@ func TestAccountsDB_RemoveAccountMarksObsoleteHashesForEviction(t *testing.T) { addr := make([]byte, 32) acc, _ := adb.LoadAccount(addr) userAcc := acc.(state.UserAccountHandler) - _ = userAcc.DataTrieTracker().SaveKeyValue([]byte("key"), []byte("value")) + _ = userAcc.SaveKeyValue([]byte("key"), []byte("value")) _ = adb.SaveAccount(userAcc) addr1 := make([]byte, 32) @@ -1766,7 +1807,7 @@ func TestAccountsDB_RemoveAccountMarksObsoleteHashesForEviction(t *testing.T) { _ = adb.SaveAccount(acc) rootHash, _ := adb.Commit() - hashes, _ := userAcc.DataTrieTracker().DataTrie().GetAllHashes() + hashes, _ := userAcc.DataTrie().(common.Trie).GetAllHashes() err := adb.RemoveAccount(addr) obsoleteHashes := adb.GetObsoleteHashes() @@ -2023,11 +2064,11 @@ func TestAccountsDB_SetStateCheckpointCommitsOnlyMissingData(t *testing.T) { allStateHashes = append(allStateHashes, mainTrieHashes...) acc, _ := adb.LoadAccount(accountsAddresses[0]) - dataTrie1Hashes, _ := acc.(state.UserAccountHandler).DataTrie().GetAllHashes() + dataTrie1Hashes, _ := acc.(state.UserAccountHandler).DataTrie().(common.Trie).GetAllHashes() allStateHashes = append(allStateHashes, dataTrie1Hashes...) acc, _ = adb.LoadAccount(accountsAddresses[1]) - dataTrie2Hashes, _ := acc.(state.UserAccountHandler).DataTrie().GetAllHashes() + dataTrie2Hashes, _ := acc.(state.UserAccountHandler).DataTrie().(common.Trie).GetAllHashes() allStateHashes = append(allStateHashes, dataTrie2Hashes...) for _, hash := range allStateHashes { @@ -2136,18 +2177,18 @@ func generateRandomByteArray(size int) []byte { func modifyDataTries(t *testing.T, accountsAddresses [][]byte, adb *state.AccountsDB) common.ModifiedHashes { acc, _ := adb.LoadAccount(accountsAddresses[0]) - err := acc.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue([]byte("key1"), []byte("value1")) + err := acc.(state.UserAccountHandler).SaveKeyValue([]byte("key1"), []byte("value1")) assert.Nil(t, err) - err = acc.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue([]byte("key2"), []byte("value2")) + err = acc.(state.UserAccountHandler).SaveKeyValue([]byte("key2"), []byte("value2")) assert.Nil(t, err) _ = adb.SaveAccount(acc) - newHashes, _ := acc.(state.UserAccountHandler).DataTrie().GetDirtyHashes() + newHashes, _ := acc.(state.UserAccountHandler).DataTrie().(common.Trie).GetDirtyHashes() acc, _ = adb.LoadAccount(accountsAddresses[1]) - err = acc.(state.UserAccountHandler).DataTrieTracker().SaveKeyValue([]byte("key2"), []byte("value2")) + err = acc.(state.UserAccountHandler).SaveKeyValue([]byte("key2"), []byte("value2")) assert.Nil(t, err) _ = adb.SaveAccount(acc) - newHashesDataTrie, _ := acc.(state.UserAccountHandler).DataTrie().GetDirtyHashes() + newHashesDataTrie, _ := acc.(state.UserAccountHandler).DataTrie().(common.Trie).GetDirtyHashes() mergeMaps(newHashes, newHashesDataTrie) return newHashes @@ -2359,7 +2400,7 @@ func TestAccountsDB_GetAccountFromBytesShouldLoadDataTrie(t *testing.T) { assert.Equal(t, dataTrie, account.DataTrie()) } -func TestAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { +func TestAccountsDB_SetSyncerAndStartSnapshotIfNeeded(t *testing.T) { t.Parallel() rootHash := []byte("rootHash") @@ -2390,7 +2431,11 @@ func TestAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { }, } - _ = generateAccountDBFromTrie(trieStub) + adb := generateAccountDBFromTrie(trieStub) + err := adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) assert.True(t, putCalled) }) @@ -2416,7 +2461,11 @@ func TestAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { }, } - _ = generateAccountDBFromTrie(trieStub) + adb := generateAccountDBFromTrie(trieStub) + err := adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) }) t.Run("in import DB mode", func(t *testing.T) { putCalled := false @@ -2448,7 +2497,11 @@ func TestAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { args.ProcessingMode = common.ImportDb args.Trie = trieStub - _, _ = state.NewAccountsDB(args) + adb, _ := state.NewAccountsDB(args) + err := adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) assert.True(t, putCalled) }) @@ -2474,7 +2527,7 @@ func TestAccountsDB_NewAccountsDbStartsSnapshotAfterRestart(t *testing.T) { ShouldTakeSnapshotCalled: func() bool { return true }, - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { takeSnapshotCalled.SetValue(true) }, GetLatestStorageEpochCalled: func() (uint32, error) { @@ -2484,7 +2537,11 @@ func TestAccountsDB_NewAccountsDbStartsSnapshotAfterRestart(t *testing.T) { }, } - _ = generateAccountDBFromTrie(trieStub) + adb := generateAccountDBFromTrie(trieStub) + err := adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) time.Sleep(time.Second) assert.True(t, takeSnapshotCalled.IsSet()) } diff --git a/state/baseAccount.go b/state/baseAccount.go index 3a0a376a893..383d288b2a9 100644 --- a/state/baseAccount.go +++ b/state/baseAccount.go @@ -25,7 +25,7 @@ func (ba *baseAccount) SetCode(code []byte) { } // DataTrie returns the trie that holds the current account's data -func (ba *baseAccount) DataTrie() common.Trie { +func (ba *baseAccount) DataTrie() common.DataTrieHandler { return ba.dataTrieTracker.DataTrie() } @@ -34,18 +34,31 @@ func (ba *baseAccount) SetDataTrie(trie common.Trie) { ba.dataTrieTracker.SetDataTrie(trie) } -// DataTrieTracker returns the trie wrapper used in managing the SC data -func (ba *baseAccount) DataTrieTracker() DataTrieTracker { - return ba.dataTrieTracker +// RetrieveValue fetches the value from a particular key searching the account data store in the data trie tracker +func (ba *baseAccount) RetrieveValue(key []byte) ([]byte, error) { + if check.IfNil(ba.dataTrieTracker) { + return nil, ErrNilTrackableDataTrie + } + + return ba.dataTrieTracker.RetrieveValue(key) } -// RetrieveValueFromDataTrieTracker fetches the value from a particular key searching the account data store in the data trie tracker -func (ba *baseAccount) RetrieveValueFromDataTrieTracker(key []byte) ([]byte, error) { +// SaveKeyValue adds the given key and value to the underlying trackable data trie +func (ba *baseAccount) SaveKeyValue(key []byte, value []byte) error { + if check.IfNil(ba.dataTrieTracker) { + return ErrNilTrackableDataTrie + } + + return ba.dataTrieTracker.SaveKeyValue(key, value) +} + +// SaveDirtyData triggers SaveDirtyData form the underlying trackableDataTrie +func (ba *baseAccount) SaveDirtyData(trie common.Trie) (map[string][]byte, error) { if check.IfNil(ba.dataTrieTracker) { return nil, ErrNilTrackableDataTrie } - return ba.dataTrieTracker.RetrieveValue(key) + return ba.dataTrieTracker.SaveDirtyData(trie) } // AccountDataHandler returns the account data handler diff --git a/state/baseAccount_test.go b/state/baseAccount_test.go index 8c8c0d8135c..410b12890b7 100644 --- a/state/baseAccount_test.go +++ b/state/baseAccount_test.go @@ -19,15 +19,6 @@ func TestBaseAccount_AddressContainer(t *testing.T) { assert.Equal(t, address, ba.AddressBytes()) } -func TestBaseAccount_DataTrieTracker(t *testing.T) { - t.Parallel() - - tracker := &trieMock.DataTrieTrackerStub{} - - ba := state.NewEmptyBaseAccount(nil, tracker) - assert.Equal(t, tracker, ba.DataTrieTracker()) -} - func TestBaseAccount_DataTrie(t *testing.T) { t.Parallel() diff --git a/state/errors.go b/state/errors.go index 2ee8ecb4f18..89cc3da65e0 100644 --- a/state/errors.go +++ b/state/errors.go @@ -152,3 +152,6 @@ var ErrNilBlockInfoProvider = errors.New("nil block info provider") // ErrFunctionalityNotImplemented signals that the functionality has not been implemented yet var ErrFunctionalityNotImplemented = errors.New("functionality not implemented yet") + +// ErrNilTrieSyncer signals that the trie syncer is nil +var ErrNilTrieSyncer = errors.New("trie syncer is nil") diff --git a/state/interface.go b/state/interface.go index 54d4e9d9528..7d0ccbc2d22 100644 --- a/state/interface.go +++ b/state/interface.go @@ -72,9 +72,9 @@ type UserAccountHandler interface { SetRootHash([]byte) GetRootHash() []byte SetDataTrie(trie common.Trie) - DataTrie() common.Trie - DataTrieTracker() DataTrieTracker - RetrieveValueFromDataTrieTracker(key []byte) ([]byte, error) + DataTrie() common.DataTrieHandler + RetrieveValue(key []byte) ([]byte, error) + SaveKeyValue(key []byte, value []byte) error AddToBalance(value *big.Int) error SubFromBalance(value *big.Int) error GetBalance() *big.Int @@ -91,12 +91,11 @@ type UserAccountHandler interface { // DataTrieTracker models what how to manipulate data held by a SC account type DataTrieTracker interface { - ClearDataCaches() - DirtyData() map[string][]byte RetrieveValue(key []byte) ([]byte, error) SaveKeyValue(key []byte, value []byte) error SetDataTrie(tr common.Trie) - DataTrie() common.Trie + DataTrie() common.DataTrieHandler + SaveDirtyData(common.Trie) (map[string][]byte, error) IsInterfaceNil() bool } @@ -125,10 +124,18 @@ type AccountsAdapter interface { RecreateAllTries(rootHash []byte) (map[string]common.Trie, error) GetTrie(rootHash []byte) (common.Trie, error) GetStackDebugFirstEntry() []byte + SetSyncer(syncer AccountsDBSyncer) error + StartSnapshotIfNeeded() error Close() error IsInterfaceNil() bool } +// AccountsDBSyncer defines the methods for the accounts db syncer +type AccountsDBSyncer interface { + SyncAccounts(rootHash []byte) error + IsInterfaceNil() bool +} + // AccountsRepository handles the defined execution based on the query options type AccountsRepository interface { GetAccountWithBlockInfo(address []byte, options api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) @@ -157,8 +164,8 @@ type baseAccountHandler interface { SetRootHash([]byte) GetRootHash() []byte SetDataTrie(trie common.Trie) - DataTrie() common.Trie - DataTrieTracker() DataTrieTracker + DataTrie() common.DataTrieHandler + SaveDirtyData(trie common.Trie) (map[string][]byte, error) IsInterfaceNil() bool } diff --git a/state/journalEntries.go b/state/journalEntries.go index 586f72f863b..855f5daa1c6 100644 --- a/state/journalEntries.go +++ b/state/journalEntries.go @@ -6,6 +6,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/marshal" + "github.com/ElrondNetwork/elrond-go/common" vmcommon "github.com/ElrondNetwork/elrond-vm-common" ) @@ -187,8 +188,13 @@ func NewJournalEntryDataTrieUpdates(trieUpdates map[string][]byte, account baseA // Revert applies undo operation func (jedtu *journalEntryDataTrieUpdates) Revert() (vmcommon.AccountHandler, error) { + trie, ok := jedtu.account.DataTrie().(common.Trie) + if !ok { + return nil, fmt.Errorf("invalid trie, type is %T", jedtu.account.DataTrie()) + } + for key := range jedtu.trieUpdates { - err := jedtu.account.DataTrie().Update([]byte(key), jedtu.trieUpdates[key]) + err := trie.Update([]byte(key), jedtu.trieUpdates[key]) if err != nil { return nil, err } @@ -196,7 +202,7 @@ func (jedtu *journalEntryDataTrieUpdates) Revert() (vmcommon.AccountHandler, err log.Trace("revert data trie update", "key", []byte(key), "val", jedtu.trieUpdates[key]) } - rootHash, err := jedtu.account.DataTrie().RootHash() + rootHash, err := trie.RootHash() if err != nil { return nil, err } diff --git a/state/peerAccountsDB.go b/state/peerAccountsDB.go index c90e0053097..4bbc51b86a9 100644 --- a/state/peerAccountsDB.go +++ b/state/peerAccountsDB.go @@ -1,10 +1,6 @@ package state import ( - "bytes" - "fmt" - "sync" - "github.com/ElrondNetwork/elrond-go/common" ) @@ -21,33 +17,17 @@ func NewPeerAccountsDB(args ArgsAccountsDB) (*PeerAccountsDB, error) { } adb := &PeerAccountsDB{ - &AccountsDB{ - mainTrie: args.Trie, - hasher: args.Hasher, - marshaller: args.Marshaller, - accountFactory: args.AccountFactory, - entries: make([]JournalEntry, 0), - dataTries: NewDataTriesHolder(), - mutOp: sync.RWMutex{}, - loadCodeMeasurements: &loadingMeasurements{ - identifier: "load code", - }, - storagePruningManager: args.StoragePruningManager, - processingMode: args.ProcessingMode, - lastSnapshot: &snapshotInfo{}, - processStatusHandler: args.ProcessStatusHandler, - }, - } - - trieStorageManager := adb.mainTrie.GetStorageManager() - val, err := trieStorageManager.GetFromCurrentEpoch([]byte(common.ActiveDBKey)) - if err != nil || !bytes.Equal(val, []byte(common.ActiveDBVal)) { - startSnapshotAfterRestart(adb, args) + AccountsDB: createAccountsDb(args), } return adb, nil } +// StartSnapshotIfNeeded starts the snapshot if the previous snapshot process was not fully completed +func (adb *PeerAccountsDB) StartSnapshotIfNeeded() error { + return startSnapshotIfNeeded(adb, adb.trieSyncer, adb.mainTrie.GetStorageManager(), adb.processingMode) +} + // MarkSnapshotDone will mark that the snapshot process has been completed func (adb *PeerAccountsDB) MarkSnapshotDone() { trieStorageManager, epoch, err := adb.getTrieStorageManagerAndLatestEpoch() @@ -60,48 +40,26 @@ func (adb *PeerAccountsDB) MarkSnapshotDone() { handleLoggingWhenError("error while putting active DB value into main storer", err) } -func (adb *PeerAccountsDB) getTrieStorageManagerAndLatestEpoch() (common.StorageManager, uint32, error) { - trieStorageManager := adb.mainTrie.GetStorageManager() - epoch, err := trieStorageManager.GetLatestStorageEpoch() - if err != nil { - return nil, 0, fmt.Errorf("%w while getting the latest storage epoch", err) - } - - return trieStorageManager, epoch, nil -} - // SnapshotState triggers the snapshotting process of the state trie func (adb *PeerAccountsDB) SnapshotState(rootHash []byte) { - log.Trace("peerAccountsDB.SnapshotState", "root hash", rootHash) - trieStorageManager, epoch, err := adb.getTrieStorageManagerAndLatestEpoch() - if err != nil { - log.Error("SnapshotState error", "err", err.Error()) - return - } + adb.mutOp.Lock() + defer adb.mutOp.Unlock() - if !trieStorageManager.ShouldTakeSnapshot() { - log.Debug("skipping snapshot for rootHash", "hash", rootHash) + trieStorageManager, epoch, shouldTakeSnapshot := adb.prepareSnapshot(rootHash) + if !shouldTakeSnapshot { return } - log.Info("starting snapshot", "rootHash", rootHash, "epoch", epoch) - - adb.lastSnapshot.rootHash = rootHash - adb.lastSnapshot.epoch = epoch - err = trieStorageManager.Put([]byte(lastSnapshotStarted), rootHash) - if err != nil { - log.Warn("could not set lastSnapshotStarted", "err", err, "rootHash", rootHash) - } - - stats := newSnapshotStatistics(0) - - trieStorageManager.EnterPruningBufferingMode() - stats.NewSnapshotStarted() + log.Info("starting snapshot peer trie", "rootHash", rootHash, "epoch", epoch) + missingNodesChannel := make(chan []byte, missingNodesChannelSize) errChan := make(chan error, 1) - trieStorageManager.TakeSnapshot(rootHash, rootHash, nil, errChan, stats, epoch) - trieStorageManager.ExitPruningBufferingMode() + stats := newSnapshotStatistics(0, 1) + stats.NewSnapshotStarted() + trieStorageManager.TakeSnapshot(rootHash, rootHash, nil, missingNodesChannel, errChan, stats, epoch) - go adb.markActiveDBAfterSnapshot(stats, errChan, rootHash, "snapshotState peer trie", epoch) + go adb.syncMissingNodes(missingNodesChannel, stats) + + go adb.processSnapshotCompletion(stats, missingNodesChannel, errChan, rootHash, "snapshotState peer trie", epoch) adb.waitForCompletionIfAppropriate(stats) } @@ -111,32 +69,26 @@ func (adb *PeerAccountsDB) SetStateCheckpoint(rootHash []byte) { log.Trace("peerAccountsDB.SetStateCheckpoint", "root hash", rootHash) trieStorageManager := adb.mainTrie.GetStorageManager() - stats := newSnapshotStatistics(0) + missingNodesChannel := make(chan []byte, missingNodesChannelSize) + stats := newSnapshotStatistics(0, 1) trieStorageManager.EnterPruningBufferingMode() stats.NewSnapshotStarted() errChan := make(chan error, 1) - trieStorageManager.SetCheckpoint(rootHash, rootHash, nil, errChan, stats) - trieStorageManager.ExitPruningBufferingMode() + trieStorageManager.SetCheckpoint(rootHash, rootHash, nil, missingNodesChannel, errChan, stats) + + go adb.syncMissingNodes(missingNodesChannel, stats) // TODO decide if we need to take some actions whenever we hit an error that occurred in the checkpoint process // that will be present in the errChan var - go stats.PrintStats("setStateCheckpoint peer trie", rootHash) + go adb.finishSnapshotOperation(rootHash, stats, missingNodesChannel, "setStateCheckpoint peer trie") adb.waitForCompletionIfAppropriate(stats) } // RecreateAllTries recreates all the tries from the accounts DB func (adb *PeerAccountsDB) RecreateAllTries(rootHash []byte) (map[string]common.Trie, error) { - recreatedTrie, err := adb.mainTrie.Recreate(rootHash) - if err != nil { - return nil, err - } - - allTries := make(map[string]common.Trie) - allTries[string(rootHash)] = recreatedTrie - - return allTries, nil + return adb.recreateMainTrie(rootHash) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/state/peerAccountsDB_test.go b/state/peerAccountsDB_test.go index 765ea67283d..ad2bfe06715 100644 --- a/state/peerAccountsDB_test.go +++ b/state/peerAccountsDB_test.go @@ -11,6 +11,7 @@ import ( "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go/common" + "github.com/ElrondNetwork/elrond-go/process/mock" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/testscommon" trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" @@ -99,7 +100,7 @@ func TestNewPeerAccountsDB_SnapshotState(t *testing.T) { args.Trie = &trieMock.TrieStub{ GetStorageManagerCalled: func() common.StorageManager { return &testscommon.StorageManagerStub{ - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { snapshotCalled = true }, } @@ -125,7 +126,7 @@ func TestNewPeerAccountsDB_SnapshotStateGetLatestStorageEpochErrDoesNotSnapshot( GetLatestStorageEpochCalled: func() (uint32, error) { return 0, fmt.Errorf("new error") }, - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { snapshotCalled = true }, } @@ -147,7 +148,7 @@ func TestNewPeerAccountsDB_SetStateCheckpoint(t *testing.T) { args.Trie = &trieMock.TrieStub{ GetStorageManagerCalled: func() common.StorageManager { return &testscommon.StorageManagerStub{ - SetCheckpointCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler) { + SetCheckpointCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler) { checkpointCalled = true }, } @@ -187,7 +188,7 @@ func TestNewPeerAccountsDB_RecreateAllTries(t *testing.T) { assert.True(t, recreateCalled) } -func TestPeerAccountsDB_NewAccountsDbStartsSnapshotAfterRestart(t *testing.T) { +func TestPeerAccountsDB_SetSyncerAndStartSnapshotIfNeeded(t *testing.T) { t.Parallel() rootHash := []byte("rootHash") @@ -208,7 +209,7 @@ func TestPeerAccountsDB_NewAccountsDbStartsSnapshotAfterRestart(t *testing.T) { ShouldTakeSnapshotCalled: func() bool { return true }, - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { mutex.Lock() takeSnapshotCalled = true mutex.Unlock() @@ -225,6 +226,10 @@ func TestPeerAccountsDB_NewAccountsDbStartsSnapshotAfterRestart(t *testing.T) { adb, err := state.NewPeerAccountsDB(args) assert.Nil(t, err) assert.NotNil(t, adb) + err = adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) time.Sleep(time.Second) mutex.RLock() @@ -321,7 +326,7 @@ func TestPeerAccountsDB_MarkSnapshotDone(t *testing.T) { } -func TestPeerAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { +func TestPeerAccountsDB_SetSyncerAndStartSnapshotIfNeededMarksActiveDB(t *testing.T) { t.Parallel() rootHash := []byte("rootHash") @@ -354,7 +359,11 @@ func TestPeerAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { args := createMockAccountsDBArgs() args.Trie = trieStub - _, _ = state.NewPeerAccountsDB(args) + adb, _ := state.NewPeerAccountsDB(args) + err := adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) assert.True(t, putCalled) }) @@ -382,7 +391,11 @@ func TestPeerAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { args := createMockAccountsDBArgs() args.Trie = trieStub - _, _ = state.NewPeerAccountsDB(args) + adb, _ := state.NewPeerAccountsDB(args) + err := adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) }) t.Run("in import DB mode", func(t *testing.T) { putCalled := false @@ -413,7 +426,11 @@ func TestPeerAccountsDB_NewAccountsDbShouldSetActiveDB(t *testing.T) { args := createMockAccountsDBArgs() args.ProcessingMode = common.ImportDb args.Trie = trieStub - _, _ = state.NewPeerAccountsDB(args) + adb, _ := state.NewPeerAccountsDB(args) + err := adb.SetSyncer(&mock.AccountsDBSyncerStub{}) + assert.Nil(t, err) + err = adb.StartSnapshotIfNeeded() + assert.Nil(t, err) assert.True(t, putCalled) }) @@ -431,7 +448,7 @@ func TestPeerAccountsDB_SnapshotStateOnAClosedStorageManagerShouldNotMarkActiveD ShouldTakeSnapshotCalled: func() bool { return true }, - TakeSnapshotCalled: func(_ []byte, _ []byte, ch chan core.KeyValueHolder, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, ch chan core.KeyValueHolder, _ chan []byte, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { stats.SnapshotFinished() }, IsClosedCalled: func() bool { diff --git a/state/snapshotStatistics.go b/state/snapshotStatistics.go index 5cf05decf19..8bc57a8e9bf 100644 --- a/state/snapshotStatistics.go +++ b/state/snapshotStatistics.go @@ -13,16 +13,21 @@ type snapshotStatistics struct { trieSize uint64 startTime time.Time - wg *sync.WaitGroup - mutex sync.RWMutex + wgSnapshot *sync.WaitGroup + wgSync *sync.WaitGroup + mutex sync.RWMutex } -func newSnapshotStatistics(delta int) *snapshotStatistics { - wg := &sync.WaitGroup{} - wg.Add(delta) +func newSnapshotStatistics(snapshotDelta int, syncDelta int) *snapshotStatistics { + wgSnapshot := &sync.WaitGroup{} + wgSnapshot.Add(snapshotDelta) + + wgSync := &sync.WaitGroup{} + wgSync.Add(syncDelta) return &snapshotStatistics{ - wg: wg, - startTime: time.Now(), + wgSnapshot: wgSnapshot, + wgSync: wgSync, + startTime: time.Now(), } } @@ -37,12 +42,12 @@ func (ss *snapshotStatistics) AddSize(size uint64) { // SnapshotFinished marks the ending of a snapshot goroutine func (ss *snapshotStatistics) SnapshotFinished() { - ss.wg.Done() + ss.wgSnapshot.Done() } // NewSnapshotStarted marks the starting of a new snapshot goroutine func (ss *snapshotStatistics) NewSnapshotStarted() { - ss.wg.Add(1) + ss.wgSnapshot.Add(1) } // NewDataTrie increases the data Tries counter @@ -55,13 +60,21 @@ func (ss *snapshotStatistics) NewDataTrie() { // WaitForSnapshotsToFinish will wait until the waitGroup counter is zero func (ss *snapshotStatistics) WaitForSnapshotsToFinish() { - ss.wg.Wait() + ss.wgSnapshot.Wait() +} + +// WaitForSyncToFinish will wait until the waitGroup counter is zero +func (ss *snapshotStatistics) WaitForSyncToFinish() { + ss.wgSync.Wait() +} + +// SyncFinished marks the end of the sync process +func (ss *snapshotStatistics) SyncFinished() { + ss.wgSync.Done() } // PrintStats will print the stats after the snapshot has finished func (ss *snapshotStatistics) PrintStats(identifier string, rootHash []byte) { - ss.wg.Wait() - ss.mutex.RLock() defer ss.mutex.RUnlock() diff --git a/state/snapshotStatistics_test.go b/state/snapshotStatistics_test.go index 60b2c882dfa..7f5d9293f43 100644 --- a/state/snapshotStatistics_test.go +++ b/state/snapshotStatistics_test.go @@ -22,7 +22,7 @@ func TestSnapshotStatistics_AddSize(t *testing.T) { func TestSnapshotStatistics_Concurrency(t *testing.T) { wg := &sync.WaitGroup{} ss := &snapshotStatistics{ - wg: wg, + wgSnapshot: wg, } numRuns := 100 diff --git a/state/storagePruningManager/evictionWaitingList/evictionWaitingList_test.go b/state/storagePruningManager/evictionWaitingList/evictionWaitingList_test.go index 1da59ba77e6..16d54998c0a 100644 --- a/state/storagePruningManager/evictionWaitingList/evictionWaitingList_test.go +++ b/state/storagePruningManager/evictionWaitingList/evictionWaitingList_test.go @@ -9,13 +9,13 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" + "github.com/ElrondNetwork/elrond-go/storage/database" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/stretchr/testify/assert" ) func getDefaultParameters() (uint, storage.Persister, marshal.Marshalizer) { - return 10, memorydb.New(), &testscommon.MarshalizerMock{} + return 10, database.NewMemDB(), &testscommon.MarshalizerMock{} } func TestNewEvictionWaitingList(t *testing.T) { @@ -297,8 +297,8 @@ func TestEvictionWaitingList_ShouldKeepHashInvalidKey(t *testing.T) { func TestNewEvictionWaitingList_Close(t *testing.T) { t.Parallel() - db := memorydb.New() - ewl, err := NewEvictionWaitingList(10, db, &testscommon.MarshalizerMock{}) + memDB := database.NewMemDB() + ewl, err := NewEvictionWaitingList(10, memDB, &testscommon.MarshalizerMock{}) assert.Nil(t, err) assert.NotNil(t, ewl) diff --git a/state/syncer/baseAccountsSyncer.go b/state/syncer/baseAccountsSyncer.go index 4969686d315..c581a345a11 100644 --- a/state/syncer/baseAccountsSyncer.go +++ b/state/syncer/baseAccountsSyncer.go @@ -32,6 +32,7 @@ type baseAccountsSyncer struct { name string maxHardCapForMissingNodes int checkNodesOnDisk bool + storageMarker trie.StorageMarker trieSyncerVersion int numTriesSynced int32 @@ -45,6 +46,7 @@ type ArgsNewBaseAccountsSyncer struct { Hasher hashing.Hasher Marshalizer marshal.Marshalizer TrieStorageManager common.StorageManager + StorageMarker trie.StorageMarker RequestHandler trie.RequestHandler Timeout time.Duration Cacher storage.Cacher diff --git a/state/syncer/userAccountsSyncer.go b/state/syncer/userAccountsSyncer.go index 0b5840533f3..bf17014291a 100644 --- a/state/syncer/userAccountsSyncer.go +++ b/state/syncer/userAccountsSyncer.go @@ -17,6 +17,7 @@ import ( "github.com/ElrondNetwork/elrond-go/process/factory" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/trie" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/trie/statistics" ) @@ -88,6 +89,7 @@ func NewUserAccountsSyncer(args ArgsNewUserAccountsSyncer) (*userAccountsSyncer, maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, trieSyncerVersion: args.TrieSyncerVersion, checkNodesOnDisk: args.CheckNodesOnDisk, + storageMarker: args.StorageMarker, } u := &userAccountsSyncer{ @@ -132,7 +134,7 @@ func (u *userAccountsSyncer) SyncAccounts(rootHash []byte) error { return err } - mainTrie.MarkStorerAsSyncedAndActive() + u.storageMarker.MarkStorerAsSyncedAndActive(mainTrie.GetStorageManager()) return nil } @@ -211,7 +213,7 @@ func (u *userAccountsSyncer) syncAccountDataTries( } leavesChannel := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = mainTrie.GetAllLeavesOnChannel(leavesChannel, context.Background(), mainRootHash) + err = mainTrie.GetAllLeavesOnChannel(leavesChannel, context.Background(), mainRootHash, keyBuilder.NewDisabledKeyBuilder()) if err != nil { return err } diff --git a/state/syncer/validatorAccountsSyncer.go b/state/syncer/validatorAccountsSyncer.go index c9d28b71aae..296762ef210 100644 --- a/state/syncer/validatorAccountsSyncer.go +++ b/state/syncer/validatorAccountsSyncer.go @@ -48,6 +48,7 @@ func NewValidatorAccountsSyncer(args ArgsNewValidatorAccountsSyncer) (*validator maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, trieSyncerVersion: args.TrieSyncerVersion, checkNodesOnDisk: args.CheckNodesOnDisk, + storageMarker: args.StorageMarker, } u := &validatorAccountsSyncer{ @@ -78,7 +79,7 @@ func (v *validatorAccountsSyncer) SyncAccounts(rootHash []byte) error { return err } - mainTrie.MarkStorerAsSyncedAndActive() + v.storageMarker.MarkStorerAsSyncedAndActive(mainTrie.GetStorageManager()) return nil } diff --git a/state/trackableDataTrie.go b/state/trackableDataTrie.go index 233ca4f487f..7ab077b3a9c 100644 --- a/state/trackableDataTrie.go +++ b/state/trackableDataTrie.go @@ -2,40 +2,31 @@ package state import ( "github.com/ElrondNetwork/elrond-go-core/core" + "github.com/ElrondNetwork/elrond-go-core/core/check" "github.com/ElrondNetwork/elrond-go-core/data" "github.com/ElrondNetwork/elrond-go/common" ) // TrackableDataTrie wraps a PatriciaMerkelTrie adding modifying data capabilities -type TrackableDataTrie struct { +type trackableDataTrie struct { dirtyData map[string][]byte tr common.Trie identifier []byte } -// NewTrackableDataTrie returns an instance of DataTrieTracker -func NewTrackableDataTrie(identifier []byte, tr common.Trie) *TrackableDataTrie { - return &TrackableDataTrie{ +// NewTrackableDataTrie returns an instance of trackableDataTrie +func NewTrackableDataTrie(identifier []byte, tr common.Trie) *trackableDataTrie { + return &trackableDataTrie{ tr: tr, dirtyData: make(map[string][]byte), identifier: identifier, } } -// ClearDataCaches empties the dirtyData map and original map -func (tdaw *TrackableDataTrie) ClearDataCaches() { - tdaw.dirtyData = make(map[string][]byte) -} - -// DirtyData returns the map of (key, value) pairs that contain the data needed to be saved in the data trie -func (tdaw *TrackableDataTrie) DirtyData() map[string][]byte { - return tdaw.dirtyData -} - // RetrieveValue fetches the value from a particular key searching the account data store // The search starts with dirty map, continues with original map and ends with the trie // Data must have been retrieved from its trie -func (tdaw *TrackableDataTrie) RetrieveValue(key []byte) ([]byte, error) { +func (tdaw *trackableDataTrie) RetrieveValue(key []byte) ([]byte, error) { tailLength := len(key) + len(tdaw.identifier) // search in dirty data cache @@ -69,7 +60,7 @@ func trimValue(value []byte, tailLength int) ([]byte, error) { // SaveKeyValue stores in dirtyData the data keys "touched" // It does not care if the data is really dirty as calling this check here will be sub-optimal -func (tdaw *TrackableDataTrie) SaveKeyValue(key []byte, value []byte) error { +func (tdaw *trackableDataTrie) SaveKeyValue(key []byte, value []byte) error { var identifier []byte lenValue := uint64(len(value)) if lenValue > core.MaxLeafSize { @@ -85,16 +76,51 @@ func (tdaw *TrackableDataTrie) SaveKeyValue(key []byte, value []byte) error { } // SetDataTrie sets the internal data trie -func (tdaw *TrackableDataTrie) SetDataTrie(tr common.Trie) { +func (tdaw *trackableDataTrie) SetDataTrie(tr common.Trie) { tdaw.tr = tr } // DataTrie sets the internal data trie -func (tdaw *TrackableDataTrie) DataTrie() common.Trie { +func (tdaw *trackableDataTrie) DataTrie() common.DataTrieHandler { return tdaw.tr } +// SaveDirtyData saved the dirty data to the trie +func (tdaw *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) (map[string][]byte, error) { + if len(tdaw.dirtyData) == 0 { + return map[string][]byte{}, nil + } + + if check.IfNil(tdaw.tr) { + newDataTrie, err := mainTrie.Recreate(make([]byte, 0)) + if err != nil { + return nil, err + } + + tdaw.tr = newDataTrie + } + + oldValues := make(map[string][]byte) + + for k, v := range tdaw.dirtyData { + val, err := tdaw.tr.Get([]byte(k)) + if err != nil { + return oldValues, err + } + + oldValues[k] = val + + err = tdaw.tr.Update([]byte(k), v) + if err != nil { + return oldValues, err + } + } + + tdaw.dirtyData = make(map[string][]byte) + return oldValues, nil +} + // IsInterfaceNil returns true if there is no value under the interface -func (tdaw *TrackableDataTrie) IsInterfaceNil() bool { +func (tdaw *trackableDataTrie) IsInterfaceNil() bool { return tdaw == nil } diff --git a/state/trackableDataTrie_test.go b/state/trackableDataTrie_test.go index 2ded76c2b39..28cf5869d53 100644 --- a/state/trackableDataTrie_test.go +++ b/state/trackableDataTrie_test.go @@ -33,29 +33,6 @@ func TestTrackableDataTrie_RetrieveValueNilDataTrieShouldErr(t *testing.T) { assert.NotNil(t, err) } -func TestTrackableDataTrie_RetrieveValueFoundInDirtyShouldWork(t *testing.T) { - t.Parallel() - - stringKey := "ABC" - identifier := []byte("identifier") - trie := &trieMock.TrieStub{} - tdaw := state.NewTrackableDataTrie(identifier, trie) - assert.NotNil(t, tdaw) - - tdaw.SetDataTrie(&trieMock.TrieStub{}) - key := []byte(stringKey) - val := []byte("123") - - trieVal := append(val, key...) - trieVal = append(trieVal, identifier...) - - tdaw.DirtyData()[stringKey] = trieVal - - retrievedVal, err := tdaw.RetrieveValue(key) - assert.Nil(t, err) - assert.Equal(t, val, retrievedVal) -} - func TestTrackableDataTrie_RetrieveValueFoundInTrieShouldWork(t *testing.T) { t.Parallel() @@ -141,9 +118,6 @@ func TestTrackableDataTrie_SaveKeyValueShouldSaveOnlyInDirty(t *testing.T) { keyExpected := []byte("key") value := []byte("value") - expectedVal := append(value, keyExpected...) - expectedVal = append(expectedVal, identifier...) - trie := &trieMock.TrieStub{ UpdateCalled: func(key, value []byte) error { return nil @@ -159,27 +133,9 @@ func TestTrackableDataTrie_SaveKeyValueShouldSaveOnlyInDirty(t *testing.T) { _ = mdaw.SaveKeyValue(keyExpected, value) // test in dirty - assert.Equal(t, expectedVal, mdaw.DirtyData()[string(keyExpected)]) -} - -func TestTrackableDataTrie_ClearDataCachesValidDataShouldWork(t *testing.T) { - t.Parallel() - - trie := &trieMock.TrieStub{} - mdaw := state.NewTrackableDataTrie([]byte("identifier"), trie) - assert.NotNil(t, mdaw) - - mdaw.SetDataTrie(&trieMock.TrieStub{}) - - assert.Equal(t, 0, len(mdaw.DirtyData())) - - // add something - _ = mdaw.SaveKeyValue([]byte("ABC"), []byte("123")) - assert.Equal(t, 1, len(mdaw.DirtyData())) - - // clear - mdaw.ClearDataCaches() - assert.Equal(t, 0, len(mdaw.DirtyData())) + retrievedVal, err := mdaw.RetrieveValue(keyExpected) + assert.Nil(t, err) + assert.Equal(t, value, retrievedVal) } func TestTrackableDataTrie_SetAndGetDataTrie(t *testing.T) { diff --git a/statusHandler/persister/persistentHandler.go b/statusHandler/persister/persistentHandler.go index 31777847f0b..22045c92dde 100644 --- a/statusHandler/persister/persistentHandler.go +++ b/statusHandler/persister/persistentHandler.go @@ -11,7 +11,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/statusHandler" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) var log = logger.GetOrCreate("statusHandler/persister") @@ -38,7 +38,7 @@ func NewPersistentStatusHandler( } psh := new(PersistentStatusHandler) - psh.store = storageUnit.NewNilStorer() + psh.store = storageunit.NewNilStorer() psh.uint64ByteSliceConverter = uint64ByteSliceConverter psh.marshalizer = marshalizer psh.persistentMetrics = &sync.Map{} diff --git a/statusHandler/statusMetricsProvider_test.go b/statusHandler/statusMetricsProvider_test.go index 360b7624e65..fb2d94e956d 100644 --- a/statusHandler/statusMetricsProvider_test.go +++ b/statusHandler/statusMetricsProvider_test.go @@ -332,7 +332,7 @@ func TestStatusMetrics_EnableEpochMetrics(t *testing.T) { common.MetricNodesToShufflePerShard: uint64(5), }, }, - common.MetricHeartbeatDisableEpoch: uint64(5), + common.MetricHeartbeatDisableEpoch: uint64(5), } epochsMetrics, _ := sm.EnableEpochsMetrics() diff --git a/storage/cache/cache.go b/storage/cache/cache.go new file mode 100644 index 00000000000..fb4fcc43788 --- /dev/null +++ b/storage/cache/cache.go @@ -0,0 +1,80 @@ +package cache + +import ( + "time" + + "github.com/ElrondNetwork/elrond-go-core/core" + "github.com/ElrondNetwork/elrond-go-storage/immunitycache" + "github.com/ElrondNetwork/elrond-go-storage/lrucache" + "github.com/ElrondNetwork/elrond-go-storage/lrucache/capacity" + "github.com/ElrondNetwork/elrond-go-storage/timecache" + "github.com/ElrondNetwork/elrond-go-storage/types" + "github.com/ElrondNetwork/elrond-go/storage" +) + +// ArgTimeCacher is the argument used to create a new timeCacher instance +type ArgTimeCacher = timecache.ArgTimeCacher + +// TimeCache is an alias for the imported TimeCache structure +type TimeCache = timecache.TimeCache + +// EvictionHandler is an alias to the imported EvictionHandler +type EvictionHandler = types.EvictionHandler + +// ImmunityCache is a cache-like structure +type ImmunityCache = immunitycache.ImmunityCache + +// CacheConfig holds cache configuration +type CacheConfig = immunitycache.CacheConfig + +// TimeCacher defines the cache that can keep a record for a bounded time +type TimeCacher interface { + Add(key string) error + Upsert(key string, span time.Duration) error + Has(key string) bool + Sweep() + IsInterfaceNil() bool +} + +// PeerBlackListCacher can determine if a certain peer id is or not blacklisted +type PeerBlackListCacher interface { + Upsert(pid core.PeerID, span time.Duration) error + Has(pid core.PeerID) bool + Sweep() + IsInterfaceNil() bool +} + +// NewTimeCache returns an instance of a time cache +func NewTimeCache(defaultSpan time.Duration) *timecache.TimeCache { + return timecache.NewTimeCache(defaultSpan) +} + +// NewTimeCacher creates a new timeCacher +func NewTimeCacher(arg ArgTimeCacher) (storage.Cacher, error) { + return timecache.NewTimeCacher(arg) +} + +// NewLRUCache returns an instance of a LRU cache +func NewLRUCache(size int) (storage.Cacher, error) { + return lrucache.NewCache(size) +} + +// NewPeerTimeCache returns an instance of a peer time cacher +func NewPeerTimeCache(cache TimeCacher) (PeerBlackListCacher, error) { + return timecache.NewPeerTimeCache(cache) +} + +// NewCapacityLRU constructs an LRU cache of the given size with a byte size capacity +func NewCapacityLRU(size int, byteCapacity int64) (storage.AdaptedSizedLRUCache, error) { + return capacity.NewCapacityLRU(size, byteCapacity) +} + +// NewLRUCacheWithEviction creates a new sized LRU cache instance with eviction function +func NewLRUCacheWithEviction(size int, onEvicted func(key interface{}, value interface{})) (storage.Cacher, error) { + return lrucache.NewCacheWithEviction(size, onEvicted) +} + +// NewImmunityCache creates a new cache +func NewImmunityCache(config CacheConfig) (*immunitycache.ImmunityCache, error) { + return immunitycache.NewImmunityCache(config) +} diff --git a/storage/clean/oldDatabaseCleaner.go b/storage/clean/oldDatabaseCleaner.go index e16c92507a0..7fff01ae2cd 100644 --- a/storage/clean/oldDatabaseCleaner.go +++ b/storage/clean/oldDatabaseCleaner.go @@ -15,7 +15,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/factory/directoryhandler" + "github.com/ElrondNetwork/elrond-go/storage/directoryhandler" ) var log = logger.GetOrCreate("storage/clean") diff --git a/storage/constants.go b/storage/constants.go new file mode 100644 index 00000000000..5c93b86c128 --- /dev/null +++ b/storage/constants.go @@ -0,0 +1,35 @@ +package storage + +import ( + "github.com/ElrondNetwork/elrond-go-storage/storageUnit" +) + +// MaxRetriesToCreateDB represents the maximum number of times to try to create DB if it failed +const MaxRetriesToCreateDB = storageUnit.MaxRetriesToCreateDB + +// SleepTimeBetweenCreateDBRetries represents the number of seconds to sleep between DB creates +const SleepTimeBetweenCreateDBRetries = storageUnit.SleepTimeBetweenCreateDBRetries + +// PathShardPlaceholder represents the placeholder for the shard ID in paths +const PathShardPlaceholder = "[S]" + +// PathEpochPlaceholder represents the placeholder for the epoch number in paths +const PathEpochPlaceholder = "[E]" + +// PathIdentifierPlaceholder represents the placeholder for the identifier in paths +const PathIdentifierPlaceholder = "[I]" + +// TxPoolNumTxsToPreemptivelyEvict instructs tx pool eviction algorithm to remove this many transactions when eviction takes place +const TxPoolNumTxsToPreemptivelyEvict = uint32(1000) + +// DefaultDBPath is the default path for nodes databases +const DefaultDBPath = "db" + +// DefaultEpochString is the default folder root name for node per epoch databases +const DefaultEpochString = "Epoch" + +// DefaultStaticDbString is the default name for the static databases (not changing with epoch) +const DefaultStaticDbString = "Static" + +// DefaultShardString is the default folder root name for per shard databases +const DefaultShardString = "Shard" diff --git a/storage/database/db.go b/storage/database/db.go new file mode 100644 index 00000000000..5c65e3f95f7 --- /dev/null +++ b/storage/database/db.go @@ -0,0 +1,33 @@ +package database + +import ( + "github.com/ElrondNetwork/elrond-go-storage/leveldb" + "github.com/ElrondNetwork/elrond-go-storage/memorydb" + "github.com/ElrondNetwork/elrond-go/storage" +) + +// MemDB represents the memory database storage. It holds a map of key value pairs +// and a mutex to handle concurrent accesses to the map +type MemDB = memorydb.DB + +// NewMemDB creates a new memorydb object +func NewMemDB() *MemDB { + return memorydb.New() +} + +// NewlruDB creates a lruDB according to size +func NewlruDB(size uint32) (storage.Persister, error) { + return memorydb.NewlruDB(size) +} + +// NewLevelDB is a constructor for the leveldb persister +// It creates the files in the location given as parameter +func NewLevelDB(path string, batchDelaySeconds int, maxBatchSize int, maxOpenFiles int) (s *leveldb.DB, err error) { + return leveldb.NewDB(path, batchDelaySeconds, maxBatchSize, maxOpenFiles) +} + +// NewSerialDB is a constructor for the leveldb persister +// It creates the files in the location given as parameter +func NewSerialDB(path string, batchDelaySeconds int, maxBatchSize int, maxOpenFiles int) (s *leveldb.SerialDB, err error) { + return leveldb.NewSerialDB(path, batchDelaySeconds, maxBatchSize, maxOpenFiles) +} diff --git a/storage/factory/directoryhandler/directoryReader.go b/storage/directoryhandler/directoryReader.go similarity index 100% rename from storage/factory/directoryhandler/directoryReader.go rename to storage/directoryhandler/directoryReader.go diff --git a/storage/factory/directoryhandler/directoryReader_test.go b/storage/directoryhandler/directoryReader_test.go similarity index 100% rename from storage/factory/directoryhandler/directoryReader_test.go rename to storage/directoryhandler/directoryReader_test.go diff --git a/storage/disabled/persister.go b/storage/disabled/persister.go index da66061b600..91945094e30 100644 --- a/storage/disabled/persister.go +++ b/storage/disabled/persister.go @@ -1,6 +1,8 @@ package disabled -import "github.com/ElrondNetwork/elrond-go/storage" +import ( + "github.com/ElrondNetwork/elrond-go/storage" +) type persister struct{} diff --git a/storage/errors.go b/storage/errors.go index f0111509f49..632807ca8d5 100644 --- a/storage/errors.go +++ b/storage/errors.go @@ -3,34 +3,9 @@ package storage import ( "errors" "strings" -) - -// ErrNilPersister is raised when a nil persister is provided -var ErrNilPersister = errors.New("expected not nil persister") - -// ErrNilCacher is raised when a nil cacher is provided -var ErrNilCacher = errors.New("expected not nil cacher") - -// ErrNotSupportedCacheType is raised when an unsupported cache type is provided -var ErrNotSupportedCacheType = errors.New("not supported cache type") - -// ErrNotSupportedDBType is raised when an unsupported database type is provided -var ErrNotSupportedDBType = errors.New("not supported db type") - -// ErrNotSupportedHashType is raised when an unsupported hasher is provided -var ErrNotSupportedHashType = errors.New("hash type not supported") - -// ErrKeyNotFound is raised when a key is not found -var ErrKeyNotFound = errors.New("key not found") -// ErrInvalidBatch is raised when the used batch is invalid -var ErrInvalidBatch = errors.New("batch is invalid") - -// ErrInvalidNumOpenFiles is raised when the max num of open files is less than 1 -var ErrInvalidNumOpenFiles = errors.New("maxOpenFiles is invalid") - -// ErrEmptyKey is raised when a key is empty -var ErrEmptyKey = errors.New("key is empty") + storageErrors "github.com/ElrondNetwork/elrond-go-storage/common" +) // ErrInvalidNumberOfPersisters signals that an invalid number of persisters has been provided var ErrInvalidNumberOfPersisters = errors.New("invalid number of active persisters") @@ -50,9 +25,6 @@ var ErrDestroyingUnit = errors.New("destroy unit didn't remove all the persister // ErrNilConfig signals that a nil configuration has been received var ErrNilConfig = errors.New("nil config") -// ErrInvalidConfig signals an invalid config -var ErrInvalidConfig = errors.New("invalid config") - // ErrNilShardCoordinator signals that a nil shard coordinator has been provided var ErrNilShardCoordinator = errors.New("nil shard coordinator") @@ -65,21 +37,6 @@ var ErrNilCustomDatabaseRemover = errors.New("custom database remover") // ErrNilStorageListProvider signals that a nil storage list provided has been provided var ErrNilStorageListProvider = errors.New("nil storage list provider") -// ErrEmptyPruningPathTemplate signals that an empty path template for pruning storers has been provided -var ErrEmptyPruningPathTemplate = errors.New("empty path template for pruning storers") - -// ErrEmptyStaticPathTemplate signals that an empty path template for static storers has been provided -var ErrEmptyStaticPathTemplate = errors.New("empty path template for static storers") - -// ErrInvalidPruningPathTemplate signals that an invalid path template for pruning storers has been provided -var ErrInvalidPruningPathTemplate = errors.New("invalid path template for pruning storers") - -// ErrInvalidStaticPathTemplate signals that an invalid path template for static storers has been provided -var ErrInvalidStaticPathTemplate = errors.New("invalid path template for static storers") - -// ErrInvalidDatabasePath signals that an invalid database path has been provided -var ErrInvalidDatabasePath = errors.New("invalid database path") - // ErrOldestEpochNotAvailable signals that fetching the oldest epoch is not available var ErrOldestEpochNotAvailable = errors.New("oldest epoch not available") @@ -104,36 +61,6 @@ var ErrNilMarshalizer = errors.New("nil marshalizer") // ErrWrongTypeAssertion is thrown when a wrong type assertion is spotted var ErrWrongTypeAssertion = errors.New("wrong type assertion") -// ErrFailedCacheEviction signals a failed eviction within a cache -var ErrFailedCacheEviction = errors.New("failed eviction within cache") - -// ErrImmuneItemsCapacityReached signals that capacity for immune items is reached -var ErrImmuneItemsCapacityReached = errors.New("capacity reached for immune items") - -// ErrItemAlreadyInCache signals that an item is already in cache -var ErrItemAlreadyInCache = errors.New("item already in cache") - -// ErrCacheSizeInvalid signals that size of cache is less than 1 -var ErrCacheSizeInvalid = errors.New("cache size is less than 1") - -// ErrCacheCapacityInvalid signals that capacity of cache is less than 1 -var ErrCacheCapacityInvalid = errors.New("cache capacity is less than 1") - -// ErrLRUCacheWithProvidedSize signals that a simple LRU cache is wanted but the user provided a positive size in bytes value -var ErrLRUCacheWithProvidedSize = errors.New("LRU cache does not support size in bytes") - -// ErrLRUCacheInvalidSize signals that the provided size in bytes value for LRU cache is invalid -var ErrLRUCacheInvalidSize = errors.New("wrong size in bytes value for LRU cache") - -// ErrNegativeSizeInBytes signals that the provided size in bytes value is negative -var ErrNegativeSizeInBytes = errors.New("negative size in bytes") - -// ErrNilTimeCache signals that a nil time cache has been provided -var ErrNilTimeCache = errors.New("nil time cache") - -// ErrNilTxGasHandler signals that a nil tx gas handler was provided -var ErrNilTxGasHandler = errors.New("nil tx gas handler") - // ErrCannotComputeStorageOldestEpoch signals an issue when computing the oldest epoch for storage var ErrCannotComputeStorageOldestEpoch = errors.New("could not compute the oldest epoch for storage") @@ -143,14 +70,26 @@ var ErrNilNodeTypeProvider = errors.New("nil node type provider") // ErrNilOldDataCleanerProvider signals that a nil old data cleaner provider has been provided var ErrNilOldDataCleanerProvider = errors.New("nil old data cleaner provider") -// ErrNilStoredDataFactory signals that a nil stored data factory has been provided -var ErrNilStoredDataFactory = errors.New("nil stored data factory") +// ErrKeyNotFound is raised when a key is not found +var ErrKeyNotFound = storageErrors.ErrKeyNotFound + +// ErrInvalidConfig signals an invalid config +var ErrInvalidConfig = storageErrors.ErrInvalidConfig + +// ErrCacheSizeInvalid signals that size of cache is less than 1 +var ErrCacheSizeInvalid = storageErrors.ErrCacheSizeInvalid -// ErrInvalidDefaultSpan signals that an invalid default span was provided -var ErrInvalidDefaultSpan = errors.New("invalid default span") +// ErrNotSupportedDBType is raised when an unsupported database type is provided +var ErrNotSupportedDBType = storageErrors.ErrNotSupportedDBType + +// ErrNotSupportedCacheType is raised when an unsupported cache type is provided +var ErrNotSupportedCacheType = storageErrors.ErrNotSupportedCacheType // ErrInvalidCacheExpiry signals that an invalid cache expiry was provided -var ErrInvalidCacheExpiry = errors.New("invalid cache expiry") +var ErrInvalidCacheExpiry = storageErrors.ErrInvalidCacheExpiry + +// ErrDBIsClosed is raised when the DB is closed +var ErrDBIsClosed = storageErrors.ErrDBIsClosed // ErrEpochKeepIsLowerThanNumActive signals that num epochs to keep is lower than num active epochs var ErrEpochKeepIsLowerThanNumActive = errors.New("num epochs to keep is lower than num active epochs") diff --git a/storage/factory/bootstrapDataProvider.go b/storage/factory/bootstrapDataProvider.go index 3c4f504d5e2..0139564586c 100644 --- a/storage/factory/bootstrapDataProvider.go +++ b/storage/factory/bootstrapDataProvider.go @@ -6,8 +6,8 @@ import ( "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/process/block/bootstrapStorage" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) type bootstrapDataProvider struct { @@ -45,12 +45,12 @@ func (bdp *bootstrapDataProvider) LoadForPath( } }() - cacher, err := lrucache.NewCache(10) + cacher, err := cache.NewLRUCache(10) if err != nil { return nil, nil, err } - storer, err := storageUnit.NewStorageUnit(cacher, persister) + storer, err := storageunit.NewStorageUnit(cacher, persister) if err != nil { return nil, nil, err } diff --git a/storage/factory/bootstrapDataProvider_test.go b/storage/factory/bootstrapDataProvider_test.go index 58258068aee..bc35f104dc3 100644 --- a/storage/factory/bootstrapDataProvider_test.go +++ b/storage/factory/bootstrapDataProvider_test.go @@ -9,7 +9,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/process/block/bootstrapStorage" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" + "github.com/ElrondNetwork/elrond-go/storage/database" "github.com/ElrondNetwork/elrond-go/storage/mock" "github.com/stretchr/testify/require" ) @@ -54,7 +54,7 @@ func TestBootstrapDataProvider_LoadForPath_KeyNotFound(t *testing.T) { bdp, _ := NewBootstrapDataProvider(&mock.MarshalizerMock{}) persisterFactory := &mock.PersisterFactoryStub{ CreateCalled: func(_ string) (persister storage.Persister, e error) { - persister, e = memorydb.NewlruDB(20) + persister, e = database.NewlruDB(20) return }, } @@ -70,7 +70,7 @@ func TestBootstrapDataProvider_LoadForPath_ShouldWork(t *testing.T) { marshalizer := &mock.MarshalizerMock{} bdp, _ := NewBootstrapDataProvider(marshalizer) - persisterToUse := memorydb.New() + persisterToUse := database.NewMemDB() expectedRound := int64(37) roundNum := bootstrapStorage.RoundNum{Num: expectedRound} diff --git a/storage/factory/common.go b/storage/factory/common.go index 7af92324e92..2d9a064eb24 100644 --- a/storage/factory/common.go +++ b/storage/factory/common.go @@ -2,26 +2,26 @@ package factory import ( "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // GetCacherFromConfig will return the cache config needed for storage unit from a config came from the toml file -func GetCacherFromConfig(cfg config.CacheConfig) storageUnit.CacheConfig { - return storageUnit.CacheConfig{ +func GetCacherFromConfig(cfg config.CacheConfig) storageunit.CacheConfig { + return storageunit.CacheConfig{ Name: cfg.Name, Capacity: cfg.Capacity, SizePerSender: cfg.SizePerSender, SizeInBytes: cfg.SizeInBytes, SizeInBytesPerSender: cfg.SizeInBytesPerSender, - Type: storageUnit.CacheType(cfg.Type), + Type: storageunit.CacheType(cfg.Type), Shards: cfg.Shards, } } // GetDBFromConfig will return the db config needed for storage unit from a config came from the toml file -func GetDBFromConfig(cfg config.DBConfig) storageUnit.DBConfig { - return storageUnit.DBConfig{ - Type: storageUnit.DBType(cfg.Type), +func GetDBFromConfig(cfg config.DBConfig) storageunit.DBConfig { + return storageunit.DBConfig{ + Type: storageunit.DBType(cfg.Type), MaxBatchSize: cfg.MaxBatchSize, BatchDelaySeconds: cfg.BatchDelaySeconds, MaxOpenFiles: cfg.MaxOpenFiles, diff --git a/storage/factory/common_test.go b/storage/factory/common_test.go index 17604ac009e..eb2dbdfa180 100644 --- a/storage/factory/common_test.go +++ b/storage/factory/common_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/stretchr/testify/assert" ) @@ -19,10 +19,10 @@ func TestGetCacherFromConfig(t *testing.T) { } storageCacheConfig := GetCacherFromConfig(cfg) - assert.Equal(t, storageUnit.CacheConfig{ + assert.Equal(t, storageunit.CacheConfig{ Capacity: cfg.Capacity, SizeInBytes: cfg.SizeInBytes, - Type: storageUnit.CacheType(cfg.Type), + Type: storageunit.CacheType(cfg.Type), Shards: cfg.Shards, }, storageCacheConfig) } @@ -38,8 +38,8 @@ func TestGetDBFromConfig(t *testing.T) { } storageDBConfig := GetDBFromConfig(cfg) - assert.Equal(t, storageUnit.DBConfig{ - Type: storageUnit.DBType(cfg.Type), + assert.Equal(t, storageunit.DBConfig{ + Type: storageunit.DBType(cfg.Type), MaxBatchSize: cfg.MaxBatchSize, BatchDelaySeconds: cfg.BatchDelaySeconds, MaxOpenFiles: cfg.MaxOpenFiles, diff --git a/storage/factory/openStorage.go b/storage/factory/openStorage.go index 669e2671001..83ff6392b11 100644 --- a/storage/factory/openStorage.go +++ b/storage/factory/openStorage.go @@ -5,12 +5,11 @@ import ( "path/filepath" "time" - "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/process/block/bootstrapStorage" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) const cacheSize = 10 @@ -68,12 +67,12 @@ func (o *openStorageUnits) GetMostRecentStorageUnit(dbConfig config.DBConfig) (s return nil, err } - cacher, err := lrucache.NewCache(cacheSize) + cacher, err := cache.NewLRUCache(cacheSize) if err != nil { return nil, err } - storer, err := storageUnit.NewStorageUnit(cacher, persister) + storer, err := storageunit.NewStorageUnit(cacher, persister) if err != nil { return nil, err } @@ -108,25 +107,25 @@ func (o *openStorageUnits) OpenDB(dbConfig config.DBConfig, shardID uint32, epoc return nil, err } - cache, err := lrucache.NewCache(cacheSize) + lruCache, err := cache.NewLRUCache(cacheSize) if err != nil { return nil, err } - return storageUnit.NewStorageUnit(cache, persister) + return storageunit.NewStorageUnit(lruCache, persister) } func createDB(persisterFactory *PersisterFactory, persisterPath string) (storage.Persister, error) { var persister storage.Persister var err error - for i := 0; i < common.MaxRetriesToCreateDB; i++ { + for i := 0; i < storage.MaxRetriesToCreateDB; i++ { persister, err = persisterFactory.Create(persisterPath) if err == nil { return persister, nil } log.Warn("Create Persister failed", "path", persisterPath, "error", err) //TODO: extract this in a parameter and inject it - time.Sleep(common.SleepTimeBetweenCreateDBRetries) + time.Sleep(storage.SleepTimeBetweenCreateDBRetries) } return nil, err } diff --git a/storage/factory/pathManager.go b/storage/factory/pathManager.go index 7029dc27004..a455d52c4da 100644 --- a/storage/factory/pathManager.go +++ b/storage/factory/pathManager.go @@ -4,7 +4,7 @@ import ( "fmt" "path/filepath" - "github.com/ElrondNetwork/elrond-go/common" + "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/pathmanager" ) @@ -16,22 +16,22 @@ type ArgCreatePathManager struct { // CreatePathManager crates a path manager from provided working directory and chain ID func CreatePathManager(arg ArgCreatePathManager) (*pathmanager.PathManager, error) { - return CreatePathManagerFromSinglePathString(filepath.Join(arg.WorkingDir, common.DefaultDBPath, arg.ChainID)) + return CreatePathManagerFromSinglePathString(filepath.Join(arg.WorkingDir, storage.DefaultDBPath, arg.ChainID)) } // CreatePathManagerFromSinglePathString crates a path manager from provided path string func CreatePathManagerFromSinglePathString(dbPathWithChainID string) (*pathmanager.PathManager, error) { pathTemplateForPruningStorer := filepath.Join( dbPathWithChainID, - fmt.Sprintf("%s_%s", common.DefaultEpochString, common.PathEpochPlaceholder), - fmt.Sprintf("%s_%s", common.DefaultShardString, common.PathShardPlaceholder), - common.PathIdentifierPlaceholder) + fmt.Sprintf("%s_%s", storage.DefaultEpochString, storage.PathEpochPlaceholder), + fmt.Sprintf("%s_%s", storage.DefaultShardString, storage.PathShardPlaceholder), + storage.PathIdentifierPlaceholder) pathTemplateForStaticStorer := filepath.Join( dbPathWithChainID, - common.DefaultStaticDbString, - fmt.Sprintf("%s_%s", common.DefaultShardString, common.PathShardPlaceholder), - common.PathIdentifierPlaceholder) + storage.DefaultStaticDbString, + fmt.Sprintf("%s_%s", storage.DefaultShardString, storage.PathShardPlaceholder), + storage.PathIdentifierPlaceholder) return pathmanager.NewPathManager(pathTemplateForPruningStorer, pathTemplateForStaticStorer, dbPathWithChainID) } diff --git a/storage/factory/persisterFactory.go b/storage/factory/persisterFactory.go index c2d26526571..f4c0e993630 100644 --- a/storage/factory/persisterFactory.go +++ b/storage/factory/persisterFactory.go @@ -5,9 +5,8 @@ import ( "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/leveldb" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // PersisterFactory is the factory which will handle creating new databases @@ -34,13 +33,13 @@ func (pf *PersisterFactory) Create(path string) (storage.Persister, error) { return nil, errors.New("invalid file path") } - switch storageUnit.DBType(pf.dbType) { - case storageUnit.LvlDB: - return leveldb.NewDB(path, pf.batchDelaySeconds, pf.maxBatchSize, pf.maxOpenFiles) - case storageUnit.LvlDBSerial: - return leveldb.NewSerialDB(path, pf.batchDelaySeconds, pf.maxBatchSize, pf.maxOpenFiles) - case storageUnit.MemoryDB: - return memorydb.New(), nil + switch storageunit.DBType(pf.dbType) { + case storageunit.LvlDB: + return database.NewLevelDB(path, pf.batchDelaySeconds, pf.maxBatchSize, pf.maxOpenFiles) + case storageunit.LvlDBSerial: + return database.NewSerialDB(path, pf.batchDelaySeconds, pf.maxBatchSize, pf.maxOpenFiles) + case storageunit.MemoryDB: + return database.NewMemDB(), nil default: return nil, storage.ErrNotSupportedDBType } diff --git a/storage/factory/pruningStorerFactory.go b/storage/factory/pruningStorerFactory.go index 9f33917f452..abaeab82b53 100644 --- a/storage/factory/pruningStorerFactory.go +++ b/storage/factory/pruningStorerFactory.go @@ -9,13 +9,14 @@ import ( logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/dataRetriever" + "github.com/ElrondNetwork/elrond-go/epochStart" "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/clean" "github.com/ElrondNetwork/elrond-go/storage/databaseremover" "github.com/ElrondNetwork/elrond-go/storage/databaseremover/disabled" storageDisabled "github.com/ElrondNetwork/elrond-go/storage/disabled" "github.com/ElrondNetwork/elrond-go/storage/pruning" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) var log = logger.GetOrCreate("storage/factory") @@ -42,7 +43,7 @@ type StorageServiceFactory struct { prefsConfig *config.PreferencesConfig shardCoordinator storage.ShardCoordinator pathManager storage.PathManagerHandler - epochStartNotifier storage.EpochStartNotifier + epochStartNotifier epochStart.EpochStartNotifier oldDataCleanerProvider clean.OldDataCleanerProvider createTrieEpochRootHashStorer bool currentEpoch uint32 @@ -55,7 +56,7 @@ func NewStorageServiceFactory( prefsConfig *config.PreferencesConfig, shardCoordinator storage.ShardCoordinator, pathManager storage.PathManagerHandler, - epochStartNotifier storage.EpochStartNotifier, + epochStartNotifier epochStart.EpochStartNotifier, nodeTypeProvider NodeTypeProviderHandler, currentEpoch uint32, createTrieEpochRootHashStorer bool, @@ -201,7 +202,7 @@ func (psf *StorageServiceFactory) CreateForShard() (dataRetriever.StorageService shardID := core.GetShardIDString(psf.shardCoordinator.SelfId()) dbPath := psf.pathManager.PathForStatic(shardID, psf.generalConfig.MetaHdrNonceHashStorage.DB.FilePath) metaHdrHashNonceUnitConfig.FilePath = dbPath - metaHdrHashNonceUnit, err := storageUnit.NewStorageUnitFromConf( + metaHdrHashNonceUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.MetaHdrNonceHashStorage.Cache), metaHdrHashNonceUnitConfig) if err != nil { @@ -213,7 +214,7 @@ func (psf *StorageServiceFactory) CreateForShard() (dataRetriever.StorageService shardID = core.GetShardIDString(psf.shardCoordinator.SelfId()) dbPath = psf.pathManager.PathForStatic(shardID, psf.generalConfig.ShardHdrNonceHashStorage.DB.FilePath) + shardID shardHdrHashNonceConfig.FilePath = dbPath - shardHdrHashNonceUnit, err := storageUnit.NewStorageUnitFromConf( + shardHdrHashNonceUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.ShardHdrNonceHashStorage.Cache), shardHdrHashNonceConfig) if err != nil { @@ -224,7 +225,7 @@ func (psf *StorageServiceFactory) CreateForShard() (dataRetriever.StorageService shardId := core.GetShardIDString(psf.shardCoordinator.SelfId()) dbPath = psf.pathManager.PathForStatic(shardId, psf.generalConfig.Heartbeat.HeartbeatStorage.DB.FilePath) heartbeatDbConfig.FilePath = dbPath - heartbeatStorageUnit, err := storageUnit.NewStorageUnitFromConf( + heartbeatStorageUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.Heartbeat.HeartbeatStorage.Cache), heartbeatDbConfig) if err != nil { @@ -235,7 +236,7 @@ func (psf *StorageServiceFactory) CreateForShard() (dataRetriever.StorageService shardId = core.GetShardIDString(psf.shardCoordinator.SelfId()) dbPath = psf.pathManager.PathForStatic(shardId, psf.generalConfig.StatusMetricsStorage.DB.FilePath) statusMetricsDbConfig.FilePath = dbPath - statusMetricsStorageUnit, err := storageUnit.NewStorageUnitFromConf( + statusMetricsStorageUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.StatusMetricsStorage.Cache), statusMetricsDbConfig) if err != nil { @@ -370,20 +371,20 @@ func (psf *StorageServiceFactory) CreateForMeta() (dataRetriever.StorageService, shardID := core.GetShardIDString(core.MetachainShardId) dbPath := psf.pathManager.PathForStatic(shardID, psf.generalConfig.MetaHdrNonceHashStorage.DB.FilePath) metaHdrHashNonceUnitConfig.FilePath = dbPath - metaHdrHashNonceUnit, err := storageUnit.NewStorageUnitFromConf( + metaHdrHashNonceUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.MetaHdrNonceHashStorage.Cache), metaHdrHashNonceUnitConfig) if err != nil { return nil, err } - shardHdrHashNonceUnits := make([]*storageUnit.Unit, psf.shardCoordinator.NumberOfShards()) + shardHdrHashNonceUnits := make([]*storageunit.Unit, psf.shardCoordinator.NumberOfShards()) for i := uint32(0); i < psf.shardCoordinator.NumberOfShards(); i++ { shardHdrHashNonceConfig := GetDBFromConfig(psf.generalConfig.ShardHdrNonceHashStorage.DB) shardID = core.GetShardIDString(core.MetachainShardId) dbPath = psf.pathManager.PathForStatic(shardID, psf.generalConfig.ShardHdrNonceHashStorage.DB.FilePath) + fmt.Sprintf("%d", i) shardHdrHashNonceConfig.FilePath = dbPath - shardHdrHashNonceUnits[i], err = storageUnit.NewStorageUnitFromConf( + shardHdrHashNonceUnits[i], err = storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.ShardHdrNonceHashStorage.Cache), shardHdrHashNonceConfig) if err != nil { @@ -395,7 +396,7 @@ func (psf *StorageServiceFactory) CreateForMeta() (dataRetriever.StorageService, heartbeatDbConfig := GetDBFromConfig(psf.generalConfig.Heartbeat.HeartbeatStorage.DB) dbPath = psf.pathManager.PathForStatic(shardId, psf.generalConfig.Heartbeat.HeartbeatStorage.DB.FilePath) heartbeatDbConfig.FilePath = dbPath - heartbeatStorageUnit, err := storageUnit.NewStorageUnitFromConf( + heartbeatStorageUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.Heartbeat.HeartbeatStorage.Cache), heartbeatDbConfig) if err != nil { @@ -406,7 +407,7 @@ func (psf *StorageServiceFactory) CreateForMeta() (dataRetriever.StorageService, shardId = core.GetShardIDString(psf.shardCoordinator.SelfId()) dbPath = psf.pathManager.PathForStatic(shardId, psf.generalConfig.StatusMetricsStorage.DB.FilePath) statusMetricsDbConfig.FilePath = dbPath - statusMetricsStorageUnit, err := storageUnit.NewStorageUnitFromConf( + statusMetricsStorageUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.StatusMetricsStorage.Cache), statusMetricsDbConfig) if err != nil { @@ -578,7 +579,7 @@ func (psf *StorageServiceFactory) setupDbLookupExtensions(chainStorer *dataRetri miniblockHashByTxHashDbConfig := GetDBFromConfig(miniblockHashByTxHashConfig.DB) miniblockHashByTxHashDbConfig.FilePath = psf.pathManager.PathForStatic(shardID, miniblockHashByTxHashConfig.DB.FilePath) miniblockHashByTxHashCacherConfig := GetCacherFromConfig(miniblockHashByTxHashConfig.Cache) - miniblockHashByTxHashUnit, err := storageUnit.NewStorageUnitFromConf(miniblockHashByTxHashCacherConfig, miniblockHashByTxHashDbConfig) + miniblockHashByTxHashUnit, err := storageunit.NewStorageUnitFromConf(miniblockHashByTxHashCacherConfig, miniblockHashByTxHashDbConfig) if err != nil { return err } @@ -590,7 +591,7 @@ func (psf *StorageServiceFactory) setupDbLookupExtensions(chainStorer *dataRetri blockHashByRoundDBConfig := GetDBFromConfig(blockHashByRoundConfig.DB) blockHashByRoundDBConfig.FilePath = psf.pathManager.PathForStatic(shardID, blockHashByRoundConfig.DB.FilePath) blockHashByRoundCacherConfig := GetCacherFromConfig(blockHashByRoundConfig.Cache) - blockHashByRoundUnit, err := storageUnit.NewStorageUnitFromConf(blockHashByRoundCacherConfig, blockHashByRoundDBConfig) + blockHashByRoundUnit, err := storageunit.NewStorageUnitFromConf(blockHashByRoundCacherConfig, blockHashByRoundDBConfig) if err != nil { return err } @@ -602,7 +603,7 @@ func (psf *StorageServiceFactory) setupDbLookupExtensions(chainStorer *dataRetri epochByHashDbConfig := GetDBFromConfig(epochByHashConfig.DB) epochByHashDbConfig.FilePath = psf.pathManager.PathForStatic(shardID, epochByHashConfig.DB.FilePath) epochByHashCacherConfig := GetCacherFromConfig(epochByHashConfig.Cache) - epochByHashUnit, err := storageUnit.NewStorageUnitFromConf(epochByHashCacherConfig, epochByHashDbConfig) + epochByHashUnit, err := storageunit.NewStorageUnitFromConf(epochByHashCacherConfig, epochByHashDbConfig) if err != nil { return err } @@ -613,7 +614,7 @@ func (psf *StorageServiceFactory) setupDbLookupExtensions(chainStorer *dataRetri esdtSuppliesDbConfig := GetDBFromConfig(esdtSuppliesConfig.DB) esdtSuppliesDbConfig.FilePath = psf.pathManager.PathForStatic(shardID, esdtSuppliesConfig.DB.FilePath) esdtSuppliesCacherConfig := GetCacherFromConfig(esdtSuppliesConfig.Cache) - esdtSuppliesUnit, err := storageUnit.NewStorageUnitFromConf(esdtSuppliesCacherConfig, esdtSuppliesDbConfig) + esdtSuppliesUnit, err := storageunit.NewStorageUnitFromConf(esdtSuppliesCacherConfig, esdtSuppliesDbConfig) if err != nil { return err } @@ -659,14 +660,14 @@ func (psf *StorageServiceFactory) createPruningStorerArgs( func (psf *StorageServiceFactory) createTrieEpochRootHashStorerIfNeeded() (storage.Storer, error) { if !psf.createTrieEpochRootHashStorer { - return storageUnit.NewNilStorer(), nil + return storageunit.NewNilStorer(), nil } trieEpochRootHashDbConfig := GetDBFromConfig(psf.generalConfig.TrieEpochRootHashStorage.DB) shardId := core.GetShardIDString(psf.shardCoordinator.SelfId()) dbPath := psf.pathManager.PathForStatic(shardId, psf.generalConfig.TrieEpochRootHashStorage.DB.FilePath) trieEpochRootHashDbConfig.FilePath = dbPath - trieEpochRootHashStorageUnit, err := storageUnit.NewStorageUnitFromConf( + trieEpochRootHashStorageUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(psf.generalConfig.TrieEpochRootHashStorage.Cache), trieEpochRootHashDbConfig) if err != nil { @@ -683,7 +684,7 @@ func (psf *StorageServiceFactory) createTriePersister( shardID := core.GetShardIDString(psf.shardCoordinator.SelfId()) dbPath := psf.pathManager.PathForStatic(shardID, storageConfig.DB.FilePath) trieDBConfig.FilePath = dbPath - trieUnit, err := storageUnit.NewStorageUnitFromConf( + trieUnit, err := storageunit.NewStorageUnitFromConf( GetCacherFromConfig(storageConfig.Cache), trieDBConfig) if err != nil { diff --git a/storage/fifocache/export_test.go b/storage/fifocache/export_test.go deleted file mode 100644 index 24f07a1ae09..00000000000 --- a/storage/fifocache/export_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package fifocache - -func (c *FIFOShardedCache) AddedDataHandlers() map[string]func(key []byte, value interface{}) { - return c.mapDataHandlers -} diff --git a/storage/fifocache/fifocacheSharded.go b/storage/fifocache/fifocacheSharded.go deleted file mode 100644 index 85fa427a18d..00000000000 --- a/storage/fifocache/fifocacheSharded.go +++ /dev/null @@ -1,151 +0,0 @@ -package fifocache - -import ( - "sync" - - cmap "github.com/ElrondNetwork/concurrent-map" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var _ storage.Cacher = (*FIFOShardedCache)(nil) - -var log = logger.GetOrCreate("storage/fifocache") - -// FIFOShardedCache implements a First In First Out eviction cache -type FIFOShardedCache struct { - cache *cmap.ConcurrentMap - maxsize int - - mutAddedDataHandlers sync.RWMutex - mapDataHandlers map[string]func(key []byte, value interface{}) -} - -// NewShardedCache creates a new cache instance -func NewShardedCache(size int, shards int) (*FIFOShardedCache, error) { - cache := cmap.New(size, shards) - fifoShardedCache := &FIFOShardedCache{ - cache: cache, - maxsize: size, - mutAddedDataHandlers: sync.RWMutex{}, - mapDataHandlers: make(map[string]func(key []byte, value interface{})), - } - - return fifoShardedCache, nil -} - -// Clear is used to completely clear the cache. -func (c *FIFOShardedCache) Clear() { - keys := c.cache.Keys() - for _, key := range keys { - c.cache.Remove(key) - } -} - -// Put adds a value to the cache. Returns true if an eviction occurred. -// the int parameter for size is not used as, for now, fifo sharded cache can not count for its contained data size -func (c *FIFOShardedCache) Put(key []byte, value interface{}, _ int) (evicted bool) { - c.cache.Set(string(key), value) - c.callAddedDataHandlers(key, value) - - return true -} - -// RegisterHandler registers a new handler to be called when a new data is added -func (c *FIFOShardedCache) RegisterHandler(handler func(key []byte, value interface{}), id string) { - if handler == nil { - log.Error("attempt to register a nil handler to a cacher object") - return - } - - c.mutAddedDataHandlers.Lock() - c.mapDataHandlers[id] = handler - c.mutAddedDataHandlers.Unlock() -} - -// UnRegisterHandler removes the handler from the list -func (c *FIFOShardedCache) UnRegisterHandler(id string) { - c.mutAddedDataHandlers.Lock() - delete(c.mapDataHandlers, id) - c.mutAddedDataHandlers.Unlock() -} - -// Get looks up a key's value from the cache. -func (c *FIFOShardedCache) Get(key []byte) (value interface{}, ok bool) { - return c.cache.Get(string(key)) -} - -// Has checks if a key is in the cache, without updating the -// recent-ness or deleting it for being stale. -func (c *FIFOShardedCache) Has(key []byte) bool { - return c.cache.Has(string(key)) -} - -// Peek returns the key value (or undefined if not found) without updating -// the "recently used"-ness of the key. -func (c *FIFOShardedCache) Peek(key []byte) (value interface{}, ok bool) { - return c.cache.Get(string(key)) -} - -// HasOrAdd checks if a key is in the cache without updating the -// recent-ness or deleting it for being stale, and if not, adds the value. -// Returns whether the item existed before and whether it has been added. -func (c *FIFOShardedCache) HasOrAdd(key []byte, value interface{}, _ int) (has, added bool) { - added = c.cache.SetIfAbsent(string(key), value) - - if added { - c.callAddedDataHandlers(key, value) - } - - return !added, added -} - -func (c *FIFOShardedCache) callAddedDataHandlers(key []byte, value interface{}) { - c.mutAddedDataHandlers.RLock() - for _, handler := range c.mapDataHandlers { - go handler(key, value) - } - c.mutAddedDataHandlers.RUnlock() -} - -// Remove removes the provided key from the cache. -func (c *FIFOShardedCache) Remove(key []byte) { - c.cache.Remove(string(key)) -} - -// Keys returns a slice of the keys in the cache, from oldest to newest. -func (c *FIFOShardedCache) Keys() [][]byte { - res := c.cache.Keys() - r := make([][]byte, len(res)) - - for i := 0; i < len(res); i++ { - r[i] = []byte(res[i]) - } - - return r -} - -// Len returns the number of items in the cache. -func (c *FIFOShardedCache) Len() int { - return c.cache.Count() -} - -// SizeInBytesContained returns 0 -func (c *FIFOShardedCache) SizeInBytesContained() uint64 { - return 0 -} - -// MaxSize returns the maximum number of items which can be stored in cache. -func (c *FIFOShardedCache) MaxSize() int { - return c.maxsize -} - -// Close does nothing for this cacher implementation -func (c *FIFOShardedCache) Close() error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (c *FIFOShardedCache) IsInterfaceNil() bool { - return c == nil -} diff --git a/storage/fifocache/fifocacheSharded_test.go b/storage/fifocache/fifocacheSharded_test.go deleted file mode 100644 index b358fa1009d..00000000000 --- a/storage/fifocache/fifocacheSharded_test.go +++ /dev/null @@ -1,375 +0,0 @@ -package fifocache_test - -import ( - "bytes" - "fmt" - "sync" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/storage/fifocache" - "github.com/stretchr/testify/assert" -) - -var timeoutWaitForWaitGroups = time.Second * 2 - -func TestFIFOShardedCache_PutNotPresent(t *testing.T) { - key, val := []byte("key"), []byte("value") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - l := c.Len() - - assert.Zero(t, l, "cache expected to be empty") - - c.Put(key, val, 0) - l = c.Len() - - assert.Equal(t, l, 1, "cache size expected 1 but found %d", l) -} - -func TestFIFOShardedCache_PutPresent(t *testing.T) { - key, val := []byte("key"), []byte("value") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - c.Put(key, val, 0) - c.Put(key, val, 0) - - l := c.Len() - assert.Equal(t, l, 1, "cache size expected 1 but found %d", l) -} - -func TestFIFOShardedCache_PutPresentRewrite(t *testing.T) { - key := []byte("key") - val1 := []byte("value1") - val2 := []byte("value2") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - c.Put(key, val1, 0) - c.Put(key, val2, 0) - - l := c.Len() - assert.Equal(t, l, 1, "cache size expected 1 but found %d", l) - recoveredVal, has := c.Get(key) - assert.True(t, has) - assert.Equal(t, val2, recoveredVal) -} - -func TestFIFOShardedCache_GetNotPresent(t *testing.T) { - key := []byte("key1") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - v, ok := c.Get(key) - - assert.False(t, ok, "value %s not expected to be found", v) -} - -func TestFIFOShardedCache_GetPresent(t *testing.T) { - key, val := []byte("key2"), []byte("value2") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - c.Put(key, val, 0) - - v, ok := c.Get(key) - - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v) -} - -func TestFIFOShardedCache_HasNotPresent(t *testing.T) { - key := []byte("key3") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - found := c.Has(key) - - assert.False(t, found, "key %s not expected to be found", key) -} - -func TestFIFOShardedCache_HasPresent(t *testing.T) { - key, val := []byte("key4"), []byte("value4") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - c.Put(key, val, 0) - - found := c.Has(key) - - assert.True(t, found, "value expected but not found") -} - -func TestFIFOShardedCache_PeekNotPresent(t *testing.T) { - key := []byte("key5") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - _, ok := c.Peek(key) - - assert.False(t, ok, "not expected to find key %s", key) -} - -func TestFIFOShardedCache_PeekPresent(t *testing.T) { - key, val := []byte("key6"), []byte("value6") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - c.Put(key, val, 0) - v, ok := c.Peek(key) - - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v, "expected to find %s but found %s", val, v) -} - -func TestFIFOShardedCache_HasOrAddNotPresent(t *testing.T) { - key, val := []byte("key7"), []byte("value7") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - _, ok := c.Peek(key) - - assert.False(t, ok, "not expected to find key %s", key) - - c.HasOrAdd(key, val, 0) - v, ok := c.Peek(key) - - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v, "expected to find %s but found %s", val, v) -} - -func TestFIFOShardedCache_HasOrAddPresent(t *testing.T) { - key, val := []byte("key8"), []byte("value8") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - _, ok := c.Peek(key) - - assert.False(t, ok, "not expected to find key %s", key) - - c.HasOrAdd(key, val, 0) - v, ok := c.Peek(key) - - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v, "expected to find %s but found %s", val, v) -} - -func TestFIFOShardedCache_RemoveNotPresent(t *testing.T) { - key := []byte("key9") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - found := c.Has(key) - - assert.False(t, found, "not expected to find key %s", key) - - c.Remove(key) - found = c.Has(key) - - assert.False(t, found, "not expected to find key %s", key) -} - -func TestFIFOShardedCache_RemovePresent(t *testing.T) { - key, val := []byte("key10"), []byte("value10") - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - c.Put(key, val, 0) - found := c.Has(key) - - assert.True(t, found, "expected to find key %s", key) - - c.Remove(key) - found = c.Has(key) - - assert.False(t, found, "not expected to find key %s", key) -} - -func TestFIFOShardedCache_Keys(t *testing.T) { - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - for i := 0; i < 20; i++ { - key, val := []byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)) - c.Put(key, val, 0) - } - - keys := c.Keys() - - // check also that cache size does not grow over the capacity - assert.True(t, 10 >= len(keys), "expected up to 10 stored keys but current found %d", len(keys)) -} - -func TestFIFOShardedCache_Len(t *testing.T) { - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - for i := 0; i < 20; i++ { - key, val := []byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)) - c.Put(key, val, 0) - } - - l := c.Len() - - assert.True(t, 10 >= l, "expected up to 10 stored keys but current size %d", l) -} - -func TestFIFOShardedCache_Clear(t *testing.T) { - c, err := fifocache.NewShardedCache(10, 2) - - assert.Nil(t, err, "no error expected but got %s", err) - - for i := 0; i < 5; i++ { - key, val := []byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)) - c.Put(key, val, 0) - } - - l := c.Len() - - assert.Equal(t, 5, l, "expected size 5, got %d", l) - - c.Clear() - l = c.Len() - - assert.Zero(t, l, "expected size 0, got %d", l) -} - -func TestFIFOShardedCache_CloseDoesNotErr(t *testing.T) { - t.Parallel() - - c, _ := fifocache.NewShardedCache(10, 2) - - err := c.Close() - assert.Nil(t, err) -} - -func TestFIFOShardedCache_CacherRegisterAddedDataHandlerNilHandlerShouldIgnore(t *testing.T) { - t.Parallel() - - c, err := fifocache.NewShardedCache(100, 2) - assert.Nil(t, err) - c.RegisterHandler(nil, "") - - assert.Equal(t, 0, len(c.AddedDataHandlers())) -} - -func TestFIFOShardedCache_CacherRegisterPutAddedDataHandlerShouldWork(t *testing.T) { - t.Parallel() - - wg := sync.WaitGroup{} - wg.Add(1) - chDone := make(chan bool) - - f := func(key []byte, value interface{}) { - if !bytes.Equal([]byte("aaaa"), key) { - return - } - - wg.Done() - } - - go func() { - wg.Wait() - chDone <- true - }() - - c, err := fifocache.NewShardedCache(100, 2) - assert.Nil(t, err) - c.RegisterHandler(f, "") - c.Put([]byte("aaaa"), "bbbb", 0) - - select { - case <-chDone: - case <-time.After(timeoutWaitForWaitGroups): - assert.Fail(t, "should have been called") - return - } - - assert.Equal(t, 1, len(c.AddedDataHandlers())) -} - -func TestFIFOShardedCache_CacherRegisterHasOrAddAddedDataHandlerShouldWork(t *testing.T) { - t.Parallel() - - wg := sync.WaitGroup{} - wg.Add(1) - chDone := make(chan bool) - - f := func(key []byte, value interface{}) { - if !bytes.Equal([]byte("aaaa"), key) { - return - } - - wg.Done() - } - - go func() { - wg.Wait() - chDone <- true - }() - - c, err := fifocache.NewShardedCache(100, 2) - assert.Nil(t, err) - c.RegisterHandler(f, "") - c.HasOrAdd([]byte("aaaa"), "bbbb", 0) - - select { - case <-chDone: - case <-time.After(timeoutWaitForWaitGroups): - assert.Fail(t, "should have been called") - return - } - - assert.Equal(t, 1, len(c.AddedDataHandlers())) -} - -func TestFIFOShardedCache_CacherRegisterHasOrAddAddedDataHandlerNotAddedShouldNotCall(t *testing.T) { - t.Parallel() - - wg := sync.WaitGroup{} - wg.Add(1) - chDone := make(chan bool) - - f := func(key []byte, value interface{}) { - wg.Done() - } - - go func() { - wg.Wait() - chDone <- true - }() - - c, err := fifocache.NewShardedCache(100, 2) - assert.Nil(t, err) - //first add, no call - c.HasOrAdd([]byte("aaaa"), "bbbb", 0) - c.RegisterHandler(f, "") - //second add, should not call as the data was found - c.HasOrAdd([]byte("aaaa"), "bbbb", 0) - - select { - case <-chDone: - assert.Fail(t, "should have not been called") - return - case <-time.After(timeoutWaitForWaitGroups): - } - - assert.Equal(t, 1, len(c.AddedDataHandlers())) -} diff --git a/storage/immunitycache/cache.go b/storage/immunitycache/cache.go deleted file mode 100644 index 5ad63c67a09..00000000000 --- a/storage/immunitycache/cache.go +++ /dev/null @@ -1,311 +0,0 @@ -package immunitycache - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var _ storage.Cacher = (*ImmunityCache)(nil) - -var log = logger.GetOrCreate("storage/immunitycache") - -const hospitalityWarnThreshold = -10000 -const hospitalityUpperLimit = 10000 - -// ImmunityCache is a cache-like structure -type ImmunityCache struct { - config CacheConfig - chunks []*immunityChunk - hospitality atomic.Counter - mutex sync.RWMutex -} - -// NewImmunityCache creates a new cache -func NewImmunityCache(config CacheConfig) (*ImmunityCache, error) { - log.Debug("NewImmunityCache", "config", config.String()) - storage.MonitorNewCache(config.Name, uint64(config.MaxNumBytes)) - - err := config.Verify() - if err != nil { - return nil, err - } - - cache := ImmunityCache{ - config: config, - } - - cache.initializeChunksWithLock() - return &cache, nil -} - -func (ic *ImmunityCache) initializeChunksWithLock() { - ic.mutex.Lock() - defer ic.mutex.Unlock() - - config := ic.config - chunkConfig := config.getChunkConfig() - - ic.chunks = make([]*immunityChunk, config.NumChunks) - for i := uint32(0); i < config.NumChunks; i++ { - ic.chunks[i] = newImmunityChunk(chunkConfig) - } -} - -// ImmunizeKeys marks items as immune to eviction -func (ic *ImmunityCache) ImmunizeKeys(keys [][]byte) (numNowTotal, numFutureTotal int) { - immuneItemsCapacityReached := ic.CountImmune()+len(keys) > int(ic.config.MaxNumItems) - if immuneItemsCapacityReached { - log.Warn("ImmunityCache.ImmunizeKeys(): will not immunize", "err", storage.ErrImmuneItemsCapacityReached) - return - } - - groups := ic.groupKeysByChunk(keys) - - for chunkIndex, chunkKeys := range groups { - chunk := ic.getChunkByIndexWithLock(chunkIndex) - - numNow, numFuture := chunk.ImmunizeKeys(chunkKeys) - numNowTotal += numNow - numFutureTotal += numFuture - } - - return -} - -func (ic *ImmunityCache) groupKeysByChunk(keys [][]byte) map[uint32][][]byte { - groups := make(map[uint32][][]byte) - - for _, key := range keys { - chunkIndex := ic.getChunkIndexByKey(string(key)) - groups[chunkIndex] = append(groups[chunkIndex], key) - } - - return groups -} - -func (ic *ImmunityCache) getChunkIndexByKey(key string) uint32 { - return fnv32Hash(key) % ic.config.NumChunks -} - -// fnv32Hash implements https://en.wikipedia.org/wiki/Fowler–Noll–Vo_hash_function for 32 bits -func fnv32Hash(key string) uint32 { - hash := uint32(2166136261) - const prime32 = uint32(16777619) - for i := 0; i < len(key); i++ { - hash *= prime32 - hash ^= uint32(key[i]) - } - return hash -} - -func (ic *ImmunityCache) getChunkByIndexWithLock(index uint32) *immunityChunk { - ic.mutex.RLock() - defer ic.mutex.RUnlock() - return ic.chunks[index] -} - -func (ic *ImmunityCache) getChunkByKeyWithLock(key string) *immunityChunk { - ic.mutex.RLock() - defer ic.mutex.RUnlock() - - chunkIndex := ic.getChunkIndexByKey(key) - return ic.chunks[chunkIndex] -} - -// Get gets an item (payload) by key -func (ic *ImmunityCache) Get(key []byte) (value interface{}, ok bool) { - item, ok := ic.getItem(key) - if ok { - return item.payload, true - } - - return nil, false -} - -// GetItem gets an item by key -func (ic *ImmunityCache) getItem(key []byte) (*cacheItem, bool) { - chunk := ic.getChunkByKeyWithLock(string(key)) - return chunk.GetItem(string(key)) -} - -// Has checks is an item exists -func (ic *ImmunityCache) Has(key []byte) bool { - chunk := ic.getChunkByKeyWithLock(string(key)) - _, ok := chunk.GetItem(string(key)) - return ok -} - -// Peek gets an item -func (ic *ImmunityCache) Peek(key []byte) (value interface{}, ok bool) { - return ic.Get(key) -} - -// HasOrAdd adds an item in the cache -func (ic *ImmunityCache) HasOrAdd(key []byte, value interface{}, sizeInBytes int) (has, added bool) { - item := newCacheItem(value, string(key), sizeInBytes) - chunk := ic.getChunkByKeyWithLock(string(key)) - has, added = chunk.AddItem(item) - if !has { - if added { - ic.hospitality.Increment() - } else { - ic.hospitality.Decrement() - } - } - - return has, added -} - -// Put adds an item in the cache -func (ic *ImmunityCache) Put(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - ic.HasOrAdd(key, value, sizeInBytes) - return false -} - -// Remove removes an item -func (ic *ImmunityCache) Remove(key []byte) { - _ = ic.RemoveWithResult(key) -} - -// RemoveWithResult removes an item -// TODO: In the future, add this method to the "storage.Cacher" interface. EN-6739. -func (ic *ImmunityCache) RemoveWithResult(key []byte) bool { - chunk := ic.getChunkByKeyWithLock(string(key)) - return chunk.RemoveItem(string(key)) -} - -// RemoveOldest is not implemented -func (ic *ImmunityCache) RemoveOldest() { - log.Error("ImmunityCache.RemoveOldest is not implemented") -} - -// Clear clears the map -func (ic *ImmunityCache) Clear() { - // There is no need to explicitly remove each item for each chunk - // The garbage collector will remove the data from memory - ic.initializeChunksWithLock() -} - -// MaxSize returns the capacity of the cache -func (ic *ImmunityCache) MaxSize() int { - return int(ic.config.MaxNumItems) -} - -// Len is an alias for Count -func (ic *ImmunityCache) Len() int { - return ic.Count() -} - -// SizeInBytesContained returns 0 -func (ic *ImmunityCache) SizeInBytesContained() uint64 { - return 0 -} - -// Count returns the number of elements within the map -func (ic *ImmunityCache) Count() int { - count := 0 - for _, chunk := range ic.getChunksWithLock() { - count += chunk.Count() - } - return count -} - -func (ic *ImmunityCache) getChunksWithLock() []*immunityChunk { - ic.mutex.RLock() - defer ic.mutex.RUnlock() - return ic.chunks -} - -// CountImmune returns the number of immunized (current or future) elements within the map -func (ic *ImmunityCache) CountImmune() int { - count := 0 - for _, chunk := range ic.getChunksWithLock() { - count += chunk.CountImmune() - } - return count -} - -// NumBytes estimates the size of the cache, in bytes -func (ic *ImmunityCache) NumBytes() int { - numBytes := 0 - for _, chunk := range ic.getChunksWithLock() { - numBytes += chunk.NumBytes() - } - return numBytes -} - -// Keys returns all keys -func (ic *ImmunityCache) Keys() [][]byte { - count := ic.Count() - keys := make([][]byte, 0, count) - - for _, chunk := range ic.getChunksWithLock() { - keys = chunk.AppendKeys(keys) - } - - return keys -} - -// RegisterHandler is not implemented -func (ic *ImmunityCache) RegisterHandler(func(key []byte, value interface{}), string) { - log.Error("ImmunityCache.RegisterHandler is not implemented") -} - -// UnRegisterHandler removes the handler from the list -func (ic *ImmunityCache) UnRegisterHandler(_ string) { - log.Error("ImmunityCache.UnRegisterHandler is not implemented") -} - -// ForEachItem iterates over the items in the cache -func (ic *ImmunityCache) ForEachItem(function storage.ForEachItem) { - for _, chunk := range ic.getChunksWithLock() { - chunk.ForEachItem(function) - } -} - -// Diagnose displays a summary of the internal state of the cache -func (ic *ImmunityCache) Diagnose(_ bool) { - count := ic.Count() - countImmune := ic.CountImmune() - numBytes := ic.NumBytes() - hospitality := ic.hospitality.Get() - - isNotHospitable := hospitality <= hospitalityWarnThreshold - if isNotHospitable { - // After emitting a Warn, we reset the hospitality indicator - log.Warn("ImmunityCache.Diagnose(): cache is not hospitable", - "name", ic.config.Name, - "count", count, - "countImmune", countImmune, - "numBytes", numBytes, - "hospitality", hospitality, - ) - ic.hospitality.Reset() - return - } - - if hospitality >= hospitalityUpperLimit { - ic.hospitality.Set(hospitalityUpperLimit) - } - - log.Trace("ImmunityCache.Diagnose()", - "name", ic.config.Name, - "count", count, - "countImmune", countImmune, - "numBytes", numBytes, - "hospitality", hospitality, - ) -} - -// Close does nothing for this cacher implementation -func (ic *ImmunityCache) Close() error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ic *ImmunityCache) IsInterfaceNil() bool { - return ic == nil -} diff --git a/storage/immunitycache/cacheItem.go b/storage/immunitycache/cacheItem.go deleted file mode 100644 index 59c194c399f..00000000000 --- a/storage/immunitycache/cacheItem.go +++ /dev/null @@ -1,28 +0,0 @@ -package immunitycache - -import ( - "github.com/ElrondNetwork/elrond-go-core/core/atomic" -) - -type cacheItem struct { - payload interface{} - key string - size int - isImmune atomic.Flag -} - -func newCacheItem(payload interface{}, key string, size int) *cacheItem { - return &cacheItem{ - payload: payload, - key: key, - size: size, - } -} - -func (item *cacheItem) isImmuneToEviction() bool { - return item.isImmune.IsSet() -} - -func (item *cacheItem) immunizeAgainstEviction() { - _ = item.isImmune.SetReturningPrevious() -} diff --git a/storage/immunitycache/cache_test.go b/storage/immunitycache/cache_test.go deleted file mode 100644 index f998386c9bf..00000000000 --- a/storage/immunitycache/cache_test.go +++ /dev/null @@ -1,331 +0,0 @@ -package immunitycache - -import ( - "errors" - "fmt" - "math" - "sync" - "testing" - - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewImmunityCache(t *testing.T) { - config := CacheConfig{ - Name: "test", - NumChunks: 16, - MaxNumItems: math.MaxUint32, - MaxNumBytes: maxNumBytesUpperBound, - NumItemsToPreemptivelyEvict: 100, - } - - cache, err := NewImmunityCache(config) - require.Nil(t, err) - require.NotNil(t, cache) - require.Equal(t, math.MaxUint32, cache.MaxSize()) - - invalidConfig := config - invalidConfig.Name = "" - requireErrorOnNewCache(t, invalidConfig, storage.ErrInvalidConfig, "config.Name") - - invalidConfig = config - invalidConfig.NumChunks = 0 - requireErrorOnNewCache(t, invalidConfig, storage.ErrInvalidConfig, "config.NumChunks") - - invalidConfig = config - invalidConfig.MaxNumItems = 0 - requireErrorOnNewCache(t, invalidConfig, storage.ErrInvalidConfig, "config.MaxNumItems") - - invalidConfig = config - invalidConfig.MaxNumBytes = 0 - requireErrorOnNewCache(t, invalidConfig, storage.ErrInvalidConfig, "config.MaxNumBytes") - - invalidConfig = config - invalidConfig.NumItemsToPreemptivelyEvict = 0 - requireErrorOnNewCache(t, invalidConfig, storage.ErrInvalidConfig, "config.NumItemsToPreemptivelyEvict") -} - -func requireErrorOnNewCache(t *testing.T, config CacheConfig, errExpected error, errPartialMessage string) { - cache, errReceived := NewImmunityCache(config) - require.Nil(t, cache) - require.True(t, errors.Is(errReceived, errExpected)) - require.Contains(t, errReceived.Error(), errPartialMessage) -} - -func TestImmunityCache_ImmunizeAgainstEviction(t *testing.T) { - cache := newCacheToTest(1, 8, maxNumBytesUpperBound) - - cache.addTestItems("a", "b", "c", "d") - numNow, numFuture := cache.ImmunizeKeys(keysAsBytes([]string{"a", "b", "e", "f"})) - require.Equal(t, 2, numNow) - require.Equal(t, 2, numFuture) - require.Equal(t, 4, cache.Len()) - require.Equal(t, 4, cache.CountImmune()) - - cache.addTestItems("e", "f", "g", "h") - require.ElementsMatch(t, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, keysAsStrings(cache.Keys())) - - cache.addTestItems("i", "j", "k", "l") - require.ElementsMatch(t, []string{"a", "b", "e", "f", "i", "j", "k", "l"}, keysAsStrings(cache.Keys())) - - require.Equal(t, 4, cache.CountImmune()) - cache.Remove([]byte("e")) - cache.Remove([]byte("f")) - require.Equal(t, 2, cache.CountImmune()) -} - -func TestImmunityCache_ImmunizeDoesNothingIfCapacityReached(t *testing.T) { - cache := newCacheToTest(1, 4, maxNumBytesUpperBound) - - numNow, numFuture := cache.ImmunizeKeys(keysAsBytes([]string{"a", "b", "c", "d"})) - require.Equal(t, 0, numNow) - require.Equal(t, 4, numFuture) - require.Equal(t, 4, cache.CountImmune()) - - numNow, numFuture = cache.ImmunizeKeys(keysAsBytes([]string{"e", "f", "g", "h"})) - require.Equal(t, 0, numNow) - require.Equal(t, 0, numFuture) - require.Equal(t, 4, cache.CountImmune()) -} - -func TestImmunityCache_AddThenRemove(t *testing.T) { - cache := newCacheToTest(1, 8, maxNumBytesUpperBound) - - _, _ = cache.HasOrAdd([]byte("a"), "foo-a", 1) - _, _ = cache.HasOrAdd([]byte("b"), "foo-b", 1) - _, _ = cache.HasOrAdd([]byte("c"), "foo-c", 0) - _ = cache.Put([]byte("d"), "foo-d", 0) // Same as HasOrAdd() - require.Equal(t, 4, cache.Len()) - require.Equal(t, 4, int(cache.hospitality.Get())) - require.True(t, cache.Has([]byte("a"))) - require.True(t, cache.Has([]byte("c"))) - - // Duplicates are not added - _, added := cache.HasOrAdd([]byte("a"), "foo-a", 1) - require.False(t, added) - - // Won't remove if not exists - ok := cache.RemoveWithResult([]byte("x")) - require.False(t, ok) - - cache.Remove([]byte("a")) - cache.Remove([]byte("c")) - require.Equal(t, 2, cache.Len()) - require.False(t, cache.Has([]byte("a"))) - require.False(t, cache.Has([]byte("c"))) - - cache.addTestItems("e", "f", "g", "h", "i", "j") - require.Equal(t, 8, cache.Len()) - - // Now eviction takes place - cache.addTestItems("k", "l", "m", "n") - require.Equal(t, 8, cache.Len()) - require.ElementsMatch(t, []string{"g", "h", "i", "j", "k", "l", "m", "n"}, keysAsStrings(cache.Keys())) - - cache.Clear() - require.Equal(t, 0, cache.Len()) -} - -func TestImmunityCache_Get(t *testing.T) { - cache := newCacheToTest(1, 8, maxNumBytesUpperBound) - - a := "foo-a" - b := "foo-b" - _, added := cache.HasOrAdd([]byte("a"), a, 1) - require.True(t, added) - _, added = cache.HasOrAdd([]byte("b"), b, 1) - require.True(t, added) - - item, ok := cache.Get([]byte("a")) - require.True(t, ok) - require.Equal(t, a, item) - - itemAsEmptyInterface, ok := cache.Get([]byte("a")) - require.True(t, ok) - require.Equal(t, a, itemAsEmptyInterface) - - itemAsEmptyInterface, ok = cache.Peek([]byte("a")) - require.True(t, ok) - require.Equal(t, a, itemAsEmptyInterface) - - item, ok = cache.Get([]byte("b")) - require.True(t, ok) - require.Equal(t, b, item) - - itemAsEmptyInterface, ok = cache.Get([]byte("b")) - require.True(t, ok) - require.Equal(t, b, itemAsEmptyInterface) - - itemAsEmptyInterface, ok = cache.Peek([]byte("b")) - require.True(t, ok) - require.Equal(t, b, itemAsEmptyInterface) - - item, ok = cache.Get([]byte("c")) - require.False(t, ok) - require.Nil(t, item) - - itemAsEmptyInterface, ok = cache.Get([]byte("c")) - require.False(t, ok) - require.Nil(t, itemAsEmptyInterface) - - itemAsEmptyInterface, ok = cache.Peek([]byte("c")) - require.False(t, ok) - require.Nil(t, itemAsEmptyInterface) -} - -func TestImmunityCache_AddThenRemove_ChangesNumBytes(t *testing.T) { - cache := newCacheToTest(1, 8, 1000) - - _, _ = cache.HasOrAdd([]byte("a"), "foo-a", 100) - _, _ = cache.HasOrAdd([]byte("b"), "foo-b", 300) - require.Equal(t, 400, cache.NumBytes()) - - _, _ = cache.HasOrAdd([]byte("c"), "foo-c", 400) - _, _ = cache.HasOrAdd([]byte("d"), "foo-d", 200) - require.Equal(t, 1000, cache.NumBytes()) - - // Eviction takes place - _, _ = cache.HasOrAdd([]byte("e"), "foo-e", 500) - // Edge case, added item overflows. - // Should not be an issue in practice, when we preemptively evict a large number of items. - require.Equal(t, 1400, cache.NumBytes()) - require.ElementsMatch(t, []string{"b", "c", "d", "e"}, keysAsStrings(cache.Keys())) - - // "b" and "c" (300 + 400) will be evicted - _, _ = cache.HasOrAdd([]byte("f"), "foo-f", 400) - require.Equal(t, 1100, cache.NumBytes()) -} - -func TestImmunityCache_AddDoesNotWork_WhenFullWithImmune(t *testing.T) { - cache := newCacheToTest(1, 4, 1000) - - cache.addTestItems("a", "b", "c", "d") - numNow, numFuture := cache.ImmunizeKeys(keysAsBytes([]string{"a", "b", "c", "d"})) - require.Equal(t, 4, numNow) - require.Equal(t, 0, numFuture) - require.Equal(t, 4, int(cache.hospitality.Get())) - - _, added := cache.HasOrAdd([]byte("x"), "foo-x", 1) - require.False(t, added) - require.False(t, cache.Has([]byte("x"))) - require.Equal(t, 3, int(cache.hospitality.Get())) -} - -func TestImmunityCache_ForEachItem(t *testing.T) { - cache := newCacheToTest(1, 4, 1000) - - keys := make([]string, 0) - cache.addTestItems("a", "b", "c", "d") - cache.ForEachItem(func(key []byte, value interface{}) { - keys = append(keys, string(key)) - }) - - require.ElementsMatch(t, []string{"a", "b", "c", "d"}, keys) -} - -// This information about (hash to chunk) distribution is useful to write tests -func TestImmunityCache_Fnv32Hash(t *testing.T) { - // Cache with 2 chunks - require.Equal(t, 0, int(fnv32Hash("a")%2)) - require.Equal(t, 1, int(fnv32Hash("b")%2)) - require.Equal(t, 0, int(fnv32Hash("c")%2)) - require.Equal(t, 1, int(fnv32Hash("d")%2)) - - // Cache with 4 chunks - require.Equal(t, 2, int(fnv32Hash("a")%4)) - require.Equal(t, 1, int(fnv32Hash("b")%4)) - require.Equal(t, 0, int(fnv32Hash("c")%4)) - require.Equal(t, 3, int(fnv32Hash("d")%4)) -} - -func TestImmunityCache_DiagnoseAppliesLimitToHospitality(t *testing.T) { - cache := newCacheToTest(1, hospitalityUpperLimit*42, 1000) - - for i := 0; i < hospitalityUpperLimit*2; i++ { - cache.addTestItems(fmt.Sprintf("%d", i)) - require.Equal(t, i+1, int(cache.hospitality.Get())) - } - - require.Equal(t, hospitalityUpperLimit*2, int(cache.hospitality.Get())) - cache.Diagnose(false) - require.Equal(t, hospitalityUpperLimit, int(cache.hospitality.Get())) -} - -func TestImmunityCache_DiagnoseResetsHospitalityAfterWarn(t *testing.T) { - cache := newCacheToTest(1, 4, 1000) - cache.addTestItems("a", "b", "c", "d") - _, _ = cache.ImmunizeKeys(keysAsBytes([]string{"a", "b", "c", "d"})) - require.Equal(t, 4, int(cache.hospitality.Get())) - - cache.addTestItems("e", "f", "g", "h") - require.Equal(t, 0, int(cache.hospitality.Get())) - - for i := -1; i > hospitalityWarnThreshold; i-- { - cache.addTestItems("foo") - require.Equal(t, i, int(cache.hospitality.Get())) - } - - require.Equal(t, hospitalityWarnThreshold+1, int(cache.hospitality.Get())) - cache.Diagnose(false) - require.Equal(t, hospitalityWarnThreshold+1, int(cache.hospitality.Get())) - cache.addTestItems("foo") - require.Equal(t, hospitalityWarnThreshold, int(cache.hospitality.Get())) - cache.Diagnose(false) - require.Equal(t, 0, int(cache.hospitality.Get())) -} - -func TestImmunityCache_ClearConcurrentWithRangeOverChunks(t *testing.T) { - cache := newCacheToTest(16, 4, 1000) - require.Equal(t, 16, len(cache.chunks)) - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - for i := 0; i < 10000; i++ { - cache.Clear() - } - wg.Done() - }() - - go func() { - for i := 0; i < 10000; i++ { - for _, chunk := range cache.getChunksWithLock() { - assert.Equal(t, 0, chunk.Count()) - } - } - wg.Done() - }() - - wg.Wait() -} - -func TestImmunityCache_CloseDoesNotErr(t *testing.T) { - cache := newCacheToTest(1, 4, 1000) - - err := cache.Close() - assert.Nil(t, err) -} - -func newCacheToTest(numChunks uint32, maxNumItems uint32, numMaxBytes uint32) *ImmunityCache { - cache, err := NewImmunityCache(CacheConfig{ - Name: "test", - NumChunks: numChunks, - MaxNumItems: maxNumItems, - MaxNumBytes: numMaxBytes, - NumItemsToPreemptivelyEvict: numChunks * 1, - }) - if err != nil { - panic(fmt.Sprintf("newCacheToTest(): %s", err)) - } - - return cache -} - -func (ic *ImmunityCache) addTestItems(keys ...string) { - for _, key := range keys { - _, _ = ic.HasOrAdd([]byte(key), fmt.Sprintf("foo-%s", key), 100) - } -} diff --git a/storage/immunitycache/chunk.go b/storage/immunitycache/chunk.go deleted file mode 100644 index 14d1aeda8d0..00000000000 --- a/storage/immunitycache/chunk.go +++ /dev/null @@ -1,283 +0,0 @@ -package immunitycache - -import ( - "container/list" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var emptyStruct struct{} - -type immunityChunk struct { - config immunityChunkConfig - items map[string]chunkItemWrapper - itemsAsList *list.List - immuneKeys map[string]struct{} - numBytes int - mutex sync.RWMutex -} - -type chunkItemWrapper struct { - item *cacheItem - listElement *list.Element -} - -func newImmunityChunk(config immunityChunkConfig) *immunityChunk { - log.Trace("newImmunityChunk", "config", config.String()) - - return &immunityChunk{ - config: config, - items: make(map[string]chunkItemWrapper), - itemsAsList: list.New(), - immuneKeys: make(map[string]struct{}), - } -} - -// ImmunizeKeys marks keys as immune to eviction -func (chunk *immunityChunk) ImmunizeKeys(keys [][]byte) (numNow, numFuture int) { - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - - for _, key := range keys { - item, ok := chunk.getItemNoLock(string(key)) - - if ok { - // Item exists, immunize now! - item.immunizeAgainstEviction() - numNow++ - } else { - // Item not yet in cache, will be immunized in the future - numFuture++ - } - - // Disregarding the items presence, we hold the immune key - chunk.immuneKeys[string(key)] = emptyStruct - } - - return -} - -func (chunk *immunityChunk) getItemNoLock(key string) (*cacheItem, bool) { - wrapper, ok := chunk.items[key] - if !ok { - return nil, false - } - - return wrapper.item, true -} - -// AddItem add an item to the chunk -func (chunk *immunityChunk) AddItem(item *cacheItem) (has, added bool) { - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - - err := chunk.evictItemsIfCapacityExceededNoLock() - if err != nil { - // No more room for the new item - return false, false - } - - // Discard duplicates - if chunk.itemExistsNoLock(item) { - return true, false - } - - chunk.addItemNoLock(item) - chunk.immunizeItemOnAddNoLock(item) - chunk.trackNumBytesOnAddNoLock(item) - return false, true -} - -func (chunk *immunityChunk) evictItemsIfCapacityExceededNoLock() error { - if !chunk.isCapacityExceededNoLock() { - return nil - } - - numRemoved, err := chunk.evictItemsNoLock() - chunk.monitorEvictionNoLock(numRemoved, err) - return err -} - -func (chunk *immunityChunk) isCapacityExceededNoLock() bool { - tooManyItems := len(chunk.items) >= int(chunk.config.maxNumItems) - tooManyBytes := chunk.numBytes >= int(chunk.config.maxNumBytes) - return tooManyItems || tooManyBytes -} - -func (chunk *immunityChunk) evictItemsNoLock() (numRemoved int, err error) { - numToRemoveEachStep := int(chunk.config.numItemsToPreemptivelyEvict) - - // We perform the first step out of the loop in order to detect & return error - numRemovedInStep := chunk.removeOldestNoLock(numToRemoveEachStep) - numRemoved += numRemovedInStep - - if numRemovedInStep == 0 { - return 0, storage.ErrFailedCacheEviction - } - - for chunk.isCapacityExceededNoLock() && numRemovedInStep == numToRemoveEachStep { - numRemovedInStep = chunk.removeOldestNoLock(numToRemoveEachStep) - numRemoved += numRemovedInStep - } - - return numRemoved, nil -} - -func (chunk *immunityChunk) removeOldestNoLock(numToRemove int) int { - numRemoved := 0 - element := chunk.itemsAsList.Front() - - for element != nil && numRemoved < numToRemove { - item := element.Value.(*cacheItem) - - if item.isImmuneToEviction() { - element = element.Next() - continue - } - - elementToRemove := element - element = element.Next() - - chunk.removeNoLock(elementToRemove) - numRemoved++ - } - - return numRemoved -} - -func (chunk *immunityChunk) removeNoLock(element *list.Element) { - item := element.Value.(*cacheItem) - delete(chunk.items, item.key) - chunk.itemsAsList.Remove(element) - chunk.trackNumBytesOnRemoveNoLock(item) -} - -func (chunk *immunityChunk) monitorEvictionNoLock(numRemoved int, err error) { - if err != nil { - log.Trace("immunityChunk.monitorEviction()", "name", chunk.config.cacheName, "numRemoved", numRemoved, "err", err) - } -} - -func (chunk *immunityChunk) itemExistsNoLock(item *cacheItem) bool { - _, exists := chunk.items[item.key] - return exists -} - -// First, we insert (append) in the linked list; then in the map. -// In the map, we also need to hold a reference to the list element, to have O(1) removal. -func (chunk *immunityChunk) addItemNoLock(item *cacheItem) { - element := chunk.itemsAsList.PushBack(item) - chunk.items[item.key] = chunkItemWrapper{item: item, listElement: element} -} - -func (chunk *immunityChunk) immunizeItemOnAddNoLock(item *cacheItem) { - if _, immunize := chunk.immuneKeys[item.key]; immunize { - item.immunizeAgainstEviction() - // We do not remove the key from "immuneKeys", we hold it there until item's removal. - } -} - -func (chunk *immunityChunk) trackNumBytesOnAddNoLock(item *cacheItem) { - chunk.numBytes += item.size -} - -// GetItem gets an item from the chunk -func (chunk *immunityChunk) GetItem(key string) (*cacheItem, bool) { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - return chunk.getItemNoLock(key) -} - -// RemoveItem removes an item from the chunk -// In order to improve the robustness of the cache, we'll also remove from "keysToImmunizeFuture", -// even if the item does not actually exist in the cache - to allow un-doing immunization intent (perhaps useful for rollbacks). -func (chunk *immunityChunk) RemoveItem(key string) bool { - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - - delete(chunk.immuneKeys, key) - - wrapper, ok := chunk.items[key] - if !ok { - return false - } - - chunk.removeNoLock(wrapper.listElement) - return true -} - -func (chunk *immunityChunk) trackNumBytesOnRemoveNoLock(item *cacheItem) { - chunk.numBytes -= item.size - chunk.numBytes = core.MaxInt(chunk.numBytes, 0) -} - -// RemoveOldest removes a number of old items -func (chunk *immunityChunk) RemoveOldest(numToRemove int) int { - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - return chunk.removeOldestNoLock(numToRemove) -} - -// Count counts the items -func (chunk *immunityChunk) Count() int { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - return len(chunk.items) -} - -// CountImmune counts the immune items -func (chunk *immunityChunk) CountImmune() int { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - return len(chunk.immuneKeys) -} - -// NumBytes gets the number of bytes stored -func (chunk *immunityChunk) NumBytes() int { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - return chunk.numBytes -} - -// KeysInOrder gets the keys, in order -func (chunk *immunityChunk) KeysInOrder() [][]byte { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - - keys := make([][]byte, 0, chunk.itemsAsList.Len()) - for element := chunk.itemsAsList.Front(); element != nil; element = element.Next() { - item := element.Value.(*cacheItem) - keys = append(keys, []byte(item.key)) - } - - return keys -} - -// AppendKeys accumulates keys in a given slice -func (chunk *immunityChunk) AppendKeys(keysAccumulator [][]byte) [][]byte { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - - for key := range chunk.items { - keysAccumulator = append(keysAccumulator, []byte(key)) - } - - return keysAccumulator -} - -// ForEachItem iterates over the items in the chunk -func (chunk *immunityChunk) ForEachItem(function storage.ForEachItem) { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - - for key, itemWrapper := range chunk.items { - function([]byte(key), itemWrapper.item.payload) - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (chunk *immunityChunk) IsInterfaceNil() bool { - return chunk == nil -} diff --git a/storage/immunitycache/chunk_test.go b/storage/immunitycache/chunk_test.go deleted file mode 100644 index 1141242ee6a..00000000000 --- a/storage/immunitycache/chunk_test.go +++ /dev/null @@ -1,103 +0,0 @@ -package immunitycache - -import ( - "math" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestImmunityChunk_ImmunizeKeys(t *testing.T) { - chunk := newUnconstrainedChunkToTest() - - chunk.addTestItems("x", "y", "z") - require.Equal(t, 3, chunk.Count()) - - // No immune items, all removed - numRemoved := chunk.RemoveOldest(42) - require.Equal(t, 3, numRemoved) - require.Equal(t, 0, chunk.Count()) - - chunk.addTestItems("x", "y", "z") - require.Equal(t, 3, chunk.Count()) - - // Immunize some items - numNow, numFuture := chunk.ImmunizeKeys(keysAsBytes([]string{"x", "z"})) - require.Equal(t, 2, numNow) - require.Equal(t, 0, numFuture) - - numRemoved = chunk.RemoveOldest(42) - require.Equal(t, 1, numRemoved) - require.Equal(t, 2, chunk.Count()) - require.Equal(t, []string{"x", "z"}, keysAsStrings(chunk.KeysInOrder())) -} - -func TestImmunityChunk_AddItemIgnoresDuplicates(t *testing.T) { - chunk := newUnconstrainedChunkToTest() - chunk.addTestItems("x", "y", "z") - require.Equal(t, 3, chunk.Count()) - - has, added := chunk.AddItem(newCacheItem("foo", "a", 1)) - require.False(t, has) - require.True(t, added) - require.Equal(t, 4, chunk.Count()) - - has, added = chunk.AddItem(newCacheItem("bar", "x", 1)) - require.True(t, has) - require.False(t, added) - require.Equal(t, 4, chunk.Count()) -} - -func TestImmunityChunk_AddItemEvictsWhenTooMany(t *testing.T) { - chunk := newChunkToTest(3, math.MaxUint32) - chunk.addTestItems("x", "y", "z") - require.Equal(t, 3, chunk.Count()) - - chunk.addTestItems("a", "b") - require.Equal(t, []string{"z", "a", "b"}, keysAsStrings(chunk.KeysInOrder())) -} - -func TestImmunityChunk_AddItemDoesNotEvictImmuneItems(t *testing.T) { - chunk := newChunkToTest(3, math.MaxUint32) - chunk.addTestItems("x", "y", "z") - require.Equal(t, 3, chunk.Count()) - - _, _ = chunk.ImmunizeKeys(keysAsBytes([]string{"x", "y"})) - - chunk.addTestItems("a") - require.Equal(t, []string{"x", "y", "a"}, keysAsStrings(chunk.KeysInOrder())) - chunk.addTestItems("b") - require.Equal(t, []string{"x", "y", "b"}, keysAsStrings(chunk.KeysInOrder())) - - _, _ = chunk.ImmunizeKeys(keysAsBytes([]string{"b"})) - has, added := chunk.AddItem(newCacheItem("foo", "c", 1)) - require.False(t, has) - require.False(t, added) - require.Equal(t, []string{"x", "y", "b"}, keysAsStrings(chunk.KeysInOrder())) -} - -func newUnconstrainedChunkToTest() *immunityChunk { - chunk := newImmunityChunk(immunityChunkConfig{ - maxNumItems: math.MaxUint32, - maxNumBytes: maxNumBytesUpperBound, - numItemsToPreemptivelyEvict: math.MaxUint32, - }) - - return chunk -} - -func newChunkToTest(maxNumItems uint32, numMaxBytes uint32) *immunityChunk { - chunk := newImmunityChunk(immunityChunkConfig{ - maxNumItems: maxNumItems, - maxNumBytes: numMaxBytes, - numItemsToPreemptivelyEvict: 1, - }) - - return chunk -} - -func (chunk *immunityChunk) addTestItems(keys ...string) { - for _, key := range keys { - _, _ = chunk.AddItem(newCacheItem("foo", key, 100)) - } -} diff --git a/storage/immunitycache/config.go b/storage/immunitycache/config.go deleted file mode 100644 index 3c3672eeade..00000000000 --- a/storage/immunitycache/config.go +++ /dev/null @@ -1,84 +0,0 @@ -package immunitycache - -import ( - "encoding/json" - "fmt" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go/storage" -) - -const numChunksLowerBound = 1 -const numChunksUpperBound = 128 -const maxNumItemsLowerBound = 4 -const maxNumBytesLowerBound = maxNumItemsLowerBound * 1 -const maxNumBytesUpperBound = 1_073_741_824 // one GB -const numItemsToPreemptivelyEvictLowerBound = 1 - -// CacheConfig holds cache configuration -type CacheConfig struct { - Name string - NumChunks uint32 - MaxNumItems uint32 - MaxNumBytes uint32 - NumItemsToPreemptivelyEvict uint32 -} - -// Verify verifies the validity of the configuration -func (config *CacheConfig) Verify() error { - if len(config.Name) == 0 { - return fmt.Errorf("%w: config.Name is invalid", storage.ErrInvalidConfig) - } - if config.NumChunks < numChunksLowerBound || config.NumChunks > numChunksUpperBound { - return fmt.Errorf("%w: config.NumChunks is invalid", storage.ErrInvalidConfig) - } - if config.MaxNumItems < maxNumItemsLowerBound { - return fmt.Errorf("%w: config.MaxNumItems is invalid", storage.ErrInvalidConfig) - } - if config.MaxNumBytes < maxNumBytesLowerBound || config.MaxNumBytes > maxNumBytesUpperBound { - return fmt.Errorf("%w: config.MaxNumBytes is invalid", storage.ErrInvalidConfig) - } - if config.NumItemsToPreemptivelyEvict < numItemsToPreemptivelyEvictLowerBound { - return fmt.Errorf("%w: config.NumItemsToPreemptivelyEvict is invalid", storage.ErrInvalidConfig) - } - - return nil -} - -func (config *CacheConfig) getChunkConfig() immunityChunkConfig { - numChunks := core.MaxUint32(config.NumChunks, 1) - - return immunityChunkConfig{ - cacheName: config.Name, - maxNumItems: config.MaxNumItems / numChunks, - maxNumBytes: config.MaxNumBytes / numChunks, - numItemsToPreemptivelyEvict: config.NumItemsToPreemptivelyEvict / numChunks, - } -} - -// String returns a readable representation of the object -func (config *CacheConfig) String() string { - bytes, err := json.Marshal(config) - if err != nil { - log.Error("CacheConfig.String()", "err", err) - } - - return string(bytes) -} - -type immunityChunkConfig struct { - cacheName string - maxNumItems uint32 - maxNumBytes uint32 - numItemsToPreemptivelyEvict uint32 -} - -// String returns a readable representation of the object -func (config *immunityChunkConfig) String() string { - return fmt.Sprintf( - "maxNumItems: %d, maxNumBytes: %d, numItemsToPreemptivelyEvict: %d", - config.maxNumItems, - config.maxNumBytes, - config.numItemsToPreemptivelyEvict, - ) -} diff --git a/storage/immunitycache/testutils_test.go b/storage/immunitycache/testutils_test.go deleted file mode 100644 index 30b61834961..00000000000 --- a/storage/immunitycache/testutils_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package immunitycache - -func keysAsStrings(keys [][]byte) []string { - result := make([]string, len(keys)) - for i := 0; i < len(keys); i++ { - result[i] = string(keys[i]) - } - - return result -} - -func keysAsBytes(keys []string) [][]byte { - result := make([][]byte, len(keys)) - for i := 0; i < len(keys); i++ { - result[i] = []byte(keys[i]) - } - - return result -} diff --git a/storage/interface.go b/storage/interface.go index 48d68eb9b01..11bbda51caf 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -4,47 +4,10 @@ import ( "time" "github.com/ElrondNetwork/elrond-go-core/storage" + "github.com/ElrondNetwork/elrond-go-storage/types" "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/epochStart" ) -// Persister provides storage of data services in a database like construct -type Persister interface { - // Put add the value to the (key, val) persistence medium - Put(key, val []byte) error - // Get gets the value associated to the key - Get(key []byte) ([]byte, error) - // Has returns true if the given key is present in the persistence medium - Has(key []byte) error - // Close closes the files/resources associated to the persistence medium - Close() error - // Remove removes the data associated to the given key - Remove(key []byte) error - // Destroy removes the persistence medium stored data - Destroy() error - // DestroyClosed removes the already closed persistence medium stored data - DestroyClosed() error - RangeKeys(handler func(key []byte, val []byte) bool) - // IsInterfaceNil returns true if there is no value under the interface - IsInterfaceNil() bool -} - -// Batcher allows to batch the data first then write the batch to the persister in one go -type Batcher interface { - // Put inserts one entry - key, value pair - into the batch - Put(key []byte, val []byte) error - // Get returns the value from the batch - Get(key []byte) []byte - // Delete deletes the batch - Delete(key []byte) error - // Reset clears the contents of the batch - Reset() - // IsRemoved returns true if the provided key is marked for deletion - IsRemoved(key []byte) bool - // IsInterfaceNil returns true if there is no value under the interface - IsInterfaceNil() bool -} - // Cacher provides caching services type Cacher interface { // Clear is used to completely clear the cache. @@ -83,6 +46,43 @@ type Cacher interface { IsInterfaceNil() bool } +// Persister provides storage of data services in a database like construct +type Persister interface { + // Put add the value to the (key, val) persistence medium + Put(key, val []byte) error + // Get gets the value associated to the key + Get(key []byte) ([]byte, error) + // Has returns true if the given key is present in the persistence medium + Has(key []byte) error + // Close closes the files/resources associated to the persistence medium + Close() error + // Remove removes the data associated to the given key + Remove(key []byte) error + // Destroy removes the persistence medium stored data + Destroy() error + // DestroyClosed removes the already closed persistence medium stored data + DestroyClosed() error + RangeKeys(handler func(key []byte, val []byte) bool) + // IsInterfaceNil returns true if there is no value under the interface + IsInterfaceNil() bool +} + +// Batcher allows to batch the data first then write the batch to the persister in one go +type Batcher interface { + // Put inserts one entry - key, value pair - into the batch + Put(key []byte, val []byte) error + // Get returns the value from the batch + Get(key []byte) []byte + // Delete deletes the batch + Delete(key []byte) error + // Reset clears the contents of the batch + Reset() + // IsRemoved returns true if the provided key is marked for deletion + IsRemoved(key []byte) bool + // IsInterfaceNil returns true if there is no value under the interface + IsInterfaceNil() bool +} + // Storer provides storage services in a two layered storage construct, where the first layer is // represented by a cache and second layer by a persitent storage (DB-like) type Storer interface { @@ -109,12 +109,6 @@ type StorerWithPutInEpoch interface { SetEpochForPutOperation(epoch uint32) } -// EpochStartNotifier defines which actions should be done for handling new epoch's events -type EpochStartNotifier interface { - RegisterHandler(handler epochStart.ActionHandler) - IsInterfaceNil() bool -} - // PathManagerHandler defines which actions should be done for generating paths for databases directories type PathManagerHandler interface { PathForEpoch(shardId string, epoch uint32, identifier string) string @@ -154,12 +148,7 @@ type LatestStorageDataProviderHandler interface { } // LatestDataFromStorage represents the DTO structure to return from storage -type LatestDataFromStorage struct { - Epoch uint32 - ShardID uint32 - LastRound int64 - EpochStartRound uint64 -} +type LatestDataFromStorage = types.LatestDataFromStorage // ShardCoordinator defines what a shard state coordinator should hold type ShardCoordinator interface { @@ -171,20 +160,25 @@ type ShardCoordinator interface { IsInterfaceNil() bool } -// ForEachItem is an iterator callback -type ForEachItem func(key []byte, value interface{}) +// TimeCacher defines the cache that can keep a record for a bounded time +type TimeCacher interface { + Add(key string) error + Upsert(key string, span time.Duration) error + Has(key string) bool + Sweep() + IsInterfaceNil() bool +} -// LRUCacheHandler is the interface for LRU cache. -type LRUCacheHandler interface { - Add(key, value interface{}) bool - Get(key interface{}) (value interface{}, ok bool) - Contains(key interface{}) (ok bool) - ContainsOrAdd(key, value interface{}) (ok, evicted bool) - Peek(key interface{}) (value interface{}, ok bool) - Remove(key interface{}) bool - Keys() []interface{} - Len() int - Purge() +// StoredDataFactory creates empty objects of the stored data type +type StoredDataFactory interface { + CreateEmpty() interface{} + IsInterfaceNil() bool +} + +// CustomDatabaseRemoverHandler defines the behaviour of a component that should tell if a database is removable or not +type CustomDatabaseRemoverHandler interface { + ShouldRemove(dbIdentifier string, epoch uint32) bool + IsInterfaceNil() bool } // SizedLRUCacheHandler is the interface for size capable LRU cache. @@ -201,36 +195,9 @@ type SizedLRUCacheHandler interface { Purge() } -// TimeCacher defines the cache that can keep a record for a bounded time -type TimeCacher interface { - Add(key string) error - Upsert(key string, span time.Duration) error - Has(key string) bool - Sweep() - IsInterfaceNil() bool -} - // AdaptedSizedLRUCache defines a cache that returns the evicted value type AdaptedSizedLRUCache interface { SizedLRUCacheHandler AddSizedAndReturnEvicted(key, value interface{}, sizeInBytes int64) map[interface{}]interface{} IsInterfaceNil() bool } - -// StoredDataFactory creates empty objects of the stored data type -type StoredDataFactory interface { - CreateEmpty() interface{} - IsInterfaceNil() bool -} - -// SerializedStoredData defines a data type that has the serialized data as a field -type SerializedStoredData interface { - GetSerialized() []byte - SetSerialized([]byte) -} - -// CustomDatabaseRemoverHandler defines the behaviour of a component that should tell if a database is removable or not -type CustomDatabaseRemoverHandler interface { - ShouldRemove(dbIdentifier string, epoch uint32) bool - IsInterfaceNil() bool -} diff --git a/storage/leveldb/batch.go b/storage/leveldb/batch.go deleted file mode 100644 index 846454bf433..00000000000 --- a/storage/leveldb/batch.go +++ /dev/null @@ -1,79 +0,0 @@ -package leveldb - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/syndtr/goleveldb/leveldb" -) - -var _ storage.Batcher = (*batch)(nil) - -type batch struct { - batch *leveldb.Batch - cachedData map[string][]byte - removedData map[string]struct{} - mutBatch sync.RWMutex -} - -// NewBatch creates a batch -func NewBatch() *batch { - return &batch{ - batch: &leveldb.Batch{}, - cachedData: make(map[string][]byte), - removedData: make(map[string]struct{}), - mutBatch: sync.RWMutex{}, - } -} - -// Put inserts one entry - key, value pair - into the batch -func (b *batch) Put(key []byte, val []byte) error { - b.mutBatch.Lock() - b.batch.Put(key, val) - b.cachedData[string(key)] = val - delete(b.removedData, string(key)) - b.mutBatch.Unlock() - return nil -} - -// Delete deletes the entry for the provided key from the batch -func (b *batch) Delete(key []byte) error { - b.mutBatch.Lock() - b.batch.Delete(key) - b.removedData[string(key)] = struct{}{} - delete(b.cachedData, string(key)) - b.mutBatch.Unlock() - return nil -} - -// Reset clears the contents of the batch -func (b *batch) Reset() { - b.mutBatch.Lock() - b.batch.Reset() - b.cachedData = make(map[string][]byte) - b.removedData = make(map[string]struct{}) - b.mutBatch.Unlock() -} - -// Get returns the value -func (b *batch) Get(key []byte) []byte { - b.mutBatch.RLock() - defer b.mutBatch.RUnlock() - - return b.cachedData[string(key)] -} - -// IsRemoved returns true if the key is marked for removal -func (b *batch) IsRemoved(key []byte) bool { - b.mutBatch.RLock() - defer b.mutBatch.RUnlock() - - _, found := b.removedData[string(key)] - - return found -} - -// IsInterfaceNil returns true if there is no value under the interface -func (b *batch) IsInterfaceNil() bool { - return b == nil -} diff --git a/storage/leveldb/common.go b/storage/leveldb/common.go deleted file mode 100644 index 5d9f827eb9b..00000000000 --- a/storage/leveldb/common.go +++ /dev/null @@ -1,137 +0,0 @@ -package leveldb - -import ( - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/errors" - "github.com/syndtr/goleveldb/leveldb/opt" -) - -const resourceUnavailable = "resource temporarily unavailable" -const maxRetries = 10 -const timeBetweenRetries = time.Second - -// loggingDBCounter this variable should be used only used in logging prints -var loggingDBCounter = uint32(0) - -func openLevelDB(path string, options *opt.Options) (*leveldb.DB, error) { - retries := 0 - for { - db, err := openOneTime(path, options) - if err == nil { - return db, nil - } - if err.Error() != resourceUnavailable { - return nil, err - } - - log.Debug("error opening DB", - "error", err, - "path", path, - "retry", retries, - ) - - time.Sleep(timeBetweenRetries) - retries++ - if retries > maxRetries { - return nil, fmt.Errorf("%w, retried %d number of times", err, maxRetries) - } - } -} - -func openOneTime(path string, options *opt.Options) (*leveldb.DB, error) { - db, errOpen := leveldb.OpenFile(path, options) - if errOpen == nil { - return db, nil - } - - if errors.IsCorrupted(errOpen) { - var errRecover error - log.Warn("corrupted DB file", - "path", path, - "error", errOpen, - ) - db, errRecover = leveldb.RecoverFile(path, options) - if errRecover != nil { - return nil, fmt.Errorf("%w while recovering DB %s, after the initial failure %s", - errRecover, - path, - errOpen.Error(), - ) - } - log.Info("DB file recovered", - "path", path, - ) - - return db, nil - } - - return nil, errOpen -} - -type baseLevelDb struct { - mutDb sync.RWMutex - path string - db *leveldb.DB -} - -func (bldb *baseLevelDb) getDbPointer() *leveldb.DB { - bldb.mutDb.RLock() - defer bldb.mutDb.RUnlock() - - return bldb.db -} - -func (bldb *baseLevelDb) makeDbPointerNilReturningLast() *leveldb.DB { - bldb.mutDb.Lock() - defer bldb.mutDb.Unlock() - - if bldb.db != nil { - crtCounter := atomic.AddUint32(&loggingDBCounter, ^uint32(0)) // subtract 1 - log.Debug("makeDbPointerNilReturningLast", "path", bldb.path, "nilled pointer", fmt.Sprintf("%p", bldb.db), "global db counter", crtCounter) - } - - db := bldb.db - bldb.db = nil - - return db -} - -// RangeKeys will call the handler function for each (key, value) pair -// If the handler returns true, the iteration will continue, otherwise will stop -func (bldb *baseLevelDb) RangeKeys(handler func(key []byte, value []byte) bool) { - if handler == nil { - return - } - - db := bldb.getDbPointer() - if db == nil { - return - } - - iterator := db.NewIterator(nil, nil) - for { - if !iterator.Next() { - break - } - - key := iterator.Key() - clonedKey := make([]byte, len(key)) - copy(clonedKey, key) - - val := iterator.Value() - clonedVal := make([]byte, len(val)) - copy(clonedVal, val) - - shouldContinue := handler(clonedKey, clonedVal) - if !shouldContinue { - break - } - } - - iterator.Release() -} diff --git a/storage/leveldb/leveldb.go b/storage/leveldb/leveldb.go deleted file mode 100644 index 61876c85119..00000000000 --- a/storage/leveldb/leveldb.go +++ /dev/null @@ -1,294 +0,0 @@ -package leveldb - -import ( - "context" - "fmt" - "os" - "runtime" - "sync" - "sync/atomic" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/opt" -) - -var _ storage.Persister = (*DB)(nil) - -// read + write + execute for owner only -const rwxOwner = 0700 -const mkdirAllFunction = "mkdirAll" -const openLevelDBFunction = "openLevelDB" - -var log = logger.GetOrCreate("storage/leveldb") - -// DB holds a pointer to the leveldb database and the path to where it is stored. -type DB struct { - *baseLevelDb - maxBatchSize int - batchDelaySeconds int - sizeBatch int - batch storage.Batcher - mutBatch sync.RWMutex - cancel context.CancelFunc -} - -// NewDB is a constructor for the leveldb persister -// It creates the files in the location given as parameter -func NewDB(path string, batchDelaySeconds int, maxBatchSize int, maxOpenFiles int) (s *DB, err error) { - constructorName := "NewDB" - - sw := core.NewStopWatch() - sw.Start(constructorName) - - sw.Start(mkdirAllFunction) - err = os.MkdirAll(path, rwxOwner) - if err != nil { - return nil, err - } - sw.Stop(mkdirAllFunction) - - if maxOpenFiles < 1 { - return nil, storage.ErrInvalidNumOpenFiles - } - - options := &opt.Options{ - // disable internal cache - BlockCacheCapacity: -1, - OpenFilesCacheCapacity: maxOpenFiles, - } - - sw.Start(openLevelDBFunction) - db, err := openLevelDB(path, options) - if err != nil { - return nil, fmt.Errorf("%w for path %s", err, path) - } - sw.Stop(openLevelDBFunction) - - bldb := &baseLevelDb{ - db: db, - path: path, - } - - ctx, cancel := context.WithCancel(context.Background()) - dbStore := &DB{ - baseLevelDb: bldb, - maxBatchSize: maxBatchSize, - batchDelaySeconds: batchDelaySeconds, - sizeBatch: 0, - cancel: cancel, - } - - dbStore.batch = dbStore.createBatch() - - go dbStore.batchTimeoutHandle(ctx) - - runtime.SetFinalizer(dbStore, func(db *DB) { - _ = db.Close() - }) - - crtCounter := atomic.AddUint32(&loggingDBCounter, 1) - sw.Stop(constructorName) - - logArguments := []interface{}{"path", path, "created pointer", fmt.Sprintf("%p", bldb.db), "global db counter", crtCounter} - logArguments = append(logArguments, sw.GetMeasurements()...) - log.Debug("opened level db persister", logArguments...) - - return dbStore, nil -} - -func (s *DB) batchTimeoutHandle(ctx context.Context) { - interval := time.Duration(s.batchDelaySeconds) * time.Second - timer := time.NewTimer(interval) - defer timer.Stop() - - for { - timer.Reset(interval) - - select { - case <-timer.C: - s.mutBatch.Lock() - err := s.putBatch(s.batch) - if err != nil { - log.Warn("leveldb putBatch", "error", err.Error()) - s.mutBatch.Unlock() - continue - } - - s.batch.Reset() - s.sizeBatch = 0 - s.mutBatch.Unlock() - case <-ctx.Done(): - log.Debug("closing the timed batch handler", "path", s.path) - return - } - } -} - -func (s *DB) updateBatchWithIncrement() error { - s.mutBatch.Lock() - defer s.mutBatch.Unlock() - - s.sizeBatch++ - if s.sizeBatch < s.maxBatchSize { - return nil - } - - err := s.putBatch(s.batch) - if err != nil { - log.Warn("leveldb putBatch", "error", err.Error()) - return err - } - - s.batch.Reset() - s.sizeBatch = 0 - - return nil -} - -// Put adds the value to the (key, val) storage medium -func (s *DB) Put(key, val []byte) error { - err := s.batch.Put(key, val) - if err != nil { - return err - } - - return s.updateBatchWithIncrement() -} - -// Get returns the value associated to the key -func (s *DB) Get(key []byte) ([]byte, error) { - db := s.getDbPointer() - if db == nil { - return nil, errors.ErrDBIsClosed - } - - if s.batch.IsRemoved(key) { - return nil, storage.ErrKeyNotFound - } - - data := s.batch.Get(key) - if data != nil { - return data, nil - } - - data, err := db.Get(key, nil) - if err == leveldb.ErrNotFound { - return nil, storage.ErrKeyNotFound - } - if err != nil { - return nil, err - } - - return data, nil -} - -// Has returns nil if the given key is present in the persistence medium -func (s *DB) Has(key []byte) error { - db := s.getDbPointer() - if db == nil { - return errors.ErrDBIsClosed - } - - if s.batch.IsRemoved(key) { - return storage.ErrKeyNotFound - } - - data := s.batch.Get(key) - if data != nil { - return nil - } - - has, err := db.Has(key, nil) - if err != nil { - return err - } - - if has { - return nil - } - - return storage.ErrKeyNotFound -} - -// CreateBatch returns a batcher to be used for batch writing data to the database -func (s *DB) createBatch() storage.Batcher { - return NewBatch() -} - -// putBatch writes the Batch data into the database -func (s *DB) putBatch(b storage.Batcher) error { - dbBatch, ok := b.(*batch) - if !ok { - return storage.ErrInvalidBatch - } - - wopt := &opt.WriteOptions{ - Sync: true, - } - - db := s.getDbPointer() - if db == nil { - return errors.ErrDBIsClosed - } - - return db.Write(dbBatch.batch, wopt) -} - -// Close closes the files/resources associated to the storage medium -func (s *DB) Close() error { - s.mutBatch.Lock() - _ = s.putBatch(s.batch) - s.sizeBatch = 0 - s.mutBatch.Unlock() - - s.cancel() - db := s.makeDbPointerNilReturningLast() - if db != nil { - return db.Close() - } - - return nil -} - -// Remove removes the data associated to the given key -func (s *DB) Remove(key []byte) error { - s.mutBatch.Lock() - _ = s.batch.Delete(key) - s.mutBatch.Unlock() - - return s.updateBatchWithIncrement() -} - -// Destroy removes the storage medium stored data -func (s *DB) Destroy() error { - s.mutBatch.Lock() - s.batch.Reset() - s.sizeBatch = 0 - s.mutBatch.Unlock() - - s.cancel() - db := s.makeDbPointerNilReturningLast() - if db != nil { - err := db.Close() - if err != nil { - return err - } - } - - return os.RemoveAll(s.path) -} - -// DestroyClosed removes the already closed storage medium stored data -func (s *DB) DestroyClosed() error { - return os.RemoveAll(s.path) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (s *DB) IsInterfaceNil() bool { - return s == nil -} diff --git a/storage/leveldb/leveldbSerial.go b/storage/leveldb/leveldbSerial.go deleted file mode 100644 index cc535f1f52e..00000000000 --- a/storage/leveldb/leveldbSerial.go +++ /dev/null @@ -1,351 +0,0 @@ -package leveldb - -import ( - "context" - "fmt" - "os" - "runtime" - "sync" - "sync/atomic" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/closing" - "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/opt" -) - -var _ storage.Persister = (*SerialDB)(nil) - -// SerialDB holds a pointer to the leveldb database and the path to where it is stored. -type SerialDB struct { - *baseLevelDb - maxBatchSize int - batchDelaySeconds int - sizeBatch int - batch storage.Batcher - mutBatch sync.RWMutex - dbAccess chan serialQueryer - cancel context.CancelFunc - closer core.SafeCloser -} - -// NewSerialDB is a constructor for the leveldb persister -// It creates the files in the location given as parameter -func NewSerialDB(path string, batchDelaySeconds int, maxBatchSize int, maxOpenFiles int) (s *SerialDB, err error) { - constructorName := "NewSerialDB" - - sw := core.NewStopWatch() - sw.Start(constructorName) - - sw.Start(mkdirAllFunction) - err = os.MkdirAll(path, rwxOwner) - if err != nil { - return nil, err - } - sw.Stop(mkdirAllFunction) - - if maxOpenFiles < 1 { - return nil, storage.ErrInvalidNumOpenFiles - } - - options := &opt.Options{ - // disable internal cache - BlockCacheCapacity: -1, - OpenFilesCacheCapacity: maxOpenFiles, - } - - sw.Start(openLevelDBFunction) - db, err := openLevelDB(path, options) - if err != nil { - return nil, fmt.Errorf("%w for path %s", err, path) - } - sw.Stop(openLevelDBFunction) - - bldb := &baseLevelDb{ - db: db, - path: path, - } - - ctx, cancel := context.WithCancel(context.Background()) - dbStore := &SerialDB{ - baseLevelDb: bldb, - maxBatchSize: maxBatchSize, - batchDelaySeconds: batchDelaySeconds, - sizeBatch: 0, - dbAccess: make(chan serialQueryer), - cancel: cancel, - closer: closing.NewSafeChanCloser(), - } - - dbStore.batch = NewBatch() - - go dbStore.batchTimeoutHandle(ctx) - go dbStore.processLoop(ctx) - - runtime.SetFinalizer(dbStore, func(db *SerialDB) { - _ = db.Close() - }) - - crtCounter := atomic.AddUint32(&loggingDBCounter, 1) - sw.Stop(constructorName) - - logArguments := []interface{}{"path", path, "created pointer", fmt.Sprintf("%p", bldb.db), "global db counter", crtCounter} - logArguments = append(logArguments, sw.GetMeasurements()...) - log.Debug("opened serial level db persister", logArguments...) - - return dbStore, nil -} - -func (s *SerialDB) batchTimeoutHandle(ctx context.Context) { - interval := time.Duration(s.batchDelaySeconds) * time.Second - timer := time.NewTimer(interval) - defer timer.Stop() - - for { - timer.Reset(interval) - - select { - case <-timer.C: - err := s.putBatch() - if err != nil { - log.Warn("leveldb serial putBatch", "error", err.Error()) - continue - } - case <-ctx.Done(): - log.Debug("batchTimeoutHandle - closing", "path", s.path) - return - } - } -} - -func (s *SerialDB) updateBatchWithIncrement() error { - s.mutBatch.Lock() - s.sizeBatch++ - if s.sizeBatch < s.maxBatchSize { - s.mutBatch.Unlock() - return nil - } - s.mutBatch.Unlock() - - err := s.putBatch() - - return err -} - -// Put adds the value to the (key, val) storage medium -func (s *SerialDB) Put(key, val []byte) error { - if s.isClosed() { - return errors.ErrDBIsClosed - } - - s.mutBatch.RLock() - err := s.batch.Put(key, val) - s.mutBatch.RUnlock() - if err != nil { - return err - } - - return s.updateBatchWithIncrement() -} - -// Get returns the value associated to the key -func (s *SerialDB) Get(key []byte) ([]byte, error) { - if s.isClosed() { - return nil, errors.ErrDBIsClosed - } - - s.mutBatch.RLock() - if s.batch.IsRemoved(key) { - s.mutBatch.RUnlock() - return nil, storage.ErrKeyNotFound - } - - data := s.batch.Get(key) - s.mutBatch.RUnlock() - - if data != nil { - return data, nil - } - - ch := make(chan *pairResult) - req := &getAct{ - key: key, - resChan: ch, - } - - err := s.tryWriteInDbAccessChan(req) - if err != nil { - return nil, err - } - result := <-ch - close(ch) - - if result.err == leveldb.ErrNotFound { - return nil, storage.ErrKeyNotFound - } - if result.err != nil { - return nil, result.err - } - - return result.value, nil -} - -// Has returns nil if the given key is present in the persistence medium -func (s *SerialDB) Has(key []byte) error { - if s.isClosed() { - return errors.ErrDBIsClosed - } - - s.mutBatch.RLock() - if s.batch.IsRemoved(key) { - s.mutBatch.RUnlock() - return storage.ErrKeyNotFound - } - - data := s.batch.Get(key) - s.mutBatch.RUnlock() - - if data != nil { - return nil - } - - ch := make(chan error) - req := &hasAct{ - key: key, - resChan: ch, - } - - err := s.tryWriteInDbAccessChan(req) - if err != nil { - return err - } - result := <-ch - close(ch) - - return result -} - -func (s *SerialDB) tryWriteInDbAccessChan(req serialQueryer) error { - select { - case s.dbAccess <- req: - return nil - case <-s.closer.ChanClose(): - return errors.ErrDBIsClosed - } -} - -// putBatch writes the Batch data into the database -func (s *SerialDB) putBatch() error { - s.mutBatch.Lock() - dbBatch, ok := s.batch.(*batch) - if !ok { - s.mutBatch.Unlock() - return storage.ErrInvalidBatch - } - s.sizeBatch = 0 - s.batch = NewBatch() - s.mutBatch.Unlock() - - ch := make(chan error) - req := &putBatchAct{ - batch: dbBatch, - resChan: ch, - } - - err := s.tryWriteInDbAccessChan(req) - if err != nil { - return err - } - result := <-ch - close(ch) - - return result -} - -func (s *SerialDB) isClosed() bool { - db := s.getDbPointer() - - return db == nil -} - -// Close closes the files/resources associated to the storage medium -func (s *SerialDB) Close() error { - // calling close on the SafeCloser instance should be the last instruction called - // (just to close some go routines started as edge cases that would otherwise hang) - defer s.closer.Close() - - return s.doClose() -} - -// Remove removes the data associated to the given key -func (s *SerialDB) Remove(key []byte) error { - if s.isClosed() { - return errors.ErrDBIsClosed - } - - s.mutBatch.Lock() - _ = s.batch.Delete(key) - s.mutBatch.Unlock() - - return s.updateBatchWithIncrement() -} - -// Destroy removes the storage medium stored data -func (s *SerialDB) Destroy() error { - log.Debug("serialDB.Destroy", "path", s.path) - - // calling close on the SafeCloser instance should be the last instruction called - // (just to close some go routines started as edge cases that would otherwise hang) - defer s.closer.Close() - - err := s.doClose() - if err == nil { - return os.RemoveAll(s.path) - } - - return err -} - -// DestroyClosed removes the already closed storage medium stored data -func (s *SerialDB) DestroyClosed() error { - err := os.RemoveAll(s.path) - if err != nil { - log.Error("error destroy closed", "error", err, "path", s.path) - } - return err -} - -// doClose will handle the closing of the internal components -// must be called under mutex protection -// TODO: re-use this function in leveldb.go as well -func (s *SerialDB) doClose() error { - _ = s.putBatch() - s.cancel() - - db := s.makeDbPointerNilReturningLast() - if db != nil { - return db.Close() - } - - return nil -} - -func (s *SerialDB) processLoop(ctx context.Context) { - for { - select { - case queryer := <-s.dbAccess: - queryer.request(s) - case <-ctx.Done(): - log.Debug("processLoop - closing the leveldb process loop", "path", s.path) - return - } - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (s *SerialDB) IsInterfaceNil() bool { - return s == nil -} diff --git a/storage/leveldb/leveldbSerial_test.go b/storage/leveldb/leveldbSerial_test.go deleted file mode 100644 index 6f0a1707d8b..00000000000 --- a/storage/leveldb/leveldbSerial_test.go +++ /dev/null @@ -1,318 +0,0 @@ -package leveldb_test - -import ( - "fmt" - "math/big" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/leveldb" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func createSerialLevelDb(tb testing.TB, batchDelaySeconds int, maxBatchSize int, maxOpenFiles int) (p *leveldb.SerialDB) { - lvdb, err := leveldb.NewSerialDB(tb.TempDir(), batchDelaySeconds, maxBatchSize, maxOpenFiles) - - assert.Nil(tb, err, "Failed creating leveldb database file") - return lvdb -} - -func TestSerialDB_PutNoError(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createSerialLevelDb(t, 10, 1, 10) - - err := ldb.Put(key, val) - - assert.Nil(t, err, "error saving in DB") -} - -func TestSerialDB_GetErrorAfterPutBeforeTimeout(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createSerialLevelDb(t, 1, 100, 10) - - _ = ldb.Put(key, val) - v, err := ldb.Get(key) - - assert.Equal(t, val, v) - assert.Nil(t, err) -} - -func TestSerialDB_GetErrorOnFail(t *testing.T) { - ldb := createSerialLevelDb(t, 10, 1, 10) - _ = ldb.Destroy() - - v, err := ldb.Get([]byte("key")) - assert.Nil(t, v) - assert.NotNil(t, err) -} - -func TestSerialDB_MethodCallsAfterCloseOrDestroy(t *testing.T) { - t.Parallel() - - t.Run("when closing", func(t *testing.T) { - t.Parallel() - - testSerialDbAllMethodsShouldNotPanic(t, func(db *leveldb.SerialDB) { - _ = db.Close() - }) - }) - t.Run("when destroying", func(t *testing.T) { - t.Parallel() - - testSerialDbAllMethodsShouldNotPanic(t, func(db *leveldb.SerialDB) { - _ = db.Destroy() - }) - }) -} - -func testSerialDbAllMethodsShouldNotPanic(t *testing.T, closeHandler func(db *leveldb.SerialDB)) { - ldb := createSerialLevelDb(t, 10, 1, 10) - - defer func() { - r := recover() - if r != nil { - assert.Fail(t, fmt.Sprintf("should have not panic %v", r)) - } - }() - - closeHandler(ldb) - - _, err := ldb.Get([]byte("key1")) - assert.Equal(t, errors.ErrDBIsClosed, err) - - err = ldb.Has([]byte("key2")) - assert.Equal(t, errors.ErrDBIsClosed, err) - - err = ldb.Remove([]byte("key3")) - assert.Equal(t, errors.ErrDBIsClosed, err) - - err = ldb.Put([]byte("key4"), []byte("val")) - assert.Equal(t, errors.ErrDBIsClosed, err) - - ldb.RangeKeys(func(key []byte, value []byte) bool { - require.Fail(t, "should have not called range") - return false - }) -} - -func TestSerialDB_GetOKAfterPutWithTimeout(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createSerialLevelDb(t, 1, 100, 10) - - _ = ldb.Put(key, val) - time.Sleep(time.Second * 3) - v, err := ldb.Get(key) - - assert.Nil(t, err) - assert.Equal(t, val, v) -} - -func TestSerialDB_RemoveBeforeTimeoutOK(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createSerialLevelDb(t, 1, 100, 10) - - _ = ldb.Put(key, val) - _ = ldb.Remove(key) - time.Sleep(time.Second * 2) - v, err := ldb.Get(key) - - assert.Nil(t, v) - assert.Equal(t, storage.ErrKeyNotFound, err) -} - -func TestSerialDB_RemoveAfterTimeoutOK(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createSerialLevelDb(t, 1, 100, 10) - - _ = ldb.Put(key, val) - time.Sleep(time.Second * 2) - _ = ldb.Remove(key) - v, err := ldb.Get(key) - - assert.Nil(t, v) - assert.Equal(t, storage.ErrKeyNotFound, err) -} - -func TestSerialDB_GetPresent(t *testing.T) { - key, val := []byte("key1"), []byte("value1") - ldb := createSerialLevelDb(t, 10, 1, 10) - - _ = ldb.Put(key, val) - v, err := ldb.Get(key) - - assert.Nil(t, err, "error not expected, but got %s", err) - assert.Equalf(t, v, val, "read:%s but expected: %s", v, val) -} - -func TestSerialDB_GetNotPresent(t *testing.T) { - key := []byte("key2") - ldb := createSerialLevelDb(t, 10, 1, 10) - - v, err := ldb.Get(key) - - assert.NotNil(t, err, "error expected but got nil, value %s", v) -} - -func TestSerialDB_HasPresent(t *testing.T) { - key, val := []byte("key3"), []byte("value3") - ldb := createSerialLevelDb(t, 10, 1, 10) - - _ = ldb.Put(key, val) - err := ldb.Has(key) - - assert.Nil(t, err) -} - -func TestSerialDB_HasNotPresent(t *testing.T) { - key := []byte("key4") - ldb := createSerialLevelDb(t, 10, 1, 10) - - err := ldb.Has(key) - - assert.NotNil(t, err) - assert.Equal(t, err, storage.ErrKeyNotFound) -} - -func TestSerialDB_RemovePresent(t *testing.T) { - key, val := []byte("key5"), []byte("value5") - ldb := createSerialLevelDb(t, 10, 1, 10) - - _ = ldb.Put(key, val) - _ = ldb.Remove(key) - err := ldb.Has(key) - - assert.NotNil(t, err) - assert.Equal(t, err, storage.ErrKeyNotFound) -} - -func TestSerialDB_RemoveNotPresent(t *testing.T) { - key := []byte("key6") - ldb := createSerialLevelDb(t, 10, 1, 10) - - err := ldb.Remove(key) - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestSerialDB_Close(t *testing.T) { - ldb := createSerialLevelDb(t, 10, 1, 10) - - err := ldb.Close() - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestSerialDB_CloseTwice(t *testing.T) { - ldb := createSerialLevelDb(t, 10, 1, 10) - - _ = ldb.Close() - err := ldb.Close() - - assert.Nil(t, err) -} - -func TestSerialDB_Destroy(t *testing.T) { - ldb := createSerialLevelDb(t, 10, 1, 10) - - err := ldb.Destroy() - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestSerialDB_SpecialValueTest(t *testing.T) { - t.Parallel() - - ldb := createSerialLevelDb(t, 100, 100, 10) - key := []byte("key") - removedValue := []byte("removed") // in old implementations we had a check against this value - randomValue := []byte("random") - t.Run("operations: put -> get of 'removed' value", func(t *testing.T) { - err := ldb.Put(key, removedValue) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Nil(t, err) - assert.Equal(t, removedValue, recovered) - }) - t.Run("operations: put -> remove -> get of 'removed' value", func(t *testing.T) { - err := ldb.Put(key, removedValue) - require.Nil(t, err) - - err = ldb.Remove(key) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Equal(t, storage.ErrKeyNotFound, err) - assert.Nil(t, recovered) - }) - t.Run("operations: put -> remove -> put -> get of 'removed' value", func(t *testing.T) { - err := ldb.Put(key, removedValue) - require.Nil(t, err) - - err = ldb.Remove(key) - require.Nil(t, err) - - err = ldb.Put(key, removedValue) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Nil(t, err) - assert.Equal(t, removedValue, recovered) - }) - t.Run("operations: put -> remove -> put -> get of random value", func(t *testing.T) { - err := ldb.Put(key, randomValue) - require.Nil(t, err) - - err = ldb.Remove(key) - require.Nil(t, err) - - err = ldb.Put(key, randomValue) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Nil(t, err) - assert.Equal(t, randomValue, recovered) - }) - - _ = ldb.Close() -} - -func BenchmarkSerialDB_SpecialValueTest(b *testing.B) { - ldb := createSerialLevelDb(b, 10000, 10000000, 10) - key := []byte("key") - removedValue := []byte("removed") // in old implementations we had a check against this value - - b.Run("put -> remove -> get", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ldb.Put(key, removedValue) - _ = ldb.Remove(key) - _, _ = ldb.Get(key) - } - }) - b.Run("put -> get", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ldb.Put(key, removedValue) - _, _ = ldb.Get(key) - } - }) - b.Run("put -> remove -> get with different keys", func(b *testing.B) { - for i := 0; i < b.N; i++ { - testKey := append(key, big.NewInt(int64(i)).Bytes()...) - - _ = ldb.Put(testKey, removedValue) - _ = ldb.Remove(testKey) - _, _ = ldb.Get(testKey) - } - }) - b.Run("put -> get with different keys", func(b *testing.B) { - for i := 0; i < b.N; i++ { - testKey := append(key, big.NewInt(int64(i)).Bytes()...) - - _ = ldb.Put(testKey, removedValue) - _, _ = ldb.Get(testKey) - } - }) -} diff --git a/storage/leveldb/leveldb_test.go b/storage/leveldb/leveldb_test.go deleted file mode 100644 index e538dfbe8a4..00000000000 --- a/storage/leveldb/leveldb_test.go +++ /dev/null @@ -1,412 +0,0 @@ -package leveldb_test - -import ( - "crypto/rand" - "fmt" - "os" - "path" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/leveldb" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func createLevelDb(t *testing.T, batchDelaySeconds int, maxBatchSize int, maxOpenFiles int) (p *leveldb.DB) { - lvdb, err := leveldb.NewDB(t.TempDir(), batchDelaySeconds, maxBatchSize, maxOpenFiles) - - assert.Nil(t, err, "Failed creating leveldb database file") - return lvdb -} - -func TestDB_CorruptdeDBShouldRecover(t *testing.T) { - dir := t.TempDir() - db, err := leveldb.NewDB(dir, 10, 1, 10) - require.Nil(t, err) - - key := []byte("key") - val := []byte("val") - err = db.Put(key, val) - require.Nil(t, err) - _ = db.Close() - - err = os.Remove(path.Join(dir, "MANIFEST-000000")) - require.Nil(t, err) - - dbRecovered, err := leveldb.NewDB(dir, 10, 1, 10) - if err != nil { - assert.Fail(t, fmt.Sprintf("should have not errored %s", err.Error())) - return - } - - valRecovered, err := dbRecovered.Get(key) - assert.Nil(t, err) - _ = dbRecovered.Close() - - assert.Equal(t, val, valRecovered) -} - -func TestDB_DoubleOpenShouldError(t *testing.T) { - dir := t.TempDir() - lvdb1, err := leveldb.NewDB(dir, 10, 1, 10) - require.Nil(t, err) - - defer func() { - _ = lvdb1.Close() - }() - - _, err = leveldb.NewDB(dir, 10, 1, 10) - assert.NotNil(t, err) -} - -func TestDB_DoubleOpenButClosedInTimeShouldWork(t *testing.T) { - dir := t.TempDir() - lvdb1, err := leveldb.NewDB(dir, 10, 1, 10) - require.Nil(t, err) - - defer func() { - _ = lvdb1.Close() - }() - - go func() { - time.Sleep(time.Second * 3) - _ = lvdb1.Close() - }() - - lvdb2, err := leveldb.NewDB(dir, 10, 1, 10) - assert.Nil(t, err) - assert.NotNil(t, lvdb2) - - _ = lvdb2.Close() -} - -func TestDB_PutNoError(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Put(key, val) - - assert.Nil(t, err, "error saving in DB") -} - -func TestDB_GetErrorAfterPutBeforeTimeout(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createLevelDb(t, 1, 100, 10) - - err := ldb.Put(key, val) - assert.Nil(t, err) - v, err := ldb.Get(key) - assert.Equal(t, val, v) - assert.Nil(t, err) -} - -func TestDB_GetOKAfterPutWithTimeout(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createLevelDb(t, 1, 100, 10) - - err := ldb.Put(key, val) - assert.Nil(t, err) - time.Sleep(time.Second * 3) - - v, err := ldb.Get(key) - assert.Nil(t, err) - assert.Equal(t, val, v) -} - -func TestDB_GetErrorOnFail(t *testing.T) { - ldb := createLevelDb(t, 1, 100, 10) - _ = ldb.Close() - - v, err := ldb.Get([]byte("key")) - assert.Nil(t, v) - assert.NotNil(t, err) -} - -func TestDB_RemoveBeforeTimeoutOK(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - key, val := []byte("key"), []byte("value") - ldb := createLevelDb(t, 1, 100, 10) - - err := ldb.Put(key, val) - assert.Nil(t, err) - - _ = ldb.Remove(key) - time.Sleep(time.Second * 2) - - v, err := ldb.Get(key) - assert.Nil(t, v) - assert.Equal(t, storage.ErrKeyNotFound, err) -} - -func TestDB_RemoveAfterTimeoutOK(t *testing.T) { - key, val := []byte("key"), []byte("value") - ldb := createLevelDb(t, 1, 100, 10) - - err := ldb.Put(key, val) - assert.Nil(t, err) - time.Sleep(time.Second * 2) - - _ = ldb.Remove(key) - - v, err := ldb.Get(key) - assert.Nil(t, v) - assert.Equal(t, storage.ErrKeyNotFound, err) -} - -func TestDB_GetPresent(t *testing.T) { - key, val := []byte("key1"), []byte("value1") - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Put(key, val) - - assert.Nil(t, err, "error saving in DB") - - v, err := ldb.Get(key) - - assert.Nil(t, err, "error not expected, but got %s", err) - assert.Equalf(t, v, val, "read:%s but expected: %s", v, val) -} - -func TestDB_GetNotPresent(t *testing.T) { - key := []byte("key2") - ldb := createLevelDb(t, 10, 1, 10) - - v, err := ldb.Get(key) - - assert.NotNil(t, err, "error expected but got nil, value %s", v) -} - -func TestDB_HasPresent(t *testing.T) { - key, val := []byte("key3"), []byte("value3") - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Put(key, val) - - assert.Nil(t, err, "error saving in DB") - - err = ldb.Has(key) - - assert.Nil(t, err) -} - -func TestDB_HasNotPresent(t *testing.T) { - key := []byte("key4") - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Has(key) - - assert.NotNil(t, err) - assert.Equal(t, err, storage.ErrKeyNotFound) -} - -func TestDB_RemovePresent(t *testing.T) { - key, val := []byte("key5"), []byte("value5") - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Put(key, val) - - assert.Nil(t, err, "error saving in DB") - - err = ldb.Remove(key) - - assert.Nil(t, err, "no error expected but got %s", err) - - err = ldb.Has(key) - - assert.NotNil(t, err) - assert.Equal(t, err, storage.ErrKeyNotFound) -} - -func TestDB_RemoveNotPresent(t *testing.T) { - key := []byte("key6") - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Remove(key) - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestDB_Close(t *testing.T) { - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Close() - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestDB_Destroy(t *testing.T) { - ldb := createLevelDb(t, 10, 1, 10) - - err := ldb.Destroy() - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestDB_RangeKeys(t *testing.T) { - ldb := createLevelDb(t, 1, 1, 10) - defer func() { - _ = ldb.Close() - }() - - keysVals := map[string][]byte{ - "key1": []byte("value1"), - "key2": []byte("value2"), - "key3": []byte("value3"), - "key4": []byte("value4"), - "key5": []byte("value5"), - "key6": []byte("value6"), - "key7": []byte("value7"), - } - - for key, val := range keysVals { - _ = ldb.Put([]byte(key), val) - } - - time.Sleep(time.Second * 2) - - recovered := make(map[string][]byte) - - handler := func(key []byte, val []byte) bool { - recovered[string(key)] = val - return true - } - - ldb.RangeKeys(handler) - - assert.Equal(t, keysVals, recovered) -} - -func TestDB_PutGetLargeValue(t *testing.T) { - t.Parallel() - - buffLargeValue := make([]byte, 32*1000000) // equivalent to ~1000000 hashes - key := []byte("key") - _, _ = rand.Read(buffLargeValue) - - ldb := createLevelDb(t, 1, 1, 10) - defer func() { - _ = ldb.Close() - }() - - err := ldb.Put(key, buffLargeValue) - assert.Nil(t, err) - - time.Sleep(time.Second * 2) - - recovered, err := ldb.Get(key) - assert.Nil(t, err) - - assert.Equal(t, buffLargeValue, recovered) -} - -func TestDB_MethodCallsAfterCloseOrDestroy(t *testing.T) { - t.Parallel() - - t.Run("when closing", func(t *testing.T) { - t.Parallel() - - testDbAllMethodsShouldNotPanic(t, func(db *leveldb.DB) { - _ = db.Close() - }) - }) - t.Run("when destroying", func(t *testing.T) { - t.Parallel() - - testDbAllMethodsShouldNotPanic(t, func(db *leveldb.DB) { - _ = db.Destroy() - }) - }) -} - -func testDbAllMethodsShouldNotPanic(t *testing.T, closeHandler func(db *leveldb.DB)) { - defer func() { - r := recover() - if r != nil { - assert.Fail(t, fmt.Sprintf("should have not panic %v", r)) - } - }() - - ldb := createLevelDb(t, 1, 1, 10) - closeHandler(ldb) - - err := ldb.Put([]byte("key1"), []byte("val1")) - require.Equal(t, errors.ErrDBIsClosed, err) - - _, err = ldb.Get([]byte("key2")) - require.Equal(t, errors.ErrDBIsClosed, err) - - err = ldb.Has([]byte("key3")) - require.Equal(t, errors.ErrDBIsClosed, err) - - ldb.RangeKeys(func(key []byte, value []byte) bool { - require.Fail(t, "should have not called range") - return false - }) - - err = ldb.Remove([]byte("key4")) - require.Equal(t, errors.ErrDBIsClosed, err) -} - -func TestDB_SpecialValueTest(t *testing.T) { - t.Parallel() - - ldb := createLevelDb(t, 100, 100, 10) - key := []byte("key") - removedValue := []byte("removed") // in old implementations we had a check against this value - randomValue := []byte("random") - t.Run("operations: put -> get of 'removed' value", func(t *testing.T) { - err := ldb.Put(key, removedValue) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Nil(t, err) - assert.Equal(t, removedValue, recovered) - }) - t.Run("operations: put -> remove -> get of 'removed' value", func(t *testing.T) { - err := ldb.Put(key, removedValue) - require.Nil(t, err) - - err = ldb.Remove(key) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Equal(t, storage.ErrKeyNotFound, err) - assert.Nil(t, recovered) - }) - t.Run("operations: put -> remove -> put -> get of 'removed' value", func(t *testing.T) { - err := ldb.Put(key, removedValue) - require.Nil(t, err) - - err = ldb.Remove(key) - require.Nil(t, err) - - err = ldb.Put(key, removedValue) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Nil(t, err) - assert.Equal(t, removedValue, recovered) - }) - t.Run("operations: put -> remove -> put -> get of random value", func(t *testing.T) { - err := ldb.Put(key, randomValue) - require.Nil(t, err) - - err = ldb.Remove(key) - require.Nil(t, err) - - err = ldb.Put(key, randomValue) - require.Nil(t, err) - - recovered, err := ldb.Get(key) - assert.Nil(t, err) - assert.Equal(t, randomValue, recovered) - }) - - _ = ldb.Close() -} diff --git a/storage/leveldb/serialActions.go b/storage/leveldb/serialActions.go deleted file mode 100644 index 4f64d6a9751..00000000000 --- a/storage/leveldb/serialActions.go +++ /dev/null @@ -1,92 +0,0 @@ -package leveldb - -import ( - "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/syndtr/goleveldb/leveldb/opt" -) - -type putBatchAct struct { - batch *batch - resChan chan<- error -} - -type pairResult struct { - value []byte - err error -} - -type serialQueryer interface { - request(s *SerialDB) -} - -type getAct struct { - key []byte - resChan chan<- *pairResult -} - -type hasAct struct { - key []byte - resChan chan<- error -} - -func (p *putBatchAct) request(s *SerialDB) { - p.resChan <- p.doPutRequest(s) -} - -func (p *putBatchAct) doPutRequest(s *SerialDB) error { - db := s.getDbPointer() - if db == nil { - return errors.ErrDBIsClosed - } - - wopt := &opt.WriteOptions{ - Sync: true, - } - - return db.Write(p.batch.batch, wopt) -} - -func (g *getAct) request(s *SerialDB) { - data, err := g.doGetRequest(s) - - res := &pairResult{ - value: data, - err: err, - } - g.resChan <- res -} - -func (g *getAct) doGetRequest(s *SerialDB) ([]byte, error) { - db := s.getDbPointer() - if db == nil { - return nil, errors.ErrDBIsClosed - } - - return db.Get(g.key, nil) -} - -func (h *hasAct) request(s *SerialDB) { - has, err := h.doHasRequest(s) - - if err != nil { - h.resChan <- err - return - } - - if has { - h.resChan <- nil - return - } - - h.resChan <- storage.ErrKeyNotFound -} - -func (h *hasAct) doHasRequest(s *SerialDB) (bool, error) { - db := s.getDbPointer() - if db == nil { - return false, errors.ErrDBIsClosed - } - - return db.Has(h.key, nil) -} diff --git a/storage/lrucache/capacity/capacityLRUCache.go b/storage/lrucache/capacity/capacityLRUCache.go deleted file mode 100644 index 66301822aeb..00000000000 --- a/storage/lrucache/capacity/capacityLRUCache.go +++ /dev/null @@ -1,300 +0,0 @@ -package capacity - -import ( - "container/list" - "fmt" - "sync" - - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var log = logger.GetOrCreate("storage/lrucache/capacity") - -// capacityLRU implements a non thread safe LRU Cache with a max capacity size -type capacityLRU struct { - lock sync.Mutex - size int - maxCapacityInBytes int64 - currentCapacityInBytes int64 - //TODO investigate if we can replace this list with a binary tree. Check also the other implementation lruCache - evictList *list.List - items map[interface{}]*list.Element -} - -// entry is used to hold a value in the evictList -type entry struct { - key interface{} - value interface{} - size int64 -} - -// NewCapacityLRU constructs an CapacityLRU of the given size with a byte size capacity -func NewCapacityLRU(size int, byteCapacity int64) (*capacityLRU, error) { - if size < 1 { - return nil, storage.ErrCacheSizeInvalid - } - if byteCapacity < 1 { - return nil, storage.ErrCacheCapacityInvalid - } - c := &capacityLRU{ - size: size, - maxCapacityInBytes: byteCapacity, - evictList: list.New(), - items: make(map[interface{}]*list.Element), - } - return c, nil -} - -// Purge is used to completely clear the cache. -func (c *capacityLRU) Purge() { - c.lock.Lock() - defer c.lock.Unlock() - - c.items = make(map[interface{}]*list.Element) - c.evictList.Init() - c.currentCapacityInBytes = 0 -} - -// AddSized adds a value to the cache. Returns true if an eviction occurred. -func (c *capacityLRU) AddSized(key, value interface{}, sizeInBytes int64) bool { - c.lock.Lock() - defer c.lock.Unlock() - - c.addSized(key, value, sizeInBytes) - - return c.evictIfNeeded() -} - -func (c *capacityLRU) addSized(key interface{}, value interface{}, sizeInBytes int64) { - if sizeInBytes < 0 { - log.Error("size LRU cache add error", - "key", fmt.Sprintf("%v", key), - "value", fmt.Sprintf("%v", value), - "error", storage.ErrNegativeSizeInBytes, - ) - - return - } - - // Check for existing item - if ent, ok := c.items[key]; ok { - c.update(key, value, sizeInBytes, ent) - } else { - c.addNew(key, value, sizeInBytes) - } -} - -// AddSizedAndReturnEvicted adds the given key-value pair to the cache, and returns the evicted values -func (c *capacityLRU) AddSizedAndReturnEvicted(key, value interface{}, sizeInBytes int64) map[interface{}]interface{} { - c.lock.Lock() - defer c.lock.Unlock() - - c.addSized(key, value, sizeInBytes) - - evictedValues := make(map[interface{}]interface{}) - for c.shouldEvict() { - evicted := c.evictList.Back() - if evicted == nil { - continue - } - - c.removeElement(evicted) - evictedEntry, ok := evicted.Value.(*entry) - if !ok { - continue - } - - evictedValues[evictedEntry.key] = evictedEntry.value - } - - return evictedValues -} - -func (c *capacityLRU) addNew(key interface{}, value interface{}, sizeInBytes int64) { - ent := &entry{ - key: key, - value: value, - size: sizeInBytes, - } - e := c.evictList.PushFront(ent) - c.items[key] = e - c.currentCapacityInBytes += sizeInBytes -} - -func (c *capacityLRU) update(key interface{}, value interface{}, sizeInBytes int64, ent *list.Element) { - c.evictList.MoveToFront(ent) - - e := ent.Value.(*entry) - sizeDiff := sizeInBytes - e.size - e.value = value - e.size = sizeInBytes - c.currentCapacityInBytes += sizeDiff - - c.adjustSize(key, sizeInBytes) -} - -// Get looks up a key's value from the cache. -func (c *capacityLRU) Get(key interface{}) (interface{}, bool) { - c.lock.Lock() - defer c.lock.Unlock() - - if ent, ok := c.items[key]; ok { - c.evictList.MoveToFront(ent) - if ent.Value.(*entry) == nil { - return nil, false - } - - return ent.Value.(*entry).value, true - } - - return nil, false -} - -// Contains checks if a key is in the cache, without updating the recent-ness -// or deleting it for being stale. -func (c *capacityLRU) Contains(key interface{}) bool { - c.lock.Lock() - defer c.lock.Unlock() - - _, ok := c.items[key] - - return ok -} - -// AddSizedIfMissing checks if a key is in the cache without updating the -// recent-ness or deleting it for being stale, and if not, adds the value. -// Returns whether found and whether an eviction occurred. -func (c *capacityLRU) AddSizedIfMissing(key, value interface{}, sizeInBytes int64) (bool, bool) { - if sizeInBytes < 0 { - log.Error("size LRU cache contains or add error", - "key", fmt.Sprintf("%v", key), - "value", fmt.Sprintf("%v", value), - "error", "size in bytes is negative", - ) - - return false, false - } - - c.lock.Lock() - defer c.lock.Unlock() - - _, ok := c.items[key] - if ok { - return true, false - } - c.addNew(key, value, sizeInBytes) - evicted := c.evictIfNeeded() - - return false, evicted -} - -// Peek returns the key value (or undefined if not found) without updating -// the "recently used"-ness of the key. -func (c *capacityLRU) Peek(key interface{}) (interface{}, bool) { - c.lock.Lock() - defer c.lock.Unlock() - - ent, ok := c.items[key] - if ok { - return ent.Value.(*entry).value, true - } - return nil, ok -} - -// Remove removes the provided key from the cache, returning if the -// key was contained. -func (c *capacityLRU) Remove(key interface{}) bool { - c.lock.Lock() - defer c.lock.Unlock() - - if ent, ok := c.items[key]; ok { - c.removeElement(ent) - return true - } - return false -} - -// Keys returns a slice of the keys in the cache, from oldest to newest. -func (c *capacityLRU) Keys() []interface{} { - c.lock.Lock() - defer c.lock.Unlock() - - keys := make([]interface{}, len(c.items)) - i := 0 - for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() { - keys[i] = ent.Value.(*entry).key - i++ - } - return keys -} - -// Len returns the number of items in the cache. -func (c *capacityLRU) Len() int { - c.lock.Lock() - defer c.lock.Unlock() - - return c.evictList.Len() -} - -// SizeInBytesContained returns the size in bytes of all contained elements -func (c *capacityLRU) SizeInBytesContained() uint64 { - c.lock.Lock() - defer c.lock.Unlock() - - return uint64(c.currentCapacityInBytes) -} - -// removeOldest removes the oldest item from the cache. -func (c *capacityLRU) removeOldest() { - ent := c.evictList.Back() - if ent != nil { - c.removeElement(ent) - } -} - -// removeElement is used to remove a given list element from the cache -func (c *capacityLRU) removeElement(e *list.Element) { - c.evictList.Remove(e) - kv := e.Value.(*entry) - delete(c.items, kv.key) - c.currentCapacityInBytes -= kv.size -} - -func (c *capacityLRU) adjustSize(key interface{}, sizeInBytes int64) { - element := c.items[key] - if element == nil || element.Value == nil || element.Value.(*entry) == nil { - return - } - - v := element.Value.(*entry) - c.currentCapacityInBytes -= v.size - v.size = sizeInBytes - element.Value = v - c.currentCapacityInBytes += sizeInBytes - c.evictIfNeeded() -} - -func (c *capacityLRU) shouldEvict() bool { - if c.evictList.Len() == 1 { - // keep at least one element, no matter how large it is - return false - } - - return c.evictList.Len() > c.size || c.currentCapacityInBytes > c.maxCapacityInBytes -} - -func (c *capacityLRU) evictIfNeeded() bool { - evicted := false - for c.shouldEvict() { - c.removeOldest() - evicted = true - } - - return evicted -} - -// IsInterfaceNil returns true if there is no value under the interface -func (c *capacityLRU) IsInterfaceNil() bool { - return c == nil -} diff --git a/storage/lrucache/capacity/capacityLRUCache_test.go b/storage/lrucache/capacity/capacityLRUCache_test.go deleted file mode 100644 index 87afd2c74d5..00000000000 --- a/storage/lrucache/capacity/capacityLRUCache_test.go +++ /dev/null @@ -1,499 +0,0 @@ -package capacity - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/stretchr/testify/assert" -) - -func createDefaultCache() *capacityLRU { - cache, _ := NewCapacityLRU(100, 100) - return cache -} - -//------- NewCapacityLRU - -func TestNewCapacityLRU_WithInvalidSize(t *testing.T) { - t.Parallel() - - size := 0 - capacity := int64(1) - cache, err := NewCapacityLRU(size, capacity) - assert.True(t, check.IfNil(cache)) - assert.Equal(t, storage.ErrCacheSizeInvalid, err) -} - -func TestNewCapacityLRU_WithInvalidCapacity(t *testing.T) { - t.Parallel() - - size := 1 - capacity := int64(0) - cache, err := NewCapacityLRU(size, capacity) - assert.Nil(t, cache) - assert.Equal(t, storage.ErrCacheCapacityInvalid, err) -} - -func TestNewCapacityLRU(t *testing.T) { - t.Parallel() - - size := 1 - capacity := int64(5) - - cache, err := NewCapacityLRU(size, capacity) - assert.False(t, check.IfNil(cache)) - assert.Nil(t, err) - assert.Equal(t, size, cache.size) - assert.Equal(t, capacity, cache.maxCapacityInBytes) - assert.Equal(t, int64(0), cache.currentCapacityInBytes) - assert.NotNil(t, cache.evictList) - assert.NotNil(t, cache.items) -} - -//------- AddSized - -func TestCapacityLRUCache_AddSizedNegativeSizeInBytesShouldReturn(t *testing.T) { - t.Parallel() - - c := createDefaultCache() - data := []byte("test") - key := "key" - c.AddSized(key, data, -1) - - assert.Equal(t, 0, c.Len()) -} - -func TestCapacityLRUCache_AddSizedSimpleTestShouldWork(t *testing.T) { - t.Parallel() - - c := createDefaultCache() - data := []byte("test") - key := "key" - capacity := int64(5) - c.AddSized(key, data, capacity) - - v, ok := c.Get(key) - assert.True(t, ok) - assert.NotNil(t, v) - assert.Equal(t, data, v) - - keys := c.Keys() - assert.Equal(t, 1, len(keys)) - assert.Equal(t, key, keys[0]) -} - -func TestCapacityLRUCache_AddSizedEvictionByCacheSizeShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(3, 100000) - - keys := []string{"key1", "key2", "key3", "key4", "key5"} - - c.AddSized(keys[0], struct{}{}, 0) - assert.Equal(t, 1, c.Len()) - - c.AddSized(keys[1], struct{}{}, 0) - assert.Equal(t, 2, c.Len()) - - c.AddSized(keys[2], struct{}{}, 0) - assert.Equal(t, 3, c.Len()) - - c.AddSized(keys[3], struct{}{}, 0) - assert.Equal(t, 3, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[3])) - - c.AddSized(keys[4], struct{}{}, 0) - assert.Equal(t, 3, c.Len()) - assert.False(t, c.Contains(keys[1])) - assert.True(t, c.Contains(keys[4])) -} - -func TestCapacityLRUCache_AddSizedEvictionBySizeInBytesShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2", "key3", "key4"} - - c.AddSized(keys[0], struct{}{}, 500) - assert.Equal(t, 1, c.Len()) - - c.AddSized(keys[1], struct{}{}, 500) - assert.Equal(t, 2, c.Len()) - - c.AddSized(keys[2], struct{}{}, 500) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[2])) - - c.AddSized(keys[3], struct{}{}, 500) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[1])) - assert.True(t, c.Contains(keys[3])) -} - -func TestCapacityLRUCache_AddSizedEvictionBySizeInBytesOneLargeElementShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2", "key3", "key4"} - - c.AddSized(keys[0], struct{}{}, 500) - assert.Equal(t, 1, c.Len()) - - c.AddSized(keys[1], struct{}{}, 500) - assert.Equal(t, 2, c.Len()) - - c.AddSized(keys[2], struct{}{}, 500) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[2])) - - c.AddSized(keys[3], struct{}{}, 500000) - assert.Equal(t, 1, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.False(t, c.Contains(keys[1])) - assert.False(t, c.Contains(keys[2])) - assert.True(t, c.Contains(keys[3])) -} - -func TestCapacityLRUCache_AddSizedEvictionBySizeInBytesOneLargeElementEvictedBySmallElementsShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2", "key3"} - - c.AddSized(keys[0], struct{}{}, 500000) - assert.Equal(t, 1, c.Len()) - - c.AddSized(keys[1], struct{}{}, 500) - assert.Equal(t, 1, c.Len()) - - c.AddSized(keys[2], struct{}{}, 500) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[1])) - assert.True(t, c.Contains(keys[2])) -} - -func TestCapacityLRUCache_AddSizedEvictionBySizeInBytesExistingOneLargeElementShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2"} - - c.AddSized(keys[0], struct{}{}, 500) - assert.Equal(t, 1, c.Len()) - - c.AddSized(keys[1], struct{}{}, 500) - assert.Equal(t, 2, c.Len()) - - c.AddSized(keys[0], struct{}{}, 500000) - assert.Equal(t, 1, c.Len()) - assert.True(t, c.Contains(keys[0])) - assert.False(t, c.Contains(keys[1])) -} - -//------- AddSizedIfMissing - -func TestCapacityLRUCache_AddSizedIfMissing(t *testing.T) { - t.Parallel() - - c := createDefaultCache() - data := []byte("data1") - key := "key" - - found, evicted := c.AddSizedIfMissing(key, data, 1) - assert.False(t, found) - assert.False(t, evicted) - - v, ok := c.Get(key) - assert.True(t, ok) - assert.NotNil(t, v) - assert.Equal(t, data, v) - - data2 := []byte("data2") - found, evicted = c.AddSizedIfMissing(key, data2, 1) - assert.True(t, found) - assert.False(t, evicted) - - v, ok = c.Get(key) - assert.True(t, ok) - assert.NotNil(t, v) - assert.Equal(t, data, v) -} - -func TestCapacityLRUCache_AddSizedIfMissingNegativeSizeInBytesShouldReturnFalse(t *testing.T) { - t.Parallel() - - c := createDefaultCache() - data := []byte("data1") - key := "key" - - has, evicted := c.AddSizedIfMissing(key, data, -1) - assert.False(t, has) - assert.False(t, evicted) - assert.Equal(t, 0, c.Len()) -} - -//------- Get - -func TestCapacityLRUCache_GetShouldWork(t *testing.T) { - t.Parallel() - - key := "key" - value := &struct{ A int }{A: 10} - - c := createDefaultCache() - c.AddSized(key, value, 0) - - recovered, exists := c.Get(key) - assert.True(t, value == recovered) //pointer testing - assert.True(t, exists) - - recovered, exists = c.Get("key not found") - assert.Nil(t, recovered) - assert.False(t, exists) -} - -//------- Purge - -func TestCapacityLRUCache_PurgeShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2"} - c.AddSized(keys[0], struct{}{}, 500) - c.AddSized(keys[1], struct{}{}, 500) - - c.Purge() - - assert.Equal(t, 0, c.Len()) - assert.Equal(t, int64(0), c.currentCapacityInBytes) -} - -//------- Peek - -func TestCapacityLRUCache_PeekNotFoundShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - val, found := c.Peek("key not found") - - assert.Nil(t, val) - assert.False(t, found) -} - -func TestCapacityLRUCache_PeekShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - key1 := "key1" - key2 := "key2" - val1 := &struct{}{} - - c.AddSized(key1, val1, 0) - c.AddSized(key2, struct{}{}, 0) - - //at this point key2 is more "recent" than key1 - assert.True(t, c.evictList.Front().Value.(*entry).key == key2) - - val, found := c.Peek(key1) - assert.True(t, val == val1) //pointer testing - assert.True(t, found) - - //recentness should not have been altered - assert.True(t, c.evictList.Front().Value.(*entry).key == key2) -} - -//------- Remove - -func TestCapacityLRUCache_RemoveNotFoundShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - removed := c.Remove("key not found") - - assert.False(t, removed) -} - -func TestCapacityLRUCache_RemovedShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - key1 := "key1" - key2 := "key2" - - c.AddSized(key1, struct{}{}, 0) - c.AddSized(key2, struct{}{}, 0) - - assert.Equal(t, 2, c.Len()) - - c.Remove(key1) - - assert.Equal(t, 1, c.Len()) - assert.True(t, c.Contains(key2)) -} - -// ---------- AddSizedAndReturnEvicted - -func TestCapacityLRUCache_AddSizedAndReturnEvictedNegativeSizeInBytesShouldReturn(t *testing.T) { - t.Parallel() - - c := createDefaultCache() - data := []byte("test") - key := "key" - c.AddSizedAndReturnEvicted(key, data, -1) - - assert.Equal(t, 0, c.Len()) -} - -func TestCapacityLRUCache_AddSizedAndReturnEvictedSimpleTestShouldWork(t *testing.T) { - t.Parallel() - - c := createDefaultCache() - data := []byte("test") - key := "key" - capacity := int64(5) - c.AddSizedAndReturnEvicted(key, data, capacity) - - v, ok := c.Get(key) - assert.True(t, ok) - assert.NotNil(t, v) - assert.Equal(t, data, v) - - keys := c.Keys() - assert.Equal(t, 1, len(keys)) - assert.Equal(t, key, keys[0]) -} - -func TestCapacityLRUCache_AddSizedAndReturnEvictedEvictionByCacheSizeShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(3, 100000) - - keys := []string{"key1", "key2", "key3", "key4", "key5"} - values := []string{"val1", "val2", "val3", "val4", "val5"} - - evicted := c.AddSizedAndReturnEvicted(keys[0], values[0], int64(len(values[0]))) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 1, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[1], values[1], int64(len(values[1]))) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 2, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[2], values[2], int64(len(values[2]))) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 3, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[3], values[3], int64(len(values[3]))) - assert.Equal(t, 3, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[3])) - assert.Equal(t, 1, len(evicted)) - assert.Equal(t, values[0], evicted[keys[0]]) - - evicted = c.AddSizedAndReturnEvicted(keys[4], values[4], int64(len(values[4]))) - assert.Equal(t, 3, c.Len()) - assert.False(t, c.Contains(keys[1])) - assert.True(t, c.Contains(keys[4])) - assert.Equal(t, 1, len(evicted)) - assert.Equal(t, values[1], evicted[keys[1]]) -} - -func TestCapacityLRUCache_AddSizedAndReturnEvictedEvictionBySizeInBytesShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2", "key3", "key4"} - values := []string{"val1", "val2", "val3", "val4"} - - evicted := c.AddSizedAndReturnEvicted(keys[0], values[0], 500) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 1, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[1], values[1], 500) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 2, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[2], values[2], 500) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[2])) - assert.Equal(t, 1, len(evicted)) - assert.Equal(t, values[0], evicted[keys[0]]) - - evicted = c.AddSizedAndReturnEvicted(keys[3], values[3], 500) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[1])) - assert.True(t, c.Contains(keys[3])) - assert.Equal(t, 1, len(evicted)) - assert.Equal(t, values[1], evicted[keys[1]]) -} - -func TestCapacityLRUCache_AddSizedAndReturnEvictedEvictionBySizeInBytesOneLargeElementShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2", "key3", "key4"} - values := []string{"val1", "val2", "val3", "val4"} - - evicted := c.AddSizedAndReturnEvicted(keys[0], values[0], 500) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 1, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[1], values[1], 500) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 2, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[2], values[2], 500) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[2])) - assert.Equal(t, 1, len(evicted)) - assert.Equal(t, values[0], evicted[keys[0]]) - - evicted = c.AddSizedAndReturnEvicted(keys[3], values[3], 500000) - assert.Equal(t, 1, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.False(t, c.Contains(keys[1])) - assert.False(t, c.Contains(keys[2])) - assert.True(t, c.Contains(keys[3])) - assert.Equal(t, 2, len(evicted)) - assert.Equal(t, values[1], evicted[keys[1]]) - assert.Equal(t, values[2], evicted[keys[2]]) -} - -func TestCapacityLRUCache_AddSizedAndReturnEvictedEvictionBySizeInBytesOneLargeElementEvictedBySmallElementsShouldWork(t *testing.T) { - t.Parallel() - - c, _ := NewCapacityLRU(100000, 1000) - - keys := []string{"key1", "key2", "key3"} - values := []string{"val1", "val2", "val3"} - - evicted := c.AddSizedAndReturnEvicted(keys[0], values[0], 500000) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 1, c.Len()) - - evicted = c.AddSizedAndReturnEvicted(keys[1], values[1], 500) - assert.Equal(t, 1, c.Len()) - assert.Equal(t, 1, len(evicted)) - assert.Equal(t, values[0], evicted[keys[0]]) - - evicted = c.AddSizedAndReturnEvicted(keys[2], values[2], 500) - assert.Equal(t, 0, len(evicted)) - assert.Equal(t, 2, c.Len()) - assert.False(t, c.Contains(keys[0])) - assert.True(t, c.Contains(keys[1])) - assert.True(t, c.Contains(keys[2])) -} diff --git a/storage/lrucache/export_test.go b/storage/lrucache/export_test.go deleted file mode 100644 index 92889ed2690..00000000000 --- a/storage/lrucache/export_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package lrucache - -func (c *lruCache) AddedDataHandlers() map[string]func(key []byte, value interface{}) { - return c.mapDataHandlers -} diff --git a/storage/lrucache/lrucache.go b/storage/lrucache/lrucache.go deleted file mode 100644 index d035280fc81..00000000000 --- a/storage/lrucache/lrucache.go +++ /dev/null @@ -1,195 +0,0 @@ -package lrucache - -import ( - "sync" - - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache/capacity" - lru "github.com/hashicorp/golang-lru" -) - -var _ storage.Cacher = (*lruCache)(nil) - -var log = logger.GetOrCreate("storage/lrucache") - -// LRUCache implements a Least Recently Used eviction cache -type lruCache struct { - cache storage.SizedLRUCacheHandler - maxsize int - - mutAddedDataHandlers sync.RWMutex - mapDataHandlers map[string]func(key []byte, value interface{}) -} - -// NewCache creates a new LRU cache instance -func NewCache(size int) (*lruCache, error) { - cache, err := lru.New(size) - if err != nil { - return nil, err - } - - c := createLRUCache(size, cache) - - return c, nil -} - -// NewCacheWithEviction creates a new sized LRU cache instance with eviction function -func NewCacheWithEviction(size int, onEvicted func(key interface{}, value interface{})) (*lruCache, error) { - cache, err := lru.NewWithEvict(size, onEvicted) - if err != nil { - return nil, err - } - - c := createLRUCache(size, cache) - - return c, nil -} - -func createLRUCache(size int, cache *lru.Cache) *lruCache { - c := &lruCache{ - cache: &simpleLRUCacheAdapter{ - LRUCacheHandler: cache, - }, - maxsize: size, - mutAddedDataHandlers: sync.RWMutex{}, - mapDataHandlers: make(map[string]func(key []byte, value interface{})), - } - return c -} - -// NewCacheWithSizeInBytes creates a new sized LRU cache instance -func NewCacheWithSizeInBytes(size int, sizeInBytes int64) (*lruCache, error) { - cache, err := capacity.NewCapacityLRU(size, sizeInBytes) - if err != nil { - return nil, err - } - - c := &lruCache{ - cache: cache, - maxsize: size, - mutAddedDataHandlers: sync.RWMutex{}, - mapDataHandlers: make(map[string]func(key []byte, value interface{})), - } - - return c, nil -} - -// Clear is used to completely clear the cache. -func (c *lruCache) Clear() { - c.cache.Purge() -} - -// Put adds a value to the cache. Returns true if an eviction occurred. -func (c *lruCache) Put(key []byte, value interface{}, sizeInBytes int) (evicted bool) { - evicted = c.cache.AddSized(string(key), value, int64(sizeInBytes)) - - c.callAddedDataHandlers(key, value) - - return evicted -} - -// RegisterHandler registers a new handler to be called when a new data is added -func (c *lruCache) RegisterHandler(handler func(key []byte, value interface{}), id string) { - if handler == nil { - log.Error("attempt to register a nil handler to a cacher object") - return - } - - c.mutAddedDataHandlers.Lock() - c.mapDataHandlers[id] = handler - c.mutAddedDataHandlers.Unlock() -} - -// UnRegisterHandler removes the handler from the list -func (c *lruCache) UnRegisterHandler(id string) { - c.mutAddedDataHandlers.Lock() - delete(c.mapDataHandlers, id) - c.mutAddedDataHandlers.Unlock() -} - -// Get looks up a key's value from the cache. -func (c *lruCache) Get(key []byte) (value interface{}, ok bool) { - return c.cache.Get(string(key)) -} - -// Has checks if a key is in the cache, without updating the -// recent-ness or deleting it for being stale. -func (c *lruCache) Has(key []byte) bool { - return c.cache.Contains(string(key)) -} - -// Peek returns the key value (or undefined if not found) without updating -// the "recently used"-ness of the key. -func (c *lruCache) Peek(key []byte) (value interface{}, ok bool) { - v, ok := c.cache.Peek(string(key)) - - if !ok { - return nil, ok - } - - return v, ok -} - -// HasOrAdd checks if a key is in the cache without updating the -// recent-ness or deleting it for being stale, and if not, adds the value. -// Returns whether found and whether an eviction occurred. -func (c *lruCache) HasOrAdd(key []byte, value interface{}, sizeInBytes int) (has, added bool) { - has, _ = c.cache.AddSizedIfMissing(string(key), value, int64(sizeInBytes)) - - if !has { - c.callAddedDataHandlers(key, value) - } - - return has, !has -} - -func (c *lruCache) callAddedDataHandlers(key []byte, value interface{}) { - c.mutAddedDataHandlers.RLock() - for _, handler := range c.mapDataHandlers { - go handler(key, value) - } - c.mutAddedDataHandlers.RUnlock() -} - -// Remove removes the provided key from the cache. -func (c *lruCache) Remove(key []byte) { - c.cache.Remove(string(key)) -} - -// Keys returns a slice of the keys in the cache, from oldest to newest. -func (c *lruCache) Keys() [][]byte { - res := c.cache.Keys() - r := make([][]byte, len(res)) - - for i := 0; i < len(res); i++ { - r[i] = []byte(res[i].(string)) - } - - return r -} - -// Len returns the number of items in the cache. -func (c *lruCache) Len() int { - return c.cache.Len() -} - -// SizeInBytesContained returns the size in bytes of all contained elements -func (c *lruCache) SizeInBytesContained() uint64 { - return c.cache.SizeInBytesContained() -} - -// MaxSize returns the maximum number of items which can be stored in cache. -func (c *lruCache) MaxSize() int { - return c.maxsize -} - -// Close does nothing for this cacher implementation -func (c *lruCache) Close() error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (c *lruCache) IsInterfaceNil() bool { - return c == nil -} diff --git a/storage/lrucache/lrucache_test.go b/storage/lrucache/lrucache_test.go deleted file mode 100644 index 911b1194c23..00000000000 --- a/storage/lrucache/lrucache_test.go +++ /dev/null @@ -1,420 +0,0 @@ -package lrucache_test - -import ( - "bytes" - "fmt" - "sync" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/stretchr/testify/assert" -) - -var timeoutWaitForWaitGroups = time.Second * 2 - -//------- NewCache - -func TestNewCache_BadSizeShouldErr(t *testing.T) { - t.Parallel() - - c, err := lrucache.NewCache(0) - - assert.True(t, check.IfNil(c)) - assert.NotNil(t, err) -} - -func TestNewCache_ShouldWork(t *testing.T) { - t.Parallel() - - c, err := lrucache.NewCache(1) - - assert.False(t, check.IfNil(c)) - assert.Nil(t, err) -} - -//------- NewCacheWithSizeInBytes - -func TestNewCacheWithSizeInBytes_BadSizeShouldErr(t *testing.T) { - t.Parallel() - - c, err := lrucache.NewCacheWithSizeInBytes(0, 100000) - - assert.True(t, check.IfNil(c)) - assert.Equal(t, storage.ErrCacheSizeInvalid, err) -} - -func TestNewCacheWithSizeInBytes_BadSizeInBytesShouldErr(t *testing.T) { - t.Parallel() - - c, err := lrucache.NewCacheWithSizeInBytes(1, 0) - - assert.True(t, check.IfNil(c)) - assert.Equal(t, storage.ErrCacheCapacityInvalid, err) -} - -func TestNewCacheWithSizeInBytes_ShouldWork(t *testing.T) { - t.Parallel() - - c, err := lrucache.NewCacheWithSizeInBytes(1, 100000) - - assert.False(t, check.IfNil(c)) - assert.Nil(t, err) -} - -func TestLRUCache_PutNotPresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key"), []byte("value") - c, _ := lrucache.NewCache(10) - - l := c.Len() - - assert.Zero(t, l, "cache expected to be empty") - - c.Put(key, val, 0) - l = c.Len() - - assert.Equal(t, l, 1, "cache size expected 1 but found %d", l) -} - -func TestLRUCache_PutPresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key"), []byte("value") - c, _ := lrucache.NewCache(10) - - c.Put(key, val, 0) - c.Put(key, val, 0) - - l := c.Len() - assert.Equal(t, l, 1, "cache size expected 1 but found %d", l) -} - -func TestLRUCache_PutPresentRewrite(t *testing.T) { - t.Parallel() - - key := []byte("key") - val1 := []byte("value1") - val2 := []byte("value2") - c, _ := lrucache.NewCache(10) - - c.Put(key, val1, 0) - c.Put(key, val2, 0) - - l := c.Len() - assert.Equal(t, l, 1, "cache size expected 1 but found %d", l) - recoveredVal, has := c.Get(key) - assert.True(t, has) - assert.Equal(t, val2, recoveredVal) -} - -func TestLRUCache_GetNotPresent(t *testing.T) { - t.Parallel() - - key := []byte("key1") - c, _ := lrucache.NewCache(10) - - v, ok := c.Get(key) - - assert.False(t, ok, "value %s not expected to be found", v) -} - -func TestLRUCache_GetPresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key2"), []byte("value2") - c, _ := lrucache.NewCache(10) - - c.Put(key, val, 0) - - v, ok := c.Get(key) - - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v) -} - -func TestLRUCache_HasNotPresent(t *testing.T) { - t.Parallel() - - key := []byte("key3") - c, _ := lrucache.NewCache(10) - - found := c.Has(key) - - assert.False(t, found, "key %s not expected to be found", key) -} - -func TestLRUCache_HasPresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key4"), []byte("value4") - c, _ := lrucache.NewCache(10) - - c.Put(key, val, 0) - - found := c.Has(key) - - assert.True(t, found, "value expected but not found") -} - -func TestLRUCache_PeekNotPresent(t *testing.T) { - t.Parallel() - - key := []byte("key5") - c, _ := lrucache.NewCache(10) - - _, ok := c.Peek(key) - - assert.False(t, ok, "not expected to find key %s", key) -} - -func TestLRUCache_PeekPresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key6"), []byte("value6") - c, _ := lrucache.NewCache(10) - - c.Put(key, val, 0) - v, ok := c.Peek(key) - - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v, "expected to find %s but found %s", val, v) -} - -func TestLRUCache_HasOrAddNotPresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key7"), []byte("value7") - c, _ := lrucache.NewCache(10) - - _, ok := c.Peek(key) - assert.False(t, ok, "not expected to find key %s", key) - - c.HasOrAdd(key, val, 0) - v, ok := c.Peek(key) - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v, "expected to find %s but found %s", val, v) -} - -func TestLRUCache_HasOrAddPresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key8"), []byte("value8") - c, _ := lrucache.NewCache(10) - - _, ok := c.Peek(key) - - assert.False(t, ok, "not expected to find key %s", key) - - c.HasOrAdd(key, val, 0) - v, ok := c.Peek(key) - - assert.True(t, ok, "value expected but not found") - assert.Equal(t, val, v, "expected to find %s but found %s", val, v) -} - -func TestLRUCache_RemoveNotPresent(t *testing.T) { - t.Parallel() - - key := []byte("key9") - c, _ := lrucache.NewCache(10) - - found := c.Has(key) - - assert.False(t, found, "not expected to find key %s", key) - - c.Remove(key) - found = c.Has(key) - - assert.False(t, found, "not expected to find key %s", key) -} - -func TestLRUCache_RemovePresent(t *testing.T) { - t.Parallel() - - key, val := []byte("key10"), []byte("value10") - c, _ := lrucache.NewCache(10) - - c.Put(key, val, 0) - found := c.Has(key) - - assert.True(t, found, "expected to find key %s", key) - - c.Remove(key) - found = c.Has(key) - - assert.False(t, found, "not expected to find key %s", key) -} - -func TestLRUCache_Keys(t *testing.T) { - t.Parallel() - - c, _ := lrucache.NewCache(10) - - for i := 0; i < 20; i++ { - key, val := []byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)) - c.Put(key, val, 0) - } - - keys := c.Keys() - - // check also that cache size does not grow over the capacity - assert.Equal(t, 10, len(keys), "expected cache size 10 but current size %d", len(keys)) -} - -func TestLRUCache_Len(t *testing.T) { - t.Parallel() - - c, _ := lrucache.NewCache(10) - - for i := 0; i < 20; i++ { - key, val := []byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)) - c.Put(key, val, 0) - } - - l := c.Len() - - assert.Equal(t, 10, l, "expected cache size 10 but current size %d", l) -} - -func TestLRUCache_Clear(t *testing.T) { - t.Parallel() - - c, _ := lrucache.NewCache(10) - - for i := 0; i < 5; i++ { - key, val := []byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)) - c.Put(key, val, 0) - } - - l := c.Len() - - assert.Equal(t, 5, l, "expected size 5, got %d", l) - - c.Clear() - l = c.Len() - - assert.Zero(t, l, "expected size 0, got %d", l) -} - -func TestLRUCache_CacherRegisterAddedDataHandlerNilHandlerShouldIgnore(t *testing.T) { - t.Parallel() - - c, _ := lrucache.NewCache(100) - c.RegisterHandler(nil, "") - - assert.Equal(t, 0, len(c.AddedDataHandlers())) -} - -func TestLRUCache_CacherRegisterPutAddedDataHandlerShouldWork(t *testing.T) { - t.Parallel() - - wg := sync.WaitGroup{} - wg.Add(1) - chDone := make(chan bool) - - f := func(key []byte, value interface{}) { - if !bytes.Equal([]byte("aaaa"), key) { - return - } - - wg.Done() - } - - go func() { - wg.Wait() - chDone <- true - }() - - c, _ := lrucache.NewCache(100) - c.RegisterHandler(f, "") - c.Put([]byte("aaaa"), "bbbb", 0) - - select { - case <-chDone: - case <-time.After(timeoutWaitForWaitGroups): - assert.Fail(t, "should have been called") - return - } - - assert.Equal(t, 1, len(c.AddedDataHandlers())) -} - -func TestLRUCache_CacherRegisterHasOrAddAddedDataHandlerShouldWork(t *testing.T) { - t.Parallel() - - wg := sync.WaitGroup{} - wg.Add(1) - chDone := make(chan bool) - - f := func(key []byte, value interface{}) { - if !bytes.Equal([]byte("aaaa"), key) { - return - } - - wg.Done() - } - - go func() { - wg.Wait() - chDone <- true - }() - - c, _ := lrucache.NewCache(100) - c.RegisterHandler(f, "") - c.HasOrAdd([]byte("aaaa"), "bbbb", 0) - - select { - case <-chDone: - case <-time.After(timeoutWaitForWaitGroups): - assert.Fail(t, "should have been called") - return - } - - assert.Equal(t, 1, len(c.AddedDataHandlers())) -} - -func TestLRUCache_CacherRegisterHasOrAddAddedDataHandlerNotAddedShouldNotCall(t *testing.T) { - t.Parallel() - - wg := sync.WaitGroup{} - wg.Add(1) - chDone := make(chan bool) - - f := func(key []byte, value interface{}) { - wg.Done() - } - - go func() { - wg.Wait() - chDone <- true - }() - - c, _ := lrucache.NewCache(100) - //first add, no call - c.HasOrAdd([]byte("aaaa"), "bbbb", 0) - c.RegisterHandler(f, "") - //second add, should not call as the data was found - c.HasOrAdd([]byte("aaaa"), "bbbb", 0) - - select { - case <-chDone: - assert.Fail(t, "should have not been called") - return - case <-time.After(timeoutWaitForWaitGroups): - } - - assert.Equal(t, 1, len(c.AddedDataHandlers())) -} - -func TestLRUCache_CloseShouldNotErr(t *testing.T) { - t.Parallel() - - c, _ := lrucache.NewCache(1) - - err := c.Close() - assert.Nil(t, err) -} diff --git a/storage/lrucache/simpleLRUCacheAdapter.go b/storage/lrucache/simpleLRUCacheAdapter.go deleted file mode 100644 index 4f0e5aad6ab..00000000000 --- a/storage/lrucache/simpleLRUCacheAdapter.go +++ /dev/null @@ -1,23 +0,0 @@ -package lrucache - -import "github.com/ElrondNetwork/elrond-go/storage" - -// simpleLRUCacheAdapter provides an adapter between LRUCacheHandler and SizeLRUCacheHandler -type simpleLRUCacheAdapter struct { - storage.LRUCacheHandler -} - -// AddSized calls the Add method without the size in bytes parameter -func (slca *simpleLRUCacheAdapter) AddSized(key, value interface{}, _ int64) bool { - return slca.Add(key, value) -} - -// AddSizedIfMissing calls ContainsOrAdd without the size in bytes parameter -func (slca *simpleLRUCacheAdapter) AddSizedIfMissing(key, value interface{}, _ int64) (ok, evicted bool) { - return slca.ContainsOrAdd(key, value) -} - -// SizeInBytesContained returns 0 -func (slca *simpleLRUCacheAdapter) SizeInBytesContained() uint64 { - return 0 -} diff --git a/storage/memorydb/lruMemoryDB.go b/storage/memorydb/lruMemoryDB.go deleted file mode 100644 index 1bf5d112257..00000000000 --- a/storage/memorydb/lruMemoryDB.go +++ /dev/null @@ -1,106 +0,0 @@ -package memorydb - -import ( - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" -) - -var _ storage.Persister = (*lruDB)(nil) - -// lruDB represents the memory database storage. It holds a LRU of key value pairs -// and a mutex to handle concurrent accesses to the map -type lruDB struct { - cacher storage.Cacher -} - -// NewlruDB creates a lruDB according to size -func NewlruDB(size uint32) (*lruDB, error) { - cacher, err := lrucache.NewCache(int(size)) - if err != nil { - return nil, err - } - - return &lruDB{cacher: cacher}, nil -} - -// Put adds the value to the (key, val) storage medium -func (l *lruDB) Put(key, val []byte) error { - _ = l.cacher.Put(key, val, len(val)) - return nil -} - -// Get gets the value associated to the key, or reports an error -func (l *lruDB) Get(key []byte) ([]byte, error) { - val, ok := l.cacher.Get(key) - if !ok { - return nil, storage.ErrKeyNotFound - } - - mrsVal, ok := val.([]byte) - if !ok { - return nil, storage.ErrKeyNotFound - } - return mrsVal, nil -} - -// Has returns true if the given key is present in the persistence medium, false otherwise -func (l *lruDB) Has(key []byte) error { - has := l.cacher.Has(key) - if has { - return nil - } - return storage.ErrKeyNotFound -} - -// Close closes the files/resources associated to the storage medium -func (l *lruDB) Close() error { - l.cacher.Clear() - return nil -} - -// Remove removes the data associated to the given key -func (l *lruDB) Remove(key []byte) error { - l.cacher.Remove(key) - return nil -} - -// Destroy removes the storage medium stored data -func (l *lruDB) Destroy() error { - l.cacher.Clear() - return nil -} - -// DestroyClosed removes the already closed storage medium stored data -func (l *lruDB) DestroyClosed() error { - return l.Destroy() -} - -// RangeKeys will iterate over all contained (key, value) pairs calling the provided handler -func (l *lruDB) RangeKeys(handler func(key []byte, value []byte) bool) { - if handler == nil { - return - } - - keys := l.cacher.Keys() - for _, k := range keys { - v, ok := l.cacher.Get(k) - if !ok { - continue - } - - vBuff, ok := v.([]byte) - if !ok { - continue - } - - shouldContinue := handler(k, vBuff) - if !shouldContinue { - return - } - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (l *lruDB) IsInterfaceNil() bool { - return l == nil -} diff --git a/storage/memorydb/lruMemoryDB_test.go b/storage/memorydb/lruMemoryDB_test.go deleted file mode 100644 index 88124f98378..00000000000 --- a/storage/memorydb/lruMemoryDB_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package memorydb_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/stretchr/testify/assert" -) - -func TestLruDB_LruDB_InitBadSize(t *testing.T) { - mdb, err := memorydb.NewlruDB(0) - assert.Nil(t, mdb) - assert.NotNil(t, err) -} - -func TestLruDB_PutNoError(t *testing.T) { - key, val := []byte("key"), []byte("value") - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Put(key, val) - - assert.Nil(t, err, "error saving in db") -} - -func TestLruDB_GetPresent(t *testing.T) { - key, val := []byte("key1"), []byte("value1") - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Put(key, val) - - assert.Nil(t, err, "error saving in db") - - v, err := mdb.Get(key) - - assert.Nil(t, err, "error not expected but got %s", err) - assert.Equal(t, val, v, "expected %s but got %s", val, v) -} - -func TestLruDB_GetNotPresent(t *testing.T) { - key := []byte("key2") - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - v, err := mdb.Get(key) - - assert.NotNil(t, err, "error expected but got nil, value %s", v) -} - -func TestLruDB_HasPresent(t *testing.T) { - key, val := []byte("key3"), []byte("value3") - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Put(key, val) - - assert.Nil(t, err, "error saving in db") - - err = mdb.Has(key) - - assert.Nil(t, err, "error not expected but got %s", err) -} - -func TestLruDB_HasNotPresent(t *testing.T) { - key := []byte("key4") - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Has(key) - - assert.Equal(t, storage.ErrKeyNotFound, err) -} - -func TestLruDB_DeletePresent(t *testing.T) { - key, val := []byte("key5"), []byte("value5") - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Put(key, val) - - assert.Nil(t, err, "error saving in db") - - err = mdb.Remove(key) - - assert.Nil(t, err, "no error expected but got %s", err) - - err = mdb.Has(key) - - assert.Equal(t, storage.ErrKeyNotFound, err) -} - -func TestLruDB_DeleteNotPresent(t *testing.T) { - key := []byte("key6") - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Remove(key) - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestLruDB_Close(t *testing.T) { - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Close() - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestLruDB_Destroy(t *testing.T) { - mdb, err := memorydb.NewlruDB(10000) - - assert.Nil(t, err, "failed to create memorydb: %s", err) - - err = mdb.Destroy() - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestLruDB_RangeKeys(t *testing.T) { - t.Parallel() - - mdb, _ := memorydb.NewlruDB(10000) - - keysVals := map[string][]byte{ - "key1": []byte("value1"), - "key2": []byte("value2"), - "key3": []byte("value3"), - "key4": []byte("value4"), - "key5": []byte("value5"), - "key6": []byte("value6"), - "key7": []byte("value7"), - } - - for key, val := range keysVals { - _ = mdb.Put([]byte(key), val) - } - - recovered := make(map[string][]byte) - mdb.RangeKeys(func(key []byte, value []byte) bool { - recovered[string(key)] = value - return true - }) - - assert.Equal(t, keysVals, recovered) -} diff --git a/storage/memorydb/memorydb.go b/storage/memorydb/memorydb.go deleted file mode 100644 index 582e335fbe9..00000000000 --- a/storage/memorydb/memorydb.go +++ /dev/null @@ -1,117 +0,0 @@ -package memorydb - -import ( - "encoding/base64" - "errors" - "fmt" - "sync" - - "github.com/ElrondNetwork/elrond-go/storage" -) - -var _ storage.Persister = (*DB)(nil) - -// DB represents the memory database storage. It holds a map of key value pairs -// and a mutex to handle concurrent accesses to the map -type DB struct { - db map[string][]byte - mutx sync.RWMutex -} - -// New creates a new memorydb object -func New() *DB { - return &DB{ - db: make(map[string][]byte), - mutx: sync.RWMutex{}, - } -} - -// Put adds the value to the (key, val) storage medium -func (s *DB) Put(key, val []byte) error { - s.mutx.Lock() - defer s.mutx.Unlock() - - s.db[string(key)] = val - - return nil -} - -// Get gets the value associated to the key, or reports an error -func (s *DB) Get(key []byte) ([]byte, error) { - s.mutx.RLock() - defer s.mutx.RUnlock() - - val, ok := s.db[string(key)] - - if !ok { - return nil, fmt.Errorf("key: %s not found", base64.StdEncoding.EncodeToString(key)) - } - - return val, nil -} - -// Has returns true if the given key is present in the persistence medium, false otherwise -func (s *DB) Has(key []byte) error { - s.mutx.RLock() - defer s.mutx.RUnlock() - - _, ok := s.db[string(key)] - - if !ok { - return errors.New("key not found") - } - return nil -} - -// Close closes the files/resources associated to the storage medium -func (s *DB) Close() error { - // nothing to do - return nil -} - -// Remove removes the data associated to the given key -func (s *DB) Remove(key []byte) error { - s.mutx.Lock() - defer s.mutx.Unlock() - - delete(s.db, string(key)) - - return nil -} - -// Destroy removes the storage medium stored data -func (s *DB) Destroy() error { - s.mutx.Lock() - defer s.mutx.Unlock() - - s.db = make(map[string][]byte) - - return nil -} - -// RangeKeys will iterate over all contained (key, value) pairs calling the provided handler -func (s *DB) RangeKeys(handler func(key []byte, value []byte) bool) { - if handler == nil { - return - } - - s.mutx.RLock() - defer s.mutx.RUnlock() - - for k, v := range s.db { - shouldContinue := handler([]byte(k), v) - if !shouldContinue { - return - } - } -} - -// DestroyClosed removes the storage medium stored data -func (s *DB) DestroyClosed() error { - return s.Destroy() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (s *DB) IsInterfaceNil() bool { - return s == nil -} diff --git a/storage/memorydb/memorydb_test.go b/storage/memorydb/memorydb_test.go deleted file mode 100644 index 53e5f8e24e4..00000000000 --- a/storage/memorydb/memorydb_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package memorydb_test - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/stretchr/testify/assert" -) - -func TestPutNoError(t *testing.T) { - key, val := []byte("key"), []byte("value") - mdb := memorydb.New() - - err := mdb.Put(key, val) - assert.Nil(t, err, "error saving in db") -} - -func TestGetPresent(t *testing.T) { - key, val := []byte("key1"), []byte("value1") - mdb := memorydb.New() - - err := mdb.Put(key, val) - assert.Nil(t, err, "error saving in db") - - v, err := mdb.Get(key) - assert.Nil(t, err, "error not expected but got %s", err) - assert.Equal(t, val, v, "expected %s but got %s", val, v) -} - -func TestGetNotPresent(t *testing.T) { - key := []byte("key2") - mdb := memorydb.New() - - v, err := mdb.Get(key) - assert.NotNil(t, err, "error expected but got nil, value %s", v) -} - -func TestHasPresent(t *testing.T) { - key, val := []byte("key3"), []byte("value3") - mdb := memorydb.New() - - err := mdb.Put(key, val) - assert.Nil(t, err, "error saving in db") - - err = mdb.Has(key) - assert.Nil(t, err, "error not expected but got %s", err) -} - -func TestHasNotPresent(t *testing.T) { - key := []byte("key4") - mdb := memorydb.New() - - err := mdb.Has(key) - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "key not found") -} - -func TestDeletePresent(t *testing.T) { - key, val := []byte("key5"), []byte("value5") - mdb := memorydb.New() - - err := mdb.Put(key, val) - assert.Nil(t, err, "error saving in db") - - err = mdb.Remove(key) - assert.Nil(t, err, "no error expected but got %s", err) - - err = mdb.Has(key) - assert.NotNil(t, err, "element not expected as already deleted") - assert.Contains(t, err.Error(), "key not found") -} - -func TestDeleteNotPresent(t *testing.T) { - key := []byte("key6") - mdb := memorydb.New() - - err := mdb.Remove(key) - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestClose(t *testing.T) { - mdb := memorydb.New() - - err := mdb.Close() - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestDestroy(t *testing.T) { - mdb := memorydb.New() - - err := mdb.Destroy() - assert.Nil(t, err, "no error expected but got %s", err) -} - -func Test_RangeKeys(t *testing.T) { - t.Parallel() - - mdb := memorydb.New() - - keysVals := map[string][]byte{ - "key1": []byte("value1"), - "key2": []byte("value2"), - "key3": []byte("value3"), - "key4": []byte("value4"), - "key5": []byte("value5"), - "key6": []byte("value6"), - "key7": []byte("value7"), - } - - for key, val := range keysVals { - _ = mdb.Put([]byte(key), val) - } - - recovered := make(map[string][]byte) - - mdb.RangeKeys(func(key []byte, value []byte) bool { - recovered[string(key)] = value - return true - }) - - assert.Equal(t, keysVals, recovered) -} diff --git a/storage/monitoring.go b/storage/monitoring.go deleted file mode 100644 index b6bd6a4af6c..00000000000 --- a/storage/monitoring.go +++ /dev/null @@ -1,17 +0,0 @@ -package storage - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - logger "github.com/ElrondNetwork/elrond-go-logger" -) - -var log = logger.GetOrCreate("storage") - -var cumulatedSizeInBytes atomic.Counter - -// MonitorNewCache adds the size in the global cumulated size variable -func MonitorNewCache(tag string, sizeInBytes uint64) { - cumulatedSizeInBytes.Add(int64(sizeInBytes)) - log.Debug("MonitorNewCache", "name", tag, "capacity", core.ConvertBytes(sizeInBytes), "cumulated", core.ConvertBytes(cumulatedSizeInBytes.GetUint64())) -} diff --git a/storage/pathmanager/errors.go b/storage/pathmanager/errors.go new file mode 100644 index 00000000000..835b436efdc --- /dev/null +++ b/storage/pathmanager/errors.go @@ -0,0 +1,20 @@ +package pathmanager + +import ( + "errors" +) + +// ErrEmptyPruningPathTemplate signals that an empty path template for pruning storers has been provided +var ErrEmptyPruningPathTemplate = errors.New("empty path template for pruning storers") + +// ErrEmptyStaticPathTemplate signals that an empty path template for static storers has been provided +var ErrEmptyStaticPathTemplate = errors.New("empty path template for static storers") + +// ErrInvalidPruningPathTemplate signals that an invalid path template for pruning storers has been provided +var ErrInvalidPruningPathTemplate = errors.New("invalid path template for pruning storers") + +// ErrInvalidStaticPathTemplate signals that an invalid path template for static storers has been provided +var ErrInvalidStaticPathTemplate = errors.New("invalid path template for static storers") + +// ErrInvalidDatabasePath signals that an invalid database path has been provided +var ErrInvalidDatabasePath = errors.New("invalid database path") diff --git a/storage/pathmanager/pathManager.go b/storage/pathmanager/pathManager.go index 8b6d3644994..6c0520428b6 100644 --- a/storage/pathmanager/pathManager.go +++ b/storage/pathmanager/pathManager.go @@ -4,7 +4,6 @@ import ( "fmt" "strings" - "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/storage" ) @@ -20,24 +19,24 @@ type PathManager struct { // NewPathManager will return a new instance of PathManager if the provided arguments are fine func NewPathManager(pruningPathTemplate string, staticPathTemplate string, databasePath string) (*PathManager, error) { if len(pruningPathTemplate) == 0 { - return nil, storage.ErrEmptyPruningPathTemplate + return nil, ErrEmptyPruningPathTemplate } - if !strings.Contains(pruningPathTemplate, common.PathEpochPlaceholder) || - !strings.Contains(pruningPathTemplate, common.PathShardPlaceholder) || - !strings.Contains(pruningPathTemplate, common.PathIdentifierPlaceholder) { - return nil, storage.ErrInvalidPruningPathTemplate + if !strings.Contains(pruningPathTemplate, storage.PathEpochPlaceholder) || + !strings.Contains(pruningPathTemplate, storage.PathShardPlaceholder) || + !strings.Contains(pruningPathTemplate, storage.PathIdentifierPlaceholder) { + return nil, ErrInvalidPruningPathTemplate } if len(staticPathTemplate) == 0 { - return nil, storage.ErrEmptyStaticPathTemplate + return nil, ErrEmptyStaticPathTemplate } - if !strings.Contains(staticPathTemplate, common.PathShardPlaceholder) || - !strings.Contains(staticPathTemplate, common.PathIdentifierPlaceholder) { - return nil, storage.ErrInvalidStaticPathTemplate + if !strings.Contains(staticPathTemplate, storage.PathShardPlaceholder) || + !strings.Contains(staticPathTemplate, storage.PathIdentifierPlaceholder) { + return nil, ErrInvalidStaticPathTemplate } if len(databasePath) == 0 { - return nil, storage.ErrInvalidDatabasePath + return nil, ErrInvalidDatabasePath } return &PathManager{ @@ -50,9 +49,9 @@ func NewPathManager(pruningPathTemplate string, staticPathTemplate string, datab // PathForEpoch will return the new path for a pruning storer func (pm *PathManager) PathForEpoch(shardId string, epoch uint32, identifier string) string { path := pm.pruningPathTemplate - path = strings.Replace(path, common.PathEpochPlaceholder, fmt.Sprintf("%d", epoch), 1) - path = strings.Replace(path, common.PathShardPlaceholder, shardId, 1) - path = strings.Replace(path, common.PathIdentifierPlaceholder, identifier, 1) + path = strings.Replace(path, storage.PathEpochPlaceholder, fmt.Sprintf("%d", epoch), 1) + path = strings.Replace(path, storage.PathShardPlaceholder, shardId, 1) + path = strings.Replace(path, storage.PathIdentifierPlaceholder, identifier, 1) return path } @@ -60,8 +59,8 @@ func (pm *PathManager) PathForEpoch(shardId string, epoch uint32, identifier str // PathForStatic will return the path for a static storer func (pm *PathManager) PathForStatic(shardId string, identifier string) string { path := pm.staticPathTemplate - path = strings.Replace(path, common.PathShardPlaceholder, shardId, 1) - path = strings.Replace(path, common.PathIdentifierPlaceholder, identifier, 1) + path = strings.Replace(path, storage.PathShardPlaceholder, shardId, 1) + path = strings.Replace(path, storage.PathIdentifierPlaceholder, identifier, 1) return path } diff --git a/storage/pathmanager/pathManager_test.go b/storage/pathmanager/pathManager_test.go index 23ebe34f829..63b5871e0df 100644 --- a/storage/pathmanager/pathManager_test.go +++ b/storage/pathmanager/pathManager_test.go @@ -3,7 +3,6 @@ package pathmanager_test import ( "testing" - "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/pathmanager" "github.com/stretchr/testify/assert" ) @@ -13,7 +12,7 @@ func TestNewPathManager_EmptyPruningPathTemplateShouldErr(t *testing.T) { pm, err := pathmanager.NewPathManager("", "shard_[S]/[I]", "db") assert.Nil(t, pm) - assert.Equal(t, storage.ErrEmptyPruningPathTemplate, err) + assert.Equal(t, pathmanager.ErrEmptyPruningPathTemplate, err) } func TestNewPathManager_EmptyStaticPathTemplateShouldErr(t *testing.T) { @@ -21,7 +20,7 @@ func TestNewPathManager_EmptyStaticPathTemplateShouldErr(t *testing.T) { pm, err := pathmanager.NewPathManager("epoch_[E]/shard_[S]/[I]", "", "db") assert.Nil(t, pm) - assert.Equal(t, storage.ErrEmptyStaticPathTemplate, err) + assert.Equal(t, pathmanager.ErrEmptyStaticPathTemplate, err) } func TestNewPathManager_InvalidPruningPathTemplate_NoShardPlaceholder_ShouldErr(t *testing.T) { @@ -29,7 +28,7 @@ func TestNewPathManager_InvalidPruningPathTemplate_NoShardPlaceholder_ShouldErr( pm, err := pathmanager.NewPathManager("epoch_[E]/shard/[I]", "shard_[S]/[I]", "db") assert.Nil(t, pm) - assert.Equal(t, storage.ErrInvalidPruningPathTemplate, err) + assert.Equal(t, pathmanager.ErrInvalidPruningPathTemplate, err) } func TestNewPathManager_InvalidPruningPathTemplate_NoEpochPlaceholder_ShouldErr(t *testing.T) { @@ -37,7 +36,7 @@ func TestNewPathManager_InvalidPruningPathTemplate_NoEpochPlaceholder_ShouldErr( pm, err := pathmanager.NewPathManager("epoch/shard_[S]/[I]", "shard_[S]/[I]", "db") assert.Nil(t, pm) - assert.Equal(t, storage.ErrInvalidPruningPathTemplate, err) + assert.Equal(t, pathmanager.ErrInvalidPruningPathTemplate, err) } func TestNewPathManager_InvalidPathPruningTemplate_NoIdentifierPlaceholder_ShouldErr(t *testing.T) { @@ -45,7 +44,7 @@ func TestNewPathManager_InvalidPathPruningTemplate_NoIdentifierPlaceholder_Shoul pm, err := pathmanager.NewPathManager("epoch_[E]/shard_[S]", "shard_[S]/[I]", "db") assert.Nil(t, pm) - assert.Equal(t, storage.ErrInvalidPruningPathTemplate, err) + assert.Equal(t, pathmanager.ErrInvalidPruningPathTemplate, err) } func TestNewPathManager_InvalidStaticPathTemplate_NoShardPlaceholder_ShouldErr(t *testing.T) { @@ -53,7 +52,7 @@ func TestNewPathManager_InvalidStaticPathTemplate_NoShardPlaceholder_ShouldErr(t pm, err := pathmanager.NewPathManager("epoch_[E]/shard_[S]/[I]", "shard/[I]", "db") assert.Nil(t, pm) - assert.Equal(t, storage.ErrInvalidStaticPathTemplate, err) + assert.Equal(t, pathmanager.ErrInvalidStaticPathTemplate, err) } func TestNewPathManager_InvalidStaticPathTemplate_NoIdentifierPlaceholder_ShouldErr(t *testing.T) { @@ -61,7 +60,7 @@ func TestNewPathManager_InvalidStaticPathTemplate_NoIdentifierPlaceholder_Should pm, err := pathmanager.NewPathManager("epoch_[E]/shard_[S]/[I]", "shard_[S]", "db") assert.Nil(t, pm) - assert.Equal(t, storage.ErrInvalidStaticPathTemplate, err) + assert.Equal(t, pathmanager.ErrInvalidStaticPathTemplate, err) } func TestNewPathManager_OkValsShouldWork(t *testing.T) { diff --git a/storage/pruning/fullHistoryPruningStorer.go b/storage/pruning/fullHistoryPruningStorer.go index bbdeae9c6c1..9d706c672a4 100644 --- a/storage/pruning/fullHistoryPruningStorer.go +++ b/storage/pruning/fullHistoryPruningStorer.go @@ -7,7 +7,7 @@ import ( storageCore "github.com/ElrondNetwork/elrond-go-core/storage" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" ) // FullHistoryPruningStorer represents a storer for full history nodes @@ -60,7 +60,7 @@ func initFullHistoryPruningStorer(args *FullHistoryStorerArgs, shardId string) ( args: args.StorerArgs, shardId: shardId, } - fhps.oldEpochsActivePersistersCache, err = lrucache.NewCacheWithEviction(int(args.NumOfOldActivePersisters), fhps.onEvicted) + fhps.oldEpochsActivePersistersCache, err = cache.NewLRUCacheWithEviction(int(args.NumOfOldActivePersisters), fhps.onEvicted) if err != nil { return nil, err } diff --git a/storage/pruning/fullHistoryPruningStorer_test.go b/storage/pruning/fullHistoryPruningStorer_test.go index 790624c65f7..7d7d507c734 100644 --- a/storage/pruning/fullHistoryPruningStorer_test.go +++ b/storage/pruning/fullHistoryPruningStorer_test.go @@ -15,8 +15,8 @@ import ( logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/storage" + "github.com/ElrondNetwork/elrond-go/storage/database" "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" "github.com/ElrondNetwork/elrond-go/storage/pathmanager" "github.com/ElrondNetwork/elrond-go/storage/pruning" "github.com/stretchr/testify/assert" @@ -115,7 +115,7 @@ func TestNewFullHistoryPruningStorer_GetAfterEvictShouldWork(t *testing.T) { t.Parallel() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() args := getDefaultArgs() args.DbPath = "Epoch_0" args.EpochsData.NumOfActivePersisters = 1 diff --git a/storage/pruning/pruningStorer.go b/storage/pruning/pruningStorer.go index 28d6774b9d8..bf0d778803e 100644 --- a/storage/pruning/pruningStorer.go +++ b/storage/pruning/pruningStorer.go @@ -16,10 +16,9 @@ import ( logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/epochStart/notifier" - elrondErrors "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/clean" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) var _ storage.Storer = (*PruningStorer)(nil) @@ -133,7 +132,7 @@ func initPruningStorer( ) (*PruningStorer, error) { pdb := &PruningStorer{} - cache, err := storageUnit.NewCache(args.CacheConf) + suCache, err := storageunit.NewCache(args.CacheConf) if err != nil { return nil, err } @@ -147,7 +146,7 @@ func initPruningStorer( pdb.identifier = identifier pdb.persisterFactory = args.PersisterFactory pdb.shardCoordinator = args.ShardCoordinator - pdb.cacher = cache + pdb.cacher = suCache pdb.epochPrepareHdr = &block.MetaBlock{Epoch: epochForDefaultEpochPrepareHdr} pdb.epochForPutOperation = args.EpochsData.StartingEpoch pdb.pathManager = args.PathManager @@ -336,13 +335,17 @@ func (ps *PruningStorer) getPersisterToUse() *persisterData { returningPath := "" if persisterToUse != nil { returningPath = persisterToUse.path + log.Debug("active persister not found", + "epoch", ps.epochForPutOperation, + "used", persisterToUse.epoch, + "path", ps.dbPath, + "returning persister", returningPath) + } else { + log.Debug("active persister not found", + "epoch", ps.epochForPutOperation, + "path", ps.dbPath, + "returning persister", returningPath) } - - log.Debug("active persister not found", - "epoch", ps.epochForPutOperation, - "used", persisterToUse.epoch, - "path", ps.dbPath, - "returning persister", returningPath) } return persisterToUse @@ -436,7 +439,7 @@ func (ps *PruningStorer) Get(key []byte) ([]byte, error) { for idx := 0; idx < len(ps.activePersisters); idx++ { val, err := ps.activePersisters[idx].persister.Get(key) if err != nil { - if err == elrondErrors.ErrDBIsClosed { + if err == storage.ErrDBIsClosed { numClosedDbs++ } @@ -449,7 +452,7 @@ func (ps *PruningStorer) Get(key []byte) ([]byte, error) { } if numClosedDbs == len(ps.activePersisters) && len(ps.activePersisters) > 0 { - return nil, elrondErrors.ErrDBIsClosed + return nil, storage.ErrDBIsClosed } return nil, fmt.Errorf("key %s not found in %s", hex.EncodeToString(key), ps.identifier) diff --git a/storage/pruning/pruningStorerArgs.go b/storage/pruning/pruningStorerArgs.go index 3bd15253b62..c44eea91ffd 100644 --- a/storage/pruning/pruningStorerArgs.go +++ b/storage/pruning/pruningStorerArgs.go @@ -3,14 +3,14 @@ package pruning import ( "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/clean" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // StorerArgs will hold the arguments needed for PruningStorer type StorerArgs struct { Identifier string ShardCoordinator storage.ShardCoordinator - CacheConf storageUnit.CacheConfig + CacheConf storageunit.CacheConfig PathManager storage.PathManagerHandler DbPath string PersisterFactory DbFactoryHandler diff --git a/storage/pruning/pruningStorer_test.go b/storage/pruning/pruningStorer_test.go index 467beb68213..5abbccb3234 100644 --- a/storage/pruning/pruningStorer_test.go +++ b/storage/pruning/pruningStorer_test.go @@ -20,14 +20,13 @@ import ( logger "github.com/ElrondNetwork/elrond-go-logger" "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/storage" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/directoryhandler" "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/factory/directoryhandler" - "github.com/ElrondNetwork/elrond-go/storage/leveldb" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" "github.com/ElrondNetwork/elrond-go/storage/mock" "github.com/ElrondNetwork/elrond-go/storage/pathmanager" "github.com/ElrondNetwork/elrond-go/storage/pruning" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,13 +34,13 @@ import ( var log = logger.GetOrCreate("storage/pruning_test") -func getDummyConfig() (storageUnit.CacheConfig, storageUnit.DBConfig) { - cacheConf := storageUnit.CacheConfig{ +func getDummyConfig() (storageunit.CacheConfig, storageunit.DBConfig) { + cacheConf := storageunit.CacheConfig{ Capacity: 10, Type: "LRU", Shards: 3, } - dbConf := storageUnit.DBConfig{ + dbConf := storageunit.DBConfig{ FilePath: "path/Epoch_0/Shard_1", Type: "LvlDBSerial", BatchDelaySeconds: 500, @@ -63,7 +62,7 @@ func getDefaultArgs() *pruning.StorerArgs { persister, exists := persistersMap[path] if !exists { - persister = memorydb.New() + persister = database.NewMemDB() persistersMap[path] = persister } @@ -97,7 +96,7 @@ func getDefaultArgsSerialDB() *pruning.StorerArgs { cacheConf.Capacity = 40 persisterFactory := &mock.PersisterFactoryStub{ CreateCalled: func(path string) (storage.Persister, error) { - return leveldb.NewSerialDB(path, 1, 20, 10) + return database.NewSerialDB(path, 1, 20, 10) }, } pathManager := &testscommon.PathManagerStub{PathForEpochCalled: func(shardId string, epoch uint32, identifier string) string { @@ -270,14 +269,14 @@ func TestPruningStorer_Put_EpochWhichWasSetDoesNotExistShouldNotFind(t *testing. args := getDefaultArgs() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() args.PersisterFactory = &mock.PersisterFactoryStub{ // simulate an opening of an existing database from the file path by saving activePersisters in a map based on their path CreateCalled: func(path string) (storage.Persister, error) { if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -301,17 +300,17 @@ func TestPruningStorer_Put_ShouldPutInSpecifiedEpoch(t *testing.T) { args := getDefaultArgs() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() expectedEpoch := uint32(37) key := fmt.Sprintf("Epoch_%d", expectedEpoch) - persistersByPath[key] = memorydb.New() + persistersByPath[key] = database.NewMemDB() args.PersisterFactory = &mock.PersisterFactoryStub{ // simulate an opening of an existing database from the file path by saving activePersisters in a map based on their path CreateCalled: func(path string) (storage.Persister, error) { if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -442,7 +441,7 @@ func TestNewPruningStorer_GetDataFromClosedPersister(t *testing.T) { t.Parallel() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() args := getDefaultArgs() args.DbPath = "Epoch_0" args.PersisterFactory = &mock.PersisterFactoryStub{ @@ -451,7 +450,7 @@ func TestNewPruningStorer_GetDataFromClosedPersister(t *testing.T) { if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -489,7 +488,7 @@ func TestNewPruningStorer_GetBulkFromEpoch(t *testing.T) { t.Parallel() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() args := getDefaultArgs() args.DbPath = "Epoch_0" args.PersisterFactory = &mock.PersisterFactoryStub{ @@ -498,7 +497,7 @@ func TestNewPruningStorer_GetBulkFromEpoch(t *testing.T) { if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -543,7 +542,7 @@ func TestNewPruningStorer_ChangeEpochDbsShouldNotBeDeletedIfPruningIsDisabled(t if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -624,7 +623,7 @@ func TestPruningStorer_SearchFirst(t *testing.T) { t.Parallel() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() args := getDefaultArgs() args.DbPath = "Epoch_0" args.PersisterFactory = &mock.PersisterFactoryStub{ @@ -633,7 +632,7 @@ func TestPruningStorer_SearchFirst(t *testing.T) { if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -680,7 +679,7 @@ func TestPruningStorer_ChangeEpochWithKeepingFromOldestEpochInMetaBlock(t *testi t.Parallel() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() args := getDefaultArgs() args.DbPath = "Epoch_0" args.PersisterFactory = &mock.PersisterFactoryStub{ @@ -689,7 +688,7 @@ func TestPruningStorer_ChangeEpochWithKeepingFromOldestEpochInMetaBlock(t *testi if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -750,7 +749,7 @@ func TestPruningStorer_ChangeEpochShouldUseMetaBlockFromEpochPrepare(t *testing. t.Parallel() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0"] = memorydb.New() + persistersByPath["Epoch_0"] = database.NewMemDB() args := getDefaultArgs() args.DbPath = "Epoch_0" args.PersisterFactory = &mock.PersisterFactoryStub{ @@ -759,7 +758,7 @@ func TestPruningStorer_ChangeEpochShouldUseMetaBlockFromEpochPrepare(t *testing. if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil @@ -793,7 +792,7 @@ func TestPruningStorer_ChangeEpochWithExisting(t *testing.T) { t.Parallel() persistersByPath := make(map[string]storage.Persister) - persistersByPath["Epoch_0/Shard_0/id"] = memorydb.New() + persistersByPath["Epoch_0/Shard_0/id"] = database.NewMemDB() args := getDefaultArgs() args.DbPath = "Epoch_0" args.PersisterFactory = &mock.PersisterFactoryStub{ @@ -802,7 +801,7 @@ func TestPruningStorer_ChangeEpochWithExisting(t *testing.T) { if _, ok := persistersByPath[path]; ok { return persistersByPath[path], nil } - newPers := memorydb.New() + newPers := database.NewMemDB() persistersByPath[path] = newPers return newPers, nil diff --git a/storage/pruning/triePruningStorer.go b/storage/pruning/triePruningStorer.go index 66f7c0ff9f1..a653e1598f6 100644 --- a/storage/pruning/triePruningStorer.go +++ b/storage/pruning/triePruningStorer.go @@ -5,8 +5,9 @@ import ( "encoding/hex" "fmt" + "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go/common" - elrondErrors "github.com/ElrondNetwork/elrond-go/errors" + "github.com/ElrondNetwork/elrond-go/storage" ) const ( @@ -90,10 +91,10 @@ func (ps *triePruningStorer) PutInEpochWithoutCache(key []byte, data []byte, epo } // GetFromOldEpochsWithoutAddingToCache searches the old epochs for the given key without adding to the cache -func (ps *triePruningStorer) GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, error) { +func (ps *triePruningStorer) GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, core.OptionalUint32, error) { v, ok := ps.cacher.Get(key) if ok && !bytes.Equal([]byte(common.ActiveDBKey), key) { - return v.([]byte), nil + return v.([]byte), core.OptionalUint32{}, nil } ps.lock.RLock() @@ -103,21 +104,25 @@ func (ps *triePruningStorer) GetFromOldEpochsWithoutAddingToCache(key []byte) ([ for idx := 1; idx < len(ps.activePersisters); idx++ { val, err := ps.activePersisters[idx].persister.Get(key) if err != nil { - if err == elrondErrors.ErrDBIsClosed { + if err == storage.ErrDBIsClosed { numClosedDbs++ } continue } - return val, nil + epoch := core.OptionalUint32{ + Value: ps.activePersisters[idx].epoch, + HasValue: true, + } + return val, epoch, nil } if numClosedDbs+1 == len(ps.activePersisters) && len(ps.activePersisters) > 1 { - return nil, elrondErrors.ErrDBIsClosed + return nil, core.OptionalUint32{}, storage.ErrDBIsClosed } - return nil, fmt.Errorf("key %s not found in %s", hex.EncodeToString(key), ps.identifier) + return nil, core.OptionalUint32{}, fmt.Errorf("key %s not found in %s", hex.EncodeToString(key), ps.identifier) } // GetFromLastEpoch searches only the last epoch storer for the given key diff --git a/storage/pruning/triePruningStorer_test.go b/storage/pruning/triePruningStorer_test.go index df97693c492..124f1cc9698 100644 --- a/storage/pruning/triePruningStorer_test.go +++ b/storage/pruning/triePruningStorer_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/storage/mock" "github.com/ElrondNetwork/elrond-go/storage/pruning" @@ -14,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestTriePruningStorer_GetFromOldEpochsWithoutCacheSearchesOnlyOldEpochs(t *testing.T) { +func TestTriePruningStorer_GetFromOldEpochsWithoutCacheSearchesOnlyOldEpochsAndReturnsEpoch(t *testing.T) { t.Parallel() args := getDefaultArgs() @@ -38,13 +37,16 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheSearchesOnlyOldEpochs(t * assert.Nil(t, err) assert.Equal(t, 0, len(cacher.Keys())) - res, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) + res, epoch, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) assert.Equal(t, testVal1, res) assert.Nil(t, err) + assert.True(t, epoch.HasValue) + assert.Equal(t, uint32(0), epoch.Value) - res, err = ps.GetFromOldEpochsWithoutAddingToCache(testKey2) + res, epoch, err = ps.GetFromOldEpochsWithoutAddingToCache(testKey2) assert.Nil(t, res) assert.NotNil(t, err) + assert.False(t, epoch.HasValue) assert.True(t, strings.Contains(err.Error(), "not found")) } @@ -63,7 +65,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheLessActivePersisters(t *t assert.Equal(t, 1, ps.GetNumActivePersisters()) _ = ps.ChangeEpochSimple(1) - val, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) + val, _, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) assert.Nil(t, err) assert.Equal(t, testVal1, val) } @@ -86,7 +88,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheMoreActivePersisters(t *t _ = ps.ChangeEpochSimple(2) _ = ps.ChangeEpochSimple(3) - val, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) + val, _, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) assert.Nil(t, err) assert.Equal(t, testVal1, val) } @@ -105,7 +107,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheAllPersistersClosed(t *te if !exists { persister = &mock.PersisterStub{ GetCalled: func(key []byte) ([]byte, error) { - return nil, errors.ErrDBIsClosed + return nil, storage.ErrDBIsClosed }, } persistersMap[path] = persister @@ -122,9 +124,9 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheAllPersistersClosed(t *te _ = ps.ChangeEpochSimple(3) _ = ps.Close() - val, err := ps.GetFromOldEpochsWithoutAddingToCache([]byte("key")) + val, _, err := ps.GetFromOldEpochsWithoutAddingToCache([]byte("key")) assert.Nil(t, val) - assert.Equal(t, errors.ErrDBIsClosed, err) + assert.Equal(t, storage.ErrDBIsClosed, err) } func TestTriePruningStorer_GetFromOldEpochsWithoutCacheDoesNotSearchInCurrentStorer(t *testing.T) { @@ -145,7 +147,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheDoesNotSearchInCurrentSto assert.Nil(t, err) ps.ClearCache() - res, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) + res, _, err := ps.GetFromOldEpochsWithoutAddingToCache(testKey1) assert.Nil(t, res) assert.NotNil(t, err) assert.True(t, strings.Contains(err.Error(), "not found")) diff --git a/storage/storageCacherAdapter/storageCacherAdapter.go b/storage/storageCacherAdapter/storageCacherAdapter.go deleted file mode 100644 index de678ac4e06..00000000000 --- a/storage/storageCacherAdapter/storageCacherAdapter.go +++ /dev/null @@ -1,284 +0,0 @@ -package storageCacherAdapter - -import ( - "math" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go-core/marshal" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var log = logger.GetOrCreate("storageCacherAdapter") - -type storageCacherAdapter struct { - cacher storage.AdaptedSizedLRUCache - db storage.Persister - lock sync.RWMutex - dbIsClosed bool - - storedDataFactory storage.StoredDataFactory - marshalizer marshal.Marshalizer - numValuesInStorage int -} - -// NewStorageCacherAdapter creates a new storageCacherAdapter -func NewStorageCacherAdapter( - cacher storage.AdaptedSizedLRUCache, - db storage.Persister, - storedDataFactory storage.StoredDataFactory, - marshalizer marshal.Marshalizer, -) (*storageCacherAdapter, error) { - if check.IfNil(cacher) { - return nil, storage.ErrNilCacher - } - if check.IfNil(db) { - return nil, storage.ErrNilPersister - } - if check.IfNil(marshalizer) { - return nil, storage.ErrNilMarshalizer - } - if check.IfNil(storedDataFactory) { - return nil, storage.ErrNilStoredDataFactory - } - - return &storageCacherAdapter{ - cacher: cacher, - db: db, - lock: sync.RWMutex{}, - storedDataFactory: storedDataFactory, - marshalizer: marshalizer, - numValuesInStorage: 0, - }, nil -} - -// Clear clears the cache -func (c *storageCacherAdapter) Clear() { - c.lock.Lock() - defer c.lock.Unlock() - - c.cacher.Purge() -} - -// Put adds the given value in the cacher. If the cacher is full, the evicted values will be persisted to the db -func (c *storageCacherAdapter) Put(key []byte, value interface{}, sizeInBytes int) bool { - c.lock.Lock() - defer c.lock.Unlock() - - evictedValues := c.cacher.AddSizedAndReturnEvicted(string(key), value, int64(sizeInBytes)) - - if c.dbIsClosed { - return len(evictedValues) != 0 - } - - for evictedKey, evictedVal := range evictedValues { - evictedKeyStr, ok := evictedKey.(string) - if !ok { - log.Warn("invalid key type", "key", evictedKey) - continue - } - - evictedValBytes := getBytes(evictedVal, c.marshalizer) - if len(evictedValBytes) == 0 { - continue - } - - err := c.db.Put([]byte(evictedKeyStr), evictedValBytes) - if err != nil { - log.Error("could not save to db", "error", err) - continue - } - - c.numValuesInStorage++ - } - - return len(evictedValues) != 0 -} - -func getBytes(data interface{}, marshalizer marshal.Marshalizer) []byte { - evictedVal, ok := data.(storage.SerializedStoredData) - if ok { - return evictedVal.GetSerialized() - } - - evictedValBytes, err := marshalizer.Marshal(data) - if err != nil { - log.Error("could not marshal value", "error", err) - return nil - } - - return evictedValBytes -} - -// Get returns the value at the given key -func (c *storageCacherAdapter) Get(key []byte) (interface{}, bool) { - c.lock.RLock() - defer c.lock.RUnlock() - - val, ok := c.cacher.Get(string(key)) - if ok { - return val, true - } - - if c.dbIsClosed { - return nil, false - } - - valBytes, err := c.db.Get(key) - if err != nil { - return nil, false - } - - storedData, err := c.getData(valBytes) - if err != nil { - log.Error("could not get data", "error", err) - return nil, false - } - - return storedData, true -} - -func (c *storageCacherAdapter) getData(serializedData []byte) (interface{}, error) { - storedData := c.storedDataFactory.CreateEmpty() - data, ok := storedData.(storage.SerializedStoredData) - if ok { - data.SetSerialized(serializedData) - return data, nil - } - - err := c.marshalizer.Unmarshal(storedData, serializedData) - if err != nil { - return nil, err - } - - return storedData, nil -} - -// Has checks if the given key is present in the storageUnit -func (c *storageCacherAdapter) Has(key []byte) bool { - c.lock.RLock() - defer c.lock.RUnlock() - - isPresent := c.cacher.Contains(string(key)) - if isPresent { - return true - } - - if c.dbIsClosed { - return false - } - - err := c.db.Has(key) - return err == nil -} - -// Peek returns the value at the given key by searching only in cacher -func (c *storageCacherAdapter) Peek(key []byte) (interface{}, bool) { - c.lock.RLock() - defer c.lock.RUnlock() - - return c.cacher.Peek(string(key)) -} - -// HasOrAdd checks if the value exists and adds it otherwise -func (c *storageCacherAdapter) HasOrAdd(key []byte, value interface{}, sizeInBytes int) (bool, bool) { - ok := c.Has(key) - if ok { - return true, false - } - - added := c.Put(key, value, sizeInBytes) - - return false, added -} - -// Remove deletes the given key from the storageUnit -func (c *storageCacherAdapter) Remove(key []byte) { - c.lock.Lock() - defer c.lock.Unlock() - - removed := c.cacher.Remove(string(key)) - if removed || c.dbIsClosed { - return - } - - err := c.db.Remove(key) - if err == nil { - c.numValuesInStorage-- - } -} - -// Keys returns all the keys present in the storageUnit -func (c *storageCacherAdapter) Keys() [][]byte { - c.lock.RLock() - defer c.lock.RUnlock() - - cacherKeys := c.cacher.Keys() - storedKeys := make([][]byte, 0, len(cacherKeys)) - for i := range cacherKeys { - key, ok := cacherKeys[i].(string) - if !ok { - continue - } - - storedKeys = append(storedKeys, []byte(key)) - } - - if c.dbIsClosed { - return storedKeys - } - - getKeys := func(key []byte, _ []byte) bool { - storedKeys = append(storedKeys, key) - return true - } - - c.db.RangeKeys(getKeys) - return storedKeys -} - -// Len returns the number of elements from the storageUnit -func (c *storageCacherAdapter) Len() int { - c.lock.RLock() - defer c.lock.RUnlock() - - cacheLen := c.cacher.Len() - return cacheLen + c.numValuesInStorage -} - -// SizeInBytesContained returns the number of bytes stored in the cache -func (c *storageCacherAdapter) SizeInBytesContained() uint64 { - c.lock.RLock() - defer c.lock.RUnlock() - - return c.cacher.SizeInBytesContained() -} - -// MaxSize returns MaxInt64 -func (c *storageCacherAdapter) MaxSize() int { - return math.MaxInt64 -} - -// RegisterHandler does nothing -func (c *storageCacherAdapter) RegisterHandler(_ func(_ []byte, _ interface{}), _ string) { -} - -// UnRegisterHandler does nothing -func (c *storageCacherAdapter) UnRegisterHandler(_ string) { -} - -// Close closes the underlying db -func (c *storageCacherAdapter) Close() error { - c.lock.Lock() - defer c.lock.Unlock() - - c.dbIsClosed = true - c.numValuesInStorage = 0 - return c.db.Close() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (c *storageCacherAdapter) IsInterfaceNil() bool { - return c == nil -} diff --git a/storage/storageCacherAdapter/storageCacherAdapter_test.go b/storage/storageCacherAdapter/storageCacherAdapter_test.go deleted file mode 100644 index 314ade8ff4e..00000000000 --- a/storage/storageCacherAdapter/storageCacherAdapter_test.go +++ /dev/null @@ -1,699 +0,0 @@ -package storageCacherAdapter - -import ( - "fmt" - "math" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/process/mock" - "github.com/ElrondNetwork/elrond-go/storage" - storageMock "github.com/ElrondNetwork/elrond-go/storage/mock" - "github.com/ElrondNetwork/elrond-go/testscommon" - trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewStorageCacherAdapter_NilCacher(t *testing.T) { - t.Parallel() - - sca, err := NewStorageCacherAdapter( - nil, - &storageMock.PersisterStub{}, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, sca) - assert.Equal(t, storage.ErrNilCacher, err) -} - -func TestNewStorageCacherAdapter_NilDB(t *testing.T) { - t.Parallel() - - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{}, - nil, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.True(t, check.IfNil(sca)) - assert.Equal(t, storage.ErrNilPersister, err) -} - -func TestNewStorageCacherAdapter_NilStoredDataFactory(t *testing.T) { - t.Parallel() - - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{}, - &storageMock.PersisterStub{}, - nil, - &mock.MarshalizerMock{}, - ) - assert.Nil(t, sca) - assert.Equal(t, storage.ErrNilStoredDataFactory, err) -} - -func TestNewStorageCacherAdapter_NilMarshalizer(t *testing.T) { - t.Parallel() - - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{}, - &storageMock.PersisterStub{}, - trieFactory.NewTrieNodeFactory(), - nil, - ) - assert.Nil(t, sca) - assert.Equal(t, storage.ErrNilMarshalizer, err) -} - -func TestStorageCacherAdapter_Clear(t *testing.T) { - t.Parallel() - - purgeCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - PurgeCalled: func() { - purgeCalled = true - }, - }, - &storageMock.PersisterStub{}, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - sca.Clear() - assert.True(t, purgeCalled) -} - -func TestStorageCacherAdapter_Put(t *testing.T) { - t.Parallel() - - addedKey := "key1" - addedVal := []byte("value1") - addSizedAndReturnEvictedCalled := false - putCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - AddSizedAndReturnEvictedCalled: func(key, value interface{}, _ int64) map[interface{}]interface{} { - stringKey, ok := key.(string) - assert.True(t, ok) - assert.Equal(t, addedKey, stringKey) - - res := make(map[interface{}]interface{}) - res[100] = 10 - res[stringKey] = value - - addSizedAndReturnEvictedCalled = true - return res - }, - }, - &storageMock.PersisterStub{ - PutCalled: func(key, _ []byte) error { - assert.Equal(t, []byte(addedKey), key) - putCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - evicted := sca.Put([]byte(addedKey), addedVal, 100) - assert.True(t, evicted) - assert.True(t, putCalled) - assert.True(t, addSizedAndReturnEvictedCalled) -} - -func TestStorageCacherAdapter_PutWithClosedDB(t *testing.T) { - t.Parallel() - - addedKey := "key1" - addedVal := []byte("value1") - addSizedAndReturnEvictedCalled := false - putCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - AddSizedAndReturnEvictedCalled: func(key, value interface{}, _ int64) map[interface{}]interface{} { - stringKey, ok := key.(string) - assert.True(t, ok) - assert.Equal(t, addedKey, stringKey) - - res := make(map[interface{}]interface{}) - res[100] = 10 - res[stringKey] = value - - addSizedAndReturnEvictedCalled = true - return res - }, - }, - &storageMock.PersisterStub{ - PutCalled: func(key, _ []byte) error { - assert.Equal(t, []byte(addedKey), key) - putCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - require.Nil(t, err) - - err = sca.Close() - require.Nil(t, err) - - evicted := sca.Put([]byte(addedKey), addedVal, 100) - assert.True(t, evicted) - assert.False(t, putCalled) - assert.True(t, addSizedAndReturnEvictedCalled) -} - -func TestStorageCacherAdapter_GetFoundInCacherShouldNotCallDbGet(t *testing.T) { - t.Parallel() - - keyString := "key" - cacherGetCalled := false - dbGetCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - GetCalled: func(key interface{}) (interface{}, bool) { - k, ok := key.(string) - assert.True(t, ok) - assert.Equal(t, keyString, k) - - cacherGetCalled = true - return []byte("val"), true - }, - }, - &storageMock.PersisterStub{ - GetCalled: func(_ []byte) ([]byte, error) { - dbGetCalled = true - return nil, nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - retrievedVal, _ := sca.Get([]byte(keyString)) - - assert.Equal(t, []byte("val"), retrievedVal) - assert.True(t, cacherGetCalled) - assert.False(t, dbGetCalled) -} - -type testStoredDataImpl struct { -} - -type testStoredData struct { - Key []byte - Value uint64 -} - -func (t *testStoredDataImpl) CreateEmpty() interface{} { - return &testStoredData{} -} - -func (t *testStoredDataImpl) IsInterfaceNil() bool { - return t == nil -} - -func TestStorageCacherAdapter_GetFromDb(t *testing.T) { - t.Parallel() - - testData := testStoredData{ - Key: []byte("key"), - Value: 100, - } - - marshalizer := &mock.MarshalizerMock{} - cacherGetCalled := false - dbGetCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - GetCalled: func(_ interface{}) (interface{}, bool) { - cacherGetCalled = true - return nil, false - }, - }, - &storageMock.PersisterStub{ - GetCalled: func(_ []byte) ([]byte, error) { - dbGetCalled = true - byteData, err := marshalizer.Marshal(testData) - return byteData, err - }, - }, - &testStoredDataImpl{}, - marshalizer, - ) - assert.Nil(t, err) - - retrievedVal, _ := sca.Get([]byte("key")) - - val, ok := retrievedVal.(*testStoredData) - assert.True(t, ok) - assert.Equal(t, testData.Key, val.Key) - assert.Equal(t, testData.Value, val.Value) - assert.True(t, cacherGetCalled) - assert.True(t, dbGetCalled) -} - -func TestStorageCacherAdapter_GetWithClosedDB(t *testing.T) { - t.Parallel() - - marshalizer := &mock.MarshalizerMock{} - cacherGetCalled := false - dbGetCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - GetCalled: func(_ interface{}) (interface{}, bool) { - cacherGetCalled = true - return nil, false - }, - }, - &storageMock.PersisterStub{ - GetCalled: func(_ []byte) ([]byte, error) { - dbGetCalled = true - return nil, nil - }, - }, - &testStoredDataImpl{}, - marshalizer, - ) - assert.Nil(t, err) - - err = sca.Close() - require.Nil(t, err) - - retrievedVal, _ := sca.Get([]byte("key")) - - val, ok := retrievedVal.(*testStoredData) - assert.False(t, ok) - assert.Nil(t, val) - assert.True(t, cacherGetCalled) - assert.False(t, dbGetCalled) -} - -func TestStorageCacherAdapter_HasReturnsIfFoundInCacher(t *testing.T) { - t.Parallel() - - containsCalled := false - hasCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - ContainsCalled: func(key interface{}) bool { - _, ok := key.(string) - assert.True(t, ok) - - containsCalled = true - return true - }, - }, - &storageMock.PersisterStub{ - HasCalled: func(_ []byte) error { - hasCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - isPresent := sca.Has([]byte("key")) - - assert.True(t, isPresent) - assert.True(t, containsCalled) - assert.False(t, hasCalled) -} - -func TestStorageCacherAdapter_HasReturnsTrueIfFoundInDB(t *testing.T) { - t.Parallel() - - containsCalled := false - hasCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - ContainsCalled: func(_ interface{}) bool { - containsCalled = true - return false - }, - }, - &storageMock.PersisterStub{ - HasCalled: func(_ []byte) error { - hasCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - isPresent := sca.Has([]byte("key")) - - assert.True(t, isPresent) - assert.True(t, containsCalled) - assert.True(t, hasCalled) -} - -func TestStorageCacherAdapter_HasReturnsFalseIfNotFound(t *testing.T) { - t.Parallel() - - containsCalled := false - hasCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - ContainsCalled: func(_ interface{}) bool { - containsCalled = true - return false - }, - }, - &storageMock.PersisterStub{ - HasCalled: func(_ []byte) error { - hasCalled = true - return fmt.Errorf("not found err") - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - isPresent := sca.Has([]byte("key")) - - assert.False(t, isPresent) - assert.True(t, containsCalled) - assert.True(t, hasCalled) -} - -func TestStorageCacherAdapter_HasWithClosedDB(t *testing.T) { - t.Parallel() - - containsCalled := false - hasCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - ContainsCalled: func(_ interface{}) bool { - containsCalled = true - return false - }, - }, - &storageMock.PersisterStub{ - HasCalled: func(_ []byte) error { - hasCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - err = sca.Close() - require.Nil(t, err) - - isPresent := sca.Has([]byte("key")) - - assert.False(t, isPresent) - assert.True(t, containsCalled) - assert.False(t, hasCalled) -} - -func TestStorageCacherAdapter_Peek(t *testing.T) { - t.Parallel() - - peekCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - PeekCalled: func(key interface{}) (interface{}, bool) { - _, ok := key.(string) - assert.True(t, ok) - - peekCalled = true - return "value", true - }, - }, - &storageMock.PersisterStub{}, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - val, ok := sca.Peek([]byte("key")) - - assert.True(t, peekCalled) - assert.True(t, ok) - assert.Equal(t, "value", val) -} - -func TestStorageCacherAdapter_RemoveFromCacherFirst(t *testing.T) { - t.Parallel() - - cacherRemoveCalled := false - dbRemoveCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - RemoveCalled: func(key interface{}) bool { - _, ok := key.(string) - assert.True(t, ok) - - cacherRemoveCalled = true - return true - }, - }, - &storageMock.PersisterStub{ - RemoveCalled: func(key []byte) error { - dbRemoveCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - sca.Remove([]byte("key")) - - assert.True(t, cacherRemoveCalled) - assert.False(t, dbRemoveCalled) -} - -func TestStorageCacherAdapter_RemoveFromDb(t *testing.T) { - t.Parallel() - - cacherRemoveCalled := false - dbRemoveCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - RemoveCalled: func(_ interface{}) bool { - cacherRemoveCalled = true - return false - }, - }, - &storageMock.PersisterStub{ - RemoveCalled: func(_ []byte) error { - dbRemoveCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - sca.Remove([]byte("key")) - - assert.True(t, cacherRemoveCalled) - assert.True(t, dbRemoveCalled) -} - -func TestStorageCacherAdapter_RemoveWithClosedDB(t *testing.T) { - t.Parallel() - - cacherRemoveCalled := false - dbRemoveCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - RemoveCalled: func(_ interface{}) bool { - cacherRemoveCalled = true - return false - }, - }, - &storageMock.PersisterStub{ - RemoveCalled: func(_ []byte) error { - dbRemoveCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - err = sca.Close() - require.Nil(t, err) - - sca.Remove([]byte("key")) - - assert.True(t, cacherRemoveCalled) - assert.False(t, dbRemoveCalled) -} - -func TestStorageCacherAdapter_Keys(t *testing.T) { - t.Parallel() - - db := testscommon.NewMemDbMock() - _ = db.Put([]byte("key"), []byte("val")) - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - KeysCalled: func() []interface{} { - return []interface{}{"key2"} - }, - }, - db, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - keys := sca.Keys() - assert.Equal(t, 2, len(keys)) -} - -func TestStorageCacherAdapter_KeysWithClosedDB(t *testing.T) { - t.Parallel() - - db := testscommon.NewMemDbMock() - _ = db.Put([]byte("key"), []byte("val")) - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - KeysCalled: func() []interface{} { - return []interface{}{"key2"} - }, - }, - db, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - err = sca.Close() - require.Nil(t, err) - - keys := sca.Keys() - assert.Equal(t, 1, len(keys)) - assert.Equal(t, []byte("key2"), keys[0]) -} - -func TestStorageCacherAdapter_Len(t *testing.T) { - t.Parallel() - - db := testscommon.NewMemDbMock() - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - LenCalled: func() int { - return 3 - }, - AddSizedAndReturnEvictedCalled: func(key, value interface{}, sizeInBytes int64) map[interface{}]interface{} { - res := make(map[interface{}]interface{}) - res[key] = value - return res - }, - }, - db, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - _ = sca.Put([]byte("key"), []byte("val"), 3) - numVals := sca.Len() - assert.Equal(t, 4, numVals) -} - -func TestStorageCacherAdapter_SizeInBytesContained(t *testing.T) { - t.Parallel() - - db := testscommon.NewMemDbMock() - _ = db.Put([]byte("key"), []byte("val")) - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{ - SizeInBytesContainedCalled: func() uint64 { - return 1000 - }, - }, - db, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - totalSize := sca.SizeInBytesContained() - assert.Equal(t, uint64(1000), totalSize) -} - -func TestStorageCacherAdapter_MaxSize(t *testing.T) { - t.Parallel() - - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{}, - &storageMock.PersisterStub{}, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - maxSize := sca.MaxSize() - assert.Equal(t, math.MaxInt64, maxSize) -} - -func TestStorageCacherAdapter_RegisterHandler(t *testing.T) { - t.Parallel() - - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{}, - &storageMock.PersisterStub{}, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - sca.RegisterHandler(nil, "") -} - -func TestStorageCacherAdapter_UnRegisterHandler(t *testing.T) { - t.Parallel() - - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{}, - &storageMock.PersisterStub{}, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - sca.UnRegisterHandler("") -} - -func TestStorageCacherAdapter_Close(t *testing.T) { - t.Parallel() - - closeCalled := false - sca, err := NewStorageCacherAdapter( - &storageMock.AdaptedSizedLruCacheStub{}, - &storageMock.PersisterStub{ - CloseCalled: func() error { - closeCalled = true - return nil - }, - }, - trieFactory.NewTrieNodeFactory(), - &mock.MarshalizerMock{}, - ) - assert.Nil(t, err) - - _ = sca.Close() - assert.True(t, closeCalled) -} diff --git a/storage/storageUnit/nilStorer.go b/storage/storageUnit/nilStorer.go deleted file mode 100644 index b7f85e7745a..00000000000 --- a/storage/storageUnit/nilStorer.go +++ /dev/null @@ -1,88 +0,0 @@ -package storageUnit - -import ( - storageCore "github.com/ElrondNetwork/elrond-go-core/storage" - "github.com/ElrondNetwork/elrond-go/storage" -) - -// NilStorer resembles a disabled implementation of the Storer interface -type NilStorer struct { -} - -// NewNilStorer will return a nil storer -func NewNilStorer() *NilStorer { - return new(NilStorer) -} - -// GetFromEpoch will do nothing -func (ns *NilStorer) GetFromEpoch(_ []byte, _ uint32) ([]byte, error) { - return nil, nil -} - -// GetBulkFromEpoch will do nothing -func (ns *NilStorer) GetBulkFromEpoch(_ [][]byte, _ uint32) ([]storageCore.KeyValuePair, error) { - return nil, nil -} - -// SearchFirst will do nothing -func (ns *NilStorer) SearchFirst(_ []byte) ([]byte, error) { - return nil, nil -} - -// Put will do nothing -func (ns *NilStorer) Put(_, _ []byte) error { - return nil -} - -// PutInEpoch will do nothing -func (ns *NilStorer) PutInEpoch(_, _ []byte, _ uint32) error { - return nil -} - -// GetOldestEpoch will return an error that signals that the oldest epoch fetching is not available -func (ns *NilStorer) GetOldestEpoch() (uint32, error) { - return 0, storage.ErrOldestEpochNotAvailable -} - -// Close will do nothing -func (ns *NilStorer) Close() error { - return nil -} - -// Get will do nothing -func (ns *NilStorer) Get(_ []byte) ([]byte, error) { - return nil, nil -} - -// Has will do nothing -func (ns *NilStorer) Has(_ []byte) error { - return nil -} - -// RemoveFromCurrentEpoch will do nothing -func (ns *NilStorer) RemoveFromCurrentEpoch(_ []byte) error { - return nil -} - -// Remove will do nothing -func (ns *NilStorer) Remove(_ []byte) error { - return nil -} - -// ClearCache will do nothing -func (ns *NilStorer) ClearCache() { -} - -// DestroyUnit will do nothing -func (ns *NilStorer) DestroyUnit() error { - return nil -} - -// RangeKeys does nothing -func (ns *NilStorer) RangeKeys(_ func(key []byte, val []byte) bool) { -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ns *NilStorer) IsInterfaceNil() bool { - return ns == nil -} diff --git a/storage/storageUnit/storageunit.go b/storage/storageUnit/storageunit.go deleted file mode 100644 index 65f7eda7bfa..00000000000 --- a/storage/storageUnit/storageunit.go +++ /dev/null @@ -1,409 +0,0 @@ -package storageUnit - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "sync" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go-core/hashing" - "github.com/ElrondNetwork/elrond-go-core/hashing/blake2b" - "github.com/ElrondNetwork/elrond-go-core/hashing/fnv" - "github.com/ElrondNetwork/elrond-go-core/hashing/keccak" - storageCore "github.com/ElrondNetwork/elrond-go-core/storage" - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/common" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/fifocache" - "github.com/ElrondNetwork/elrond-go/storage/leveldb" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" -) - -var _ storage.Storer = (*Unit)(nil) - -// CacheType represents the type of the supported caches -type CacheType string - -// DBType represents the type of the supported databases -type DBType string - -// HasherType represents the type of the supported hash functions -type HasherType string - -// LRUCache is currently the only supported Cache type -const ( - LRUCache CacheType = "LRU" - SizeLRUCache CacheType = "SizeLRU" - FIFOShardedCache CacheType = "FIFOSharded" -) - -var log = logger.GetOrCreate("storage/storageUnit") - -// LvlDB currently the only supported DBs -// More to be added -const ( - LvlDB DBType = "LvlDB" - LvlDBSerial DBType = "LvlDBSerial" - MemoryDB DBType = "MemoryDB" -) - -const ( - // Keccak is the string representation of the keccak hashing function - Keccak HasherType = "Keccak" - // Blake2b is the string representation of the blake2b hashing function - Blake2b HasherType = "Blake2b" - // Fnv is the string representation of the fnv hashing function - Fnv HasherType = "Fnv" -) - -const minimumSizeForLRUCache = 1024 - -// UnitConfig holds the configurable elements of the storage unit -type UnitConfig struct { - CacheConf CacheConfig - DBConf DBConfig -} - -// CacheConfig holds the configurable elements of a cache -type CacheConfig struct { - Name string - Type CacheType - SizeInBytes uint64 - SizeInBytesPerSender uint32 - Capacity uint32 - SizePerSender uint32 - Shards uint32 -} - -// String returns a readable representation of the object -func (config *CacheConfig) String() string { - bytes, err := json.Marshal(config) - if err != nil { - log.Error("CacheConfig.String()", "err", err) - } - - return string(bytes) -} - -// DBConfig holds the configurable elements of a database -type DBConfig struct { - FilePath string - Type DBType - BatchDelaySeconds int - MaxBatchSize int - MaxOpenFiles int -} - -// Unit represents a storer's data bank -// holding the cache and persistence unit -type Unit struct { - lock sync.RWMutex - persister storage.Persister - cacher storage.Cacher -} - -// Put adds data to both cache and persistence medium -func (u *Unit) Put(key, data []byte) error { - u.lock.Lock() - defer u.lock.Unlock() - - u.cacher.Put(key, data, len(data)) - - err := u.persister.Put(key, data) - if err != nil { - u.cacher.Remove(key) - return err - } - - return err -} - -// PutInEpoch will call the Put method as this storer doesn't handle epochs -func (u *Unit) PutInEpoch(key, data []byte, _ uint32) error { - return u.Put(key, data) -} - -// GetOldestEpoch will return an error that signals that the oldest epoch fetching is not available -func (u *Unit) GetOldestEpoch() (uint32, error) { - return 0, storage.ErrOldestEpochNotAvailable -} - -// Close will close unit -func (u *Unit) Close() error { - u.cacher.Clear() - - err := u.persister.Close() - if err != nil { - log.Error("cannot close storage unit persister", "error", err) - return err - } - - return nil -} - -// RangeKeys can iterate over the persisted (key, value) pairs calling the provided handler -func (u *Unit) RangeKeys(handler func(key []byte, value []byte) bool) { - u.persister.RangeKeys(handler) -} - -// Get searches the key in the cache. In case it is not found, -// it further searches it in the associated database. -// In case it is found in the database, the cache is updated with the value as well. -func (u *Unit) Get(key []byte) ([]byte, error) { - u.lock.Lock() - defer u.lock.Unlock() - - v, ok := u.cacher.Get(key) - var err error - - if !ok { - // not found in cache - // search it in second persistence medium - - v, err = u.persister.Get(key) - if err != nil { - return nil, err - } - - buff, okAssertion := v.([]byte) - if !okAssertion { - return nil, fmt.Errorf("key: %s is not a byte slice", base64.StdEncoding.EncodeToString(key)) - } - - // if found in persistence unit, add it in cache - u.cacher.Put(key, v, len(buff)) - } - - return v.([]byte), nil -} - -// GetFromEpoch will call the Get method as this storer doesn't handle epochs -func (u *Unit) GetFromEpoch(key []byte, _ uint32) ([]byte, error) { - return u.Get(key) -} - -// GetBulkFromEpoch will call the Get method for all keys as this storer doesn't handle epochs -func (u *Unit) GetBulkFromEpoch(keys [][]byte, _ uint32) ([]storageCore.KeyValuePair, error) { - results := make([]storageCore.KeyValuePair, 0, len(keys)) - for _, key := range keys { - value, err := u.Get(key) - if err != nil { - log.Warn("cannot get key from unit", - "key", key, - "error", err.Error(), - ) - continue - } - keyValue := storageCore.KeyValuePair{Key: key, Value: value} - results = append(results, keyValue) - } - return results, nil -} - -// Has checks if the key is in the Unit. -// It first checks the cache. If it is not found, it checks the db -func (u *Unit) Has(key []byte) error { - u.lock.RLock() - defer u.lock.RUnlock() - - has := u.cacher.Has(key) - if has { - return nil - } - - return u.persister.Has(key) -} - -// SearchFirst will call the Get method as this storer doesn't handle epochs -func (u *Unit) SearchFirst(key []byte) ([]byte, error) { - return u.Get(key) -} - -// RemoveFromCurrentEpoch removes the data associated to the given key from both cache and persistence medium -func (u *Unit) RemoveFromCurrentEpoch(key []byte) error { - return u.Remove(key) -} - -// Remove removes the data associated to the given key from both cache and persistence medium -func (u *Unit) Remove(key []byte) error { - u.lock.Lock() - defer u.lock.Unlock() - - u.cacher.Remove(key) - err := u.persister.Remove(key) - - return err -} - -// ClearCache cleans up the entire cache -func (u *Unit) ClearCache() { - u.cacher.Clear() -} - -// DestroyUnit cleans up the cache, and the db -func (u *Unit) DestroyUnit() error { - u.lock.Lock() - defer u.lock.Unlock() - - u.cacher.Clear() - return u.persister.Destroy() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (u *Unit) IsInterfaceNil() bool { - return u == nil -} - -// NewStorageUnit is the constructor for the storage unit, creating a new storage unit -// from the given cacher and persister. -func NewStorageUnit(c storage.Cacher, p storage.Persister) (*Unit, error) { - if check.IfNil(p) { - return nil, storage.ErrNilPersister - } - if check.IfNil(c) { - return nil, storage.ErrNilCacher - } - - sUnit := &Unit{ - persister: p, - cacher: c, - } - - return sUnit, nil -} - -// NewStorageUnitFromConf creates a new storage unit from a storage unit config -func NewStorageUnitFromConf(cacheConf CacheConfig, dbConf DBConfig) (*Unit, error) { - var cache storage.Cacher - var db storage.Persister - var err error - - // TODO: if there will be a differentiation between the creation or opening of a DB, the DB could be destroyed - // in case of a failure while creating (not opening). - - if dbConf.MaxBatchSize > int(cacheConf.Capacity) { - return nil, storage.ErrCacheSizeIsLowerThanBatchSize - } - - cache, err = NewCache(cacheConf) - if err != nil { - return nil, err - } - - argDB := ArgDB{ - DBType: dbConf.Type, - Path: dbConf.FilePath, - BatchDelaySeconds: dbConf.BatchDelaySeconds, - MaxBatchSize: dbConf.MaxBatchSize, - MaxOpenFiles: dbConf.MaxOpenFiles, - } - db, err = NewDB(argDB) - if err != nil { - return nil, err - } - - return NewStorageUnit(cache, db) -} - -// NewCache creates a new cache from a cache config -func NewCache(config CacheConfig) (storage.Cacher, error) { - storage.MonitorNewCache(config.Name, config.SizeInBytes) - - cacheType := config.Type - capacity := config.Capacity - shards := config.Shards - sizeInBytes := config.SizeInBytes - - var cacher storage.Cacher - var err error - - switch cacheType { - case LRUCache: - if sizeInBytes != 0 { - return nil, storage.ErrLRUCacheWithProvidedSize - } - - cacher, err = lrucache.NewCache(int(capacity)) - case SizeLRUCache: - if sizeInBytes < minimumSizeForLRUCache { - return nil, fmt.Errorf("%w, provided %d, minimum %d", - storage.ErrLRUCacheInvalidSize, - sizeInBytes, - minimumSizeForLRUCache, - ) - } - - cacher, err = lrucache.NewCacheWithSizeInBytes(int(capacity), int64(sizeInBytes)) - case FIFOShardedCache: - cacher, err = fifocache.NewShardedCache(int(capacity), int(shards)) - if err != nil { - return nil, err - } - // add other implementations if required - default: - return nil, storage.ErrNotSupportedCacheType - } - - if err != nil { - return nil, err - } - - return cacher, nil -} - -// ArgDB is a structure that is used to create a new storage.Persister implementation -type ArgDB struct { - DBType DBType - Path string - BatchDelaySeconds int - MaxBatchSize int - MaxOpenFiles int -} - -// NewDB creates a new database from database config -func NewDB(argDB ArgDB) (storage.Persister, error) { - var db storage.Persister - var err error - - for i := 0; i < common.MaxRetriesToCreateDB; i++ { - switch argDB.DBType { - case LvlDB: - db, err = leveldb.NewDB(argDB.Path, argDB.BatchDelaySeconds, argDB.MaxBatchSize, argDB.MaxOpenFiles) - case LvlDBSerial: - db, err = leveldb.NewSerialDB(argDB.Path, argDB.BatchDelaySeconds, argDB.MaxBatchSize, argDB.MaxOpenFiles) - case MemoryDB: - db = memorydb.New() - default: - return nil, storage.ErrNotSupportedDBType - } - - if err == nil { - return db, nil - } - - // TODO: extract this in a parameter and inject it - time.Sleep(common.SleepTimeBetweenCreateDBRetries) - } - if err != nil { - return nil, err - } - - return db, nil -} - -// NewHasher will return a hasher implementation form the string HasherType -func (h HasherType) NewHasher() (hashing.Hasher, error) { - switch h { - case Keccak: - return keccak.NewKeccak(), nil - case Blake2b: - return blake2b.NewBlake2b(), nil - case Fnv: - return fnv.NewFnv(), nil - default: - return nil, storage.ErrNotSupportedHashType - } -} diff --git a/storage/storageUnit/storageunit_test.go b/storage/storageUnit/storageunit_test.go deleted file mode 100644 index 0b419376361..00000000000 --- a/storage/storageUnit/storageunit_test.go +++ /dev/null @@ -1,453 +0,0 @@ -package storageUnit_test - -import ( - "fmt" - "math/rand" - "strconv" - "testing" - - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/stretchr/testify/assert" -) - -func logError(err error) { - if err != nil { - fmt.Println(err.Error()) - } -} - -func initStorageUnit(tb testing.TB, cSize int) *storageUnit.Unit { - mdb := memorydb.New() - cache, err2 := lrucache.NewCache(cSize) - assert.Nil(tb, err2, "no error expected but got %s", err2) - - sUnit, err := storageUnit.NewStorageUnit(cache, mdb) - assert.Nil(tb, err, "failed to create storage unit") - - return sUnit -} - -func TestStorageUnitNilPersister(t *testing.T) { - cache, err1 := lrucache.NewCache(10) - - assert.Nil(t, err1, "no error expected but got %s", err1) - - _, err := storageUnit.NewStorageUnit(cache, nil) - - assert.NotNil(t, err, "expected failure") -} - -func TestStorageUnitNilCacher(t *testing.T) { - mdb := memorydb.New() - - _, err1 := storageUnit.NewStorageUnit(nil, mdb) - assert.NotNil(t, err1, "expected failure") -} - -func TestStorageUnit(t *testing.T) { - cache, err1 := lrucache.NewCache(10) - mdb := memorydb.New() - - assert.Nil(t, err1, "no error expected but got %s", err1) - - _, err := storageUnit.NewStorageUnit(cache, mdb) - assert.Nil(t, err, "did not expect failure") -} - -func TestPutNotPresent(t *testing.T) { - key, val := []byte("key0"), []byte("value0") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - - assert.Nil(t, err, "no error expected but got %s", err) - - err = s.Has(key) - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestPutNotPresentCache(t *testing.T) { - key, val := []byte("key1"), []byte("value1") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - - assert.Nil(t, err, "no error expected but got %s", err) - - s.ClearCache() - - err = s.Has(key) - - assert.Nil(t, err, "no error expected but got %s", err) -} - -func TestPutPresentShouldOverwriteValue(t *testing.T) { - key, val := []byte("key2"), []byte("value2") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - - assert.Nil(t, err, "no error expected but got %s", err) - - newVal := []byte("value5") - err = s.Put(key, newVal) - assert.Nil(t, err, "no error expected but got %s", err) - - returnedVal, err := s.Get(key) - assert.Nil(t, err) - assert.Equal(t, newVal, returnedVal) -} - -func TestGetNotPresent(t *testing.T) { - key := []byte("key3") - s := initStorageUnit(t, 10) - v, err := s.Get(key) - - assert.NotNil(t, err, "expected to find no value, but found %s", v) -} - -func TestGetNotPresentCache(t *testing.T) { - key, val := []byte("key4"), []byte("value4") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - - assert.Nil(t, err, "no error expected but got %s", err) - - s.ClearCache() - - v, err := s.Get(key) - - assert.Nil(t, err, "expected no error, but got %s", err) - assert.Equal(t, val, v, "expected %s but got %s", val, v) -} - -func TestGetPresent(t *testing.T) { - key, val := []byte("key5"), []byte("value4") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - - assert.Nil(t, err, "no error expected but got %s", err) - - v, err := s.Get(key) - - assert.Nil(t, err, "expected no error, but got %s", err) - assert.Equal(t, val, v, "expected %s but got %s", val, v) -} - -func TestHasNotPresent(t *testing.T) { - key := []byte("key6") - s := initStorageUnit(t, 10) - err := s.Has(key) - - assert.NotNil(t, err) - assert.Equal(t, err, storage.ErrKeyNotFound) -} - -func TestHasNotPresentCache(t *testing.T) { - key, val := []byte("key7"), []byte("value7") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - - assert.Nil(t, err, "no error expected but got %s", err) - - s.ClearCache() - - err = s.Has(key) - - assert.Nil(t, err, "expected no error, but got %s", err) -} - -func TestHasPresent(t *testing.T) { - key, val := []byte("key8"), []byte("value8") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - - assert.Nil(t, err, "no error expected but got %s", err) - - err = s.Has(key) - - assert.Nil(t, err, "expected no error, but got %s", err) -} - -func TestDeleteNotPresent(t *testing.T) { - key := []byte("key12") - s := initStorageUnit(t, 10) - err := s.Remove(key) - - assert.Nil(t, err, "expected no error, but got %s", err) -} - -func TestDeleteNotPresentCache(t *testing.T) { - key, val := []byte("key13"), []byte("value13") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - assert.Nil(t, err, "Could not put value in storage unit") - - err = s.Has(key) - - assert.Nil(t, err, "expected no error, but got %s", err) - - s.ClearCache() - - err = s.Remove(key) - assert.Nil(t, err, "expected no error, but got %s", err) - - err = s.Has(key) - - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "key not found") -} - -func TestDeletePresent(t *testing.T) { - key, val := []byte("key14"), []byte("value14") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - assert.Nil(t, err, "Could not put value in storage unit") - - err = s.Has(key) - - assert.Nil(t, err, "expected no error, but got %s", err) - - err = s.Remove(key) - - assert.Nil(t, err, "expected no error, but got %s", err) - - err = s.Has(key) - - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "key not found") -} - -func TestClearCacheNotAffectPersist(t *testing.T) { - key, val := []byte("key15"), []byte("value15") - s := initStorageUnit(t, 10) - err := s.Put(key, val) - assert.Nil(t, err, "Could not put value in storage unit") - s.ClearCache() - - err = s.Has(key) - - assert.Nil(t, err, "no error expected, but got %s", err) -} - -func TestDestroyUnitNoError(t *testing.T) { - s := initStorageUnit(t, 10) - err := s.DestroyUnit() - assert.Nil(t, err, "no error expected, but got %s", err) -} - -func TestCreateCacheFromConfWrongType(t *testing.T) { - - cacher, err := storageUnit.NewCache(storageUnit.CacheConfig{Type: "NotLRU", Capacity: 100, Shards: 1, SizeInBytes: 0}) - - assert.NotNil(t, err, "error expected") - assert.Nil(t, cacher, "cacher expected to be nil, but got %s", cacher) -} - -func TestCreateCacheFromConfOK(t *testing.T) { - - cacher, err := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 10, Shards: 1, SizeInBytes: 0}) - - assert.Nil(t, err, "no error expected but got %s", err) - assert.NotNil(t, cacher, "valid cacher expected but got nil") -} - -func TestCreateDBFromConfWrongType(t *testing.T) { - arg := storageUnit.ArgDB{ - DBType: "NotLvlDB", - Path: "test", - BatchDelaySeconds: 10, - MaxBatchSize: 10, - MaxOpenFiles: 10, - } - persister, err := storageUnit.NewDB(arg) - - assert.NotNil(t, err, "error expected") - assert.Nil(t, persister, "persister expected to be nil, but got %s", persister) -} - -func TestCreateDBFromConfWrongFileNameLvlDB(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - arg := storageUnit.ArgDB{ - DBType: storageUnit.LvlDB, - Path: "", - BatchDelaySeconds: 10, - MaxBatchSize: 10, - MaxOpenFiles: 10, - } - persister, err := storageUnit.NewDB(arg) - assert.NotNil(t, err, "error expected") - assert.Nil(t, persister, "persister expected to be nil, but got %s", persister) -} - -func TestCreateDBFromConfLvlDBOk(t *testing.T) { - arg := storageUnit.ArgDB{ - DBType: storageUnit.LvlDB, - Path: t.TempDir(), - BatchDelaySeconds: 10, - MaxBatchSize: 10, - MaxOpenFiles: 10, - } - persister, err := storageUnit.NewDB(arg) - assert.Nil(t, err, "no error expected") - assert.NotNil(t, persister, "valid persister expected but got nil") - - err = persister.Destroy() - assert.Nil(t, err, "no error expected destroying the persister") -} - -func TestNewStorageUnit_FromConfWrongCacheSizeVsBatchSize(t *testing.T) { - - storer, err := storageUnit.NewStorageUnitFromConf(storageUnit.CacheConfig{ - Capacity: 10, - Type: storageUnit.LRUCache, - }, storageUnit.DBConfig{ - FilePath: "Blocks", - Type: storageUnit.LvlDB, - MaxBatchSize: 11, - BatchDelaySeconds: 1, - MaxOpenFiles: 10, - }) - - assert.NotNil(t, err, "error expected") - assert.Nil(t, storer, "storer expected to be nil but got %s", storer) -} - -func TestNewStorageUnit_FromConfWrongCacheConfig(t *testing.T) { - - storer, err := storageUnit.NewStorageUnitFromConf(storageUnit.CacheConfig{ - Capacity: 10, - Type: "NotLRU", - }, storageUnit.DBConfig{ - FilePath: "Blocks", - Type: storageUnit.LvlDB, - BatchDelaySeconds: 1, - MaxBatchSize: 1, - MaxOpenFiles: 10, - }) - - assert.NotNil(t, err, "error expected") - assert.Nil(t, storer, "storer expected to be nil but got %s", storer) -} - -func TestNewStorageUnit_FromConfWrongDBConfig(t *testing.T) { - storer, err := storageUnit.NewStorageUnitFromConf(storageUnit.CacheConfig{ - Capacity: 10, - Type: storageUnit.LRUCache, - }, storageUnit.DBConfig{ - FilePath: "Blocks", - Type: "NotLvlDB", - }) - - assert.NotNil(t, err, "error expected") - assert.Nil(t, storer, "storer expected to be nil but got %s", storer) -} - -func TestNewStorageUnit_FromConfLvlDBOk(t *testing.T) { - storer, err := storageUnit.NewStorageUnitFromConf(storageUnit.CacheConfig{ - Capacity: 10, - Type: storageUnit.LRUCache, - }, storageUnit.DBConfig{ - FilePath: "Blocks", - Type: storageUnit.LvlDB, - MaxBatchSize: 1, - BatchDelaySeconds: 1, - MaxOpenFiles: 10, - }) - - assert.Nil(t, err, "no error expected but got %s", err) - assert.NotNil(t, storer, "valid storer expected but got nil") - err = storer.DestroyUnit() - assert.Nil(t, err, "no error expected destroying the persister") -} - -func TestNewStorageUnit_ShouldWorkLvlDB(t *testing.T) { - storer, err := storageUnit.NewStorageUnitFromConf(storageUnit.CacheConfig{ - Capacity: 10, - Type: storageUnit.LRUCache, - }, storageUnit.DBConfig{ - FilePath: "Blocks", - Type: storageUnit.LvlDB, - BatchDelaySeconds: 1, - MaxBatchSize: 1, - MaxOpenFiles: 10, - }) - - assert.Nil(t, err, "no error expected but got %s", err) - assert.NotNil(t, storer, "valid storer expected but got nil") - err = storer.DestroyUnit() - assert.Nil(t, err, "no error expected destroying the persister") -} - -const ( - valuesInDb = 100000 -) - -func BenchmarkStorageUnit_Put(b *testing.B) { - b.StopTimer() - s := initStorageUnit(b, 1) - defer func() { - err := s.DestroyUnit() - logError(err) - }() - b.StartTimer() - - for i := 0; i < b.N; i++ { - b.StopTimer() - nr := rand.Intn(valuesInDb) - b.StartTimer() - - err := s.Put([]byte(strconv.Itoa(nr)), []byte(strconv.Itoa(nr))) - logError(err) - } -} - -func BenchmarkStorageUnit_GetWithDataBeingPresent(b *testing.B) { - b.StopTimer() - s := initStorageUnit(b, 1) - defer func() { - err := s.DestroyUnit() - logError(err) - }() - for i := 0; i < valuesInDb; i++ { - err := s.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) - logError(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - b.StopTimer() - nr := rand.Intn(valuesInDb) - b.StartTimer() - - _, err := s.Get([]byte(strconv.Itoa(nr))) - logError(err) - } -} - -func BenchmarkStorageUnit_GetWithDataNotBeingPresent(b *testing.B) { - b.StopTimer() - s := initStorageUnit(b, 1) - defer func() { - err := s.DestroyUnit() - logError(err) - }() - for i := 0; i < valuesInDb; i++ { - err := s.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) - logError(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - b.StopTimer() - nr := rand.Intn(valuesInDb) + valuesInDb - b.StartTimer() - - _, err := s.Get([]byte(strconv.Itoa(nr))) - logError(err) - } -} diff --git a/storage/storageunit/constants.go b/storage/storageunit/constants.go new file mode 100644 index 00000000000..ca17c374fde --- /dev/null +++ b/storage/storageunit/constants.go @@ -0,0 +1,21 @@ +package storageunit + +import "github.com/ElrondNetwork/elrond-go-storage/storageUnit" + +const ( + // LRUCache defines a cache identifier with least-recently-used eviction mechanism + LRUCache = storageUnit.LRUCache + // SizeLRUCache defines a cache identifier with least-recently-used eviction mechanism and fixed size in bytes + SizeLRUCache = storageUnit.SizeLRUCache +) + +// LvlDB currently the only supported DBs +// More to be added +const ( + // LvlDB represents a levelDB storage identifier + LvlDB = storageUnit.LvlDB + // LvlDBSerial represents a levelDB storage with serialized operations identifier + LvlDBSerial = storageUnit.LvlDBSerial + // MemoryDB represents an in memory storage identifier + MemoryDB = storageUnit.MemoryDB +) diff --git a/storage/storageunit/storageunit.go b/storage/storageunit/storageunit.go new file mode 100644 index 00000000000..6a2404a2e02 --- /dev/null +++ b/storage/storageunit/storageunit.go @@ -0,0 +1,66 @@ +package storageunit + +import ( + "github.com/ElrondNetwork/elrond-go-core/marshal" + "github.com/ElrondNetwork/elrond-go-storage/storageCacherAdapter" + "github.com/ElrondNetwork/elrond-go-storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage" +) + +// Unit represents a storer's data bank +// holding the cache and persistence unit +type Unit = storageUnit.Unit + +// CacheConfig holds the configurable elements of a cache +type CacheConfig = storageUnit.CacheConfig + +// ArgDB is a structure that is used to create a new storage.Persister implementation +type ArgDB = storageUnit.ArgDB + +// DBConfig holds the configurable elements of a database +type DBConfig = storageUnit.DBConfig + +// NilStorer resembles a disabled implementation of the Storer interface +type NilStorer = storageUnit.NilStorer + +// CacheType represents the type of the supported caches +type CacheType = storageUnit.CacheType + +// DBType represents the type of the supported databases +type DBType = storageUnit.DBType + +// NewStorageUnit is the constructor for the storage unit, creating a new storage unit +// from the given cacher and persister. +func NewStorageUnit(c storage.Cacher, p storage.Persister) (*Unit, error) { + return storageUnit.NewStorageUnit(c, p) +} + +// NewCache creates a new cache from a cache config +func NewCache(config CacheConfig) (storage.Cacher, error) { + return storageUnit.NewCache(config) +} + +// NewDB creates a new database from database config +func NewDB(argDB ArgDB) (storage.Persister, error) { + return storageUnit.NewDB(argDB) +} + +// NewStorageUnitFromConf creates a new storage unit from a storage unit config +func NewStorageUnitFromConf(cacheConf CacheConfig, dbConf DBConfig) (*Unit, error) { + return storageUnit.NewStorageUnitFromConf(cacheConf, dbConf) +} + +// NewNilStorer will return a nil storer +func NewNilStorer() *NilStorer { + return storageUnit.NewNilStorer() +} + +// NewStorageCacherAdapter creates a new storageCacherAdapter +func NewStorageCacherAdapter( + cacher storage.AdaptedSizedLRUCache, + db storage.Persister, + storedDataFactory storage.StoredDataFactory, + marshaller marshal.Marshalizer, +) (storage.Cacher, error) { + return storageCacherAdapter.NewStorageCacherAdapter(cacher, db, storedDataFactory, marshaller) +} diff --git a/storage/timecache/export_test.go b/storage/timecache/export_test.go deleted file mode 100644 index e1e7b9be320..00000000000 --- a/storage/timecache/export_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package timecache - -// Keys - -func (tc *TimeCache) Keys() []string { - tc.timeCache.Lock() - defer tc.timeCache.Unlock() - - keys := make([]string, 0, len(tc.timeCache.data)) - for key := range tc.timeCache.data { - keys = append(keys, key) - } - - return keys -} - -// Value - -func (tc *TimeCache) Value(key string) (*entry, bool) { - tc.timeCache.Lock() - defer tc.timeCache.Unlock() - - val, ok := tc.timeCache.data[key] - - return val, ok -} - -// NumRegisteredHandlers - -func (tc *timeCacher) NumRegisteredHandlers() int { - tc.mutAddedDataHandlers.RLock() - defer tc.mutAddedDataHandlers.RUnlock() - - return len(tc.mapDataHandlers) -} diff --git a/storage/timecache/peerTimeCache.go b/storage/timecache/peerTimeCache.go deleted file mode 100644 index 5f498274f55..00000000000 --- a/storage/timecache/peerTimeCache.go +++ /dev/null @@ -1,46 +0,0 @@ -package timecache - -import ( - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/storage" -) - -type peerTimeCache struct { - timeCache storage.TimeCacher -} - -// NewPeerTimeCache creates a new peer time cache data structure instance -func NewPeerTimeCache(timeCache storage.TimeCacher) (*peerTimeCache, error) { - if check.IfNil(timeCache) { - return nil, storage.ErrNilTimeCache - } - - return &peerTimeCache{ - timeCache: timeCache, - }, nil -} - -// Upsert will add the pid and provided duration if not exists -// If the record exists, will update the duration if the provided duration is larger than existing -// Also, it will reset the contained timestamp to time.Now -func (ptc *peerTimeCache) Upsert(pid core.PeerID, duration time.Duration) error { - return ptc.timeCache.Upsert(string(pid), duration) -} - -// Sweep will call the inner time cache method -func (ptc *peerTimeCache) Sweep() { - ptc.timeCache.Sweep() -} - -// Has will call the inner time cache method with the provided pid as string -func (ptc *peerTimeCache) Has(pid core.PeerID) bool { - return ptc.timeCache.Has(string(pid)) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ptc *peerTimeCache) IsInterfaceNil() bool { - return ptc == nil -} diff --git a/storage/timecache/peerTimeCache_test.go b/storage/timecache/peerTimeCache_test.go deleted file mode 100644 index be021bbae94..00000000000 --- a/storage/timecache/peerTimeCache_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package timecache - -import ( - "errors" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/testscommon" - "github.com/stretchr/testify/assert" -) - -func TestNewPeerTimeCache_NilTimeCacheShouldErr(t *testing.T) { - t.Parallel() - - ptc, err := NewPeerTimeCache(nil) - - assert.Equal(t, storage.ErrNilTimeCache, err) - assert.True(t, check.IfNil(ptc)) -} - -func TestNewPeerTimeCache_ShouldWork(t *testing.T) { - t.Parallel() - - ptc, err := NewPeerTimeCache(&testscommon.TimeCacheStub{}) - - assert.Nil(t, err) - assert.False(t, check.IfNil(ptc)) -} - -func TestPeerTimeCache_Methods(t *testing.T) { - t.Parallel() - - pid := core.PeerID("test peer id") - unexpectedErr := errors.New("unexpected error") - updateWasCalled := false - hasWasCalled := false - sweepWasCalled := false - ptc, _ := NewPeerTimeCache(&testscommon.TimeCacheStub{ - UpsertCalled: func(key string, span time.Duration) error { - if key != string(pid) { - return unexpectedErr - } - - updateWasCalled = true - return nil - }, - HasCalled: func(key string) bool { - if key != string(pid) { - return false - } - - hasWasCalled = true - return true - }, - SweepCalled: func() { - sweepWasCalled = true - }, - }) - - assert.Nil(t, ptc.Upsert(pid, time.Second)) - assert.True(t, ptc.Has(pid)) - ptc.Sweep() - - assert.True(t, updateWasCalled) - assert.True(t, hasWasCalled) - assert.True(t, sweepWasCalled) -} diff --git a/storage/timecache/timeCache.go b/storage/timecache/timeCache.go deleted file mode 100644 index bc30f85b200..00000000000 --- a/storage/timecache/timeCache.go +++ /dev/null @@ -1,81 +0,0 @@ -package timecache - -import ( - "time" - - "github.com/ElrondNetwork/elrond-go/dataRetriever" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var _ dataRetriever.RequestedItemsHandler = (*TimeCache)(nil) - -// TimeCache can retain an amount of string keys for a defined period of time -// sweeping (clean-up) is triggered each time a new item is added or a key is present in the time cache -// This data structure is concurrent safe. -type TimeCache struct { - timeCache *timeCacheCore -} - -// NewTimeCache creates a new time cache data structure instance -func NewTimeCache(defaultSpan time.Duration) *TimeCache { - return &TimeCache{ - timeCache: newTimeCacheCore(defaultSpan), - } -} - -// Add will store the key in the time cache -// Double adding the key is permitted. It will replace the data, if existing. It does not trigger sweep. -func (tc *TimeCache) Add(key string) error { - return tc.add(key, tc.timeCache.defaultSpan) -} - -func (tc *TimeCache) add(key string, duration time.Duration) error { - if len(key) == 0 { - return storage.ErrEmptyKey - } - - tc.timeCache.Lock() - defer tc.timeCache.Unlock() - - tc.timeCache.data[key] = &entry{ - timestamp: time.Now(), - span: duration, - } - return nil -} - -// AddWithSpan will store the key in the time cache with the provided span duration -// Double adding the key is permitted. It will replace the data, if existing. It does not trigger sweep. -func (tc *TimeCache) AddWithSpan(key string, duration time.Duration) error { - return tc.add(key, duration) -} - -// Upsert will add the key and provided duration if not exists -// If the record exists, will update the duration if the provided duration is larger than existing -// Also, it will reset the contained timestamp to time.Now -func (tc *TimeCache) Upsert(key string, duration time.Duration) error { - _, err := tc.timeCache.upsert(key, nil, duration) - - return err -} - -// Sweep starts from the oldest element and will search each element if it is still valid to be kept. Sweep ends when -// it finds an element that is still valid -func (tc *TimeCache) Sweep() { - tc.timeCache.sweep() -} - -// Has returns if the key is still found in the time cache -func (tc *TimeCache) Has(key string) bool { - return tc.timeCache.has(key) -} - -// Len returns the number of elements which are still stored in the time cache -func (tc *TimeCache) Len() int { - return tc.timeCache.len() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (tc *TimeCache) IsInterfaceNil() bool { - return tc == nil -} diff --git a/storage/timecache/timeCacheCore.go b/storage/timecache/timeCacheCore.go deleted file mode 100644 index 1cdd13a3ba9..00000000000 --- a/storage/timecache/timeCacheCore.go +++ /dev/null @@ -1,139 +0,0 @@ -package timecache - -import ( - "sync" - "time" - - "github.com/ElrondNetwork/elrond-go/storage" -) - -type entry struct { - timestamp time.Time - span time.Duration - value interface{} -} - -type timeCacheCore struct { - *sync.RWMutex - data map[string]*entry - defaultSpan time.Duration -} - -func newTimeCacheCore(defaultSpan time.Duration) *timeCacheCore { - return &timeCacheCore{ - RWMutex: &sync.RWMutex{}, - data: make(map[string]*entry), - defaultSpan: defaultSpan, - } -} - -// upsert will add the key, value and provided duration if not exists -// If the record exists, will update the duration if the provided duration is larger than existing -// Also, it will reset the contained timestamp to time.Now -// It returns if the value existed before this call. It also operates on the locker so the call is concurrent safe -func (tcc *timeCacheCore) upsert(key string, value interface{}, duration time.Duration) (bool, error) { - if len(key) == 0 { - return false, storage.ErrEmptyKey - } - - tcc.Lock() - defer tcc.Unlock() - - existing, found := tcc.data[key] - if found { - if existing.span < duration { - existing.span = duration - } - existing.timestamp = time.Now() - - return found, nil - } - - tcc.data[key] = &entry{ - timestamp: time.Now(), - span: duration, - value: value, - } - return found, nil -} - -// put will add the key, value and provided duration, overriding values if the data already existed -// It also operates on the locker so the call is concurrent safe -func (tcc *timeCacheCore) put(key string, value interface{}, duration time.Duration) error { - if len(key) == 0 { - return storage.ErrEmptyKey - } - - tcc.Lock() - defer tcc.Unlock() - - tcc.data[key] = &entry{ - timestamp: time.Now(), - span: duration, - value: value, - } - return nil -} - -// hasOrAdd will add the key, value and provided duration, if the key is not found -// It returns true if the value existed before this call and if it has been added or not. It also operates on the locker so the call is concurrent safe -func (tcc *timeCacheCore) hasOrAdd(key string, value interface{}, duration time.Duration) (bool, bool, error) { - if len(key) == 0 { - return false, false, storage.ErrEmptyKey - } - - tcc.Lock() - defer tcc.Unlock() - - _, found := tcc.data[key] - if found { - return true, false, nil - } - - tcc.data[key] = &entry{ - timestamp: time.Now(), - span: duration, - value: value, - } - return false, true, nil -} - -// sweep iterates over all contained elements checking if the element is still valid to be kept -// It also operates on the locker so the call is concurrent safe -func (tcc *timeCacheCore) sweep() { - tcc.Lock() - defer tcc.Unlock() - - for key, element := range tcc.data { - isOldElement := time.Since(element.timestamp) > element.span - if isOldElement { - delete(tcc.data, key) - } - } -} - -// has returns if the key is still found in the time cache -func (tcc *timeCacheCore) has(key string) bool { - tcc.RLock() - defer tcc.RUnlock() - - _, ok := tcc.data[key] - - return ok -} - -// len returns the number of elements which are still stored in the time cache -func (tcc *timeCacheCore) len() int { - tcc.RLock() - defer tcc.RUnlock() - - return len(tcc.data) -} - -// clear recreates the map, thus deleting any existing entries -// It also operates on the locker so the call is concurrent safe -func (tcc *timeCacheCore) clear() { - tcc.Lock() - tcc.data = make(map[string]*entry) - tcc.Unlock() -} diff --git a/storage/timecache/timeCacheCore_test.go b/storage/timecache/timeCacheCore_test.go deleted file mode 100644 index 149fb6b7d06..00000000000 --- a/storage/timecache/timeCacheCore_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package timecache - -import ( - "fmt" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestTimeCacheCore_ConcurrentOperations(t *testing.T) { - t.Parallel() - - tcc := newTimeCacheCore(time.Second) - numOperations := 1000 - wg := &sync.WaitGroup{} - wg.Add(numOperations) - for i := 0; i < numOperations; i++ { - go func(idx int) { - time.Sleep(time.Millisecond * 10) - - switch idx % 7 { - case 0: - _, err := tcc.upsert(fmt.Sprintf("key%d", idx), fmt.Sprintf("valuey%d", idx), time.Second) - assert.Nil(t, err) - case 1: - tcc.sweep() - case 2: - _ = tcc.has(fmt.Sprintf("key%d", idx)) - case 3: - _ = tcc.len() - case 4: - tcc.clear() - case 5: - err := tcc.put(fmt.Sprintf("key%d", idx), fmt.Sprintf("valuey%d", idx), time.Second) - assert.Nil(t, err) - case 6: - _, _, err := tcc.hasOrAdd(fmt.Sprintf("key%d", idx), fmt.Sprintf("valuey%d", idx), time.Second) - assert.Nil(t, err) - default: - assert.Fail(t, "test setup error, change the line 'switch idx % xxx {' from this test") - } - - wg.Done() - }(i) - } - - wg.Wait() -} diff --git a/storage/timecache/timeCache_test.go b/storage/timecache/timeCache_test.go deleted file mode 100644 index 9d3a68e4956..00000000000 --- a/storage/timecache/timeCache_test.go +++ /dev/null @@ -1,247 +0,0 @@ -package timecache - -import ( - "fmt" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// ------- Add - -func TestTimeCache_EmptyKeyShouldErr(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "" - - err := tc.Add(key) - - _, ok := tc.Value(key) - assert.Equal(t, storage.ErrEmptyKey, err) - assert.False(t, ok) -} - -func TestTimeCache_AddShouldWork(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key1" - - err := tc.Add(key) - - keys := tc.Keys() - _, ok := tc.Value(key) - assert.Nil(t, err) - assert.Equal(t, key, keys[0]) - assert.True(t, ok) -} - -func TestTimeCache_DoubleAddShouldWork(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key1" - - _ = tc.AddWithSpan(key, time.Second) - newSpan := time.Second * 4 - err := tc.AddWithSpan(key, newSpan) - assert.Nil(t, err) - - keys := tc.Keys() - s, ok := tc.Value(key) - assert.Equal(t, key, keys[0]) - assert.True(t, ok) - assert.Equal(t, newSpan, s.span) -} - -func TestTimeCache_DoubleAddAfterExpirationAndSweepShouldWork(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Millisecond) - key := "key1" - - _ = tc.Add(key) - time.Sleep(time.Second) - tc.Sweep() - err := tc.Add(key) - - keys := tc.Keys() - _, ok := tc.Value(key) - assert.Nil(t, err) - assert.Equal(t, key, keys[0]) - assert.True(t, ok) -} - -func TestTimeCache_AddWithSpanShouldWork(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key1" - - duration := time.Second * 1638 - err := tc.AddWithSpan(key, duration) - - keys := tc.Keys() - _, ok := tc.Value(key) - assert.Nil(t, err) - assert.Equal(t, key, keys[0]) - assert.True(t, ok) - - spanRecovered, _ := tc.Value(key) - assert.Equal(t, duration, spanRecovered.span) -} - -// ------- Has - -func TestTimeCache_HasNotExistingShouldRetFalse(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key1" - - exists := tc.Has(key) - - assert.False(t, exists) -} - -func TestTimeCache_HasExistsShouldRetTrue(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key1" - _ = tc.Add(key) - - exists := tc.Has(key) - - assert.True(t, exists) -} - -func TestTimeCache_HasCheckEvictionIsDoneProperly(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Millisecond) - key1 := "key1" - key2 := "key2" - _ = tc.Add(key1) - _ = tc.Add(key2) - time.Sleep(time.Second) - tc.Sweep() - - exists1 := tc.Has(key1) - exists2 := tc.Has(key2) - - assert.False(t, exists1) - assert.False(t, exists2) - assert.Equal(t, 0, len(tc.Keys())) -} - -func TestTimeCache_HasCheckHandlingInconsistency(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key1" - _ = tc.Add(key) - tc.timeCache.clear() - tc.Sweep() - - exists := tc.Has(key) - - assert.False(t, exists) - assert.Equal(t, 0, len(tc.Keys())) -} - -// ------- Upsert - -func TestTimeCache_UpsertEmptyKeyShouldErr(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - err := tc.Upsert("", time.Second) - - assert.Equal(t, storage.ErrEmptyKey, err) -} - -func TestTimeCache_UpsertShouldAddIfMissing(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key" - s := time.Second * 45 - err := tc.Upsert(key, s) - assert.Nil(t, err) - - recovered, ok := tc.Value(key) - require.True(t, ok) - assert.Equal(t, s, recovered.span) -} - -func TestTimeCache_UpsertLessSpanShouldNotUpdate(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key" - highSpan := time.Second * 45 - lowSpan := time.Second * 44 - err := tc.Upsert(key, highSpan) - assert.Nil(t, err) - - err = tc.Upsert(key, lowSpan) - assert.Nil(t, err) - - recovered, ok := tc.Value(key) - require.True(t, ok) - assert.Equal(t, highSpan, recovered.span) -} - -func TestTimeCache_UpsertmoreSpanShouldUpdate(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - key := "key" - highSpan := time.Second * 45 - lowSpan := time.Second * 44 - err := tc.Upsert(key, lowSpan) - assert.Nil(t, err) - - err = tc.Upsert(key, highSpan) - assert.Nil(t, err) - - recovered, ok := tc.Value(key) - require.True(t, ok) - assert.Equal(t, highSpan, recovered.span) -} - -func TestTimeCache_Len(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - assert.Equal(t, 0, tc.Len()) - numTests := 10 - for i := 0; i < numTests; i++ { - _ = tc.Add(fmt.Sprintf("%d", i)) - assert.Equal(t, i+1, tc.Len()) - } -} - -// ------- IsInterfaceNil - -func TestTimeCache_IsInterfaceNilNotNil(t *testing.T) { - t.Parallel() - - tc := NewTimeCache(time.Second) - - assert.False(t, check.IfNil(tc)) -} - -func TestTimeCache_IsInterfaceNil(t *testing.T) { - t.Parallel() - - var tc *TimeCache - - assert.True(t, check.IfNil(tc)) -} diff --git a/storage/timecache/timeCacher.go b/storage/timecache/timeCacher.go deleted file mode 100644 index aeb620122f4..00000000000 --- a/storage/timecache/timeCacher.go +++ /dev/null @@ -1,221 +0,0 @@ -package timecache - -import ( - "context" - "math" - "sync" - "time" - - logger "github.com/ElrondNetwork/elrond-go-logger" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var log = logger.GetOrCreate("storage/maptimecache") - -const minDuration = time.Second - -// ArgTimeCacher is the argument used to create a new timeCacher instance -type ArgTimeCacher struct { - DefaultSpan time.Duration - CacheExpiry time.Duration -} - -// timeCacher implements a time cacher with automatic sweeping mechanism -type timeCacher struct { - timeCache *timeCacheCore - cacheExpiry time.Duration - cancelFunc func() - - mutAddedDataHandlers sync.RWMutex - mapDataHandlers map[string]func(key []byte, value interface{}) -} - -// NewTimeCacher creates a new timeCacher -func NewTimeCacher(arg ArgTimeCacher) (*timeCacher, error) { - err := checkArg(arg) - if err != nil { - return nil, err - } - - tc := &timeCacher{ - timeCache: newTimeCacheCore(arg.DefaultSpan), - cacheExpiry: arg.CacheExpiry, - mapDataHandlers: make(map[string]func(key []byte, value interface{})), - } - - var ctx context.Context - ctx, tc.cancelFunc = context.WithCancel(context.Background()) - go tc.startSweeping(ctx) - - return tc, nil -} - -func checkArg(arg ArgTimeCacher) error { - if arg.DefaultSpan < minDuration { - return storage.ErrInvalidDefaultSpan - } - if arg.CacheExpiry < minDuration { - return storage.ErrInvalidCacheExpiry - } - - return nil -} - -// startSweeping handles sweeping the time cache -func (tc *timeCacher) startSweeping(ctx context.Context) { - timer := time.NewTimer(tc.cacheExpiry) - defer timer.Stop() - - for { - timer.Reset(tc.cacheExpiry) - - select { - case <-timer.C: - tc.timeCache.sweep() - case <-ctx.Done(): - log.Info("closing mapTimeCacher's sweep go routine...") - return - } - } -} - -// Clear deletes all stored data -func (tc *timeCacher) Clear() { - tc.timeCache.clear() -} - -// Put adds a value to the cache. It will always return false since the eviction did not occur -func (tc *timeCacher) Put(key []byte, value interface{}, _ int) (evicted bool) { - err := tc.timeCache.put(string(key), value, tc.timeCache.defaultSpan) - if err != nil { - log.Error("mapTimeCacher.Put", "key", key, "error", err) - return - } - - tc.callAddedDataHandlers(key, value) - - return false -} - -// Get returns a key's value from the cache -func (tc *timeCacher) Get(key []byte) (interface{}, bool) { - tc.timeCache.RLock() - defer tc.timeCache.RUnlock() - - v, ok := tc.timeCache.data[string(key)] - if !ok { - return nil, ok - } - - return v.value, ok -} - -// Has checks if a key is in the cache -func (tc *timeCacher) Has(key []byte) bool { - return tc.timeCache.has(string(key)) -} - -// Peek returns a key's value from the cache -func (tc *timeCacher) Peek(key []byte) (value interface{}, ok bool) { - return tc.Get(key) -} - -// HasOrAdd checks if a key is in the cache. -// If key exists, does not update the value. Otherwise, adds the key-value in the cache -func (tc *timeCacher) HasOrAdd(key []byte, value interface{}, _ int) (has, added bool) { - var err error - has, added, err = tc.timeCache.hasOrAdd(string(key), value, tc.timeCache.defaultSpan) - if err != nil { - log.Error("mapTimeCacher.HasOrAdd", "key", key, "error", err) - return - } - - if !has { - tc.callAddedDataHandlers(key, value) - } - - return -} - -// Remove removes the key from cache -func (tc *timeCacher) Remove(key []byte) { - if key == nil { - return - } - - tc.timeCache.Lock() - defer tc.timeCache.Unlock() - - delete(tc.timeCache.data, string(key)) -} - -// Keys returns all keys from cache -func (tc *timeCacher) Keys() [][]byte { - tc.timeCache.RLock() - defer tc.timeCache.RUnlock() - - keys := make([][]byte, len(tc.timeCache.data)) - idx := 0 - for k := range tc.timeCache.data { - keys[idx] = []byte(k) - idx++ - } - - return keys -} - -// Len returns the size of the cache -func (tc *timeCacher) Len() int { - return tc.timeCache.len() -} - -// SizeInBytesContained will always return 0 -func (tc *timeCacher) SizeInBytesContained() uint64 { - return 0 -} - -// MaxSize returns the maximum number of items which can be stored in cache. -func (tc *timeCacher) MaxSize() int { - return math.MaxInt32 -} - -// RegisterHandler registers a new handler to be called when a new data is added -func (tc *timeCacher) RegisterHandler(handler func(key []byte, value interface{}), id string) { - if handler == nil { - log.Error("attempt to register a nil handler to a cacher object", "id", id) - return - } - - tc.mutAddedDataHandlers.Lock() - tc.mapDataHandlers[id] = handler - tc.mutAddedDataHandlers.Unlock() -} - -// UnRegisterHandler removes the handler from the list -func (tc *timeCacher) UnRegisterHandler(id string) { - tc.mutAddedDataHandlers.Lock() - delete(tc.mapDataHandlers, id) - tc.mutAddedDataHandlers.Unlock() -} - -func (tc *timeCacher) callAddedDataHandlers(key []byte, value interface{}) { - tc.mutAddedDataHandlers.RLock() - for _, handler := range tc.mapDataHandlers { - go handler(key, value) - } - tc.mutAddedDataHandlers.RUnlock() -} - -// Close will close the internal sweep go routine -func (tc *timeCacher) Close() error { - if tc.cancelFunc != nil { - tc.cancelFunc() - } - - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (tc *timeCacher) IsInterfaceNil() bool { - return tc == nil -} diff --git a/storage/timecache/timeCacher_test.go b/storage/timecache/timeCacher_test.go deleted file mode 100644 index c7fc3c91644..00000000000 --- a/storage/timecache/timeCacher_test.go +++ /dev/null @@ -1,423 +0,0 @@ -package timecache_test - -import ( - "bytes" - "fmt" - "math" - "sort" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/timecache" - "github.com/stretchr/testify/assert" -) - -func createArgTimeCacher() timecache.ArgTimeCacher { - return timecache.ArgTimeCacher{ - DefaultSpan: time.Minute, - CacheExpiry: time.Minute, - } -} - -func createKeysVals(numOfPairs int) ([][]byte, [][]byte) { - keys := make([][]byte, numOfPairs) - vals := make([][]byte, numOfPairs) - for i := 0; i < numOfPairs; i++ { - keys[i] = []byte("k" + string(rune(i))) - vals[i] = []byte("v" + string(rune(i))) - } - - return keys, vals -} - -func TestNewTimeCache(t *testing.T) { - t.Parallel() - - t.Run("invalid DefaultSpan should error", func(t *testing.T) { - t.Parallel() - - arg := createArgTimeCacher() - arg.DefaultSpan = time.Second - time.Nanosecond - cacher, err := timecache.NewTimeCacher(arg) - assert.Nil(t, cacher) - assert.Equal(t, storage.ErrInvalidDefaultSpan, err) - }) - t.Run("invalid CacheExpiry should error", func(t *testing.T) { - t.Parallel() - - arg := createArgTimeCacher() - arg.CacheExpiry = time.Second - time.Nanosecond - cacher, err := timecache.NewTimeCacher(arg) - assert.Nil(t, cacher) - assert.Equal(t, storage.ErrInvalidCacheExpiry, err) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - cacher, err := timecache.NewTimeCacher(createArgTimeCacher()) - assert.Nil(t, err) - assert.False(t, cacher.IsInterfaceNil()) - }) -} - -func TestTimeCacher_Clear(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - numOfPairs := 3 - providedKeys, providedVals := createKeysVals(numOfPairs) - for i := 0; i < numOfPairs; i++ { - cacher.Put(providedKeys[i], providedVals[i], len(providedVals[i])) - } - assert.Equal(t, numOfPairs, cacher.Len()) - - cacher.Clear() - assert.Equal(t, 0, cacher.Len()) -} - -func TestTimeCacher_Close(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - err := cacher.Close() - assert.Nil(t, err) -} - -func TestTimeCacher_Get(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - providedKey, providedVal := []byte("key"), []byte("val") - cacher.Put(providedKey, providedVal, len(providedVal)) - - v, ok := cacher.Get(providedKey) - assert.True(t, ok) - assert.Equal(t, providedVal, v) - - v, ok = cacher.Get([]byte("missing key")) - assert.False(t, ok) - assert.Nil(t, v) -} - -func TestTimeCacher_Has(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - providedKey, providedVal := []byte("key"), []byte("val") - cacher.Put(providedKey, providedVal, len(providedVal)) - - assert.True(t, cacher.Has(providedKey)) - assert.False(t, cacher.Has([]byte("missing key"))) -} - -func TestTimeCacher_HasOrAdd(t *testing.T) { - t.Parallel() - - t.Run("empty or nil key should return false, false", func(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - cacher.RegisterHandler(func(key []byte, value interface{}) { - assert.Fail(t, "should have not added") - }, "test") - t.Run("nil key", func(t *testing.T) { - has, added := cacher.HasOrAdd(nil, nil, 0) - assert.False(t, has) - assert.False(t, added) - assert.Equal(t, 0, cacher.Len()) - time.Sleep(time.Second) - }) - t.Run("empty key", func(t *testing.T) { - has, added := cacher.HasOrAdd(make([]byte, 0), nil, 0) - assert.False(t, has) - assert.False(t, added) - assert.Equal(t, 0, cacher.Len()) - time.Sleep(time.Second) - }) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - numAdded := int64(0) - cacher.RegisterHandler(func(key []byte, value interface{}) { - atomic.AddInt64(&numAdded, 1) - }, "test") - - providedKey, providedVal := []byte("key"), []byte("val") - has, added := cacher.HasOrAdd(providedKey, providedVal, len(providedVal)) - assert.False(t, has) - assert.True(t, added) - time.Sleep(time.Second) - assert.Equal(t, int64(1), atomic.LoadInt64(&numAdded)) - - has, added = cacher.HasOrAdd(providedKey, providedVal, len(providedVal)) - assert.True(t, has) - assert.False(t, added) - time.Sleep(time.Second) - assert.Equal(t, int64(1), atomic.LoadInt64(&numAdded)) - }) -} - -func TestTimeCacher_Keys(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - numOfPairs := 10 - providedKeys, providedVals := createKeysVals(numOfPairs) - for i := 0; i < numOfPairs; i++ { - cacher.Put(providedKeys[i], providedVals[i], len(providedVals[i])) - } - - receivedKeys := cacher.Keys() - assert.Equal(t, numOfPairs, len(receivedKeys)) - - sort.Slice(providedKeys, func(i, j int) bool { - return bytes.Compare(providedKeys[i], providedKeys[j]) < 0 - }) - sort.Slice(receivedKeys, func(i, j int) bool { - return bytes.Compare(receivedKeys[i], receivedKeys[j]) < 0 - }) - assert.Equal(t, providedKeys, receivedKeys) -} - -func TestTimeCacher_Evicted(t *testing.T) { - t.Parallel() - - arg := createArgTimeCacher() - arg.CacheExpiry = 2 * time.Second - arg.DefaultSpan = time.Second - cacher, _ := timecache.NewTimeCacher(arg) - assert.False(t, cacher.IsInterfaceNil()) - - numOfPairs := 2 - providedKeys, providedVals := createKeysVals(numOfPairs) - for i := 0; i < numOfPairs; i++ { - cacher.Put(providedKeys[i], providedVals[i], len(providedVals[i])) - } - assert.Equal(t, numOfPairs, cacher.Len()) - - time.Sleep(2 * arg.CacheExpiry) - assert.Equal(t, 0, cacher.Len()) - err := cacher.Close() - assert.Nil(t, err) -} - -func TestTimeCacher_Peek(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - providedKey, providedVal := []byte("key"), []byte("val") - cacher.Put(providedKey, providedVal, len(providedVal)) - - v, ok := cacher.Peek(providedKey) - assert.True(t, ok) - assert.Equal(t, providedVal, v) - - v, ok = cacher.Peek([]byte("missing key")) - assert.False(t, ok) - assert.Nil(t, v) -} - -func TestTimeCacher_Put(t *testing.T) { - t.Parallel() - - t.Run("empty or nil key should return false, false", func(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - cacher.RegisterHandler(func(key []byte, value interface{}) { - assert.Fail(t, "should have not added") - }, "test") - t.Run("nil key", func(t *testing.T) { - evicted := cacher.Put(nil, nil, 0) - assert.False(t, evicted) - assert.Equal(t, 0, cacher.Len()) - time.Sleep(time.Second) - }) - t.Run("empty key", func(t *testing.T) { - evicted := cacher.Put(make([]byte, 0), nil, 0) - assert.False(t, evicted) - assert.Equal(t, 0, cacher.Len()) - time.Sleep(time.Second) - }) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - numAdded := int64(0) - cacher.RegisterHandler(func(key []byte, value interface{}) { - atomic.AddInt64(&numAdded, 1) - }, "test") - - numOfPairs := 2 - keys, vals := createKeysVals(numOfPairs) - evicted := cacher.Put(keys[0], vals[0], len(vals[0])) - assert.False(t, evicted) - assert.Equal(t, 1, cacher.Len()) - time.Sleep(time.Second) - assert.Equal(t, int64(1), atomic.LoadInt64(&numAdded)) - - evicted = cacher.Put(keys[0], vals[1], len(vals[1])) - assert.False(t, evicted) - assert.Equal(t, 1, cacher.Len()) - time.Sleep(time.Second) - assert.Equal(t, int64(2), atomic.LoadInt64(&numAdded)) - }) -} - -func TestTimeCacher_Remove(t *testing.T) { - t.Parallel() - - defer func() { - if r := recover(); r != nil { - assert.Fail(t, "should not panic") - } - }() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - providedKey, providedVal := []byte("key"), []byte("val") - cacher.Put(providedKey, providedVal, len(providedVal)) - assert.Equal(t, 1, cacher.Len()) - - cacher.Remove(nil) - assert.Equal(t, 1, cacher.Len()) - - cacher.Remove(providedKey) - assert.Equal(t, 0, cacher.Len()) - - cacher.Remove(providedKey) -} - -func TestTimeCacher_SizeInBytesContained(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - - providedKey, providedVal := []byte("key"), []byte("val") - cacher.Put(providedKey, providedVal, len(providedVal)) - - assert.Zero(t, cacher.SizeInBytesContained()) -} - -func TestTimeCacher_RegisterHandler(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.Equal(t, 0, cacher.NumRegisteredHandlers()) - - cacher.RegisterHandler(nil, "") - assert.Equal(t, 0, cacher.NumRegisteredHandlers()) - - cacher.RegisterHandler(func(key []byte, value interface{}) {}, "0") - assert.Equal(t, 1, cacher.NumRegisteredHandlers()) - - cacher.RegisterHandler(nil, "") - assert.Equal(t, 1, cacher.NumRegisteredHandlers()) -} - -func TestTimeCacher_UnRegisterHandler(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.Equal(t, 0, cacher.NumRegisteredHandlers()) - - cacher.UnRegisterHandler("0") - - cacher.RegisterHandler(func(key []byte, value interface{}) {}, "0") - assert.Equal(t, 1, cacher.NumRegisteredHandlers()) - - cacher.UnRegisterHandler("1") - assert.Equal(t, 1, cacher.NumRegisteredHandlers()) - - cacher.UnRegisterHandler("0") - assert.Equal(t, 0, cacher.NumRegisteredHandlers()) -} - -func TestTimeCacher_MaxSize(t *testing.T) { - t.Parallel() - - cacher, _ := timecache.NewTimeCacher(createArgTimeCacher()) - assert.False(t, cacher.IsInterfaceNil()) - assert.Equal(t, math.MaxInt32, cacher.MaxSize()) -} - -func TestTimeCacher_ConcurrentOperations(t *testing.T) { - t.Parallel() - - tc, _ := timecache.NewTimeCacher(createArgTimeCacher()) - numOperations := 1000 - wg := &sync.WaitGroup{} - wg.Add(numOperations) - for i := 0; i < numOperations; i++ { - go func(idx int) { - time.Sleep(time.Millisecond * 10) - - switch idx % 14 { - case 0: - tc.Clear() - case 1: - _ = tc.Put(createKeyByteSlice(idx), createValueByteSlice(idx), 0) - case 2: - _, _ = tc.Get(createKeyByteSlice(idx)) - case 3: - _ = tc.Has([]byte(fmt.Sprintf("key%d", idx))) - case 4: - _, _ = tc.Peek(createKeyByteSlice(idx)) - case 5: - _, _ = tc.HasOrAdd(createKeyByteSlice(idx), createValueByteSlice(idx), 0) - case 6: - tc.Remove(createKeyByteSlice(idx)) - case 7: - _ = tc.Keys() - case 8: - _ = tc.Len() - case 9: - _ = tc.SizeInBytesContained() - case 10: - _ = tc.MaxSize() - case 11: - tc.RegisterHandler(nil, "") - case 12: - tc.UnRegisterHandler("") - case 13: - _ = tc.Close() - default: - assert.Fail(t, "test setup error, change the line 'switch idx % xxx {' from this test") - } - - wg.Done() - }(i) - } - - wg.Wait() -} - -func createKeyByteSlice(index int) []byte { - return []byte(fmt.Sprintf("key%d", index)) -} - -func createValueByteSlice(index int) []byte { - return []byte(fmt.Sprintf("value%d", index)) -} diff --git a/storage/txcache/benchmarks.sh b/storage/txcache/benchmarks.sh deleted file mode 100644 index a3f9fa36f77..00000000000 --- a/storage/txcache/benchmarks.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -go test -bench="BenchmarkSendersMap_GetSnapshotAscending$" -benchtime=1x diff --git a/storage/txcache/config.go b/storage/txcache/config.go deleted file mode 100644 index d9d14eadca0..00000000000 --- a/storage/txcache/config.go +++ /dev/null @@ -1,122 +0,0 @@ -package txcache - -import ( - "encoding/json" - "fmt" - - "github.com/ElrondNetwork/elrond-go/storage" -) - -const numChunksLowerBound = 1 -const numChunksUpperBound = 128 -const maxNumItemsLowerBound = 4 -const maxNumBytesLowerBound = maxNumItemsLowerBound * 1 -const maxNumBytesUpperBound = 1_073_741_824 // one GB -const maxNumItemsPerSenderLowerBound = 1 -const maxNumBytesPerSenderLowerBound = maxNumItemsPerSenderLowerBound * 1 -const maxNumBytesPerSenderUpperBound = 33_554_432 // 32 MB -const numTxsToPreemptivelyEvictLowerBound = 1 -const numSendersToPreemptivelyEvictLowerBound = 1 - -// ConfigSourceMe holds cache configuration -type ConfigSourceMe struct { - Name string - NumChunks uint32 - EvictionEnabled bool - NumBytesThreshold uint32 - NumBytesPerSenderThreshold uint32 - CountThreshold uint32 - CountPerSenderThreshold uint32 - NumSendersToPreemptivelyEvict uint32 -} - -type senderConstraints struct { - maxNumTxs uint32 - maxNumBytes uint32 -} - -// TODO: Upon further analysis and brainstorming, add some sensible minimum accepted values for the appropriate fields. -func (config *ConfigSourceMe) verify() error { - if len(config.Name) == 0 { - return fmt.Errorf("%w: config.Name is invalid", storage.ErrInvalidConfig) - } - if config.NumChunks < numChunksLowerBound || config.NumChunks > numChunksUpperBound { - return fmt.Errorf("%w: config.NumChunks is invalid", storage.ErrInvalidConfig) - } - if config.NumBytesPerSenderThreshold < maxNumBytesPerSenderLowerBound || config.NumBytesPerSenderThreshold > maxNumBytesPerSenderUpperBound { - return fmt.Errorf("%w: config.NumBytesPerSenderThreshold is invalid", storage.ErrInvalidConfig) - } - if config.CountPerSenderThreshold < maxNumItemsPerSenderLowerBound { - return fmt.Errorf("%w: config.CountPerSenderThreshold is invalid", storage.ErrInvalidConfig) - } - if config.EvictionEnabled { - if config.NumBytesThreshold < maxNumBytesLowerBound || config.NumBytesThreshold > maxNumBytesUpperBound { - return fmt.Errorf("%w: config.NumBytesThreshold is invalid", storage.ErrInvalidConfig) - } - if config.CountThreshold < maxNumItemsLowerBound { - return fmt.Errorf("%w: config.CountThreshold is invalid", storage.ErrInvalidConfig) - } - if config.NumSendersToPreemptivelyEvict < numSendersToPreemptivelyEvictLowerBound { - return fmt.Errorf("%w: config.NumSendersToPreemptivelyEvict is invalid", storage.ErrInvalidConfig) - } - } - - return nil -} - -func (config *ConfigSourceMe) getSenderConstraints() senderConstraints { - return senderConstraints{ - maxNumBytes: config.NumBytesPerSenderThreshold, - maxNumTxs: config.CountPerSenderThreshold, - } -} - -// String returns a readable representation of the object -func (config *ConfigSourceMe) String() string { - bytes, err := json.Marshal(config) - if err != nil { - log.Error("ConfigSourceMe.String()", "err", err) - } - - return string(bytes) -} - -// ConfigDestinationMe holds cache configuration -type ConfigDestinationMe struct { - Name string - NumChunks uint32 - MaxNumItems uint32 - MaxNumBytes uint32 - NumItemsToPreemptivelyEvict uint32 -} - -// TODO: Upon further analysis and brainstorming, add some sensible minimum accepted values for the appropriate fields. -func (config *ConfigDestinationMe) verify() error { - if len(config.Name) == 0 { - return fmt.Errorf("%w: config.Name is invalid", storage.ErrInvalidConfig) - } - if config.NumChunks < numChunksLowerBound || config.NumChunks > numChunksUpperBound { - return fmt.Errorf("%w: config.NumChunks is invalid", storage.ErrInvalidConfig) - } - if config.MaxNumItems < maxNumItemsLowerBound { - return fmt.Errorf("%w: config.MaxNumItems is invalid", storage.ErrInvalidConfig) - } - if config.MaxNumBytes < maxNumBytesLowerBound || config.MaxNumBytes > maxNumBytesUpperBound { - return fmt.Errorf("%w: config.MaxNumBytes is invalid", storage.ErrInvalidConfig) - } - if config.NumItemsToPreemptivelyEvict < numTxsToPreemptivelyEvictLowerBound { - return fmt.Errorf("%w: config.NumItemsToPreemptivelyEvict is invalid", storage.ErrInvalidConfig) - } - - return nil -} - -// String returns a readable representation of the object -func (config *ConfigDestinationMe) String() string { - bytes, err := json.Marshal(config) - if err != nil { - log.Error("ConfigDestinationMe.String()", "err", err) - } - - return string(bytes) -} diff --git a/storage/txcache/constants.go b/storage/txcache/constants.go deleted file mode 100644 index a76fb3d3cc0..00000000000 --- a/storage/txcache/constants.go +++ /dev/null @@ -1,9 +0,0 @@ -package txcache - -const estimatedNumOfSweepableSendersPerSelection = 100 - -const senderGracePeriodLowerBound = 2 - -const senderGracePeriodUpperBound = 2 - -const numEvictedTxsToDisplay = 3 diff --git a/storage/txcache/crossTxCache.go b/storage/txcache/crossTxCache.go deleted file mode 100644 index 1845de9294e..00000000000 --- a/storage/txcache/crossTxCache.go +++ /dev/null @@ -1,121 +0,0 @@ -package txcache - -import ( - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/immunitycache" -) - -var _ storage.Cacher = (*CrossTxCache)(nil) - -// CrossTxCache holds cross-shard transactions (where destination == me) -type CrossTxCache struct { - *immunitycache.ImmunityCache - config ConfigDestinationMe -} - -// NewCrossTxCache creates a new transactions cache -func NewCrossTxCache(config ConfigDestinationMe) (*CrossTxCache, error) { - log.Debug("NewCrossTxCache", "config", config.String()) - - err := config.verify() - if err != nil { - return nil, err - } - - immunityCacheConfig := immunitycache.CacheConfig{ - Name: config.Name, - NumChunks: config.NumChunks, - MaxNumBytes: config.MaxNumBytes, - MaxNumItems: config.MaxNumItems, - NumItemsToPreemptivelyEvict: config.NumItemsToPreemptivelyEvict, - } - - immunityCache, err := immunitycache.NewImmunityCache(immunityCacheConfig) - if err != nil { - return nil, err - } - - cache := CrossTxCache{ - ImmunityCache: immunityCache, - config: config, - } - - return &cache, nil -} - -// ImmunizeTxsAgainstEviction marks items as non-evictable -func (cache *CrossTxCache) ImmunizeTxsAgainstEviction(keys [][]byte) { - numNow, numFuture := cache.ImmunityCache.ImmunizeKeys(keys) - log.Trace("CrossTxCache.ImmunizeTxsAgainstEviction()", - "name", cache.config.Name, - "len(keys)", len(keys), - "numNow", numNow, - "numFuture", numFuture, - ) - cache.Diagnose(false) -} - -// AddTx adds a transaction in the cache -func (cache *CrossTxCache) AddTx(tx *WrappedTransaction) (has, added bool) { - return cache.HasOrAdd(tx.TxHash, tx, int(tx.Size)) -} - -// GetByTxHash gets the transaction by hash -func (cache *CrossTxCache) GetByTxHash(txHash []byte) (*WrappedTransaction, bool) { - item, ok := cache.ImmunityCache.Get(txHash) - if !ok { - return nil, false - } - tx, ok := item.(*WrappedTransaction) - if !ok { - return nil, false - } - - return tx, true -} - -// Get returns the unwrapped payload of a TransactionWrapper -// Implemented for compatibility reasons (see txPoolsCleaner.go). -func (cache *CrossTxCache) Get(key []byte) (value interface{}, ok bool) { - wrapped, ok := cache.GetByTxHash(key) - if !ok { - return nil, false - } - - return wrapped.Tx, true -} - -// Peek returns the unwrapped payload of a TransactionWrapper -// Implemented for compatibility reasons (see transactions.go, common.go). -func (cache *CrossTxCache) Peek(key []byte) (value interface{}, ok bool) { - return cache.Get(key) -} - -// RemoveTxByHash removes tx by hash -func (cache *CrossTxCache) RemoveTxByHash(txHash []byte) bool { - return cache.RemoveWithResult(txHash) -} - -// ForEachTransaction iterates over the transactions in the cache -func (cache *CrossTxCache) ForEachTransaction(function ForEachTransaction) { - cache.ForEachItem(func(key []byte, item interface{}) { - tx, ok := item.(*WrappedTransaction) - if !ok { - return - } - - function(key, tx) - }) -} - -// GetTransactionsPoolForSender returns an empty slice, only to respect the interface -// CrossTxCache does not support transaction selection (not applicable, since transactions are already half-executed), -// thus does not handle nonces, nonce gaps etc. -func (cache *CrossTxCache) GetTransactionsPoolForSender(_ string) []*WrappedTransaction { - return make([]*WrappedTransaction, 0) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (cache *CrossTxCache) IsInterfaceNil() bool { - return cache == nil -} diff --git a/storage/txcache/crossTxCache_test.go b/storage/txcache/crossTxCache_test.go deleted file mode 100644 index eca4a64b716..00000000000 --- a/storage/txcache/crossTxCache_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package txcache - -import ( - "fmt" - "math" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestCrossTxCache_DoImmunizeTxsAgainstEviction(t *testing.T) { - cache := newCrossTxCacheToTest(1, 8, math.MaxUint16) - - cache.addTestTxs("a", "b", "c", "d") - numNow, numFuture := cache.ImmunizeKeys(hashesAsBytes([]string{"a", "b", "e", "f"})) - require.Equal(t, 2, numNow) - require.Equal(t, 2, numFuture) - require.Equal(t, 4, cache.Len()) - - cache.addTestTxs("e", "f", "g", "h") - require.ElementsMatch(t, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, hashesAsStrings(cache.Keys())) - - cache.addTestTxs("i", "j", "k", "l") - require.ElementsMatch(t, []string{"a", "b", "e", "f", "i", "j", "k", "l"}, hashesAsStrings(cache.Keys())) -} - -func TestCrossTxCache_Get(t *testing.T) { - cache := newCrossTxCacheToTest(1, 8, math.MaxUint16) - - cache.addTestTxs("a", "b", "c", "d") - a, ok := cache.GetByTxHash([]byte("a")) - require.True(t, ok) - require.NotNil(t, a) - - x, ok := cache.GetByTxHash([]byte("x")) - require.False(t, ok) - require.Nil(t, x) - - aTx, ok := cache.Get([]byte("a")) - require.True(t, ok) - require.NotNil(t, aTx) - require.Equal(t, a.Tx, aTx) - - xTx, ok := cache.Get([]byte("x")) - require.False(t, ok) - require.Nil(t, xTx) - - aTx, ok = cache.Peek([]byte("a")) - require.True(t, ok) - require.NotNil(t, aTx) - require.Equal(t, a.Tx, aTx) - - xTx, ok = cache.Peek([]byte("x")) - require.False(t, ok) - require.Nil(t, xTx) - - require.Equal(t, make([]*WrappedTransaction, 0), cache.GetTransactionsPoolForSender("")) -} - -func newCrossTxCacheToTest(numChunks uint32, maxNumItems uint32, numMaxBytes uint32) *CrossTxCache { - cache, err := NewCrossTxCache(ConfigDestinationMe{ - Name: "test", - NumChunks: numChunks, - MaxNumItems: maxNumItems, - MaxNumBytes: numMaxBytes, - NumItemsToPreemptivelyEvict: numChunks * 1, - }) - if err != nil { - panic(fmt.Sprintf("newCrossTxCacheToTest(): %s", err)) - } - - return cache -} - -func (cache *CrossTxCache) addTestTxs(hashes ...string) { - for _, hash := range hashes { - _, _ = cache.addTestTx(hash) - } -} - -func (cache *CrossTxCache) addTestTx(hash string) (ok, added bool) { - return cache.AddTx(createTx([]byte(hash), ".", uint64(42))) -} diff --git a/storage/txcache/disabledCache.go b/storage/txcache/disabledCache.go deleted file mode 100644 index f2072dbc50b..00000000000 --- a/storage/txcache/disabledCache.go +++ /dev/null @@ -1,133 +0,0 @@ -package txcache - -import ( - "github.com/ElrondNetwork/elrond-go/storage" -) - -var _ storage.Cacher = (*DisabledCache)(nil) - -// DisabledCache represents a disabled cache -type DisabledCache struct { -} - -// NewDisabledCache creates a new disabled cache -func NewDisabledCache() *DisabledCache { - return &DisabledCache{} -} - -// AddTx does nothing -func (cache *DisabledCache) AddTx(_ *WrappedTransaction) (ok bool, added bool) { - return false, false -} - -// GetByTxHash returns no transaction -func (cache *DisabledCache) GetByTxHash(_ []byte) (*WrappedTransaction, bool) { - return nil, false -} - -// SelectTransactionsWithBandwidth returns an empty slice -func (cache *DisabledCache) SelectTransactionsWithBandwidth(_ int, _ int, _ uint64) []*WrappedTransaction { - return make([]*WrappedTransaction, 0) -} - -// RemoveTxByHash does nothing -func (cache *DisabledCache) RemoveTxByHash(_ []byte) bool { - return false -} - -// Len returns zero -func (cache *DisabledCache) Len() int { - return 0 -} - -// SizeInBytesContained returns 0 -func (cache *DisabledCache) SizeInBytesContained() uint64 { - return 0 -} - -// NumBytes returns zero -func (cache *DisabledCache) NumBytes() int { - return 0 -} - -// ForEachTransaction does nothing -func (cache *DisabledCache) ForEachTransaction(_ ForEachTransaction) { -} - -// Clear does nothing -func (cache *DisabledCache) Clear() { -} - -// Put does nothing -func (cache *DisabledCache) Put(_ []byte, _ interface{}, _ int) (evicted bool) { - return false -} - -// Get returns no transaction -func (cache *DisabledCache) Get(_ []byte) (value interface{}, ok bool) { - return nil, false -} - -// Has returns false -func (cache *DisabledCache) Has(_ []byte) bool { - return false -} - -// Peek returns no transaction -func (cache *DisabledCache) Peek(_ []byte) (value interface{}, ok bool) { - return nil, false -} - -// HasOrAdd returns false, does nothing -func (cache *DisabledCache) HasOrAdd(_ []byte, _ interface{}, _ int) (has, added bool) { - return false, false -} - -// Remove does nothing -func (cache *DisabledCache) Remove(_ []byte) { -} - -// Keys returns an empty slice -func (cache *DisabledCache) Keys() [][]byte { - return make([][]byte, 0) -} - -// MaxSize returns zero -func (cache *DisabledCache) MaxSize() int { - return 0 -} - -// RegisterHandler does nothing -func (cache *DisabledCache) RegisterHandler(func(key []byte, value interface{}), string) { -} - -// UnRegisterHandler does nothing -func (cache *DisabledCache) UnRegisterHandler(string) { -} - -// NotifyAccountNonce does nothing -func (cache *DisabledCache) NotifyAccountNonce(_ []byte, _ uint64) { -} - -// ImmunizeTxsAgainstEviction does nothing -func (cache *DisabledCache) ImmunizeTxsAgainstEviction(_ [][]byte) { -} - -// Diagnose does nothing -func (cache *DisabledCache) Diagnose(_ bool) { -} - -// GetTransactionsPoolForSender returns an empty slice -func (cache *DisabledCache) GetTransactionsPoolForSender(_ string) []*WrappedTransaction { - return make([]*WrappedTransaction, 0) -} - -// Close does nothing -func (cache *DisabledCache) Close() error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (cache *DisabledCache) IsInterfaceNil() bool { - return cache == nil -} diff --git a/storage/txcache/disabledCache_test.go b/storage/txcache/disabledCache_test.go deleted file mode 100644 index a19e947aac3..00000000000 --- a/storage/txcache/disabledCache_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package txcache - -import ( - "math" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestDisabledCache_DoesNothing(t *testing.T) { - cache := NewDisabledCache() - - ok, added := cache.AddTx(nil) - require.False(t, ok) - require.False(t, added) - - tx, ok := cache.GetByTxHash([]byte{}) - require.Nil(t, tx) - require.False(t, ok) - - selection := cache.SelectTransactionsWithBandwidth(42, 42, math.MaxUint64) - require.Equal(t, 0, len(selection)) - - removed := cache.RemoveTxByHash([]byte{}) - require.False(t, removed) - - length := cache.Len() - require.Equal(t, 0, length) - - require.NotPanics(t, func() { cache.ForEachTransaction(func(_ []byte, _ *WrappedTransaction) {}) }) - - txs := cache.GetTransactionsPoolForSender("") - require.Equal(t, make([]*WrappedTransaction, 0), txs) - - cache.Clear() - - evicted := cache.Put(nil, nil, 0) - require.False(t, evicted) - - value, ok := cache.Get([]byte{}) - require.Nil(t, value) - require.False(t, ok) - - value, ok = cache.Peek([]byte{}) - require.Nil(t, value) - require.False(t, ok) - - has := cache.Has([]byte{}) - require.False(t, has) - - has, added = cache.HasOrAdd([]byte{}, nil, 0) - require.False(t, has) - require.False(t, added) - - cache.Remove([]byte{}) - - keys := cache.Keys() - require.Equal(t, 0, len(keys)) - - maxSize := cache.MaxSize() - require.Equal(t, 0, maxSize) - - require.NotPanics(t, func() { cache.RegisterHandler(func(_ []byte, _ interface{}) {}, "") }) - require.False(t, cache.IsInterfaceNil()) - - err := cache.Close() - require.Nil(t, err) -} diff --git a/storage/txcache/eviction.go b/storage/txcache/eviction.go deleted file mode 100644 index fc41f9c29f7..00000000000 --- a/storage/txcache/eviction.go +++ /dev/null @@ -1,128 +0,0 @@ -package txcache - -import ( - "github.com/ElrondNetwork/elrond-go-core/core" -) - -// doEviction does cache eviction -// We do not allow more evictions to start concurrently -func (cache *TxCache) doEviction() { - if cache.isEvictionInProgress.IsSet() { - return - } - - if !cache.isCapacityExceeded() { - return - } - - cache.evictionMutex.Lock() - defer cache.evictionMutex.Unlock() - - _ = cache.isEvictionInProgress.SetReturningPrevious() - defer cache.isEvictionInProgress.Reset() - - if !cache.isCapacityExceeded() { - return - } - - stopWatch := cache.monitorEvictionStart() - cache.makeSnapshotOfSenders() - - journal := evictionJournal{} - journal.passOneNumSteps, journal.passOneNumTxs, journal.passOneNumSenders = cache.evictSendersInLoop() - journal.evictionPerformed = true - cache.evictionJournal = journal - - cache.monitorEvictionEnd(stopWatch) - cache.destroySnapshotOfSenders() -} - -func (cache *TxCache) makeSnapshotOfSenders() { - cache.evictionSnapshotOfSenders = cache.txListBySender.getSnapshotAscending() -} - -func (cache *TxCache) destroySnapshotOfSenders() { - cache.evictionSnapshotOfSenders = nil -} - -func (cache *TxCache) isCapacityExceeded() bool { - return cache.areThereTooManyBytes() || cache.areThereTooManySenders() || cache.areThereTooManyTxs() -} - -func (cache *TxCache) areThereTooManyBytes() bool { - numBytes := cache.NumBytes() - tooManyBytes := numBytes > int(cache.config.NumBytesThreshold) - return tooManyBytes -} - -func (cache *TxCache) areThereTooManySenders() bool { - numSenders := cache.CountSenders() - tooManySenders := numSenders > uint64(cache.config.CountThreshold) - return tooManySenders -} - -func (cache *TxCache) areThereTooManyTxs() bool { - numTxs := cache.CountTx() - tooManyTxs := numTxs > uint64(cache.config.CountThreshold) - return tooManyTxs -} - -// This is called concurrently by two goroutines: the eviction one and the sweeping one -func (cache *TxCache) doEvictItems(txsToEvict [][]byte, sendersToEvict []string) (countTxs uint32, countSenders uint32) { - countTxs = cache.txByHash.RemoveTxsBulk(txsToEvict) - countSenders = cache.txListBySender.RemoveSendersBulk(sendersToEvict) - return -} - -func (cache *TxCache) evictSendersInLoop() (uint32, uint32, uint32) { - return cache.evictSendersWhile(cache.isCapacityExceeded) -} - -// evictSendersWhileTooManyTxs removes transactions in a loop, as long as "shouldContinue" is true -// One batch of senders is removed in each step -func (cache *TxCache) evictSendersWhile(shouldContinue func() bool) (step uint32, numTxs uint32, numSenders uint32) { - if !shouldContinue() { - return - } - - snapshot := cache.evictionSnapshotOfSenders - snapshotLength := uint32(len(snapshot)) - batchSize := cache.config.NumSendersToPreemptivelyEvict - batchStart := uint32(0) - - for step = 0; shouldContinue(); step++ { - batchEnd := batchStart + batchSize - batchEndBounded := core.MinUint32(batchEnd, snapshotLength) - batch := snapshot[batchStart:batchEndBounded] - - numTxsEvictedInStep, numSendersEvictedInStep := cache.evictSendersAndTheirTxs(batch) - - numTxs += numTxsEvictedInStep - numSenders += numSendersEvictedInStep - batchStart += batchSize - - reachedEnd := batchStart >= snapshotLength - noTxsEvicted := numTxsEvictedInStep == 0 - incompleteBatch := numSendersEvictedInStep < batchSize - - shouldBreak := noTxsEvicted || incompleteBatch || reachedEnd - if shouldBreak { - break - } - } - - return -} - -// This is called concurrently by two goroutines: the eviction one and the sweeping one -func (cache *TxCache) evictSendersAndTheirTxs(listsToEvict []*txListForSender) (uint32, uint32) { - sendersToEvict := make([]string, 0, len(listsToEvict)) - txsToEvict := make([][]byte, 0, approximatelyCountTxInLists(listsToEvict)) - - for _, txList := range listsToEvict { - sendersToEvict = append(sendersToEvict, txList.sender) - txsToEvict = append(txsToEvict, txList.getTxHashes()...) - } - - return cache.doEvictItems(txsToEvict, sendersToEvict) -} diff --git a/storage/txcache/eviction_test.go b/storage/txcache/eviction_test.go deleted file mode 100644 index 1df1007f1a6..00000000000 --- a/storage/txcache/eviction_test.go +++ /dev/null @@ -1,308 +0,0 @@ -package txcache - -import ( - "math" - "sync" - "testing" - - "github.com/ElrondNetwork/elrond-go/dataRetriever" - "github.com/stretchr/testify/require" -) - -func TestEviction_EvictSendersWhileTooManyTxs(t *testing.T) { - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 16, - CountThreshold: 100, - CountPerSenderThreshold: math.MaxUint32, - NumSendersToPreemptivelyEvict: 20, - NumBytesThreshold: maxNumBytesUpperBound, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - } - - txGasHandler, _ := dummyParams() - - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - // 200 senders, each with 1 transaction - for index := 0; index < 200; index++ { - sender := string(createFakeSenderAddress(index)) - cache.AddTx(createTx([]byte{byte(index)}, sender, uint64(1))) - } - - require.Equal(t, int64(200), cache.txListBySender.counter.Get()) - require.Equal(t, int64(200), cache.txByHash.counter.Get()) - - cache.makeSnapshotOfSenders() - steps, nTxs, nSenders := cache.evictSendersInLoop() - - require.Equal(t, uint32(5), steps) - require.Equal(t, uint32(100), nTxs) - require.Equal(t, uint32(100), nSenders) - require.Equal(t, int64(100), cache.txListBySender.counter.Get()) - require.Equal(t, int64(100), cache.txByHash.counter.Get()) -} - -func TestEviction_EvictSendersWhileTooManyBytes(t *testing.T) { - numBytesPerTx := uint32(1000) - - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 16, - CountThreshold: math.MaxUint32, - CountPerSenderThreshold: math.MaxUint32, - NumBytesThreshold: numBytesPerTx * 100, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - NumSendersToPreemptivelyEvict: 20, - } - txGasHandler, _ := dummyParams() - - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - // 200 senders, each with 1 transaction - for index := 0; index < 200; index++ { - sender := string(createFakeSenderAddress(index)) - cache.AddTx(createTxWithParams([]byte{byte(index)}, sender, uint64(1), uint64(numBytesPerTx), 10000, 100*oneBillion)) - } - - require.Equal(t, int64(200), cache.txListBySender.counter.Get()) - require.Equal(t, int64(200), cache.txByHash.counter.Get()) - - cache.makeSnapshotOfSenders() - steps, nTxs, nSenders := cache.evictSendersInLoop() - - require.Equal(t, uint32(5), steps) - require.Equal(t, uint32(100), nTxs) - require.Equal(t, uint32(100), nSenders) - require.Equal(t, int64(100), cache.txListBySender.counter.Get()) - require.Equal(t, int64(100), cache.txByHash.counter.Get()) -} - -func TestEviction_DoEvictionDoneInPassTwo_BecauseOfCount(t *testing.T) { - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 16, - NumBytesThreshold: maxNumBytesUpperBound, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountThreshold: 2, - CountPerSenderThreshold: math.MaxUint32, - NumSendersToPreemptivelyEvict: 2, - } - txGasHandler, _ := dummyParamsWithGasPrice(100 * oneBillion) - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - cache.AddTx(createTxWithParams([]byte("hash-alice"), "alice", uint64(1), 1000, 100000, 100*oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-bob"), "bob", uint64(1), 1000, 100000, 100*oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-carol"), "carol", uint64(1), 1000, 100000, 700*oneBillion)) - - cache.doEviction() - require.Equal(t, uint32(2), cache.evictionJournal.passOneNumTxs) - require.Equal(t, uint32(2), cache.evictionJournal.passOneNumSenders) - require.Equal(t, uint32(1), cache.evictionJournal.passOneNumSteps) - - // Alice and Bob evicted. Carol still there. - _, ok := cache.GetByTxHash([]byte("hash-carol")) - require.True(t, ok) - require.Equal(t, uint64(1), cache.CountSenders()) - require.Equal(t, uint64(1), cache.CountTx()) -} - -func TestEviction_DoEvictionDoneInPassTwo_BecauseOfSize(t *testing.T) { - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 16, - CountThreshold: math.MaxUint32, - CountPerSenderThreshold: math.MaxUint32, - NumBytesThreshold: 1000, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - NumSendersToPreemptivelyEvict: 2, - } - - txGasHandler, _ := dummyParamsWithGasPrice(oneBillion) - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - cache.AddTx(createTxWithParams([]byte("hash-alice"), "alice", uint64(1), 128, 100000, oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-bob"), "bob", uint64(1), 128, 100000, oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-dave1"), "dave", uint64(3), 128, 40000000, oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-dave2"), "dave", uint64(1), 128, 50000, oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-dave3"), "dave", uint64(2), 128, 50000, oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-chris"), "chris", uint64(1), 128, 50000, oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-richard"), "richard", uint64(1), 128, 50000, uint64(1.2*oneBillion))) - cache.AddTx(createTxWithParams([]byte("hash-carol"), "carol", uint64(1), 128, 100000, 7*oneBillion)) - cache.AddTx(createTxWithParams([]byte("hash-eve"), "eve", uint64(1), 128, 50000, 4*oneBillion)) - - scoreAlice := cache.getScoreOfSender("alice") - scoreBob := cache.getScoreOfSender("bob") - scoreDave := cache.getScoreOfSender("dave") - scoreCarol := cache.getScoreOfSender("carol") - scoreEve := cache.getScoreOfSender("eve") - scoreChris := cache.getScoreOfSender("chris") - scoreRichard := cache.getScoreOfSender("richard") - - require.Equal(t, uint32(23), scoreAlice) - require.Equal(t, uint32(23), scoreBob) - require.Equal(t, uint32(7), scoreDave) - require.Equal(t, uint32(100), scoreCarol) - require.Equal(t, uint32(100), scoreEve) - require.Equal(t, uint32(33), scoreChris) - require.Equal(t, uint32(54), scoreRichard) - - cache.doEviction() - require.Equal(t, uint32(4), cache.evictionJournal.passOneNumTxs) - require.Equal(t, uint32(2), cache.evictionJournal.passOneNumSenders) - require.Equal(t, uint32(1), cache.evictionJournal.passOneNumSteps) - - // Alice and Bob evicted (lower score). Carol and Eve still there. - _, ok := cache.GetByTxHash([]byte("hash-carol")) - require.True(t, ok) - require.Equal(t, uint64(5), cache.CountSenders()) - require.Equal(t, uint64(5), cache.CountTx()) -} - -func TestEviction_doEvictionDoesNothingWhenAlreadyInProgress(t *testing.T) { - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 1, - CountThreshold: 0, - NumSendersToPreemptivelyEvict: 1, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - } - - txGasHandler, _ := dummyParams() - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - cache.AddTx(createTx([]byte("hash-alice"), "alice", uint64(1))) - - _ = cache.isEvictionInProgress.SetReturningPrevious() - cache.doEviction() - - require.False(t, cache.evictionJournal.evictionPerformed) -} - -func TestEviction_evictSendersInLoop_CoverLoopBreak_WhenSmallBatch(t *testing.T) { - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 1, - CountThreshold: 0, - NumSendersToPreemptivelyEvict: 42, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - } - - txGasHandler, _ := dummyParams() - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - cache.AddTx(createTx([]byte("hash-alice"), "alice", uint64(1))) - - cache.makeSnapshotOfSenders() - - steps, nTxs, nSenders := cache.evictSendersInLoop() - require.Equal(t, uint32(0), steps) - require.Equal(t, uint32(1), nTxs) - require.Equal(t, uint32(1), nSenders) -} - -func TestEviction_evictSendersWhile_ShouldContinueBreak(t *testing.T) { - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 1, - CountThreshold: 0, - NumSendersToPreemptivelyEvict: 1, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - } - - txGasHandler, _ := dummyParams() - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - cache.AddTx(createTx([]byte("hash-alice"), "alice", uint64(1))) - cache.AddTx(createTx([]byte("hash-bob"), "bob", uint64(1))) - - cache.makeSnapshotOfSenders() - - steps, nTxs, nSenders := cache.evictSendersWhile(func() bool { - return false - }) - - require.Equal(t, uint32(0), steps) - require.Equal(t, uint32(0), nTxs) - require.Equal(t, uint32(0), nSenders) -} - -// This seems to be the most reasonable "bad-enough" (not worst) scenario to benchmark: -// 25000 senders with 10 transactions each, with default "NumSendersToPreemptivelyEvict". -// ~1 second on average laptop. -func Test_AddWithEviction_UniformDistribution_25000x10(t *testing.T) { - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 16, - EvictionEnabled: true, - NumBytesThreshold: 1000000000, - CountThreshold: 240000, - NumSendersToPreemptivelyEvict: dataRetriever.TxPoolNumSendersToPreemptivelyEvict, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - } - - txGasHandler, _ := dummyParams() - numSenders := 25000 - numTxsPerSender := 10 - - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - addManyTransactionsWithUniformDistribution(cache, numSenders, numTxsPerSender) - - // Sometimes (due to map iteration non-determinism), more eviction happens - one more step of 100 senders. - require.LessOrEqual(t, uint32(cache.CountTx()), config.CountThreshold) - require.GreaterOrEqual(t, uint32(cache.CountTx()), config.CountThreshold-config.NumSendersToPreemptivelyEvict*uint32(numTxsPerSender)) -} - -func Test_EvictSendersAndTheirTxs_Concurrently(t *testing.T) { - cache := newUnconstrainedCacheToTest() - var wg sync.WaitGroup - - for i := 0; i < 10; i++ { - wg.Add(3) - - go func() { - cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - cache.AddTx(createTx([]byte("alice-y"), "alice", 43)) - cache.AddTx(createTx([]byte("bob-x"), "bob", 42)) - cache.AddTx(createTx([]byte("bob-y"), "bob", 43)) - cache.Remove([]byte("alice-x")) - cache.Remove([]byte("bob-x")) - wg.Done() - }() - - go func() { - snapshot := cache.txListBySender.getSnapshotAscending() - cache.evictSendersAndTheirTxs(snapshot) - wg.Done() - }() - - go func() { - snapshot := cache.txListBySender.getSnapshotAscending() - cache.evictSendersAndTheirTxs(snapshot) - wg.Done() - }() - } - - wg.Wait() -} diff --git a/storage/txcache/feeComputationHelper.go b/storage/txcache/feeComputationHelper.go deleted file mode 100644 index 66e365dc48c..00000000000 --- a/storage/txcache/feeComputationHelper.go +++ /dev/null @@ -1,80 +0,0 @@ -package txcache - -type feeHelper interface { - gasLimitShift() uint64 - gasPriceShift() uint64 - minPricePerUnit() uint64 - normalizedMinFee() uint64 - minGasPriceFactor() uint64 - IsInterfaceNil() bool -} - -type feeComputationHelper struct { - gasShiftingFactor uint64 - priceShiftingFactor uint64 - minFeeNormalized uint64 - minPPUNormalized uint64 - minPriceFactor uint64 -} - -const priceBinaryResolution = 10 -const gasBinaryResolution = 4 - -func newFeeComputationHelper(minPrice, minGasLimit, minPriceProcessing uint64) *feeComputationHelper { - feeComputeHelper := &feeComputationHelper{} - feeComputeHelper.initializeHelperParameters(minPrice, minGasLimit, minPriceProcessing) - return feeComputeHelper -} - -func (fch *feeComputationHelper) gasLimitShift() uint64 { - return fch.gasShiftingFactor -} - -func (fch *feeComputationHelper) gasPriceShift() uint64 { - return fch.priceShiftingFactor -} - -func (fch *feeComputationHelper) normalizedMinFee() uint64 { - return fch.minFeeNormalized -} - -func (fch *feeComputationHelper) minPricePerUnit() uint64 { - return fch.minPPUNormalized -} - -func (fch *feeComputationHelper) minGasPriceFactor() uint64 { - return fch.minPriceFactor -} - -func (fch *feeComputationHelper) initializeHelperParameters(minPrice, minGasLimit, minPriceProcessing uint64) { - fch.priceShiftingFactor = computeShiftMagnitude(minPrice, priceBinaryResolution) - x := minPriceProcessing >> fch.priceShiftingFactor - for x == 0 && fch.priceShiftingFactor > 0 { - fch.priceShiftingFactor-- - x = minPriceProcessing >> fch.priceShiftingFactor - } - - fch.gasShiftingFactor = computeShiftMagnitude(minGasLimit, gasBinaryResolution) - - fch.minPPUNormalized = minPriceProcessing >> fch.priceShiftingFactor - fch.minFeeNormalized = (minGasLimit >> fch.gasLimitShift()) * (minPrice >> fch.priceShiftingFactor) - fch.minPriceFactor = minPrice / minPriceProcessing -} - -// returns the maximum shift magnitude of the number in order to maintain the given binary resolution -func computeShiftMagnitude(x uint64, resolution uint8) uint64 { - m := uint64(0) - stopCondition := uint64(1) << resolution - shiftStep := uint64(1) - - for i := x; i > stopCondition; i >>= shiftStep { - m += shiftStep - } - - return m -} - -// IsInterfaceNil returns nil if the underlying object is nil -func (fch *feeComputationHelper) IsInterfaceNil() bool { - return fch == nil -} diff --git a/storage/txcache/feeComputationHelper_test.go b/storage/txcache/feeComputationHelper_test.go deleted file mode 100644 index 9a015ccffb7..00000000000 --- a/storage/txcache/feeComputationHelper_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package txcache - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func Test_initializeHelperParameters(t *testing.T) { - fch := &feeComputationHelper{ - gasShiftingFactor: 0, - priceShiftingFactor: 0, - minFeeNormalized: 0, - minPPUNormalized: 0, - minPriceFactor: 0, - } - - fch.initializeHelperParameters(1<<20, 1<<10, 1<<10) - require.Equal(t, uint64(10), fch.priceShiftingFactor) - require.Equal(t, uint64(6), fch.gasShiftingFactor) - require.Equal(t, uint64(1<<10), fch.minPriceFactor) - require.Equal(t, uint64((1<<4)*(1<<10)), fch.minFeeNormalized) - require.Equal(t, uint64(1), fch.minPPUNormalized) - - fch.initializeHelperParameters(1<<22, 1<<17, 1<<7) - require.Equal(t, uint64(7), fch.priceShiftingFactor) - require.Equal(t, uint64(13), fch.gasShiftingFactor) - require.Equal(t, uint64(1<<15), fch.minPriceFactor) - require.Equal(t, uint64((1<<4)*(1<<15)), fch.minFeeNormalized) - require.Equal(t, uint64(1), fch.minPPUNormalized) - - fch.initializeHelperParameters(1<<20, 1<<3, 1<<15) - require.Equal(t, uint64(10), fch.priceShiftingFactor) - require.Equal(t, uint64(0), fch.gasShiftingFactor) - require.Equal(t, uint64(1<<5), fch.minPriceFactor) - require.Equal(t, uint64((1<<3)*(1<<10)), fch.minFeeNormalized) - require.Equal(t, uint64(1<<5), fch.minPPUNormalized) -} - -func Test_newFeeComputationHelper(t *testing.T) { - fch := newFeeComputationHelper(1<<20, 1<<10, 1<<10) - require.Equal(t, uint64(10), fch.priceShiftingFactor) - require.Equal(t, uint64(6), fch.gasShiftingFactor) - require.Equal(t, uint64(1<<10), fch.minPriceFactor) - require.Equal(t, uint64((1<<4)*(1<<10)), fch.minFeeNormalized) - require.Equal(t, uint64(1), fch.minPPUNormalized) -} - -func Test_getters(t *testing.T) { - fch := newFeeComputationHelper(1<<20, 1<<10, 1<<10) - gasShift := fch.gasLimitShift() - gasPriceShift := fch.gasPriceShift() - minFeeNormalized := fch.normalizedMinFee() - minPPUNormalized := fch.minPricePerUnit() - minGasPriceFactor := fch.minGasPriceFactor() - - require.Equal(t, uint64(10), gasPriceShift) - require.Equal(t, uint64(6), gasShift) - require.Equal(t, uint64(1<<10), minGasPriceFactor) - require.Equal(t, uint64((1<<4)*(1<<10)), minFeeNormalized) - require.Equal(t, uint64(1), minPPUNormalized) -} - -func Test_computeShiftMagnitude(t *testing.T) { - shift := computeShiftMagnitude(1<<20, 10) - require.Equal(t, uint64(10), shift) - - shift = computeShiftMagnitude(1<<12, 10) - require.Equal(t, uint64(2), shift) - - shift = computeShiftMagnitude(1<<8, 10) - require.Equal(t, uint64(0), shift) -} diff --git a/storage/txcache/interface.go b/storage/txcache/interface.go deleted file mode 100644 index ea0be076b69..00000000000 --- a/storage/txcache/interface.go +++ /dev/null @@ -1,23 +0,0 @@ -package txcache - -import ( - "github.com/ElrondNetwork/elrond-go-core/data" -) - -type scoreComputer interface { - computeScore(scoreParams senderScoreParams) uint32 -} - -// TxGasHandler handles a transaction gas and gas cost -type TxGasHandler interface { - SplitTxGasInCategories(tx data.TransactionWithFeeHandler) (uint64, uint64) - GasPriceForProcessing(tx data.TransactionWithFeeHandler) uint64 - GasPriceForMove(tx data.TransactionWithFeeHandler) uint64 - MinGasPrice() uint64 - MinGasLimit() uint64 - MinGasPriceForProcessing() uint64 - IsInterfaceNil() bool -} - -// ForEachTransaction is an iterator callback -type ForEachTransaction func(txHash []byte, value *WrappedTransaction) diff --git a/storage/txcache/maps/bucketSortedMap.go b/storage/txcache/maps/bucketSortedMap.go deleted file mode 100644 index 90a94162278..00000000000 --- a/storage/txcache/maps/bucketSortedMap.go +++ /dev/null @@ -1,342 +0,0 @@ -package maps - -import ( - "sync" -) - -// BucketSortedMap is -type BucketSortedMap struct { - mutex sync.RWMutex - nChunks uint32 - nScoreChunks uint32 - maxScore uint32 - chunks []*MapChunk - scoreChunks []*MapChunk -} - -// MapChunk is -type MapChunk struct { - items map[string]BucketSortedMapItem - mutex sync.RWMutex -} - -// NewBucketSortedMap creates a new map. -func NewBucketSortedMap(nChunks uint32, nScoreChunks uint32) *BucketSortedMap { - if nChunks == 0 { - nChunks = 1 - } - if nScoreChunks == 0 { - nScoreChunks = 1 - } - - sortedMap := BucketSortedMap{ - nChunks: nChunks, - nScoreChunks: nScoreChunks, - maxScore: nScoreChunks - 1, - } - - sortedMap.initializeChunks() - - return &sortedMap -} - -func (sortedMap *BucketSortedMap) initializeChunks() { - // Assignment is not an atomic operation, so we have to wrap this in a critical section - sortedMap.mutex.Lock() - defer sortedMap.mutex.Unlock() - - sortedMap.chunks = make([]*MapChunk, sortedMap.nChunks) - sortedMap.scoreChunks = make([]*MapChunk, sortedMap.nScoreChunks) - - for i := uint32(0); i < sortedMap.nChunks; i++ { - sortedMap.chunks[i] = &MapChunk{ - items: make(map[string]BucketSortedMapItem), - } - } - - for i := uint32(0); i < sortedMap.nScoreChunks; i++ { - sortedMap.scoreChunks[i] = &MapChunk{ - items: make(map[string]BucketSortedMapItem), - } - } -} - -// Set puts the item in the map -// This doesn't add the item to the score chunks (not necessary) -func (sortedMap *BucketSortedMap) Set(item BucketSortedMapItem) { - chunk := sortedMap.getChunk(item.GetKey()) - chunk.setItem(item) -} - -// NotifyScoreChange moves or adds the item to the corresponding score chunk -func (sortedMap *BucketSortedMap) NotifyScoreChange(item BucketSortedMapItem, newScore uint32) { - if newScore > sortedMap.maxScore { - newScore = sortedMap.maxScore - } - - newScoreChunk := sortedMap.getScoreChunks()[newScore] - if newScoreChunk != item.GetScoreChunk() { - removeFromScoreChunk(item) - newScoreChunk.setItem(item) - item.SetScoreChunk(newScoreChunk) - } -} - -func removeFromScoreChunk(item BucketSortedMapItem) { - currentScoreChunk := item.GetScoreChunk() - if currentScoreChunk != nil { - currentScoreChunk.removeItem(item) - } -} - -// Get retrieves an element from map under given key. -func (sortedMap *BucketSortedMap) Get(key string) (BucketSortedMapItem, bool) { - chunk := sortedMap.getChunk(key) - chunk.mutex.RLock() - val, ok := chunk.items[key] - chunk.mutex.RUnlock() - return val, ok -} - -// Has looks up an item under specified key -func (sortedMap *BucketSortedMap) Has(key string) bool { - chunk := sortedMap.getChunk(key) - chunk.mutex.RLock() - _, ok := chunk.items[key] - chunk.mutex.RUnlock() - return ok -} - -// Remove removes an element from the map -func (sortedMap *BucketSortedMap) Remove(key string) (interface{}, bool) { - chunk := sortedMap.getChunk(key) - item := chunk.removeItemByKey(key) - if item != nil { - removeFromScoreChunk(item) - } - - return item, item != nil -} - -// getChunk returns the chunk holding the given key. -func (sortedMap *BucketSortedMap) getChunk(key string) *MapChunk { - sortedMap.mutex.RLock() - defer sortedMap.mutex.RUnlock() - return sortedMap.chunks[fnv32Hash(key)%sortedMap.nChunks] -} - -// fnv32Hash implements https://en.wikipedia.org/wiki/Fowler–Noll–Vo_hash_function for 32 bits -func fnv32Hash(key string) uint32 { - hash := uint32(2166136261) - const prime32 = uint32(16777619) - for i := 0; i < len(key); i++ { - hash *= prime32 - hash ^= uint32(key[i]) - } - return hash -} - -// Clear clears the map -func (sortedMap *BucketSortedMap) Clear() { - // There is no need to explicitly remove each item for each chunk - // The garbage collector will remove the data from memory - sortedMap.initializeChunks() -} - -// Count returns the number of elements within the map -func (sortedMap *BucketSortedMap) Count() uint32 { - count := uint32(0) - for _, chunk := range sortedMap.getChunks() { - count += chunk.countItems() - } - return count -} - -// CountSorted returns the number of sorted elements within the map -func (sortedMap *BucketSortedMap) CountSorted() uint32 { - count := uint32(0) - for _, chunk := range sortedMap.getScoreChunks() { - count += chunk.countItems() - } - return count -} - -// ChunksCounts returns the number of elements by chunk -func (sortedMap *BucketSortedMap) ChunksCounts() []uint32 { - counts := make([]uint32, sortedMap.nChunks) - for i, chunk := range sortedMap.getChunks() { - counts[i] = chunk.countItems() - } - return counts -} - -// ScoreChunksCounts returns the number of elements by chunk -func (sortedMap *BucketSortedMap) ScoreChunksCounts() []uint32 { - counts := make([]uint32, sortedMap.nScoreChunks) - for i, chunk := range sortedMap.getScoreChunks() { - counts[i] = chunk.countItems() - } - return counts -} - -// SortedMapIterCb is an iterator callback -type SortedMapIterCb func(key string, value BucketSortedMapItem) - -// GetSnapshotAscending gets a snapshot of the items -func (sortedMap *BucketSortedMap) GetSnapshotAscending() []BucketSortedMapItem { - return sortedMap.getSortedSnapshot(sortedMap.fillSnapshotAscending) -} - -// GetSnapshotDescending gets a snapshot of the items -func (sortedMap *BucketSortedMap) GetSnapshotDescending() []BucketSortedMapItem { - return sortedMap.getSortedSnapshot(sortedMap.fillSnapshotDescending) -} - -// This applies a read lock on all chunks, so that they aren't mutated during snapshot -func (sortedMap *BucketSortedMap) getSortedSnapshot(fillSnapshot func(scoreChunks []*MapChunk, snapshot []BucketSortedMapItem)) []BucketSortedMapItem { - counter := uint32(0) - scoreChunks := sortedMap.getScoreChunks() - - for _, chunk := range scoreChunks { - chunk.mutex.RLock() - counter += uint32(len(chunk.items)) - } - - snapshot := make([]BucketSortedMapItem, counter) - fillSnapshot(scoreChunks, snapshot) - - for _, chunk := range scoreChunks { - chunk.mutex.RUnlock() - } - - return snapshot -} - -// This function should only be called under already read-locked score chunks -func (sortedMap *BucketSortedMap) fillSnapshotAscending(scoreChunks []*MapChunk, snapshot []BucketSortedMapItem) { - i := 0 - for _, chunk := range scoreChunks { - for _, item := range chunk.items { - snapshot[i] = item - i++ - } - } -} - -// This function should only be called under already read-locked score chunks -func (sortedMap *BucketSortedMap) fillSnapshotDescending(scoreChunks []*MapChunk, snapshot []BucketSortedMapItem) { - i := 0 - for chunkIndex := len(scoreChunks) - 1; chunkIndex >= 0; chunkIndex-- { - chunk := scoreChunks[chunkIndex] - for _, item := range chunk.items { - snapshot[i] = item - i++ - } - } -} - -// IterCbSortedAscending iterates over the sorted elements in the map -func (sortedMap *BucketSortedMap) IterCbSortedAscending(callback SortedMapIterCb) { - for _, chunk := range sortedMap.getScoreChunks() { - chunk.forEachItem(callback) - } -} - -// IterCbSortedDescending iterates over the sorted elements in the map -func (sortedMap *BucketSortedMap) IterCbSortedDescending(callback SortedMapIterCb) { - chunks := sortedMap.getScoreChunks() - for i := len(chunks) - 1; i >= 0; i-- { - chunk := chunks[i] - chunk.forEachItem(callback) - } -} - -// Keys returns all keys as []string -func (sortedMap *BucketSortedMap) Keys() []string { - count := sortedMap.Count() - // count is not exact anymore, since we are in a different lock than the one aquired by Count() (but is a good approximation) - keys := make([]string, 0, count) - - for _, chunk := range sortedMap.getChunks() { - keys = chunk.appendKeys(keys) - } - - return keys -} - -// KeysSorted returns all keys of the sorted items as []string -func (sortedMap *BucketSortedMap) KeysSorted() []string { - count := sortedMap.CountSorted() - // count is not exact anymore, since we are in a different lock than the one aquired by CountSorted() (but is a good approximation) - keys := make([]string, 0, count) - - for _, chunk := range sortedMap.getScoreChunks() { - keys = chunk.appendKeys(keys) - } - - return keys -} - -func (sortedMap *BucketSortedMap) getChunks() []*MapChunk { - sortedMap.mutex.RLock() - defer sortedMap.mutex.RUnlock() - return sortedMap.chunks -} - -func (sortedMap *BucketSortedMap) getScoreChunks() []*MapChunk { - sortedMap.mutex.RLock() - defer sortedMap.mutex.RUnlock() - return sortedMap.scoreChunks -} - -func (chunk *MapChunk) removeItem(item BucketSortedMapItem) { - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - - key := item.GetKey() - delete(chunk.items, key) -} - -func (chunk *MapChunk) removeItemByKey(key string) BucketSortedMapItem { - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - - item := chunk.items[key] - delete(chunk.items, key) - return item -} - -func (chunk *MapChunk) setItem(item BucketSortedMapItem) { - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - - key := item.GetKey() - chunk.items[key] = item -} - -func (chunk *MapChunk) countItems() uint32 { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - - return uint32(len(chunk.items)) -} - -func (chunk *MapChunk) forEachItem(callback SortedMapIterCb) { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - - for key, value := range chunk.items { - callback(key, value) - } -} - -func (chunk *MapChunk) appendKeys(keysAccumulator []string) []string { - chunk.mutex.RLock() - defer chunk.mutex.RUnlock() - - for key := range chunk.items { - keysAccumulator = append(keysAccumulator, key) - } - - return keysAccumulator -} diff --git a/storage/txcache/maps/bucketSortedMapItem.go b/storage/txcache/maps/bucketSortedMapItem.go deleted file mode 100644 index 4ba551811ef..00000000000 --- a/storage/txcache/maps/bucketSortedMapItem.go +++ /dev/null @@ -1,8 +0,0 @@ -package maps - -// BucketSortedMapItem defines an item of the bucket sorted map -type BucketSortedMapItem interface { - GetKey() string - GetScoreChunk() *MapChunk - SetScoreChunk(*MapChunk) -} diff --git a/storage/txcache/maps/bucketSortedMap_test.go b/storage/txcache/maps/bucketSortedMap_test.go deleted file mode 100644 index 66d37f63737..00000000000 --- a/storage/txcache/maps/bucketSortedMap_test.go +++ /dev/null @@ -1,421 +0,0 @@ -package maps - -import ( - "fmt" - "sync" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - "github.com/stretchr/testify/require" -) - -type dummyItem struct { - score atomic.Uint32 - key string - chunk *MapChunk - chunkMutex sync.RWMutex - mutex sync.Mutex -} - -func newDummyItem(key string) *dummyItem { - return &dummyItem{ - key: key, - } -} - -func newScoredDummyItem(key string, score uint32) *dummyItem { - item := &dummyItem{ - key: key, - } - item.score.Set(score) - return item -} - -func (item *dummyItem) GetKey() string { - return item.key -} - -func (item *dummyItem) GetScoreChunk() *MapChunk { - item.chunkMutex.RLock() - defer item.chunkMutex.RUnlock() - - return item.chunk -} - -func (item *dummyItem) SetScoreChunk(chunk *MapChunk) { - item.chunkMutex.Lock() - defer item.chunkMutex.Unlock() - - item.chunk = chunk -} - -func (item *dummyItem) simulateMutationThatChangesScore(myMap *BucketSortedMap) { - item.mutex.Lock() - myMap.NotifyScoreChange(item, item.score.Get()) - item.mutex.Unlock() -} - -func simulateMutationThatChangesScore(myMap *BucketSortedMap, key string) { - item, ok := myMap.Get(key) - if !ok { - return - } - - itemAsDummy := item.(*dummyItem) - itemAsDummy.simulateMutationThatChangesScore(myMap) -} - -func TestNewBucketSortedMap(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - require.Equal(t, uint32(4), myMap.nChunks) - require.Equal(t, 4, len(myMap.chunks)) - require.Equal(t, uint32(100), myMap.nScoreChunks) - require.Equal(t, 100, len(myMap.scoreChunks)) - - // 1 is minimum number of chunks - myMap = NewBucketSortedMap(0, 0) - require.Equal(t, uint32(1), myMap.nChunks) - require.Equal(t, uint32(1), myMap.nScoreChunks) -} - -func TestBucketSortedMap_Count(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - myMap.Set(newScoredDummyItem("a", 0)) - myMap.Set(newScoredDummyItem("b", 1)) - myMap.Set(newScoredDummyItem("c", 2)) - myMap.Set(newScoredDummyItem("d", 3)) - - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - simulateMutationThatChangesScore(myMap, "c") - simulateMutationThatChangesScore(myMap, "d") - - require.Equal(t, uint32(4), myMap.Count()) - require.Equal(t, uint32(4), myMap.CountSorted()) - - counts := myMap.ChunksCounts() - require.Equal(t, uint32(1), counts[0]) - require.Equal(t, uint32(1), counts[1]) - require.Equal(t, uint32(1), counts[2]) - require.Equal(t, uint32(1), counts[3]) - - counts = myMap.ScoreChunksCounts() - require.Equal(t, uint32(1), counts[0]) - require.Equal(t, uint32(1), counts[1]) - require.Equal(t, uint32(1), counts[2]) - require.Equal(t, uint32(1), counts[3]) -} - -func TestBucketSortedMap_Keys(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - myMap.Set(newDummyItem("a")) - myMap.Set(newDummyItem("b")) - myMap.Set(newDummyItem("c")) - - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - simulateMutationThatChangesScore(myMap, "c") - - require.Equal(t, 3, len(myMap.Keys())) - require.Equal(t, 3, len(myMap.KeysSorted())) -} - -func TestBucketSortedMap_KeysSorted(t *testing.T) { - myMap := NewBucketSortedMap(1, 4) - - myMap.Set(newScoredDummyItem("d", 3)) - myMap.Set(newScoredDummyItem("a", 0)) - myMap.Set(newScoredDummyItem("c", 2)) - myMap.Set(newScoredDummyItem("b", 1)) - myMap.Set(newScoredDummyItem("f", 5)) - myMap.Set(newScoredDummyItem("e", 4)) - - simulateMutationThatChangesScore(myMap, "d") - simulateMutationThatChangesScore(myMap, "e") - simulateMutationThatChangesScore(myMap, "f") - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - simulateMutationThatChangesScore(myMap, "c") - - keys := myMap.KeysSorted() - require.Equal(t, "a", keys[0]) - require.Equal(t, "b", keys[1]) - require.Equal(t, "c", keys[2]) - - counts := myMap.ScoreChunksCounts() - require.Equal(t, uint32(1), counts[0]) - require.Equal(t, uint32(1), counts[1]) - require.Equal(t, uint32(1), counts[2]) - require.Equal(t, uint32(3), counts[3]) -} - -func TestBucketSortedMap_ItemMovesOnNotifyScoreChange(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - - a := newScoredDummyItem("a", 1) - b := newScoredDummyItem("b", 42) - myMap.Set(a) - myMap.Set(b) - - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - - require.Equal(t, myMap.scoreChunks[1], a.GetScoreChunk()) - require.Equal(t, myMap.scoreChunks[42], b.GetScoreChunk()) - - a.score.Set(2) - b.score.Set(43) - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - - require.Equal(t, myMap.scoreChunks[2], a.GetScoreChunk()) - require.Equal(t, myMap.scoreChunks[43], b.GetScoreChunk()) -} - -func TestBucketSortedMap_Has(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - myMap.Set(newDummyItem("a")) - myMap.Set(newDummyItem("b")) - - require.True(t, myMap.Has("a")) - require.True(t, myMap.Has("b")) - require.False(t, myMap.Has("c")) -} - -func TestBucketSortedMap_Remove(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - myMap.Set(newDummyItem("a")) - myMap.Set(newDummyItem("b")) - - _, ok := myMap.Remove("b") - require.True(t, ok) - _, ok = myMap.Remove("x") - require.False(t, ok) - - require.True(t, myMap.Has("a")) - require.False(t, myMap.Has("b")) -} - -func TestBucketSortedMap_Clear(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - myMap.Set(newDummyItem("a")) - myMap.Set(newDummyItem("b")) - - myMap.Clear() - - require.Equal(t, uint32(0), myMap.Count()) - require.Equal(t, uint32(0), myMap.CountSorted()) -} - -func TestBucketSortedMap_IterCb(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - - myMap.Set(newScoredDummyItem("a", 15)) - myMap.Set(newScoredDummyItem("b", 101)) - myMap.Set(newScoredDummyItem("c", 3)) - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - simulateMutationThatChangesScore(myMap, "c") - - sorted := []string{"c", "a", "b"} - - i := 0 - myMap.IterCbSortedAscending(func(key string, value BucketSortedMapItem) { - require.Equal(t, sorted[i], key) - i++ - }) - - require.Equal(t, 3, i) - - i = len(sorted) - 1 - myMap.IterCbSortedDescending(func(key string, value BucketSortedMapItem) { - require.Equal(t, sorted[i], key) - i-- - }) - - require.Equal(t, 0, i+1) -} - -func TestBucketSortedMap_GetSnapshotAscending(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - - snapshot := myMap.GetSnapshotAscending() - require.Equal(t, []BucketSortedMapItem{}, snapshot) - - a := newScoredDummyItem("a", 15) - b := newScoredDummyItem("b", 101) - c := newScoredDummyItem("c", 3) - - myMap.Set(a) - myMap.Set(b) - myMap.Set(c) - - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - simulateMutationThatChangesScore(myMap, "c") - - snapshot = myMap.GetSnapshotAscending() - require.Equal(t, []BucketSortedMapItem{c, a, b}, snapshot) -} - -func TestBucketSortedMap_GetSnapshotDescending(t *testing.T) { - myMap := NewBucketSortedMap(4, 100) - - snapshot := myMap.GetSnapshotDescending() - require.Equal(t, []BucketSortedMapItem{}, snapshot) - - a := newScoredDummyItem("a", 15) - b := newScoredDummyItem("b", 101) - c := newScoredDummyItem("c", 3) - - myMap.Set(a) - myMap.Set(b) - myMap.Set(c) - - simulateMutationThatChangesScore(myMap, "a") - simulateMutationThatChangesScore(myMap, "b") - simulateMutationThatChangesScore(myMap, "c") - - snapshot = myMap.GetSnapshotDescending() - require.Equal(t, []BucketSortedMapItem{b, a, c}, snapshot) -} - -func TestBucketSortedMap_AddManyItems(t *testing.T) { - numGoroutines := 42 - numItemsPerGoroutine := 1000 - numScoreChunks := 100 - numItemsInScoreChunkPerGoroutine := numItemsPerGoroutine / numScoreChunks - numItemsInScoreChunk := numItemsInScoreChunkPerGoroutine * numGoroutines - - myMap := NewBucketSortedMap(16, uint32(numScoreChunks)) - - var waitGroup sync.WaitGroup - waitGroup.Add(numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func(i int) { - for j := 0; j < numItemsPerGoroutine; j++ { - key := fmt.Sprintf("%d_%d", i, j) - item := newScoredDummyItem(key, uint32(j%numScoreChunks)) - myMap.Set(item) - simulateMutationThatChangesScore(myMap, key) - } - - waitGroup.Done() - }(i) - } - - waitGroup.Wait() - - require.Equal(t, uint32(numGoroutines*numItemsPerGoroutine), myMap.CountSorted()) - - counts := myMap.ScoreChunksCounts() - for i := 0; i < numScoreChunks; i++ { - require.Equal(t, uint32(numItemsInScoreChunk), counts[i]) - } -} - -func TestBucketSortedMap_ClearConcurrentWithRead(t *testing.T) { - numChunks := uint32(4) - numScoreChunks := uint32(4) - myMap := NewBucketSortedMap(numChunks, numScoreChunks) - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - - for j := 0; j < 1000; j++ { - myMap.Clear() - } - }() - - go func() { - defer wg.Done() - - for j := 0; j < 1000; j++ { - require.Equal(t, uint32(0), myMap.Count()) - require.Equal(t, uint32(0), myMap.CountSorted()) - require.Len(t, myMap.ChunksCounts(), int(numChunks)) - require.Len(t, myMap.ScoreChunksCounts(), int(numScoreChunks)) - require.Len(t, myMap.Keys(), 0) - require.Len(t, myMap.KeysSorted(), 0) - require.Equal(t, false, myMap.Has("foobar")) - item, ok := myMap.Get("foobar") - require.Nil(t, item) - require.False(t, ok) - require.Len(t, myMap.GetSnapshotAscending(), 0) - myMap.IterCbSortedAscending(func(key string, item BucketSortedMapItem) { - }) - myMap.IterCbSortedDescending(func(key string, item BucketSortedMapItem) { - }) - } - }() - - wg.Wait() -} - -func TestBucketSortedMap_ClearConcurrentWithWrite(t *testing.T) { - myMap := NewBucketSortedMap(4, 4) - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - for j := 0; j < 10000; j++ { - myMap.Clear() - } - - wg.Done() - }() - - go func() { - for j := 0; j < 10000; j++ { - myMap.Set(newDummyItem("foobar")) - _, _ = myMap.Remove("foobar") - myMap.NotifyScoreChange(newDummyItem("foobar"), 42) - simulateMutationThatChangesScore(myMap, "foobar") - } - - wg.Done() - }() - - wg.Wait() -} - -func TestBucketSortedMap_NoForgottenItemsOnConcurrentScoreChanges(t *testing.T) { - // This test helped us to find a memory leak occuring on concurrent score changes (concurrent movements across buckets) - - for i := 0; i < 1000; i++ { - myMap := NewBucketSortedMap(16, 16) - a := newScoredDummyItem("a", 0) - myMap.Set(a) - simulateMutationThatChangesScore(myMap, "a") - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - a.score.Set(1) - simulateMutationThatChangesScore(myMap, "a") - wg.Done() - }() - - go func() { - a.score.Set(2) - simulateMutationThatChangesScore(myMap, "a") - wg.Done() - }() - - wg.Wait() - - require.Equal(t, uint32(1), myMap.CountSorted()) - require.Equal(t, uint32(1), myMap.Count()) - - _, _ = myMap.Remove("a") - - require.Equal(t, uint32(0), myMap.CountSorted()) - require.Equal(t, uint32(0), myMap.Count()) - } -} diff --git a/storage/txcache/maps/concurrentMap.go b/storage/txcache/maps/concurrentMap.go deleted file mode 100644 index 8ee101696ce..00000000000 --- a/storage/txcache/maps/concurrentMap.go +++ /dev/null @@ -1,180 +0,0 @@ -package maps - -import ( - "sync" -) - -// This implementation is a simplified version of: -// https://github.com/ElrondNetwork/concurrent-map, which is based on: -// https://github.com/orcaman/concurrent-map - -// ConcurrentMap is a thread safe map of type string:Anything. -// To avoid lock bottlenecks this map is divided to several map chunks. -type ConcurrentMap struct { - mutex sync.RWMutex - nChunks uint32 - chunks []*concurrentMapChunk -} - -// concurrentMapChunk is a thread safe string to anything map. -type concurrentMapChunk struct { - items map[string]interface{} - mutex sync.RWMutex -} - -// NewConcurrentMap creates a new concurrent map. -func NewConcurrentMap(nChunks uint32) *ConcurrentMap { - // We cannot have a map with no chunks - if nChunks == 0 { - nChunks = 1 - } - - m := ConcurrentMap{ - nChunks: nChunks, - } - - m.initializeChunks() - - return &m -} - -func (m *ConcurrentMap) initializeChunks() { - // Assignment is not an atomic operation, so we have to wrap this in a critical section - m.mutex.Lock() - defer m.mutex.Unlock() - - m.chunks = make([]*concurrentMapChunk, m.nChunks) - - for i := uint32(0); i < m.nChunks; i++ { - m.chunks[i] = &concurrentMapChunk{ - items: make(map[string]interface{}), - } - } -} - -// Set sets the given value under the specified key. -func (m *ConcurrentMap) Set(key string, value interface{}) { - chunk := m.getChunk(key) - chunk.mutex.Lock() - chunk.items[key] = value - chunk.mutex.Unlock() -} - -// SetIfAbsent sets the given value under the specified key if no value was associated with it. -func (m *ConcurrentMap) SetIfAbsent(key string, value interface{}) bool { - chunk := m.getChunk(key) - chunk.mutex.Lock() - _, ok := chunk.items[key] - if !ok { - chunk.items[key] = value - } - chunk.mutex.Unlock() - return !ok -} - -// Get retrieves an element from map under given key. -func (m *ConcurrentMap) Get(key string) (interface{}, bool) { - chunk := m.getChunk(key) - chunk.mutex.RLock() - val, ok := chunk.items[key] - chunk.mutex.RUnlock() - return val, ok -} - -// Has looks up an item under specified key. -func (m *ConcurrentMap) Has(key string) bool { - chunk := m.getChunk(key) - chunk.mutex.RLock() - _, ok := chunk.items[key] - chunk.mutex.RUnlock() - return ok -} - -// Remove removes an element from the map. -func (m *ConcurrentMap) Remove(key string) (interface{}, bool) { - chunk := m.getChunk(key) - chunk.mutex.Lock() - defer chunk.mutex.Unlock() - - item := chunk.items[key] - delete(chunk.items, key) - return item, item != nil -} - -func (m *ConcurrentMap) getChunk(key string) *concurrentMapChunk { - m.mutex.RLock() - defer m.mutex.RUnlock() - return m.chunks[fnv32(key)%m.nChunks] -} - -// fnv32 implements https://en.wikipedia.org/wiki/Fowler–Noll–Vo_hash_function for 32 bits -func fnv32(key string) uint32 { - hash := uint32(2166136261) - const prime32 = uint32(16777619) - for i := 0; i < len(key); i++ { - hash *= prime32 - hash ^= uint32(key[i]) - } - return hash -} - -// Clear clears the map -func (m *ConcurrentMap) Clear() { - // There is no need to explicitly remove each item for each chunk - // The garbage collector will remove the data from memory - m.initializeChunks() -} - -// Count returns the number of elements within the map -func (m *ConcurrentMap) Count() int { - count := 0 - chunks := m.getChunks() - - for _, chunk := range chunks { - chunk.mutex.RLock() - count += len(chunk.items) - chunk.mutex.RUnlock() - } - return count -} - -// Keys returns all keys as []string -func (m *ConcurrentMap) Keys() []string { - count := m.Count() - chunks := m.getChunks() - - // count is not exact anymore, since we are in a different lock than the one aquired by Count() (but is a good approximation) - keys := make([]string, 0, count) - - for _, chunk := range chunks { - chunk.mutex.RLock() - for key := range chunk.items { - keys = append(keys, key) - } - chunk.mutex.RUnlock() - } - - return keys -} - -// IterCb is an iterator callback -type IterCb func(key string, v interface{}) - -// IterCb iterates over the map (cheapest way to read all elements in a map) -func (m *ConcurrentMap) IterCb(fn IterCb) { - chunks := m.getChunks() - - for _, chunk := range chunks { - chunk.mutex.RLock() - for key, value := range chunk.items { - fn(key, value) - } - chunk.mutex.RUnlock() - } -} - -func (m *ConcurrentMap) getChunks() []*concurrentMapChunk { - m.mutex.RLock() - defer m.mutex.RUnlock() - return m.chunks -} diff --git a/storage/txcache/maps/concurrentMap_test.go b/storage/txcache/maps/concurrentMap_test.go deleted file mode 100644 index 705b87791ea..00000000000 --- a/storage/txcache/maps/concurrentMap_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package maps - -import ( - "sync" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestNewConcurrentMap(t *testing.T) { - myMap := NewConcurrentMap(4) - require.Equal(t, uint32(4), myMap.nChunks) - require.Equal(t, 4, len(myMap.chunks)) - - // 1 is minimum number of chunks - myMap = NewConcurrentMap(0) - require.Equal(t, uint32(1), myMap.nChunks) - require.Equal(t, 1, len(myMap.chunks)) -} - -func TestConcurrentMap_Get(t *testing.T) { - myMap := NewConcurrentMap(4) - myMap.Set("a", "foo") - myMap.Set("b", 42) - - a, ok := myMap.Get("a") - require.True(t, ok) - require.Equal(t, "foo", a) - - b, ok := myMap.Get("b") - require.True(t, ok) - require.Equal(t, 42, b) -} - -func TestConcurrentMap_Count(t *testing.T) { - myMap := NewConcurrentMap(4) - myMap.Set("a", "a") - myMap.Set("b", "b") - myMap.Set("c", "c") - - require.Equal(t, 3, myMap.Count()) -} - -func TestConcurrentMap_Keys(t *testing.T) { - myMap := NewConcurrentMap(4) - myMap.Set("1", 0) - myMap.Set("2", 0) - myMap.Set("3", 0) - myMap.Set("4", 0) - - require.Equal(t, 4, len(myMap.Keys())) -} - -func TestConcurrentMap_Has(t *testing.T) { - myMap := NewConcurrentMap(4) - myMap.SetIfAbsent("a", "a") - myMap.SetIfAbsent("b", "b") - - require.True(t, myMap.Has("a")) - require.True(t, myMap.Has("b")) - require.False(t, myMap.Has("c")) -} - -func TestConcurrentMap_Remove(t *testing.T) { - myMap := NewConcurrentMap(4) - myMap.SetIfAbsent("a", "a") - myMap.SetIfAbsent("b", "b") - - _, ok := myMap.Remove("b") - require.True(t, ok) - _, ok = myMap.Remove("x") - require.False(t, ok) - - require.True(t, myMap.Has("a")) - require.False(t, myMap.Has("b")) -} - -func TestConcurrentMap_Clear(t *testing.T) { - myMap := NewConcurrentMap(4) - myMap.SetIfAbsent("a", "a") - myMap.SetIfAbsent("b", "b") - - myMap.Clear() - - require.Equal(t, 0, myMap.Count()) -} - -func TestConcurrentMap_ClearConcurrentWithRead(t *testing.T) { - myMap := NewConcurrentMap(4) - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - for j := 0; j < 1000; j++ { - myMap.Clear() - } - - wg.Done() - }() - - go func() { - for j := 0; j < 1000; j++ { - require.Equal(t, 0, myMap.Count()) - require.Len(t, myMap.Keys(), 0) - require.Equal(t, false, myMap.Has("foobar")) - item, ok := myMap.Get("foobar") - require.Nil(t, item) - require.False(t, ok) - myMap.IterCb(func(key string, item interface{}) { - }) - } - - wg.Done() - }() - - wg.Wait() -} - -func TestConcurrentMap_ClearConcurrentWithWrite(t *testing.T) { - myMap := NewConcurrentMap(4) - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - for j := 0; j < 10000; j++ { - myMap.Clear() - } - - wg.Done() - }() - - go func() { - for j := 0; j < 10000; j++ { - myMap.Set("foobar", "foobar") - myMap.SetIfAbsent("foobar", "foobar") - _, _ = myMap.Remove("foobar") - } - - wg.Done() - }() - - wg.Wait() -} - -func TestConcurrentMap_IterCb(t *testing.T) { - myMap := NewConcurrentMap(4) - - myMap.Set("a", "a") - myMap.Set("b", "b") - myMap.Set("c", "c") - - i := 0 - myMap.IterCb(func(key string, value interface{}) { - i++ - }) - - require.Equal(t, 3, i) -} diff --git a/storage/txcache/monitoring.go b/storage/txcache/monitoring.go deleted file mode 100644 index 5fa556ab114..00000000000 --- a/storage/txcache/monitoring.go +++ /dev/null @@ -1,247 +0,0 @@ -package txcache - -import ( - "encoding/hex" - "fmt" - "strings" - - "github.com/ElrondNetwork/elrond-go-core/core" - logger "github.com/ElrondNetwork/elrond-go-logger" -) - -var log = logger.GetOrCreate("txcache") - -func (cache *TxCache) monitorEvictionWrtSenderLimit(sender []byte, evicted [][]byte) { - log.Trace("TxCache.AddTx() evict transactions wrt. limit by sender", "name", cache.name, "sender", sender, "num", len(evicted)) - - for i := 0; i < core.MinInt(len(evicted), numEvictedTxsToDisplay); i++ { - log.Trace("TxCache.AddTx() evict transactions wrt. limit by sender", "name", cache.name, "sender", sender, "tx", evicted[i]) - } -} - -func (cache *TxCache) monitorEvictionStart() *core.StopWatch { - log.Debug("TxCache: eviction started", "name", cache.name, "numBytes", cache.NumBytes(), "txs", cache.CountTx(), "senders", cache.CountSenders()) - cache.displaySendersHistogram() - sw := core.NewStopWatch() - sw.Start("eviction") - return sw -} - -func (cache *TxCache) monitorEvictionEnd(stopWatch *core.StopWatch) { - stopWatch.Stop("eviction") - duration := stopWatch.GetMeasurement("eviction") - log.Debug("TxCache: eviction ended", "name", cache.name, "duration", duration, "numBytes", cache.NumBytes(), "txs", cache.CountTx(), "senders", cache.CountSenders()) - cache.evictionJournal.display() - cache.displaySendersHistogram() -} - -func (cache *TxCache) monitorSelectionStart() *core.StopWatch { - log.Debug("TxCache: selection started", "name", cache.name, "numBytes", cache.NumBytes(), "txs", cache.CountTx(), "senders", cache.CountSenders()) - cache.displaySendersHistogram() - sw := core.NewStopWatch() - sw.Start("selection") - return sw -} - -func (cache *TxCache) monitorSelectionEnd(selection []*WrappedTransaction, stopWatch *core.StopWatch) { - stopWatch.Stop("selection") - duration := stopWatch.GetMeasurement("selection") - numSendersSelected := cache.numSendersSelected.Reset() - numSendersWithInitialGap := cache.numSendersWithInitialGap.Reset() - numSendersWithMiddleGap := cache.numSendersWithMiddleGap.Reset() - numSendersInGracePeriod := cache.numSendersInGracePeriod.Reset() - - log.Debug("TxCache: selection ended", "name", cache.name, "duration", duration, - "numTxSelected", len(selection), - "numSendersSelected", numSendersSelected, - "numSendersWithInitialGap", numSendersWithInitialGap, - "numSendersWithMiddleGap", numSendersWithMiddleGap, - "numSendersInGracePeriod", numSendersInGracePeriod, - ) -} - -type batchSelectionJournal struct { - copied int - isFirstBatch bool - hasInitialGap bool - hasMiddleGap bool - isGracePeriod bool -} - -func (cache *TxCache) monitorBatchSelectionEnd(journal batchSelectionJournal) { - if !journal.isFirstBatch { - return - } - - if journal.hasInitialGap { - cache.numSendersWithInitialGap.Increment() - } else if journal.hasMiddleGap { - // Currently, we only count middle gaps on first batch (for simplicity) - cache.numSendersWithMiddleGap.Increment() - } - - if journal.isGracePeriod { - cache.numSendersInGracePeriod.Increment() - } else if journal.copied > 0 { - cache.numSendersSelected.Increment() - } -} - -func (cache *TxCache) monitorSweepingStart() *core.StopWatch { - sw := core.NewStopWatch() - sw.Start("sweeping") - return sw -} - -func (cache *TxCache) monitorSweepingEnd(numTxs uint32, numSenders uint32, stopWatch *core.StopWatch) { - stopWatch.Stop("sweeping") - duration := stopWatch.GetMeasurement("sweeping") - log.Debug("TxCache: swept senders:", "name", cache.name, "duration", duration, "txs", numTxs, "senders", numSenders) - cache.displaySendersHistogram() -} - -func (cache *TxCache) displaySendersHistogram() { - backingMap := cache.txListBySender.backingMap - log.Debug("TxCache.sendersHistogram:", "chunks", backingMap.ChunksCounts(), "scoreChunks", backingMap.ScoreChunksCounts()) -} - -// evictionJournal keeps a short journal about the eviction process -// This is useful for debugging and reasoning about the eviction -type evictionJournal struct { - evictionPerformed bool - passOneNumTxs uint32 - passOneNumSenders uint32 - passOneNumSteps uint32 -} - -func (journal *evictionJournal) display() { - log.Debug("Eviction.pass1:", "txs", journal.passOneNumTxs, "senders", journal.passOneNumSenders, "steps", journal.passOneNumSteps) -} - -// Diagnose checks the state of the cache for inconsistencies and displays a summary -func (cache *TxCache) Diagnose(deep bool) { - cache.diagnoseShallowly() - if deep { - cache.diagnoseDeeply() - } -} - -func (cache *TxCache) diagnoseShallowly() { - sw := core.NewStopWatch() - sw.Start("diagnose") - - sizeInBytes := cache.NumBytes() - numTxsEstimate := int(cache.CountTx()) - numTxsInChunks := cache.txByHash.backingMap.Count() - txsKeys := cache.txByHash.backingMap.Keys() - numSendersEstimate := uint32(cache.CountSenders()) - numSendersInChunks := cache.txListBySender.backingMap.Count() - numSendersInScoreChunks := cache.txListBySender.backingMap.CountSorted() - sendersKeys := cache.txListBySender.backingMap.Keys() - sendersKeysSorted := cache.txListBySender.backingMap.KeysSorted() - sendersSnapshot := cache.txListBySender.getSnapshotAscending() - - sw.Stop("diagnose") - duration := sw.GetMeasurement("diagnose") - - fine := numSendersEstimate == numSendersInChunks && numSendersEstimate == numSendersInScoreChunks - fine = fine && (len(sendersKeys) == len(sendersKeysSorted) && len(sendersKeys) == len(sendersSnapshot)) - fine = fine && (int(numSendersEstimate) == len(sendersKeys)) - fine = fine && (numTxsEstimate == numTxsInChunks && numTxsEstimate == len(txsKeys)) - - log.Debug("TxCache.diagnoseShallowly()", "name", cache.name, "duration", duration, "fine", fine) - log.Debug("TxCache.Size:", "current", sizeInBytes, "max", cache.config.NumBytesThreshold) - log.Debug("TxCache.NumSenders:", "estimate", numSendersEstimate, "inChunks", numSendersInChunks, "inScoreChunks", numSendersInScoreChunks) - log.Debug("TxCache.NumSenders (continued):", "keys", len(sendersKeys), "keysSorted", len(sendersKeysSorted), "snapshot", len(sendersSnapshot)) - log.Debug("TxCache.NumTxs:", "estimate", numTxsEstimate, "inChunks", numTxsInChunks, "keys", len(txsKeys)) -} - -func (cache *TxCache) diagnoseDeeply() { - sw := core.NewStopWatch() - sw.Start("diagnose") - - journal := cache.checkInternalConsistency() - cache.displaySendersSummary() - - sw.Stop("diagnose") - duration := sw.GetMeasurement("diagnose") - - log.Debug("TxCache.diagnoseDeeply()", "name", cache.name, "duration", duration) - journal.display() - cache.displaySendersHistogram() -} - -type internalConsistencyJournal struct { - numInMapByHash int - numInMapBySender int - numMissingInMapByHash int -} - -func (journal *internalConsistencyJournal) isFine() bool { - return (journal.numInMapByHash == journal.numInMapBySender) && (journal.numMissingInMapByHash == 0) -} - -func (journal *internalConsistencyJournal) display() { - log.Debug("TxCache.internalConsistencyJournal:", "fine", journal.isFine(), "numInMapByHash", journal.numInMapByHash, "numInMapBySender", journal.numInMapBySender, "numMissingInMapByHash", journal.numMissingInMapByHash) -} - -func (cache *TxCache) checkInternalConsistency() internalConsistencyJournal { - internalMapByHash := cache.txByHash - internalMapBySender := cache.txListBySender - - senders := internalMapBySender.getSnapshotAscending() - numInMapByHash := len(internalMapByHash.keys()) - numInMapBySender := 0 - numMissingInMapByHash := 0 - - for _, sender := range senders { - numInMapBySender += int(sender.countTx()) - - for _, hash := range sender.getTxHashes() { - _, ok := internalMapByHash.getTx(string(hash)) - if !ok { - numMissingInMapByHash++ - } - } - } - - return internalConsistencyJournal{ - numInMapByHash: numInMapByHash, - numInMapBySender: numInMapBySender, - numMissingInMapByHash: numMissingInMapByHash, - } -} - -func (cache *TxCache) displaySendersSummary() { - if log.GetLevel() != logger.LogTrace { - return - } - - senders := cache.txListBySender.getSnapshotAscending() - if len(senders) == 0 { - return - } - - var builder strings.Builder - builder.WriteString("\n[#index (score)] address [nonce known / nonce vs lowestTxNonce] txs = numTxs, !numFailedSelections\n") - - for i, sender := range senders { - address := hex.EncodeToString([]byte(sender.sender)) - accountNonce := sender.accountNonce.Get() - accountNonceKnown := sender.accountNonceKnown.IsSet() - numFailedSelections := sender.numFailedSelections.Get() - score := sender.getLastComputedScore() - numTxs := sender.countTxWithLock() - - lowestTxNonce := -1 - lowestTx := sender.getLowestNonceTx() - if lowestTx != nil { - lowestTxNonce = int(lowestTx.Tx.GetNonce()) - } - - _, _ = fmt.Fprintf(&builder, "[#%d (%d)] %s [%t / %d vs %d] txs = %d, !%d\n", i, score, address, accountNonceKnown, accountNonce, lowestTxNonce, numTxs, numFailedSelections) - } - - summary := builder.String() - log.Debug("TxCache.displaySendersSummary()", "name", cache.name, "summary\n", summary) -} diff --git a/storage/txcache/score.go b/storage/txcache/score.go deleted file mode 100644 index 06bde537498..00000000000 --- a/storage/txcache/score.go +++ /dev/null @@ -1,67 +0,0 @@ -package txcache - -import ( - "math" -) - -var _ scoreComputer = (*defaultScoreComputer)(nil) - -// TODO (continued): The score formula should work even if minGasPrice = 0. -type senderScoreParams struct { - count uint64 - // Fee score is normalized - feeScore uint64 - gas uint64 -} - -type defaultScoreComputer struct { - txFeeHelper feeHelper - ppuDivider uint64 -} - -func newDefaultScoreComputer(txFeeHelper feeHelper) *defaultScoreComputer { - ppuScoreDivider := txFeeHelper.minGasPriceFactor() - ppuScoreDivider = ppuScoreDivider * ppuScoreDivider * ppuScoreDivider - - return &defaultScoreComputer{ - txFeeHelper: txFeeHelper, - ppuDivider: ppuScoreDivider, - } -} - -// computeScore computes the score of the sender, as an integer 0-100 -func (computer *defaultScoreComputer) computeScore(scoreParams senderScoreParams) uint32 { - rawScore := computer.computeRawScore(scoreParams) - truncatedScore := uint32(rawScore) - return truncatedScore -} - -// TODO (optimization): switch to integer operations (as opposed to float operations). -func (computer *defaultScoreComputer) computeRawScore(params senderScoreParams) float64 { - allParamsDefined := params.feeScore > 0 && params.gas > 0 && params.count > 0 - if !allParamsDefined { - return 0 - } - - ppuMin := computer.txFeeHelper.minPricePerUnit() - normalizedGas := params.gas >> computer.txFeeHelper.gasLimitShift() - if normalizedGas == 0 { - normalizedGas = 1 - } - ppuAvg := params.feeScore / normalizedGas - // (<< 3)^3 and >> 9 cancel each other; used to preserve a bit more resolution - ppuRatio := ppuAvg << 3 / ppuMin - ppuScore := ppuRatio * ppuRatio * ppuRatio >> 9 - ppuScoreAdjusted := float64(ppuScore) / float64(computer.ppuDivider) - - countPow2 := params.count * params.count - countScore := math.Log(float64(countPow2)+1) + 1 - - rawScore := ppuScoreAdjusted / countScore - // We apply the logistic function, - // and then subtract 0.5, since we only deal with positive scores, - // and then we multiply by 2, to have full [0..1] range. - asymptoticScore := (1/(1+math.Exp(-rawScore)) - 0.5) * 2 - score := asymptoticScore * float64(numberOfScoreChunks) - return score -} diff --git a/storage/txcache/score_test.go b/storage/txcache/score_test.go deleted file mode 100644 index 51e438e1d17..00000000000 --- a/storage/txcache/score_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package txcache - -import ( - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDefaultScoreComputer_computeRawScore(t *testing.T) { - _, txFeeHelper := dummyParamsWithGasPrice(oneBillion) - computer := newDefaultScoreComputer(txFeeHelper) - - // 50k moveGas, 100Bil minPrice -> normalizedFee 8940 - score := computer.computeRawScore(senderScoreParams{count: 1, feeScore: 18000, gas: 100000}) - assert.InDelta(t, float64(16.8753739025), score, delta) - - score = computer.computeRawScore(senderScoreParams{count: 1, feeScore: 1500000, gas: 10000000}) - assert.InDelta(t, float64(9.3096887100), score, delta) - - score = computer.computeRawScore(senderScoreParams{count: 1, feeScore: 5000000, gas: 30000000}) - assert.InDelta(t, float64(12.7657690638), score, delta) - - score = computer.computeRawScore(senderScoreParams{count: 2, feeScore: 36000, gas: 200000}) - assert.InDelta(t, float64(11.0106052638), score, delta) - - score = computer.computeRawScore(senderScoreParams{count: 1000, feeScore: 18000000, gas: 100000000}) - assert.InDelta(t, float64(1.8520698299), score, delta) - - score = computer.computeRawScore(senderScoreParams{count: 10000, feeScore: 180000000, gas: 1000000000}) - assert.InDelta(t, float64(1.4129614707), score, delta) -} - -func BenchmarkScoreComputer_computeRawScore(b *testing.B) { - _, txFeeHelper := dummyParams() - computer := newDefaultScoreComputer(txFeeHelper) - - for i := 0; i < b.N; i++ { - for j := uint64(0); j < 10000000; j++ { - computer.computeRawScore(senderScoreParams{count: j, feeScore: uint64(float64(8000) * float64(j)), gas: 100000 * j}) - } - } -} - -func TestDefaultScoreComputer_computeRawScoreOfTxListForSender(t *testing.T) { - txGasHandler, txFeeHelper := dummyParamsWithGasPrice(oneBillion) - computer := newDefaultScoreComputer(txFeeHelper) - list := newUnconstrainedListToTest() - - list.AddTx(createTxWithParams([]byte("a"), ".", 1, 1000, 50000, oneBillion), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("b"), ".", 1, 500, 100000, oneBillion), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("c"), ".", 1, 500, 100000, oneBillion), txGasHandler, txFeeHelper) - - require.Equal(t, uint64(3), list.countTx()) - require.Equal(t, int64(2000), list.totalBytes.Get()) - require.Equal(t, int64(250000), list.totalGas.Get()) - require.Equal(t, int64(51588), list.totalFeeScore.Get()) - - scoreParams := list.getScoreParams() - rawScore := computer.computeRawScore(scoreParams) - require.InDelta(t, float64(12.4595615805), rawScore, delta) -} - -func TestDefaultScoreComputer_scoreFluctuatesDeterministicallyWhileTxListForSenderMutates(t *testing.T) { - txGasHandler, txFeeHelper := dummyParamsWithGasPrice(oneBillion) - computer := newDefaultScoreComputer(txFeeHelper) - list := newUnconstrainedListToTest() - - A := createTxWithParams([]byte("A"), ".", 1, 1000, 200000, oneBillion) - B := createTxWithParams([]byte("b"), ".", 1, 500, 100000, oneBillion) - C := createTxWithParams([]byte("c"), ".", 1, 500, 100000, oneBillion) - D := createTxWithParams([]byte("d"), ".", 1, 128, 50000, oneBillion) - - scoreNone := int(computer.computeScore(list.getScoreParams())) - list.AddTx(A, txGasHandler, txFeeHelper) - scoreA := int(computer.computeScore(list.getScoreParams())) - list.AddTx(B, txGasHandler, txFeeHelper) - scoreAB := int(computer.computeScore(list.getScoreParams())) - list.AddTx(C, txGasHandler, txFeeHelper) - scoreABC := int(computer.computeScore(list.getScoreParams())) - list.AddTx(D, txGasHandler, txFeeHelper) - scoreABCD := int(computer.computeScore(list.getScoreParams())) - - require.Equal(t, 0, scoreNone) - require.Equal(t, 18, scoreA) - require.Equal(t, 12, scoreAB) - require.Equal(t, 10, scoreABC) - require.Equal(t, 9, scoreABCD) - - list.RemoveTx(D) - scoreABC = int(computer.computeScore(list.getScoreParams())) - list.RemoveTx(C) - scoreAB = int(computer.computeScore(list.getScoreParams())) - list.RemoveTx(B) - scoreA = int(computer.computeScore(list.getScoreParams())) - list.RemoveTx(A) - scoreNone = int(computer.computeScore(list.getScoreParams())) - - require.Equal(t, 0, scoreNone) - require.Equal(t, 18, scoreA) - require.Equal(t, 12, scoreAB) - require.Equal(t, 10, scoreABC) -} - -func TestDefaultScoreComputer_DifferentSenders(t *testing.T) { - txGasHandler, txFeeHelper := dummyParamsWithGasPrice(oneBillion) - computer := newDefaultScoreComputer(txFeeHelper) - - A := createTxWithParams([]byte("a"), "a", 1, 128, 50000, oneBillion) // min value normal tx - B := createTxWithParams([]byte("b"), "b", 1, 128, 50000, uint64(1.5*oneBillion)) // 50% higher value normal tx - C := createTxWithParams([]byte("c"), "c", 1, 128, 10000000, oneBillion) // min value SC call - D := createTxWithParams([]byte("d"), "d", 1, 128, 10000000, uint64(1.5*oneBillion)) // 50% higher value SC call - - listA := newUnconstrainedListToTest() - listA.AddTx(A, txGasHandler, txFeeHelper) - scoreA := int(computer.computeScore(listA.getScoreParams())) - - listB := newUnconstrainedListToTest() - listB.AddTx(B, txGasHandler, txFeeHelper) - scoreB := int(computer.computeScore(listB.getScoreParams())) - - listC := newUnconstrainedListToTest() - listC.AddTx(C, txGasHandler, txFeeHelper) - scoreC := int(computer.computeScore(listC.getScoreParams())) - - listD := newUnconstrainedListToTest() - listD.AddTx(D, txGasHandler, txFeeHelper) - scoreD := int(computer.computeScore(listD.getScoreParams())) - - require.Equal(t, 33, scoreA) - require.Equal(t, 82, scoreB) - require.Equal(t, 15, scoreC) - require.Equal(t, 16, scoreD) - - // adding same type of transactions for each sender decreases the score - for i := 2; i < 1000; i++ { - A = createTxWithParams([]byte("a"+strconv.Itoa(i)), "a", uint64(i), 128, 50000, oneBillion) // min value normal tx - listA.AddTx(A, txGasHandler, txFeeHelper) - B = createTxWithParams([]byte("b"+strconv.Itoa(i)), "b", uint64(i), 128, 50000, uint64(1.5*oneBillion)) // 50% higher value normal tx - listB.AddTx(B, txGasHandler, txFeeHelper) - C = createTxWithParams([]byte("c"+strconv.Itoa(i)), "c", uint64(i), 128, 10000000, oneBillion) // min value SC call - listC.AddTx(C, txGasHandler, txFeeHelper) - D = createTxWithParams([]byte("d"+strconv.Itoa(i)), "d", uint64(i), 128, 10000000, uint64(1.5*oneBillion)) // 50% higher value SC call - listD.AddTx(D, txGasHandler, txFeeHelper) - } - - scoreA = int(computer.computeScore(listA.getScoreParams())) - scoreB = int(computer.computeScore(listB.getScoreParams())) - scoreC = int(computer.computeScore(listC.getScoreParams())) - scoreD = int(computer.computeScore(listD.getScoreParams())) - - require.Equal(t, 3, scoreA) - require.Equal(t, 12, scoreB) - require.Equal(t, 1, scoreC) - require.Equal(t, 1, scoreD) -} diff --git a/storage/txcache/sweeping.go b/storage/txcache/sweeping.go deleted file mode 100644 index 92255309aea..00000000000 --- a/storage/txcache/sweeping.go +++ /dev/null @@ -1,29 +0,0 @@ -package txcache - -func (cache *TxCache) initSweepable() { - cache.sweepingListOfSenders = make([]*txListForSender, 0, estimatedNumOfSweepableSendersPerSelection) -} - -func (cache *TxCache) collectSweepable(list *txListForSender) { - if !list.sweepable.IsSet() { - return - } - - cache.sweepingMutex.Lock() - cache.sweepingListOfSenders = append(cache.sweepingListOfSenders, list) - cache.sweepingMutex.Unlock() -} - -func (cache *TxCache) sweepSweepable() { - cache.sweepingMutex.Lock() - defer cache.sweepingMutex.Unlock() - - if len(cache.sweepingListOfSenders) == 0 { - return - } - - stopWatch := cache.monitorSweepingStart() - numTxs, numSenders := cache.evictSendersAndTheirTxs(cache.sweepingListOfSenders) - cache.initSweepable() - cache.monitorSweepingEnd(numTxs, numSenders, stopWatch) -} diff --git a/storage/txcache/sweeping_test.go b/storage/txcache/sweeping_test.go deleted file mode 100644 index a700f7a8755..00000000000 --- a/storage/txcache/sweeping_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package txcache - -import ( - "math" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSweeping_CollectSweepable(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("alice-42"), "alice", 42)) - cache.AddTx(createTx([]byte("bob-42"), "bob", 42)) - cache.AddTx(createTx([]byte("carol-42"), "carol", 42)) - - // Senders have no initial gaps - selection := cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 3, len(selection)) - require.Equal(t, 0, len(cache.sweepingListOfSenders)) - - // Alice and Bob have initial gaps, Carol doesn't - cache.NotifyAccountNonce([]byte("alice"), 10) - cache.NotifyAccountNonce([]byte("bob"), 20) - - // 1st fail - selection = cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 1, len(selection)) - require.Equal(t, 0, len(cache.sweepingListOfSenders)) - require.Equal(t, 1, cache.getNumFailedSelectionsOfSender("alice")) - require.Equal(t, 1, cache.getNumFailedSelectionsOfSender("bob")) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("carol")) - - // 2nd fail, grace period, one grace transaction for Alice and Bob - selection = cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 3, len(selection)) - require.Equal(t, 0, len(cache.sweepingListOfSenders)) - require.Equal(t, 2, cache.getNumFailedSelectionsOfSender("alice")) - require.Equal(t, 2, cache.getNumFailedSelectionsOfSender("bob")) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("carol")) - - // 3nd fail, collect Alice and Bob as sweepables - selection = cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 1, len(selection)) - require.Equal(t, 2, len(cache.sweepingListOfSenders)) - require.True(t, cache.isSenderSweepable("alice")) - require.True(t, cache.isSenderSweepable("bob")) - require.Equal(t, 3, cache.getNumFailedSelectionsOfSender("alice")) - require.Equal(t, 3, cache.getNumFailedSelectionsOfSender("bob")) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("carol")) -} - -func TestSweeping_WhenSendersEscapeCollection(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("alice-42"), "alice", 42)) - cache.AddTx(createTx([]byte("bob-42"), "bob", 42)) - cache.AddTx(createTx([]byte("carol-42"), "carol", 42)) - - // Senders have no initial gaps - selection := cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 3, len(selection)) - require.Equal(t, 0, len(cache.sweepingListOfSenders)) - - // Alice and Bob have initial gaps, Carol doesn't - cache.NotifyAccountNonce([]byte("alice"), 10) - cache.NotifyAccountNonce([]byte("bob"), 20) - - // 1st fail - selection = cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 1, len(selection)) - require.Equal(t, 0, len(cache.sweepingListOfSenders)) - require.Equal(t, 1, cache.getNumFailedSelectionsOfSender("alice")) - require.Equal(t, 1, cache.getNumFailedSelectionsOfSender("bob")) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("carol")) - - // 2nd fail, grace period, one grace transaction for Alice and Bob - selection = cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 3, len(selection)) - require.Equal(t, 0, len(cache.sweepingListOfSenders)) - require.Equal(t, 2, cache.getNumFailedSelectionsOfSender("alice")) - require.Equal(t, 2, cache.getNumFailedSelectionsOfSender("bob")) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("carol")) - - // 3rd attempt, but with gaps resolved - // Alice and Bob escape and won't be collected as sweepables - cache.NotifyAccountNonce([]byte("alice"), 42) - cache.NotifyAccountNonce([]byte("bob"), 42) - - selection = cache.doSelectTransactions(1000, 1000, math.MaxUint64) - require.Equal(t, 3, len(selection)) - require.Equal(t, 0, len(cache.sweepingListOfSenders)) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("alice")) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("bob")) - require.Equal(t, 0, cache.getNumFailedSelectionsOfSender("carol")) -} - -func TestSweeping_SweepSweepable(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("alice-42"), "alice", 42)) - cache.AddTx(createTx([]byte("bob-42"), "bob", 42)) - cache.AddTx(createTx([]byte("carol-42"), "carol", 42)) - - // Fake "Alice" and "Bob" as sweepable - cache.sweepingListOfSenders = []*txListForSender{ - cache.getListForSender("alice"), - cache.getListForSender("bob"), - } - - require.Equal(t, uint64(3), cache.CountTx()) - require.Equal(t, uint64(3), cache.CountSenders()) - - cache.sweepSweepable() - - require.Equal(t, uint64(1), cache.CountTx()) - require.Equal(t, uint64(1), cache.CountSenders()) -} diff --git a/storage/txcache/testutils_test.go b/storage/txcache/testutils_test.go deleted file mode 100644 index 76382eb7676..00000000000 --- a/storage/txcache/testutils_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package txcache - -import ( - "encoding/binary" - "sync" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/data/transaction" -) - -const oneMilion = 1000000 -const oneBillion = oneMilion * 1000 -const delta = 0.00000001 -const estimatedSizeOfBoundedTxFields = uint64(128) - -func (cache *TxCache) areInternalMapsConsistent() bool { - journal := cache.checkInternalConsistency() - return journal.isFine() -} - -func (cache *TxCache) getHashesForSender(sender string) []string { - return cache.getListForSender(sender).getTxHashesAsStrings() -} - -func (cache *TxCache) getListForSender(sender string) *txListForSender { - return cache.txListBySender.testGetListForSender(sender) -} - -func (txMap *txListBySenderMap) testGetListForSender(sender string) *txListForSender { - list, ok := txMap.getListForSender(sender) - if !ok { - panic("sender not in cache") - } - - return list -} - -func (cache *TxCache) getScoreOfSender(sender string) uint32 { - list := cache.getListForSender(sender) - scoreParams := list.getScoreParams() - computer := cache.txListBySender.scoreComputer - return computer.computeScore(scoreParams) -} - -func (cache *TxCache) getNumFailedSelectionsOfSender(sender string) int { - return int(cache.getListForSender(sender).numFailedSelections.Get()) -} - -func (cache *TxCache) isSenderSweepable(sender string) bool { - for _, item := range cache.sweepingListOfSenders { - if item.sender == sender { - return true - } - } - - return false -} - -func (listForSender *txListForSender) getTxHashesAsStrings() []string { - hashes := listForSender.getTxHashes() - return hashesAsStrings(hashes) -} - -func hashesAsStrings(hashes [][]byte) []string { - result := make([]string, len(hashes)) - for i := 0; i < len(hashes); i++ { - result[i] = string(hashes[i]) - } - - return result -} - -func hashesAsBytes(hashes []string) [][]byte { - result := make([][]byte, len(hashes)) - for i := 0; i < len(hashes); i++ { - result[i] = []byte(hashes[i]) - } - - return result -} - -func addManyTransactionsWithUniformDistribution(cache *TxCache, nSenders int, nTransactionsPerSender int) { - for senderTag := 0; senderTag < nSenders; senderTag++ { - sender := createFakeSenderAddress(senderTag) - - for txNonce := nTransactionsPerSender; txNonce > 0; txNonce-- { - txHash := createFakeTxHash(sender, txNonce) - tx := createTx(txHash, string(sender), uint64(txNonce)) - cache.AddTx(tx) - } - } -} - -func createTx(hash []byte, sender string, nonce uint64) *WrappedTransaction { - tx := &transaction.Transaction{ - SndAddr: []byte(sender), - Nonce: nonce, - } - - return &WrappedTransaction{ - Tx: tx, - TxHash: hash, - Size: int64(estimatedSizeOfBoundedTxFields), - } -} -func createTxWithGasLimit(hash []byte, sender string, nonce uint64, gasLimit uint64) *WrappedTransaction { - tx := &transaction.Transaction{ - SndAddr: []byte(sender), - Nonce: nonce, - GasLimit: gasLimit, - } - - return &WrappedTransaction{ - Tx: tx, - TxHash: hash, - Size: int64(estimatedSizeOfBoundedTxFields), - } -} - -func createTxWithParams(hash []byte, sender string, nonce uint64, size uint64, gasLimit uint64, gasPrice uint64) *WrappedTransaction { - dataLength := int(size) - int(estimatedSizeOfBoundedTxFields) - if dataLength < 0 { - panic("createTxWithData(): invalid length for dummy tx") - } - - tx := &transaction.Transaction{ - SndAddr: []byte(sender), - Nonce: nonce, - Data: make([]byte, dataLength), - GasLimit: gasLimit, - GasPrice: gasPrice, - } - - return &WrappedTransaction{ - Tx: tx, - TxHash: hash, - Size: int64(size), - } -} - -func createFakeSenderAddress(senderTag int) []byte { - bytes := make([]byte, 32) - binary.LittleEndian.PutUint64(bytes, uint64(senderTag)) - binary.LittleEndian.PutUint64(bytes[24:], uint64(senderTag)) - return bytes -} - -func createFakeTxHash(fakeSenderAddress []byte, nonce int) []byte { - bytes := make([]byte, 32) - copy(bytes, fakeSenderAddress) - binary.LittleEndian.PutUint64(bytes[8:], uint64(nonce)) - binary.LittleEndian.PutUint64(bytes[16:], uint64(nonce)) - return bytes -} - -func measureWithStopWatch(b *testing.B, function func()) { - sw := core.NewStopWatch() - sw.Start("time") - function() - sw.Stop("time") - - duration := sw.GetMeasurementsMap()["time"] - b.ReportMetric(duration, "time@stopWatch") -} - -// waitTimeout waits for the waitgroup for the specified max timeout. -// Returns true if waiting timed out. -// Reference: https://stackoverflow.com/a/32843750/1475331 -func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { - c := make(chan struct{}) - go func() { - defer close(c) - wg.Wait() - }() - select { - case <-c: - return false // completed normally - case <-time.After(timeout): - return true // timed out - } -} - -var _ scoreComputer = (*disabledScoreComputer)(nil) - -type disabledScoreComputer struct { -} - -func (computer *disabledScoreComputer) computeScore(_ senderScoreParams) uint32 { - return 0 -} diff --git a/storage/txcache/txByHashMap.go b/storage/txcache/txByHashMap.go deleted file mode 100644 index 2346cc364c9..00000000000 --- a/storage/txcache/txByHashMap.go +++ /dev/null @@ -1,101 +0,0 @@ -package txcache - -import ( - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - "github.com/ElrondNetwork/elrond-go/storage/txcache/maps" -) - -// txByHashMap is a new map-like structure for holding and accessing transactions by txHash -type txByHashMap struct { - backingMap *maps.ConcurrentMap - counter atomic.Counter - numBytes atomic.Counter -} - -// newTxByHashMap creates a new TxByHashMap instance -func newTxByHashMap(nChunksHint uint32) *txByHashMap { - backingMap := maps.NewConcurrentMap(nChunksHint) - - return &txByHashMap{ - backingMap: backingMap, - } -} - -// addTx adds a transaction to the map -func (txMap *txByHashMap) addTx(tx *WrappedTransaction) bool { - added := txMap.backingMap.SetIfAbsent(string(tx.TxHash), tx) - if added { - txMap.counter.Increment() - txMap.numBytes.Add(tx.Size) - } - - return added -} - -// removeTx removes a transaction from the map -func (txMap *txByHashMap) removeTx(txHash string) (*WrappedTransaction, bool) { - item, removed := txMap.backingMap.Remove(txHash) - if !removed { - return nil, false - } - - tx, ok := item.(*WrappedTransaction) - if !ok { - return nil, false - } - - if removed { - txMap.counter.Decrement() - txMap.numBytes.Subtract(tx.Size) - } - - return tx, true -} - -// getTx gets a transaction from the map -func (txMap *txByHashMap) getTx(txHash string) (*WrappedTransaction, bool) { - txUntyped, ok := txMap.backingMap.Get(txHash) - if !ok { - return nil, false - } - - tx := txUntyped.(*WrappedTransaction) - return tx, true -} - -// RemoveTxsBulk removes transactions, in bulk -func (txMap *txByHashMap) RemoveTxsBulk(txHashes [][]byte) uint32 { - numRemoved := uint32(0) - - for _, txHash := range txHashes { - _, removed := txMap.removeTx(string(txHash)) - if removed { - numRemoved++ - } - } - - return numRemoved -} - -// forEach iterates over the senders -func (txMap *txByHashMap) forEach(function ForEachTransaction) { - txMap.backingMap.IterCb(func(key string, item interface{}) { - tx := item.(*WrappedTransaction) - function([]byte(key), tx) - }) -} - -func (txMap *txByHashMap) clear() { - txMap.backingMap.Clear() - txMap.counter.Set(0) -} - -func (txMap *txByHashMap) keys() [][]byte { - keys := txMap.backingMap.Keys() - keysAsBytes := make([][]byte, len(keys)) - for i := 0; i < len(keys); i++ { - keysAsBytes[i] = []byte(keys[i]) - } - - return keysAsBytes -} diff --git a/storage/txcache/txCache.go b/storage/txcache/txCache.go deleted file mode 100644 index d9772cb21a9..00000000000 --- a/storage/txcache/txCache.go +++ /dev/null @@ -1,327 +0,0 @@ -package txcache - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/storage" -) - -var _ storage.Cacher = (*TxCache)(nil) - -// TxCache represents a cache-like structure (it has a fixed capacity and implements an eviction mechanism) for holding transactions -type TxCache struct { - name string - txListBySender *txListBySenderMap - txByHash *txByHashMap - config ConfigSourceMe - evictionMutex sync.Mutex - evictionJournal evictionJournal - evictionSnapshotOfSenders []*txListForSender - isEvictionInProgress atomic.Flag - numSendersSelected atomic.Counter - numSendersWithInitialGap atomic.Counter - numSendersWithMiddleGap atomic.Counter - numSendersInGracePeriod atomic.Counter - sweepingMutex sync.Mutex - sweepingListOfSenders []*txListForSender - mutTxOperation sync.Mutex -} - -// NewTxCache creates a new transaction cache -func NewTxCache(config ConfigSourceMe, txGasHandler TxGasHandler) (*TxCache, error) { - log.Debug("NewTxCache", "config", config.String()) - storage.MonitorNewCache(config.Name, uint64(config.NumBytesThreshold)) - - err := config.verify() - if err != nil { - return nil, err - } - if check.IfNil(txGasHandler) { - return nil, storage.ErrNilTxGasHandler - } - - // Note: for simplicity, we use the same "numChunks" for both internal concurrent maps - numChunks := config.NumChunks - senderConstraintsObj := config.getSenderConstraints() - txFeeHelper := newFeeComputationHelper(txGasHandler.MinGasPrice(), txGasHandler.MinGasLimit(), txGasHandler.MinGasPriceForProcessing()) - scoreComputerObj := newDefaultScoreComputer(txFeeHelper) - - txCache := &TxCache{ - name: config.Name, - txListBySender: newTxListBySenderMap(numChunks, senderConstraintsObj, scoreComputerObj, txGasHandler, txFeeHelper), - txByHash: newTxByHashMap(numChunks), - config: config, - evictionJournal: evictionJournal{}, - } - - txCache.initSweepable() - return txCache, nil -} - -// AddTx adds a transaction in the cache -// Eviction happens if maximum capacity is reached -func (cache *TxCache) AddTx(tx *WrappedTransaction) (ok bool, added bool) { - if tx == nil || check.IfNil(tx.Tx) { - return false, false - } - - if cache.config.EvictionEnabled { - cache.doEviction() - } - - cache.mutTxOperation.Lock() - addedInByHash := cache.txByHash.addTx(tx) - addedInBySender, evicted := cache.txListBySender.addTx(tx) - cache.mutTxOperation.Unlock() - if addedInByHash != addedInBySender { - // This can happen when two go-routines concur to add the same transaction: - // - A adds to "txByHash" - // - B won't add to "txByHash" (duplicate) - // - B adds to "txListBySender" - // - A won't add to "txListBySender" (duplicate) - log.Trace("TxCache.AddTx(): slight inconsistency detected:", "name", cache.name, "tx", tx.TxHash, "sender", tx.Tx.GetSndAddr(), "addedInByHash", addedInByHash, "addedInBySender", addedInBySender) - } - - if len(evicted) > 0 { - cache.monitorEvictionWrtSenderLimit(tx.Tx.GetSndAddr(), evicted) - cache.txByHash.RemoveTxsBulk(evicted) - } - - // The return value "added" is true even if transaction added, but then removed due to limits be sender. - // This it to ensure that onAdded() notification is triggered. - return true, addedInByHash || addedInBySender -} - -// GetByTxHash gets the transaction by hash -func (cache *TxCache) GetByTxHash(txHash []byte) (*WrappedTransaction, bool) { - tx, ok := cache.txByHash.getTx(string(txHash)) - return tx, ok -} - -// SelectTransactionsWithBandwidth selects a reasonably fair list of transactions to be included in the next miniblock -// It returns at most "numRequested" transactions -// Each sender gets the chance to give at least bandwidthPerSender gas worth of transactions, unless "numRequested" limit is reached before iterating over all senders -func (cache *TxCache) SelectTransactionsWithBandwidth(numRequested int, batchSizePerSender int, bandwidthPerSender uint64) []*WrappedTransaction { - result := cache.doSelectTransactions(numRequested, batchSizePerSender, bandwidthPerSender) - go cache.doAfterSelection() - return result -} - -func (cache *TxCache) doSelectTransactions(numRequested int, batchSizePerSender int, bandwidthPerSender uint64) []*WrappedTransaction { - stopWatch := cache.monitorSelectionStart() - - result := make([]*WrappedTransaction, numRequested) - resultFillIndex := 0 - resultIsFull := false - - snapshotOfSenders := cache.getSendersEligibleForSelection() - - for pass := 0; !resultIsFull; pass++ { - copiedInThisPass := 0 - - for _, txList := range snapshotOfSenders { - batchSizeWithScoreCoefficient := batchSizePerSender * int(txList.getLastComputedScore()+1) - // Reset happens on first pass only - isFirstBatch := pass == 0 - journal := txList.selectBatchTo(isFirstBatch, result[resultFillIndex:], batchSizeWithScoreCoefficient, bandwidthPerSender) - cache.monitorBatchSelectionEnd(journal) - - if isFirstBatch { - cache.collectSweepable(txList) - } - - resultFillIndex += journal.copied - copiedInThisPass += journal.copied - resultIsFull = resultFillIndex == numRequested - if resultIsFull { - break - } - } - - nothingCopiedThisPass := copiedInThisPass == 0 - - // No more passes needed - if nothingCopiedThisPass { - break - } - } - - result = result[:resultFillIndex] - cache.monitorSelectionEnd(result, stopWatch) - return result -} - -func (cache *TxCache) getSendersEligibleForSelection() []*txListForSender { - return cache.txListBySender.getSnapshotDescending() -} - -func (cache *TxCache) doAfterSelection() { - cache.sweepSweepable() - cache.Diagnose(false) -} - -// RemoveTxByHash removes tx by hash -func (cache *TxCache) RemoveTxByHash(txHash []byte) bool { - cache.mutTxOperation.Lock() - defer cache.mutTxOperation.Unlock() - - tx, foundInByHash := cache.txByHash.removeTx(string(txHash)) - if !foundInByHash { - return false - } - - foundInBySender := cache.txListBySender.removeTx(tx) - if !foundInBySender { - // This condition can arise often at high load & eviction, when two go-routines concur to remove the same transaction: - // - A = remove transactions upon commit / final - // - B = remove transactions due to high load (eviction) - // - // - A reaches "RemoveTxByHash()", then "cache.txByHash.removeTx()". - // - B reaches "cache.txByHash.RemoveTxsBulk()" - // - B reaches "cache.txListBySender.RemoveSendersBulk()" - // - A reaches "cache.txListBySender.removeTx()", but sender does not exist anymore - log.Trace("TxCache.RemoveTxByHash(): slight inconsistency detected: !foundInBySender", "name", cache.name, "tx", txHash) - } - - return true -} - -// NumBytes gets the approximate number of bytes stored in the cache -func (cache *TxCache) NumBytes() int { - return int(cache.txByHash.numBytes.GetUint64()) -} - -// CountTx gets the number of transactions in the cache -func (cache *TxCache) CountTx() uint64 { - return cache.txByHash.counter.GetUint64() -} - -// Len is an alias for CountTx -func (cache *TxCache) Len() int { - return int(cache.CountTx()) -} - -// SizeInBytesContained returns 0 -func (cache *TxCache) SizeInBytesContained() uint64 { - return 0 -} - -// CountSenders gets the number of senders in the cache -func (cache *TxCache) CountSenders() uint64 { - return cache.txListBySender.counter.GetUint64() -} - -// ForEachTransaction iterates over the transactions in the cache -func (cache *TxCache) ForEachTransaction(function ForEachTransaction) { - cache.txByHash.forEach(function) -} - -// GetTransactionsPoolForSender returns the list of transaction hashes for the sender -func (cache *TxCache) GetTransactionsPoolForSender(sender string) []*WrappedTransaction { - listForSender, ok := cache.txListBySender.getListForSender(sender) - if !ok { - return nil - } - - wrappedTxs := make([]*WrappedTransaction, listForSender.items.Len()) - for element, i := listForSender.items.Front(), 0; element != nil; element, i = element.Next(), i+1 { - tx := element.Value.(*WrappedTransaction) - wrappedTxs[i] = tx - } - - return wrappedTxs -} - -// Clear clears the cache -func (cache *TxCache) Clear() { - cache.mutTxOperation.Lock() - cache.txListBySender.clear() - cache.txByHash.clear() - cache.mutTxOperation.Unlock() -} - -// Put is not implemented -func (cache *TxCache) Put(_ []byte, _ interface{}, _ int) (evicted bool) { - log.Error("TxCache.Put is not implemented") - return false -} - -// Get gets a transaction (unwrapped) by hash -// Implemented for compatibility reasons (see txPoolsCleaner.go). -func (cache *TxCache) Get(key []byte) (value interface{}, ok bool) { - tx, ok := cache.GetByTxHash(key) - if ok { - return tx.Tx, true - } - return nil, false -} - -// Has checks if a transaction exists -func (cache *TxCache) Has(key []byte) bool { - _, ok := cache.GetByTxHash(key) - return ok -} - -// Peek gets a transaction (unwrapped) by hash -// Implemented for compatibility reasons (see transactions.go, common.go). -func (cache *TxCache) Peek(key []byte) (value interface{}, ok bool) { - tx, ok := cache.GetByTxHash(key) - if ok { - return tx.Tx, true - } - return nil, false -} - -// HasOrAdd is not implemented -func (cache *TxCache) HasOrAdd(_ []byte, _ interface{}, _ int) (has, added bool) { - log.Error("TxCache.HasOrAdd is not implemented") - return false, false -} - -// Remove removes tx by hash -func (cache *TxCache) Remove(key []byte) { - _ = cache.RemoveTxByHash(key) -} - -// Keys returns the tx hashes in the cache -func (cache *TxCache) Keys() [][]byte { - return cache.txByHash.keys() -} - -// MaxSize is not implemented -func (cache *TxCache) MaxSize() int { - // TODO: Should be analyzed if the returned value represents the max size of one cache in sharded cache configuration - return int(cache.config.CountThreshold) -} - -// RegisterHandler is not implemented -func (cache *TxCache) RegisterHandler(func(key []byte, value interface{}), string) { - log.Error("TxCache.RegisterHandler is not implemented") -} - -// UnRegisterHandler is not implemented -func (cache *TxCache) UnRegisterHandler(string) { - log.Error("TxCache.UnRegisterHandler is not implemented") -} - -// NotifyAccountNonce should be called by external components (such as interceptors and transactions processor) -// in order to inform the cache about initial nonce gap phenomena -func (cache *TxCache) NotifyAccountNonce(accountKey []byte, nonce uint64) { - cache.txListBySender.notifyAccountNonce(accountKey, nonce) -} - -// ImmunizeTxsAgainstEviction does nothing for this type of cache -func (cache *TxCache) ImmunizeTxsAgainstEviction(_ [][]byte) { -} - -// Close does nothing for this cacher implementation -func (cache *TxCache) Close() error { - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (cache *TxCache) IsInterfaceNil() bool { - return cache == nil -} diff --git a/storage/txcache/txCache_test.go b/storage/txcache/txCache_test.go deleted file mode 100644 index fdb0ee98b5a..00000000000 --- a/storage/txcache/txCache_test.go +++ /dev/null @@ -1,656 +0,0 @@ -package txcache - -import ( - "errors" - "fmt" - "math" - "sort" - "sync" - "testing" - "time" - - "github.com/ElrondNetwork/elrond-go-core/core" - "github.com/ElrondNetwork/elrond-go-core/core/check" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Test_NewTxCache(t *testing.T) { - config := ConfigSourceMe{ - Name: "test", - NumChunks: 16, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - } - - withEvictionConfig := ConfigSourceMe{ - Name: "test", - NumChunks: 16, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - EvictionEnabled: true, - NumBytesThreshold: maxNumBytesUpperBound, - CountThreshold: math.MaxUint32, - NumSendersToPreemptivelyEvict: 100, - } - txGasHandler, _ := dummyParams() - - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - badConfig := config - badConfig.Name = "" - requireErrorOnNewTxCache(t, badConfig, storage.ErrInvalidConfig, "config.Name", txGasHandler) - - badConfig = config - badConfig.NumChunks = 0 - requireErrorOnNewTxCache(t, badConfig, storage.ErrInvalidConfig, "config.NumChunks", txGasHandler) - - badConfig = config - badConfig.NumBytesPerSenderThreshold = 0 - requireErrorOnNewTxCache(t, badConfig, storage.ErrInvalidConfig, "config.NumBytesPerSenderThreshold", txGasHandler) - - badConfig = config - badConfig.CountPerSenderThreshold = 0 - requireErrorOnNewTxCache(t, badConfig, storage.ErrInvalidConfig, "config.CountPerSenderThreshold", txGasHandler) - - badConfig = config - cache, err = NewTxCache(config, nil) - require.Nil(t, cache) - require.Equal(t, storage.ErrNilTxGasHandler, err) - - badConfig = withEvictionConfig - badConfig.NumBytesThreshold = 0 - requireErrorOnNewTxCache(t, badConfig, storage.ErrInvalidConfig, "config.NumBytesThreshold", txGasHandler) - - badConfig = withEvictionConfig - badConfig.CountThreshold = 0 - requireErrorOnNewTxCache(t, badConfig, storage.ErrInvalidConfig, "config.CountThreshold", txGasHandler) - - badConfig = withEvictionConfig - badConfig.NumSendersToPreemptivelyEvict = 0 - requireErrorOnNewTxCache(t, badConfig, storage.ErrInvalidConfig, "config.NumSendersToPreemptivelyEvict", txGasHandler) -} - -func requireErrorOnNewTxCache(t *testing.T, config ConfigSourceMe, errExpected error, errPartialMessage string, txGasHandler TxGasHandler) { - cache, errReceived := NewTxCache(config, txGasHandler) - require.Nil(t, cache) - require.True(t, errors.Is(errReceived, errExpected)) - require.Contains(t, errReceived.Error(), errPartialMessage) -} - -func Test_AddTx(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - tx := createTx([]byte("hash-1"), "alice", 1) - - ok, added := cache.AddTx(tx) - require.True(t, ok) - require.True(t, added) - require.True(t, cache.Has([]byte("hash-1"))) - - // Add it again (no-operation) - ok, added = cache.AddTx(tx) - require.True(t, ok) - require.False(t, added) - require.True(t, cache.Has([]byte("hash-1"))) - - foundTx, ok := cache.GetByTxHash([]byte("hash-1")) - require.True(t, ok) - require.Equal(t, tx, foundTx) -} - -func Test_AddNilTx_DoesNothing(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - txHash := []byte("hash-1") - - ok, added := cache.AddTx(&WrappedTransaction{Tx: nil, TxHash: txHash}) - require.False(t, ok) - require.False(t, added) - - foundTx, ok := cache.GetByTxHash(txHash) - require.False(t, ok) - require.Nil(t, foundTx) -} - -func Test_AddTx_AppliesSizeConstraintsPerSenderForNumTransactions(t *testing.T) { - cache := newCacheToTest(maxNumBytesPerSenderUpperBound, 3) - - cache.AddTx(createTx([]byte("tx-alice-1"), "alice", 1)) - cache.AddTx(createTx([]byte("tx-alice-2"), "alice", 2)) - cache.AddTx(createTx([]byte("tx-alice-4"), "alice", 4)) - cache.AddTx(createTx([]byte("tx-bob-1"), "bob", 1)) - cache.AddTx(createTx([]byte("tx-bob-2"), "bob", 2)) - require.Equal(t, []string{"tx-alice-1", "tx-alice-2", "tx-alice-4"}, cache.getHashesForSender("alice")) - require.Equal(t, []string{"tx-bob-1", "tx-bob-2"}, cache.getHashesForSender("bob")) - require.True(t, cache.areInternalMapsConsistent()) - - cache.AddTx(createTx([]byte("tx-alice-3"), "alice", 3)) - require.Equal(t, []string{"tx-alice-1", "tx-alice-2", "tx-alice-3"}, cache.getHashesForSender("alice")) - require.Equal(t, []string{"tx-bob-1", "tx-bob-2"}, cache.getHashesForSender("bob")) - require.True(t, cache.areInternalMapsConsistent()) -} - -func Test_AddTx_AppliesSizeConstraintsPerSenderForNumBytes(t *testing.T) { - cache := newCacheToTest(1024, math.MaxUint32) - - cache.AddTx(createTxWithParams([]byte("tx-alice-1"), "alice", 1, 128, 42, 42)) - cache.AddTx(createTxWithParams([]byte("tx-alice-2"), "alice", 2, 512, 42, 42)) - cache.AddTx(createTxWithParams([]byte("tx-alice-4"), "alice", 3, 256, 42, 42)) - cache.AddTx(createTxWithParams([]byte("tx-bob-1"), "bob", 1, 512, 42, 42)) - cache.AddTx(createTxWithParams([]byte("tx-bob-2"), "bob", 2, 513, 42, 42)) - - require.Equal(t, []string{"tx-alice-1", "tx-alice-2", "tx-alice-4"}, cache.getHashesForSender("alice")) - require.Equal(t, []string{"tx-bob-1"}, cache.getHashesForSender("bob")) - require.True(t, cache.areInternalMapsConsistent()) - - cache.AddTx(createTxWithParams([]byte("tx-alice-3"), "alice", 3, 256, 42, 42)) - cache.AddTx(createTxWithParams([]byte("tx-bob-2"), "bob", 3, 512, 42, 42)) - require.Equal(t, []string{"tx-alice-1", "tx-alice-2", "tx-alice-3"}, cache.getHashesForSender("alice")) - require.Equal(t, []string{"tx-bob-1", "tx-bob-2"}, cache.getHashesForSender("bob")) - require.True(t, cache.areInternalMapsConsistent()) -} - -func Test_RemoveByTxHash(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("hash-1"), "alice", 1)) - cache.AddTx(createTx([]byte("hash-2"), "alice", 2)) - - removed := cache.RemoveTxByHash([]byte("hash-1")) - require.True(t, removed) - cache.Remove([]byte("hash-2")) - - foundTx, ok := cache.GetByTxHash([]byte("hash-1")) - require.False(t, ok) - require.Nil(t, foundTx) - - foundTx, ok = cache.GetByTxHash([]byte("hash-2")) - require.False(t, ok) - require.Nil(t, foundTx) -} - -func Test_CountTx_And_Len(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("hash-1"), "alice", 1)) - cache.AddTx(createTx([]byte("hash-2"), "alice", 2)) - cache.AddTx(createTx([]byte("hash-3"), "alice", 3)) - - require.Equal(t, uint64(3), cache.CountTx()) - require.Equal(t, 3, cache.Len()) -} - -func Test_GetByTxHash_And_Peek_And_Get(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - txHash := []byte("hash-1") - tx := createTx(txHash, "alice", 1) - cache.AddTx(tx) - - foundTx, ok := cache.GetByTxHash(txHash) - require.True(t, ok) - require.Equal(t, tx, foundTx) - - foundTxPeek, okPeek := cache.Peek(txHash) - require.True(t, okPeek) - require.Equal(t, tx.Tx, foundTxPeek) - - foundTxPeek, okPeek = cache.Peek([]byte("missing")) - require.False(t, okPeek) - require.Nil(t, foundTxPeek) - - foundTxGet, okGet := cache.Get(txHash) - require.True(t, okGet) - require.Equal(t, tx.Tx, foundTxGet) - - foundTxGet, okGet = cache.Get([]byte("missing")) - require.False(t, okGet) - require.Nil(t, foundTxGet) -} - -func Test_RemoveByTxHash_WhenMissing(t *testing.T) { - cache := newUnconstrainedCacheToTest() - removed := cache.RemoveTxByHash([]byte("missing")) - require.False(t, removed) -} - -func Test_RemoveByTxHash_RemovesFromByHash_WhenMapsInconsistency(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - txHash := []byte("hash-1") - tx := createTx(txHash, "alice", 1) - cache.AddTx(tx) - - // Cause an inconsistency between the two internal maps (theoretically possible in case of misbehaving eviction) - cache.txListBySender.removeTx(tx) - - _ = cache.RemoveTxByHash(txHash) - require.Equal(t, 0, cache.txByHash.backingMap.Count()) -} - -func Test_Clear(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("hash-alice-1"), "alice", 1)) - cache.AddTx(createTx([]byte("hash-bob-7"), "bob", 7)) - cache.AddTx(createTx([]byte("hash-alice-42"), "alice", 42)) - require.Equal(t, uint64(3), cache.CountTx()) - - cache.Clear() - require.Equal(t, uint64(0), cache.CountTx()) -} - -func Test_ForEachTransaction(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("hash-alice-1"), "alice", 1)) - cache.AddTx(createTx([]byte("hash-bob-7"), "bob", 7)) - - counter := 0 - cache.ForEachTransaction(func(txHash []byte, value *WrappedTransaction) { - counter++ - }) - require.Equal(t, 2, counter) -} - -func Test_GetTransactionsPoolForSender(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - txHashes1 := [][]byte{[]byte("hash-1"), []byte("hash-2")} - txSender1 := "alice" - wrappedTxs1 := []*WrappedTransaction{ - createTx(txHashes1[1], txSender1, 2), - createTx(txHashes1[0], txSender1, 1), - } - txHashes2 := [][]byte{[]byte("hash-3"), []byte("hash-4"), []byte("hash-5")} - txSender2 := "bob" - wrappedTxs2 := []*WrappedTransaction{ - createTx(txHashes2[1], txSender2, 4), - createTx(txHashes2[0], txSender2, 3), - createTx(txHashes2[2], txSender2, 5), - } - cache.AddTx(wrappedTxs1[0]) - cache.AddTx(wrappedTxs1[1]) - cache.AddTx(wrappedTxs2[0]) - cache.AddTx(wrappedTxs2[1]) - cache.AddTx(wrappedTxs2[2]) - - sort.Slice(wrappedTxs1, func(i, j int) bool { - return wrappedTxs1[i].Tx.GetNonce() < wrappedTxs1[j].Tx.GetNonce() - }) - txs := cache.GetTransactionsPoolForSender(txSender1) - require.Equal(t, wrappedTxs1, txs) - - sort.Slice(wrappedTxs2, func(i, j int) bool { - return wrappedTxs2[i].Tx.GetNonce() < wrappedTxs2[j].Tx.GetNonce() - }) - txs = cache.GetTransactionsPoolForSender(txSender2) - require.Equal(t, wrappedTxs2, txs) - - cache.RemoveTxByHash(txHashes2[0]) - expectedTxs := wrappedTxs2[1:] - txs = cache.GetTransactionsPoolForSender(txSender2) - require.Equal(t, expectedTxs, txs) -} - -func Test_SelectTransactions_Dummy(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("hash-alice-4"), "alice", 4)) - cache.AddTx(createTx([]byte("hash-alice-3"), "alice", 3)) - cache.AddTx(createTx([]byte("hash-alice-2"), "alice", 2)) - cache.AddTx(createTx([]byte("hash-alice-1"), "alice", 1)) - cache.AddTx(createTx([]byte("hash-bob-7"), "bob", 7)) - cache.AddTx(createTx([]byte("hash-bob-6"), "bob", 6)) - cache.AddTx(createTx([]byte("hash-bob-5"), "bob", 5)) - cache.AddTx(createTx([]byte("hash-carol-1"), "carol", 1)) - - sorted := cache.SelectTransactionsWithBandwidth(10, 2, math.MaxUint64) - require.Len(t, sorted, 8) -} - -func Test_SelectTransactionsWithBandwidth_Dummy(t *testing.T) { - cache := newUnconstrainedCacheToTest() - cache.AddTx(createTxWithGasLimit([]byte("hash-alice-4"), "alice", 4, 100000)) - cache.AddTx(createTxWithGasLimit([]byte("hash-alice-3"), "alice", 3, 100000)) - cache.AddTx(createTxWithGasLimit([]byte("hash-alice-2"), "alice", 2, 500000)) - cache.AddTx(createTxWithGasLimit([]byte("hash-alice-1"), "alice", 1, 200000)) - cache.AddTx(createTxWithGasLimit([]byte("hash-bob-7"), "bob", 7, 100000)) - cache.AddTx(createTxWithGasLimit([]byte("hash-bob-6"), "bob", 6, 50000)) - cache.AddTx(createTxWithGasLimit([]byte("hash-bob-5"), "bob", 5, 50000)) - cache.AddTx(createTxWithGasLimit([]byte("hash-carol-1"), "carol", 1, 50000)) - - sorted := cache.SelectTransactionsWithBandwidth(5, 2, 200000) - numSelected := 1 + 1 + 3 // 1 alice, 1 carol, 3 bob - - require.Len(t, sorted, numSelected) -} - -func Test_SelectTransactions_BreaksAtNonceGaps(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("hash-alice-1"), "alice", 1)) - cache.AddTx(createTx([]byte("hash-alice-2"), "alice", 2)) - cache.AddTx(createTx([]byte("hash-alice-3"), "alice", 3)) - cache.AddTx(createTx([]byte("hash-alice-5"), "alice", 5)) - cache.AddTx(createTx([]byte("hash-bob-42"), "bob", 42)) - cache.AddTx(createTx([]byte("hash-bob-44"), "bob", 44)) - cache.AddTx(createTx([]byte("hash-bob-45"), "bob", 45)) - cache.AddTx(createTx([]byte("hash-carol-7"), "carol", 7)) - cache.AddTx(createTx([]byte("hash-carol-8"), "carol", 8)) - cache.AddTx(createTx([]byte("hash-carol-10"), "carol", 10)) - cache.AddTx(createTx([]byte("hash-carol-11"), "carol", 11)) - - numSelected := 3 + 1 + 2 // 3 alice + 1 bob + 2 carol - - sorted := cache.SelectTransactionsWithBandwidth(10, 2, math.MaxUint64) - require.Len(t, sorted, numSelected) -} - -func Test_SelectTransactions(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - // Add "nSenders" * "nTransactionsPerSender" transactions in the cache (in reversed nonce order) - nSenders := 1000 - nTransactionsPerSender := 100 - nTotalTransactions := nSenders * nTransactionsPerSender - nRequestedTransactions := math.MaxInt16 - - for senderTag := 0; senderTag < nSenders; senderTag++ { - sender := fmt.Sprintf("sender:%d", senderTag) - - for txNonce := nTransactionsPerSender; txNonce > 0; txNonce-- { - txHash := fmt.Sprintf("hash:%d:%d", senderTag, txNonce) - tx := createTx([]byte(txHash), sender, uint64(txNonce)) - cache.AddTx(tx) - } - } - - require.Equal(t, uint64(nTotalTransactions), cache.CountTx()) - - sorted := cache.SelectTransactionsWithBandwidth(nRequestedTransactions, 2, math.MaxUint64) - - require.Len(t, sorted, core.MinInt(nRequestedTransactions, nTotalTransactions)) - - // Check order - nonces := make(map[string]uint64, nSenders) - for _, tx := range sorted { - nonce := tx.Tx.GetNonce() - sender := string(tx.Tx.GetSndAddr()) - previousNonce := nonces[sender] - - require.LessOrEqual(t, previousNonce, nonce) - nonces[sender] = nonce - } -} - -func Test_Keys(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - cache.AddTx(createTx([]byte("alice-y"), "alice", 43)) - cache.AddTx(createTx([]byte("bob-x"), "bob", 42)) - cache.AddTx(createTx([]byte("bob-y"), "bob", 43)) - - keys := cache.Keys() - require.Equal(t, 4, len(keys)) - require.Contains(t, keys, []byte("alice-x")) - require.Contains(t, keys, []byte("alice-y")) - require.Contains(t, keys, []byte("bob-x")) - require.Contains(t, keys, []byte("bob-y")) -} - -func Test_AddWithEviction_UniformDistributionOfTxsPerSender(t *testing.T) { - txGasHandler, _ := dummyParams() - config := ConfigSourceMe{ - Name: "untitled", - NumChunks: 16, - EvictionEnabled: true, - NumBytesThreshold: maxNumBytesUpperBound, - CountThreshold: 100, - NumSendersToPreemptivelyEvict: 1, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - } - - // 11 * 10 - cache, err := NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - addManyTransactionsWithUniformDistribution(cache, 11, 10) - require.LessOrEqual(t, cache.CountTx(), uint64(100)) - - config = ConfigSourceMe{ - Name: "untitled", - NumChunks: 16, - EvictionEnabled: true, - NumBytesThreshold: maxNumBytesUpperBound, - CountThreshold: 250000, - NumSendersToPreemptivelyEvict: 1, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - } - - // 100 * 1000 - cache, err = NewTxCache(config, txGasHandler) - require.Nil(t, err) - require.NotNil(t, cache) - - addManyTransactionsWithUniformDistribution(cache, 100, 1000) - require.LessOrEqual(t, cache.CountTx(), uint64(250000)) -} - -func Test_NotImplementedFunctions(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - evicted := cache.Put(nil, nil, 0) - require.False(t, evicted) - - has, added := cache.HasOrAdd(nil, nil, 0) - require.False(t, has) - require.False(t, added) - - require.NotPanics(t, func() { cache.RegisterHandler(nil, "") }) - require.Zero(t, cache.MaxSize()) - - err := cache.Close() - require.Nil(t, err) -} - -func Test_IsInterfaceNil(t *testing.T) { - cache := newUnconstrainedCacheToTest() - require.False(t, check.IfNil(cache)) - - makeNil := func() storage.Cacher { - return nil - } - - thisIsNil := makeNil() - require.True(t, check.IfNil(thisIsNil)) -} - -func TestTxCache_ConcurrentMutationAndSelection(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - // Alice will quickly move between two score buckets (chunks) - cheapTransaction := createTxWithParams([]byte("alice-x-o"), "alice", 0, 128, 50000, 100*oneBillion) - expensiveTransaction := createTxWithParams([]byte("alice-x-1"), "alice", 1, 128, 50000, 300*oneBillion) - cache.AddTx(cheapTransaction) - cache.AddTx(expensiveTransaction) - - wg := sync.WaitGroup{} - - // Simulate selection - wg.Add(1) - go func() { - for i := 0; i < 100; i++ { - fmt.Println("Selection", i) - cache.SelectTransactionsWithBandwidth(100, 100, math.MaxUint64) - } - - wg.Done() - }() - - // Simulate add / remove transactions - wg.Add(1) - go func() { - for i := 0; i < 100; i++ { - fmt.Println("Add / remove", i) - cache.Remove([]byte("alice-x-1")) - cache.AddTx(expensiveTransaction) - } - - wg.Done() - }() - - timedOut := waitTimeout(&wg, 1*time.Second) - require.False(t, timedOut, "Timed out. Perhaps deadlock?") -} - -func TestTxCache_TransactionIsAdded_EvenWhenInternalMapsAreInconsistent(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - // Setup inconsistency: transaction already exists in map by hash, but not in map by sender - cache.txByHash.addTx(createTx([]byte("alice-x"), "alice", 42)) - - require.Equal(t, 1, cache.txByHash.backingMap.Count()) - require.True(t, cache.Has([]byte("alice-x"))) - ok, added := cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - require.True(t, ok) - require.True(t, added) - require.Equal(t, uint64(1), cache.CountSenders()) - require.Equal(t, []string{"alice-x"}, cache.getHashesForSender("alice")) - cache.Clear() - - // Setup inconsistency: transaction already exists in map by sender, but not in map by hash - cache.txListBySender.addTx(createTx([]byte("alice-x"), "alice", 42)) - - require.False(t, cache.Has([]byte("alice-x"))) - ok, added = cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - require.True(t, ok) - require.True(t, added) - require.Equal(t, uint64(1), cache.CountSenders()) - require.Equal(t, []string{"alice-x"}, cache.getHashesForSender("alice")) - cache.Clear() -} - -func TestTxCache_NoCriticalInconsistency_WhenConcurrentAdditionsAndRemovals(t *testing.T) { - cache := newUnconstrainedCacheToTest() - - // A lot of routines concur to add & remove THE FIRST transaction of a sender - for try := 0; try < 100; try++ { - var wg sync.WaitGroup - - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { - cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - _ = cache.RemoveTxByHash([]byte("alice-x")) - wg.Done() - }() - } - - wg.Wait() - // In this case, there is the slight chance that: - // go A: add to map by hash - // go B: won't add in map by hash, already there - // go A: add to map by sender - // go A: remove from map by hash - // go A: remove from map by sender and delete empty sender - // go B: add to map by sender - // go B: can't remove from map by hash, not found - // go B: won't remove from map by sender (sender unknown) - - // Therefore, the number of senders could be 0 or 1 - require.Equal(t, 0, cache.txByHash.backingMap.Count()) - expectedCountConsistent := 0 - expectedCountSlightlyInconsistent := 1 - actualCount := int(cache.txListBySender.backingMap.Count()) - require.True(t, actualCount == expectedCountConsistent || actualCount == expectedCountSlightlyInconsistent) - - // A further addition works: - cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - require.True(t, cache.Has([]byte("alice-x"))) - require.Equal(t, []string{"alice-x"}, cache.getHashesForSender("alice")) - } - - cache.Clear() - - // A lot of routines concur to add & remove subsequent transactions of a sender - cache.AddTx(createTx([]byte("alice-w"), "alice", 41)) - - for try := 0; try < 100; try++ { - var wg sync.WaitGroup - - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { - cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - _ = cache.RemoveTxByHash([]byte("alice-x")) - wg.Done() - }() - } - - wg.Wait() - - // In this case, there is the slight chance that: - // go A: add to map by hash - // go B: won't add in map by hash, already there - // go A: add to map by sender (existing sender/list) - // go A: remove from map by hash - // go A: remove from map by sender - // go B: add to map by sender (existing sender/list) - // go B: can't remove from map by hash, not found - // go B: won't remove from map by sender (sender unknown) - - // Therefore, Alice may have one or two transactions in her list. - require.Equal(t, 1, cache.txByHash.backingMap.Count()) - expectedTxsConsistent := []string{"alice-w"} - expectedTxsSlightlyInconsistent := []string{"alice-w", "alice-x"} - actualTxs := cache.getHashesForSender("alice") - require.True(t, assert.ObjectsAreEqual(expectedTxsConsistent, actualTxs) || assert.ObjectsAreEqual(expectedTxsSlightlyInconsistent, actualTxs)) - - // A further addition works: - cache.AddTx(createTx([]byte("alice-x"), "alice", 42)) - require.True(t, cache.Has([]byte("alice-w"))) - require.True(t, cache.Has([]byte("alice-x"))) - require.Equal(t, []string{"alice-w", "alice-x"}, cache.getHashesForSender("alice")) - } - - cache.Clear() -} - -func newUnconstrainedCacheToTest() *TxCache { - txGasHandler, _ := dummyParams() - cache, err := NewTxCache(ConfigSourceMe{ - Name: "test", - NumChunks: 16, - NumBytesPerSenderThreshold: maxNumBytesPerSenderUpperBound, - CountPerSenderThreshold: math.MaxUint32, - }, txGasHandler) - if err != nil { - panic(fmt.Sprintf("newUnconstrainedCacheToTest(): %s", err)) - } - - return cache -} - -func newCacheToTest(numBytesPerSenderThreshold uint32, countPerSenderThreshold uint32) *TxCache { - txGasHandler, _ := dummyParams() - cache, err := NewTxCache(ConfigSourceMe{ - Name: "test", - NumChunks: 16, - NumBytesPerSenderThreshold: numBytesPerSenderThreshold, - CountPerSenderThreshold: countPerSenderThreshold, - }, txGasHandler) - if err != nil { - panic(fmt.Sprintf("newCacheToTest(): %s", err)) - } - - return cache -} diff --git a/storage/txcache/txListBySenderMap.go b/storage/txcache/txListBySenderMap.go deleted file mode 100644 index 1ff04738d0f..00000000000 --- a/storage/txcache/txListBySenderMap.go +++ /dev/null @@ -1,171 +0,0 @@ -package txcache - -import ( - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - "github.com/ElrondNetwork/elrond-go/storage/txcache/maps" -) - -const numberOfScoreChunks = uint32(100) - -// txListBySenderMap is a map-like structure for holding and accessing transactions by sender -type txListBySenderMap struct { - backingMap *maps.BucketSortedMap - senderConstraints senderConstraints - counter atomic.Counter - scoreComputer scoreComputer - txGasHandler TxGasHandler - txFeeHelper feeHelper - mutex sync.Mutex -} - -// newTxListBySenderMap creates a new instance of TxListBySenderMap -func newTxListBySenderMap( - nChunksHint uint32, - senderConstraints senderConstraints, - scoreComputer scoreComputer, - txGasHandler TxGasHandler, - txFeeHelper feeHelper, -) *txListBySenderMap { - backingMap := maps.NewBucketSortedMap(nChunksHint, numberOfScoreChunks) - - return &txListBySenderMap{ - backingMap: backingMap, - senderConstraints: senderConstraints, - scoreComputer: scoreComputer, - txGasHandler: txGasHandler, - txFeeHelper: txFeeHelper, - } -} - -// addTx adds a transaction in the map, in the corresponding list (selected by its sender) -func (txMap *txListBySenderMap) addTx(tx *WrappedTransaction) (bool, [][]byte) { - sender := string(tx.Tx.GetSndAddr()) - listForSender := txMap.getOrAddListForSender(sender) - return listForSender.AddTx(tx, txMap.txGasHandler, txMap.txFeeHelper) -} - -// getOrAddListForSender gets or lazily creates a list (using double-checked locking pattern) -func (txMap *txListBySenderMap) getOrAddListForSender(sender string) *txListForSender { - listForSender, ok := txMap.getListForSender(sender) - if ok { - return listForSender - } - - txMap.mutex.Lock() - defer txMap.mutex.Unlock() - - listForSender, ok = txMap.getListForSender(sender) - if ok { - return listForSender - } - - return txMap.addSender(sender) -} - -func (txMap *txListBySenderMap) getListForSender(sender string) (*txListForSender, bool) { - listForSenderUntyped, ok := txMap.backingMap.Get(sender) - if !ok { - return nil, false - } - - listForSender := listForSenderUntyped.(*txListForSender) - return listForSender, true -} - -func (txMap *txListBySenderMap) addSender(sender string) *txListForSender { - listForSender := newTxListForSender(sender, &txMap.senderConstraints, txMap.notifyScoreChange) - - txMap.backingMap.Set(listForSender) - txMap.counter.Increment() - - return listForSender -} - -// This function should only be called in a critical section managed by a "txListForSender" -func (txMap *txListBySenderMap) notifyScoreChange(txList *txListForSender, scoreParams senderScoreParams) { - score := txMap.scoreComputer.computeScore(scoreParams) - txList.setLastComputedScore(score) - txMap.backingMap.NotifyScoreChange(txList, score) -} - -// removeTx removes a transaction from the map -func (txMap *txListBySenderMap) removeTx(tx *WrappedTransaction) bool { - sender := string(tx.Tx.GetSndAddr()) - - listForSender, ok := txMap.getListForSender(sender) - if !ok { - // This happens when a sender whose transactions were selected for processing is removed from cache in the meantime. - // When it comes to remove one if its transactions due to processing (commited / finalized block), they don't exist in cache anymore. - log.Trace("txListBySenderMap.removeTx() detected slight inconsistency: sender of tx not in cache", "tx", tx.TxHash, "sender", []byte(sender)) - return false - } - - isFound := listForSender.RemoveTx(tx) - isEmpty := listForSender.IsEmpty() - if isEmpty { - txMap.removeSender(sender) - } - - return isFound -} - -func (txMap *txListBySenderMap) removeSender(sender string) bool { - _, removed := txMap.backingMap.Remove(sender) - if removed { - txMap.counter.Decrement() - } - - return removed -} - -// RemoveSendersBulk removes senders, in bulk -func (txMap *txListBySenderMap) RemoveSendersBulk(senders []string) uint32 { - numRemoved := uint32(0) - - for _, senderKey := range senders { - if txMap.removeSender(senderKey) { - numRemoved++ - } - } - - return numRemoved -} - -func (txMap *txListBySenderMap) notifyAccountNonce(accountKey []byte, nonce uint64) { - sender := string(accountKey) - listForSender, ok := txMap.getListForSender(sender) - if !ok { - return - } - - listForSender.notifyAccountNonce(nonce) -} - -func (txMap *txListBySenderMap) getSnapshotAscending() []*txListForSender { - itemsSnapshot := txMap.backingMap.GetSnapshotAscending() - listsSnapshot := make([]*txListForSender, len(itemsSnapshot)) - - for i, item := range itemsSnapshot { - listsSnapshot[i] = item.(*txListForSender) - } - - return listsSnapshot -} - -func (txMap *txListBySenderMap) getSnapshotDescending() []*txListForSender { - itemsSnapshot := txMap.backingMap.GetSnapshotDescending() - listsSnapshot := make([]*txListForSender, len(itemsSnapshot)) - - for i, item := range itemsSnapshot { - listsSnapshot[i] = item.(*txListForSender) - } - - return listsSnapshot -} - -func (txMap *txListBySenderMap) clear() { - txMap.backingMap.Clear() - txMap.counter.Set(0) -} diff --git a/storage/txcache/txListBySenderMap_test.go b/storage/txcache/txListBySenderMap_test.go deleted file mode 100644 index d3393225889..00000000000 --- a/storage/txcache/txListBySenderMap_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package txcache - -import ( - "fmt" - "math" - "sync" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSendersMap_AddTx_IncrementsCounter(t *testing.T) { - myMap := newSendersMapToTest() - - myMap.addTx(createTx([]byte("a"), "alice", uint64(1))) - myMap.addTx(createTx([]byte("aa"), "alice", uint64(2))) - myMap.addTx(createTx([]byte("b"), "bob", uint64(1))) - - // There are 2 senders - require.Equal(t, int64(2), myMap.counter.Get()) -} - -func TestSendersMap_RemoveTx_AlsoRemovesSenderWhenNoTransactionLeft(t *testing.T) { - myMap := newSendersMapToTest() - - txAlice1 := createTx([]byte("a1"), "alice", uint64(1)) - txAlice2 := createTx([]byte("a2"), "alice", uint64(2)) - txBob := createTx([]byte("b"), "bob", uint64(1)) - - myMap.addTx(txAlice1) - myMap.addTx(txAlice2) - myMap.addTx(txBob) - require.Equal(t, int64(2), myMap.counter.Get()) - require.Equal(t, uint64(2), myMap.testGetListForSender("alice").countTx()) - require.Equal(t, uint64(1), myMap.testGetListForSender("bob").countTx()) - - myMap.removeTx(txAlice1) - require.Equal(t, int64(2), myMap.counter.Get()) - require.Equal(t, uint64(1), myMap.testGetListForSender("alice").countTx()) - require.Equal(t, uint64(1), myMap.testGetListForSender("bob").countTx()) - - myMap.removeTx(txAlice2) - // All alice's transactions have been removed now - require.Equal(t, int64(1), myMap.counter.Get()) - - myMap.removeTx(txBob) - // Also Bob has no more transactions - require.Equal(t, int64(0), myMap.counter.Get()) -} - -func TestSendersMap_RemoveSender(t *testing.T) { - myMap := newSendersMapToTest() - - myMap.addTx(createTx([]byte("a"), "alice", uint64(1))) - require.Equal(t, int64(1), myMap.counter.Get()) - - // Bob is unknown - myMap.removeSender("bob") - require.Equal(t, int64(1), myMap.counter.Get()) - - myMap.removeSender("alice") - require.Equal(t, int64(0), myMap.counter.Get()) -} - -func TestSendersMap_RemoveSendersBulk_ConcurrentWithAddition(t *testing.T) { - myMap := newSendersMapToTest() - - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - - for i := 0; i < 100; i++ { - numRemoved := myMap.RemoveSendersBulk([]string{"alice"}) - require.LessOrEqual(t, numRemoved, uint32(1)) - - numRemoved = myMap.RemoveSendersBulk([]string{"bob"}) - require.LessOrEqual(t, numRemoved, uint32(1)) - - numRemoved = myMap.RemoveSendersBulk([]string{"carol"}) - require.LessOrEqual(t, numRemoved, uint32(1)) - } - }() - - wg.Add(100) - for i := 0; i < 100; i++ { - go func(i int) { - myMap.addTx(createTx([]byte("a"), "alice", uint64(i))) - myMap.addTx(createTx([]byte("b"), "bob", uint64(i))) - myMap.addTx(createTx([]byte("c"), "carol", uint64(i))) - - wg.Done() - }(i) - } - - wg.Wait() -} - -func TestSendersMap_notifyAccountNonce(t *testing.T) { - myMap := newSendersMapToTest() - - // Discarded notification, since sender not added yet - myMap.notifyAccountNonce([]byte("alice"), 42) - - myMap.addTx(createTx([]byte("tx-42"), "alice", uint64(42))) - alice, _ := myMap.getListForSender("alice") - require.Equal(t, uint64(0), alice.accountNonce.Get()) - require.False(t, alice.accountNonceKnown.IsSet()) - - myMap.notifyAccountNonce([]byte("alice"), 42) - require.Equal(t, uint64(42), alice.accountNonce.Get()) - require.True(t, alice.accountNonceKnown.IsSet()) -} - -func BenchmarkSendersMap_GetSnapshotAscending(b *testing.B) { - if b.N > 10 { - fmt.Println("impractical benchmark: b.N too high") - return - } - - numSenders := 250000 - maps := make([]*txListBySenderMap, b.N) - for i := 0; i < b.N; i++ { - maps[i] = createTxListBySenderMap(numSenders) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - measureWithStopWatch(b, func() { - snapshot := maps[i].getSnapshotAscending() - require.Len(b, snapshot, numSenders) - }) - } -} - -func TestSendersMap_GetSnapshots_NoPanic_IfAlsoConcurrentMutation(t *testing.T) { - myMap := newSendersMapToTest() - - var wg sync.WaitGroup - - for i := 0; i < 100; i++ { - wg.Add(2) - - go func() { - for j := 0; j < 100; j++ { - myMap.getSnapshotAscending() - } - - wg.Done() - }() - - go func() { - for j := 0; j < 1000; j++ { - sender := fmt.Sprintf("Sender-%d", j) - myMap.removeSender(sender) - } - - wg.Done() - }() - } - - wg.Wait() -} - -func createTxListBySenderMap(numSenders int) *txListBySenderMap { - myMap := newSendersMapToTest() - for i := 0; i < numSenders; i++ { - sender := fmt.Sprintf("Sender-%d", i) - hash := createFakeTxHash([]byte(sender), 1) - myMap.addTx(createTx(hash, sender, uint64(1))) - } - - return myMap -} - -func newSendersMapToTest() *txListBySenderMap { - txGasHandler, txFeeHelper := dummyParams() - return newTxListBySenderMap(4, senderConstraints{ - maxNumBytes: math.MaxUint32, - maxNumTxs: math.MaxUint32, - }, &disabledScoreComputer{}, txGasHandler, txFeeHelper) -} diff --git a/storage/txcache/txListForSender.go b/storage/txcache/txListForSender.go deleted file mode 100644 index 07335bf53c2..00000000000 --- a/storage/txcache/txListForSender.go +++ /dev/null @@ -1,412 +0,0 @@ -package txcache - -import ( - "bytes" - "container/list" - "sync" - - "github.com/ElrondNetwork/elrond-go-core/core/atomic" - "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/txcache/maps" -) - -var _ maps.BucketSortedMapItem = (*txListForSender)(nil) - -// txListForSender represents a sorted list of transactions of a particular sender -type txListForSender struct { - copyDetectedGap bool - lastComputedScore atomic.Uint32 - accountNonceKnown atomic.Flag - sweepable atomic.Flag - copyPreviousNonce uint64 - sender string - items *list.List - copyBatchIndex *list.Element - constraints *senderConstraints - scoreChunk *maps.MapChunk - accountNonce atomic.Uint64 - totalBytes atomic.Counter - totalGas atomic.Counter - totalFeeScore atomic.Counter - numFailedSelections atomic.Counter - onScoreChange scoreChangeCallback - - scoreChunkMutex sync.RWMutex - mutex sync.RWMutex -} - -type scoreChangeCallback func(value *txListForSender, scoreParams senderScoreParams) - -// newTxListForSender creates a new (sorted) list of transactions -func newTxListForSender(sender string, constraints *senderConstraints, onScoreChange scoreChangeCallback) *txListForSender { - return &txListForSender{ - items: list.New(), - sender: sender, - constraints: constraints, - onScoreChange: onScoreChange, - } -} - -// AddTx adds a transaction in sender's list -// This is a "sorted" insert -func (listForSender *txListForSender) AddTx(tx *WrappedTransaction, gasHandler TxGasHandler, txFeeHelper feeHelper) (bool, [][]byte) { - // We don't allow concurrent interceptor goroutines to mutate a given sender's list - listForSender.mutex.Lock() - defer listForSender.mutex.Unlock() - - insertionPlace, err := listForSender.findInsertionPlace(tx) - if err != nil { - return false, nil - } - - if insertionPlace == nil { - listForSender.items.PushFront(tx) - } else { - listForSender.items.InsertAfter(tx, insertionPlace) - } - - listForSender.onAddedTransaction(tx, gasHandler, txFeeHelper) - evicted := listForSender.applySizeConstraints() - listForSender.triggerScoreChange() - return true, evicted -} - -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) applySizeConstraints() [][]byte { - evictedTxHashes := make([][]byte, 0) - - // Iterate back to front - for element := listForSender.items.Back(); element != nil; element = element.Prev() { - if !listForSender.isCapacityExceeded() { - break - } - - listForSender.items.Remove(element) - listForSender.onRemovedListElement(element) - - // Keep track of removed transactions - value := element.Value.(*WrappedTransaction) - evictedTxHashes = append(evictedTxHashes, value.TxHash) - } - - return evictedTxHashes -} - -func (listForSender *txListForSender) isCapacityExceeded() bool { - maxBytes := int64(listForSender.constraints.maxNumBytes) - maxNumTxs := uint64(listForSender.constraints.maxNumTxs) - tooManyBytes := listForSender.totalBytes.Get() > maxBytes - tooManyTxs := listForSender.countTx() > maxNumTxs - - return tooManyBytes || tooManyTxs -} - -func (listForSender *txListForSender) onAddedTransaction(tx *WrappedTransaction, gasHandler TxGasHandler, txFeeHelper feeHelper) { - listForSender.totalBytes.Add(tx.Size) - listForSender.totalGas.Add(int64(estimateTxGas(tx))) - listForSender.totalFeeScore.Add(int64(estimateTxFeeScore(tx, gasHandler, txFeeHelper))) -} - -func (listForSender *txListForSender) triggerScoreChange() { - scoreParams := listForSender.getScoreParams() - listForSender.onScoreChange(listForSender, scoreParams) -} - -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) getScoreParams() senderScoreParams { - fee := listForSender.totalFeeScore.GetUint64() - gas := listForSender.totalGas.GetUint64() - count := listForSender.countTx() - - return senderScoreParams{count: count, feeScore: fee, gas: gas} -} - -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) findInsertionPlace(incomingTx *WrappedTransaction) (*list.Element, error) { - incomingNonce := incomingTx.Tx.GetNonce() - incomingGasPrice := incomingTx.Tx.GetGasPrice() - - for element := listForSender.items.Back(); element != nil; element = element.Prev() { - currentTx := element.Value.(*WrappedTransaction) - currentTxNonce := currentTx.Tx.GetNonce() - currentTxGasPrice := currentTx.Tx.GetGasPrice() - - if incomingTx.sameAs(currentTx) { - // The incoming transaction will be discarded - return nil, storage.ErrItemAlreadyInCache - } - - if currentTxNonce == incomingNonce { - if currentTxGasPrice > incomingGasPrice { - // The incoming transaction will be placed right after the existing one, which has same nonce but higher price. - // If the nonces are the same, but the incoming gas price is higher or equal, the search loop continues. - return element, nil - } - if currentTxGasPrice == incomingGasPrice { - // The incoming transaction will be placed right after the existing one, which has same nonce and the same price. - // (but different hash, because of some other fields like receiver, value or data) - // This will order out the transactions having the same nonce and gas price - if bytes.Compare(currentTx.TxHash, incomingTx.TxHash) < 0 { - return element, nil - } - } - } - - if currentTxNonce < incomingNonce { - // We've found the first transaction with a lower nonce than the incoming one, - // thus the incoming transaction will be placed right after this one. - return element, nil - } - } - - // The incoming transaction will be inserted at the head of the list. - return nil, nil -} - -// RemoveTx removes a transaction from the sender's list -func (listForSender *txListForSender) RemoveTx(tx *WrappedTransaction) bool { - // We don't allow concurrent interceptor goroutines to mutate a given sender's list - listForSender.mutex.Lock() - defer listForSender.mutex.Unlock() - - marker := listForSender.findListElementWithTx(tx) - isFound := marker != nil - if isFound { - listForSender.items.Remove(marker) - listForSender.onRemovedListElement(marker) - listForSender.triggerScoreChange() - } - - return isFound -} - -func (listForSender *txListForSender) onRemovedListElement(element *list.Element) { - value := element.Value.(*WrappedTransaction) - - listForSender.totalBytes.Subtract(value.Size) - listForSender.totalGas.Subtract(int64(estimateTxGas(value))) - listForSender.totalFeeScore.Subtract(int64(value.TxFeeScoreNormalized)) -} - -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) findListElementWithTx(txToFind *WrappedTransaction) *list.Element { - txToFindHash := txToFind.TxHash - txToFindNonce := txToFind.Tx.GetNonce() - - for element := listForSender.items.Front(); element != nil; element = element.Next() { - value := element.Value.(*WrappedTransaction) - - if bytes.Equal(value.TxHash, txToFindHash) { - return element - } - - // Optimization: stop search at this point, since the list is sorted by nonce - if value.Tx.GetNonce() > txToFindNonce { - break - } - } - - return nil -} - -// IsEmpty checks whether the list is empty -func (listForSender *txListForSender) IsEmpty() bool { - return listForSender.countTxWithLock() == 0 -} - -// selectBatchTo copies a batch (usually small) of transactions of a limited gas bandwidth and limited number of transactions to a destination slice -// It also updates the internal state used for copy operations -func (listForSender *txListForSender) selectBatchTo(isFirstBatch bool, destination []*WrappedTransaction, batchSize int, bandwidth uint64) batchSelectionJournal { - // We can't read from multiple goroutines at the same time - // And we can't mutate the sender's list while reading it - listForSender.mutex.Lock() - defer listForSender.mutex.Unlock() - - journal := batchSelectionJournal{} - - // Reset the internal state used for copy operations - if isFirstBatch { - hasInitialGap := listForSender.verifyInitialGapOnSelectionStart() - - listForSender.copyBatchIndex = listForSender.items.Front() - listForSender.copyPreviousNonce = 0 - listForSender.copyDetectedGap = hasInitialGap - - journal.isFirstBatch = true - journal.hasInitialGap = hasInitialGap - } - - element := listForSender.copyBatchIndex - availableSpace := len(destination) - detectedGap := listForSender.copyDetectedGap - previousNonce := listForSender.copyPreviousNonce - - // If a nonce gap is detected, no transaction is returned in this read. - // There is an exception though: if this is the first read operation for the sender in the current selection process and the sender is in the grace period, - // then one transaction will be returned. But subsequent reads for this sender will return nothing. - if detectedGap { - if isFirstBatch && listForSender.isInGracePeriod() { - journal.isGracePeriod = true - batchSize = 1 - } else { - batchSize = 0 - } - } - - copiedBandwidth := uint64(0) - lastTxGasLimit := uint64(0) - copied := 0 - for ; ; copied, copiedBandwidth = copied+1, copiedBandwidth+lastTxGasLimit { - if element == nil || copied == batchSize || copied == availableSpace || copiedBandwidth >= bandwidth { - break - } - - value := element.Value.(*WrappedTransaction) - txNonce := value.Tx.GetNonce() - lastTxGasLimit = value.Tx.GetGasLimit() - - if previousNonce > 0 && txNonce > previousNonce+1 { - listForSender.copyDetectedGap = true - journal.hasMiddleGap = true - break - } - - destination[copied] = value - element = element.Next() - previousNonce = txNonce - } - - listForSender.copyBatchIndex = element - listForSender.copyPreviousNonce = previousNonce - journal.copied = copied - return journal -} - -// getTxHashes returns the hashes of transactions in the list -func (listForSender *txListForSender) getTxHashes() [][]byte { - listForSender.mutex.RLock() - defer listForSender.mutex.RUnlock() - - result := make([][]byte, 0, listForSender.countTx()) - - for element := listForSender.items.Front(); element != nil; element = element.Next() { - value := element.Value.(*WrappedTransaction) - result = append(result, value.TxHash) - } - - return result -} - -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) countTx() uint64 { - return uint64(listForSender.items.Len()) -} - -func (listForSender *txListForSender) countTxWithLock() uint64 { - listForSender.mutex.RLock() - defer listForSender.mutex.RUnlock() - return uint64(listForSender.items.Len()) -} - -func approximatelyCountTxInLists(lists []*txListForSender) uint64 { - count := uint64(0) - - for _, listForSender := range lists { - count += listForSender.countTxWithLock() - } - - return count -} - -// notifyAccountNonce does not update the "numFailedSelections" counter, -// since the notification comes at a time when we cannot actually detect whether the initial gap still exists or it was resolved. -func (listForSender *txListForSender) notifyAccountNonce(nonce uint64) { - listForSender.accountNonce.Set(nonce) - _ = listForSender.accountNonceKnown.SetReturningPrevious() -} - -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) verifyInitialGapOnSelectionStart() bool { - hasInitialGap := listForSender.hasInitialGap() - - if hasInitialGap { - listForSender.numFailedSelections.Increment() - - if listForSender.isGracePeriodExceeded() { - _ = listForSender.sweepable.SetReturningPrevious() - } - } else { - listForSender.numFailedSelections.Reset() - } - - return hasInitialGap -} - -// hasInitialGap should only be called at tx selection time, since only then we can detect initial gaps with certainty -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) hasInitialGap() bool { - accountNonceKnown := listForSender.accountNonceKnown.IsSet() - if !accountNonceKnown { - return false - } - - firstTx := listForSender.getLowestNonceTx() - if firstTx == nil { - return false - } - - firstTxNonce := firstTx.Tx.GetNonce() - accountNonce := listForSender.accountNonce.Get() - hasGap := firstTxNonce > accountNonce - return hasGap -} - -// This function should only be used in critical section (listForSender.mutex) -func (listForSender *txListForSender) getLowestNonceTx() *WrappedTransaction { - front := listForSender.items.Front() - if front == nil { - return nil - } - - value := front.Value.(*WrappedTransaction) - return value -} - -// isInGracePeriod returns whether the sender is grace period due to a number of failed selections -func (listForSender *txListForSender) isInGracePeriod() bool { - numFailedSelections := listForSender.numFailedSelections.Get() - return numFailedSelections >= senderGracePeriodLowerBound && numFailedSelections <= senderGracePeriodUpperBound -} - -func (listForSender *txListForSender) isGracePeriodExceeded() bool { - numFailedSelections := listForSender.numFailedSelections.Get() - return numFailedSelections > senderGracePeriodUpperBound -} - -func (listForSender *txListForSender) getLastComputedScore() uint32 { - return listForSender.lastComputedScore.Get() -} - -func (listForSender *txListForSender) setLastComputedScore(score uint32) { - listForSender.lastComputedScore.Set(score) -} - -// GetKey returns the key -func (listForSender *txListForSender) GetKey() string { - return listForSender.sender -} - -// GetScoreChunk returns the score chunk the sender is currently in -func (listForSender *txListForSender) GetScoreChunk() *maps.MapChunk { - listForSender.scoreChunkMutex.RLock() - defer listForSender.scoreChunkMutex.RUnlock() - - return listForSender.scoreChunk -} - -// SetScoreChunk returns the score chunk the sender is currently in -func (listForSender *txListForSender) SetScoreChunk(scoreChunk *maps.MapChunk) { - listForSender.scoreChunkMutex.Lock() - listForSender.scoreChunk = scoreChunk - listForSender.scoreChunkMutex.Unlock() -} diff --git a/storage/txcache/txListForSender_test.go b/storage/txcache/txListForSender_test.go deleted file mode 100644 index 6dc44c5f1af..00000000000 --- a/storage/txcache/txListForSender_test.go +++ /dev/null @@ -1,443 +0,0 @@ -package txcache - -import ( - "math" - "testing" - - "github.com/ElrondNetwork/elrond-go-core/data/transaction" - "github.com/ElrondNetwork/elrond-go/testscommon/txcachemocks" - "github.com/stretchr/testify/require" -) - -func TestListForSender_AddTx_Sorts(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - list.AddTx(createTx([]byte("a"), ".", 1), txGasHandler, txFeeHelper) - list.AddTx(createTx([]byte("c"), ".", 3), txGasHandler, txFeeHelper) - list.AddTx(createTx([]byte("d"), ".", 4), txGasHandler, txFeeHelper) - list.AddTx(createTx([]byte("b"), ".", 2), txGasHandler, txFeeHelper) - - require.Equal(t, []string{"a", "b", "c", "d"}, list.getTxHashesAsStrings()) -} - -func TestListForSender_AddTx_GivesPriorityToHigherGas(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - list.AddTx(createTxWithParams([]byte("a"), ".", 1, 128, 42, 42), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("b"), ".", 3, 128, 42, 100), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("c"), ".", 3, 128, 42, 99), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("d"), ".", 2, 128, 42, 42), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("e"), ".", 3, 128, 42, 101), txGasHandler, txFeeHelper) - - require.Equal(t, []string{"a", "d", "e", "b", "c"}, list.getTxHashesAsStrings()) -} - -func TestListForSender_AddTx_SortsCorrectlyWhenSameNonceSamePrice(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - list.AddTx(createTxWithParams([]byte("a"), ".", 1, 128, 42, 42), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("b"), ".", 3, 128, 42, 100), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("c"), ".", 3, 128, 42, 100), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("d"), ".", 3, 128, 42, 98), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("e"), ".", 3, 128, 42, 101), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("f"), ".", 2, 128, 42, 42), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("g"), ".", 3, 128, 42, 99), txGasHandler, txFeeHelper) - - // In case of same-nonce, same-price transactions, the newer one has priority - require.Equal(t, []string{"a", "f", "e", "b", "c", "g", "d"}, list.getTxHashesAsStrings()) -} - -func TestListForSender_AddTx_IgnoresDuplicates(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - added, _ := list.AddTx(createTx([]byte("tx1"), ".", 1), txGasHandler, txFeeHelper) - require.True(t, added) - added, _ = list.AddTx(createTx([]byte("tx2"), ".", 2), txGasHandler, txFeeHelper) - require.True(t, added) - added, _ = list.AddTx(createTx([]byte("tx3"), ".", 3), txGasHandler, txFeeHelper) - require.True(t, added) - added, _ = list.AddTx(createTx([]byte("tx2"), ".", 2), txGasHandler, txFeeHelper) - require.False(t, added) -} - -func TestListForSender_AddTx_AppliesSizeConstraintsForNumTransactions(t *testing.T) { - list := newListToTest(math.MaxUint32, 3) - txGasHandler, txFeeHelper := dummyParams() - - list.AddTx(createTx([]byte("tx1"), ".", 1), txGasHandler, txFeeHelper) - list.AddTx(createTx([]byte("tx5"), ".", 5), txGasHandler, txFeeHelper) - list.AddTx(createTx([]byte("tx4"), ".", 4), txGasHandler, txFeeHelper) - list.AddTx(createTx([]byte("tx2"), ".", 2), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2", "tx4"}, list.getTxHashesAsStrings()) - - _, evicted := list.AddTx(createTx([]byte("tx3"), ".", 3), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2", "tx3"}, list.getTxHashesAsStrings()) - require.Equal(t, []string{"tx4"}, hashesAsStrings(evicted)) - - // Gives priority to higher gas - though undesirably to some extent, "tx3" is evicted - _, evicted = list.AddTx(createTxWithParams([]byte("tx2++"), ".", 2, 128, 42, 42), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2++", "tx2"}, list.getTxHashesAsStrings()) - require.Equal(t, []string{"tx3"}, hashesAsStrings(evicted)) - - // Though Undesirably to some extent, "tx3++"" is added, then evicted - _, evicted = list.AddTx(createTxWithParams([]byte("tx3++"), ".", 3, 128, 42, 42), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2++", "tx2"}, list.getTxHashesAsStrings()) - require.Equal(t, []string{"tx3++"}, hashesAsStrings(evicted)) -} - -func TestListForSender_AddTx_AppliesSizeConstraintsForNumBytes(t *testing.T) { - list := newListToTest(1024, math.MaxUint32) - txGasHandler, txFeeHelper := dummyParams() - - list.AddTx(createTxWithParams([]byte("tx1"), ".", 1, 128, 42, 42), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("tx2"), ".", 2, 512, 42, 42), txGasHandler, txFeeHelper) - list.AddTx(createTxWithParams([]byte("tx3"), ".", 3, 256, 42, 42), txGasHandler, txFeeHelper) - _, evicted := list.AddTx(createTxWithParams([]byte("tx5"), ".", 4, 256, 42, 42), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2", "tx3"}, list.getTxHashesAsStrings()) - require.Equal(t, []string{"tx5"}, hashesAsStrings(evicted)) - - _, evicted = list.AddTx(createTxWithParams([]byte("tx5--"), ".", 4, 128, 42, 42), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2", "tx3", "tx5--"}, list.getTxHashesAsStrings()) - require.Equal(t, []string{}, hashesAsStrings(evicted)) - - _, evicted = list.AddTx(createTxWithParams([]byte("tx4"), ".", 4, 128, 42, 42), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2", "tx3", "tx4"}, list.getTxHashesAsStrings()) - require.Equal(t, []string{"tx5--"}, hashesAsStrings(evicted)) - - // Gives priority to higher gas - though undesirably to some extent, "tx4" is evicted - _, evicted = list.AddTx(createTxWithParams([]byte("tx3++"), ".", 3, 256, 42, 100), txGasHandler, txFeeHelper) - require.Equal(t, []string{"tx1", "tx2", "tx3++", "tx3"}, list.getTxHashesAsStrings()) - require.Equal(t, []string{"tx4"}, hashesAsStrings(evicted)) -} - -func TestListForSender_findTx(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - txA := createTx([]byte("A"), ".", 41) - txANewer := createTx([]byte("ANewer"), ".", 41) - txB := createTx([]byte("B"), ".", 42) - txD := createTx([]byte("none"), ".", 43) - list.AddTx(txA, txGasHandler, txFeeHelper) - list.AddTx(txANewer, txGasHandler, txFeeHelper) - list.AddTx(txB, txGasHandler, txFeeHelper) - - elementWithA := list.findListElementWithTx(txA) - elementWithANewer := list.findListElementWithTx(txANewer) - elementWithB := list.findListElementWithTx(txB) - noElementWithD := list.findListElementWithTx(txD) - - require.NotNil(t, elementWithA) - require.NotNil(t, elementWithANewer) - require.NotNil(t, elementWithB) - - require.Equal(t, txA, elementWithA.Value.(*WrappedTransaction)) - require.Equal(t, txANewer, elementWithANewer.Value.(*WrappedTransaction)) - require.Equal(t, txB, elementWithB.Value.(*WrappedTransaction)) - require.Nil(t, noElementWithD) -} - -func TestListForSender_findTx_CoverNonceComparisonOptimization(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - list.AddTx(createTx([]byte("A"), ".", 42), txGasHandler, txFeeHelper) - - // Find one with a lower nonce, not added to cache - noElement := list.findListElementWithTx(createTx(nil, ".", 41)) - require.Nil(t, noElement) -} - -func TestListForSender_RemoveTransaction(t *testing.T) { - list := newUnconstrainedListToTest() - tx := createTx([]byte("a"), ".", 1) - txGasHandler, txFeeHelper := dummyParams() - - list.AddTx(tx, txGasHandler, txFeeHelper) - require.Equal(t, 1, list.items.Len()) - - list.RemoveTx(tx) - require.Equal(t, 0, list.items.Len()) -} - -func TestListForSender_RemoveTransaction_NoPanicWhenTxMissing(t *testing.T) { - list := newUnconstrainedListToTest() - tx := createTx([]byte(""), ".", 1) - - list.RemoveTx(tx) - require.Equal(t, 0, list.items.Len()) -} - -func TestListForSender_SelectBatchTo(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - for index := 0; index < 100; index++ { - list.AddTx(createTx([]byte{byte(index)}, ".", uint64(index)), txGasHandler, txFeeHelper) - } - - destination := make([]*WrappedTransaction, 1000) - - // First batch - journal := list.selectBatchTo(true, destination, 50, math.MaxUint64) - require.Equal(t, 50, journal.copied) - require.NotNil(t, destination[49]) - require.Nil(t, destination[50]) - - // Second batch - journal = list.selectBatchTo(false, destination[50:], 50, math.MaxUint64) - require.Equal(t, 50, journal.copied) - require.NotNil(t, destination[99]) - - // No third batch - journal = list.selectBatchTo(false, destination, 50, math.MaxUint64) - require.Equal(t, 0, journal.copied) - - // Restart copy - journal = list.selectBatchTo(true, destination, 12345, math.MaxUint64) - require.Equal(t, 100, journal.copied) -} - -func TestListForSender_SelectBatchToWithLimitedGasBandwidth(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - for index := 0; index < 40; index++ { - wtx := createTx([]byte{byte(index)}, ".", uint64(index)) - tx, _ := wtx.Tx.(*transaction.Transaction) - tx.GasLimit = 1000000 - list.AddTx(wtx, txGasHandler, txFeeHelper) - } - - destination := make([]*WrappedTransaction, 1000) - - // First batch - journal := list.selectBatchTo(true, destination, 50, 500000) - require.Equal(t, 1, journal.copied) - require.NotNil(t, destination[0]) - require.Nil(t, destination[1]) - - // Second batch - journal = list.selectBatchTo(false, destination[1:], 50, 20000000) - require.Equal(t, 20, journal.copied) - require.NotNil(t, destination[20]) - require.Nil(t, destination[21]) - - // third batch - journal = list.selectBatchTo(false, destination[21:], 20, math.MaxUint64) - require.Equal(t, 19, journal.copied) - - // Restart copy - journal = list.selectBatchTo(true, destination[41:], 12345, math.MaxUint64) - require.Equal(t, 40, journal.copied) -} - -func TestListForSender_SelectBatchTo_NoPanicWhenCornerCases(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - for index := 0; index < 100; index++ { - list.AddTx(createTx([]byte{byte(index)}, ".", uint64(index)), txGasHandler, txFeeHelper) - } - - // When empty destination - destination := make([]*WrappedTransaction, 0) - journal := list.selectBatchTo(true, destination, 10, math.MaxUint64) - require.Equal(t, 0, journal.copied) - - // When small destination - destination = make([]*WrappedTransaction, 5) - journal = list.selectBatchTo(false, destination, 10, math.MaxUint64) - require.Equal(t, 5, journal.copied) -} - -func TestListForSender_SelectBatchTo_WhenInitialGap(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - list.notifyAccountNonce(1) - - for index := 10; index < 20; index++ { - list.AddTx(createTx([]byte{byte(index)}, ".", uint64(index)), txGasHandler, txFeeHelper) - } - - destination := make([]*WrappedTransaction, 1000) - - // First batch of selection, first failure - journal := list.selectBatchTo(true, destination, 50, math.MaxUint64) - require.Equal(t, 0, journal.copied) - require.Nil(t, destination[0]) - require.Equal(t, int64(1), list.numFailedSelections.Get()) - - // Second batch of selection, don't count failure again - journal = list.selectBatchTo(false, destination, 50, math.MaxUint64) - require.Equal(t, 0, journal.copied) - require.Nil(t, destination[0]) - require.Equal(t, int64(1), list.numFailedSelections.Get()) - - // First batch of another selection, second failure, enters grace period - journal = list.selectBatchTo(true, destination, 50, math.MaxUint64) - require.Equal(t, 1, journal.copied) - require.NotNil(t, destination[0]) - require.Nil(t, destination[1]) - require.Equal(t, int64(2), list.numFailedSelections.Get()) -} - -func TestListForSender_SelectBatchTo_WhenGracePeriodWithGapResolve(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - list.notifyAccountNonce(1) - - for index := 2; index < 20; index++ { - list.AddTx(createTx([]byte{byte(index)}, ".", uint64(index)), txGasHandler, txFeeHelper) - } - - destination := make([]*WrappedTransaction, 1000) - - // Try a number of selections with failure, reach close to grace period - for i := 1; i < senderGracePeriodLowerBound; i++ { - journal := list.selectBatchTo(true, destination, math.MaxInt32, math.MaxUint64) - require.Equal(t, 0, journal.copied) - require.Equal(t, int64(i), list.numFailedSelections.Get()) - } - - // Try selection again. Failure will move the sender to grace period and return 1 transaction - journal := list.selectBatchTo(true, destination, math.MaxInt32, math.MaxUint64) - require.Equal(t, 1, journal.copied) - require.Equal(t, int64(senderGracePeriodLowerBound), list.numFailedSelections.Get()) - require.False(t, list.sweepable.IsSet()) - - // Now resolve the gap - list.AddTx(createTx([]byte("resolving-tx"), ".", 1), txGasHandler, txFeeHelper) - // Selection will be successful - journal = list.selectBatchTo(true, destination, math.MaxInt32, math.MaxUint64) - require.Equal(t, 19, journal.copied) - require.Equal(t, int64(0), list.numFailedSelections.Get()) - require.False(t, list.sweepable.IsSet()) -} - -func TestListForSender_SelectBatchTo_WhenGracePeriodWithNoGapResolve(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - list.notifyAccountNonce(1) - - for index := 2; index < 20; index++ { - list.AddTx(createTx([]byte{byte(index)}, ".", uint64(index)), txGasHandler, txFeeHelper) - } - - destination := make([]*WrappedTransaction, 1000) - - // Try a number of selections with failure, reach close to grace period - for i := 1; i < senderGracePeriodLowerBound; i++ { - journal := list.selectBatchTo(true, destination, math.MaxInt32, math.MaxUint64) - require.Equal(t, 0, journal.copied) - require.Equal(t, int64(i), list.numFailedSelections.Get()) - } - - // Try a number of selections with failure, within the grace period - for i := senderGracePeriodLowerBound; i <= senderGracePeriodUpperBound; i++ { - journal := list.selectBatchTo(true, destination, math.MaxInt32, math.MaxUint64) - require.Equal(t, 1, journal.copied) - require.Equal(t, int64(i), list.numFailedSelections.Get()) - } - - // Grace period exceeded now - journal := list.selectBatchTo(true, destination, math.MaxInt32, math.MaxUint64) - require.Equal(t, 0, journal.copied) - require.Equal(t, int64(senderGracePeriodUpperBound+1), list.numFailedSelections.Get()) - require.True(t, list.sweepable.IsSet()) -} - -func TestListForSender_NotifyAccountNonce(t *testing.T) { - list := newUnconstrainedListToTest() - - require.Equal(t, uint64(0), list.accountNonce.Get()) - require.False(t, list.accountNonceKnown.IsSet()) - - list.notifyAccountNonce(42) - - require.Equal(t, uint64(42), list.accountNonce.Get()) - require.True(t, list.accountNonceKnown.IsSet()) -} - -func TestListForSender_hasInitialGap(t *testing.T) { - list := newUnconstrainedListToTest() - list.notifyAccountNonce(42) - txGasHandler, txFeeHelper := dummyParams() - - // No transaction, no gap - require.False(t, list.hasInitialGap()) - // One gap - list.AddTx(createTx([]byte("tx-43"), ".", 43), txGasHandler, txFeeHelper) - require.True(t, list.hasInitialGap()) - // Resolve gap - list.AddTx(createTx([]byte("tx-42"), ".", 42), txGasHandler, txFeeHelper) - require.False(t, list.hasInitialGap()) -} - -func TestListForSender_getTxHashes(t *testing.T) { - list := newUnconstrainedListToTest() - require.Len(t, list.getTxHashes(), 0) - txGasHandler, txFeeHelper := dummyParams() - - list.AddTx(createTx([]byte("A"), ".", 1), txGasHandler, txFeeHelper) - require.Len(t, list.getTxHashes(), 1) - - list.AddTx(createTx([]byte("B"), ".", 2), txGasHandler, txFeeHelper) - list.AddTx(createTx([]byte("C"), ".", 3), txGasHandler, txFeeHelper) - require.Len(t, list.getTxHashes(), 3) -} - -func TestListForSender_DetectRaceConditions(t *testing.T) { - list := newUnconstrainedListToTest() - txGasHandler, txFeeHelper := dummyParams() - - go func() { - // These are called concurrently with addition: during eviction, during removal etc. - approximatelyCountTxInLists([]*txListForSender{list}) - list.IsEmpty() - }() - - go func() { - list.AddTx(createTx([]byte("test"), ".", 42), txGasHandler, txFeeHelper) - }() -} - -func dummyParamsWithGasPriceAndGasLimit(minGasPrice uint64, minGasLimit uint64) (TxGasHandler, feeHelper) { - minPrice := minGasPrice - divisor := uint64(100) - minPriceProcessing := minGasPrice / divisor - txFeeHelper := newFeeComputationHelper(minPrice, minGasLimit, minPriceProcessing) - txGasHandler := &txcachemocks.TxGasHandlerMock{ - MinimumGasMove: minGasLimit, - MinimumGasPrice: minPrice, - GasProcessingDivisor: divisor, - } - return txGasHandler, txFeeHelper -} - -func dummyParamsWithGasPrice(minGasPrice uint64) (TxGasHandler, feeHelper) { - return dummyParamsWithGasPriceAndGasLimit(minGasPrice, 50000) -} - -func dummyParams() (TxGasHandler, feeHelper) { - minPrice := uint64(1000000000) - minGasLimit := uint64(50000) - return dummyParamsWithGasPriceAndGasLimit(minPrice, minGasLimit) -} - -func newUnconstrainedListToTest() *txListForSender { - return newTxListForSender(".", &senderConstraints{ - maxNumBytes: math.MaxUint32, - maxNumTxs: math.MaxUint32, - }, func(_ *txListForSender, _ senderScoreParams) {}) -} - -func newListToTest(maxNumBytes uint32, maxNumTxs uint32) *txListForSender { - return newTxListForSender(".", &senderConstraints{ - maxNumBytes: maxNumBytes, - maxNumTxs: maxNumTxs, - }, func(_ *txListForSender, _ senderScoreParams) {}) -} diff --git a/storage/txcache/txcache.go b/storage/txcache/txcache.go new file mode 100644 index 00000000000..05fc953b7ba --- /dev/null +++ b/storage/txcache/txcache.go @@ -0,0 +1,44 @@ +package txcache + +import ( + "github.com/ElrondNetwork/elrond-go-storage/txcache" +) + +// WrappedTransaction contains a transaction, its hash and extra information +type WrappedTransaction = txcache.WrappedTransaction + +// TxGasHandler handles a transaction gas and gas cost +type TxGasHandler = txcache.TxGasHandler + +// ForEachTransaction is an iterator callback +type ForEachTransaction = txcache.ForEachTransaction + +// ConfigDestinationMe holds cache configuration +type ConfigDestinationMe = txcache.ConfigDestinationMe + +// ConfigSourceMe holds cache configuration +type ConfigSourceMe = txcache.ConfigSourceMe + +// TxCache represents a cache-like structure (it has a fixed capacity and implements an eviction mechanism) for holding transactions +type TxCache = txcache.TxCache + +// DisabledCache represents a disabled cache +type DisabledCache = txcache.DisabledCache + +// CrossTxCache holds cross-shard transactions (where destination == me) +type CrossTxCache = txcache.CrossTxCache + +// NewTxCache creates a new transaction cache +func NewTxCache(config ConfigSourceMe, txGasHandler TxGasHandler) (*TxCache, error) { + return txcache.NewTxCache(config, txGasHandler) +} + +// NewDisabledCache creates a new disabled cache +func NewDisabledCache() *DisabledCache { + return txcache.NewDisabledCache() +} + +// NewCrossTxCache creates a new transactions cache +func NewCrossTxCache(config ConfigDestinationMe) (*CrossTxCache, error) { + return txcache.NewCrossTxCache(config) +} diff --git a/storage/txcache/wrappedTransaction.go b/storage/txcache/wrappedTransaction.go deleted file mode 100644 index 4491492efc3..00000000000 --- a/storage/txcache/wrappedTransaction.go +++ /dev/null @@ -1,72 +0,0 @@ -package txcache - -import ( - "bytes" - - "github.com/ElrondNetwork/elrond-go-core/data" -) - -const processFeeFactor = float64(0.8) // 80% - -// WrappedTransaction contains a transaction, its hash and extra information -type WrappedTransaction struct { - Tx data.TransactionHandler - TxHash []byte - SenderShardID uint32 - ReceiverShardID uint32 - Size int64 - TxFeeScoreNormalized uint64 -} - -func (wrappedTx *WrappedTransaction) sameAs(another *WrappedTransaction) bool { - return bytes.Equal(wrappedTx.TxHash, another.TxHash) -} - -// estimateTxGas returns an approximation for the necessary computation units (gas units) -func estimateTxGas(tx *WrappedTransaction) uint64 { - gasLimit := tx.Tx.GetGasLimit() - return gasLimit -} - -// estimateTxFeeScore returns a normalized approximation for the cost of a transaction -func estimateTxFeeScore(tx *WrappedTransaction, txGasHandler TxGasHandler, txFeeHelper feeHelper) uint64 { - moveGas, processGas := txGasHandler.SplitTxGasInCategories(tx.Tx) - - normalizedMoveGas := moveGas >> txFeeHelper.gasLimitShift() - normalizedProcessGas := processGas >> txFeeHelper.gasLimitShift() - - normalizedGasPriceMove := txGasHandler.GasPriceForMove(tx.Tx) >> txFeeHelper.gasPriceShift() - normalizedGasPriceProcess := normalizeGasPriceProcessing(tx, txGasHandler, txFeeHelper) - - normalizedFeeMove := normalizedMoveGas * normalizedGasPriceMove - normalizedFeeProcess := normalizedProcessGas * normalizedGasPriceProcess - - adjustmentFactor := computeProcessingGasPriceAdjustment(tx, txGasHandler, txFeeHelper) - - tx.TxFeeScoreNormalized = normalizedFeeMove + normalizedFeeProcess*adjustmentFactor - - return tx.TxFeeScoreNormalized -} - -func normalizeGasPriceProcessing(tx *WrappedTransaction, txGasHandler TxGasHandler, txFeeHelper feeHelper) uint64 { - return txGasHandler.GasPriceForProcessing(tx.Tx) >> txFeeHelper.gasPriceShift() -} - -func computeProcessingGasPriceAdjustment( - tx *WrappedTransaction, - txGasHandler TxGasHandler, - txFeeHelper feeHelper, -) uint64 { - minPriceFactor := txFeeHelper.minGasPriceFactor() - - if minPriceFactor <= 2 { - return 1 - } - - actualPriceFactor := float64(1) - if txGasHandler.MinGasPriceForProcessing() != 0 { - actualPriceFactor = float64(txGasHandler.GasPriceForProcessing(tx.Tx)) / float64(txGasHandler.MinGasPriceForProcessing()) - } - - return uint64(float64(txFeeHelper.minGasPriceFactor()) * processFeeFactor / actualPriceFactor) -} diff --git a/storage/txcache/wrappedTransaction_test.go b/storage/txcache/wrappedTransaction_test.go deleted file mode 100644 index 9a543711501..00000000000 --- a/storage/txcache/wrappedTransaction_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package txcache - -import ( - "testing" - - "github.com/ElrondNetwork/elrond-go/testscommon/txcachemocks" - "github.com/stretchr/testify/require" -) - -func Test_estimateTxFeeScore(t *testing.T) { - txGasHandler, txFeeHelper := dummyParamsWithGasPrice(100 * oneBillion) - A := createTxWithParams([]byte("a"), "a", 1, 200, 50000, 100*oneBillion) - B := createTxWithParams([]byte("b"), "b", 1, 200, 50000000, 100*oneBillion) - C := createTxWithParams([]byte("C"), "c", 1, 200, 1500000000, 100*oneBillion) - - scoreA := estimateTxFeeScore(A, txGasHandler, txFeeHelper) - scoreB := estimateTxFeeScore(B, txGasHandler, txFeeHelper) - scoreC := estimateTxFeeScore(C, txGasHandler, txFeeHelper) - require.Equal(t, uint64(8940), scoreA) - require.Equal(t, uint64(8940), A.TxFeeScoreNormalized) - require.Equal(t, uint64(6837580), scoreB) - require.Equal(t, uint64(6837580), B.TxFeeScoreNormalized) - require.Equal(t, uint64(205079820), scoreC) - require.Equal(t, uint64(205079820), C.TxFeeScoreNormalized) -} - -func Test_normalizeGasPriceProcessing(t *testing.T) { - txGasHandler, txFeeHelper := dummyParamsWithGasPriceAndDivisor(100*oneBillion, 100) - A := createTxWithParams([]byte("A"), "a", 1, 200, 1500000000, 100*oneBillion) - normalizedGasPriceProcess := normalizeGasPriceProcessing(A, txGasHandler, txFeeHelper) - require.Equal(t, uint64(7), normalizedGasPriceProcess) - - txGasHandler, txFeeHelper = dummyParamsWithGasPriceAndDivisor(100*oneBillion, 50) - normalizedGasPriceProcess = normalizeGasPriceProcessing(A, txGasHandler, txFeeHelper) - require.Equal(t, uint64(14), normalizedGasPriceProcess) - - txGasHandler, txFeeHelper = dummyParamsWithGasPriceAndDivisor(100*oneBillion, 1) - normalizedGasPriceProcess = normalizeGasPriceProcessing(A, txGasHandler, txFeeHelper) - require.Equal(t, uint64(745), normalizedGasPriceProcess) - - txGasHandler, txFeeHelper = dummyParamsWithGasPriceAndDivisor(100000, 100) - A = createTxWithParams([]byte("A"), "a", 1, 200, 1500000000, 100000) - normalizedGasPriceProcess = normalizeGasPriceProcessing(A, txGasHandler, txFeeHelper) - require.Equal(t, uint64(7), normalizedGasPriceProcess) -} - -func Test_computeProcessingGasPriceAdjustment(t *testing.T) { - txGasHandler, txFeeHelper := dummyParamsWithGasPriceAndDivisor(100*oneBillion, 100) - A := createTxWithParams([]byte("A"), "a", 1, 200, 1500000000, 100*oneBillion) - adjustment := computeProcessingGasPriceAdjustment(A, txGasHandler, txFeeHelper) - require.Equal(t, uint64(80), adjustment) - - A = createTxWithParams([]byte("A"), "a", 1, 200, 1500000000, 150*oneBillion) - adjustment = computeProcessingGasPriceAdjustment(A, txGasHandler, txFeeHelper) - expectedAdjustment := float64(100) * processFeeFactor / float64(1.5) - require.Equal(t, uint64(expectedAdjustment), adjustment) - - A = createTxWithParams([]byte("A"), "a", 1, 200, 1500000000, 110*oneBillion) - adjustment = computeProcessingGasPriceAdjustment(A, txGasHandler, txFeeHelper) - expectedAdjustment = float64(100) * processFeeFactor / float64(1.1) - require.Equal(t, uint64(expectedAdjustment), adjustment) -} - -func dummyParamsWithGasPriceAndDivisor(minGasPrice, processingPriceDivisor uint64) (TxGasHandler, feeHelper) { - minPrice := minGasPrice - minPriceProcessing := minGasPrice / processingPriceDivisor - minGasLimit := uint64(50000) - txFeeHelper := newFeeComputationHelper(minPrice, minGasLimit, minPriceProcessing) - txGasHandler := &txcachemocks.TxGasHandlerMock{ - MinimumGasMove: minGasLimit, - MinimumGasPrice: minPrice, - GasProcessingDivisor: processingPriceDivisor, - } - return txGasHandler, txFeeHelper -} diff --git a/testscommon/components/components.go b/testscommon/components/components.go new file mode 100644 index 00000000000..09899a6da75 --- /dev/null +++ b/testscommon/components/components.go @@ -0,0 +1,913 @@ +package components + +import ( + "fmt" + "math/big" + "testing" + "time" + + arwenConfig "github.com/ElrondNetwork/arwen-wasm-vm/v1_4/config" + "github.com/ElrondNetwork/elrond-go-core/data/block" + "github.com/ElrondNetwork/elrond-go-core/data/endProcess" + "github.com/ElrondNetwork/elrond-go-core/data/indexer" + crypto "github.com/ElrondNetwork/elrond-go-crypto" + logger "github.com/ElrondNetwork/elrond-go-logger" + "github.com/ElrondNetwork/elrond-go/common" + commonFactory "github.com/ElrondNetwork/elrond-go/common/factory" + "github.com/ElrondNetwork/elrond-go/config" + "github.com/ElrondNetwork/elrond-go/consensus/spos" + "github.com/ElrondNetwork/elrond-go/epochStart/bootstrap/disabled" + "github.com/ElrondNetwork/elrond-go/factory" + bootstrapComp "github.com/ElrondNetwork/elrond-go/factory/bootstrap" + consensusComp "github.com/ElrondNetwork/elrond-go/factory/consensus" + coreComp "github.com/ElrondNetwork/elrond-go/factory/core" + cryptoComp "github.com/ElrondNetwork/elrond-go/factory/crypto" + dataComp "github.com/ElrondNetwork/elrond-go/factory/data" + heartbeatComp "github.com/ElrondNetwork/elrond-go/factory/heartbeat" + "github.com/ElrondNetwork/elrond-go/factory/mock" + networkComp "github.com/ElrondNetwork/elrond-go/factory/network" + processComp "github.com/ElrondNetwork/elrond-go/factory/processing" + stateComp "github.com/ElrondNetwork/elrond-go/factory/state" + statusComp "github.com/ElrondNetwork/elrond-go/factory/status" + "github.com/ElrondNetwork/elrond-go/genesis" + "github.com/ElrondNetwork/elrond-go/genesis/data" + "github.com/ElrondNetwork/elrond-go/p2p" + p2pConfig "github.com/ElrondNetwork/elrond-go/p2p/config" + "github.com/ElrondNetwork/elrond-go/sharding" + "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" + "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/testscommon" + "github.com/ElrondNetwork/elrond-go/testscommon/dblookupext" + "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" + "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" + statusHandlerMock "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" + "github.com/ElrondNetwork/elrond-go/trie" + trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" + "github.com/ElrondNetwork/elrond-go/trie/hashesHolder" + "github.com/stretchr/testify/require" +) + +var log = logger.GetOrCreate("componentsMock") + +// TestHasher - +const TestHasher = "blake2b" + +// TestMarshalizer - +const TestMarshalizer = "json" + +// SignedBlocksThreshold - +const SignedBlocksThreshold = 0.025 + +// ConsecutiveMissedBlocksPenalty - +const ConsecutiveMissedBlocksPenalty = 1.1 + +// DummyPk - +const DummyPk = "629e1245577afb7717ccb46b6ff3649bdd6a1311514ad4a7695da13f801cc277ee24e730a7fa8aa6c612159b4328db17" + + "35692d0bded3a2264ba621d6bda47a981d60e17dd306d608e0875a0ba19639fb0844661f519472a175ca9ed2f33fbe16" + +// DummySk - +const DummySk = "cea01c0bf060187d90394802ff223078e47527dc8aa33a922744fb1d06029c4b" + +// LoadKeysFunc - +type LoadKeysFunc func(string, int) ([]byte, string, error) + +// GetCoreArgs - +func GetCoreArgs() coreComp.CoreComponentsFactoryArgs { + return coreComp.CoreComponentsFactoryArgs{ + Config: config.Config{ + EpochStartConfig: GetEpochStartConfig(), + PublicKeyPeerId: config.CacheConfig{ + Type: "LRU", + Capacity: 5000, + Shards: 16, + }, + PublicKeyShardId: config.CacheConfig{ + Type: "LRU", + Capacity: 5000, + Shards: 16, + }, + PeerIdShardId: config.CacheConfig{ + Type: "LRU", + Capacity: 5000, + Shards: 16, + }, + PeerHonesty: config.CacheConfig{ + Type: "LRU", + Capacity: 5000, + Shards: 16, + }, + GeneralSettings: config.GeneralSettingsConfig{ + ChainID: "undefined", + MinTransactionVersion: 1, + GenesisMaxNumberOfShards: 3, + }, + Marshalizer: config.MarshalizerConfig{ + Type: TestMarshalizer, + SizeCheckDelta: 0, + }, + Hasher: config.TypeConfig{ + Type: TestHasher, + }, + VmMarshalizer: config.TypeConfig{ + Type: TestMarshalizer, + }, + TxSignMarshalizer: config.TypeConfig{ + Type: TestMarshalizer, + }, + TxSignHasher: config.TypeConfig{ + Type: TestHasher, + }, + AddressPubkeyConverter: config.PubkeyConfig{ + Length: 32, + Type: "bech32", + SignatureLength: 0, + }, + ValidatorPubkeyConverter: config.PubkeyConfig{ + Length: 96, + Type: "hex", + SignatureLength: 48, + }, + Consensus: config.ConsensusConfig{ + Type: "bls", + }, + ValidatorStatistics: config.ValidatorStatisticsConfig{ + CacheRefreshIntervalInSec: uint32(100), + }, + SoftwareVersionConfig: config.SoftwareVersionConfig{ + PollingIntervalInMinutes: 30, + }, + Versions: config.VersionsConfig{ + DefaultVersion: "1", + VersionsByEpochs: nil, + Cache: config.CacheConfig{ + Type: "LRU", + Capacity: 1000, + Shards: 1, + }, + }, + PeersRatingConfig: config.PeersRatingConfig{ + TopRatedCacheCapacity: 1000, + BadRatedCacheCapacity: 1000, + }, + PoolsCleanersConfig: config.PoolsCleanersConfig{ + MaxRoundsToKeepUnprocessedMiniBlocks: 50, + MaxRoundsToKeepUnprocessedTransactions: 50, + }, + Hardfork: config.HardforkConfig{ + PublicKeyToListenFrom: DummyPk, + }, + HeartbeatV2: config.HeartbeatV2Config{ + HeartbeatExpiryTimespanInSec: 10, + }, + }, + ConfigPathsHolder: config.ConfigurationPathsHolder{ + GasScheduleDirectoryName: "../../cmd/node/config/gasSchedules", + }, + RatingsConfig: CreateDummyRatingsConfig(), + EconomicsConfig: CreateDummyEconomicsConfig(), + NodesFilename: "../mock/testdata/nodesSetupMock.json", + WorkingDirectory: "home", + ChanStopNodeProcess: make(chan endProcess.ArgEndProcess), + StatusHandlersFactory: &statusHandlerMock.StatusHandlersFactoryMock{}, + EpochConfig: config.EpochConfig{ + GasSchedule: config.GasScheduleConfig{ + GasScheduleByEpochs: []config.GasScheduleByEpochs{ + { + StartEpoch: 0, + FileName: "gasScheduleV1.toml", + }, + }, + }, + }, + RoundConfig: config.RoundConfig{ + RoundActivations: map[string]config.ActivationRoundByName{ + "Example": { + Round: "18446744073709551615", + }, + }, + }, + } +} + +// GetConsensusArgs - +func GetConsensusArgs(shardCoordinator sharding.Coordinator) consensusComp.ConsensusComponentsFactoryArgs { + coreComponents := GetCoreComponents() + networkComponents := GetNetworkComponents() + stateComponents := GetStateComponents(coreComponents, shardCoordinator) + cryptoComponents := GetCryptoComponents(coreComponents) + dataComponents := GetDataComponents(coreComponents, shardCoordinator) + processComponents := GetProcessComponents( + shardCoordinator, + coreComponents, + networkComponents, + dataComponents, + cryptoComponents, + stateComponents, + ) + statusComponents := GetStatusComponents( + coreComponents, + networkComponents, + dataComponents, + stateComponents, + shardCoordinator, + processComponents.NodesCoordinator(), + ) + + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: coreComponents.SyncTimer(), + Processor: processComponents.BlockProcessor(), + RoundTimeDurationHandler: coreComponents.RoundHandler(), + } + scheduledProcessor, _ := spos.NewScheduledProcessorWrapper(args) + + return consensusComp.ConsensusComponentsFactoryArgs{ + Config: testscommon.GetGeneralConfig(), + BootstrapRoundIndex: 0, + CoreComponents: coreComponents, + NetworkComponents: networkComponents, + CryptoComponents: cryptoComponents, + DataComponents: dataComponents, + ProcessComponents: processComponents, + StateComponents: stateComponents, + StatusComponents: statusComponents, + ScheduledProcessor: scheduledProcessor, + } +} + +// GetCryptoArgs - +func GetCryptoArgs(coreComponents factory.CoreComponentsHolder) cryptoComp.CryptoComponentsFactoryArgs { + args := cryptoComp.CryptoComponentsFactoryArgs{ + Config: config.Config{ + GeneralSettings: config.GeneralSettingsConfig{ChainID: "undefined"}, + Consensus: config.ConsensusConfig{ + Type: "bls", + }, + MultisigHasher: config.TypeConfig{Type: "blake2b"}, + PublicKeyPIDSignature: config.CacheConfig{ + Capacity: 1000, + Type: "LRU", + }, + Hasher: config.TypeConfig{Type: "blake2b"}, + }, + SkIndex: 0, + ValidatorKeyPemFileName: "validatorKey.pem", + CoreComponentsHolder: coreComponents, + ActivateBLSPubKeyMessageVerification: false, + KeyLoader: &mock.KeyLoaderStub{ + LoadKeyCalled: DummyLoadSkPkFromPemFile([]byte(DummySk), DummyPk, nil), + }, + EnableEpochs: config.EnableEpochs{ + BLSMultiSignerEnableEpoch: []config.MultiSignerConfig{{EnableEpoch: 0, Type: "no-KOSK"}}, + }, + } + + return args +} + +// GetDataArgs - +func GetDataArgs(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) dataComp.DataComponentsFactoryArgs { + return dataComp.DataComponentsFactoryArgs{ + Config: testscommon.GetGeneralConfig(), + PrefsConfig: config.PreferencesConfig{ + FullArchive: false, + }, + ShardCoordinator: shardCoordinator, + Core: coreComponents, + EpochStartNotifier: &mock.EpochStartNotifierStub{}, + CurrentEpoch: 0, + CreateTrieEpochRootHashStorer: false, + } +} + +// GetCoreComponents - +func GetCoreComponents() factory.CoreComponentsHolder { + coreArgs := GetCoreArgs() + coreComponentsFactory, _ := coreComp.NewCoreComponentsFactory(coreArgs) + coreComponents, err := coreComp.NewManagedCoreComponents(coreComponentsFactory) + if err != nil { + fmt.Println("getCoreComponents NewManagedCoreComponents", "error", err.Error()) + return nil + } + err = coreComponents.Create() + if err != nil { + fmt.Println("getCoreComponents Create", "error", err.Error()) + } + return coreComponents +} + +// GetHeartbeatFactoryArgs - +func GetHeartbeatFactoryArgs(shardCoordinator sharding.Coordinator) heartbeatComp.HeartbeatComponentsFactoryArgs { + coreComponents := GetCoreComponents() + networkComponents := GetNetworkComponents() + dataComponents := GetDataComponents(coreComponents, shardCoordinator) + cryptoComponents := GetCryptoComponents(coreComponents) + stateComponents := GetStateComponents(coreComponents, shardCoordinator) + processComponents := GetProcessComponents( + shardCoordinator, + coreComponents, + networkComponents, + dataComponents, + cryptoComponents, + stateComponents, + ) + + return heartbeatComp.HeartbeatComponentsFactoryArgs{ + Config: config.Config{ + Heartbeat: config.HeartbeatConfig{ + MinTimeToWaitBetweenBroadcastsInSec: 20, + MaxTimeToWaitBetweenBroadcastsInSec: 25, + HeartbeatRefreshIntervalInSec: 60, + HideInactiveValidatorIntervalInSec: 3600, + DurationToConsiderUnresponsiveInSec: 60, + HeartbeatStorage: config.StorageConfig{ + Cache: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + DB: config.DBConfig{ + FilePath: "HeartbeatStorage", + Type: "MemoryDB", + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, + }, + ValidatorStatistics: config.ValidatorStatisticsConfig{ + CacheRefreshIntervalInSec: uint32(100), + }, + }, + Prefs: config.Preferences{}, + AppVersion: "test", + GenesisTime: time.Time{}, + RedundancyHandler: &mock.RedundancyHandlerStub{ + ObserverPrivateKeyCalled: func() crypto.PrivateKey { + return &mock.PrivateKeyStub{ + GeneratePublicHandler: func() crypto.PublicKey { + return &mock.PublicKeyMock{} + }, + } + }, + }, + CoreComponents: coreComponents, + DataComponents: dataComponents, + NetworkComponents: networkComponents, + CryptoComponents: cryptoComponents, + ProcessComponents: processComponents, + } +} + +// GetNetworkFactoryArgs - +func GetNetworkFactoryArgs() networkComp.NetworkComponentsFactoryArgs { + p2pCfg := p2pConfig.P2PConfig{ + Node: p2pConfig.NodeConfig{ + Port: "0", + }, + KadDhtPeerDiscovery: p2pConfig.KadDhtPeerDiscoveryConfig{ + Enabled: false, + Type: "optimized", + RefreshIntervalInSec: 10, + ProtocolID: "erd/kad/1.0.0", + InitialPeerList: []string{"peer0", "peer1"}, + BucketSize: 10, + RoutingTableRefreshIntervalInSec: 5, + }, + Sharding: p2pConfig.ShardingConfig{ + TargetPeerCount: 10, + MaxIntraShardValidators: 10, + MaxCrossShardValidators: 10, + MaxIntraShardObservers: 10, + MaxCrossShardObservers: 10, + MaxSeeders: 2, + Type: "NilListSharder", + AdditionalConnections: p2pConfig.AdditionalConnectionsConfig{ + MaxFullHistoryObservers: 10, + }, + }, + } + + mainConfig := config.Config{ + PeerHonesty: config.CacheConfig{ + Type: "LRU", + Capacity: 5000, + Shards: 16, + }, + Debug: config.DebugConfig{ + Antiflood: config.AntifloodDebugConfig{ + Enabled: true, + CacheSize: 100, + IntervalAutoPrintInSeconds: 1, + }, + }, + PeersRatingConfig: config.PeersRatingConfig{ + TopRatedCacheCapacity: 1000, + BadRatedCacheCapacity: 1000, + }, + PoolsCleanersConfig: config.PoolsCleanersConfig{ + MaxRoundsToKeepUnprocessedMiniBlocks: 50, + MaxRoundsToKeepUnprocessedTransactions: 50, + }, + } + + appStatusHandler := statusHandlerMock.NewAppStatusHandlerMock() + + return networkComp.NetworkComponentsFactoryArgs{ + P2pConfig: p2pCfg, + MainConfig: mainConfig, + StatusHandler: appStatusHandler, + Marshalizer: &mock.MarshalizerMock{}, + RatingsConfig: config.RatingsConfig{ + General: config.General{}, + ShardChain: config.ShardChain{}, + MetaChain: config.MetaChain{}, + PeerHonesty: config.PeerHonestyConfig{ + DecayCoefficient: 0.9779, + DecayUpdateIntervalInSeconds: 10, + MaxScore: 100, + MinScore: -100, + BadPeerThreshold: -80, + UnitValue: 1.0, + }, + }, + Syncer: &p2p.LocalSyncTimer{}, + NodeOperationMode: p2p.NormalOperation, + } +} + +func getNewTrieStorageManagerArgs() trie.NewTrieStorageManagerArgs { + return trie.NewTrieStorageManagerArgs{ + MainStorer: testscommon.CreateMemUnit(), + CheckpointsStorer: testscommon.CreateMemUnit(), + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + GeneralConfig: config.TrieStorageManagerConfig{SnapshotsGoroutineNum: 1}, + CheckpointHashesHolder: hashesHolder.NewCheckpointHashesHolder(10, 32), + IdleProvider: &testscommon.ProcessStatusHandlerStub{}, + } +} + +// GetStateFactoryArgs - +func GetStateFactoryArgs(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) stateComp.StateComponentsFactoryArgs { + tsm, _ := trie.NewTrieStorageManager(getNewTrieStorageManagerArgs()) + storageManagerUser, _ := trie.NewTrieStorageManagerWithoutPruning(tsm) + tsm, _ = trie.NewTrieStorageManager(getNewTrieStorageManagerArgs()) + storageManagerPeer, _ := trie.NewTrieStorageManagerWithoutPruning(tsm) + + trieStorageManagers := make(map[string]common.StorageManager) + trieStorageManagers[trieFactory.UserAccountTrie] = storageManagerUser + trieStorageManagers[trieFactory.PeerAccountTrie] = storageManagerPeer + + triesHolder := state.NewDataTriesHolder() + trieUsers, _ := trie.NewTrie(storageManagerUser, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), 5) + triePeers, _ := trie.NewTrie(storageManagerPeer, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), 5) + triesHolder.Put([]byte(trieFactory.UserAccountTrie), trieUsers) + triesHolder.Put([]byte(trieFactory.PeerAccountTrie), triePeers) + + stateComponentsFactoryArgs := stateComp.StateComponentsFactoryArgs{ + Config: GetGeneralConfig(), + ShardCoordinator: shardCoordinator, + Core: coreComponents, + StorageService: disabled.NewChainStorer(), + ProcessingMode: common.Normal, + ChainHandler: &testscommon.ChainHandlerStub{}, + } + + return stateComponentsFactoryArgs +} + +// GetProcessComponentsFactoryArgs - +func GetProcessComponentsFactoryArgs(shardCoordinator sharding.Coordinator) processComp.ProcessComponentsFactoryArgs { + coreComponents := GetCoreComponents() + networkComponents := GetNetworkComponents() + dataComponents := GetDataComponents(coreComponents, shardCoordinator) + cryptoComponents := GetCryptoComponents(coreComponents) + stateComponents := GetStateComponents(coreComponents, shardCoordinator) + processArgs := GetProcessArgs( + shardCoordinator, + coreComponents, + dataComponents, + cryptoComponents, + stateComponents, + networkComponents, + ) + return processArgs +} + +//GetBootStrapFactoryArgs - +func GetBootStrapFactoryArgs() bootstrapComp.BootstrapComponentsFactoryArgs { + coreComponents := GetCoreComponents() + networkComponents := GetNetworkComponents() + cryptoComponents := GetCryptoComponents(coreComponents) + return bootstrapComp.BootstrapComponentsFactoryArgs{ + Config: testscommon.GetGeneralConfig(), + WorkingDir: "home", + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + NetworkComponents: networkComponents, + PrefConfig: config.Preferences{ + Preferences: config.PreferencesConfig{ + DestinationShardAsObserver: "0", + ConnectionWatcherType: "print", + }, + }, + ImportDbConfig: config.ImportDbConfig{ + IsImportDBMode: false, + }, + RoundConfig: config.RoundConfig{}, + FlagsConfig: config.ContextFlagsConfig{ + ForceStartFromNetwork: false, + }, + } +} + +// GetProcessArgs - +func GetProcessArgs( + shardCoordinator sharding.Coordinator, + coreComponents factory.CoreComponentsHolder, + dataComponents factory.DataComponentsHolder, + cryptoComponents factory.CryptoComponentsHolder, + stateComponents factory.StateComponentsHolder, + networkComponents factory.NetworkComponentsHolder, +) processComp.ProcessComponentsFactoryArgs { + + gasSchedule := arwenConfig.MakeGasMapForTests() + // TODO: check if these could be initialized by MakeGasMapForTests() + gasSchedule["BuiltInCost"]["SaveUserName"] = 1 + gasSchedule["BuiltInCost"]["SaveKeyValue"] = 1 + gasSchedule["BuiltInCost"]["ESDTTransfer"] = 1 + gasSchedule["BuiltInCost"]["ESDTBurn"] = 1 + gasSchedule[common.MetaChainSystemSCsCost] = FillGasMapMetaChainSystemSCsCosts(1) + + gasScheduleNotifier := &testscommon.GasScheduleNotifierMock{ + GasSchedule: gasSchedule, + } + + nc := &shardingMocks.NodesCoordinatorMock{} + statusComponents := GetStatusComponents( + coreComponents, + networkComponents, + dataComponents, + stateComponents, + shardCoordinator, + nc, + ) + + bootstrapComponentsFactoryArgs := GetBootStrapFactoryArgs() + bootstrapComponentsFactory, _ := bootstrapComp.NewBootstrapComponentsFactory(bootstrapComponentsFactoryArgs) + bootstrapComponents, _ := bootstrapComp.NewTestManagedBootstrapComponents(bootstrapComponentsFactory) + _ = bootstrapComponents.Create() + _ = bootstrapComponents.SetShardCoordinator(shardCoordinator) + + return processComp.ProcessComponentsFactoryArgs{ + Config: testscommon.GetGeneralConfig(), + AccountsParser: &mock.AccountsParserStub{ + InitialAccountsCalled: func() []genesis.InitialAccountHandler { + addrConverter, _ := commonFactory.NewPubkeyConverter(config.PubkeyConfig{ + Length: 32, + Type: "bech32", + SignatureLength: 0, + }) + balance := big.NewInt(0) + acc1 := data.InitialAccount{ + Address: "erd1ulhw20j7jvgfgak5p05kv667k5k9f320sgef5ayxkt9784ql0zssrzyhjp", + Supply: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), + Balance: balance, + StakingValue: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), + Delegation: &data.DelegationData{ + Address: "", + Value: big.NewInt(0), + }, + } + acc2 := data.InitialAccount{ + Address: "erd17c4fs6mz2aa2hcvva2jfxdsrdknu4220496jmswer9njznt22eds0rxlr4", + Supply: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), + Balance: balance, + StakingValue: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), + Delegation: &data.DelegationData{ + Address: "", + Value: big.NewInt(0), + }, + } + acc3 := data.InitialAccount{ + Address: "erd10d2gufxesrp8g409tzxljlaefhs0rsgjle3l7nq38de59txxt8csj54cd3", + Supply: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), + Balance: balance, + StakingValue: big.NewInt(0).Mul(big.NewInt(2500000000), big.NewInt(1000000000000)), + Delegation: &data.DelegationData{ + Address: "", + Value: big.NewInt(0), + }, + } + + acc1Bytes, _ := addrConverter.Decode(acc1.Address) + acc1.SetAddressBytes(acc1Bytes) + acc2Bytes, _ := addrConverter.Decode(acc2.Address) + acc2.SetAddressBytes(acc2Bytes) + acc3Bytes, _ := addrConverter.Decode(acc3.Address) + acc3.SetAddressBytes(acc3Bytes) + initialAccounts := []genesis.InitialAccountHandler{&acc1, &acc2, &acc3} + + return initialAccounts + }, + GenerateInitialTransactionsCalled: func(shardCoordinator sharding.Coordinator, initialIndexingData map[uint32]*genesis.IndexingData) ([]*block.MiniBlock, map[uint32]*indexer.Pool, error) { + txsPool := make(map[uint32]*indexer.Pool) + for i := uint32(0); i < shardCoordinator.NumberOfShards(); i++ { + txsPool[i] = &indexer.Pool{} + } + + return make([]*block.MiniBlock, 4), txsPool, nil + }, + }, + SmartContractParser: &mock.SmartContractParserStub{}, + GasSchedule: gasScheduleNotifier, + NodesCoordinator: nc, + Data: dataComponents, + CoreData: coreComponents, + Crypto: cryptoComponents, + State: stateComponents, + Network: networkComponents, + StatusComponents: statusComponents, + BootstrapComponents: bootstrapComponents, + RequestedItemsHandler: &testscommon.RequestedItemsHandlerStub{}, + WhiteListHandler: &testscommon.WhiteListHandlerStub{}, + WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, + MaxRating: 100, + ImportStartHandler: &testscommon.ImportStartHandlerStub{}, + SystemSCConfig: &config.SystemSmartContractsConfig{ + ESDTSystemSCConfig: config.ESDTSystemSCConfig{ + BaseIssuingCost: "1000", + OwnerAddress: "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", + }, + GovernanceSystemSCConfig: config.GovernanceSystemSCConfig{ + V1: config.GovernanceSystemSCConfigV1{ + ProposalCost: "500", + NumNodes: 100, + MinQuorum: 50, + MinPassThreshold: 50, + MinVetoThreshold: 50, + }, + Active: config.GovernanceSystemSCConfigActive{ + ProposalCost: "500", + MinQuorum: "50", + MinPassThreshold: "50", + MinVetoThreshold: "50", + }, + FirstWhitelistedAddress: "erd1vxy22x0fj4zv6hktmydg8vpfh6euv02cz4yg0aaws6rrad5a5awqgqky80", + }, + StakingSystemSCConfig: config.StakingSystemSCConfig{ + GenesisNodePrice: "2500000000000000000000", + MinStakeValue: "1", + UnJailValue: "1", + MinStepValue: "1", + UnBondPeriod: 0, + NumRoundsWithoutBleed: 0, + MaximumPercentageToBleed: 0, + BleedPercentagePerRound: 0, + MaxNumberOfNodesForStake: 10, + ActivateBLSPubKeyMessageVerification: false, + MinUnstakeTokensValue: "1", + }, + DelegationManagerSystemSCConfig: config.DelegationManagerSystemSCConfig{ + MinCreationDeposit: "100", + MinStakeAmount: "100", + ConfigChangeAddress: "erd1vxy22x0fj4zv6hktmydg8vpfh6euv02cz4yg0aaws6rrad5a5awqgqky80", + }, + DelegationSystemSCConfig: config.DelegationSystemSCConfig{ + MinServiceFee: 0, + MaxServiceFee: 100, + }, + }, + Version: "v1.0.0", + HistoryRepo: &dblookupext.HistoryRepositoryStub{}, + } +} + +// GetStatusComponents - +func GetStatusComponents( + coreComponents factory.CoreComponentsHolder, + networkComponents factory.NetworkComponentsHolder, + dataComponents factory.DataComponentsHolder, + stateComponents factory.StateComponentsHolder, + shardCoordinator sharding.Coordinator, + nodesCoordinator nodesCoordinator.NodesCoordinator, +) factory.StatusComponentsHandler { + indexerURL := "url" + elasticUsername := "user" + elasticPassword := "pass" + statusArgs := statusComp.StatusComponentsFactoryArgs{ + Config: testscommon.GetGeneralConfig(), + ExternalConfig: config.ExternalConfig{ + ElasticSearchConnector: config.ElasticSearchConfig{ + Enabled: false, + URL: indexerURL, + Username: elasticUsername, + Password: elasticPassword, + EnabledIndexes: []string{"transactions", "blocks"}, + }, + }, + EconomicsConfig: config.EconomicsConfig{}, + ShardCoordinator: shardCoordinator, + NodesCoordinator: nodesCoordinator, + EpochStartNotifier: coreComponents.EpochStartNotifierWithConfirm(), + CoreComponents: coreComponents, + DataComponents: dataComponents, + NetworkComponents: networkComponents, + StateComponents: stateComponents, + IsInImportMode: false, + } + + statusComponentsFactory, _ := statusComp.NewStatusComponentsFactory(statusArgs) + managedStatusComponents, err := statusComp.NewManagedStatusComponents(statusComponentsFactory) + if err != nil { + log.Error("getStatusComponents NewManagedStatusComponents", "error", err.Error()) + return nil + } + err = managedStatusComponents.Create() + if err != nil { + log.Error("getStatusComponents Create", "error", err.Error()) + return nil + } + return managedStatusComponents +} + +// GetStatusComponentsFactoryArgsAndProcessComponents - +func GetStatusComponentsFactoryArgsAndProcessComponents(shardCoordinator sharding.Coordinator) (statusComp.StatusComponentsFactoryArgs, factory.ProcessComponentsHolder) { + coreComponents := GetCoreComponents() + networkComponents := GetNetworkComponents() + dataComponents := GetDataComponents(coreComponents, shardCoordinator) + cryptoComponents := GetCryptoComponents(coreComponents) + stateComponents := GetStateComponents(coreComponents, shardCoordinator) + processComponents := GetProcessComponents( + shardCoordinator, + coreComponents, + networkComponents, + dataComponents, + cryptoComponents, + stateComponents, + ) + + indexerURL := "url" + elasticUsername := "user" + elasticPassword := "pass" + return statusComp.StatusComponentsFactoryArgs{ + Config: testscommon.GetGeneralConfig(), + ExternalConfig: config.ExternalConfig{ + ElasticSearchConnector: config.ElasticSearchConfig{ + Enabled: false, + URL: indexerURL, + Username: elasticUsername, + Password: elasticPassword, + EnabledIndexes: []string{"transactions", "blocks"}, + }, + }, + EconomicsConfig: config.EconomicsConfig{}, + ShardCoordinator: mock.NewMultiShardsCoordinatorMock(2), + NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, + EpochStartNotifier: &mock.EpochStartNotifierStub{}, + CoreComponents: coreComponents, + DataComponents: dataComponents, + NetworkComponents: networkComponents, + StateComponents: stateComponents, + IsInImportMode: false, + }, processComponents +} + +// GetNetworkComponents - +func GetNetworkComponents() factory.NetworkComponentsHolder { + networkArgs := GetNetworkFactoryArgs() + networkComponentsFactory, _ := networkComp.NewNetworkComponentsFactory(networkArgs) + networkComponents, _ := networkComp.NewManagedNetworkComponents(networkComponentsFactory) + + _ = networkComponents.Create() + + return networkComponents +} + +// GetDataComponents - +func GetDataComponents(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) factory.DataComponentsHolder { + dataArgs := GetDataArgs(coreComponents, shardCoordinator) + dataComponentsFactory, _ := dataComp.NewDataComponentsFactory(dataArgs) + dataComponents, _ := dataComp.NewManagedDataComponents(dataComponentsFactory) + _ = dataComponents.Create() + return dataComponents +} + +// GetCryptoComponents - +func GetCryptoComponents(coreComponents factory.CoreComponentsHolder) factory.CryptoComponentsHolder { + cryptoArgs := GetCryptoArgs(coreComponents) + cryptoComponentsFactory, _ := cryptoComp.NewCryptoComponentsFactory(cryptoArgs) + cryptoComponents, err := cryptoComp.NewManagedCryptoComponents(cryptoComponentsFactory) + if err != nil { + log.Error("getCryptoComponents NewManagedCryptoComponents", "error", err.Error()) + return nil + } + + err = cryptoComponents.Create() + if err != nil { + log.Error("getCryptoComponents Create", "error", err.Error()) + return nil + } + return cryptoComponents +} + +// GetStateComponents - +func GetStateComponents(coreComponents factory.CoreComponentsHolder, shardCoordinator sharding.Coordinator) factory.StateComponentsHolder { + stateArgs := GetStateFactoryArgs(coreComponents, shardCoordinator) + stateComponentsFactory, err := stateComp.NewStateComponentsFactory(stateArgs) + if err != nil { + log.Error("getStateComponents NewStateComponentsFactory", "error", err.Error()) + return nil + } + + stateComponents, err := stateComp.NewManagedStateComponents(stateComponentsFactory) + if err != nil { + log.Error("getStateComponents NewManagedStateComponents", "error", err.Error()) + return nil + } + err = stateComponents.Create() + if err != nil { + log.Error("getStateComponents Create", "error", err.Error()) + return nil + } + return stateComponents +} + +// GetProcessComponents - +func GetProcessComponents( + shardCoordinator sharding.Coordinator, + coreComponents factory.CoreComponentsHolder, + networkComponents factory.NetworkComponentsHolder, + dataComponents factory.DataComponentsHolder, + cryptoComponents factory.CryptoComponentsHolder, + stateComponents factory.StateComponentsHolder, +) factory.ProcessComponentsHolder { + processArgs := GetProcessArgs( + shardCoordinator, + coreComponents, + dataComponents, + cryptoComponents, + stateComponents, + networkComponents, + ) + processComponentsFactory, _ := processComp.NewProcessComponentsFactory(processArgs) + managedProcessComponents, err := processComp.NewManagedProcessComponents(processComponentsFactory) + if err != nil { + log.Error("getProcessComponents NewManagedProcessComponents", "error", err.Error()) + return nil + } + err = managedProcessComponents.Create() + if err != nil { + log.Error("getProcessComponents Create", "error", err.Error()) + return nil + } + return managedProcessComponents +} + +// DummyLoadSkPkFromPemFile - +func DummyLoadSkPkFromPemFile(sk []byte, pk string, err error) LoadKeysFunc { + return func(_ string, _ int) ([]byte, string, error) { + return sk, pk, err + } +} + +// FillGasMapMetaChainSystemSCsCosts - +func FillGasMapMetaChainSystemSCsCosts(value uint64) map[string]uint64 { + gasMap := make(map[string]uint64) + gasMap["Stake"] = value + gasMap["UnStake"] = value + gasMap["UnBond"] = value + gasMap["Claim"] = value + gasMap["Get"] = value + gasMap["ChangeRewardAddress"] = value + gasMap["ChangeValidatorKeys"] = value + gasMap["UnJail"] = value + gasMap["ESDTIssue"] = value + gasMap["ESDTOperations"] = value + gasMap["Proposal"] = value + gasMap["Vote"] = value + gasMap["DelegateVote"] = value + gasMap["RevokeVote"] = value + gasMap["CloseProposal"] = value + gasMap["DelegationOps"] = value + gasMap["UnStakeTokens"] = value + gasMap["UnBondTokens"] = value + gasMap["DelegationMgrOps"] = value + gasMap["GetAllNodeStates"] = value + gasMap["ValidatorToDelegation"] = value + gasMap["FixWaitingListSize"] = value + + return gasMap +} + +// SetShardCoordinator - +func SetShardCoordinator(t *testing.T, bootstrapComponents factory.BootstrapComponentsHolder, coordinator sharding.Coordinator) { + type testBootstrapComponents interface { + SetShardCoordinator(shardCoordinator sharding.Coordinator) error + } + + testBootstrap, ok := bootstrapComponents.(testBootstrapComponents) + require.True(t, ok) + + _ = testBootstrap.SetShardCoordinator(coordinator) +} diff --git a/testscommon/components/configs.go b/testscommon/components/configs.go new file mode 100644 index 00000000000..a97f00fec12 --- /dev/null +++ b/testscommon/components/configs.go @@ -0,0 +1,237 @@ +package components + +import ( + "github.com/ElrondNetwork/elrond-go/config" +) + +// GetGeneralConfig - +func GetGeneralConfig() config.Config { + return config.Config{ + AddressPubkeyConverter: config.PubkeyConfig{ + Length: 32, + Type: "hex", + SignatureLength: 0, + }, + ValidatorPubkeyConverter: config.PubkeyConfig{ + Length: 96, + Type: "hex", + SignatureLength: 0, + }, + StateTriesConfig: config.StateTriesConfig{ + CheckpointRoundsModulus: 5, + AccountsStatePruningEnabled: true, + PeerStatePruningEnabled: true, + MaxStateTrieLevelInMemory: 5, + MaxPeerTrieLevelInMemory: 5, + }, + EvictionWaitingList: config.EvictionWaitingListConfig{ + HashesSize: 100, + RootHashesSize: 100, + DB: config.DBConfig{ + FilePath: "EvictionWaitingList", + Type: "MemoryDB", + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, + AccountsTrieStorage: config.StorageConfig{ + Cache: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + DB: config.DBConfig{ + FilePath: "AccountsTrie/MainDB", + Type: "MemoryDB", + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, + AccountsTrieCheckpointsStorage: config.StorageConfig{ + Cache: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + DB: config.DBConfig{ + FilePath: "AccountsTrieCheckpoints", + Type: "MemoryDB", + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, + PeerAccountsTrieStorage: config.StorageConfig{ + Cache: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + DB: config.DBConfig{ + FilePath: "PeerAccountsTrie/MainDB", + Type: "MemoryDB", + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, + PeerAccountsTrieCheckpointsStorage: config.StorageConfig{ + Cache: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + DB: config.DBConfig{ + FilePath: "PeerAccountsTrieCheckpoints", + Type: "MemoryDB", + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, + TrieStorageManagerConfig: config.TrieStorageManagerConfig{ + PruningBufferLen: 1000, + SnapshotsBufferLen: 10, + SnapshotsGoroutineNum: 1, + }, + VirtualMachine: config.VirtualMachineServicesConfig{ + Querying: config.QueryVirtualMachineConfig{ + NumConcurrentVMs: 1, + VirtualMachineConfig: config.VirtualMachineConfig{ + ArwenVersions: []config.ArwenVersionByEpoch{ + {StartEpoch: 0, Version: "v0.3"}, + }, + }, + }, + Execution: config.VirtualMachineConfig{ + ArwenVersions: []config.ArwenVersionByEpoch{ + {StartEpoch: 0, Version: "v0.3"}, + }, + }, + GasConfig: config.VirtualMachineGasConfig{ + ShardMaxGasPerVmQuery: 1_500_000_000, + MetaMaxGasPerVmQuery: 0, + }, + }, + SmartContractsStorageForSCQuery: config.StorageConfig{ + Cache: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + }, + SmartContractDataPool: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + PeersRatingConfig: config.PeersRatingConfig{ + TopRatedCacheCapacity: 1000, + BadRatedCacheCapacity: 1000, + }, + PoolsCleanersConfig: config.PoolsCleanersConfig{ + MaxRoundsToKeepUnprocessedMiniBlocks: 50, + MaxRoundsToKeepUnprocessedTransactions: 50, + }, + BuiltInFunctions: config.BuiltInFunctionsConfig{ + AutomaticCrawlerAddresses: []string{ + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + }, + MaxNumAddressesInTransferRole: 100, + }, + } +} + +// GetEpochStartConfig - +func GetEpochStartConfig() config.EpochStartConfig { + return config.EpochStartConfig{ + MinRoundsBetweenEpochs: 20, + RoundsPerEpoch: 20, + MaxShuffledOutRestartThreshold: 0.2, + MinShuffledOutRestartThreshold: 0.1, + MinNumConnectedPeersToStart: 2, + MinNumOfPeersToConsiderBlockValid: 2, + } +} + +// CreateDummyEconomicsConfig - +func CreateDummyEconomicsConfig() config.EconomicsConfig { + return config.EconomicsConfig{ + GlobalSettings: config.GlobalSettings{ + GenesisTotalSupply: "20000000000000000000000000", + MinimumInflation: 0, + YearSettings: []*config.YearSetting{ + { + Year: 0, + MaximumInflation: 0.01, + }, + }, + }, + RewardsSettings: config.RewardsSettings{ + RewardsConfigByEpoch: []config.EpochRewardSettings{ + { + LeaderPercentage: 0.1, + ProtocolSustainabilityPercentage: 0.1, + ProtocolSustainabilityAddress: "erd1932eft30w753xyvme8d49qejgkjc09n5e49w4mwdjtm0neld797su0dlxp", + TopUpFactor: 0.25, + TopUpGradientPoint: "3000000000000000000000000", + }, + }, + }, + FeeSettings: config.FeeSettings{ + GasLimitSettings: []config.GasLimitSetting{ + { + MaxGasLimitPerBlock: "1500000000", + MaxGasLimitPerMiniBlock: "1500000000", + MaxGasLimitPerMetaBlock: "15000000000", + MaxGasLimitPerMetaMiniBlock: "15000000000", + MaxGasLimitPerTx: "1500000000", + MinGasLimit: "50000", + }, + }, + MinGasPrice: "1000000000", + GasPerDataByte: "1500", + GasPriceModifier: 1, + }, + } +} + +// CreateDummyRatingsConfig - +func CreateDummyRatingsConfig() config.RatingsConfig { + return config.RatingsConfig{ + General: config.General{ + StartRating: 5000001, + MaxRating: 10000000, + MinRating: 1, + SignedBlocksThreshold: SignedBlocksThreshold, + SelectionChances: []*config.SelectionChance{ + {MaxThreshold: 0, ChancePercent: 5}, + {MaxThreshold: 2500000, ChancePercent: 19}, + {MaxThreshold: 7500000, ChancePercent: 20}, + {MaxThreshold: 10000000, ChancePercent: 21}, + }, + }, + ShardChain: config.ShardChain{ + RatingSteps: config.RatingSteps{ + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: ConsecutiveMissedBlocksPenalty, + }, + }, + MetaChain: config.MetaChain{ + RatingSteps: config.RatingSteps{ + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: ConsecutiveMissedBlocksPenalty, + }, + }, + } +} diff --git a/testscommon/components/default.go b/testscommon/components/default.go new file mode 100644 index 00000000000..25162922da6 --- /dev/null +++ b/testscommon/components/default.go @@ -0,0 +1,145 @@ +package components + +import ( + "time" + + crypto "github.com/ElrondNetwork/elrond-go-crypto" + "github.com/ElrondNetwork/elrond-go/common" + "github.com/ElrondNetwork/elrond-go/factory/mock" + "github.com/ElrondNetwork/elrond-go/sharding" + "github.com/ElrondNetwork/elrond-go/testscommon" + "github.com/ElrondNetwork/elrond-go/testscommon/cryptoMocks" + dataRetrieverMock "github.com/ElrondNetwork/elrond-go/testscommon/dataRetriever" + "github.com/ElrondNetwork/elrond-go/testscommon/economicsmocks" + "github.com/ElrondNetwork/elrond-go/testscommon/nodeTypeProviderMock" + "github.com/ElrondNetwork/elrond-go/testscommon/p2pmocks" + "github.com/ElrondNetwork/elrond-go/testscommon/shardingMocks" + stateMock "github.com/ElrondNetwork/elrond-go/testscommon/state" + statusHandlerMock "github.com/ElrondNetwork/elrond-go/testscommon/statusHandler" + "github.com/ElrondNetwork/elrond-go/testscommon/storage" + trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" + trieFactory "github.com/ElrondNetwork/elrond-go/trie/factory" +) + +// GetDefaultCoreComponents - +func GetDefaultCoreComponents() *mock.CoreComponentsMock { + return &mock.CoreComponentsMock{ + IntMarsh: &testscommon.MarshalizerMock{}, + TxMarsh: &testscommon.MarshalizerMock{}, + VmMarsh: &testscommon.MarshalizerMock{}, + Hash: &testscommon.HasherStub{}, + UInt64ByteSliceConv: testscommon.NewNonceHashConverterMock(), + AddrPubKeyConv: testscommon.NewPubkeyConverterMock(32), + ValPubKeyConv: testscommon.NewPubkeyConverterMock(32), + PathHdl: &testscommon.PathManagerStub{}, + ChainIdCalled: func() string { + return "chainID" + }, + MinTransactionVersionCalled: func() uint32 { + return 1 + }, + AppStatusHdl: &statusHandlerMock.AppStatusHandlerStub{}, + WatchdogTimer: &testscommon.WatchdogMock{}, + AlarmSch: &testscommon.AlarmSchedulerStub{}, + NtpSyncTimer: &testscommon.SyncTimerStub{}, + RoundHandlerField: &testscommon.RoundHandlerMock{}, + EconomicsHandler: &economicsmocks.EconomicsHandlerStub{}, + RatingsConfig: &testscommon.RatingsInfoMock{}, + RatingHandler: &testscommon.RaterMock{}, + NodesConfig: &testscommon.NodesSetupStub{}, + StartTime: time.Time{}, + NodeTypeProviderField: &nodeTypeProviderMock.NodeTypeProviderStub{}, + } +} + +// GetDefaultCryptoComponents - +func GetDefaultCryptoComponents() *mock.CryptoComponentsMock { + return &mock.CryptoComponentsMock{ + PubKey: &mock.PublicKeyMock{}, + PrivKey: &mock.PrivateKeyStub{}, + PubKeyString: "pubKey", + PrivKeyBytes: []byte("privKey"), + PubKeyBytes: []byte("pubKey"), + BlockSig: &mock.SinglesignMock{}, + TxSig: &mock.SinglesignMock{}, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(&cryptoMocks.MultisignerMock{}), + PeerSignHandler: &mock.PeerSignatureHandler{}, + BlKeyGen: &mock.KeyGenMock{}, + TxKeyGen: &mock.KeyGenMock{}, + MsgSigVerifier: &testscommon.MessageSignVerifierMock{}, + } +} + +// GetDefaultNetworkComponents - +func GetDefaultNetworkComponents() *mock.NetworkComponentsMock { + return &mock.NetworkComponentsMock{ + Messenger: &p2pmocks.MessengerStub{}, + InputAntiFlood: &mock.P2PAntifloodHandlerStub{}, + OutputAntiFlood: &mock.P2PAntifloodHandlerStub{}, + PeerBlackList: &mock.PeerBlackListHandlerStub{}, + } +} + +// GetDefaultStateComponents - +func GetDefaultStateComponents() *testscommon.StateComponentsMock { + return &testscommon.StateComponentsMock{ + PeersAcc: &stateMock.AccountsStub{}, + Accounts: &stateMock.AccountsStub{}, + Tries: &trieMock.TriesHolderStub{}, + StorageManagers: map[string]common.StorageManager{ + "0": &testscommon.StorageManagerStub{}, + trieFactory.UserAccountTrie: &testscommon.StorageManagerStub{}, + trieFactory.PeerAccountTrie: &testscommon.StorageManagerStub{}, + }, + } +} + +// GetDefaultDataComponents - +func GetDefaultDataComponents() *mock.DataComponentsMock { + return &mock.DataComponentsMock{ + Blkc: &testscommon.ChainHandlerStub{}, + Storage: &storage.ChainStorerStub{}, + DataPool: &dataRetrieverMock.PoolsHolderMock{}, + MiniBlockProvider: &mock.MiniBlocksProviderStub{}, + } +} + +// GetDefaultProcessComponents - +func GetDefaultProcessComponents(shardCoordinator sharding.Coordinator) *mock.ProcessComponentsMock { + return &mock.ProcessComponentsMock{ + NodesCoord: &shardingMocks.NodesCoordinatorMock{}, + ShardCoord: shardCoordinator, + IntContainer: &testscommon.InterceptorsContainerStub{}, + ResFinder: &mock.ResolversFinderStub{}, + RoundHandlerField: &testscommon.RoundHandlerMock{}, + EpochTrigger: &testscommon.EpochStartTriggerStub{}, + EpochNotifier: &mock.EpochStartNotifierStub{}, + ForkDetect: &mock.ForkDetectorMock{}, + BlockProcess: &mock.BlockProcessorStub{}, + BlackListHdl: &testscommon.TimeCacheStub{}, + BootSore: &mock.BootstrapStorerMock{}, + HeaderSigVerif: &mock.HeaderSigVerifierStub{}, + HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, + ValidatorStatistics: &mock.ValidatorStatisticsProcessorStub{}, + ValidatorProvider: &mock.ValidatorsProviderStub{}, + BlockTrack: &mock.BlockTrackerStub{}, + PendingMiniBlocksHdl: &mock.PendingMiniBlocksHandlerStub{}, + ReqHandler: &testscommon.RequestHandlerStub{}, + TxLogsProcess: &mock.TxLogProcessorMock{}, + HeaderConstructValidator: &mock.HeaderValidatorStub{}, + PeerMapper: &p2pmocks.NetworkShardingCollectorStub{}, + FallbackHdrValidator: &testscommon.FallBackHeaderValidatorStub{}, + NodeRedundancyHandlerInternal: &mock.RedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return false + }, + IsMainMachineActiveCalled: func() bool { + return false + }, + ObserverPrivateKeyCalled: func() crypto.PrivateKey { + return &mock.PrivateKeyStub{} + }, + }, + HardforkTriggerField: &testscommon.HardforkTriggerStub{}, + } +} diff --git a/testscommon/cryptoMocks/multiSignerContainerMock.go b/testscommon/cryptoMocks/multiSignerContainerMock.go new file mode 100644 index 00000000000..f76dbd8236d --- /dev/null +++ b/testscommon/cryptoMocks/multiSignerContainerMock.go @@ -0,0 +1,23 @@ +package cryptoMocks + +import crypto "github.com/ElrondNetwork/elrond-go-crypto" + +// MultiSignerContainerMock - +type MultiSignerContainerMock struct { + MultiSigner crypto.MultiSigner +} + +// NewMultiSignerContainerMock - +func NewMultiSignerContainerMock(multiSigner crypto.MultiSigner) *MultiSignerContainerMock { + return &MultiSignerContainerMock{MultiSigner: multiSigner} +} + +// GetMultiSigner - +func (mscm *MultiSignerContainerMock) GetMultiSigner(_ uint32) (crypto.MultiSigner, error) { + return mscm.MultiSigner, nil +} + +// IsInterfaceNil - +func (mscm *MultiSignerContainerMock) IsInterfaceNil() bool { + return mscm == nil +} diff --git a/testscommon/cryptoMocks/multiSignerContainerStub.go b/testscommon/cryptoMocks/multiSignerContainerStub.go new file mode 100644 index 00000000000..048c124b009 --- /dev/null +++ b/testscommon/cryptoMocks/multiSignerContainerStub.go @@ -0,0 +1,22 @@ +package cryptoMocks + +import crypto "github.com/ElrondNetwork/elrond-go-crypto" + +// MultiSignerContainerStub - +type MultiSignerContainerStub struct { + GetMultiSignerCalled func(epoch uint32) (crypto.MultiSigner, error) +} + +// GetMultiSigner - +func (stub *MultiSignerContainerStub) GetMultiSigner(epoch uint32) (crypto.MultiSigner, error) { + if stub.GetMultiSignerCalled != nil { + return stub.GetMultiSignerCalled(epoch) + } + + return nil, nil +} + +// IsInterfaceNil - +func (stub *MultiSignerContainerStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/testscommon/cryptoMocks/multiSignerStub.go b/testscommon/cryptoMocks/multiSignerStub.go new file mode 100644 index 00000000000..61f9262a108 --- /dev/null +++ b/testscommon/cryptoMocks/multiSignerStub.go @@ -0,0 +1,50 @@ +package cryptoMocks + +// MultiSignerStub implements crypto multisigner +type MultiSignerStub struct { + VerifyAggregatedSigCalled func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error + CreateSignatureShareCalled func(privateKeyBytes []byte, message []byte) ([]byte, error) + VerifySignatureShareCalled func(publicKey []byte, message []byte, sig []byte) error + AggregateSigsCalled func(pubKeysSigners [][]byte, signatures [][]byte) ([]byte, error) +} + +// VerifyAggregatedSig - +func (stub *MultiSignerStub) VerifyAggregatedSig(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + if stub.VerifyAggregatedSigCalled != nil { + return stub.VerifyAggregatedSigCalled(pubKeysSigners, message, aggSig) + } + + return nil +} + +// CreateSignatureShare - +func (stub *MultiSignerStub) CreateSignatureShare(privateKeyBytes []byte, message []byte) ([]byte, error) { + if stub.CreateSignatureShareCalled != nil { + return stub.CreateSignatureShareCalled(privateKeyBytes, message) + } + + return nil, nil +} + +// VerifySignatureShare - +func (stub *MultiSignerStub) VerifySignatureShare(publicKey []byte, message []byte, sig []byte) error { + if stub.VerifySignatureShareCalled != nil { + return stub.VerifySignatureShareCalled(publicKey, message, sig) + } + + return nil +} + +// AggregateSigs - +func (stub *MultiSignerStub) AggregateSigs(pubKeysSigners [][]byte, signatures [][]byte) ([]byte, error) { + if stub.AggregateSigsCalled != nil { + return stub.AggregateSigsCalled(pubKeysSigners, signatures) + } + + return nil, nil +} + +// IsInterfaceNil - +func (stub *MultiSignerStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/testscommon/cryptoMocks/multisignerMock.go b/testscommon/cryptoMocks/multisignerMock.go index 5b02490eed3..bce88e5b330 100644 --- a/testscommon/cryptoMocks/multisignerMock.go +++ b/testscommon/cryptoMocks/multisignerMock.go @@ -1,154 +1,56 @@ package cryptoMocks import ( - "github.com/ElrondNetwork/elrond-go-crypto" + "bytes" ) const signatureSize = 48 // MultisignerMock is used to mock the multisignature scheme type MultisignerMock struct { - aggSig []byte - sigs [][]byte - pubkeys []string - selfId uint16 - - VerifyCalled func(msg []byte, bitmap []byte) error - CommitmentHashCalled func(index uint16) ([]byte, error) - CreateSignatureShareCalled func(msg []byte, bitmap []byte) ([]byte, error) - VerifySignatureShareCalled func(index uint16, sig []byte, msg []byte, bitmap []byte) error - AggregateSigsCalled func(bitmap []byte) ([]byte, error) - SignatureShareCalled func(index uint16) ([]byte, error) - CreateCalled func(pubKeys []string, index uint16) (crypto.MultiSigner, error) - ResetCalled func(pubKeys []string, index uint16) error - CreateAndAddSignatureShareForKeyCalled func(message []byte, privateKey crypto.PrivateKey, pubKeyBytes []byte) ([]byte, error) - StoreSignatureShareCalled func(index uint16, sig []byte) error + CreateSignatureShareCalled func(privateKeyBytes []byte, message []byte) ([]byte, error) + VerifySignatureShareCalled func(publicKey []byte, message []byte, sig []byte) error + AggregateSigsCalled func(pubKeysSigners [][]byte, signatures [][]byte) ([]byte, error) + VerifyAggregatedSigCalled func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error } // NewMultiSigner - -func NewMultiSigner(consensusSize uint32) *MultisignerMock { - multisigner := &MultisignerMock{} - multisigner.sigs = make([][]byte, consensusSize) - multisigner.pubkeys = make([]string, consensusSize) - - multisigner.aggSig = make([]byte, signatureSize) - copy(multisigner.aggSig, "aggregated signature") - - return multisigner -} - -// Create - -func (mm *MultisignerMock) Create(pubKeys []string, index uint16) (crypto.MultiSigner, error) { - if mm.CreateCalled != nil { - return mm.CreateCalled(pubKeys, index) - } - - multiSig := NewMultiSigner(uint32(len(pubKeys))) - - multiSig.selfId = index - multiSig.pubkeys = pubKeys - - return multiSig, nil -} - -// Reset - -func (mm *MultisignerMock) Reset(pubKeys []string, index uint16) error { - if mm.ResetCalled != nil { - return mm.ResetCalled(pubKeys, index) - } - - mm.sigs = make([][]byte, len(pubKeys)) - mm.pubkeys = make([]string, len(pubKeys)) - mm.selfId = index - mm.pubkeys = pubKeys - - for i := 0; i < len(pubKeys); i++ { - mm.sigs[i] = mm.aggSig - } - - mm.selfId = index - mm.pubkeys = pubKeys - - return nil -} - -// SetAggregatedSig - -func (mm *MultisignerMock) SetAggregatedSig(aggSig []byte) error { - mm.aggSig = aggSig - - return nil +func NewMultiSigner() *MultisignerMock { + return &MultisignerMock{} } -// Verify - -func (mm *MultisignerMock) Verify(msg []byte, bitmap []byte) error { - if mm.VerifyCalled != nil { - return mm.VerifyCalled(msg, bitmap) - } - - return nil -} - -// CreateSignatureShare creates a partial signature -func (mm *MultisignerMock) CreateSignatureShare(msg []byte, bitmap []byte) ([]byte, error) { +// CreateSignatureShare - +func (mm *MultisignerMock) CreateSignatureShare(privateKeyBytes []byte, message []byte) ([]byte, error) { if mm.CreateSignatureShareCalled != nil { - return mm.CreateSignatureShareCalled(msg, bitmap) + return mm.CreateSignatureShareCalled(privateKeyBytes, message) } - return mm.aggSig, nil -} - -// StoreSignatureShare - -func (mm *MultisignerMock) StoreSignatureShare(index uint16, sig []byte) error { - if mm.StoreSignatureShareCalled != nil { - return mm.StoreSignatureShareCalled(index, sig) - } - - if index >= uint16(len(mm.pubkeys)) { - return crypto.ErrIndexOutOfBounds - } - - mm.sigs[index] = sig - return nil + return bytes.Repeat([]byte{0xAA}, signatureSize), nil } // VerifySignatureShare - -func (mm *MultisignerMock) VerifySignatureShare(index uint16, sig []byte, msg []byte, bitmap []byte) error { +func (mm *MultisignerMock) VerifySignatureShare(publicKey []byte, message []byte, sig []byte) error { if mm.VerifySignatureShareCalled != nil { - return mm.VerifySignatureShareCalled(index, sig, msg, bitmap) + return mm.VerifySignatureShareCalled(publicKey, message, sig) } - return nil } // AggregateSigs - -func (mm *MultisignerMock) AggregateSigs(bitmap []byte) ([]byte, error) { +func (mm *MultisignerMock) AggregateSigs(pubKeysSigners [][]byte, signatures [][]byte) ([]byte, error) { if mm.AggregateSigsCalled != nil { - return mm.AggregateSigsCalled(bitmap) - } - - return mm.aggSig, nil -} - -// SignatureShare - -func (mm *MultisignerMock) SignatureShare(index uint16) ([]byte, error) { - if mm.SignatureShareCalled != nil { - return mm.SignatureShareCalled(index) - } - - if index >= uint16(len(mm.sigs)) { - return nil, crypto.ErrIndexOutOfBounds + return mm.AggregateSigsCalled(pubKeysSigners, signatures) } - return mm.sigs[index], nil + return bytes.Repeat([]byte{0xAA}, signatureSize), nil } -// CreateAndAddSignatureShareForKey - -func (mm *MultisignerMock) CreateAndAddSignatureShareForKey(message []byte, privateKey crypto.PrivateKey, pubKeyBytes []byte) ([]byte, error) { - if mm.CreateAndAddSignatureShareForKeyCalled != nil { - return mm.CreateAndAddSignatureShareForKeyCalled(message, privateKey, pubKeyBytes) +// VerifyAggregatedSig - +func (mm *MultisignerMock) VerifyAggregatedSig(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + if mm.VerifyAggregatedSigCalled != nil { + return mm.VerifyAggregatedSigCalled(pubKeysSigners, message, aggSig) } - - return nil, nil + return nil } // IsInterfaceNil - diff --git a/testscommon/cryptoMocks/multisignerStub.go b/testscommon/cryptoMocks/multisignerStub.go deleted file mode 100644 index 4f2d4291bf2..00000000000 --- a/testscommon/cryptoMocks/multisignerStub.go +++ /dev/null @@ -1,112 +0,0 @@ -package cryptoMocks - -import crypto "github.com/ElrondNetwork/elrond-go-crypto" - -// MultisignerStub - -type MultisignerStub struct { - CreateCalled func(pubKeys []string, index uint16) (crypto.MultiSigner, error) - SetAggregatedSigCalled func(bytes []byte) error - VerifyCalled func(msg []byte, bitmap []byte) error - ResetCalled func(pubKeys []string, index uint16) error - CreateSignatureShareCalled func(msg []byte, bitmap []byte) ([]byte, error) - StoreSignatureShareCalled func(index uint16, sig []byte) error - SignatureShareCalled func(index uint16) ([]byte, error) - VerifySignatureShareCalled func(index uint16, sig []byte, msg []byte, bitmap []byte) error - AggregateSigsCalled func(bitmap []byte) ([]byte, error) - CreateAndAddSignatureShareForKeyCalled func(message []byte, privateKey crypto.PrivateKey, pubKeyBytes []byte) ([]byte, error) -} - -// Create - -func (mss *MultisignerStub) Create(pubKeys []string, index uint16) (crypto.MultiSigner, error) { - if mss.CreateCalled != nil { - return mss.CreateCalled(pubKeys, index) - } - - return nil, nil -} - -// SetAggregatedSig - -func (mss *MultisignerStub) SetAggregatedSig(bytes []byte) error { - if mss.SetAggregatedSigCalled != nil { - return mss.SetAggregatedSigCalled(bytes) - } - - return nil -} - -// Verify - -func (mss *MultisignerStub) Verify(msg []byte, bitmap []byte) error { - if mss.VerifyCalled != nil { - return mss.VerifyCalled(msg, bitmap) - } - - return nil -} - -// Reset - -func (mss *MultisignerStub) Reset(pubKeys []string, index uint16) error { - if mss.ResetCalled != nil { - return mss.ResetCalled(pubKeys, index) - } - - return nil -} - -// CreateSignatureShare - -func (mss *MultisignerStub) CreateSignatureShare(msg []byte, bitmap []byte) ([]byte, error) { - if mss.CreateSignatureShareCalled != nil { - return mss.CreateSignatureShareCalled(msg, bitmap) - } - - return nil, nil -} - -// StoreSignatureShare - -func (mss *MultisignerStub) StoreSignatureShare(index uint16, sig []byte) error { - if mss.StoreSignatureShareCalled != nil { - return mss.StoreSignatureShareCalled(index, sig) - } - - return nil -} - -// SignatureShare - -func (mss *MultisignerStub) SignatureShare(index uint16) ([]byte, error) { - if mss.SignatureShareCalled != nil { - return mss.SignatureShareCalled(index) - } - - return nil, nil -} - -// VerifySignatureShare - -func (mss *MultisignerStub) VerifySignatureShare(index uint16, sig []byte, msg []byte, bitmap []byte) error { - if mss.VerifySignatureShareCalled != nil { - return mss.VerifySignatureShareCalled(index, sig, msg, bitmap) - } - - return nil -} - -// AggregateSigs - -func (mss *MultisignerStub) AggregateSigs(bitmap []byte) ([]byte, error) { - if mss.AggregateSigsCalled != nil { - return mss.AggregateSigsCalled(bitmap) - } - - return nil, nil -} - -// CreateAndAddSignatureShareForKey - -func (mss *MultisignerStub) CreateAndAddSignatureShareForKey(message []byte, privateKey crypto.PrivateKey, pubKeyBytes []byte) ([]byte, error) { - if mss.CreateAndAddSignatureShareForKeyCalled != nil { - return mss.CreateAndAddSignatureShareForKeyCalled(message, privateKey, pubKeyBytes) - } - - return nil, nil -} - -// IsInterfaceNil - -func (mss *MultisignerStub) IsInterfaceNil() bool { - return mss == nil -} diff --git a/testscommon/dataRetriever/poolFactory.go b/testscommon/dataRetriever/poolFactory.go index 165d202c102..18636ed9ce0 100644 --- a/testscommon/dataRetriever/poolFactory.go +++ b/testscommon/dataRetriever/poolFactory.go @@ -12,10 +12,8 @@ import ( "github.com/ElrondNetwork/elrond-go/dataRetriever/dataPool/headersCache" "github.com/ElrondNetwork/elrond-go/dataRetriever/shardedData" "github.com/ElrondNetwork/elrond-go/dataRetriever/txpool" - "github.com/ElrondNetwork/elrond-go/storage/lrucache/capacity" - "github.com/ElrondNetwork/elrond-go/storage/storageCacherAdapter" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon/txcachemocks" "github.com/ElrondNetwork/elrond-go/trie/factory" ) @@ -32,7 +30,7 @@ func panicIfError(message string, err error) { func CreateTxPool(numShards uint32, selfShard uint32) (dataRetriever.ShardedDataCacherNotifier, error) { return txpool.NewShardedTxPool( txpool.ArgShardedTxPool{ - Config: storageUnit.CacheConfig{ + Config: storageunit.CacheConfig{ Capacity: 100_000, SizePerSender: 1_000_000_000, SizeInBytes: 1_000_000_000, @@ -57,14 +55,14 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo txPool, err := CreateTxPool(numShards, selfShard) panicIfError("CreatePoolsHolder", err) - unsignedTxPool, err := shardedData.NewShardedData("unsignedTxPool", storageUnit.CacheConfig{ + unsignedTxPool, err := shardedData.NewShardedData("unsignedTxPool", storageunit.CacheConfig{ Capacity: 100000, SizeInBytes: 1000000000, Shards: 1, }) panicIfError("CreatePoolsHolder", err) - rewardsTxPool, err := shardedData.NewShardedData("rewardsTxPool", storageUnit.CacheConfig{ + rewardsTxPool, err := shardedData.NewShardedData("rewardsTxPool", storageunit.CacheConfig{ Capacity: 300, SizeInBytes: 300000, Shards: 1, @@ -77,31 +75,31 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo }) panicIfError("CreatePoolsHolder", err) - cacherConfig := storageUnit.CacheConfig{Capacity: 100000, Type: storageUnit.LRUCache, Shards: 1} - txBlockBody, err := storageUnit.NewCache(cacherConfig) + cacherConfig := storageunit.CacheConfig{Capacity: 100000, Type: storageunit.LRUCache, Shards: 1} + txBlockBody, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolder", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 100000, Type: storageUnit.LRUCache, Shards: 1} - peerChangeBlockBody, err := storageUnit.NewCache(cacherConfig) + cacherConfig = storageunit.CacheConfig{Capacity: 100000, Type: storageunit.LRUCache, Shards: 1} + peerChangeBlockBody, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolder", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 50000, Type: storageUnit.LRUCache} - cacher, err := capacity.NewCapacityLRU(10, 10000) + cacherConfig = storageunit.CacheConfig{Capacity: 50000, Type: storageunit.LRUCache} + cacher, err := cache.NewCapacityLRU(10, 10000) panicIfError("Create trieSync cacher", err) tempDir, _ := ioutil.TempDir("", "integrationTests") - cfg := storageUnit.ArgDB{ + cfg := storageunit.ArgDB{ Path: tempDir, - DBType: storageUnit.LvlDBSerial, + DBType: storageunit.LvlDBSerial, BatchDelaySeconds: 4, MaxBatchSize: 10000, MaxOpenFiles: 10, } - persister, err := storageUnit.NewDB(cfg) + persister, err := storageunit.NewDB(cfg) panicIfError("Create trieSync DB", err) tnf := factory.NewTrieNodeFactory() - adaptedTrieNodesStorage, err := storageCacherAdapter.NewStorageCacherAdapter( + adaptedTrieNodesStorage, err := storageunit.NewStorageCacherAdapter( cacher, persister, tnf, @@ -109,24 +107,24 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo ) panicIfError("Create AdaptedTrieNodesStorage", err) - trieNodesChunks, err := storageUnit.NewCache(cacherConfig) + trieNodesChunks, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolder", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 50000, Type: storageUnit.LRUCache} - smartContracts, err := storageUnit.NewCache(cacherConfig) + cacherConfig = storageunit.CacheConfig{Capacity: 50000, Type: storageunit.LRUCache} + smartContracts, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolder", err) - peerAuthPool, err := timecache.NewTimeCacher(timecache.ArgTimeCacher{ + peerAuthPool, err := cache.NewTimeCacher(cache.ArgTimeCacher{ DefaultSpan: 60 * time.Second, CacheExpiry: 60 * time.Second, }) panicIfError("CreatePoolsHolder", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 50000, Type: storageUnit.LRUCache} - heartbeatPool, err := storageUnit.NewCache(cacherConfig) + cacherConfig = storageunit.CacheConfig{Capacity: 50000, Type: storageunit.LRUCache} + heartbeatPool, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolder", err) - validatorsInfo, err := shardedData.NewShardedData("validatorsInfoPool", storageUnit.CacheConfig{ + validatorsInfo, err := shardedData.NewShardedData("validatorsInfoPool", storageunit.CacheConfig{ Capacity: 300, SizeInBytes: 300000, Shards: 1, @@ -161,14 +159,14 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) dataRetriever.PoolsHolder { var err error - unsignedTxPool, err := shardedData.NewShardedData("unsignedTxPool", storageUnit.CacheConfig{ + unsignedTxPool, err := shardedData.NewShardedData("unsignedTxPool", storageunit.CacheConfig{ Capacity: 100000, SizeInBytes: 1000000000, Shards: 1, }) panicIfError("CreatePoolsHolderWithTxPool", err) - rewardsTxPool, err := shardedData.NewShardedData("rewardsTxPool", storageUnit.CacheConfig{ + rewardsTxPool, err := shardedData.NewShardedData("rewardsTxPool", storageunit.CacheConfig{ Capacity: 300, SizeInBytes: 300000, Shards: 1, @@ -181,40 +179,40 @@ func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) }) panicIfError("CreatePoolsHolderWithTxPool", err) - cacherConfig := storageUnit.CacheConfig{Capacity: 100000, Type: storageUnit.LRUCache, Shards: 1} - txBlockBody, err := storageUnit.NewCache(cacherConfig) + cacherConfig := storageunit.CacheConfig{Capacity: 100000, Type: storageunit.LRUCache, Shards: 1} + txBlockBody, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 100000, Type: storageUnit.LRUCache, Shards: 1} - peerChangeBlockBody, err := storageUnit.NewCache(cacherConfig) + cacherConfig = storageunit.CacheConfig{Capacity: 100000, Type: storageunit.LRUCache, Shards: 1} + peerChangeBlockBody, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 50000, Type: storageUnit.LRUCache} - trieNodes, err := storageUnit.NewCache(cacherConfig) + cacherConfig = storageunit.CacheConfig{Capacity: 50000, Type: storageunit.LRUCache} + trieNodes, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) - trieNodesChunks, err := storageUnit.NewCache(cacherConfig) + trieNodesChunks, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 50000, Type: storageUnit.LRUCache} - smartContracts, err := storageUnit.NewCache(cacherConfig) + cacherConfig = storageunit.CacheConfig{Capacity: 50000, Type: storageunit.LRUCache} + smartContracts, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) - validatorsInfo, err := shardedData.NewShardedData("validatorsInfoPool", storageUnit.CacheConfig{ + validatorsInfo, err := shardedData.NewShardedData("validatorsInfoPool", storageunit.CacheConfig{ Capacity: 300, SizeInBytes: 300000, Shards: 1, }) panicIfError("CreatePoolsHolderWithTxPool", err) - peerAuthPool, err := timecache.NewTimeCacher(timecache.ArgTimeCacher{ + peerAuthPool, err := cache.NewTimeCacher(cache.ArgTimeCacher{ DefaultSpan: peerAuthDuration, CacheExpiry: peerAuthDuration, }) panicIfError("CreatePoolsHolderWithTxPool", err) - cacherConfig = storageUnit.CacheConfig{Capacity: 50000, Type: storageUnit.LRUCache} - heartbeatPool, err := storageUnit.NewCache(cacherConfig) + cacherConfig = storageunit.CacheConfig{Capacity: 50000, Type: storageunit.LRUCache} + heartbeatPool, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) currentBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() diff --git a/testscommon/dataRetriever/poolsHolderMock.go b/testscommon/dataRetriever/poolsHolderMock.go index 29bae65f787..d2814d2c954 100644 --- a/testscommon/dataRetriever/poolsHolderMock.go +++ b/testscommon/dataRetriever/poolsHolderMock.go @@ -11,8 +11,8 @@ import ( "github.com/ElrondNetwork/elrond-go/dataRetriever/shardedData" "github.com/ElrondNetwork/elrond-go/dataRetriever/txpool" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/cache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon/txcachemocks" ) @@ -41,7 +41,7 @@ func NewPoolsHolderMock() *PoolsHolderMock { holder.transactions, err = txpool.NewShardedTxPool( txpool.ArgShardedTxPool{ - Config: storageUnit.CacheConfig{ + Config: storageunit.CacheConfig{ Capacity: 100000, SizePerSender: 1000, SizeInBytes: 1000000000, @@ -58,14 +58,14 @@ func NewPoolsHolderMock() *PoolsHolderMock { ) panicIfError("NewPoolsHolderMock", err) - holder.unsignedTransactions, err = shardedData.NewShardedData("unsignedTxPool", storageUnit.CacheConfig{ + holder.unsignedTransactions, err = shardedData.NewShardedData("unsignedTxPool", storageunit.CacheConfig{ Capacity: 10000, SizeInBytes: 1000000000, Shards: 1, }) panicIfError("NewPoolsHolderMock", err) - holder.rewardTransactions, err = shardedData.NewShardedData("rewardsTxPool", storageUnit.CacheConfig{ + holder.rewardTransactions, err = shardedData.NewShardedData("rewardsTxPool", storageunit.CacheConfig{ Capacity: 100, SizeInBytes: 100000, Shards: 1, @@ -75,34 +75,34 @@ func NewPoolsHolderMock() *PoolsHolderMock { holder.headers, err = headersCache.NewHeadersPool(config.HeadersPoolConfig{MaxHeadersPerShard: 1000, NumElementsToRemoveOnEviction: 100}) panicIfError("NewPoolsHolderMock", err) - holder.miniBlocks, err = storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) + holder.miniBlocks, err = storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) panicIfError("NewPoolsHolderMock", err) - holder.peerChangesBlocks, err = storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) + holder.peerChangesBlocks, err = storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) panicIfError("NewPoolsHolderMock", err) holder.currBlockTxs = dataPool.NewCurrentBlockTransactionsPool() holder.currEpochValidatorInfo = dataPool.NewCurrentEpochValidatorInfoPool() - holder.trieNodes, err = storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.SizeLRUCache, Capacity: 900000, Shards: 1, SizeInBytes: 314572800}) + holder.trieNodes, err = storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.SizeLRUCache, Capacity: 900000, Shards: 1, SizeInBytes: 314572800}) panicIfError("NewPoolsHolderMock", err) - holder.trieNodesChunks, err = storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.SizeLRUCache, Capacity: 900000, Shards: 1, SizeInBytes: 314572800}) + holder.trieNodesChunks, err = storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.SizeLRUCache, Capacity: 900000, Shards: 1, SizeInBytes: 314572800}) panicIfError("NewPoolsHolderMock", err) - holder.smartContracts, err = storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) + holder.smartContracts, err = storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) panicIfError("NewPoolsHolderMock", err) - holder.peerAuthentications, err = timecache.NewTimeCacher(timecache.ArgTimeCacher{ + holder.peerAuthentications, err = cache.NewTimeCacher(cache.ArgTimeCacher{ DefaultSpan: 10 * time.Second, CacheExpiry: 10 * time.Second, }) panicIfError("NewPoolsHolderMock", err) - holder.heartbeats, err = storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) + holder.heartbeats, err = storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 10000, Shards: 1, SizeInBytes: 0}) panicIfError("NewPoolsHolderMock", err) - holder.validatorsInfo, err = shardedData.NewShardedData("validatorsInfoPool", storageUnit.CacheConfig{ + holder.validatorsInfo, err = shardedData.NewShardedData("validatorsInfoPool", storageunit.CacheConfig{ Capacity: 100, SizeInBytes: 100000, Shards: 1, diff --git a/testscommon/enableEpochsHandlerStub.go b/testscommon/enableEpochsHandlerStub.go index 302c69d9fc0..9c3f56bf158 100644 --- a/testscommon/enableEpochsHandlerStub.go +++ b/testscommon/enableEpochsHandlerStub.go @@ -111,6 +111,7 @@ type EnableEpochsHandlerStub struct { IsSetSenderInEeiOutputTransferFlagEnabledField bool IsChangeDelegationOwnerFlagEnabledField bool IsRefactorPeersMiniBlocksFlagEnabledField bool + IsFixAsyncCallBackArgsListFlagEnabledField bool } // ResetPenalizedTooMuchGasFlag - @@ -960,6 +961,14 @@ func (stub *EnableEpochsHandlerStub) IsRefactorPeersMiniBlocksFlagEnabled() bool return stub.IsRefactorPeersMiniBlocksFlagEnabledField } +// IsFixAsyncCallBackArgsListFlagEnabled - +func (stub *EnableEpochsHandlerStub) IsFixAsyncCallBackArgsListFlagEnabled() bool { + stub.RLock() + defer stub.RUnlock() + + return stub.IsFixAsyncCallBackArgsListFlagEnabledField +} + // IsInterfaceNil - func (stub *EnableEpochsHandlerStub) IsInterfaceNil() bool { return stub == nil diff --git a/testscommon/generalConfig.go b/testscommon/generalConfig.go index a31f1b6b2ae..bee4c8387f9 100644 --- a/testscommon/generalConfig.go +++ b/testscommon/generalConfig.go @@ -2,7 +2,7 @@ package testscommon import ( "github.com/ElrondNetwork/elrond-go/config" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // GetGeneralConfig returns the common configuration used for testing @@ -75,7 +75,7 @@ func GetGeneralConfig() config.Config { RootHashesSize: 100, DB: config.DBConfig{ FilePath: AddTimestampSuffix("EvictionWaitingList"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -85,7 +85,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("AccountsTrie"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -95,7 +95,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("AccountsTrieCheckpoints"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -105,7 +105,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("PeerAccountsTrie"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -115,7 +115,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("PeerAccountsTrieCheckpoints"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -166,7 +166,7 @@ func GetGeneralConfig() config.Config { TrieSyncStorage: config.TrieSyncStorageConfig{ DB: config.DBConfig{ FilePath: AddTimestampSuffix("TrieSync"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 2, MaxBatchSize: 1000, MaxOpenFiles: 10, @@ -181,7 +181,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("Transactions"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -191,7 +191,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("MiniBlocks"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -201,7 +201,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("ShardHdrHashNonce"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -211,7 +211,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("MetaBlock"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -221,7 +221,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("MetaHdrHashNonce"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -231,7 +231,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("UnsignedTransactions"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -241,7 +241,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("RewardTransactions"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -251,7 +251,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("BlockHeaders"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -262,7 +262,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("HeartbeatStorage"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -289,7 +289,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("StatusMetricsStorageDB"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -299,7 +299,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("SmartContractsStorage"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -309,7 +309,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("SmartContractsStorageSimulate"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -319,7 +319,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("PeerBlocks"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -329,7 +329,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("BootstrapData"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 1, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -341,7 +341,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("Logs"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 2, MaxBatchSize: 100, MaxOpenFiles: 10, @@ -352,7 +352,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("Receipts"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, @@ -362,7 +362,7 @@ func GetGeneralConfig() config.Config { Cache: getLRUCacheConfig(), DB: config.DBConfig{ FilePath: AddTimestampSuffix("ScheduledSCRs"), - Type: string(storageUnit.MemoryDB), + Type: string(storageunit.MemoryDB), BatchDelaySeconds: 30, MaxBatchSize: 6, MaxOpenFiles: 10, diff --git a/testscommon/mainFactoryMocks/bootstrapComponentsStub.go b/testscommon/mainFactoryMocks/bootstrapComponentsStub.go index 3b391794dc7..d4bfff4c887 100644 --- a/testscommon/mainFactoryMocks/bootstrapComponentsStub.go +++ b/testscommon/mainFactoryMocks/bootstrapComponentsStub.go @@ -68,6 +68,12 @@ func (bcs *BootstrapComponentsStub) HeaderIntegrityVerifier() nodeFactory.Header return bcs.HdrIntegrityVerifier } +// SetShardCoordinator - +func (bcs *BootstrapComponentsStub) SetShardCoordinator(shardCoordinator sharding.Coordinator) error { + bcs.ShCoordinator = shardCoordinator + return nil +} + // String - func (bcs *BootstrapComponentsStub) String() string { return "BootstrapComponentsStub" diff --git a/testscommon/processDebuggerStub.go b/testscommon/processDebuggerStub.go new file mode 100644 index 00000000000..b8e0d56d88d --- /dev/null +++ b/testscommon/processDebuggerStub.go @@ -0,0 +1,28 @@ +package testscommon + +// ProcessDebuggerStub - +type ProcessDebuggerStub struct { + SetLastCommittedBlockRoundCalled func(round uint64) + CloseCalled func() error +} + +// SetLastCommittedBlockRound - +func (stub *ProcessDebuggerStub) SetLastCommittedBlockRound(round uint64) { + if stub.SetLastCommittedBlockRoundCalled != nil { + stub.SetLastCommittedBlockRoundCalled(round) + } +} + +// Close - +func (stub *ProcessDebuggerStub) Close() error { + if stub.CloseCalled != nil { + return stub.CloseCalled() + } + + return nil +} + +// IsInterfaceNil - +func (stub *ProcessDebuggerStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/testscommon/snapshotPruningStorerMock.go b/testscommon/snapshotPruningStorerMock.go index 8467c9701b1..8f15c718981 100644 --- a/testscommon/snapshotPruningStorerMock.go +++ b/testscommon/snapshotPruningStorerMock.go @@ -1,5 +1,7 @@ package testscommon +import "github.com/ElrondNetwork/elrond-go-core/core" + // SnapshotPruningStorerMock - type SnapshotPruningStorerMock struct { *MemDbMock @@ -11,8 +13,10 @@ func NewSnapshotPruningStorerMock() *SnapshotPruningStorerMock { } // GetFromOldEpochsWithoutAddingToCache - -func (spsm *SnapshotPruningStorerMock) GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, error) { - return spsm.Get(key) +func (spsm *SnapshotPruningStorerMock) GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, core.OptionalUint32, error) { + val, err := spsm.Get(key) + + return val, core.OptionalUint32{}, err } // PutInEpoch - diff --git a/testscommon/state/accountWrapperMock.go b/testscommon/state/accountWrapperMock.go index e2e72f18233..e8a0cd6cdcf 100644 --- a/testscommon/state/accountWrapperMock.go +++ b/testscommon/state/accountWrapperMock.go @@ -6,6 +6,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/state" + vmcommon "github.com/ElrondNetwork/elrond-vm-common" ) var _ state.UserAccountHandler = (*AccountWrapMock)(nil) @@ -13,7 +14,6 @@ var _ state.UserAccountHandler = (*AccountWrapMock)(nil) // AccountWrapMock - type AccountWrapMock struct { AccountWrapMockData - dataTrie common.Trie nonce uint64 code []byte CodeHash []byte @@ -22,9 +22,23 @@ type AccountWrapMock struct { address []byte trackableDataTrie state.DataTrieTracker - SetNonceWithJournalCalled func(nonce uint64) error `json:"-"` - SetCodeHashWithJournalCalled func(codeHash []byte) error `json:"-"` - SetCodeWithJournalCalled func([]byte) error `json:"-"` + SetNonceWithJournalCalled func(nonce uint64) error `json:"-"` + SetCodeHashWithJournalCalled func(codeHash []byte) error `json:"-"` + SetCodeWithJournalCalled func([]byte) error `json:"-"` + AccountDataHandlerCalled func() vmcommon.AccountDataHandler `json:"-"` +} + +// NewAccountWrapMock - +func NewAccountWrapMock(adr []byte) *AccountWrapMock { + return &AccountWrapMock{ + address: adr, + trackableDataTrie: state.NewTrackableDataTrie([]byte("identifier"), nil), + } +} + +// SetTrackableDataTrie - +func (awm *AccountWrapMock) SetTrackableDataTrie(tdt state.DataTrieTracker) { + awm.trackableDataTrie = tdt } // SetUserName - @@ -81,14 +95,6 @@ func (awm *AccountWrapMock) GetOwnerAddress() []byte { return nil } -// NewAccountWrapMock - -func NewAccountWrapMock(adr []byte) *AccountWrapMock { - return &AccountWrapMock{ - address: adr, - trackableDataTrie: state.NewTrackableDataTrie([]byte("identifier"), nil), - } -} - // IsInterfaceNil - func (awm *AccountWrapMock) IsInterfaceNil() bool { return awm == nil @@ -109,11 +115,16 @@ func (awm *AccountWrapMock) SetCode(code []byte) { awm.code = code } -// RetrieveValueFromDataTrieTracker - -func (awm *AccountWrapMock) RetrieveValueFromDataTrieTracker(key []byte) ([]byte, error) { +// RetrieveValue - +func (awm *AccountWrapMock) RetrieveValue(key []byte) ([]byte, error) { return awm.trackableDataTrie.RetrieveValue(key) } +// SaveKeyValue - +func (awm *AccountWrapMock) SaveKeyValue(key []byte, value []byte) error { + return awm.trackableDataTrie.SaveKeyValue(key, value) +} + // HasNewCode - func (awm *AccountWrapMock) HasNewCode() bool { return len(awm.code) > 0 @@ -145,26 +156,33 @@ func (awm *AccountWrapMock) AddressBytes() []byte { } // DataTrie - -func (awm *AccountWrapMock) DataTrie() common.Trie { - return awm.dataTrie +func (awm *AccountWrapMock) DataTrie() common.DataTrieHandler { + return awm.trackableDataTrie.DataTrie() +} + +// SaveDirtyData - +func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) (map[string][]byte, error) { + return awm.trackableDataTrie.SaveDirtyData(trie) } // SetDataTrie - func (awm *AccountWrapMock) SetDataTrie(trie common.Trie) { - awm.dataTrie = trie awm.trackableDataTrie.SetDataTrie(trie) } -// DataTrieTracker - -func (awm *AccountWrapMock) DataTrieTracker() state.DataTrieTracker { - return awm.trackableDataTrie -} - //IncreaseNonce adds the given value to the current nonce func (awm *AccountWrapMock) IncreaseNonce(val uint64) { awm.nonce = awm.nonce + val } +// AccountDataHandler - +func (awm *AccountWrapMock) AccountDataHandler() vmcommon.AccountDataHandler { + if awm.AccountDataHandlerCalled != nil { + return awm.AccountDataHandlerCalled() + } + return awm.trackableDataTrie +} + // GetNonce gets the nonce of the account func (awm *AccountWrapMock) GetNonce() uint64 { return awm.nonce diff --git a/testscommon/state/accountsAdapterStub.go b/testscommon/state/accountsAdapterStub.go index 908941b5cb3..0d18fbd8bcf 100644 --- a/testscommon/state/accountsAdapterStub.go +++ b/testscommon/state/accountsAdapterStub.go @@ -39,6 +39,26 @@ type AccountsStub struct { GetAccountWithBlockInfoCalled func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) GetCodeWithBlockInfoCalled func(codeHash []byte, options common.RootHashHolder) ([]byte, common.BlockInfo, error) CloseCalled func() error + SetSyncerCalled func(syncer state.AccountsDBSyncer) error + StartSnapshotIfNeededCalled func() error +} + +// SetSyncer - +func (as *AccountsStub) SetSyncer(syncer state.AccountsDBSyncer) error { + if as.SetSyncerCalled != nil { + return as.SetSyncerCalled(syncer) + } + + return nil +} + +// StartSnapshotIfNeeded - +func (as *AccountsStub) StartSnapshotIfNeeded() error { + if as.StartSnapshotIfNeededCalled != nil { + return as.StartSnapshotIfNeededCalled() + } + + return nil } // GetTrie - diff --git a/testscommon/state/testTriePruningStorer.go b/testscommon/state/testTriePruningStorer.go index 833d1cf1834..bd487874c7b 100644 --- a/testscommon/state/testTriePruningStorer.go +++ b/testscommon/state/testTriePruningStorer.go @@ -5,21 +5,21 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" + "github.com/ElrondNetwork/elrond-go/storage/database" storageMock "github.com/ElrondNetwork/elrond-go/storage/mock" "github.com/ElrondNetwork/elrond-go/storage/pruning" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" ) // CreateTestingTriePruningStorer creates a new trie pruning storer that is used for testing func CreateTestingTriePruningStorer(coordinator sharding.Coordinator, notifier pruning.EpochStartNotifier) (storage.Storer, *persisterMap, error) { - cacheConf := storageUnit.CacheConfig{ + cacheConf := storageunit.CacheConfig{ Capacity: 10, Type: "LRU", Shards: 3, } - dbConf := storageUnit.DBConfig{ + dbConf := storageunit.DBConfig{ FilePath: "path/Epoch_0/Shard_1", Type: "LvlDBSerial", BatchDelaySeconds: 500, @@ -77,7 +77,7 @@ func (pm *persisterMap) GetPersister(path string) storage.Persister { persister, exists := pm.persisters[path] if !exists { - persister = memorydb.New() + persister = database.NewMemDB() pm.persisters[path] = persister } diff --git a/testscommon/state/userAccountStub.go b/testscommon/state/userAccountStub.go index edb8454db2b..55235b36b59 100644 --- a/testscommon/state/userAccountStub.go +++ b/testscommon/state/userAccountStub.go @@ -12,9 +12,9 @@ var _ state.UserAccountHandler = (*UserAccountStub)(nil) // UserAccountStub - type UserAccountStub struct { - Balance *big.Int - AddToBalanceCalled func(value *big.Int) error - DataTrieTrackerCalled func() state.DataTrieTracker + Balance *big.Int + AddToBalanceCalled func(value *big.Int) error + RetrieveValueCalled func(_ []byte) ([]byte, error) } // HasNewCode - @@ -133,23 +133,29 @@ func (u *UserAccountStub) SetDataTrie(_ common.Trie) { } // DataTrie - -func (u *UserAccountStub) DataTrie() common.Trie { +func (u *UserAccountStub) DataTrie() common.DataTrieHandler { return nil } -// RetrieveValueFromDataTrieTracker - -func (u *UserAccountStub) RetrieveValueFromDataTrieTracker(_ []byte) ([]byte, error) { +// RetrieveValue - +func (u *UserAccountStub) RetrieveValue(key []byte) ([]byte, error) { + if u.RetrieveValueCalled != nil { + return u.RetrieveValueCalled(key) + } + return nil, nil } -// DataTrieTracker - -func (u *UserAccountStub) DataTrieTracker() state.DataTrieTracker { - if u.DataTrieTrackerCalled != nil { - return u.DataTrieTrackerCalled() - } +// SaveKeyValue - +func (u *UserAccountStub) SaveKeyValue(_ []byte, _ []byte) error { return nil } +// SaveDirtyData - +func (u *UserAccountStub) SaveDirtyData(_ common.Trie) (map[string][]byte, error) { + return nil, nil +} + // IsInterfaceNil - func (u *UserAccountStub) IsInterfaceNil() bool { return false diff --git a/testscommon/storageManagerStub.go b/testscommon/storageManagerStub.go index 3be9ea09084..40498a52c3d 100644 --- a/testscommon/storageManagerStub.go +++ b/testscommon/storageManagerStub.go @@ -12,8 +12,8 @@ type StorageManagerStub struct { PutInEpochWithoutCacheCalled func([]byte, []byte, uint32) error GetCalled func([]byte) ([]byte, error) GetFromCurrentEpochCalled func([]byte) ([]byte, error) - TakeSnapshotCalled func([]byte, []byte, chan core.KeyValueHolder, chan error, common.SnapshotStatisticsHandler, uint32) - SetCheckpointCalled func([]byte, []byte, chan core.KeyValueHolder, chan error, common.SnapshotStatisticsHandler) + TakeSnapshotCalled func([]byte, []byte, chan core.KeyValueHolder, chan []byte, chan error, common.SnapshotStatisticsHandler, uint32) + SetCheckpointCalled func([]byte, []byte, chan core.KeyValueHolder, chan []byte, chan error, common.SnapshotStatisticsHandler) GetDbThatContainsHashCalled func([]byte) common.DBWriteCacher IsPruningEnabledCalled func() bool IsPruningBlockedCalled func() bool @@ -81,12 +81,13 @@ func (sms *StorageManagerStub) TakeSnapshot( rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, + missingNodesChan chan []byte, errChan chan error, stats common.SnapshotStatisticsHandler, epoch uint32, ) { if sms.TakeSnapshotCalled != nil { - sms.TakeSnapshotCalled(rootHash, mainTrieRootHash, leavesChan, errChan, stats, epoch) + sms.TakeSnapshotCalled(rootHash, mainTrieRootHash, leavesChan, missingNodesChan, errChan, stats, epoch) } } @@ -95,11 +96,12 @@ func (sms *StorageManagerStub) SetCheckpoint( rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, + missingNodesChan chan []byte, errChan chan error, stats common.SnapshotStatisticsHandler, ) { if sms.SetCheckpointCalled != nil { - sms.SetCheckpointCalled(rootHash, mainTrieRootHash, leavesChan, errChan, stats) + sms.SetCheckpointCalled(rootHash, mainTrieRootHash, leavesChan, missingNodesChan, errChan, stats) } } diff --git a/testscommon/trie/dataTrieTrackerStub.go b/testscommon/trie/dataTrieTrackerStub.go index 05454e53b07..d8fba985b32 100644 --- a/testscommon/trie/dataTrieTrackerStub.go +++ b/testscommon/trie/dataTrieTrackerStub.go @@ -6,42 +6,54 @@ import ( // DataTrieTrackerStub - type DataTrieTrackerStub struct { - ClearDataCachesCalled func() - DirtyDataCalled func() map[string][]byte - RetrieveValueCalled func(key []byte) ([]byte, error) - SaveKeyValueCalled func(key []byte, value []byte) error - SetDataTrieCalled func(tr common.Trie) - DataTrieCalled func() common.Trie -} - -// ClearDataCaches - -func (dtts *DataTrieTrackerStub) ClearDataCaches() { - dtts.ClearDataCachesCalled() -} - -// DirtyData - -func (dtts *DataTrieTrackerStub) DirtyData() map[string][]byte { - return dtts.DirtyDataCalled() + RetrieveValueCalled func(key []byte) ([]byte, error) + SaveKeyValueCalled func(key []byte, value []byte) error + SetDataTrieCalled func(tr common.Trie) + DataTrieCalled func() common.Trie + SaveDirtyDataCalled func(trie common.Trie) (map[string][]byte, error) } // RetrieveValue - func (dtts *DataTrieTrackerStub) RetrieveValue(key []byte) ([]byte, error) { - return dtts.RetrieveValueCalled(key) + if dtts.RetrieveValueCalled != nil { + return dtts.RetrieveValueCalled(key) + } + + return []byte{}, nil } // SaveKeyValue - func (dtts *DataTrieTrackerStub) SaveKeyValue(key []byte, value []byte) error { - return dtts.SaveKeyValueCalled(key, value) + if dtts.SaveKeyValueCalled != nil { + return dtts.SaveKeyValueCalled(key, value) + } + + return nil } // SetDataTrie - func (dtts *DataTrieTrackerStub) SetDataTrie(tr common.Trie) { - dtts.SetDataTrieCalled(tr) + if dtts.SetDataTrieCalled != nil { + dtts.SetDataTrieCalled(tr) + } } // DataTrie - -func (dtts *DataTrieTrackerStub) DataTrie() common.Trie { - return dtts.DataTrieCalled() +func (dtts *DataTrieTrackerStub) DataTrie() common.DataTrieHandler { + if dtts.DataTrieCalled != nil { + return dtts.DataTrieCalled() + } + + return nil +} + +// SaveDirtyData - +func (dtts *DataTrieTrackerStub) SaveDirtyData(mainTrie common.Trie) (map[string][]byte, error) { + if dtts.SaveDirtyDataCalled != nil { + return dtts.SaveDirtyDataCalled(mainTrie) + } + + return map[string][]byte{}, nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/testscommon/trie/snapshotPruningStorerStub.go b/testscommon/trie/snapshotPruningStorerStub.go index f9e8cb55645..1d03641437c 100644 --- a/testscommon/trie/snapshotPruningStorerStub.go +++ b/testscommon/trie/snapshotPruningStorerStub.go @@ -1,13 +1,14 @@ package trie import ( + "github.com/ElrondNetwork/elrond-go-core/core" "github.com/ElrondNetwork/elrond-go/testscommon" ) // SnapshotPruningStorerStub - type SnapshotPruningStorerStub struct { *testscommon.MemDbMock - GetFromOldEpochsWithoutAddingToCacheCalled func(key []byte) ([]byte, error) + GetFromOldEpochsWithoutAddingToCacheCalled func(key []byte) ([]byte, core.OptionalUint32, error) GetFromLastEpochCalled func(key []byte) ([]byte, error) GetFromCurrentEpochCalled func(key []byte) ([]byte, error) GetFromEpochCalled func(key []byte, epoch uint32) ([]byte, error) @@ -18,12 +19,12 @@ type SnapshotPruningStorerStub struct { } // GetFromOldEpochsWithoutAddingToCache - -func (spss *SnapshotPruningStorerStub) GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, error) { +func (spss *SnapshotPruningStorerStub) GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, core.OptionalUint32, error) { if spss.GetFromOldEpochsWithoutAddingToCacheCalled != nil { return spss.GetFromOldEpochsWithoutAddingToCacheCalled(key) } - return nil, nil + return nil, core.OptionalUint32{}, nil } // PutInEpoch - diff --git a/testscommon/trie/trieStub.go b/testscommon/trie/trieStub.go index 261abf94944..b0dcb1a99b1 100644 --- a/testscommon/trie/trieStub.go +++ b/testscommon/trie/trieStub.go @@ -12,26 +12,25 @@ var errNotImplemented = errors.New("not implemented") // TrieStub - type TrieStub struct { - GetCalled func(key []byte) ([]byte, error) - UpdateCalled func(key, value []byte) error - DeleteCalled func(key []byte) error - RootCalled func() ([]byte, error) - CommitCalled func() error - RecreateCalled func(root []byte) (common.Trie, error) - RecreateFromEpochCalled func(options common.RootHashHolder) (common.Trie, error) - GetObsoleteHashesCalled func() [][]byte - AppendToOldHashesCalled func([][]byte) - GetSerializedNodesCalled func([]byte, uint64) ([][]byte, uint64, error) - GetAllHashesCalled func() ([][]byte, error) - GetAllLeavesOnChannelCalled func(leavesChannel chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error - GetProofCalled func(key []byte) ([][]byte, []byte, error) - VerifyProofCalled func(rootHash []byte, key []byte, proof [][]byte) (bool, error) - GetStorageManagerCalled func() common.StorageManager - GetSerializedNodeCalled func(bytes []byte) ([]byte, error) - GetNumNodesCalled func() common.NumNodesDTO - GetOldRootCalled func() []byte - MarkStorerAsSyncedAndActiveCalled func() - CloseCalled func() error + GetCalled func(key []byte) ([]byte, error) + UpdateCalled func(key, value []byte) error + DeleteCalled func(key []byte) error + RootCalled func() ([]byte, error) + CommitCalled func() error + RecreateCalled func(root []byte) (common.Trie, error) + RecreateFromEpochCalled func(options common.RootHashHolder) (common.Trie, error) + GetObsoleteHashesCalled func() [][]byte + AppendToOldHashesCalled func([][]byte) + GetSerializedNodesCalled func([]byte, uint64) ([][]byte, uint64, error) + GetAllHashesCalled func() ([][]byte, error) + GetAllLeavesOnChannelCalled func(leavesChannel chan core.KeyValueHolder, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error + GetProofCalled func(key []byte) ([][]byte, []byte, error) + VerifyProofCalled func(rootHash []byte, key []byte, proof [][]byte) (bool, error) + GetStorageManagerCalled func() common.StorageManager + GetSerializedNodeCalled func(bytes []byte) ([]byte, error) + GetNumNodesCalled func() common.NumNodesDTO + GetOldRootCalled func() []byte + CloseCalled func() error } // GetStorageManager - @@ -62,9 +61,9 @@ func (ts *TrieStub) VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bo } // GetAllLeavesOnChannel - -func (ts *TrieStub) GetAllLeavesOnChannel(leavesChannel chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { +func (ts *TrieStub) GetAllLeavesOnChannel(leavesChannel chan core.KeyValueHolder, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { if ts.GetAllLeavesOnChannelCalled != nil { - return ts.GetAllLeavesOnChannelCalled(leavesChannel, ctx, rootHash) + return ts.GetAllLeavesOnChannelCalled(leavesChannel, ctx, rootHash, keyBuilder) } return nil @@ -205,13 +204,6 @@ func (ts *TrieStub) GetOldRoot() []byte { return nil } -// MarkStorerAsSyncedAndActive - -func (ts *TrieStub) MarkStorerAsSyncedAndActive() { - if ts.MarkStorerAsSyncedAndActiveCalled != nil { - ts.MarkStorerAsSyncedAndActiveCalled() - } -} - // Close - func (ts *TrieStub) Close() error { if ts.CloseCalled != nil { diff --git a/testscommon/utils.go b/testscommon/utils.go index 9423b1b3e48..6324b158b48 100644 --- a/testscommon/utils.go +++ b/testscommon/utils.go @@ -5,8 +5,8 @@ import ( "time" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" ) // HashSize holds the size of a typical hash used by the protocol @@ -23,9 +23,9 @@ func CreateMemUnit() storage.Storer { capacity := uint32(10) shards := uint32(1) sizeInBytes := uint64(0) - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) - persist, _ := memorydb.NewlruDB(100000) - unit, _ := storageUnit.NewStorageUnit(cache, persist) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) + persist, _ := database.NewlruDB(100000) + unit, _ := storageunit.NewStorageUnit(cache, persist) return unit } diff --git a/trie/branchNode.go b/trie/branchNode.go index 620c51242a1..3e1db2b72c5 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "io" + "strings" "sync" "github.com/ElrondNetwork/elrond-go-core/core" @@ -338,6 +339,7 @@ func (bn *branchNode) commitCheckpoint( func (bn *branchNode) commitSnapshot( db common.DBWriteCacher, leavesChan chan core.KeyValueHolder, + missingNodesChan chan []byte, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider, @@ -354,6 +356,11 @@ func (bn *branchNode) commitSnapshot( for i := range bn.children { err = resolveIfCollapsed(bn, byte(i), db) if err != nil { + if strings.Contains(err.Error(), common.GetNodeFromDBErrorString) { + log.Error(err.Error()) + missingNodesChan <- bn.EncodedChildren[i] + continue + } return err } @@ -361,7 +368,7 @@ func (bn *branchNode) commitSnapshot( continue } - err = bn.children[i].commitSnapshot(db, leavesChan, ctx, stats, idleProvider) + err = bn.children[i].commitSnapshot(db, leavesChan, missingNodesChan, ctx, stats, idleProvider) if err != nil { return err } @@ -773,7 +780,8 @@ func (bn *branchNode) loadChildren(getNode func([]byte) (node, error)) ([][]byte func (bn *branchNode) getAllLeavesOnChannel( leavesChannel chan core.KeyValueHolder, - key []byte, db common.DBWriteCacher, + keyBuilder common.KeyBuilder, + db common.DBWriteCacher, marshalizer marshal.Marshalizer, chanClose chan struct{}, ctx context.Context, @@ -801,8 +809,9 @@ func (bn *branchNode) getAllLeavesOnChannel( continue } - childKey := append(key, byte(i)) - err = bn.children[i].getAllLeavesOnChannel(leavesChannel, childKey, db, marshalizer, chanClose, ctx) + clonedKeyBuilder := keyBuilder.Clone() + clonedKeyBuilder.BuildKey([]byte{byte(i)}) + err = bn.children[i].getAllLeavesOnChannel(leavesChannel, clonedKeyBuilder, db, marshalizer, chanClose, ctx) if err != nil { return err } diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index 2e0675c04c5..15068c3d153 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -12,7 +12,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/config" elrondErrors "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" @@ -1033,7 +1033,7 @@ func TestBranchNode_getChildrenCollapsedBn(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + _ = bn.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) children, err := collapsedBn.getChildren(db) assert.Nil(t, err) @@ -1067,7 +1067,7 @@ func TestBranchNode_loadChildren(t *testing.T) { tr := initTrie() _ = tr.root.setRootHash() nodes, _ := getEncodedTrieNodesAndHashes(tr) - nodesCacher, _ := lrucache.NewCache(100) + nodesCacher, _ := cache.NewLRUCache(100) for i := range nodes { n, _ := NewInterceptedTrieNode(nodes[i], marsh, hasher) nodesCacher.Put(n.hash, n, len(n.GetSerialized())) @@ -1233,8 +1233,8 @@ func TestBranchNode_printShouldNotPanicEvenIfNodeIsCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) - _ = collapsedBn.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + _ = bn.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + _ = collapsedBn.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) bn.print(bnWriter, 0, db) collapsedBn.print(collapsedBnWriter, 0, db) @@ -1271,7 +1271,7 @@ func TestBranchNode_getAllHashesResolvesCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + _ = bn.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) hashes, err := collapsedBn.getAllHashes(db) assert.Nil(t, err) @@ -1355,7 +1355,7 @@ func TestBranchNode_commitContextDone(t *testing.T) { err := bn.commitCheckpoint(db, db, nil, nil, ctx, &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) assert.Equal(t, elrondErrors.ErrContextClosing, err) - err = bn.commitSnapshot(db, nil, ctx, &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + err = bn.commitSnapshot(db, nil, nil, ctx, &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) assert.Equal(t, elrondErrors.ErrContextClosing, err) } diff --git a/trie/doubleListSync_test.go b/trie/doubleListSync_test.go index f6d72bc5edb..1979c221b6d 100644 --- a/trie/doubleListSync_test.go +++ b/trie/doubleListSync_test.go @@ -13,8 +13,8 @@ import ( "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/errors" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" "github.com/ElrondNetwork/elrond-go/trie/hashesHolder" @@ -29,9 +29,9 @@ func createMemUnit() storage.Storer { capacity := uint32(10) shards := uint32(1) sizeInBytes := uint64(0) - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) - persist, _ := memorydb.NewlruDB(100000) - unit, _ := storageUnit.NewStorageUnit(cache, persist) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) + persist, _ := database.NewlruDB(100000) + unit, _ := storageunit.NewStorageUnit(cache, persist) return unit } @@ -69,8 +69,8 @@ func createInMemoryTrieFromDB(db storage.Persister) (common.Trie, storage.Storer capacity := uint32(10) shards := uint32(1) sizeInBytes := uint64(0) - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) - unit, _ := storageUnit.NewStorageUnit(cache, db) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: capacity, Shards: shards, SizeInBytes: sizeInBytes}) + unit, _ := storageunit.NewStorageUnit(cache, db) tsm, _ := createTrieStorageManager(unit) tr, _ := NewTrie(tsm, marshalizer, hasherMock, 6) diff --git a/trie/errors.go b/trie/errors.go index ec3a44f84bc..ba9abe19a81 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -52,9 +52,6 @@ var ErrNilLeafNode = errors.New("the leaf node is nil") // ErrNilNode is raised when we reach a nil node var ErrNilNode = errors.New("the node is nil") -// ErrInvalidLength signals that length of the array is invalid -var ErrInvalidLength = errors.New("invalid array length") - // ErrWrongTypeAssertion signals that wrong type was provided var ErrWrongTypeAssertion = errors.New("wrong type assertion") diff --git a/trie/extensionNode.go b/trie/extensionNode.go index fd1e03b49e2..dce3b265968 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "io" + "strings" "sync" "github.com/ElrondNetwork/elrond-go-core/core" @@ -244,6 +245,7 @@ func (en *extensionNode) commitCheckpoint( func (en *extensionNode) commitSnapshot( db common.DBWriteCacher, leavesChan chan core.KeyValueHolder, + missingNodesChan chan []byte, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider, @@ -258,13 +260,22 @@ func (en *extensionNode) commitSnapshot( } err = resolveIfCollapsed(en, 0, db) + isMissingNodeErr := false if err != nil { - return err + isMissingNodeErr = strings.Contains(err.Error(), common.GetNodeFromDBErrorString) + if !isMissingNodeErr { + return err + } } - err = en.child.commitSnapshot(db, leavesChan, ctx, stats, idleProvider) - if err != nil { - return err + if isMissingNodeErr { + log.Error(err.Error()) + missingNodesChan <- en.EncodedChild + } else { + err = en.child.commitSnapshot(db, leavesChan, missingNodesChan, ctx, stats, idleProvider) + if err != nil { + return err + } } return en.saveToStorage(db, stats) @@ -639,7 +650,8 @@ func (en *extensionNode) loadChildren(getNode func([]byte) (node, error)) ([][]b func (en *extensionNode) getAllLeavesOnChannel( leavesChannel chan core.KeyValueHolder, - key []byte, db common.DBWriteCacher, + keyBuilder common.KeyBuilder, + db common.DBWriteCacher, marshalizer marshal.Marshalizer, chanClose chan struct{}, ctx context.Context, @@ -662,8 +674,8 @@ func (en *extensionNode) getAllLeavesOnChannel( return err } - childKey := append(key, en.Key...) - err = en.child.getAllLeavesOnChannel(leavesChannel, childKey, db, marshalizer, chanClose, ctx) + keyBuilder.BuildKey(en.Key) + err = en.child.getAllLeavesOnChannel(leavesChannel, keyBuilder.Clone(), db, marshalizer, chanClose, ctx) if err != nil { return err } diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 97b89088282..a64a0f7c6f5 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -8,7 +8,7 @@ import ( "github.com/ElrondNetwork/elrond-go/common" elrondErrors "github.com/ElrondNetwork/elrond-go/errors" - "github.com/ElrondNetwork/elrond-go/storage/lrucache" + "github.com/ElrondNetwork/elrond-go/storage/cache" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" trieMock "github.com/ElrondNetwork/elrond-go/testscommon/trie" @@ -806,7 +806,7 @@ func TestExtensionNode_loadChildren(t *testing.T) { _ = tr.Update([]byte("ddog"), []byte("cat")) _ = tr.root.setRootHash() nodes, _ := getEncodedTrieNodesAndHashes(tr) - nodesCacher, _ := lrucache.NewCache(100) + nodesCacher, _ := cache.NewLRUCache(100) for i := range nodes { n, _ := NewInterceptedTrieNode(nodes[i], marsh, hasher) nodesCacher.Put(n.hash, n, len(n.GetSerialized())) @@ -902,7 +902,7 @@ func TestExtensionNode_printShouldNotPanicEvenIfNodeIsCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() _ = en.commitDirty(0, 5, db, db) - _ = collapsedEn.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + _ = collapsedEn.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) en.print(enWriter, 0, db) collapsedEn.print(collapsedEnWriter, 0, db) @@ -915,7 +915,7 @@ func TestExtensionNode_getDirtyHashesFromCleanNode(t *testing.T) { db := testscommon.NewMemDbMock() en, _ := getEnAndCollapsedEn() - _ = en.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + _ = en.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) dirtyHashes := make(common.ModifiedHashes) err := en.getDirtyHashes(dirtyHashes) @@ -940,7 +940,7 @@ func TestExtensionNode_getAllHashesResolvesCollapsed(t *testing.T) { trieNodes := 5 db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + _ = en.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) hashes, err := collapsedEn.getAllHashes(db) assert.Nil(t, err) @@ -1025,7 +1025,7 @@ func TestExtensionNode_commitContextDone(t *testing.T) { err := en.commitCheckpoint(db, db, nil, nil, ctx, &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) assert.Equal(t, elrondErrors.ErrContextClosing, err) - err = en.commitSnapshot(db, nil, ctx, &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + err = en.commitSnapshot(db, nil, nil, ctx, &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) assert.Equal(t, elrondErrors.ErrContextClosing, err) } diff --git a/trie/factory/trieCreator_test.go b/trie/factory/trieCreator_test.go index 05311baad44..89e1aea61bd 100644 --- a/trie/factory/trieCreator_test.go +++ b/trie/factory/trieCreator_test.go @@ -146,6 +146,7 @@ func TestTrieCreator_CreateWithNilMainStorerShouldErr(t *testing.T) { createArgs.MainStorer = nil _, tr, err := tf.Create(createArgs) require.Nil(t, tr) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), trie.ErrNilStorer.Error())) } @@ -160,6 +161,7 @@ func TestTrieCreator_CreateWithNilCheckpointsStorerShouldErr(t *testing.T) { createArgs.CheckpointsStorer = nil _, tr, err := tf.Create(createArgs) require.Nil(t, tr) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), trie.ErrNilStorer.Error())) } @@ -194,6 +196,7 @@ func testWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *testing.T }) require.True(t, check.IfNil(holder)) require.Nil(t, storageManager) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), storage.ErrKeyNotFound.Error())) require.True(t, strings.Contains(err.Error(), missingUnit.String())) } diff --git a/trie/interface.go b/trie/interface.go index 4fdb7f6b886..6ef7438042c 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -39,7 +39,7 @@ type node interface { isValid() bool setDirty(bool) loadChildren(func([]byte) (node, error)) ([][]byte, []node, error) - getAllLeavesOnChannel(chan core.KeyValueHolder, []byte, common.DBWriteCacher, marshal.Marshalizer, chan struct{}, context.Context) error + getAllLeavesOnChannel(chan core.KeyValueHolder, common.KeyBuilder, common.DBWriteCacher, marshal.Marshalizer, chan struct{}, context.Context) error getAllHashes(db common.DBWriteCacher) ([][]byte, error) getNextHashAndKey([]byte) (bool, []byte, []byte) getNumNodes() common.NumNodesDTO @@ -47,7 +47,7 @@ type node interface { commitDirty(level byte, maxTrieLevelInMemory uint, originDb common.DBWriteCacher, targetDb common.DBWriteCacher) error commitCheckpoint(originDb common.DBWriteCacher, targetDb common.DBWriteCacher, checkpointHashes CheckpointHashesHolder, leavesChan chan core.KeyValueHolder, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider) error - commitSnapshot(originDb common.DBWriteCacher, leavesChan chan core.KeyValueHolder, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider) error + commitSnapshot(originDb common.DBWriteCacher, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider) error getMarshalizer() marshal.Marshalizer setMarshalizer(marshal.Marshalizer) @@ -64,7 +64,7 @@ type dbWithGetFromEpoch interface { type snapshotNode interface { commitCheckpoint(originDb common.DBWriteCacher, targetDb common.DBWriteCacher, checkpointHashes CheckpointHashesHolder, leavesChan chan core.KeyValueHolder, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider) error - commitSnapshot(originDb common.DBWriteCacher, leavesChan chan core.KeyValueHolder, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider) error + commitSnapshot(originDb common.DBWriteCacher, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, ctx context.Context, stats common.SnapshotStatisticsHandler, idleProvider IdleNodeProvider) error } // RequestHandler defines the methods through which request to data can be made @@ -97,7 +97,7 @@ type epochStorer interface { type snapshotPruningStorer interface { common.DBWriteCacher - GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, error) + GetFromOldEpochsWithoutAddingToCache(key []byte) ([]byte, core.OptionalUint32, error) GetFromLastEpoch(key []byte) ([]byte, error) PutInEpoch(key []byte, data []byte, epoch uint32) error PutInEpochWithoutCache(key []byte, data []byte, epoch uint32) error @@ -122,3 +122,8 @@ type IdleNodeProvider interface { type storageManagerExtension interface { RemoveFromCheckpointHashesHolder(hash []byte) } + +// StorageMarker is used to mark the given storer as synced and active +type StorageMarker interface { + MarkStorerAsSyncedAndActive(storer common.StorageManager) +} diff --git a/trie/keyBuilder/disabledKeyBuilder.go b/trie/keyBuilder/disabledKeyBuilder.go new file mode 100644 index 00000000000..603bc8561d2 --- /dev/null +++ b/trie/keyBuilder/disabledKeyBuilder.go @@ -0,0 +1,28 @@ +package keyBuilder + +import ( + "github.com/ElrondNetwork/elrond-go/common" +) + +type disabledKeyBuilder struct { +} + +// NewDisabledKeyBuilder creates a new disabled key builder. This should be used when the key is not needed +func NewDisabledKeyBuilder() *disabledKeyBuilder { + return &disabledKeyBuilder{} +} + +// BuildKey does nothing for this implementation +func (dkb *disabledKeyBuilder) BuildKey(_ []byte) { + +} + +// GetKey returns an empty byte array for this implementation +func (dkb *disabledKeyBuilder) GetKey() ([]byte, error) { + return []byte{}, nil +} + +// Clone returns a new disabled key builder +func (dkb *disabledKeyBuilder) Clone() common.KeyBuilder { + return &disabledKeyBuilder{} +} diff --git a/trie/keyBuilder/errors.go b/trie/keyBuilder/errors.go new file mode 100644 index 00000000000..6c894397541 --- /dev/null +++ b/trie/keyBuilder/errors.go @@ -0,0 +1,6 @@ +package keyBuilder + +import "errors" + +// ErrInvalidLength signals that length of the array is invalid +var ErrInvalidLength = errors.New("invalid array length") diff --git a/trie/keyBuilder/keyBuilder.go b/trie/keyBuilder/keyBuilder.go new file mode 100644 index 00000000000..84228a3cb77 --- /dev/null +++ b/trie/keyBuilder/keyBuilder.go @@ -0,0 +1,61 @@ +package keyBuilder + +import ( + "github.com/ElrondNetwork/elrond-go/common" +) + +const ( + // NibbleSize marks the size of a byte nibble + NibbleSize = 4 + + keyLength = 32 +) + +type keyBuilder struct { + key []byte +} + +// NewKeyBuilder creates a new key builder. This is used for building trie keys when traversing the trie. +// Use this only if you traverse the trie from the root, else hexToTrieKeyBytes might fail +func NewKeyBuilder() *keyBuilder { + return &keyBuilder{ + key: make([]byte, 0, keyLength), + } +} + +// BuildKey appends the given byte array to the existing key +func (kb *keyBuilder) BuildKey(keyPart []byte) { + kb.key = append(kb.key, keyPart...) +} + +// GetKey transforms the key from hex to trie key, and returns it. +// Is mandatory that GetKey always returns a new byte slice, not a pointer to an existing one +func (kb *keyBuilder) GetKey() ([]byte, error) { + return hexToTrieKeyBytes(kb.key) +} + +// Clone returns a new KeyBuilder with the same key +func (kb *keyBuilder) Clone() common.KeyBuilder { + return &keyBuilder{ + key: kb.key, + } +} + +// hexToTrieKeyBytes transforms hex nibbles into key bytes. The hex terminator is removed from the end of the hex slice, +// and then the hex slice is reversed when forming the key bytes. +func hexToTrieKeyBytes(hex []byte) ([]byte, error) { + hex = hex[:len(hex)-1] + length := len(hex) + if length%2 != 0 { + return nil, ErrInvalidLength + } + + key := make([]byte, length/2) + hexSliceIndex := 0 + for i := len(key) - 1; i >= 0; i-- { + key[i] = hex[hexSliceIndex+1]< 0; i -= 2 { - nibbles[i] = str[hexSliceIndex] >> nibbleSize + nibbles[i] = str[hexSliceIndex] >> keyBuilder.NibbleSize nibbles[i-1] = str[hexSliceIndex] & nibbleMask hexSliceIndex++ } @@ -219,25 +219,6 @@ func keyBytesToHex(str []byte) []byte { return nibbles } -// hexToKeyBytes transforms hex nibbles into key bytes. The hex terminator is removed from the end of the hex slice, -// and then the hex slice is reversed when forming the key bytes. -func hexToKeyBytes(hex []byte) ([]byte, error) { - hex = hex[:len(hex)-1] - length := len(hex) - if length%2 != 0 { - return nil, ErrInvalidLength - } - - key := make([]byte, length/2) - hexSliceIndex := 0 - for i := len(key) - 1; i >= 0; i-- { - key[i] = hex[hexSliceIndex+1]<= stsm.epoch-1 { + return + } + + if bytes.Equal(key, []byte(common.ActiveDBKey)) || bytes.Equal(key, []byte(common.TrieSyncedKey)) { + return + } + + log.Trace("put missing hash in snapshot storer", "hash", key, "epoch", stsm.epoch-1) + err := stsm.mainSnapshotStorer.PutInEpoch(key, val, stsm.epoch-1) + if err != nil { + log.Warn("can not put in epoch", + "error", err, + "epoch", stsm.epoch-1, + ) + } +} + // Put adds the given value to the main storer func (stsm *snapshotTrieStorageManager) Put(key, data []byte) error { stsm.storageOperationMutex.Lock() diff --git a/trie/snapshotTrieStorageManager_test.go b/trie/snapshotTrieStorageManager_test.go index 7681fe71b5b..74fafb8c42c 100644 --- a/trie/snapshotTrieStorageManager_test.go +++ b/trie/snapshotTrieStorageManager_test.go @@ -4,11 +4,15 @@ import ( "strings" "testing" + "github.com/ElrondNetwork/elrond-go-core/core" + "github.com/ElrondNetwork/elrond-go/common" "github.com/ElrondNetwork/elrond-go/testscommon/trie" "github.com/stretchr/testify/assert" ) func TestNewSnapshotTrieStorageManagerInvalidStorerType(t *testing.T) { + t.Parallel() + _, trieStorage := newEmptyTrie() stsm, err := newSnapshotTrieStorageManager(trieStorage, 0) @@ -17,6 +21,8 @@ func TestNewSnapshotTrieStorageManagerInvalidStorerType(t *testing.T) { } func TestNewSnapshotTrieStorageManager(t *testing.T) { + t.Parallel() + _, trieStorage := newEmptyTrie() trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{} stsm, err := newSnapshotTrieStorageManager(trieStorage, 0) @@ -25,12 +31,14 @@ func TestNewSnapshotTrieStorageManager(t *testing.T) { } func TestNewSnapshotTrieStorageManager_GetFromOldEpochsWithoutCache(t *testing.T) { + t.Parallel() + _, trieStorage := newEmptyTrie() getFromOldEpochsWithoutCacheCalled := false trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ - GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, error) { + GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, core.OptionalUint32, error) { getFromOldEpochsWithoutCacheCalled = true - return nil, nil + return nil, core.OptionalUint32{}, nil }, } stsm, _ := newSnapshotTrieStorageManager(trieStorage, 0) @@ -40,6 +48,8 @@ func TestNewSnapshotTrieStorageManager_GetFromOldEpochsWithoutCache(t *testing.T } func TestNewSnapshotTrieStorageManager_PutWithoutCache(t *testing.T) { + t.Parallel() + _, trieStorage := newEmptyTrie() putWithoutCacheCalled := false trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ @@ -55,6 +65,8 @@ func TestNewSnapshotTrieStorageManager_PutWithoutCache(t *testing.T) { } func TestNewSnapshotTrieStorageManager_GetFromLastEpoch(t *testing.T) { + t.Parallel() + _, trieStorage := newEmptyTrie() getFromLastEpochCalled := false trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ @@ -68,3 +80,132 @@ func TestNewSnapshotTrieStorageManager_GetFromLastEpoch(t *testing.T) { _, _ = stsm.GetFromLastEpoch([]byte("key")) assert.True(t, getFromLastEpochCalled) } + +func TestSnapshotTrieStorageManager_AlsoAddInPreviousEpoch(t *testing.T) { + t.Parallel() + + t.Run("HasValue is false", func(t *testing.T) { + val := []byte("val") + _, trieStorage := newEmptyTrie() + trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, core.OptionalUint32, error) { + return val, core.OptionalUint32{}, nil + }, + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + assert.Fail(t, "this should not have been called") + return nil + }, + } + stsm, _ := newSnapshotTrieStorageManager(trieStorage, 5) + + returnedVal, _ := stsm.Get([]byte("key")) + assert.Equal(t, val, returnedVal) + }) + t.Run("epoch is previous epoch", func(t *testing.T) { + val := []byte("val") + _, trieStorage := newEmptyTrie() + trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, core.OptionalUint32, error) { + epoch := core.OptionalUint32{ + Value: 4, + HasValue: true, + } + return []byte("val"), epoch, nil + }, + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + assert.Fail(t, "this should not have been called") + return nil + }, + } + stsm, _ := newSnapshotTrieStorageManager(trieStorage, 5) + + returnedVal, _ := stsm.Get([]byte("key")) + assert.Equal(t, val, returnedVal) + }) + t.Run("epoch is 0", func(t *testing.T) { + val := []byte("val") + _, trieStorage := newEmptyTrie() + trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, core.OptionalUint32, error) { + epoch := core.OptionalUint32{ + Value: 4, + HasValue: true, + } + return []byte("val"), epoch, nil + }, + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + assert.Fail(t, "this should not have been called") + return nil + }, + } + stsm, _ := newSnapshotTrieStorageManager(trieStorage, 0) + + returnedVal, _ := stsm.Get([]byte("key")) + assert.Equal(t, val, returnedVal) + }) + t.Run("key is ActiveDBKey", func(t *testing.T) { + val := []byte("val") + _, trieStorage := newEmptyTrie() + trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, core.OptionalUint32, error) { + epoch := core.OptionalUint32{ + Value: 3, + HasValue: true, + } + return []byte("val"), epoch, nil + }, + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + assert.Fail(t, "this should not have been called") + return nil + }, + } + stsm, _ := newSnapshotTrieStorageManager(trieStorage, 5) + + returnedVal, _ := stsm.Get([]byte(common.ActiveDBKey)) + assert.Equal(t, val, returnedVal) + }) + t.Run("key is TrieSyncedKey", func(t *testing.T) { + val := []byte("val") + _, trieStorage := newEmptyTrie() + trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, core.OptionalUint32, error) { + epoch := core.OptionalUint32{ + Value: 3, + HasValue: true, + } + return []byte("val"), epoch, nil + }, + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + assert.Fail(t, "this should not have been called") + return nil + }, + } + stsm, _ := newSnapshotTrieStorageManager(trieStorage, 5) + + returnedVal, _ := stsm.Get([]byte(common.TrieSyncedKey)) + assert.Equal(t, val, returnedVal) + }) + t.Run("add in previous epoch", func(t *testing.T) { + val := []byte("val") + putInEpochCalled := false + _, trieStorage := newEmptyTrie() + trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + GetFromOldEpochsWithoutAddingToCacheCalled: func(_ []byte) ([]byte, core.OptionalUint32, error) { + epoch := core.OptionalUint32{ + Value: 3, + HasValue: true, + } + return []byte("val"), epoch, nil + }, + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + putInEpochCalled = true + return nil + }, + } + stsm, _ := newSnapshotTrieStorageManager(trieStorage, 5) + + returnedVal, _ := stsm.Get([]byte("key")) + assert.Equal(t, val, returnedVal) + assert.True(t, putInEpochCalled) + }) +} diff --git a/trie/storageMarker/disabledStorageMaker_test.go b/trie/storageMarker/disabledStorageMaker_test.go new file mode 100644 index 00000000000..a3e745b45f8 --- /dev/null +++ b/trie/storageMarker/disabledStorageMaker_test.go @@ -0,0 +1,16 @@ +package storageMarker + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDisabledStorageMarker_MarkStorerAsSyncedAndActive(t *testing.T) { + t.Parallel() + + dsm := NewDisabledStorageMarker() + assert.NotNil(t, dsm) + + dsm.MarkStorerAsSyncedAndActive(nil) +} diff --git a/trie/storageMarker/disabledStorageMarker.go b/trie/storageMarker/disabledStorageMarker.go new file mode 100644 index 00000000000..9c92a1181a1 --- /dev/null +++ b/trie/storageMarker/disabledStorageMarker.go @@ -0,0 +1,15 @@ +package storageMarker + +import "github.com/ElrondNetwork/elrond-go/common" + +type disabledStorageMarker struct { +} + +// NewDisabledStorageMarker creates a new instance of disabledStorageMarker +func NewDisabledStorageMarker() *disabledStorageMarker { + return &disabledStorageMarker{} +} + +// MarkStorerAsSyncedAndActive does nothing for this implementation +func (dsm *disabledStorageMarker) MarkStorerAsSyncedAndActive(_ common.StorageManager) { +} diff --git a/trie/storageMarker/trieStorageMarker.go b/trie/storageMarker/trieStorageMarker.go new file mode 100644 index 00000000000..62c7a950cb5 --- /dev/null +++ b/trie/storageMarker/trieStorageMarker.go @@ -0,0 +1,41 @@ +package storageMarker + +import ( + logger "github.com/ElrondNetwork/elrond-go-logger" + "github.com/ElrondNetwork/elrond-go/common" +) + +var log = logger.GetOrCreate("trie") + +type trieStorageMarker struct { +} + +// NewTrieStorageMarker creates a new instance of trieStorageMarker +func NewTrieStorageMarker() *trieStorageMarker { + return &trieStorageMarker{} +} + +// MarkStorerAsSyncedAndActive marks the storage as synced and active +func (sm *trieStorageMarker) MarkStorerAsSyncedAndActive(storer common.StorageManager) { + epoch, err := storer.GetLatestStorageEpoch() + if err != nil { + log.Error("getLatestStorageEpoch error", "error", err) + } + + err = storer.Put([]byte(common.TrieSyncedKey), []byte(common.TrieSyncedVal)) + if err != nil { + log.Error("error while putting trieSynced value into main storer after sync", "error", err) + } + log.Debug("set trieSyncedKey in epoch", "epoch", epoch) + + lastEpoch := epoch - 1 + if epoch == 0 { + lastEpoch = 0 + } + + err = storer.PutInEpochWithoutCache([]byte(common.ActiveDBKey), []byte(common.ActiveDBVal), lastEpoch) + if err != nil { + log.Error("error while putting activeDB value into main storer after sync", "error", err) + } + log.Debug("set activeDB in epoch", "epoch", lastEpoch) +} diff --git a/trie/storageMarker/trieStorageMarker_test.go b/trie/storageMarker/trieStorageMarker_test.go new file mode 100644 index 00000000000..132234dbd17 --- /dev/null +++ b/trie/storageMarker/trieStorageMarker_test.go @@ -0,0 +1,70 @@ +package storageMarker + +import ( + "testing" + + "github.com/ElrondNetwork/elrond-go/common" + "github.com/ElrondNetwork/elrond-go/testscommon" + "github.com/stretchr/testify/assert" +) + +func TestTrieStorageMarker_MarkStorerAsSyncedAndActive(t *testing.T) { + t.Parallel() + + t.Run("mark storer as synced and active epoch 5", func(t *testing.T) { + sm := NewTrieStorageMarker() + assert.NotNil(t, sm) + + trieSyncedKeyPut := false + activeDbKeyPut := false + storer := &testscommon.StorageManagerStub{ + GetLatestStorageEpochCalled: func() (uint32, error) { + return 5, nil + }, + PutCalled: func(key []byte, val []byte) error { + assert.Equal(t, []byte(common.TrieSyncedKey), key) + assert.Equal(t, []byte(common.TrieSyncedVal), val) + trieSyncedKeyPut = true + return nil + }, + PutInEpochWithoutCacheCalled: func(key []byte, val []byte, epoch uint32) error { + assert.Equal(t, []byte(common.ActiveDBKey), key) + assert.Equal(t, []byte(common.ActiveDBVal), val) + assert.Equal(t, uint32(4), epoch) + activeDbKeyPut = true + return nil + }, + } + sm.MarkStorerAsSyncedAndActive(storer) + assert.True(t, trieSyncedKeyPut) + assert.True(t, activeDbKeyPut) + }) + t.Run("mark storer as synced and active epoch 0", func(t *testing.T) { + sm := NewTrieStorageMarker() + assert.NotNil(t, sm) + + trieSyncedKeyPut := false + activeDbKeyPut := false + storer := &testscommon.StorageManagerStub{ + GetLatestStorageEpochCalled: func() (uint32, error) { + return 0, nil + }, + PutCalled: func(key []byte, val []byte) error { + assert.Equal(t, []byte(common.TrieSyncedKey), key) + assert.Equal(t, []byte(common.TrieSyncedVal), val) + trieSyncedKeyPut = true + return nil + }, + PutInEpochWithoutCacheCalled: func(key []byte, val []byte, epoch uint32) error { + assert.Equal(t, []byte(common.ActiveDBKey), key) + assert.Equal(t, []byte(common.ActiveDBVal), val) + assert.Equal(t, uint32(0), epoch) + activeDbKeyPut = true + return nil + }, + } + sm.MarkStorerAsSyncedAndActive(storer) + assert.True(t, trieSyncedKeyPut) + assert.True(t, activeDbKeyPut) + }) +} diff --git a/trie/sync_test.go b/trie/sync_test.go index 5ee5f8b404f..37bc6e26337 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -214,7 +214,7 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { }, } - err = bn.commitSnapshot(db, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) + err = bn.commitSnapshot(db, nil, nil, context.Background(), &trieMock.MockStatistics{}, &testscommon.ProcessStatusHandlerStub{}) require.Nil(t, err) arg := createMockArgument(timeout) diff --git a/trie/trieStorageManager.go b/trie/trieStorageManager.go index f6c0331019c..6a7ec0b5520 100644 --- a/trie/trieStorageManager.go +++ b/trie/trieStorageManager.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "strings" "sync" "time" @@ -37,6 +38,7 @@ type snapshotsQueueEntry struct { rootHash []byte mainTrieRootHash []byte leavesChan chan core.KeyValueHolder + missingNodesChan chan []byte errChan chan error stats common.SnapshotStatisticsHandler epoch uint32 @@ -320,6 +322,7 @@ func (tsm *trieStorageManager) TakeSnapshot( rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, + missingNodesChan chan []byte, errChan chan error, stats common.SnapshotStatisticsHandler, epoch uint32, @@ -351,6 +354,7 @@ func (tsm *trieStorageManager) TakeSnapshot( mainTrieRootHash: mainTrieRootHash, errChan: errChan, leavesChan: leavesChan, + missingNodesChan: missingNodesChan, stats: stats, epoch: epoch, } @@ -370,6 +374,7 @@ func (tsm *trieStorageManager) SetCheckpoint( rootHash []byte, mainTrieRootHash []byte, leavesChan chan core.KeyValueHolder, + missingNodesChan chan []byte, errChan chan error, stats common.SnapshotStatisticsHandler, ) { @@ -398,6 +403,7 @@ func (tsm *trieStorageManager) SetCheckpoint( rootHash: rootHash, mainTrieRootHash: mainTrieRootHash, leavesChan: leavesChan, + missingNodesChan: missingNodesChan, errChan: errChan, stats: stats, } @@ -431,17 +437,6 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, log.Trace("trie snapshot started", "rootHash", snapshotEntry.rootHash) - newRoot, err := newSnapshotNode(tsm, msh, hsh, snapshotEntry.rootHash) - if err != nil { - writeInChanNonBlocking(snapshotEntry.errChan, err) - treatSnapshotError(err, - "trie storage manager: newSnapshotNode takeSnapshot", - snapshotEntry.rootHash, - snapshotEntry.mainTrieRootHash, - ) - return - } - stsm, err := newSnapshotTrieStorageManager(tsm, snapshotEntry.epoch) if err != nil { writeInChanNonBlocking(snapshotEntry.errChan, err) @@ -452,7 +447,18 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, return } - err = newRoot.commitSnapshot(stsm, snapshotEntry.leavesChan, ctx, snapshotEntry.stats, tsm.idleProvider) + newRoot, err := newSnapshotNode(stsm, msh, hsh, snapshotEntry.rootHash, snapshotEntry.missingNodesChan) + if err != nil { + writeInChanNonBlocking(snapshotEntry.errChan, err) + treatSnapshotError(err, + "trie storage manager: newSnapshotNode takeSnapshot", + snapshotEntry.rootHash, + snapshotEntry.mainTrieRootHash, + ) + return + } + + err = newRoot.commitSnapshot(stsm, snapshotEntry.leavesChan, snapshotEntry.missingNodesChan, ctx, snapshotEntry.stats, tsm.idleProvider) if err != nil { writeInChanNonBlocking(snapshotEntry.errChan, err) treatSnapshotError(err, @@ -479,7 +485,7 @@ func (tsm *trieStorageManager) takeCheckpoint(checkpointEntry *snapshotsQueueEnt log.Trace("trie checkpoint started", "rootHash", checkpointEntry.rootHash) - newRoot, err := newSnapshotNode(tsm, msh, hsh, checkpointEntry.rootHash) + newRoot, err := newSnapshotNode(tsm, msh, hsh, checkpointEntry.rootHash, checkpointEntry.missingNodesChan) if err != nil { writeInChanNonBlocking(checkpointEntry.errChan, err) treatSnapshotError(err, @@ -516,9 +522,13 @@ func newSnapshotNode( msh marshal.Marshalizer, hsh hashing.Hasher, rootHash []byte, + missingNodesCh chan []byte, ) (snapshotNode, error) { newRoot, err := getNodeFromDBAndDecode(rootHash, db, msh, hsh) if err != nil { + if strings.Contains(err.Error(), common.GetNodeFromDBErrorString) { + missingNodesCh <- rootHash + } return nil, err } diff --git a/trie/trieStorageManagerFactory_test.go b/trie/trieStorageManagerFactory_test.go index 204003a8964..f157522516f 100644 --- a/trie/trieStorageManagerFactory_test.go +++ b/trie/trieStorageManagerFactory_test.go @@ -81,7 +81,7 @@ func TestTrieStorageManager_SerialFuncShadowingCallsExpectedImpl(t *testing.T) { IsPruningEnabledCalled: func() bool { return true }, - TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { + TakeSnapshotCalled: func(_ []byte, _ []byte, _ chan core.KeyValueHolder, _ chan []byte, _ chan error, _ common.SnapshotStatisticsHandler, _ uint32) { assert.Fail(t, shouldNotHaveBeenCalledErr.Error()) }, GetLatestStorageEpochCalled: func() (uint32, error) { @@ -133,7 +133,7 @@ func TestTrieStorageManager_SerialFuncShadowingCallsExpectedImpl(t *testing.T) { assert.True(t, getCalled) leavesCh := make(chan core.KeyValueHolder) - tsm.SetCheckpoint(nil, nil, leavesCh, make(chan error, 1), &trieMock.MockStatistics{}) + tsm.SetCheckpoint(nil, nil, leavesCh, nil, make(chan error, 1), &trieMock.MockStatistics{}) select { case <-leavesCh: @@ -163,7 +163,7 @@ func testTsmWithoutSnapshot( _ = tsm.PutInEpochWithoutCache([]byte("hash"), []byte("val"), 0) leavesCh := make(chan core.KeyValueHolder) - tsm.TakeSnapshot(nil, nil, leavesCh, make(chan error, 1), &trieMock.MockStatistics{}, 10) + tsm.TakeSnapshot(nil, nil, leavesCh, nil, make(chan error, 1), &trieMock.MockStatistics{}, 10) select { case <-leavesCh: diff --git a/trie/trieStorageManagerInEpoch.go b/trie/trieStorageManagerInEpoch.go index c7e68280d74..b9bda19dfd1 100644 --- a/trie/trieStorageManagerInEpoch.go +++ b/trie/trieStorageManagerInEpoch.go @@ -59,7 +59,7 @@ func (tsmie *trieStorageManagerInEpoch) Get(key []byte) ([]byte, error) { epoch := tsmie.epoch - i val, err := tsmie.mainStorer.GetFromEpoch(key, epoch) - treatGetFromEpochError(err) + treatGetFromEpochError(err, epoch) if len(val) != 0 { return val, nil } @@ -68,15 +68,15 @@ func (tsmie *trieStorageManagerInEpoch) Get(key []byte) ([]byte, error) { return nil, ErrKeyNotFound } -func treatGetFromEpochError(err error) { +func treatGetFromEpochError(err error, epoch uint32) { if err == nil { return } if errors.IsClosingError(err) { - log.Debug("trieStorageManagerInEpoch closing err", "error", err.Error()) + log.Debug("trieStorageManagerInEpoch closing err", "error", err.Error(), "epoch", epoch) return } - log.Warn("trieStorageManagerInEpoch", "error", err.Error()) + log.Warn("trieStorageManagerInEpoch", "error", err.Error(), "epoch", epoch) } diff --git a/trie/trieStorageManagerInEpoch_test.go b/trie/trieStorageManagerInEpoch_test.go index 3e9514a4297..20ee0e2862a 100644 --- a/trie/trieStorageManagerInEpoch_test.go +++ b/trie/trieStorageManagerInEpoch_test.go @@ -4,7 +4,7 @@ import ( "strings" "testing" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" + "github.com/ElrondNetwork/elrond-go/storage/database" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/trie" "github.com/stretchr/testify/assert" @@ -25,6 +25,7 @@ func TestNewTrieStorageManagerInEpochInvalidStorageManagerType(t *testing.T) { tsmie, err := newTrieStorageManagerInEpoch(trieStorage, 0) assert.Nil(t, tsmie) + assert.NotNil(t, err) assert.True(t, strings.Contains(err.Error(), "invalid storage manager, type is")) } @@ -32,10 +33,11 @@ func TestNewTrieStorageManagerInEpochInvalidStorerType(t *testing.T) { t.Parallel() _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = memorydb.New() + trieStorage.mainStorer = database.NewMemDB() tsmie, err := newTrieStorageManagerInEpoch(trieStorage, 0) assert.Nil(t, tsmie) + assert.NotNil(t, err) assert.True(t, strings.Contains(err.Error(), "invalid storer, type is")) } diff --git a/trie/trieStorageManagerWithoutCheckpoints.go b/trie/trieStorageManagerWithoutCheckpoints.go index 55f5d07ae4b..95f3c285671 100644 --- a/trie/trieStorageManagerWithoutCheckpoints.go +++ b/trie/trieStorageManagerWithoutCheckpoints.go @@ -27,6 +27,7 @@ func (tsm *trieStorageManagerWithoutCheckpoints) SetCheckpoint( _ []byte, _ []byte, chLeaves chan core.KeyValueHolder, + _ chan []byte, _ chan error, stats common.SnapshotStatisticsHandler, ) { diff --git a/trie/trieStorageManagerWithoutCheckpoints_test.go b/trie/trieStorageManagerWithoutCheckpoints_test.go index ce1c95a7f25..670c77825e1 100644 --- a/trie/trieStorageManagerWithoutCheckpoints_test.go +++ b/trie/trieStorageManagerWithoutCheckpoints_test.go @@ -25,11 +25,11 @@ func TestTrieStorageManagerWithoutCheckpoints_SetCheckpoint(t *testing.T) { tsm, _ := trie.NewTrieStorageManager(getNewTrieStorageManagerArgs()) ts, _ := trie.NewTrieStorageManagerWithoutCheckpoints(tsm) - ts.SetCheckpoint([]byte("rootHash"), make([]byte, 0), nil, errChan, &trieMock.MockStatistics{}) + ts.SetCheckpoint([]byte("rootHash"), make([]byte, 0), nil, nil, errChan, &trieMock.MockStatistics{}) assert.Equal(t, uint32(0), ts.PruningBlockingOperations()) chLeaves := make(chan core.KeyValueHolder) - ts.SetCheckpoint([]byte("rootHash"), make([]byte, 0), chLeaves, errChan, &trieMock.MockStatistics{}) + ts.SetCheckpoint([]byte("rootHash"), make([]byte, 0), chLeaves, nil, errChan, &trieMock.MockStatistics{}) assert.Equal(t, uint32(0), ts.PruningBlockingOperations()) select { diff --git a/trie/trieStorageManagerWithoutSnapshot.go b/trie/trieStorageManagerWithoutSnapshot.go index bfa5d51e2b0..f8ce1b2e7fc 100644 --- a/trie/trieStorageManagerWithoutSnapshot.go +++ b/trie/trieStorageManagerWithoutSnapshot.go @@ -37,7 +37,7 @@ func (tsm *trieStorageManagerWithoutSnapshot) PutInEpochWithoutCache(key []byte, } // TakeSnapshot does nothing, as snapshots are disabled for this implementation -func (tsm *trieStorageManagerWithoutSnapshot) TakeSnapshot(_ []byte, _ []byte, leavesChan chan core.KeyValueHolder, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { +func (tsm *trieStorageManagerWithoutSnapshot) TakeSnapshot(_ []byte, _ []byte, leavesChan chan core.KeyValueHolder, _ chan []byte, _ chan error, stats common.SnapshotStatisticsHandler, _ uint32) { safelyCloseChan(leavesChan) stats.SnapshotFinished() } diff --git a/trie/trieStorageManagerWithoutSnapshot_test.go b/trie/trieStorageManagerWithoutSnapshot_test.go index 9cfc3c2d9dc..bb8312d5469 100644 --- a/trie/trieStorageManagerWithoutSnapshot_test.go +++ b/trie/trieStorageManagerWithoutSnapshot_test.go @@ -79,7 +79,7 @@ func TestTrieStorageManagerWithoutSnapshot_TakeSnapshot(t *testing.T) { errChan := make(chan error, 1) leavesCh := make(chan core.KeyValueHolder) - ts.TakeSnapshot(nil, nil, leavesCh, errChan, &trieMock.MockStatistics{}, 10) + ts.TakeSnapshot(nil, nil, leavesCh, nil, errChan, &trieMock.MockStatistics{}, 10) select { case <-leavesCh: diff --git a/trie/trieStorageManager_test.go b/trie/trieStorageManager_test.go index 844b6742eb7..debad895362 100644 --- a/trie/trieStorageManager_test.go +++ b/trie/trieStorageManager_test.go @@ -91,7 +91,7 @@ func TestTrieCheckpoint(t *testing.T) { errChan := make(chan error, 1) trieStorage.AddDirtyCheckpointHashes(rootHash, dirtyHashes) - trieStorage.SetCheckpoint(rootHash, []byte{}, nil, errChan, &trieMock.MockStatistics{}) + trieStorage.SetCheckpoint(rootHash, []byte{}, nil, nil, errChan, &trieMock.MockStatistics{}) trie.WaitForOperationToComplete(trieStorage) val, err = trieStorage.GetFromCheckpoint(rootHash) @@ -108,7 +108,7 @@ func TestTrieStorageManager_SetCheckpointNilErrorChan(t *testing.T) { rootHash := []byte("rootHash") leavesChan := make(chan core.KeyValueHolder) - ts.SetCheckpoint(rootHash, rootHash, leavesChan, nil, &trieMock.MockStatistics{}) + ts.SetCheckpoint(rootHash, rootHash, leavesChan, nil, nil, &trieMock.MockStatistics{}) _, ok := <-leavesChan assert.False(t, ok) @@ -126,7 +126,7 @@ func TestTrieStorageManager_SetCheckpointClosedDb(t *testing.T) { rootHash := []byte("rootHash") leavesChan := make(chan core.KeyValueHolder) errChan := make(chan error, 1) - ts.SetCheckpoint(rootHash, rootHash, leavesChan, errChan, &trieMock.MockStatistics{}) + ts.SetCheckpoint(rootHash, rootHash, leavesChan, nil, errChan, &trieMock.MockStatistics{}) _, ok := <-leavesChan assert.False(t, ok) @@ -142,7 +142,7 @@ func TestTrieStorageManager_SetCheckpointEmptyTrieRootHash(t *testing.T) { rootHash := make([]byte, 32) leavesChan := make(chan core.KeyValueHolder) errChan := make(chan error, 1) - ts.SetCheckpoint(rootHash, rootHash, leavesChan, errChan, &trieMock.MockStatistics{}) + ts.SetCheckpoint(rootHash, rootHash, leavesChan, nil, errChan, &trieMock.MockStatistics{}) _, ok := <-leavesChan assert.False(t, ok) @@ -160,7 +160,7 @@ func TestTrieCheckpoint_DoesNotSaveToCheckpointStorageIfNotDirty(t *testing.T) { assert.Nil(t, val) errChan := make(chan error, 1) - trieStorage.SetCheckpoint(rootHash, []byte{}, nil, errChan, &trieMock.MockStatistics{}) + trieStorage.SetCheckpoint(rootHash, []byte{}, nil, nil, errChan, &trieMock.MockStatistics{}) trie.WaitForOperationToComplete(trieStorage) val, err = trieStorage.GetFromCheckpoint(rootHash) @@ -311,7 +311,7 @@ func TestTrieStorageManager_TakeSnapshotNilErrorChan(t *testing.T) { rootHash := []byte("rootHash") leavesChan := make(chan core.KeyValueHolder) - ts.TakeSnapshot(rootHash, rootHash, leavesChan, nil, &trieMock.MockStatistics{}, 0) + ts.TakeSnapshot(rootHash, rootHash, leavesChan, nil, nil, &trieMock.MockStatistics{}, 0) _, ok := <-leavesChan assert.False(t, ok) @@ -329,7 +329,7 @@ func TestTrieStorageManager_TakeSnapshotClosedDb(t *testing.T) { rootHash := []byte("rootHash") leavesChan := make(chan core.KeyValueHolder) errChan := make(chan error, 1) - ts.TakeSnapshot(rootHash, rootHash, leavesChan, errChan, &trieMock.MockStatistics{}, 0) + ts.TakeSnapshot(rootHash, rootHash, leavesChan, nil, errChan, &trieMock.MockStatistics{}, 0) _, ok := <-leavesChan assert.False(t, ok) @@ -345,7 +345,7 @@ func TestTrieStorageManager_TakeSnapshotEmptyTrieRootHash(t *testing.T) { rootHash := make([]byte, 32) leavesChan := make(chan core.KeyValueHolder) errChan := make(chan error, 1) - ts.TakeSnapshot(rootHash, rootHash, leavesChan, errChan, &trieMock.MockStatistics{}, 0) + ts.TakeSnapshot(rootHash, rootHash, leavesChan, nil, errChan, &trieMock.MockStatistics{}, 0) _, ok := <-leavesChan assert.False(t, ok) @@ -356,12 +356,14 @@ func TestTrieStorageManager_TakeSnapshotWithGetNodeFromDBError(t *testing.T) { t.Parallel() args := getNewTrieStorageManagerArgs() + args.MainStorer = testscommon.NewSnapshotPruningStorerMock() ts, _ := trie.NewTrieStorageManager(args) rootHash := []byte("rootHash") leavesChan := make(chan core.KeyValueHolder) errChan := make(chan error, 1) - ts.TakeSnapshot(rootHash, rootHash, leavesChan, errChan, &trieMock.MockStatistics{}, 0) + missingNodesChan := make(chan []byte, 2) + ts.TakeSnapshot(rootHash, rootHash, leavesChan, missingNodesChan, errChan, &trieMock.MockStatistics{}, 0) _, ok := <-leavesChan assert.False(t, ok) diff --git a/update/factory/accountDBSyncerContainerFactory.go b/update/factory/accountDBSyncerContainerFactory.go index ae82871ead7..1e84e9312a5 100644 --- a/update/factory/accountDBSyncerContainerFactory.go +++ b/update/factory/accountDBSyncerContainerFactory.go @@ -13,6 +13,7 @@ import ( "github.com/ElrondNetwork/elrond-go/state/syncer" "github.com/ElrondNetwork/elrond-go/storage" "github.com/ElrondNetwork/elrond-go/trie" + "github.com/ElrondNetwork/elrond-go/trie/storageMarker" "github.com/ElrondNetwork/elrond-go/update" containers "github.com/ElrondNetwork/elrond-go/update/container" "github.com/ElrondNetwork/elrond-go/update/genesis" @@ -147,6 +148,7 @@ func (a *accountDBSyncersContainerFactory) createUserAccountsSyncer(shardId uint MaxHardCapForMissingNodes: a.maxHardCapForMissingNodes, TrieSyncerVersion: a.trieSyncerVersion, CheckNodesOnDisk: a.checkNodesOnDisk, + StorageMarker: storageMarker.NewTrieStorageMarker(), }, ShardId: shardId, Throttler: thr, @@ -174,6 +176,7 @@ func (a *accountDBSyncersContainerFactory) createValidatorAccountsSyncer(shardId MaxHardCapForMissingNodes: a.maxHardCapForMissingNodes, TrieSyncerVersion: a.trieSyncerVersion, CheckNodesOnDisk: a.checkNodesOnDisk, + StorageMarker: storageMarker.NewTrieStorageMarker(), }, } accountSyncer, err := syncer.NewValidatorAccountsSyncer(args) diff --git a/update/factory/dataTrieFactory.go b/update/factory/dataTrieFactory.go index 4e61e63423a..06c156b9ae6 100644 --- a/update/factory/dataTrieFactory.go +++ b/update/factory/dataTrieFactory.go @@ -12,9 +12,9 @@ import ( "github.com/ElrondNetwork/elrond-go/config" "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/storage/database" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/trie" "github.com/ElrondNetwork/elrond-go/trie/hashesHolder/disabled" "github.com/ElrondNetwork/elrond-go/update" @@ -56,7 +56,7 @@ func NewDataTrieFactory(args ArgsNewDataTrieFactory) (*dataTrieFactory, error) { dbConfig := storageFactory.GetDBFromConfig(args.StorageConfig.DB) dbConfig.FilePath = path.Join(args.SyncFolder, args.StorageConfig.DB.FilePath) - accountsTrieStorage, err := storageUnit.NewStorageUnitFromConf( + accountsTrieStorage, err := storageunit.NewStorageUnitFromConf( storageFactory.GetCacherFromConfig(args.StorageConfig.Cache), dbConfig, ) @@ -65,7 +65,7 @@ func NewDataTrieFactory(args ArgsNewDataTrieFactory) (*dataTrieFactory, error) { } tsmArgs := trie.NewTrieStorageManagerArgs{ MainStorer: accountsTrieStorage, - CheckpointsStorer: memorydb.New(), + CheckpointsStorer: database.NewMemDB(), Marshalizer: args.Marshalizer, Hasher: args.Hasher, GeneralConfig: config.TrieStorageManagerConfig{ diff --git a/update/factory/exportHandlerFactory.go b/update/factory/exportHandlerFactory.go index cec8db108a2..8ccaca368ba 100644 --- a/update/factory/exportHandlerFactory.go +++ b/update/factory/exportHandlerFactory.go @@ -23,9 +23,9 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/state" "github.com/ElrondNetwork/elrond-go/storage" + "github.com/ElrondNetwork/elrond-go/storage/cache" storageFactory "github.com/ElrondNetwork/elrond-go/storage/factory" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" - "github.com/ElrondNetwork/elrond-go/storage/timecache" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/trie" "github.com/ElrondNetwork/elrond-go/update" "github.com/ElrondNetwork/elrond-go/update/genesis" @@ -157,7 +157,11 @@ func NewExportHandlerFactory(args ArgsExporter) (*exportHandlerFactory, error) { if check.IfNil(args.ExistingResolvers) { return nil, update.ErrNilResolverContainer } - if check.IfNil(args.CryptoComponents.MultiSigner()) { + multiSigner, err := args.CryptoComponents.GetMultiSigner(0) + if err != nil { + return nil, err + } + if check.IfNil(multiSigner) { return nil, update.ErrNilMultiSigner } if check.IfNil(args.NodesCoordinator) { @@ -214,7 +218,7 @@ func NewExportHandlerFactory(args ArgsExporter) (*exportHandlerFactory, error) { if args.NumConcurrentTrieSyncers < 1 { return nil, update.ErrInvalidNumConcurrentTrieSyncers } - err := trie.CheckTrieSyncerVersion(args.TrieSyncerVersion) + err = trie.CheckTrieSyncerVersion(args.TrieSyncerVersion) if err != nil { return nil, err } @@ -531,7 +535,7 @@ func (e *exportHandlerFactory) createInterceptors() error { DataPool: e.dataPool, MaxTxNonceDeltaAllowed: math.MaxInt32, TxFeeHandler: &disabled.FeeHandler{}, - BlockBlackList: timecache.NewTimeCache(time.Second), + BlockBlackList: cache.NewTimeCache(time.Second), HeaderSigVerifier: e.headerSigVerifier, HeaderIntegrityVerifier: e.headerIntegrityVerifier, SizeCheckDelta: math.MaxUint32, @@ -559,7 +563,7 @@ func (e *exportHandlerFactory) createInterceptors() error { func createStorer(storageConfig config.StorageConfig, folder string) (storage.Storer, error) { dbConfig := storageFactory.GetDBFromConfig(storageConfig.DB) dbConfig.FilePath = path.Join(folder, storageConfig.DB.FilePath) - accountsTrieStorage, err := storageUnit.NewStorageUnitFromConf( + accountsTrieStorage, err := storageunit.NewStorageUnitFromConf( storageFactory.GetCacherFromConfig(storageConfig.Cache), dbConfig, ) diff --git a/update/factory/fullSyncInterceptors.go b/update/factory/fullSyncInterceptors.go index 42b00245b60..0d492a38516 100644 --- a/update/factory/fullSyncInterceptors.go +++ b/update/factory/fullSyncInterceptors.go @@ -246,7 +246,11 @@ func checkBaseParams( if check.IfNil(cryptoComponents.BlockSigner()) { return process.ErrNilSingleSigner } - if check.IfNil(cryptoComponents.MultiSigner()) { + multiSigner, err := cryptoComponents.GetMultiSigner(0) + if err != nil { + return err + } + if check.IfNil(multiSigner) { return process.ErrNilMultiSigVerifier } if check.IfNil(shardCoordinator) { diff --git a/update/genesis/export.go b/update/genesis/export.go index ec3fde205eb..2fa9433f019 100644 --- a/update/genesis/export.go +++ b/update/genesis/export.go @@ -20,6 +20,7 @@ import ( "github.com/ElrondNetwork/elrond-go/sharding" "github.com/ElrondNetwork/elrond-go/sharding/nodesCoordinator" "github.com/ElrondNetwork/elrond-go/state" + "github.com/ElrondNetwork/elrond-go/trie/keyBuilder" "github.com/ElrondNetwork/elrond-go/update" ) @@ -293,7 +294,7 @@ func (se *stateExport) exportTrie(key string, trie common.Trie) error { } leavesChannel := make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity) - err = trie.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash) + err = trie.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) if err != nil { return err } diff --git a/update/genesis/export_test.go b/update/genesis/export_test.go index 2f903fce291..bbe7320042a 100644 --- a/update/genesis/export_test.go +++ b/update/genesis/export_test.go @@ -302,7 +302,7 @@ func TestStateExport_ExportTrieShouldExportNodesSetupJson(t *testing.T) { RootCalled: func() ([]byte, error) { return []byte{}, nil }, - GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte) error { + GetAllLeavesOnChannelCalled: func(ch chan core.KeyValueHolder, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { mm := &mock.MarshalizerMock{} valInfo := &state.ValidatorInfo{List: string(common.EligibleList)} pacB, _ := mm.Marshal(valInfo) diff --git a/update/process/shardBlock_test.go b/update/process/shardBlock_test.go index 885102f165d..50055faaf1d 100644 --- a/update/process/shardBlock_test.go +++ b/update/process/shardBlock_test.go @@ -11,8 +11,8 @@ import ( "github.com/ElrondNetwork/elrond-go-core/data/transaction" "github.com/ElrondNetwork/elrond-go/dataRetriever" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" "github.com/ElrondNetwork/elrond-go/update" @@ -22,14 +22,14 @@ import ( ) func generateTestCache() storage.Cacher { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) return cache } func generateTestUnit() storage.Storer { - storer, _ := storageUnit.NewStorageUnit( + storer, _ := storageunit.NewStorageUnit( generateTestCache(), - memorydb.New(), + database.NewMemDB(), ) return storer diff --git a/update/sync/syncHeaders_test.go b/update/sync/syncHeaders_test.go index 3bb36e98209..a3f61502930 100644 --- a/update/sync/syncHeaders_test.go +++ b/update/sync/syncHeaders_test.go @@ -12,8 +12,8 @@ import ( "github.com/ElrondNetwork/elrond-go/dataRetriever/dataPool/headersCache" "github.com/ElrondNetwork/elrond-go/process" "github.com/ElrondNetwork/elrond-go/storage" - "github.com/ElrondNetwork/elrond-go/storage/memorydb" - "github.com/ElrondNetwork/elrond-go/storage/storageUnit" + "github.com/ElrondNetwork/elrond-go/storage/database" + "github.com/ElrondNetwork/elrond-go/storage/storageunit" "github.com/ElrondNetwork/elrond-go/testscommon" "github.com/ElrondNetwork/elrond-go/testscommon/hashingMocks" storageStubs "github.com/ElrondNetwork/elrond-go/testscommon/storage" @@ -36,14 +36,14 @@ func createMockHeadersSyncHandlerArgs() ArgsNewHeadersSyncHandler { } func generateTestCache() storage.Cacher { - cache, _ := storageUnit.NewCache(storageUnit.CacheConfig{Type: storageUnit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) return cache } func generateTestUnit() storage.Storer { - storer, _ := storageUnit.NewStorageUnit( + storer, _ := storageunit.NewStorageUnit( generateTestCache(), - memorydb.New(), + database.NewMemDB(), ) return storer @@ -278,5 +278,4 @@ func TestSyncEpochStartMetaHeader_ReceiveHeaderOk(t *testing.T) { metaBlockSync, err := headersSyncHandler.GetEpochStartMetaBlock() require.Nil(t, err) require.Equal(t, meta, metaBlockSync) - } diff --git a/update/sync/syncTransactions_test.go b/update/sync/syncTransactions_test.go index 6adbe65ac71..da155fc8639 100644 --- a/update/sync/syncTransactions_test.go +++ b/update/sync/syncTransactions_test.go @@ -108,6 +108,7 @@ func testWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *testing.T } pendingTxsSyncer, err := NewTransactionsSyncer(args) + require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), storage.ErrKeyNotFound.Error())) require.True(t, strings.Contains(err.Error(), missingUnit.String())) require.True(t, check.IfNil(pendingTxsSyncer))